#
tokens: 47058/50000 5/353 files (page 18/25)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 18 of 25. Use http://codebase.md/beehiveinnovations/gemini-mcp-server?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .claude
│   ├── commands
│   │   └── fix-github-issue.md
│   └── settings.json
├── .coveragerc
├── .dockerignore
├── .env.example
├── .gitattributes
├── .github
│   ├── FUNDING.yml
│   ├── ISSUE_TEMPLATE
│   │   ├── bug_report.yml
│   │   ├── config.yml
│   │   ├── documentation.yml
│   │   ├── feature_request.yml
│   │   └── tool_addition.yml
│   ├── pull_request_template.md
│   └── workflows
│       ├── docker-pr.yml
│       ├── docker-release.yml
│       ├── semantic-pr.yml
│       ├── semantic-release.yml
│       └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── AGENTS.md
├── CHANGELOG.md
├── claude_config_example.json
├── CLAUDE.md
├── clink
│   ├── __init__.py
│   ├── agents
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── claude.py
│   │   ├── codex.py
│   │   └── gemini.py
│   ├── constants.py
│   ├── models.py
│   ├── parsers
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── claude.py
│   │   ├── codex.py
│   │   └── gemini.py
│   └── registry.py
├── code_quality_checks.ps1
├── code_quality_checks.sh
├── communication_simulator_test.py
├── conf
│   ├── __init__.py
│   ├── azure_models.json
│   ├── cli_clients
│   │   ├── claude.json
│   │   ├── codex.json
│   │   └── gemini.json
│   ├── custom_models.json
│   ├── dial_models.json
│   ├── gemini_models.json
│   ├── openai_models.json
│   ├── openrouter_models.json
│   └── xai_models.json
├── config.py
├── docker
│   ├── README.md
│   └── scripts
│       ├── build.ps1
│       ├── build.sh
│       ├── deploy.ps1
│       ├── deploy.sh
│       └── healthcheck.py
├── docker-compose.yml
├── Dockerfile
├── docs
│   ├── adding_providers.md
│   ├── adding_tools.md
│   ├── advanced-usage.md
│   ├── ai_banter.md
│   ├── ai-collaboration.md
│   ├── azure_openai.md
│   ├── configuration.md
│   ├── context-revival.md
│   ├── contributions.md
│   ├── custom_models.md
│   ├── docker-deployment.md
│   ├── gemini-setup.md
│   ├── getting-started.md
│   ├── index.md
│   ├── locale-configuration.md
│   ├── logging.md
│   ├── model_ranking.md
│   ├── testing.md
│   ├── tools
│   │   ├── analyze.md
│   │   ├── apilookup.md
│   │   ├── challenge.md
│   │   ├── chat.md
│   │   ├── clink.md
│   │   ├── codereview.md
│   │   ├── consensus.md
│   │   ├── debug.md
│   │   ├── docgen.md
│   │   ├── listmodels.md
│   │   ├── planner.md
│   │   ├── precommit.md
│   │   ├── refactor.md
│   │   ├── secaudit.md
│   │   ├── testgen.md
│   │   ├── thinkdeep.md
│   │   ├── tracer.md
│   │   └── version.md
│   ├── troubleshooting.md
│   ├── vcr-testing.md
│   └── wsl-setup.md
├── examples
│   ├── claude_config_macos.json
│   └── claude_config_wsl.json
├── LICENSE
├── providers
│   ├── __init__.py
│   ├── azure_openai.py
│   ├── base.py
│   ├── custom.py
│   ├── dial.py
│   ├── gemini.py
│   ├── openai_compatible.py
│   ├── openai.py
│   ├── openrouter.py
│   ├── registries
│   │   ├── __init__.py
│   │   ├── azure.py
│   │   ├── base.py
│   │   ├── custom.py
│   │   ├── dial.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   ├── openrouter.py
│   │   └── xai.py
│   ├── registry_provider_mixin.py
│   ├── registry.py
│   ├── shared
│   │   ├── __init__.py
│   │   ├── model_capabilities.py
│   │   ├── model_response.py
│   │   ├── provider_type.py
│   │   └── temperature.py
│   └── xai.py
├── pyproject.toml
├── pytest.ini
├── README.md
├── requirements-dev.txt
├── requirements.txt
├── run_integration_tests.ps1
├── run_integration_tests.sh
├── run-server.ps1
├── run-server.sh
├── scripts
│   └── sync_version.py
├── server.py
├── simulator_tests
│   ├── __init__.py
│   ├── base_test.py
│   ├── conversation_base_test.py
│   ├── log_utils.py
│   ├── test_analyze_validation.py
│   ├── test_basic_conversation.py
│   ├── test_chat_simple_validation.py
│   ├── test_codereview_validation.py
│   ├── test_consensus_conversation.py
│   ├── test_consensus_three_models.py
│   ├── test_consensus_workflow_accurate.py
│   ├── test_content_validation.py
│   ├── test_conversation_chain_validation.py
│   ├── test_cross_tool_comprehensive.py
│   ├── test_cross_tool_continuation.py
│   ├── test_debug_certain_confidence.py
│   ├── test_debug_validation.py
│   ├── test_line_number_validation.py
│   ├── test_logs_validation.py
│   ├── test_model_thinking_config.py
│   ├── test_o3_model_selection.py
│   ├── test_o3_pro_expensive.py
│   ├── test_ollama_custom_url.py
│   ├── test_openrouter_fallback.py
│   ├── test_openrouter_models.py
│   ├── test_per_tool_deduplication.py
│   ├── test_planner_continuation_history.py
│   ├── test_planner_validation_old.py
│   ├── test_planner_validation.py
│   ├── test_precommitworkflow_validation.py
│   ├── test_prompt_size_limit_bug.py
│   ├── test_refactor_validation.py
│   ├── test_secaudit_validation.py
│   ├── test_testgen_validation.py
│   ├── test_thinkdeep_validation.py
│   ├── test_token_allocation_validation.py
│   ├── test_vision_capability.py
│   └── test_xai_models.py
├── systemprompts
│   ├── __init__.py
│   ├── analyze_prompt.py
│   ├── chat_prompt.py
│   ├── clink
│   │   ├── codex_codereviewer.txt
│   │   ├── default_codereviewer.txt
│   │   ├── default_planner.txt
│   │   └── default.txt
│   ├── codereview_prompt.py
│   ├── consensus_prompt.py
│   ├── debug_prompt.py
│   ├── docgen_prompt.py
│   ├── generate_code_prompt.py
│   ├── planner_prompt.py
│   ├── precommit_prompt.py
│   ├── refactor_prompt.py
│   ├── secaudit_prompt.py
│   ├── testgen_prompt.py
│   ├── thinkdeep_prompt.py
│   └── tracer_prompt.py
├── tests
│   ├── __init__.py
│   ├── CASSETTE_MAINTENANCE.md
│   ├── conftest.py
│   ├── gemini_cassettes
│   │   ├── chat_codegen
│   │   │   └── gemini25_pro_calculator
│   │   │       └── mldev.json
│   │   ├── chat_cross
│   │   │   └── step1_gemini25_flash_number
│   │   │       └── mldev.json
│   │   └── consensus
│   │       └── step2_gemini25_flash_against
│   │           └── mldev.json
│   ├── http_transport_recorder.py
│   ├── mock_helpers.py
│   ├── openai_cassettes
│   │   ├── chat_cross_step2_gpt5_reminder.json
│   │   ├── chat_gpt5_continuation.json
│   │   ├── chat_gpt5_moon_distance.json
│   │   ├── consensus_step1_gpt5_for.json
│   │   └── o3_pro_basic_math.json
│   ├── pii_sanitizer.py
│   ├── sanitize_cassettes.py
│   ├── test_alias_target_restrictions.py
│   ├── test_auto_mode_comprehensive.py
│   ├── test_auto_mode_custom_provider_only.py
│   ├── test_auto_mode_model_listing.py
│   ├── test_auto_mode_provider_selection.py
│   ├── test_auto_mode.py
│   ├── test_auto_model_planner_fix.py
│   ├── test_azure_openai_provider.py
│   ├── test_buggy_behavior_prevention.py
│   ├── test_cassette_semantic_matching.py
│   ├── test_challenge.py
│   ├── test_chat_codegen_integration.py
│   ├── test_chat_cross_model_continuation.py
│   ├── test_chat_openai_integration.py
│   ├── test_chat_simple.py
│   ├── test_clink_claude_agent.py
│   ├── test_clink_claude_parser.py
│   ├── test_clink_codex_agent.py
│   ├── test_clink_gemini_agent.py
│   ├── test_clink_gemini_parser.py
│   ├── test_clink_integration.py
│   ├── test_clink_parsers.py
│   ├── test_clink_tool.py
│   ├── test_collaboration.py
│   ├── test_config.py
│   ├── test_consensus_integration.py
│   ├── test_consensus_schema.py
│   ├── test_consensus.py
│   ├── test_conversation_continuation_integration.py
│   ├── test_conversation_field_mapping.py
│   ├── test_conversation_file_features.py
│   ├── test_conversation_memory.py
│   ├── test_conversation_missing_files.py
│   ├── test_custom_openai_temperature_fix.py
│   ├── test_custom_provider.py
│   ├── test_debug.py
│   ├── test_deploy_scripts.py
│   ├── test_dial_provider.py
│   ├── test_directory_expansion_tracking.py
│   ├── test_disabled_tools.py
│   ├── test_docker_claude_desktop_integration.py
│   ├── test_docker_config_complete.py
│   ├── test_docker_healthcheck.py
│   ├── test_docker_implementation.py
│   ├── test_docker_mcp_validation.py
│   ├── test_docker_security.py
│   ├── test_docker_volume_persistence.py
│   ├── test_file_protection.py
│   ├── test_gemini_token_usage.py
│   ├── test_image_support_integration.py
│   ├── test_image_validation.py
│   ├── test_integration_utf8.py
│   ├── test_intelligent_fallback.py
│   ├── test_issue_245_simple.py
│   ├── test_large_prompt_handling.py
│   ├── test_line_numbers_integration.py
│   ├── test_listmodels_restrictions.py
│   ├── test_listmodels.py
│   ├── test_mcp_error_handling.py
│   ├── test_model_enumeration.py
│   ├── test_model_metadata_continuation.py
│   ├── test_model_resolution_bug.py
│   ├── test_model_restrictions.py
│   ├── test_o3_pro_output_text_fix.py
│   ├── test_o3_temperature_fix_simple.py
│   ├── test_openai_compatible_token_usage.py
│   ├── test_openai_provider.py
│   ├── test_openrouter_provider.py
│   ├── test_openrouter_registry.py
│   ├── test_parse_model_option.py
│   ├── test_per_tool_model_defaults.py
│   ├── test_pii_sanitizer.py
│   ├── test_pip_detection_fix.py
│   ├── test_planner.py
│   ├── test_precommit_workflow.py
│   ├── test_prompt_regression.py
│   ├── test_prompt_size_limit_bug_fix.py
│   ├── test_provider_retry_logic.py
│   ├── test_provider_routing_bugs.py
│   ├── test_provider_utf8.py
│   ├── test_providers.py
│   ├── test_rate_limit_patterns.py
│   ├── test_refactor.py
│   ├── test_secaudit.py
│   ├── test_server.py
│   ├── test_supported_models_aliases.py
│   ├── test_thinking_modes.py
│   ├── test_tools.py
│   ├── test_tracer.py
│   ├── test_utf8_localization.py
│   ├── test_utils.py
│   ├── test_uvx_resource_packaging.py
│   ├── test_uvx_support.py
│   ├── test_workflow_file_embedding.py
│   ├── test_workflow_metadata.py
│   ├── test_workflow_prompt_size_validation_simple.py
│   ├── test_workflow_utf8.py
│   ├── test_xai_provider.py
│   ├── transport_helpers.py
│   └── triangle.png
├── tools
│   ├── __init__.py
│   ├── analyze.py
│   ├── apilookup.py
│   ├── challenge.py
│   ├── chat.py
│   ├── clink.py
│   ├── codereview.py
│   ├── consensus.py
│   ├── debug.py
│   ├── docgen.py
│   ├── listmodels.py
│   ├── models.py
│   ├── planner.py
│   ├── precommit.py
│   ├── refactor.py
│   ├── secaudit.py
│   ├── shared
│   │   ├── __init__.py
│   │   ├── base_models.py
│   │   ├── base_tool.py
│   │   ├── exceptions.py
│   │   └── schema_builders.py
│   ├── simple
│   │   ├── __init__.py
│   │   └── base.py
│   ├── testgen.py
│   ├── thinkdeep.py
│   ├── tracer.py
│   ├── version.py
│   └── workflow
│       ├── __init__.py
│       ├── base.py
│       ├── schema_builders.py
│       └── workflow_mixin.py
├── utils
│   ├── __init__.py
│   ├── client_info.py
│   ├── conversation_memory.py
│   ├── env.py
│   ├── file_types.py
│   ├── file_utils.py
│   ├── image_utils.py
│   ├── model_context.py
│   ├── model_restrictions.py
│   ├── security_config.py
│   ├── storage_backend.py
│   └── token_utils.py
└── zen-mcp-server
```

# Files

--------------------------------------------------------------------------------
/tools/secaudit.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | SECAUDIT Workflow tool - Comprehensive security audit with systematic investigation
  3 | 
  4 | This tool provides a structured workflow for comprehensive security assessment and analysis.
  5 | It guides the CLI agent through systematic investigation steps with forced pauses between each step
  6 | to ensure thorough security examination, vulnerability identification, and compliance assessment
  7 | before proceeding. The tool supports complex security scenarios including OWASP Top 10 coverage,
  8 | compliance framework mapping, and technology-specific security patterns.
  9 | 
 10 | Key features:
 11 | - Step-by-step security audit workflow with progress tracking
 12 | - Context-aware file embedding (references during investigation, full content for analysis)
 13 | - Automatic security issue tracking with severity classification
 14 | - Expert analysis integration with external models
 15 | - Support for focused security audits (OWASP, compliance, technology-specific)
 16 | - Confidence-based workflow optimization
 17 | - Risk-based prioritization and remediation planning
 18 | """
 19 | 
 20 | import logging
 21 | from typing import TYPE_CHECKING, Any, Literal, Optional
 22 | 
 23 | from pydantic import Field, model_validator
 24 | 
 25 | if TYPE_CHECKING:
 26 |     from tools.models import ToolModelCategory
 27 | 
 28 | from config import TEMPERATURE_ANALYTICAL
 29 | from systemprompts import SECAUDIT_PROMPT
 30 | from tools.shared.base_models import WorkflowRequest
 31 | 
 32 | from .workflow.base import WorkflowTool
 33 | 
 34 | logger = logging.getLogger(__name__)
 35 | 
 36 | # Tool-specific field descriptions for security audit workflow
 37 | SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS = {
 38 |     "step": (
 39 |         "Step 1: outline the audit strategy (OWASP Top 10, auth, validation, etc.). Later steps: report findings. MANDATORY: use `relevant_files` for code references and avoid large snippets."
 40 |     ),
 41 |     "step_number": "Current security-audit step number (starts at 1).",
 42 |     "total_steps": "Expected number of audit steps; adjust as new risks surface.",
 43 |     "next_step_required": "True while additional threat analysis remains; set False once you are ready to hand off for validation.",
 44 |     "findings": "Summarize vulnerabilities, auth issues, validation gaps, compliance notes, and positives; update prior findings as needed.",
 45 |     "files_checked": "Absolute paths for every file inspected, including rejected candidates.",
 46 |     "relevant_files": "Absolute paths for security-relevant files (auth modules, configs, sensitive code).",
 47 |     "relevant_context": "Security-critical classes/methods (e.g. 'AuthService.login', 'encryption_helper').",
 48 |     "issues_found": "Security issues with severity (critical/high/medium/low) and descriptions (vulns, auth flaws, injection, crypto, config).",
 49 |     "confidence": "exploring/low/medium/high/very_high/almost_certain/certain. 'certain' blocks external validation—use only when fully complete.",
 50 |     "images": "Optional absolute paths to diagrams or threat models that inform the audit.",
 51 |     "security_scope": "Security context (web, mobile, API, cloud, etc.) including stack, user types, data sensitivity, and threat landscape.",
 52 |     "threat_level": "Assess the threat level: low (internal/low-risk), medium (customer-facing/business data), high (regulated or sensitive), critical (financial/healthcare/PII).",
 53 |     "compliance_requirements": "Applicable compliance frameworks or standards (SOC2, PCI DSS, HIPAA, GDPR, ISO 27001, NIST, etc.).",
 54 |     "audit_focus": "Primary focus area: owasp, compliance, infrastructure, dependencies, or comprehensive.",
 55 |     "severity_filter": "Minimum severity to include when reporting security issues.",
 56 | }
 57 | 
 58 | 
 59 | class SecauditRequest(WorkflowRequest):
 60 |     """Request model for security audit workflow investigation steps"""
 61 | 
 62 |     # Required fields for each investigation step
 63 |     step: str = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["step"])
 64 |     step_number: int = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
 65 |     total_steps: int = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
 66 |     next_step_required: bool = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
 67 | 
 68 |     # Investigation tracking fields
 69 |     findings: str = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["findings"])
 70 |     files_checked: list[str] = Field(
 71 |         default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["files_checked"]
 72 |     )
 73 |     relevant_files: list[str] = Field(
 74 |         default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"]
 75 |     )
 76 |     relevant_context: list[str] = Field(
 77 |         default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"]
 78 |     )
 79 |     issues_found: list[dict] = Field(
 80 |         default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["issues_found"]
 81 |     )
 82 |     confidence: Optional[str] = Field("low", description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["confidence"])
 83 | 
 84 |     # Optional images for visual context
 85 |     images: Optional[list[str]] = Field(default=None, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["images"])
 86 | 
 87 |     # Security audit-specific fields
 88 |     security_scope: Optional[str] = Field(None, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["security_scope"])
 89 |     threat_level: Optional[Literal["low", "medium", "high", "critical"]] = Field(
 90 |         "medium", description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["threat_level"]
 91 |     )
 92 |     compliance_requirements: Optional[list[str]] = Field(
 93 |         default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["compliance_requirements"]
 94 |     )
 95 |     audit_focus: Optional[Literal["owasp", "compliance", "infrastructure", "dependencies", "comprehensive"]] = Field(
 96 |         "comprehensive", description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["audit_focus"]
 97 |     )
 98 |     severity_filter: Optional[Literal["critical", "high", "medium", "low", "all"]] = Field(
 99 |         "all", description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["severity_filter"]
100 |     )
101 | 
102 |     @model_validator(mode="after")
103 |     def validate_security_audit_request(self):
104 |         """Validate security audit request parameters"""
105 |         # Ensure security scope is provided for comprehensive audits
106 |         if self.step_number == 1 and not self.security_scope:
107 |             logger.warning("Security scope not provided for security audit - defaulting to general application")
108 | 
109 |         # Validate compliance requirements format
110 |         if self.compliance_requirements:
111 |             valid_compliance = {"SOC2", "PCI DSS", "HIPAA", "GDPR", "ISO 27001", "NIST", "FedRAMP", "FISMA"}
112 |             for req in self.compliance_requirements:
113 |                 if req not in valid_compliance:
114 |                     logger.warning(f"Unknown compliance requirement: {req}")
115 | 
116 |         return self
117 | 
118 | 
119 | class SecauditTool(WorkflowTool):
120 |     """
121 |     Comprehensive security audit workflow tool.
122 | 
123 |     Provides systematic security assessment through multi-step investigation
124 |     covering OWASP Top 10, compliance requirements, and technology-specific
125 |     security patterns. Follows established WorkflowTool patterns while adding
126 |     security-specific capabilities.
127 |     """
128 | 
129 |     def __init__(self):
130 |         super().__init__()
131 |         self.initial_request = None
132 |         self.security_config = {}
133 | 
134 |     def get_name(self) -> str:
135 |         """Return the unique name of the tool."""
136 |         return "secaudit"
137 | 
138 |     def get_description(self) -> str:
139 |         """Return a description of the tool."""
140 |         return (
141 |             "Performs comprehensive security audit with systematic vulnerability assessment. "
142 |             "Use for OWASP Top 10 analysis, compliance evaluation, threat modeling, and security architecture review. "
143 |             "Guides through structured security investigation with expert validation."
144 |         )
145 | 
146 |     def get_system_prompt(self) -> str:
147 |         """Return the system prompt for expert security analysis."""
148 |         return SECAUDIT_PROMPT
149 | 
150 |     def get_default_temperature(self) -> float:
151 |         """Return the temperature for security audit analysis"""
152 |         return TEMPERATURE_ANALYTICAL
153 | 
154 |     def get_model_category(self) -> "ToolModelCategory":
155 |         """Return the model category for security audit"""
156 |         from tools.models import ToolModelCategory
157 | 
158 |         return ToolModelCategory.EXTENDED_REASONING
159 | 
160 |     def get_workflow_request_model(self) -> type:
161 |         """Return the workflow request model class"""
162 |         return SecauditRequest
163 | 
164 |     def get_tool_fields(self) -> dict[str, dict[str, Any]]:
165 |         """
166 |         Get security audit tool field definitions.
167 | 
168 |         Returns comprehensive field definitions including security-specific
169 |         parameters while maintaining compatibility with existing workflow patterns.
170 |         """
171 |         return SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS
172 | 
173 |     def get_required_actions(
174 |         self, step_number: int, confidence: str, findings: str, total_steps: int, request=None
175 |     ) -> list[str]:
176 |         """
177 |         Provide step-specific guidance for systematic security analysis.
178 | 
179 |         Each step focuses on specific security domains to ensure comprehensive
180 |         coverage without missing critical security aspects.
181 |         """
182 |         if step_number == 1:
183 |             return [
184 |                 "Identify application type, technology stack, and security scope",
185 |                 "Map attack surface, entry points, and data flows",
186 |                 "Determine relevant security standards and compliance requirements",
187 |                 "Establish threat landscape and risk context for the application",
188 |             ]
189 |         elif step_number == 2:
190 |             return [
191 |                 "Analyze authentication mechanisms and session management",
192 |                 "Check authorization controls, access patterns, and privilege escalation risks",
193 |                 "Assess multi-factor authentication, password policies, and account security",
194 |                 "Review identity and access management implementations",
195 |             ]
196 |         elif step_number == 3:
197 |             return [
198 |                 "Examine input validation and sanitization mechanisms across all entry points",
199 |                 "Check for injection vulnerabilities (SQL, XSS, Command, LDAP, NoSQL)",
200 |                 "Review data encryption, sensitive data handling, and cryptographic implementations",
201 |                 "Analyze API input validation, rate limiting, and request/response security",
202 |             ]
203 |         elif step_number == 4:
204 |             return [
205 |                 "Conduct OWASP Top 10 (2021) systematic review across all categories",
206 |                 "Check each OWASP category methodically with specific findings and evidence",
207 |                 "Cross-reference findings with application context and technology stack",
208 |                 "Prioritize vulnerabilities based on exploitability and business impact",
209 |             ]
210 |         elif step_number == 5:
211 |             return [
212 |                 "Analyze third-party dependencies for known vulnerabilities and outdated versions",
213 |                 "Review configuration security, default settings, and hardening measures",
214 |                 "Check for hardcoded secrets, credentials, and sensitive information exposure",
215 |                 "Assess logging, monitoring, incident response, and security observability",
216 |             ]
217 |         elif step_number == 6:
218 |             return [
219 |                 "Evaluate compliance requirements and identify gaps in controls",
220 |                 "Assess business impact and risk levels of all identified findings",
221 |                 "Create prioritized remediation roadmap with timeline and effort estimates",
222 |                 "Document comprehensive security posture and recommendations",
223 |             ]
224 |         else:
225 |             return [
226 |                 "Continue systematic security investigation based on emerging findings",
227 |                 "Deep-dive into specific security concerns identified in previous steps",
228 |                 "Validate security hypotheses and confirm vulnerability assessments",
229 |             ]
230 | 
231 |     def should_call_expert_analysis(self, consolidated_findings, request=None) -> bool:
232 |         """
233 |         Determine when to call expert security analysis.
234 | 
235 |         Expert analysis is triggered when the security audit has meaningful findings
236 |         unless the user requested to skip assistant model.
237 |         """
238 |         # Check if user requested to skip assistant model
239 |         if request and not self.get_request_use_assistant_model(request):
240 |             return False
241 | 
242 |         # Check if we have meaningful investigation data
243 |         return (
244 |             len(consolidated_findings.relevant_files) > 0
245 |             or len(consolidated_findings.findings) >= 2
246 |             or len(consolidated_findings.issues_found) > 0
247 |         )
248 | 
249 |     def prepare_expert_analysis_context(self, consolidated_findings) -> str:
250 |         """
251 |         Prepare comprehensive context for expert security model analysis.
252 | 
253 |         Provides security-specific context including scope, threat level,
254 |         compliance requirements, and systematic findings for expert validation.
255 |         """
256 |         context_parts = [
257 |             f"=== SECURITY AUDIT REQUEST ===\n{self.initial_request or 'Security audit workflow initiated'}\n=== END REQUEST ==="
258 |         ]
259 | 
260 |         # Add investigation summary
261 |         investigation_summary = self._build_security_audit_summary(consolidated_findings)
262 |         context_parts.append(
263 |             f"\n=== AGENT'S SECURITY INVESTIGATION ===\n{investigation_summary}\n=== END INVESTIGATION ==="
264 |         )
265 | 
266 |         # Add security configuration context if available
267 |         if self.security_config:
268 |             config_text = "\n".join(f"- {key}: {value}" for key, value in self.security_config.items() if value)
269 |             context_parts.append(f"\n=== SECURITY CONFIGURATION ===\n{config_text}\n=== END CONFIGURATION ===")
270 | 
271 |         # Add relevant files if available
272 |         if consolidated_findings.relevant_files:
273 |             files_text = "\n".join(f"- {file}" for file in consolidated_findings.relevant_files)
274 |             context_parts.append(f"\n=== RELEVANT FILES ===\n{files_text}\n=== END FILES ===")
275 | 
276 |         # Add relevant security elements if available
277 |         if consolidated_findings.relevant_context:
278 |             methods_text = "\n".join(f"- {method}" for method in consolidated_findings.relevant_context)
279 |             context_parts.append(
280 |                 f"\n=== SECURITY-CRITICAL CODE ELEMENTS ===\n{methods_text}\n=== END CODE ELEMENTS ==="
281 |             )
282 | 
283 |         # Add security issues found if available
284 |         if consolidated_findings.issues_found:
285 |             issues_text = self._format_security_issues(consolidated_findings.issues_found)
286 |             context_parts.append(f"\n=== SECURITY ISSUES IDENTIFIED ===\n{issues_text}\n=== END ISSUES ===")
287 | 
288 |         # Add assessment evolution if available
289 |         if consolidated_findings.hypotheses:
290 |             assessments_text = "\n".join(
291 |                 f"Step {h['step']} ({h['confidence']} confidence): {h['hypothesis']}"
292 |                 for h in consolidated_findings.hypotheses
293 |             )
294 |             context_parts.append(f"\n=== ASSESSMENT EVOLUTION ===\n{assessments_text}\n=== END ASSESSMENTS ===")
295 | 
296 |         # Add images if available
297 |         if consolidated_findings.images:
298 |             images_text = "\n".join(f"- {img}" for img in consolidated_findings.images)
299 |             context_parts.append(
300 |                 f"\n=== VISUAL SECURITY INFORMATION ===\n{images_text}\n=== END VISUAL INFORMATION ==="
301 |             )
302 | 
303 |         return "\n".join(context_parts)
304 | 
305 |     def _format_security_issues(self, issues_found: list[dict]) -> str:
306 |         """
307 |         Format security issues for expert analysis.
308 | 
309 |         Organizes security findings by severity for clear expert review.
310 |         """
311 |         if not issues_found:
312 |             return "No security issues identified during systematic investigation."
313 | 
314 |         # Group issues by severity
315 |         severity_groups = {"critical": [], "high": [], "medium": [], "low": []}
316 | 
317 |         for issue in issues_found:
318 |             severity = issue.get("severity", "low").lower()
319 |             description = issue.get("description", "No description provided")
320 |             if severity in severity_groups:
321 |                 severity_groups[severity].append(description)
322 |             else:
323 |                 severity_groups["low"].append(f"[{severity.upper()}] {description}")
324 | 
325 |         formatted_issues = []
326 |         for severity in ["critical", "high", "medium", "low"]:
327 |             if severity_groups[severity]:
328 |                 formatted_issues.append(f"\n{severity.upper()} SEVERITY:")
329 |                 for issue in severity_groups[severity]:
330 |                     formatted_issues.append(f"  • {issue}")
331 | 
332 |         return "\n".join(formatted_issues) if formatted_issues else "No security issues identified."
333 | 
334 |     def _build_security_audit_summary(self, consolidated_findings) -> str:
335 |         """Prepare a comprehensive summary of the security audit investigation."""
336 |         summary_parts = [
337 |             "=== SYSTEMATIC SECURITY AUDIT INVESTIGATION SUMMARY ===",
338 |             f"Total steps: {len(consolidated_findings.findings)}",
339 |             f"Files examined: {len(consolidated_findings.files_checked)}",
340 |             f"Relevant files identified: {len(consolidated_findings.relevant_files)}",
341 |             f"Security-critical elements analyzed: {len(consolidated_findings.relevant_context)}",
342 |             f"Security issues identified: {len(consolidated_findings.issues_found)}",
343 |             "",
344 |             "=== INVESTIGATION PROGRESSION ===",
345 |         ]
346 | 
347 |         for finding in consolidated_findings.findings:
348 |             summary_parts.append(finding)
349 | 
350 |         return "\n".join(summary_parts)
351 | 
352 |     def get_input_schema(self) -> dict[str, Any]:
353 |         """Generate input schema using WorkflowSchemaBuilder with security audit-specific overrides."""
354 |         from .workflow.schema_builders import WorkflowSchemaBuilder
355 | 
356 |         # Security audit workflow-specific field overrides
357 |         secaudit_field_overrides = {
358 |             "step": {
359 |                 "type": "string",
360 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["step"],
361 |             },
362 |             "step_number": {
363 |                 "type": "integer",
364 |                 "minimum": 1,
365 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["step_number"],
366 |             },
367 |             "total_steps": {
368 |                 "type": "integer",
369 |                 "minimum": 1,
370 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["total_steps"],
371 |             },
372 |             "next_step_required": {
373 |                 "type": "boolean",
374 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"],
375 |             },
376 |             "findings": {
377 |                 "type": "string",
378 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["findings"],
379 |             },
380 |             "files_checked": {
381 |                 "type": "array",
382 |                 "items": {"type": "string"},
383 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["files_checked"],
384 |             },
385 |             "relevant_files": {
386 |                 "type": "array",
387 |                 "items": {"type": "string"},
388 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"],
389 |             },
390 |             "confidence": {
391 |                 "type": "string",
392 |                 "enum": ["exploring", "low", "medium", "high", "very_high", "almost_certain", "certain"],
393 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["confidence"],
394 |             },
395 |             "issues_found": {
396 |                 "type": "array",
397 |                 "items": {"type": "object"},
398 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["issues_found"],
399 |             },
400 |             "images": {
401 |                 "type": "array",
402 |                 "items": {"type": "string"},
403 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["images"],
404 |             },
405 |             # Security audit-specific fields (for step 1)
406 |             "security_scope": {
407 |                 "type": "string",
408 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["security_scope"],
409 |             },
410 |             "threat_level": {
411 |                 "type": "string",
412 |                 "enum": ["low", "medium", "high", "critical"],
413 |                 "default": "medium",
414 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["threat_level"],
415 |             },
416 |             "compliance_requirements": {
417 |                 "type": "array",
418 |                 "items": {"type": "string"},
419 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["compliance_requirements"],
420 |             },
421 |             "audit_focus": {
422 |                 "type": "string",
423 |                 "enum": ["owasp", "compliance", "infrastructure", "dependencies", "comprehensive"],
424 |                 "default": "comprehensive",
425 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["audit_focus"],
426 |             },
427 |             "severity_filter": {
428 |                 "type": "string",
429 |                 "enum": ["critical", "high", "medium", "low", "all"],
430 |                 "default": "all",
431 |                 "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["severity_filter"],
432 |             },
433 |         }
434 | 
435 |         # Use WorkflowSchemaBuilder with security audit-specific tool fields
436 |         return WorkflowSchemaBuilder.build_schema(
437 |             tool_specific_fields=secaudit_field_overrides,
438 |             model_field_schema=self.get_model_field_schema(),
439 |             auto_mode=self.is_effective_auto_mode(),
440 |             tool_name=self.get_name(),
441 |         )
442 | 
443 |     # Hook method overrides for security audit-specific behavior
444 | 
445 |     def prepare_step_data(self, request) -> dict:
446 |         """Map security audit-specific fields for internal processing."""
447 |         step_data = {
448 |             "step": request.step,
449 |             "step_number": request.step_number,
450 |             "findings": request.findings,
451 |             "files_checked": request.files_checked,
452 |             "relevant_files": request.relevant_files,
453 |             "relevant_context": request.relevant_context,
454 |             "issues_found": request.issues_found,
455 |             "confidence": request.confidence,
456 |             "hypothesis": request.findings,  # Map findings to hypothesis for compatibility
457 |             "images": request.images or [],
458 |         }
459 | 
460 |         # Store security-specific configuration on first step
461 |         if request.step_number == 1:
462 |             self.security_config = {
463 |                 "security_scope": request.security_scope,
464 |                 "threat_level": request.threat_level,
465 |                 "compliance_requirements": request.compliance_requirements,
466 |                 "audit_focus": request.audit_focus,
467 |                 "severity_filter": request.severity_filter,
468 |             }
469 | 
470 |         return step_data
471 | 
472 |     def should_skip_expert_analysis(self, request, consolidated_findings) -> bool:
473 |         """Security audit workflow skips expert analysis when the CLI agent has "certain" confidence."""
474 |         return request.confidence == "certain" and not request.next_step_required
475 | 
476 |     def store_initial_issue(self, step_description: str):
477 |         """Store initial request for expert analysis."""
478 |         self.initial_request = step_description
479 | 
480 |     def should_include_files_in_expert_prompt(self) -> bool:
481 |         """Include files in expert analysis for comprehensive security audit."""
482 |         return True
483 | 
484 |     def should_embed_system_prompt(self) -> bool:
485 |         """Embed system prompt in expert analysis for proper context."""
486 |         return True
487 | 
488 |     def get_expert_thinking_mode(self) -> str:
489 |         """Use high thinking mode for thorough security analysis."""
490 |         return "high"
491 | 
492 |     def get_expert_analysis_instruction(self) -> str:
493 |         """Get specific instruction for security audit expert analysis."""
494 |         return (
495 |             "Please provide comprehensive security analysis based on the investigation findings. "
496 |             "Focus on identifying any remaining vulnerabilities, validating the completeness of the analysis, "
497 |             "and providing final recommendations for security improvements, following the OWASP-based "
498 |             "format specified in the system prompt."
499 |         )
500 | 
501 |     def get_completion_next_steps_message(self, expert_analysis_used: bool = False) -> str:
502 |         """
503 |         Security audit-specific completion message.
504 |         """
505 |         base_message = (
506 |             "SECURITY AUDIT IS COMPLETE. You MUST now summarize and present ALL security findings organized by "
507 |             "severity (Critical → High → Medium → Low), specific code locations with line numbers, and exact "
508 |             "remediation steps for each vulnerability. Clearly prioritize the top 3 security issues that need "
509 |             "immediate attention. Provide concrete, actionable guidance for each vulnerability—make it easy for "
510 |             "developers to understand exactly what needs to be fixed and how to implement the security improvements."
511 |         )
512 | 
513 |         # Add expert analysis guidance only when expert analysis was actually used
514 |         if expert_analysis_used:
515 |             expert_guidance = self.get_expert_analysis_guidance()
516 |             if expert_guidance:
517 |                 return f"{base_message}\n\n{expert_guidance}"
518 | 
519 |         return base_message
520 | 
521 |     def get_expert_analysis_guidance(self) -> str:
522 |         """
523 |         Provide specific guidance for handling expert analysis in security audits.
524 |         """
525 |         return (
526 |             "IMPORTANT: Analysis from an assistant model has been provided above. You MUST critically evaluate and validate "
527 |             "the expert security findings rather than accepting them blindly. Cross-reference the expert analysis with "
528 |             "your own investigation findings, verify that suggested security improvements are appropriate for this "
529 |             "application's context and threat model, and ensure recommendations align with the project's security requirements. "
530 |             "Present a synthesis that combines your systematic security review with validated expert insights, clearly "
531 |             "distinguishing between vulnerabilities you've independently confirmed and additional insights from expert analysis."
532 |         )
533 | 
534 |     def get_step_guidance_message(self, request) -> str:
535 |         """
536 |         Security audit-specific step guidance with detailed investigation instructions.
537 |         """
538 |         step_guidance = self.get_security_audit_step_guidance(request.step_number, request.confidence, request)
539 |         return step_guidance["next_steps"]
540 | 
541 |     def get_security_audit_step_guidance(self, step_number: int, confidence: str, request) -> dict[str, Any]:
542 |         """
543 |         Provide step-specific guidance for security audit workflow.
544 |         """
545 |         # Generate the next steps instruction based on required actions
546 |         required_actions = self.get_required_actions(step_number, confidence, request.findings, request.total_steps)
547 | 
548 |         if step_number == 1:
549 |             next_steps = (
550 |                 f"MANDATORY: DO NOT call the {self.get_name()} tool again immediately. You MUST first examine "
551 |                 f"the code files thoroughly using appropriate tools. CRITICAL AWARENESS: You need to understand "
552 |                 f"the security landscape, identify potential vulnerabilities across OWASP Top 10 categories, "
553 |                 f"and look for authentication flaws, injection points, cryptographic issues, and authorization bypasses. "
554 |                 f"Use file reading tools, security analysis, and systematic examination to gather comprehensive information. "
555 |                 f"Only call {self.get_name()} again AFTER completing your security investigation. When you call "
556 |                 f"{self.get_name()} next time, use step_number: {step_number + 1} and report specific "
557 |                 f"files examined, vulnerabilities found, and security assessments discovered."
558 |             )
559 |         elif confidence in ["exploring", "low"]:
560 |             next_steps = (
561 |                 f"STOP! Do NOT call {self.get_name()} again yet. Based on your findings, you've identified areas that need "
562 |                 f"deeper security analysis. MANDATORY ACTIONS before calling {self.get_name()} step {step_number + 1}:\n"
563 |                 + "\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
564 |                 + f"\n\nOnly call {self.get_name()} again with step_number: {step_number + 1} AFTER "
565 |                 + "completing these security audit tasks."
566 |             )
567 |         elif confidence in ["medium", "high"]:
568 |             next_steps = (
569 |                 f"WAIT! Your security audit needs final verification. DO NOT call {self.get_name()} immediately. REQUIRED ACTIONS:\n"
570 |                 + "\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
571 |                 + f"\n\nREMEMBER: Ensure you have identified all significant vulnerabilities across all severity levels and "
572 |                 f"verified the completeness of your security review. Document findings with specific file references and "
573 |                 f"line numbers where applicable, then call {self.get_name()} with step_number: {step_number + 1}."
574 |             )
575 |         else:
576 |             next_steps = (
577 |                 f"PAUSE SECURITY AUDIT. Before calling {self.get_name()} step {step_number + 1}, you MUST examine more code thoroughly. "
578 |                 + "Required: "
579 |                 + ", ".join(required_actions[:2])
580 |                 + ". "
581 |                 + f"Your next {self.get_name()} call (step_number: {step_number + 1}) must include "
582 |                 f"NEW evidence from actual security analysis, not just theories. NO recursive {self.get_name()} calls "
583 |                 f"without investigation work!"
584 |             )
585 | 
586 |         return {"next_steps": next_steps}
587 | 
588 |     def customize_workflow_response(self, response_data: dict, request) -> dict:
589 |         """
590 |         Customize response to match security audit workflow format.
591 |         """
592 |         # Store initial request on first step
593 |         if request.step_number == 1:
594 |             self.initial_request = request.step
595 |             # Store security configuration for expert analysis
596 |             if request.relevant_files:
597 |                 self.security_config = {
598 |                     "relevant_files": request.relevant_files,
599 |                     "security_scope": request.security_scope,
600 |                     "threat_level": request.threat_level,
601 |                     "compliance_requirements": request.compliance_requirements,
602 |                     "audit_focus": request.audit_focus,
603 |                     "severity_filter": request.severity_filter,
604 |                 }
605 | 
606 |         # Convert generic status names to security audit-specific ones
607 |         tool_name = self.get_name()
608 |         status_mapping = {
609 |             f"{tool_name}_in_progress": "security_audit_in_progress",
610 |             f"pause_for_{tool_name}": "pause_for_security_audit",
611 |             f"{tool_name}_required": "security_audit_required",
612 |             f"{tool_name}_complete": "security_audit_complete",
613 |         }
614 | 
615 |         if response_data["status"] in status_mapping:
616 |             response_data["status"] = status_mapping[response_data["status"]]
617 | 
618 |         # Rename status field to match security audit workflow
619 |         if f"{tool_name}_status" in response_data:
620 |             response_data["security_audit_status"] = response_data.pop(f"{tool_name}_status")
621 |             # Add security audit-specific status fields
622 |             response_data["security_audit_status"]["vulnerabilities_by_severity"] = {}
623 |             for issue in self.consolidated_findings.issues_found:
624 |                 severity = issue.get("severity", "unknown")
625 |                 if severity not in response_data["security_audit_status"]["vulnerabilities_by_severity"]:
626 |                     response_data["security_audit_status"]["vulnerabilities_by_severity"][severity] = 0
627 |                 response_data["security_audit_status"]["vulnerabilities_by_severity"][severity] += 1
628 |             response_data["security_audit_status"]["audit_confidence"] = self.get_request_confidence(request)
629 | 
630 |         # Map complete_secaudit to complete_security_audit
631 |         if f"complete_{tool_name}" in response_data:
632 |             response_data["complete_security_audit"] = response_data.pop(f"complete_{tool_name}")
633 | 
634 |         # Map the completion flag to match security audit workflow
635 |         if f"{tool_name}_complete" in response_data:
636 |             response_data["security_audit_complete"] = response_data.pop(f"{tool_name}_complete")
637 | 
638 |         return response_data
639 | 
640 |     # Override inheritance hooks for security audit-specific behavior
641 | 
642 |     def get_completion_status(self) -> str:
643 |         """Security audit tools use audit-specific status."""
644 |         return "security_analysis_complete"
645 | 
646 |     def get_completion_data_key(self) -> str:
647 |         """Security audit uses 'complete_security_audit' key."""
648 |         return "complete_security_audit"
649 | 
650 |     def get_final_analysis_from_request(self, request):
651 |         """Security audit tools use 'findings' field."""
652 |         return request.findings
653 | 
654 |     def get_confidence_level(self, request) -> str:
655 |         """Security audit tools use 'certain' for high confidence."""
656 |         return "certain"
657 | 
658 |     def get_completion_message(self) -> str:
659 |         """Security audit-specific completion message."""
660 |         return (
661 |             "Security audit complete with CERTAIN confidence. You have identified all significant vulnerabilities "
662 |             "and provided comprehensive security analysis. MANDATORY: Present the user with the complete security audit results "
663 |             "categorized by severity, and IMMEDIATELY proceed with implementing the highest priority security fixes "
664 |             "or provide specific guidance for vulnerability remediation. Focus on actionable security recommendations."
665 |         )
666 | 
667 |     def get_skip_reason(self) -> str:
668 |         """Security audit-specific skip reason."""
669 |         return "Completed comprehensive security audit with full confidence locally"
670 | 
671 |     def get_skip_expert_analysis_status(self) -> str:
672 |         """Security audit-specific expert analysis skip status."""
673 |         return "skipped_due_to_certain_audit_confidence"
674 | 
675 |     def prepare_work_summary(self) -> str:
676 |         """Security audit-specific work summary."""
677 |         return self._build_security_audit_summary(self.consolidated_findings)
678 | 
679 |     def get_request_model(self):
680 |         """Return the request model for this tool"""
681 |         return SecauditRequest
682 | 
683 |     async def prepare_prompt(self, request: SecauditRequest) -> str:
684 |         """Not used - workflow tools use execute_workflow()."""
685 |         return ""  # Workflow tools use execute_workflow() directly
686 | 
```

--------------------------------------------------------------------------------
/simulator_tests/test_testgen_validation.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | TestGen Tool Validation Test
  4 | 
  5 | Tests the testgen tool's capabilities using the workflow architecture.
  6 | This validates that the workflow-based implementation guides Claude through
  7 | systematic test generation analysis before creating comprehensive test suites.
  8 | """
  9 | 
 10 | import json
 11 | from typing import Optional
 12 | 
 13 | from .conversation_base_test import ConversationBaseTest
 14 | 
 15 | 
 16 | class TestGenValidationTest(ConversationBaseTest):
 17 |     """Test testgen tool with workflow architecture"""
 18 | 
 19 |     @property
 20 |     def test_name(self) -> str:
 21 |         return "testgen_validation"
 22 | 
 23 |     @property
 24 |     def test_description(self) -> str:
 25 |         return "TestGen tool validation with step-by-step test planning"
 26 | 
 27 |     def run_test(self) -> bool:
 28 |         """Test testgen tool capabilities"""
 29 |         # Set up the test environment
 30 |         self.setUp()
 31 | 
 32 |         try:
 33 |             self.logger.info("Test: TestGen tool validation")
 34 | 
 35 |             # Create sample code files to test
 36 |             self._create_test_code_files()
 37 | 
 38 |             # Test 1: Single investigation session with multiple steps
 39 |             if not self._test_single_test_generation_session():
 40 |                 return False
 41 | 
 42 |             # Test 2: Test generation with pattern following
 43 |             if not self._test_generation_with_pattern_following():
 44 |                 return False
 45 | 
 46 |             # Test 3: Complete test generation with expert analysis
 47 |             if not self._test_complete_generation_with_analysis():
 48 |                 return False
 49 | 
 50 |             # Test 4: Certain confidence behavior
 51 |             if not self._test_certain_confidence():
 52 |                 return False
 53 | 
 54 |             # Test 5: Context-aware file embedding
 55 |             if not self._test_context_aware_file_embedding():
 56 |                 return False
 57 | 
 58 |             # Test 6: Multi-step test planning
 59 |             if not self._test_multi_step_test_planning():
 60 |                 return False
 61 | 
 62 |             self.logger.info("  ✅ All testgen validation tests passed")
 63 |             return True
 64 | 
 65 |         except Exception as e:
 66 |             self.logger.error(f"TestGen validation test failed: {e}")
 67 |             return False
 68 | 
 69 |     def _create_test_code_files(self):
 70 |         """Create sample code files for test generation"""
 71 |         # Create a calculator module with various functions
 72 |         calculator_code = """#!/usr/bin/env python3
 73 | \"\"\"
 74 | Simple calculator module for demonstration
 75 | \"\"\"
 76 | 
 77 | def add(a, b):
 78 |     \"\"\"Add two numbers\"\"\"
 79 |     return a + b
 80 | 
 81 | def subtract(a, b):
 82 |     \"\"\"Subtract b from a\"\"\"
 83 |     return a - b
 84 | 
 85 | def multiply(a, b):
 86 |     \"\"\"Multiply two numbers\"\"\"
 87 |     return a * b
 88 | 
 89 | def divide(a, b):
 90 |     \"\"\"Divide a by b\"\"\"
 91 |     if b == 0:
 92 |         raise ValueError("Cannot divide by zero")
 93 |     return a / b
 94 | 
 95 | def calculate_percentage(value, percentage):
 96 |     \"\"\"Calculate percentage of a value\"\"\"
 97 |     if percentage < 0:
 98 |         raise ValueError("Percentage cannot be negative")
 99 |     if percentage > 100:
100 |         raise ValueError("Percentage cannot exceed 100")
101 |     return (value * percentage) / 100
102 | 
103 | def power(base, exponent):
104 |     \"\"\"Calculate base raised to exponent\"\"\"
105 |     if base == 0 and exponent < 0:
106 |         raise ValueError("Cannot raise 0 to negative power")
107 |     return base ** exponent
108 | """
109 | 
110 |         # Create test file
111 |         self.calculator_file = self.create_additional_test_file("calculator.py", calculator_code)
112 |         self.logger.info(f"  ✅ Created calculator module: {self.calculator_file}")
113 | 
114 |         # Create a simple existing test file to use as pattern
115 |         existing_test = """#!/usr/bin/env python3
116 | import pytest
117 | from calculator import add, subtract
118 | 
119 | class TestCalculatorBasic:
120 |     \"\"\"Test basic calculator operations\"\"\"
121 | 
122 |     def test_add_positive_numbers(self):
123 |         \"\"\"Test adding two positive numbers\"\"\"
124 |         assert add(2, 3) == 5
125 |         assert add(10, 20) == 30
126 | 
127 |     def test_add_negative_numbers(self):
128 |         \"\"\"Test adding negative numbers\"\"\"
129 |         assert add(-5, -3) == -8
130 |         assert add(-10, 5) == -5
131 | 
132 |     def test_subtract_positive(self):
133 |         \"\"\"Test subtracting positive numbers\"\"\"
134 |         assert subtract(10, 3) == 7
135 |         assert subtract(5, 5) == 0
136 | """
137 | 
138 |         self.existing_test_file = self.create_additional_test_file("test_calculator_basic.py", existing_test)
139 |         self.logger.info(f"  ✅ Created existing test file: {self.existing_test_file}")
140 | 
141 |     def _test_single_test_generation_session(self) -> bool:
142 |         """Test a complete test generation session with multiple steps"""
143 |         try:
144 |             self.logger.info("  1.1: Testing single test generation session")
145 | 
146 |             # Step 1: Start investigation
147 |             self.logger.info("    1.1.1: Step 1 - Initial test planning")
148 |             response1, continuation_id = self.call_mcp_tool(
149 |                 "testgen",
150 |                 {
151 |                     "step": "I need to generate comprehensive tests for the calculator module. Let me start by analyzing the code structure and understanding the functionality.",
152 |                     "step_number": 1,
153 |                     "total_steps": 4,
154 |                     "next_step_required": True,
155 |                     "findings": "Calculator module contains 6 functions: add, subtract, multiply, divide, calculate_percentage, and power. Each has specific error conditions that need testing.",
156 |                     "files_checked": [self.calculator_file],
157 |                     "relevant_files": [self.calculator_file],
158 |                     "relevant_context": ["add", "subtract", "multiply", "divide", "calculate_percentage", "power"],
159 |                 },
160 |             )
161 | 
162 |             if not response1 or not continuation_id:
163 |                 self.logger.error("Failed to get initial test planning response")
164 |                 return False
165 | 
166 |             # Parse and validate JSON response
167 |             response1_data = self._parse_testgen_response(response1)
168 |             if not response1_data:
169 |                 return False
170 | 
171 |             # Validate step 1 response structure
172 |             if not self._validate_step_response(response1_data, 1, 4, True, "pause_for_test_analysis"):
173 |                 return False
174 | 
175 |             self.logger.info(f"    ✅ Step 1 successful, continuation_id: {continuation_id}")
176 | 
177 |             # Step 2: Analyze test requirements
178 |             self.logger.info("    1.1.2: Step 2 - Test requirements analysis")
179 |             response2, _ = self.call_mcp_tool(
180 |                 "testgen",
181 |                 {
182 |                     "step": "Now analyzing the test requirements for each function, identifying edge cases and boundary conditions.",
183 |                     "step_number": 2,
184 |                     "total_steps": 4,
185 |                     "next_step_required": True,
186 |                     "findings": "Identified key test scenarios: (1) divide - zero division error, (2) calculate_percentage - negative/over 100 validation, (3) power - zero to negative power error. Need tests for normal cases and edge cases.",
187 |                     "files_checked": [self.calculator_file],
188 |                     "relevant_files": [self.calculator_file],
189 |                     "relevant_context": ["divide", "calculate_percentage", "power"],
190 |                     "confidence": "medium",
191 |                     "continuation_id": continuation_id,
192 |                 },
193 |             )
194 | 
195 |             if not response2:
196 |                 self.logger.error("Failed to continue test planning to step 2")
197 |                 return False
198 | 
199 |             response2_data = self._parse_testgen_response(response2)
200 |             if not self._validate_step_response(response2_data, 2, 4, True, "pause_for_test_analysis"):
201 |                 return False
202 | 
203 |             # Check test generation status tracking
204 |             test_status = response2_data.get("test_generation_status", {})
205 |             if test_status.get("test_scenarios_identified", 0) < 3:
206 |                 self.logger.error("Test scenarios not properly tracked")
207 |                 return False
208 | 
209 |             if test_status.get("analysis_confidence") != "medium":
210 |                 self.logger.error("Confidence level not properly tracked")
211 |                 return False
212 | 
213 |             self.logger.info("    ✅ Step 2 successful with proper tracking")
214 | 
215 |             # Store continuation_id for next test
216 |             self.test_continuation_id = continuation_id
217 |             return True
218 | 
219 |         except Exception as e:
220 |             self.logger.error(f"Single test generation session test failed: {e}")
221 |             return False
222 | 
223 |     def _test_generation_with_pattern_following(self) -> bool:
224 |         """Test test generation following existing patterns"""
225 |         try:
226 |             self.logger.info("  1.2: Testing test generation with pattern following")
227 | 
228 |             # Start a new investigation with existing test patterns
229 |             self.logger.info("    1.2.1: Start test generation with pattern reference")
230 |             response1, continuation_id = self.call_mcp_tool(
231 |                 "testgen",
232 |                 {
233 |                     "step": "Generating tests for remaining calculator functions following existing test patterns",
234 |                     "step_number": 1,
235 |                     "total_steps": 3,
236 |                     "next_step_required": True,
237 |                     "findings": "Found existing test pattern using pytest with class-based organization and descriptive test names",
238 |                     "files_checked": [self.calculator_file, self.existing_test_file],
239 |                     "relevant_files": [self.calculator_file, self.existing_test_file],
240 |                     "relevant_context": ["TestCalculatorBasic", "multiply", "divide", "calculate_percentage", "power"],
241 |                 },
242 |             )
243 | 
244 |             if not response1 or not continuation_id:
245 |                 self.logger.error("Failed to start pattern following test")
246 |                 return False
247 | 
248 |             # Step 2: Analyze patterns
249 |             self.logger.info("    1.2.2: Step 2 - Pattern analysis")
250 |             response2, _ = self.call_mcp_tool(
251 |                 "testgen",
252 |                 {
253 |                     "step": "Analyzing the existing test patterns to maintain consistency",
254 |                     "step_number": 2,
255 |                     "total_steps": 3,
256 |                     "next_step_required": True,
257 |                     "findings": "Existing tests use: class-based organization (TestCalculatorBasic), descriptive method names (test_operation_scenario), multiple assertions per test, pytest framework",
258 |                     "files_checked": [self.existing_test_file],
259 |                     "relevant_files": [self.calculator_file, self.existing_test_file],
260 |                     "confidence": "high",
261 |                     "continuation_id": continuation_id,
262 |                 },
263 |             )
264 | 
265 |             if not response2:
266 |                 self.logger.error("Failed to continue to step 2")
267 |                 return False
268 | 
269 |             self.logger.info("    ✅ Pattern analysis successful")
270 |             return True
271 | 
272 |         except Exception as e:
273 |             self.logger.error(f"Pattern following test failed: {e}")
274 |             return False
275 | 
276 |     def _test_complete_generation_with_analysis(self) -> bool:
277 |         """Test complete test generation ending with expert analysis"""
278 |         try:
279 |             self.logger.info("  1.3: Testing complete test generation with expert analysis")
280 | 
281 |             # Use the continuation from first test or start fresh
282 |             continuation_id = getattr(self, "test_continuation_id", None)
283 |             if not continuation_id:
284 |                 # Start fresh if no continuation available
285 |                 self.logger.info("    1.3.0: Starting fresh test generation")
286 |                 response0, continuation_id = self.call_mcp_tool(
287 |                     "testgen",
288 |                     {
289 |                         "step": "Analyzing calculator module for comprehensive test generation",
290 |                         "step_number": 1,
291 |                         "total_steps": 2,
292 |                         "next_step_required": True,
293 |                         "findings": "Identified 6 functions needing tests with various edge cases",
294 |                         "files_checked": [self.calculator_file],
295 |                         "relevant_files": [self.calculator_file],
296 |                         "relevant_context": ["add", "subtract", "multiply", "divide", "calculate_percentage", "power"],
297 |                     },
298 |                 )
299 |                 if not response0 or not continuation_id:
300 |                     self.logger.error("Failed to start fresh test generation")
301 |                     return False
302 | 
303 |             # Final step - trigger expert analysis
304 |             self.logger.info("    1.3.1: Final step - complete test planning")
305 |             response_final, _ = self.call_mcp_tool(
306 |                 "testgen",
307 |                 {
308 |                     "step": "Test planning complete. Identified all test scenarios including edge cases, error conditions, and boundary values for comprehensive coverage.",
309 |                     "step_number": 2,
310 |                     "total_steps": 2,
311 |                     "next_step_required": False,  # Final step - triggers expert analysis
312 |                     "findings": "Complete test plan: normal operations, edge cases (zero, negative), error conditions (divide by zero, invalid percentage, zero to negative power), boundary values",
313 |                     "files_checked": [self.calculator_file],
314 |                     "relevant_files": [self.calculator_file],
315 |                     "relevant_context": ["add", "subtract", "multiply", "divide", "calculate_percentage", "power"],
316 |                     "confidence": "high",
317 |                     "continuation_id": continuation_id,
318 |                     "model": "flash",  # Use flash for expert analysis
319 |                 },
320 |             )
321 | 
322 |             if not response_final:
323 |                 self.logger.error("Failed to complete test generation")
324 |                 return False
325 | 
326 |             response_final_data = self._parse_testgen_response(response_final)
327 |             if not response_final_data:
328 |                 return False
329 | 
330 |             # Validate final response structure
331 |             if response_final_data.get("status") != "calling_expert_analysis":
332 |                 self.logger.error(
333 |                     f"Expected status 'calling_expert_analysis', got '{response_final_data.get('status')}'"
334 |                 )
335 |                 return False
336 | 
337 |             if not response_final_data.get("test_generation_complete"):
338 |                 self.logger.error("Expected test_generation_complete=true for final step")
339 |                 return False
340 | 
341 |             # Check for expert analysis
342 |             if "expert_analysis" not in response_final_data:
343 |                 self.logger.error("Missing expert_analysis in final response")
344 |                 return False
345 | 
346 |             expert_analysis = response_final_data.get("expert_analysis", {})
347 | 
348 |             # Check for expected analysis content
349 |             analysis_text = json.dumps(expert_analysis, ensure_ascii=False).lower()
350 | 
351 |             # Look for test generation indicators
352 |             test_indicators = ["test", "edge", "boundary", "error", "coverage", "pytest"]
353 |             found_indicators = sum(1 for indicator in test_indicators if indicator in analysis_text)
354 | 
355 |             if found_indicators >= 4:
356 |                 self.logger.info("    ✅ Expert analysis provided comprehensive test suggestions")
357 |             else:
358 |                 self.logger.warning(
359 |                     f"    ⚠️ Expert analysis may not have fully addressed test generation (found {found_indicators}/6 indicators)"
360 |                 )
361 | 
362 |             # Check complete test generation summary
363 |             if "complete_test_generation" not in response_final_data:
364 |                 self.logger.error("Missing complete_test_generation in final response")
365 |                 return False
366 | 
367 |             complete_generation = response_final_data["complete_test_generation"]
368 |             if not complete_generation.get("relevant_context"):
369 |                 self.logger.error("Missing relevant context in complete test generation")
370 |                 return False
371 | 
372 |             self.logger.info("    ✅ Complete test generation with expert analysis successful")
373 |             return True
374 | 
375 |         except Exception as e:
376 |             self.logger.error(f"Complete test generation test failed: {e}")
377 |             return False
378 | 
379 |     def _test_certain_confidence(self) -> bool:
380 |         """Test certain confidence behavior - should skip expert analysis"""
381 |         try:
382 |             self.logger.info("  1.4: Testing certain confidence behavior")
383 | 
384 |             # Test certain confidence - should skip expert analysis
385 |             self.logger.info("    1.4.1: Certain confidence test generation")
386 |             response_certain, _ = self.call_mcp_tool(
387 |                 "testgen",
388 |                 {
389 |                     "step": "I have fully analyzed the code and identified all test scenarios with 100% certainty. Test plan is complete.",
390 |                     "step_number": 1,
391 |                     "total_steps": 1,
392 |                     "next_step_required": False,  # Final step
393 |                     "findings": "Complete test coverage plan: all functions covered with normal cases, edge cases, and error conditions. Ready for implementation.",
394 |                     "files_checked": [self.calculator_file],
395 |                     "relevant_files": [self.calculator_file],
396 |                     "relevant_context": ["add", "subtract", "multiply", "divide", "calculate_percentage", "power"],
397 |                     "confidence": "certain",  # This should skip expert analysis
398 |                     "model": "flash",
399 |                 },
400 |             )
401 | 
402 |             if not response_certain:
403 |                 self.logger.error("Failed to test certain confidence")
404 |                 return False
405 | 
406 |             response_certain_data = self._parse_testgen_response(response_certain)
407 |             if not response_certain_data:
408 |                 return False
409 | 
410 |             # Validate certain confidence response - should skip expert analysis
411 |             if response_certain_data.get("status") != "test_generation_complete_ready_for_implementation":
412 |                 self.logger.error(
413 |                     f"Expected status 'test_generation_complete_ready_for_implementation', got '{response_certain_data.get('status')}'"
414 |                 )
415 |                 return False
416 | 
417 |             if not response_certain_data.get("skip_expert_analysis"):
418 |                 self.logger.error("Expected skip_expert_analysis=true for certain confidence")
419 |                 return False
420 | 
421 |             expert_analysis = response_certain_data.get("expert_analysis", {})
422 |             if expert_analysis.get("status") != "skipped_due_to_certain_test_confidence":
423 |                 self.logger.error("Expert analysis should be skipped for certain confidence")
424 |                 return False
425 | 
426 |             self.logger.info("    ✅ Certain confidence behavior working correctly")
427 |             return True
428 | 
429 |         except Exception as e:
430 |             self.logger.error(f"Certain confidence test failed: {e}")
431 |             return False
432 | 
433 |     def call_mcp_tool(self, tool_name: str, params: dict) -> tuple[Optional[str], Optional[str]]:
434 |         """Call an MCP tool in-process - override for testgen-specific response handling"""
435 |         # Use in-process implementation to maintain conversation memory
436 |         response_text, _ = self.call_mcp_tool_direct(tool_name, params)
437 | 
438 |         if not response_text:
439 |             return None, None
440 | 
441 |         # Extract continuation_id from testgen response specifically
442 |         continuation_id = self._extract_testgen_continuation_id(response_text)
443 | 
444 |         return response_text, continuation_id
445 | 
446 |     def _extract_testgen_continuation_id(self, response_text: str) -> Optional[str]:
447 |         """Extract continuation_id from testgen response"""
448 |         try:
449 |             # Parse the response
450 |             response_data = json.loads(response_text)
451 |             return response_data.get("continuation_id")
452 | 
453 |         except json.JSONDecodeError as e:
454 |             self.logger.debug(f"Failed to parse response for testgen continuation_id: {e}")
455 |             return None
456 | 
457 |     def _parse_testgen_response(self, response_text: str) -> dict:
458 |         """Parse testgen tool JSON response"""
459 |         try:
460 |             # Parse the response - it should be direct JSON
461 |             return json.loads(response_text)
462 | 
463 |         except json.JSONDecodeError as e:
464 |             self.logger.error(f"Failed to parse testgen response as JSON: {e}")
465 |             self.logger.error(f"Response text: {response_text[:500]}...")
466 |             return {}
467 | 
468 |     def _validate_step_response(
469 |         self,
470 |         response_data: dict,
471 |         expected_step: int,
472 |         expected_total: int,
473 |         expected_next_required: bool,
474 |         expected_status: str,
475 |     ) -> bool:
476 |         """Validate a test generation step response structure"""
477 |         try:
478 |             # Check status
479 |             if response_data.get("status") != expected_status:
480 |                 self.logger.error(f"Expected status '{expected_status}', got '{response_data.get('status')}'")
481 |                 return False
482 | 
483 |             # Check step number
484 |             if response_data.get("step_number") != expected_step:
485 |                 self.logger.error(f"Expected step_number {expected_step}, got {response_data.get('step_number')}")
486 |                 return False
487 | 
488 |             # Check total steps
489 |             if response_data.get("total_steps") != expected_total:
490 |                 self.logger.error(f"Expected total_steps {expected_total}, got {response_data.get('total_steps')}")
491 |                 return False
492 | 
493 |             # Check next_step_required
494 |             if response_data.get("next_step_required") != expected_next_required:
495 |                 self.logger.error(
496 |                     f"Expected next_step_required {expected_next_required}, got {response_data.get('next_step_required')}"
497 |                 )
498 |                 return False
499 | 
500 |             # Check test_generation_status exists
501 |             if "test_generation_status" not in response_data:
502 |                 self.logger.error("Missing test_generation_status in response")
503 |                 return False
504 | 
505 |             # Check next_steps guidance
506 |             if not response_data.get("next_steps"):
507 |                 self.logger.error("Missing next_steps guidance in response")
508 |                 return False
509 | 
510 |             return True
511 | 
512 |         except Exception as e:
513 |             self.logger.error(f"Error validating step response: {e}")
514 |             return False
515 | 
516 |     def _test_context_aware_file_embedding(self) -> bool:
517 |         """Test context-aware file embedding optimization"""
518 |         try:
519 |             self.logger.info("  1.5: Testing context-aware file embedding")
520 | 
521 |             # Create additional test files
522 |             utils_code = """#!/usr/bin/env python3
523 | def validate_number(n):
524 |     \"\"\"Validate if input is a number\"\"\"
525 |     return isinstance(n, (int, float))
526 | 
527 | def format_result(result):
528 |     \"\"\"Format calculation result\"\"\"
529 |     if isinstance(result, float):
530 |         return round(result, 2)
531 |     return result
532 | """
533 | 
534 |             math_helpers_code = """#!/usr/bin/env python3
535 | import math
536 | 
537 | def factorial(n):
538 |     \"\"\"Calculate factorial of n\"\"\"
539 |     if n < 0:
540 |         raise ValueError("Factorial not defined for negative numbers")
541 |     return math.factorial(n)
542 | 
543 | def is_prime(n):
544 |     \"\"\"Check if number is prime\"\"\"
545 |     if n < 2:
546 |         return False
547 |     for i in range(2, int(n**0.5) + 1):
548 |         if n % i == 0:
549 |             return False
550 |     return True
551 | """
552 | 
553 |             # Create test files
554 |             utils_file = self.create_additional_test_file("utils.py", utils_code)
555 |             math_file = self.create_additional_test_file("math_helpers.py", math_helpers_code)
556 | 
557 |             # Test 1: New conversation, intermediate step - should only reference files
558 |             self.logger.info("    1.5.1: New conversation intermediate step (should reference only)")
559 |             response1, continuation_id = self.call_mcp_tool(
560 |                 "testgen",
561 |                 {
562 |                     "step": "Starting test generation for utility modules",
563 |                     "step_number": 1,
564 |                     "total_steps": 3,
565 |                     "next_step_required": True,  # Intermediate step
566 |                     "findings": "Initial analysis of utility functions",
567 |                     "files_checked": [utils_file, math_file],
568 |                     "relevant_files": [utils_file],  # This should be referenced, not embedded
569 |                     "relevant_context": ["validate_number", "format_result"],
570 |                     "confidence": "low",
571 |                     "model": "flash",
572 |                 },
573 |             )
574 | 
575 |             if not response1 or not continuation_id:
576 |                 self.logger.error("Failed to start context-aware file embedding test")
577 |                 return False
578 | 
579 |             response1_data = self._parse_testgen_response(response1)
580 |             if not response1_data:
581 |                 return False
582 | 
583 |             # Check file context - should be reference_only for intermediate step
584 |             file_context = response1_data.get("file_context", {})
585 |             if file_context.get("type") != "reference_only":
586 |                 self.logger.error(f"Expected reference_only file context, got: {file_context.get('type')}")
587 |                 return False
588 | 
589 |             self.logger.info("    ✅ Intermediate step correctly uses reference_only file context")
590 | 
591 |             # Test 2: Final step - should embed files for expert analysis
592 |             self.logger.info("    1.5.2: Final step (should embed files)")
593 |             response2, _ = self.call_mcp_tool(
594 |                 "testgen",
595 |                 {
596 |                     "step": "Test planning complete - all test scenarios identified",
597 |                     "step_number": 2,
598 |                     "total_steps": 2,
599 |                     "next_step_required": False,  # Final step - should embed files
600 |                     "continuation_id": continuation_id,
601 |                     "findings": "Complete test plan for all utility functions with edge cases",
602 |                     "files_checked": [utils_file, math_file],
603 |                     "relevant_files": [utils_file, math_file],  # Should be fully embedded
604 |                     "relevant_context": ["validate_number", "format_result", "factorial", "is_prime"],
605 |                     "confidence": "high",
606 |                     "model": "flash",
607 |                 },
608 |             )
609 | 
610 |             if not response2:
611 |                 self.logger.error("Failed to complete to final step")
612 |                 return False
613 | 
614 |             response2_data = self._parse_testgen_response(response2)
615 |             if not response2_data:
616 |                 return False
617 | 
618 |             # Check file context - should be fully_embedded for final step
619 |             file_context2 = response2_data.get("file_context", {})
620 |             if file_context2.get("type") != "fully_embedded":
621 |                 self.logger.error(
622 |                     f"Expected fully_embedded file context for final step, got: {file_context2.get('type')}"
623 |                 )
624 |                 return False
625 | 
626 |             # Verify expert analysis was called for final step
627 |             if response2_data.get("status") != "calling_expert_analysis":
628 |                 self.logger.error("Final step should trigger expert analysis")
629 |                 return False
630 | 
631 |             self.logger.info("    ✅ Context-aware file embedding test completed successfully")
632 |             return True
633 | 
634 |         except Exception as e:
635 |             self.logger.error(f"Context-aware file embedding test failed: {e}")
636 |             return False
637 | 
638 |     def _test_multi_step_test_planning(self) -> bool:
639 |         """Test multi-step test planning with complex code"""
640 |         try:
641 |             self.logger.info("  1.6: Testing multi-step test planning")
642 | 
643 |             # Create a complex class to test
644 |             complex_code = """#!/usr/bin/env python3
645 | import asyncio
646 | from typing import List, Dict, Optional
647 | 
648 | class DataProcessor:
649 |     \"\"\"Complex data processor with async operations\"\"\"
650 | 
651 |     def __init__(self, batch_size: int = 100):
652 |         self.batch_size = batch_size
653 |         self.processed_count = 0
654 |         self.error_count = 0
655 |         self.cache: Dict[str, any] = {}
656 | 
657 |     async def process_batch(self, items: List[dict]) -> List[dict]:
658 |         \"\"\"Process a batch of items asynchronously\"\"\"
659 |         if not items:
660 |             return []
661 | 
662 |         if len(items) > self.batch_size:
663 |             raise ValueError(f"Batch size {len(items)} exceeds limit {self.batch_size}")
664 | 
665 |         results = []
666 |         for item in items:
667 |             try:
668 |                 result = await self._process_single_item(item)
669 |                 results.append(result)
670 |                 self.processed_count += 1
671 |             except Exception as e:
672 |                 self.error_count += 1
673 |                 results.append({"error": str(e), "item": item})
674 | 
675 |         return results
676 | 
677 |     async def _process_single_item(self, item: dict) -> dict:
678 |         \"\"\"Process a single item with caching\"\"\"
679 |         item_id = item.get('id')
680 |         if not item_id:
681 |             raise ValueError("Item must have an ID")
682 | 
683 |         # Check cache
684 |         if item_id in self.cache:
685 |             return self.cache[item_id]
686 | 
687 |         # Simulate async processing
688 |         await asyncio.sleep(0.01)
689 | 
690 |         processed = {
691 |             'id': item_id,
692 |             'processed': True,
693 |             'value': item.get('value', 0) * 2
694 |         }
695 | 
696 |         # Cache result
697 |         self.cache[item_id] = processed
698 |         return processed
699 | 
700 |     def get_stats(self) -> Dict[str, int]:
701 |         \"\"\"Get processing statistics\"\"\"
702 |         return {
703 |             'processed': self.processed_count,
704 |             'errors': self.error_count,
705 |             'cache_size': len(self.cache),
706 |             'success_rate': self.processed_count / (self.processed_count + self.error_count) if (self.processed_count + self.error_count) > 0 else 0
707 |         }
708 | """
709 | 
710 |             # Create test file
711 |             processor_file = self.create_additional_test_file("data_processor.py", complex_code)
712 | 
713 |             # Step 1: Start investigation
714 |             self.logger.info("    1.6.1: Step 1 - Start complex test planning")
715 |             response1, continuation_id = self.call_mcp_tool(
716 |                 "testgen",
717 |                 {
718 |                     "step": "Analyzing complex DataProcessor class for comprehensive test generation",
719 |                     "step_number": 1,
720 |                     "total_steps": 4,
721 |                     "next_step_required": True,
722 |                     "findings": "DataProcessor is an async class with caching, error handling, and statistics. Need async test patterns.",
723 |                     "files_checked": [processor_file],
724 |                     "relevant_files": [processor_file],
725 |                     "relevant_context": ["DataProcessor", "process_batch", "_process_single_item", "get_stats"],
726 |                     "confidence": "low",
727 |                     "model": "flash",
728 |                 },
729 |             )
730 | 
731 |             if not response1 or not continuation_id:
732 |                 self.logger.error("Failed to start multi-step test planning")
733 |                 return False
734 | 
735 |             response1_data = self._parse_testgen_response(response1)
736 | 
737 |             # Validate step 1
738 |             file_context1 = response1_data.get("file_context", {})
739 |             if file_context1.get("type") != "reference_only":
740 |                 self.logger.error("Step 1 should use reference_only file context")
741 |                 return False
742 | 
743 |             self.logger.info("    ✅ Step 1: Started complex test planning")
744 | 
745 |             # Step 2: Analyze async patterns
746 |             self.logger.info("    1.6.2: Step 2 - Async pattern analysis")
747 |             response2, _ = self.call_mcp_tool(
748 |                 "testgen",
749 |                 {
750 |                     "step": "Analyzing async patterns and edge cases for testing",
751 |                     "step_number": 2,
752 |                     "total_steps": 4,
753 |                     "next_step_required": True,
754 |                     "continuation_id": continuation_id,
755 |                     "findings": "Key test areas: async batch processing, cache behavior, error handling, batch size limits, empty items, statistics calculation",
756 |                     "files_checked": [processor_file],
757 |                     "relevant_files": [processor_file],
758 |                     "relevant_context": ["process_batch", "_process_single_item"],
759 |                     "confidence": "medium",
760 |                     "model": "flash",
761 |                 },
762 |             )
763 | 
764 |             if not response2:
765 |                 self.logger.error("Failed to continue to step 2")
766 |                 return False
767 | 
768 |             self.logger.info("    ✅ Step 2: Async patterns analyzed")
769 | 
770 |             # Step 3: Edge case identification
771 |             self.logger.info("    1.6.3: Step 3 - Edge case identification")
772 |             response3, _ = self.call_mcp_tool(
773 |                 "testgen",
774 |                 {
775 |                     "step": "Identifying all edge cases and boundary conditions",
776 |                     "step_number": 3,
777 |                     "total_steps": 4,
778 |                     "next_step_required": True,
779 |                     "continuation_id": continuation_id,
780 |                     "findings": "Edge cases: empty batch, oversized batch, items without ID, cache hits/misses, concurrent processing, error accumulation",
781 |                     "files_checked": [processor_file],
782 |                     "relevant_files": [processor_file],
783 |                     "confidence": "high",
784 |                     "model": "flash",
785 |                 },
786 |             )
787 | 
788 |             if not response3:
789 |                 self.logger.error("Failed to continue to step 3")
790 |                 return False
791 | 
792 |             self.logger.info("    ✅ Step 3: Edge cases identified")
793 | 
794 |             # Step 4: Final test plan with expert analysis
795 |             self.logger.info("    1.6.4: Step 4 - Complete test plan")
796 |             response4, _ = self.call_mcp_tool(
797 |                 "testgen",
798 |                 {
799 |                     "step": "Test planning complete with comprehensive coverage strategy",
800 |                     "step_number": 4,
801 |                     "total_steps": 4,
802 |                     "next_step_required": False,  # Final step
803 |                     "continuation_id": continuation_id,
804 |                     "findings": "Complete async test suite plan: unit tests for each method, integration tests for batch processing, edge case coverage, performance tests",
805 |                     "files_checked": [processor_file],
806 |                     "relevant_files": [processor_file],
807 |                     "confidence": "high",
808 |                     "model": "flash",
809 |                 },
810 |             )
811 | 
812 |             if not response4:
813 |                 self.logger.error("Failed to complete to final step")
814 |                 return False
815 | 
816 |             response4_data = self._parse_testgen_response(response4)
817 | 
818 |             # Validate final step
819 |             if response4_data.get("status") != "calling_expert_analysis":
820 |                 self.logger.error("Final step should trigger expert analysis")
821 |                 return False
822 | 
823 |             file_context4 = response4_data.get("file_context", {})
824 |             if file_context4.get("type") != "fully_embedded":
825 |                 self.logger.error("Final step should use fully_embedded file context")
826 |                 return False
827 | 
828 |             self.logger.info("    ✅ Multi-step test planning completed successfully")
829 |             return True
830 | 
831 |         except Exception as e:
832 |             self.logger.error(f"Multi-step test planning test failed: {e}")
833 |             return False
834 | 
```

--------------------------------------------------------------------------------
/tests/test_model_restrictions.py:
--------------------------------------------------------------------------------

```python
  1 | """Tests for model restriction functionality."""
  2 | 
  3 | import os
  4 | from unittest.mock import MagicMock, patch
  5 | 
  6 | import pytest
  7 | 
  8 | from providers.gemini import GeminiModelProvider
  9 | from providers.openai import OpenAIModelProvider
 10 | from providers.shared import ProviderType
 11 | from utils.model_restrictions import ModelRestrictionService
 12 | 
 13 | 
 14 | class TestModelRestrictionService:
 15 |     """Test cases for ModelRestrictionService."""
 16 | 
 17 |     def test_no_restrictions_by_default(self):
 18 |         """Test that no restrictions exist when env vars are not set."""
 19 |         with patch.dict(os.environ, {}, clear=True):
 20 |             service = ModelRestrictionService()
 21 | 
 22 |             # Should allow all models
 23 |             assert service.is_allowed(ProviderType.OPENAI, "o3")
 24 |             assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
 25 |             assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
 26 |             assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash")
 27 |             assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4")
 28 |             assert service.is_allowed(ProviderType.OPENROUTER, "openai/o3")
 29 | 
 30 |             # Should have no restrictions
 31 |             assert not service.has_restrictions(ProviderType.OPENAI)
 32 |             assert not service.has_restrictions(ProviderType.GOOGLE)
 33 |             assert not service.has_restrictions(ProviderType.OPENROUTER)
 34 | 
 35 |     def test_load_single_model_restriction(self):
 36 |         """Test loading a single allowed model."""
 37 |         with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini"}):
 38 |             service = ModelRestrictionService()
 39 | 
 40 |             # Should only allow o3-mini
 41 |             assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
 42 |             assert not service.is_allowed(ProviderType.OPENAI, "o3")
 43 |             assert not service.is_allowed(ProviderType.OPENAI, "o4-mini")
 44 | 
 45 |             # Google and OpenRouter should have no restrictions
 46 |             assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
 47 |             assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4")
 48 | 
 49 |     def test_load_multiple_models_restriction(self):
 50 |         """Test loading multiple allowed models."""
 51 |         with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
 52 |             # Instantiate providers so alias resolution for allow-lists is available
 53 |             openai_provider = OpenAIModelProvider(api_key="test-key")
 54 |             gemini_provider = GeminiModelProvider(api_key="test-key")
 55 | 
 56 |             from providers.registry import ModelProviderRegistry
 57 | 
 58 |             def fake_get_provider(provider_type, force_new=False):
 59 |                 mapping = {
 60 |                     ProviderType.OPENAI: openai_provider,
 61 |                     ProviderType.GOOGLE: gemini_provider,
 62 |                 }
 63 |                 return mapping.get(provider_type)
 64 | 
 65 |             with patch.object(ModelProviderRegistry, "get_provider", side_effect=fake_get_provider):
 66 | 
 67 |                 service = ModelRestrictionService()
 68 | 
 69 |                 # Check OpenAI models
 70 |                 assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
 71 |                 assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
 72 |                 assert not service.is_allowed(ProviderType.OPENAI, "o3")
 73 | 
 74 |                 # Check Google models
 75 |                 assert service.is_allowed(ProviderType.GOOGLE, "flash")
 76 |                 assert service.is_allowed(ProviderType.GOOGLE, "pro")
 77 |                 assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
 78 | 
 79 |     def test_case_insensitive_and_whitespace_handling(self):
 80 |         """Test that model names are case-insensitive and whitespace is trimmed."""
 81 |         with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": " O3-MINI , o4-Mini "}):
 82 |             service = ModelRestrictionService()
 83 | 
 84 |             # Should work with any case
 85 |             assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
 86 |             assert service.is_allowed(ProviderType.OPENAI, "O3-MINI")
 87 |             assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
 88 |             assert service.is_allowed(ProviderType.OPENAI, "O4-Mini")
 89 | 
 90 |     def test_empty_string_allows_all(self):
 91 |         """Test that empty string allows all models (same as unset)."""
 92 |         with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "", "GOOGLE_ALLOWED_MODELS": "flash"}):
 93 |             service = ModelRestrictionService()
 94 | 
 95 |             # OpenAI should allow all models (empty string = no restrictions)
 96 |             assert service.is_allowed(ProviderType.OPENAI, "o3")
 97 |             assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
 98 |             assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
 99 | 
100 |             # Google should only allow flash (and its resolved name)
101 |             assert service.is_allowed(ProviderType.GOOGLE, "flash")
102 |             assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash", "flash")
103 |             assert not service.is_allowed(ProviderType.GOOGLE, "pro")
104 |             assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro", "pro")
105 | 
106 |     def test_filter_models(self):
107 |         """Test filtering a list of models based on restrictions."""
108 |         with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
109 |             service = ModelRestrictionService()
110 | 
111 |             models = ["o3", "o3-mini", "o4-mini", "o3-pro"]
112 |             filtered = service.filter_models(ProviderType.OPENAI, models)
113 | 
114 |             assert filtered == ["o3-mini", "o4-mini"]
115 | 
116 |     def test_get_allowed_models(self):
117 |         """Test getting the set of allowed models."""
118 |         with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
119 |             service = ModelRestrictionService()
120 | 
121 |             allowed = service.get_allowed_models(ProviderType.OPENAI)
122 |             assert allowed == {"o3-mini", "o4-mini"}
123 | 
124 |             # No restrictions for Google
125 |             assert service.get_allowed_models(ProviderType.GOOGLE) is None
126 | 
127 |     def test_shorthand_names_in_restrictions(self):
128 |         """Test that shorthand names work in restrictions."""
129 |         with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4mini,o3mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
130 |             # Instantiate providers so the registry can resolve aliases
131 |             OpenAIModelProvider(api_key="test-key")
132 |             GeminiModelProvider(api_key="test-key")
133 | 
134 |             service = ModelRestrictionService()
135 | 
136 |             # When providers check models, they pass both resolved and original names
137 |             # OpenAI: 'o4mini' shorthand allows o4-mini
138 |             assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "o4mini")  # How providers actually call it
139 |             assert service.is_allowed(ProviderType.OPENAI, "o4-mini")  # Canonical should also be allowed
140 | 
141 |             # OpenAI: o3-mini allowed directly
142 |             assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
143 |             assert not service.is_allowed(ProviderType.OPENAI, "o3")
144 | 
145 |             # Google should allow both models via shorthands
146 |             assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash", "flash")
147 |             assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro", "pro")
148 | 
149 |             # Also test that full names work when specified in restrictions
150 |             assert service.is_allowed(ProviderType.OPENAI, "o3-mini", "o3mini")  # Even with shorthand
151 | 
152 |     def test_validation_against_known_models(self, caplog):
153 |         """Test validation warnings for unknown models."""
154 |         with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mimi"}):  # Note the typo: o4-mimi
155 |             service = ModelRestrictionService()
156 | 
157 |             # Create mock provider with known models
158 |             mock_provider = MagicMock()
159 |             mock_provider.MODEL_CAPABILITIES = {
160 |                 "o3": {"context_window": 200000},
161 |                 "o3-mini": {"context_window": 200000},
162 |                 "o4-mini": {"context_window": 200000},
163 |             }
164 |             mock_provider.list_models.return_value = ["o3", "o3-mini", "o4-mini"]
165 | 
166 |             provider_instances = {ProviderType.OPENAI: mock_provider}
167 |             service.validate_against_known_models(provider_instances)
168 | 
169 |             # Should have logged a warning about the typo
170 |             assert "o4-mimi" in caplog.text
171 |             assert "not a recognized" in caplog.text
172 | 
173 |     def test_openrouter_model_restrictions(self):
174 |         """Test OpenRouter model restrictions functionality."""
175 |         with patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,sonnet"}):
176 |             service = ModelRestrictionService()
177 | 
178 |             # Should only allow specified OpenRouter models
179 |             assert service.is_allowed(ProviderType.OPENROUTER, "opus")
180 |             assert service.is_allowed(ProviderType.OPENROUTER, "sonnet")
181 |             assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4", "opus")  # With original name
182 |             assert not service.is_allowed(ProviderType.OPENROUTER, "haiku")
183 |             assert not service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-haiku")
184 |             assert not service.is_allowed(ProviderType.OPENROUTER, "mistral-large")
185 | 
186 |             # Other providers should have no restrictions
187 |             assert service.is_allowed(ProviderType.OPENAI, "o3")
188 |             assert service.is_allowed(ProviderType.GOOGLE, "pro")
189 | 
190 |             # Should have restrictions for OpenRouter
191 |             assert service.has_restrictions(ProviderType.OPENROUTER)
192 |             assert not service.has_restrictions(ProviderType.OPENAI)
193 |             assert not service.has_restrictions(ProviderType.GOOGLE)
194 | 
195 |     def test_openrouter_filter_models(self):
196 |         """Test filtering OpenRouter models based on restrictions."""
197 |         with patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,mistral"}):
198 |             service = ModelRestrictionService()
199 | 
200 |             models = ["opus", "sonnet", "haiku", "mistral", "llama"]
201 |             filtered = service.filter_models(ProviderType.OPENROUTER, models)
202 | 
203 |             assert filtered == ["opus", "mistral"]
204 | 
205 |     def test_combined_provider_restrictions(self):
206 |         """Test that restrictions work correctly when set for multiple providers."""
207 |         with patch.dict(
208 |             os.environ,
209 |             {
210 |                 "OPENAI_ALLOWED_MODELS": "o3-mini",
211 |                 "GOOGLE_ALLOWED_MODELS": "flash",
212 |                 "OPENROUTER_ALLOWED_MODELS": "opus,sonnet",
213 |             },
214 |         ):
215 |             service = ModelRestrictionService()
216 | 
217 |             # OpenAI restrictions
218 |             assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
219 |             assert not service.is_allowed(ProviderType.OPENAI, "o3")
220 | 
221 |             # Google restrictions
222 |             assert service.is_allowed(ProviderType.GOOGLE, "flash")
223 |             assert not service.is_allowed(ProviderType.GOOGLE, "pro")
224 | 
225 |             # OpenRouter restrictions
226 |             assert service.is_allowed(ProviderType.OPENROUTER, "opus")
227 |             assert service.is_allowed(ProviderType.OPENROUTER, "sonnet")
228 |             assert not service.is_allowed(ProviderType.OPENROUTER, "haiku")
229 | 
230 |             # All providers should have restrictions
231 |             assert service.has_restrictions(ProviderType.OPENAI)
232 |             assert service.has_restrictions(ProviderType.GOOGLE)
233 |             assert service.has_restrictions(ProviderType.OPENROUTER)
234 | 
235 | 
236 | class TestProviderIntegration:
237 |     """Test integration with actual providers."""
238 | 
239 |     @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini"})
240 |     def test_openai_provider_respects_restrictions(self):
241 |         """Test that OpenAI provider respects restrictions."""
242 |         # Clear any cached restriction service
243 |         import utils.model_restrictions
244 | 
245 |         utils.model_restrictions._restriction_service = None
246 | 
247 |         provider = OpenAIModelProvider(api_key="test-key")
248 | 
249 |         # Should validate allowed model
250 |         assert provider.validate_model_name("o3-mini")
251 | 
252 |         # Should not validate disallowed model
253 |         assert not provider.validate_model_name("o3")
254 | 
255 |         # get_capabilities should raise for disallowed model
256 |         with pytest.raises(ValueError) as exc_info:
257 |             provider.get_capabilities("o3")
258 |         assert "not allowed by restriction policy" in str(exc_info.value)
259 | 
260 |     @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash,flash"})
261 |     def test_gemini_provider_respects_restrictions(self):
262 |         """Test that Gemini provider respects restrictions."""
263 |         # Clear any cached restriction service
264 |         import utils.model_restrictions
265 | 
266 |         utils.model_restrictions._restriction_service = None
267 | 
268 |         provider = GeminiModelProvider(api_key="test-key")
269 | 
270 |         # Should validate allowed models (both shorthand and full name allowed)
271 |         assert provider.validate_model_name("flash")
272 |         assert provider.validate_model_name("gemini-2.5-flash")
273 | 
274 |         # Should not validate disallowed model
275 |         assert not provider.validate_model_name("pro")
276 |         assert not provider.validate_model_name("gemini-2.5-pro")
277 | 
278 |         # get_capabilities should raise for disallowed model
279 |         with pytest.raises(ValueError) as exc_info:
280 |             provider.get_capabilities("pro")
281 |         assert "not allowed by restriction policy" in str(exc_info.value)
282 | 
283 |     @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"})
284 |     def test_gemini_parameter_order_regression_protection(self):
285 |         """Test that prevents regression of parameter order bug in is_allowed calls.
286 | 
287 |         This test specifically catches the bug where parameters were incorrectly
288 |         passed as (provider, user_input, resolved_name) instead of
289 |         (provider, resolved_name, user_input).
290 | 
291 |         The bug was subtle because the is_allowed method uses OR logic, so it
292 |         worked in most cases by accident. This test creates a scenario where
293 |         the parameter order matters.
294 |         """
295 |         # Clear any cached restriction service
296 |         import utils.model_restrictions
297 | 
298 |         utils.model_restrictions._restriction_service = None
299 | 
300 |         provider = GeminiModelProvider(api_key="test-key")
301 | 
302 |         from providers.registry import ModelProviderRegistry
303 | 
304 |         with patch.object(ModelProviderRegistry, "get_provider", return_value=provider):
305 | 
306 |             # Test case: Only alias "flash" is allowed, not the full name
307 |             # If parameters are in wrong order, this test will catch it
308 | 
309 |             # Should allow "flash" alias
310 |             assert provider.validate_model_name("flash")
311 | 
312 |             # Should allow getting capabilities for "flash"
313 |             capabilities = provider.get_capabilities("flash")
314 |             assert capabilities.model_name == "gemini-2.5-flash"
315 | 
316 |             # Canonical form should also be allowed now that alias is on the allowlist
317 |             assert provider.validate_model_name("gemini-2.5-flash")
318 |             # Unrelated models remain blocked
319 |             assert not provider.validate_model_name("pro")
320 |             assert not provider.validate_model_name("gemini-2.5-pro")
321 | 
322 |     @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
323 |     def test_gemini_parameter_order_edge_case_full_name_only(self):
324 |         """Test parameter order with only full name allowed, not alias.
325 | 
326 |         This is the reverse scenario - only the full canonical name is allowed,
327 |         not the shorthand alias. This tests that the parameter order is correct
328 |         when resolving aliases.
329 |         """
330 |         # Clear any cached restriction service
331 |         import utils.model_restrictions
332 | 
333 |         utils.model_restrictions._restriction_service = None
334 | 
335 |         provider = GeminiModelProvider(api_key="test-key")
336 | 
337 |         # Should allow full name
338 |         assert provider.validate_model_name("gemini-2.5-flash")
339 | 
340 |         # Should also allow alias that resolves to allowed full name
341 |         # This works because is_allowed checks both resolved_name and original_name
342 |         assert provider.validate_model_name("flash")
343 | 
344 |         # Should not allow "pro" alias
345 |         assert not provider.validate_model_name("pro")
346 |         assert not provider.validate_model_name("gemini-2.5-pro")
347 | 
348 | 
349 | class TestCustomProviderOpenRouterRestrictions:
350 |     """Test custom provider integration with OpenRouter restrictions."""
351 | 
352 |     @patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,sonnet", "OPENROUTER_API_KEY": "test-key"})
353 |     def test_custom_provider_respects_openrouter_restrictions(self):
354 |         """Test that custom provider correctly defers OpenRouter models to OpenRouter provider."""
355 |         # Clear any cached restriction service
356 |         import utils.model_restrictions
357 | 
358 |         utils.model_restrictions._restriction_service = None
359 | 
360 |         from providers.custom import CustomProvider
361 | 
362 |         provider = CustomProvider(base_url="http://test.com/v1")
363 | 
364 |         # CustomProvider should NOT validate OpenRouter models - they should be deferred to OpenRouter
365 |         assert not provider.validate_model_name("opus")
366 |         assert not provider.validate_model_name("sonnet")
367 |         assert not provider.validate_model_name("haiku")
368 | 
369 |         # Should still validate custom models defined in conf/custom_models.json
370 |         assert provider.validate_model_name("local-llama")
371 | 
372 |     @patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus", "OPENROUTER_API_KEY": "test-key"})
373 |     def test_custom_provider_openrouter_capabilities_restrictions(self):
374 |         """Test that custom provider's get_capabilities correctly handles OpenRouter models."""
375 |         # Clear any cached restriction service
376 |         import utils.model_restrictions
377 | 
378 |         utils.model_restrictions._restriction_service = None
379 | 
380 |         from providers.custom import CustomProvider
381 | 
382 |         provider = CustomProvider(base_url="http://test.com/v1")
383 | 
384 |         # For OpenRouter models, CustomProvider should defer by raising
385 |         with pytest.raises(ValueError):
386 |             provider.get_capabilities("opus")
387 | 
388 |         # Should raise for disallowed OpenRouter model (still defers)
389 |         with pytest.raises(ValueError):
390 |             provider.get_capabilities("haiku")
391 | 
392 |         # Should still work for custom models
393 |         capabilities = provider.get_capabilities("local-llama")
394 |         assert capabilities.provider == ProviderType.CUSTOM
395 | 
396 |     @patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus"}, clear=False)
397 |     def test_custom_provider_no_openrouter_key_ignores_restrictions(self):
398 |         """Test that when OpenRouter key is not set, cloud models are rejected regardless of restrictions."""
399 |         # Make sure OPENROUTER_API_KEY is not set
400 |         if "OPENROUTER_API_KEY" in os.environ:
401 |             del os.environ["OPENROUTER_API_KEY"]
402 |         # Clear any cached restriction service
403 |         import utils.model_restrictions
404 | 
405 |         utils.model_restrictions._restriction_service = None
406 | 
407 |         from providers.custom import CustomProvider
408 | 
409 |         provider = CustomProvider(base_url="http://test.com/v1")
410 | 
411 |         # Should not validate OpenRouter models when key is not available
412 |         assert not provider.validate_model_name("opus")  # Even though it's in allowed list
413 |         assert not provider.validate_model_name("haiku")
414 | 
415 |         # Should still validate custom models
416 |         assert provider.validate_model_name("local-llama")
417 | 
418 |     @patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "", "OPENROUTER_API_KEY": "test-key"})
419 |     def test_custom_provider_empty_restrictions_allows_all_openrouter(self):
420 |         """Test that custom provider correctly defers OpenRouter models regardless of restrictions."""
421 |         # Clear any cached restriction service
422 |         import utils.model_restrictions
423 | 
424 |         utils.model_restrictions._restriction_service = None
425 | 
426 |         from providers.custom import CustomProvider
427 | 
428 |         provider = CustomProvider(base_url="http://test.com/v1")
429 | 
430 |         # CustomProvider should NOT validate OpenRouter models - they should be deferred to OpenRouter
431 |         assert not provider.validate_model_name("opus")
432 |         assert not provider.validate_model_name("sonnet")
433 |         assert not provider.validate_model_name("haiku")
434 | 
435 | 
436 | class TestRegistryIntegration:
437 |     """Test integration with ModelProviderRegistry."""
438 | 
439 |     @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GOOGLE_ALLOWED_MODELS": "flash"})
440 |     def test_registry_with_shorthand_restrictions(self):
441 |         """Test that registry handles shorthand restrictions correctly."""
442 |         # Clear cached restriction service
443 |         import utils.model_restrictions
444 | 
445 |         utils.model_restrictions._restriction_service = None
446 | 
447 |         from providers.registry import ModelProviderRegistry
448 | 
449 |         # Clear registry cache
450 |         ModelProviderRegistry.clear_cache()
451 | 
452 |         # Get available models with restrictions
453 |         # This test documents current behavior - get_available_models doesn't handle aliases
454 |         ModelProviderRegistry.get_available_models(respect_restrictions=True)
455 | 
456 |         # Currently, this will be empty because get_available_models doesn't
457 |         # recognize that "mini" allows "o4-mini"
458 |         # This is a known limitation that should be documented
459 | 
460 |     @patch("providers.registry.ModelProviderRegistry.get_provider")
461 |     def test_get_available_models_respects_restrictions(self, mock_get_provider):
462 |         """Test that registry filters models based on restrictions."""
463 |         from providers.registry import ModelProviderRegistry
464 | 
465 |         # Mock providers
466 |         mock_openai = MagicMock()
467 |         mock_openai.MODEL_CAPABILITIES = {
468 |             "o3": {"context_window": 200000},
469 |             "o3-mini": {"context_window": 200000},
470 |         }
471 |         mock_openai.get_provider_type.return_value = ProviderType.OPENAI
472 | 
473 |         def openai_list_models(
474 |             *,
475 |             respect_restrictions: bool = True,
476 |             include_aliases: bool = True,
477 |             lowercase: bool = False,
478 |             unique: bool = False,
479 |         ):
480 |             from utils.model_restrictions import get_restriction_service
481 | 
482 |             restriction_service = get_restriction_service() if respect_restrictions else None
483 |             models = []
484 |             for model_name, config in mock_openai.MODEL_CAPABILITIES.items():
485 |                 if isinstance(config, str):
486 |                     target_model = config
487 |                     if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
488 |                         continue
489 |                     if include_aliases:
490 |                         models.append(model_name)
491 |                 else:
492 |                     if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
493 |                         continue
494 |                     models.append(model_name)
495 |             if lowercase:
496 |                 models = [m.lower() for m in models]
497 |             if unique:
498 |                 seen = set()
499 |                 ordered = []
500 |                 for name in models:
501 |                     if name in seen:
502 |                         continue
503 |                     seen.add(name)
504 |                     ordered.append(name)
505 |                 models = ordered
506 |             return models
507 | 
508 |         mock_openai.list_models = MagicMock(side_effect=openai_list_models)
509 | 
510 |         mock_gemini = MagicMock()
511 |         mock_gemini.MODEL_CAPABILITIES = {
512 |             "gemini-2.5-pro": {"context_window": 1048576},
513 |             "gemini-2.5-flash": {"context_window": 1048576},
514 |         }
515 |         mock_gemini.get_provider_type.return_value = ProviderType.GOOGLE
516 | 
517 |         def gemini_list_models(
518 |             *,
519 |             respect_restrictions: bool = True,
520 |             include_aliases: bool = True,
521 |             lowercase: bool = False,
522 |             unique: bool = False,
523 |         ):
524 |             from utils.model_restrictions import get_restriction_service
525 | 
526 |             restriction_service = get_restriction_service() if respect_restrictions else None
527 |             models = []
528 |             for model_name, config in mock_gemini.MODEL_CAPABILITIES.items():
529 |                 if isinstance(config, str):
530 |                     target_model = config
531 |                     if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model):
532 |                         continue
533 |                     if include_aliases:
534 |                         models.append(model_name)
535 |                 else:
536 |                     if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, model_name):
537 |                         continue
538 |                     models.append(model_name)
539 |             if lowercase:
540 |                 models = [m.lower() for m in models]
541 |             if unique:
542 |                 seen = set()
543 |                 ordered = []
544 |                 for name in models:
545 |                     if name in seen:
546 |                         continue
547 |                     seen.add(name)
548 |                     ordered.append(name)
549 |                 models = ordered
550 |             return models
551 | 
552 |         mock_gemini.list_models = MagicMock(side_effect=gemini_list_models)
553 | 
554 |         def get_provider_side_effect(provider_type):
555 |             if provider_type == ProviderType.OPENAI:
556 |                 return mock_openai
557 |             elif provider_type == ProviderType.GOOGLE:
558 |                 return mock_gemini
559 |             return None
560 | 
561 |         mock_get_provider.side_effect = get_provider_side_effect
562 | 
563 |         # Set up registry with providers
564 |         registry = ModelProviderRegistry()
565 |         registry._providers = {
566 |             ProviderType.OPENAI: type(mock_openai),
567 |             ProviderType.GOOGLE: type(mock_gemini),
568 |         }
569 | 
570 |         with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini", "GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}):
571 |             # Clear cached restriction service
572 |             import utils.model_restrictions
573 | 
574 |             utils.model_restrictions._restriction_service = None
575 | 
576 |             available = ModelProviderRegistry.get_available_models(respect_restrictions=True)
577 | 
578 |             # Should only include allowed models
579 |             assert "o3-mini" in available
580 |             assert "o3" not in available
581 |             assert "gemini-2.5-flash" in available
582 |             assert "gemini-2.5-pro" not in available
583 | 
584 | 
585 | class TestShorthandRestrictions:
586 |     """Test that shorthand model names work correctly in restrictions."""
587 | 
588 |     @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GOOGLE_ALLOWED_MODELS": "flash"})
589 |     def test_providers_validate_shorthands_correctly(self):
590 |         """Test that providers correctly validate shorthand names."""
591 |         # Clear cached restriction service
592 |         import utils.model_restrictions
593 | 
594 |         utils.model_restrictions._restriction_service = None
595 | 
596 |         # Test OpenAI provider
597 |         openai_provider = OpenAIModelProvider(api_key="test-key")
598 |         gemini_provider = GeminiModelProvider(api_key="test-key")
599 | 
600 |         from providers.registry import ModelProviderRegistry
601 | 
602 |         def registry_side_effect(provider_type, force_new=False):
603 |             mapping = {
604 |                 ProviderType.OPENAI: openai_provider,
605 |                 ProviderType.GOOGLE: gemini_provider,
606 |             }
607 |             return mapping.get(provider_type)
608 | 
609 |         with patch.object(ModelProviderRegistry, "get_provider", side_effect=registry_side_effect):
610 |             assert openai_provider.validate_model_name("mini")  # Should work with shorthand
611 |             assert openai_provider.validate_model_name("gpt-5-mini")  # Canonical resolved from shorthand
612 |             assert not openai_provider.validate_model_name("o4-mini")  # Unrelated model still blocked
613 |             assert not openai_provider.validate_model_name("o3-mini")
614 | 
615 |             # Test Gemini provider
616 |             assert gemini_provider.validate_model_name("flash")  # Should work with shorthand
617 |             assert gemini_provider.validate_model_name("gemini-2.5-flash")  # Canonical allowed
618 |             assert not gemini_provider.validate_model_name("pro")  # Not allowed
619 | 
620 |     @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
621 |     def test_multiple_shorthands_for_same_model(self):
622 |         """Test that multiple shorthands work correctly."""
623 |         # Clear cached restriction service
624 |         import utils.model_restrictions
625 | 
626 |         utils.model_restrictions._restriction_service = None
627 | 
628 |         openai_provider = OpenAIModelProvider(api_key="test-key")
629 | 
630 |         # Both shorthands should work
631 |         assert openai_provider.validate_model_name("mini")  # mini -> o4-mini
632 |         assert openai_provider.validate_model_name("o3mini")  # o3mini -> o3-mini
633 | 
634 |         # Resolved names should be allowed when their shorthands are present
635 |         assert openai_provider.validate_model_name("o4-mini")  # Explicitly allowed
636 |         assert openai_provider.validate_model_name("o3-mini")  # Allowed via shorthand
637 | 
638 |         # Other models should not work
639 |         assert not openai_provider.validate_model_name("o3")
640 |         assert not openai_provider.validate_model_name("o3-pro")
641 | 
642 |     @patch.dict(
643 |         os.environ,
644 |         {"OPENAI_ALLOWED_MODELS": "mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,gemini-2.5-flash"},
645 |     )
646 |     def test_both_shorthand_and_full_name_allowed(self):
647 |         """Test that we can allow both shorthand and full names."""
648 |         # Clear cached restriction service
649 |         import utils.model_restrictions
650 | 
651 |         utils.model_restrictions._restriction_service = None
652 | 
653 |         # OpenAI - both mini and o4-mini are allowed
654 |         openai_provider = OpenAIModelProvider(api_key="test-key")
655 |         assert openai_provider.validate_model_name("mini")
656 |         assert openai_provider.validate_model_name("o4-mini")
657 | 
658 |         # Gemini - both flash and full name are allowed
659 |         gemini_provider = GeminiModelProvider(api_key="test-key")
660 |         assert gemini_provider.validate_model_name("flash")
661 |         assert gemini_provider.validate_model_name("gemini-2.5-flash")
662 | 
663 | 
664 | class TestAutoModeWithRestrictions:
665 |     """Test auto mode behavior with restrictions."""
666 | 
667 |     @patch("providers.registry.ModelProviderRegistry.get_provider")
668 |     def test_fallback_model_respects_restrictions(self, mock_get_provider):
669 |         """Test that fallback model selection respects restrictions."""
670 |         from providers.registry import ModelProviderRegistry
671 |         from tools.models import ToolModelCategory
672 | 
673 |         # Mock providers
674 |         mock_openai = MagicMock()
675 |         mock_openai.MODEL_CAPABILITIES = {
676 |             "o3": {"context_window": 200000},
677 |             "o3-mini": {"context_window": 200000},
678 |             "o4-mini": {"context_window": 200000},
679 |         }
680 |         mock_openai.get_provider_type.return_value = ProviderType.OPENAI
681 | 
682 |         def openai_list_models(
683 |             *,
684 |             respect_restrictions: bool = True,
685 |             include_aliases: bool = True,
686 |             lowercase: bool = False,
687 |             unique: bool = False,
688 |         ):
689 |             from utils.model_restrictions import get_restriction_service
690 | 
691 |             restriction_service = get_restriction_service() if respect_restrictions else None
692 |             models = []
693 |             for model_name, config in mock_openai.MODEL_CAPABILITIES.items():
694 |                 if isinstance(config, str):
695 |                     target_model = config
696 |                     if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
697 |                         continue
698 |                     if include_aliases:
699 |                         models.append(model_name)
700 |                 else:
701 |                     if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
702 |                         continue
703 |                     models.append(model_name)
704 |             if lowercase:
705 |                 models = [m.lower() for m in models]
706 |             if unique:
707 |                 seen = set()
708 |                 ordered = []
709 |                 for name in models:
710 |                     if name in seen:
711 |                         continue
712 |                     seen.add(name)
713 |                     ordered.append(name)
714 |                 models = ordered
715 |             return models
716 | 
717 |         mock_openai.list_models = MagicMock(side_effect=openai_list_models)
718 | 
719 |         # Add get_preferred_model method to mock to match new implementation
720 |         def get_preferred_model(category, allowed_models):
721 |             # Simple preference logic for testing - just return first allowed model
722 |             return allowed_models[0] if allowed_models else None
723 | 
724 |         mock_openai.get_preferred_model = get_preferred_model
725 | 
726 |         def get_provider_side_effect(provider_type):
727 |             if provider_type == ProviderType.OPENAI:
728 |                 return mock_openai
729 |             return None
730 | 
731 |         mock_get_provider.side_effect = get_provider_side_effect
732 | 
733 |         # Set up registry
734 |         registry = ModelProviderRegistry()
735 |         registry._providers = {ProviderType.OPENAI: type(mock_openai)}
736 | 
737 |         with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}):
738 |             # Clear cached restriction service
739 |             import utils.model_restrictions
740 | 
741 |             utils.model_restrictions._restriction_service = None
742 | 
743 |             # Should pick o4-mini instead of o3-mini for fast response
744 |             model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
745 |             assert model == "o4-mini"
746 | 
747 |     def test_fallback_with_shorthand_restrictions(self, monkeypatch):
748 |         """Test fallback model selection with shorthand restrictions."""
749 |         # Use monkeypatch to set environment variables with automatic cleanup
750 |         monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "mini")
751 |         monkeypatch.setenv("GEMINI_API_KEY", "")
752 |         monkeypatch.setenv("OPENAI_API_KEY", "test-key")
753 | 
754 |         # Clear caches and reset registry
755 |         import utils.model_restrictions
756 |         from providers.registry import ModelProviderRegistry
757 |         from tools.models import ToolModelCategory
758 | 
759 |         utils.model_restrictions._restriction_service = None
760 | 
761 |         # Store original providers for restoration
762 |         registry = ModelProviderRegistry()
763 |         original_providers = registry._providers.copy()
764 |         original_initialized = registry._initialized_providers.copy()
765 | 
766 |         try:
767 |             # Clear registry and register only OpenAI and Gemini providers
768 |             ModelProviderRegistry._instance = None
769 |             from providers.gemini import GeminiModelProvider
770 |             from providers.openai import OpenAIModelProvider
771 | 
772 |             ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
773 |             ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
774 | 
775 |             # Even with "mini" restriction, fallback should work if provider handles it correctly
776 |             # This tests the real-world scenario
777 |             model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
778 | 
779 |             # The fallback will depend on how get_available_models handles aliases
780 |             # When "mini" is allowed, it's returned as the allowed model
781 |             # "mini" is now an alias for gpt-5-mini, but the list shows "mini" itself
782 |             assert model in ["mini", "gpt-5-mini", "o4-mini", "gemini-2.5-flash"]
783 |         finally:
784 |             # Restore original registry state
785 |             registry = ModelProviderRegistry()
786 |             registry._providers.clear()
787 |             registry._initialized_providers.clear()
788 |             registry._providers.update(original_providers)
789 |             registry._initialized_providers.update(original_initialized)
790 | 
```

--------------------------------------------------------------------------------
/providers/openai_compatible.py:
--------------------------------------------------------------------------------

```python
  1 | """Base class for OpenAI-compatible API providers."""
  2 | 
  3 | import copy
  4 | import ipaddress
  5 | import logging
  6 | from typing import Optional
  7 | from urllib.parse import urlparse
  8 | 
  9 | from openai import OpenAI
 10 | 
 11 | from utils.env import get_env, suppress_env_vars
 12 | from utils.image_utils import validate_image
 13 | 
 14 | from .base import ModelProvider
 15 | from .shared import (
 16 |     ModelCapabilities,
 17 |     ModelResponse,
 18 |     ProviderType,
 19 | )
 20 | 
 21 | 
 22 | class OpenAICompatibleProvider(ModelProvider):
 23 |     """Shared implementation for OpenAI API lookalikes.
 24 | 
 25 |     The class owns HTTP client configuration (timeouts, proxy hardening,
 26 |     custom headers) and normalises the OpenAI SDK responses into
 27 |     :class:`~providers.shared.ModelResponse`.  Concrete subclasses only need to
 28 |     provide capability metadata and any provider-specific request tweaks.
 29 |     """
 30 | 
 31 |     DEFAULT_HEADERS = {}
 32 |     FRIENDLY_NAME = "OpenAI Compatible"
 33 | 
 34 |     def __init__(self, api_key: str, base_url: str = None, **kwargs):
 35 |         """Initialize the provider with API key and optional base URL.
 36 | 
 37 |         Args:
 38 |             api_key: API key for authentication
 39 |             base_url: Base URL for the API endpoint
 40 |             **kwargs: Additional configuration options including timeout
 41 |         """
 42 |         self._allowed_alias_cache: dict[str, str] = {}
 43 |         super().__init__(api_key, **kwargs)
 44 |         self._client = None
 45 |         self.base_url = base_url
 46 |         self.organization = kwargs.get("organization")
 47 |         self.allowed_models = self._parse_allowed_models()
 48 | 
 49 |         # Configure timeouts - especially important for custom/local endpoints
 50 |         self.timeout_config = self._configure_timeouts(**kwargs)
 51 | 
 52 |         # Validate base URL for security
 53 |         if self.base_url:
 54 |             self._validate_base_url()
 55 | 
 56 |         # Warn if using external URL without authentication
 57 |         if self.base_url and not self._is_localhost_url() and not api_key:
 58 |             logging.warning(
 59 |                 f"Using external URL '{self.base_url}' without API key. "
 60 |                 "This may be insecure. Consider setting an API key for authentication."
 61 |             )
 62 | 
 63 |     def _ensure_model_allowed(
 64 |         self,
 65 |         capabilities: ModelCapabilities,
 66 |         canonical_name: str,
 67 |         requested_name: str,
 68 |     ) -> None:
 69 |         """Respect provider-specific allowlists before default restriction checks."""
 70 | 
 71 |         super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
 72 | 
 73 |         if self.allowed_models is not None:
 74 |             requested = requested_name.lower()
 75 |             canonical = canonical_name.lower()
 76 | 
 77 |             if requested not in self.allowed_models and canonical not in self.allowed_models:
 78 |                 allowed = False
 79 |                 for allowed_entry in list(self.allowed_models):
 80 |                     normalized_resolved = self._allowed_alias_cache.get(allowed_entry)
 81 |                     if normalized_resolved is None:
 82 |                         try:
 83 |                             resolved_name = self._resolve_model_name(allowed_entry)
 84 |                         except Exception:
 85 |                             continue
 86 | 
 87 |                         if not resolved_name:
 88 |                             continue
 89 | 
 90 |                         normalized_resolved = resolved_name.lower()
 91 |                         self._allowed_alias_cache[allowed_entry] = normalized_resolved
 92 | 
 93 |                     if normalized_resolved == canonical:
 94 |                         # Canonical match discovered via alias resolution – mark as allowed and
 95 |                         # memoise the canonical entry for future lookups.
 96 |                         allowed = True
 97 |                         self._allowed_alias_cache[canonical] = canonical
 98 |                         self.allowed_models.add(canonical)
 99 |                         break
100 | 
101 |                 if not allowed:
102 |                     raise ValueError(
103 |                         f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
104 |                     )
105 | 
106 |     def _parse_allowed_models(self) -> Optional[set[str]]:
107 |         """Parse allowed models from environment variable.
108 | 
109 |         Returns:
110 |             Set of allowed model names (lowercase) or None if not configured
111 |         """
112 |         # Get provider-specific allowed models
113 |         provider_type = self.get_provider_type().value.upper()
114 |         env_var = f"{provider_type}_ALLOWED_MODELS"
115 |         models_str = get_env(env_var, "") or ""
116 | 
117 |         if models_str:
118 |             # Parse and normalize to lowercase for case-insensitive comparison
119 |             models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
120 |             if models:
121 |                 logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
122 |                 self._allowed_alias_cache = {}
123 |                 return models
124 | 
125 |         # Log info if no allow-list configured for proxy providers
126 |         if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
127 |             logging.info(
128 |                 f"Model allow-list not configured for {self.FRIENDLY_NAME} - all models permitted. "
129 |                 f"To restrict access, set {env_var} with comma-separated model names."
130 |             )
131 | 
132 |         return None
133 | 
134 |     def _configure_timeouts(self, **kwargs):
135 |         """Configure timeout settings based on provider type and custom settings.
136 | 
137 |         Custom URLs and local models often need longer timeouts due to:
138 |         - Network latency on local networks
139 |         - Extended thinking models taking longer to respond
140 |         - Local inference being slower than cloud APIs
141 | 
142 |         Returns:
143 |             httpx.Timeout object with appropriate timeout settings
144 |         """
145 |         import httpx
146 | 
147 |         # Default timeouts - more generous for custom/local endpoints
148 |         default_connect = 30.0  # 30 seconds for connection (vs OpenAI's 5s)
149 |         default_read = 600.0  # 10 minutes for reading (same as OpenAI default)
150 |         default_write = 600.0  # 10 minutes for writing
151 |         default_pool = 600.0  # 10 minutes for pool
152 | 
153 |         # For custom/local URLs, use even longer timeouts
154 |         if self.base_url and self._is_localhost_url():
155 |             default_connect = 60.0  # 1 minute for local connections
156 |             default_read = 1800.0  # 30 minutes for local models (extended thinking)
157 |             default_write = 1800.0  # 30 minutes for local models
158 |             default_pool = 1800.0  # 30 minutes for local models
159 |             logging.info(f"Using extended timeouts for local endpoint: {self.base_url}")
160 |         elif self.base_url:
161 |             default_connect = 45.0  # 45 seconds for custom remote endpoints
162 |             default_read = 900.0  # 15 minutes for custom remote endpoints
163 |             default_write = 900.0  # 15 minutes for custom remote endpoints
164 |             default_pool = 900.0  # 15 minutes for custom remote endpoints
165 |             logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}")
166 | 
167 |         # Allow override via kwargs or environment variables in future, for now...
168 |         connect_timeout = kwargs.get("connect_timeout")
169 |         if connect_timeout is None:
170 |             connect_timeout_raw = get_env("CUSTOM_CONNECT_TIMEOUT")
171 |             connect_timeout = float(connect_timeout_raw) if connect_timeout_raw is not None else float(default_connect)
172 | 
173 |         read_timeout = kwargs.get("read_timeout")
174 |         if read_timeout is None:
175 |             read_timeout_raw = get_env("CUSTOM_READ_TIMEOUT")
176 |             read_timeout = float(read_timeout_raw) if read_timeout_raw is not None else float(default_read)
177 | 
178 |         write_timeout = kwargs.get("write_timeout")
179 |         if write_timeout is None:
180 |             write_timeout_raw = get_env("CUSTOM_WRITE_TIMEOUT")
181 |             write_timeout = float(write_timeout_raw) if write_timeout_raw is not None else float(default_write)
182 | 
183 |         pool_timeout = kwargs.get("pool_timeout")
184 |         if pool_timeout is None:
185 |             pool_timeout_raw = get_env("CUSTOM_POOL_TIMEOUT")
186 |             pool_timeout = float(pool_timeout_raw) if pool_timeout_raw is not None else float(default_pool)
187 | 
188 |         timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout)
189 | 
190 |         logging.debug(
191 |             f"Configured timeouts - Connect: {connect_timeout}s, Read: {read_timeout}s, "
192 |             f"Write: {write_timeout}s, Pool: {pool_timeout}s"
193 |         )
194 | 
195 |         return timeout
196 | 
197 |     def _is_localhost_url(self) -> bool:
198 |         """Check if the base URL points to localhost or local network.
199 | 
200 |         Returns:
201 |             True if URL is localhost or local network, False otherwise
202 |         """
203 |         if not self.base_url:
204 |             return False
205 | 
206 |         try:
207 |             parsed = urlparse(self.base_url)
208 |             hostname = parsed.hostname
209 | 
210 |             # Check for common localhost patterns
211 |             if hostname in ["localhost", "127.0.0.1", "::1"]:
212 |                 return True
213 | 
214 |             # Check for private network ranges (local network)
215 |             if hostname:
216 |                 try:
217 |                     ip = ipaddress.ip_address(hostname)
218 |                     return ip.is_private or ip.is_loopback
219 |                 except ValueError:
220 |                     # Not an IP address, might be a hostname
221 |                     pass
222 | 
223 |             return False
224 |         except Exception:
225 |             return False
226 | 
227 |     def _validate_base_url(self) -> None:
228 |         """Validate base URL for security (SSRF protection).
229 | 
230 |         Raises:
231 |             ValueError: If URL is invalid or potentially unsafe
232 |         """
233 |         if not self.base_url:
234 |             return
235 | 
236 |         try:
237 |             parsed = urlparse(self.base_url)
238 | 
239 |             # Check URL scheme - only allow http/https
240 |             if parsed.scheme not in ("http", "https"):
241 |                 raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
242 | 
243 |             # Check hostname exists
244 |             if not parsed.hostname:
245 |                 raise ValueError("URL must include a hostname")
246 | 
247 |             # Check port is valid (if specified)
248 |             port = parsed.port
249 |             if port is not None and (port < 1 or port > 65535):
250 |                 raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.")
251 |         except Exception as e:
252 |             if isinstance(e, ValueError):
253 |                 raise
254 |             raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
255 | 
256 |     @property
257 |     def client(self):
258 |         """Lazy initialization of OpenAI client with security checks and timeout configuration."""
259 |         if self._client is None:
260 |             import httpx
261 | 
262 |             proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]
263 | 
264 |             with suppress_env_vars(*proxy_env_vars):
265 |                 try:
266 |                     # Create a custom httpx client that explicitly avoids proxy parameters
267 |                     timeout_config = (
268 |                         self.timeout_config
269 |                         if hasattr(self, "timeout_config") and self.timeout_config
270 |                         else httpx.Timeout(30.0)
271 |                     )
272 | 
273 |                     # Create httpx client with minimal config to avoid proxy conflicts
274 |                     # Note: proxies parameter was removed in httpx 0.28.0
275 |                     # Check for test transport injection
276 |                     if hasattr(self, "_test_transport"):
277 |                         # Use custom transport for testing (HTTP recording/replay)
278 |                         http_client = httpx.Client(
279 |                             transport=self._test_transport,
280 |                             timeout=timeout_config,
281 |                             follow_redirects=True,
282 |                         )
283 |                     else:
284 |                         # Normal production client
285 |                         http_client = httpx.Client(
286 |                             timeout=timeout_config,
287 |                             follow_redirects=True,
288 |                         )
289 | 
290 |                     # Keep client initialization minimal to avoid proxy parameter conflicts
291 |                     client_kwargs = {
292 |                         "api_key": self.api_key,
293 |                         "http_client": http_client,
294 |                     }
295 | 
296 |                     if self.base_url:
297 |                         client_kwargs["base_url"] = self.base_url
298 | 
299 |                     if self.organization:
300 |                         client_kwargs["organization"] = self.organization
301 | 
302 |                     # Add default headers if any
303 |                     if self.DEFAULT_HEADERS:
304 |                         client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
305 | 
306 |                     logging.debug(
307 |                         "OpenAI client initialized with custom httpx client and timeout: %s",
308 |                         timeout_config,
309 |                     )
310 | 
311 |                     # Create OpenAI client with custom httpx client
312 |                     self._client = OpenAI(**client_kwargs)
313 | 
314 |                 except Exception as e:
315 |                     # If all else fails, try absolute minimal client without custom httpx
316 |                     logging.warning(
317 |                         "Failed to create client with custom httpx, falling back to minimal config: %s",
318 |                         e,
319 |                     )
320 |                     try:
321 |                         minimal_kwargs = {"api_key": self.api_key}
322 |                         if self.base_url:
323 |                             minimal_kwargs["base_url"] = self.base_url
324 |                         self._client = OpenAI(**minimal_kwargs)
325 |                     except Exception as fallback_error:
326 |                         logging.error("Even minimal OpenAI client creation failed: %s", fallback_error)
327 |                         raise
328 | 
329 |         return self._client
330 | 
331 |     def _sanitize_for_logging(self, params: dict) -> dict:
332 |         """Sanitize sensitive data from parameters before logging.
333 | 
334 |         Args:
335 |             params: Dictionary of API parameters
336 | 
337 |         Returns:
338 |             dict: Sanitized copy of parameters safe for logging
339 |         """
340 |         sanitized = copy.deepcopy(params)
341 | 
342 |         # Sanitize messages content
343 |         if "input" in sanitized:
344 |             for msg in sanitized.get("input", []):
345 |                 if isinstance(msg, dict) and "content" in msg:
346 |                     for content_item in msg.get("content", []):
347 |                         if isinstance(content_item, dict) and "text" in content_item:
348 |                             # Truncate long text and add ellipsis
349 |                             text = content_item["text"]
350 |                             if len(text) > 100:
351 |                                 content_item["text"] = text[:100] + "... [truncated]"
352 | 
353 |         # Remove any API keys that might be in headers/auth
354 |         sanitized.pop("api_key", None)
355 |         sanitized.pop("authorization", None)
356 | 
357 |         return sanitized
358 | 
359 |     def _safe_extract_output_text(self, response) -> str:
360 |         """Safely extract output_text from o3-pro response with validation.
361 | 
362 |         Args:
363 |             response: Response object from OpenAI SDK
364 | 
365 |         Returns:
366 |             str: The output text content
367 | 
368 |         Raises:
369 |             ValueError: If output_text is missing, None, or not a string
370 |         """
371 |         logging.debug(f"Response object type: {type(response)}")
372 |         logging.debug(f"Response attributes: {dir(response)}")
373 | 
374 |         if not hasattr(response, "output_text"):
375 |             raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
376 | 
377 |         content = response.output_text
378 |         logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
379 | 
380 |         if content is None:
381 |             raise ValueError("o3-pro returned None for output_text")
382 | 
383 |         if not isinstance(content, str):
384 |             raise ValueError(f"o3-pro output_text is not a string. Got type: {type(content).__name__}")
385 | 
386 |         return content
387 | 
388 |     def _generate_with_responses_endpoint(
389 |         self,
390 |         model_name: str,
391 |         messages: list,
392 |         temperature: float,
393 |         max_output_tokens: Optional[int] = None,
394 |         capabilities: Optional[ModelCapabilities] = None,
395 |         **kwargs,
396 |     ) -> ModelResponse:
397 |         """Generate content using the /v1/responses endpoint for reasoning models."""
398 |         # Convert messages to the correct format for responses endpoint
399 |         input_messages = []
400 | 
401 |         for message in messages:
402 |             role = message.get("role", "")
403 |             content = message.get("content", "")
404 | 
405 |             if role == "system":
406 |                 # For o3-pro, system messages should be handled carefully to avoid policy violations
407 |                 # Instead of prefixing with "System:", we'll include the system content naturally
408 |                 input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
409 |             elif role == "user":
410 |                 input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
411 |             elif role == "assistant":
412 |                 input_messages.append({"role": "assistant", "content": [{"type": "output_text", "text": content}]})
413 | 
414 |         # Prepare completion parameters for responses endpoint
415 |         # Based on OpenAI documentation, use nested reasoning object for responses endpoint
416 |         effort = "medium"
417 |         if capabilities and capabilities.default_reasoning_effort:
418 |             effort = capabilities.default_reasoning_effort
419 | 
420 |         completion_params = {
421 |             "model": model_name,
422 |             "input": input_messages,
423 |             "reasoning": {"effort": effort},
424 |             "store": True,
425 |         }
426 | 
427 |         # Add max tokens if specified (using max_completion_tokens for responses endpoint)
428 |         if max_output_tokens:
429 |             completion_params["max_completion_tokens"] = max_output_tokens
430 | 
431 |         # For responses endpoint, we only add parameters that are explicitly supported
432 |         # Remove unsupported chat completion parameters that may cause API errors
433 | 
434 |         # Retry logic with progressive delays
435 |         max_retries = 4
436 |         retry_delays = [1, 3, 5, 8]
437 |         attempt_counter = {"value": 0}
438 | 
439 |         def _attempt() -> ModelResponse:
440 |             attempt_counter["value"] += 1
441 |             import json
442 | 
443 |             sanitized_params = self._sanitize_for_logging(completion_params)
444 |             logging.info(
445 |                 f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
446 |             )
447 | 
448 |             response = self.client.responses.create(**completion_params)
449 | 
450 |             content = self._safe_extract_output_text(response)
451 | 
452 |             usage = None
453 |             if hasattr(response, "usage"):
454 |                 usage = self._extract_usage(response)
455 |             elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
456 |                 input_tokens = getattr(response, "input_tokens", 0) or 0
457 |                 output_tokens = getattr(response, "output_tokens", 0) or 0
458 |                 usage = {
459 |                     "input_tokens": input_tokens,
460 |                     "output_tokens": output_tokens,
461 |                     "total_tokens": input_tokens + output_tokens,
462 |                 }
463 | 
464 |             return ModelResponse(
465 |                 content=content,
466 |                 usage=usage,
467 |                 model_name=model_name,
468 |                 friendly_name=self.FRIENDLY_NAME,
469 |                 provider=self.get_provider_type(),
470 |                 metadata={
471 |                     "model": getattr(response, "model", model_name),
472 |                     "id": getattr(response, "id", ""),
473 |                     "created": getattr(response, "created_at", 0),
474 |                     "endpoint": "responses",
475 |                 },
476 |             )
477 | 
478 |         try:
479 |             return self._run_with_retries(
480 |                 operation=_attempt,
481 |                 max_attempts=max_retries,
482 |                 delays=retry_delays,
483 |                 log_prefix="responses endpoint",
484 |             )
485 |         except Exception as exc:
486 |             attempts = max(attempt_counter["value"], 1)
487 |             error_msg = f"responses endpoint error after {attempts} attempt{'s' if attempts > 1 else ''}: {exc}"
488 |             logging.error(error_msg)
489 |             raise RuntimeError(error_msg) from exc
490 | 
491 |     def generate_content(
492 |         self,
493 |         prompt: str,
494 |         model_name: str,
495 |         system_prompt: Optional[str] = None,
496 |         temperature: float = 0.3,
497 |         max_output_tokens: Optional[int] = None,
498 |         images: Optional[list[str]] = None,
499 |         **kwargs,
500 |     ) -> ModelResponse:
501 |         """Generate content using the OpenAI-compatible API.
502 | 
503 |         Args:
504 |             prompt: User prompt to send to the model
505 |             model_name: Canonical model name or its alias
506 |             system_prompt: Optional system prompt for model behavior
507 |             temperature: Sampling temperature
508 |             max_output_tokens: Maximum tokens to generate
509 |             images: Optional list of image paths or data URLs to include with the prompt (for vision models)
510 |             **kwargs: Additional provider-specific parameters
511 | 
512 |         Returns:
513 |             ModelResponse with generated content and metadata
514 |         """
515 |         # Validate model name against allow-list
516 |         if not self.validate_model_name(model_name):
517 |             raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
518 | 
519 |         capabilities: Optional[ModelCapabilities]
520 |         try:
521 |             capabilities = self.get_capabilities(model_name)
522 |         except Exception as exc:
523 |             logging.debug(f"Falling back to generic capabilities for {model_name}: {exc}")
524 |             capabilities = None
525 | 
526 |         # Get effective temperature for this model from capabilities when available
527 |         if capabilities:
528 |             effective_temperature = capabilities.get_effective_temperature(temperature)
529 |             if effective_temperature is not None and effective_temperature != temperature:
530 |                 logging.debug(
531 |                     f"Adjusting temperature from {temperature} to {effective_temperature} for model {model_name}"
532 |                 )
533 |         else:
534 |             effective_temperature = temperature
535 | 
536 |         # Only validate if temperature is not None (meaning the model supports it)
537 |         if effective_temperature is not None:
538 |             # Validate parameters with the effective temperature
539 |             self.validate_parameters(model_name, effective_temperature)
540 | 
541 |         # Resolve to canonical model name
542 |         resolved_model = self._resolve_model_name(model_name)
543 | 
544 |         # Prepare messages
545 |         messages = []
546 |         if system_prompt:
547 |             messages.append({"role": "system", "content": system_prompt})
548 | 
549 |         # Prepare user message with text and potentially images
550 |         user_content = []
551 |         user_content.append({"type": "text", "text": prompt})
552 | 
553 |         # Add images if provided and model supports vision
554 |         if images and capabilities and capabilities.supports_images:
555 |             for image_path in images:
556 |                 try:
557 |                     image_content = self._process_image(image_path)
558 |                     if image_content:
559 |                         user_content.append(image_content)
560 |                 except Exception as e:
561 |                     logging.warning(f"Failed to process image {image_path}: {e}")
562 |                     # Continue with other images and text
563 |                     continue
564 |         elif images and (not capabilities or not capabilities.supports_images):
565 |             logging.warning(f"Model {resolved_model} does not support images, ignoring {len(images)} image(s)")
566 | 
567 |         # Add user message
568 |         if len(user_content) == 1:
569 |             # Only text content, use simple string format for compatibility
570 |             messages.append({"role": "user", "content": prompt})
571 |         else:
572 |             # Text + images, use content array format
573 |             messages.append({"role": "user", "content": user_content})
574 | 
575 |         # Prepare completion parameters
576 |         # Always disable streaming for OpenRouter
577 |         # MCP doesn't use streaming, and this avoids issues with O3 model access
578 |         completion_params = {
579 |             "model": resolved_model,
580 |             "messages": messages,
581 |             "stream": False,
582 |         }
583 | 
584 |         # Use the effective temperature we calculated earlier
585 |         supports_sampling = effective_temperature is not None
586 | 
587 |         if supports_sampling:
588 |             completion_params["temperature"] = effective_temperature
589 | 
590 |         # Add max tokens if specified and model supports it
591 |         # O3/O4 models that don't support temperature also don't support max_tokens
592 |         if max_output_tokens and supports_sampling:
593 |             completion_params["max_tokens"] = max_output_tokens
594 | 
595 |         # Add any additional OpenAI-specific parameters
596 |         # Use capabilities to filter parameters for reasoning models
597 |         for key, value in kwargs.items():
598 |             if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
599 |                 # Reasoning models (those that don't support temperature) also don't support these parameters
600 |                 if not supports_sampling and key in ["top_p", "frequency_penalty", "presence_penalty", "stream"]:
601 |                     continue  # Skip unsupported parameters for reasoning models
602 |                 completion_params[key] = value
603 | 
604 |         # Check if this model needs the Responses API endpoint
605 |         # Prefer capability metadata; fall back to static map when capabilities unavailable
606 |         use_responses_api = False
607 |         if capabilities is not None:
608 |             use_responses_api = getattr(capabilities, "use_openai_response_api", False)
609 |         else:
610 |             static_capabilities = self.get_all_model_capabilities().get(resolved_model)
611 |             if static_capabilities is not None:
612 |                 use_responses_api = getattr(static_capabilities, "use_openai_response_api", False)
613 | 
614 |         if use_responses_api:
615 |             # These models require the /v1/responses endpoint for stateful context
616 |             # If it fails, we should not fall back to chat/completions
617 |             return self._generate_with_responses_endpoint(
618 |                 model_name=resolved_model,
619 |                 messages=messages,
620 |                 temperature=temperature,
621 |                 max_output_tokens=max_output_tokens,
622 |                 capabilities=capabilities,
623 |                 **kwargs,
624 |             )
625 | 
626 |         # Retry logic with progressive delays
627 |         max_retries = 4  # Total of 4 attempts
628 |         retry_delays = [1, 3, 5, 8]  # Progressive delays: 1s, 3s, 5s, 8s
629 |         attempt_counter = {"value": 0}
630 | 
631 |         def _attempt() -> ModelResponse:
632 |             attempt_counter["value"] += 1
633 |             response = self.client.chat.completions.create(**completion_params)
634 | 
635 |             content = response.choices[0].message.content
636 |             usage = self._extract_usage(response)
637 | 
638 |             return ModelResponse(
639 |                 content=content,
640 |                 usage=usage,
641 |                 model_name=resolved_model,
642 |                 friendly_name=self.FRIENDLY_NAME,
643 |                 provider=self.get_provider_type(),
644 |                 metadata={
645 |                     "finish_reason": response.choices[0].finish_reason,
646 |                     "model": response.model,
647 |                     "id": response.id,
648 |                     "created": response.created,
649 |                 },
650 |             )
651 | 
652 |         try:
653 |             return self._run_with_retries(
654 |                 operation=_attempt,
655 |                 max_attempts=max_retries,
656 |                 delays=retry_delays,
657 |                 log_prefix=f"{self.FRIENDLY_NAME} API ({resolved_model})",
658 |             )
659 |         except Exception as exc:
660 |             attempts = max(attempt_counter["value"], 1)
661 |             error_msg = (
662 |                 f"{self.FRIENDLY_NAME} API error for model {resolved_model} after {attempts} attempt"
663 |                 f"{'s' if attempts > 1 else ''}: {exc}"
664 |             )
665 |             logging.error(error_msg)
666 |             raise RuntimeError(error_msg) from exc
667 | 
668 |     def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
669 |         """Validate model parameters.
670 | 
671 |         For proxy providers, this may use generic capabilities.
672 | 
673 |         Args:
674 |             model_name: Canonical model name or its alias
675 |             temperature: Temperature to validate
676 |             **kwargs: Additional parameters to validate
677 |         """
678 |         try:
679 |             capabilities = self.get_capabilities(model_name)
680 | 
681 |             # Check if we're using generic capabilities
682 |             if hasattr(capabilities, "_is_generic"):
683 |                 logging.debug(
684 |                     f"Using generic parameter validation for {model_name}. Actual model constraints may differ."
685 |                 )
686 | 
687 |             # Validate temperature using parent class method
688 |             super().validate_parameters(model_name, temperature, **kwargs)
689 | 
690 |         except Exception as e:
691 |             # For proxy providers, we might not have accurate capabilities
692 |             # Log warning but don't fail
693 |             logging.warning(f"Parameter validation limited for {model_name}: {e}")
694 | 
695 |     def _extract_usage(self, response) -> dict[str, int]:
696 |         """Extract token usage from OpenAI response.
697 | 
698 |         Args:
699 |             response: OpenAI API response object
700 | 
701 |         Returns:
702 |             Dictionary with usage statistics
703 |         """
704 |         usage = {}
705 | 
706 |         if hasattr(response, "usage") and response.usage:
707 |             # Safely extract token counts with None handling
708 |             usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) or 0
709 |             usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) or 0
710 |             usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) or 0
711 | 
712 |         return usage
713 | 
714 |     def count_tokens(self, text: str, model_name: str) -> int:
715 |         """Count tokens using OpenAI-compatible tokenizer tables when available."""
716 | 
717 |         resolved_model = self._resolve_model_name(model_name)
718 | 
719 |         try:
720 |             import tiktoken
721 | 
722 |             try:
723 |                 encoding = tiktoken.encoding_for_model(resolved_model)
724 |             except KeyError:
725 |                 encoding = tiktoken.get_encoding("cl100k_base")
726 | 
727 |             return len(encoding.encode(text))
728 | 
729 |         except (ImportError, Exception) as exc:
730 |             logging.debug("tiktoken unavailable for %s: %s", resolved_model, exc)
731 | 
732 |         return super().count_tokens(text, model_name)
733 | 
734 |     def _is_error_retryable(self, error: Exception) -> bool:
735 |         """Determine if an error should be retried based on structured error codes.
736 | 
737 |         Uses OpenAI API error structure instead of text pattern matching for reliability.
738 | 
739 |         Args:
740 |             error: Exception from OpenAI API call
741 | 
742 |         Returns:
743 |             True if error should be retried, False otherwise
744 |         """
745 |         error_str = str(error).lower()
746 | 
747 |         # Check for 429 errors first - these need special handling
748 |         if "429" in error_str:
749 |             # Try to extract structured error information
750 |             error_type = None
751 |             error_code = None
752 | 
753 |             # Parse structured error from OpenAI API response
754 |             # Format: "Error code: 429 - {'error': {'type': 'tokens', 'code': 'rate_limit_exceeded', ...}}"
755 |             try:
756 |                 import ast
757 |                 import json
758 |                 import re
759 | 
760 |                 # Extract JSON part from error string using regex
761 |                 # Look for pattern: {...} (from first { to last })
762 |                 json_match = re.search(r"\{.*\}", str(error))
763 |                 if json_match:
764 |                     json_like_str = json_match.group(0)
765 | 
766 |                     # First try: parse as Python literal (handles single quotes safely)
767 |                     try:
768 |                         error_data = ast.literal_eval(json_like_str)
769 |                     except (ValueError, SyntaxError):
770 |                         # Fallback: try JSON parsing with simple quote replacement
771 |                         # (for cases where it's already valid JSON or simple replacements work)
772 |                         json_str = json_like_str.replace("'", '"')
773 |                         error_data = json.loads(json_str)
774 | 
775 |                     if "error" in error_data:
776 |                         error_info = error_data["error"]
777 |                         error_type = error_info.get("type")
778 |                         error_code = error_info.get("code")
779 | 
780 |             except (json.JSONDecodeError, ValueError, SyntaxError, AttributeError):
781 |                 # Fall back to checking hasattr for OpenAI SDK exception objects
782 |                 if hasattr(error, "response") and hasattr(error.response, "json"):
783 |                     try:
784 |                         response_data = error.response.json()
785 |                         if "error" in response_data:
786 |                             error_info = response_data["error"]
787 |                             error_type = error_info.get("type")
788 |                             error_code = error_info.get("code")
789 |                     except Exception:
790 |                         pass
791 | 
792 |             # Determine if 429 is retryable based on structured error codes
793 |             if error_type == "tokens":
794 |                 # Token-related 429s are typically non-retryable (request too large)
795 |                 logging.debug(f"Non-retryable 429: token-related error (type={error_type}, code={error_code})")
796 |                 return False
797 |             elif error_code in ["invalid_request_error", "context_length_exceeded"]:
798 |                 # These are permanent failures
799 |                 logging.debug(f"Non-retryable 429: permanent failure (type={error_type}, code={error_code})")
800 |                 return False
801 |             else:
802 |                 # Other 429s (like requests per minute) are retryable
803 |                 logging.debug(f"Retryable 429: rate limiting (type={error_type}, code={error_code})")
804 |                 return True
805 | 
806 |         # For non-429 errors, check if they're retryable
807 |         retryable_indicators = [
808 |             "timeout",
809 |             "connection",
810 |             "network",
811 |             "temporary",
812 |             "unavailable",
813 |             "retry",
814 |             "408",  # Request timeout
815 |             "500",  # Internal server error
816 |             "502",  # Bad gateway
817 |             "503",  # Service unavailable
818 |             "504",  # Gateway timeout
819 |             "ssl",  # SSL errors
820 |             "handshake",  # Handshake failures
821 |         ]
822 | 
823 |         return any(indicator in error_str for indicator in retryable_indicators)
824 | 
825 |     def _process_image(self, image_path: str) -> Optional[dict]:
826 |         """Process an image for OpenAI-compatible API."""
827 |         try:
828 |             if image_path.startswith("data:"):
829 |                 # Validate the data URL
830 |                 validate_image(image_path)
831 |                 # Handle data URL: ...
832 |                 return {"type": "image_url", "image_url": {"url": image_path}}
833 |             else:
834 |                 # Use base class validation
835 |                 image_bytes, mime_type = validate_image(image_path)
836 | 
837 |                 # Read and encode the image
838 |                 import base64
839 | 
840 |                 image_data = base64.b64encode(image_bytes).decode()
841 |                 logging.debug(f"Processing image '{image_path}' as MIME type '{mime_type}'")
842 | 
843 |                 # Create data URL for OpenAI API
844 |                 data_url = f"data:{mime_type};base64,{image_data}"
845 | 
846 |                 return {"type": "image_url", "image_url": {"url": data_url}}
847 | 
848 |         except ValueError as e:
849 |             logging.warning(str(e))
850 |             return None
851 |         except Exception as e:
852 |             logging.error(f"Error processing image {image_path}: {e}")
853 |             return None
854 | 
```

--------------------------------------------------------------------------------
/tools/tracer.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Tracer Workflow tool - Step-by-step code tracing and dependency analysis
  3 | 
  4 | This tool provides a structured workflow for comprehensive code tracing and analysis.
  5 | It guides the CLI agent through systematic investigation steps with forced pauses between each step
  6 | to ensure thorough code examination, dependency mapping, and execution flow analysis before proceeding.
  7 | 
  8 | The tracer guides users through sequential code analysis with full context awareness and
  9 | the ability to revise and adapt as understanding deepens.
 10 | 
 11 | Key features:
 12 | - Sequential tracing with systematic investigation workflow
 13 | - Support for precision tracing (execution flow) and dependencies tracing (structural relationships)
 14 | - Self-contained completion with detailed output formatting instructions
 15 | - Context-aware analysis that builds understanding step by step
 16 | - No external expert analysis needed - provides comprehensive guidance internally
 17 | 
 18 | Perfect for: method/function execution flow analysis, dependency mapping, call chain tracing,
 19 | structural relationship analysis, architectural understanding, and code comprehension.
 20 | """
 21 | 
 22 | import logging
 23 | from typing import TYPE_CHECKING, Any, Literal, Optional
 24 | 
 25 | from pydantic import Field, field_validator
 26 | 
 27 | if TYPE_CHECKING:
 28 |     from tools.models import ToolModelCategory
 29 | 
 30 | from config import TEMPERATURE_ANALYTICAL
 31 | from systemprompts import TRACER_PROMPT
 32 | from tools.shared.base_models import WorkflowRequest
 33 | 
 34 | from .workflow.base import WorkflowTool
 35 | 
 36 | logger = logging.getLogger(__name__)
 37 | 
 38 | # Tool-specific field descriptions for tracer workflow
 39 | TRACER_WORKFLOW_FIELD_DESCRIPTIONS = {
 40 |     "step": (
 41 |         "The plan for the current tracing step. Step 1: State the tracing strategy. Later steps: Report findings and adapt the plan. "
 42 |         "CRITICAL: For 'precision' mode, focus on execution flow and call chains. For 'dependencies' mode, focus on structural relationships. "
 43 |         "If trace_mode is 'ask' in step 1, you MUST prompt the user to choose a mode."
 44 |     ),
 45 |     "step_number": (
 46 |         "The index of the current step in the tracing sequence, beginning at 1. Each step should build upon or "
 47 |         "revise the previous one."
 48 |     ),
 49 |     "total_steps": (
 50 |         "Your current estimate for how many steps will be needed to complete the tracing analysis. "
 51 |         "Adjust as new findings emerge."
 52 |     ),
 53 |     "next_step_required": (
 54 |         "Set to true if you plan to continue the investigation with another step. False means you believe the "
 55 |         "tracing analysis is complete and ready for final output formatting."
 56 |     ),
 57 |     "findings": (
 58 |         "Summary of discoveries from this step, including execution paths, dependency relationships, call chains, and structural patterns. "
 59 |         "IMPORTANT: Document both direct (immediate calls) and indirect (transitive, side effects) relationships."
 60 |     ),
 61 |     "files_checked": (
 62 |         "List all files examined (absolute paths). Include even ruled-out files to track exploration path."
 63 |     ),
 64 |     "relevant_files": (
 65 |         "Subset of files_checked directly relevant to the tracing target (absolute paths). Include implementation files, "
 66 |         "dependencies, or files demonstrating key relationships."
 67 |     ),
 68 |     "relevant_context": (
 69 |         "List methods/functions central to the tracing analysis, in 'ClassName.methodName' or 'functionName' format. "
 70 |         "Prioritize those in the execution flow or dependency chain."
 71 |     ),
 72 |     "confidence": (
 73 |         "Your confidence in the tracing analysis. Use: 'exploring', 'low', 'medium', 'high', 'very_high', 'almost_certain', 'certain'. "
 74 |         "CRITICAL: 'certain' implies the analysis is 100% complete locally and PREVENTS external model validation."
 75 |     ),
 76 |     "trace_mode": "Type of tracing: 'ask' (default - prompts user to choose mode), 'precision' (execution flow) or 'dependencies' (structural relationships)",
 77 |     "target_description": (
 78 |         "Description of what to trace and WHY. Include context about what you're trying to understand or analyze."
 79 |     ),
 80 |     "images": ("Optional paths to architecture diagrams or flow charts that help understand the tracing context."),
 81 | }
 82 | 
 83 | 
 84 | class TracerRequest(WorkflowRequest):
 85 |     """Request model for tracer workflow investigation steps"""
 86 | 
 87 |     # Required fields for each investigation step
 88 |     step: str = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["step"])
 89 |     step_number: int = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
 90 |     total_steps: int = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
 91 |     next_step_required: bool = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
 92 | 
 93 |     # Investigation tracking fields
 94 |     findings: str = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["findings"])
 95 |     files_checked: list[str] = Field(
 96 |         default_factory=list, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["files_checked"]
 97 |     )
 98 |     relevant_files: list[str] = Field(
 99 |         default_factory=list, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"]
100 |     )
101 |     relevant_context: list[str] = Field(
102 |         default_factory=list, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"]
103 |     )
104 |     confidence: Optional[str] = Field("exploring", description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["confidence"])
105 | 
106 |     # Tracer-specific fields (used in step 1 to initialize)
107 |     trace_mode: Optional[Literal["precision", "dependencies", "ask"]] = Field(
108 |         "ask", description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["trace_mode"]
109 |     )
110 |     target_description: Optional[str] = Field(
111 |         None, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["target_description"]
112 |     )
113 |     images: Optional[list[str]] = Field(default=None, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["images"])
114 | 
115 |     # Exclude fields not relevant to tracing workflow
116 |     issues_found: list[dict] = Field(default_factory=list, exclude=True, description="Tracing doesn't track issues")
117 |     hypothesis: Optional[str] = Field(default=None, exclude=True, description="Tracing doesn't use hypothesis")
118 |     # Exclude other non-tracing fields
119 |     temperature: Optional[float] = Field(default=None, exclude=True)
120 |     thinking_mode: Optional[str] = Field(default=None, exclude=True)
121 |     use_assistant_model: Optional[bool] = Field(default=False, exclude=True, description="Tracing is self-contained")
122 | 
123 |     @field_validator("step_number")
124 |     @classmethod
125 |     def validate_step_number(cls, v):
126 |         if v < 1:
127 |             raise ValueError("step_number must be at least 1")
128 |         return v
129 | 
130 |     @field_validator("total_steps")
131 |     @classmethod
132 |     def validate_total_steps(cls, v):
133 |         if v < 1:
134 |             raise ValueError("total_steps must be at least 1")
135 |         return v
136 | 
137 | 
138 | class TracerTool(WorkflowTool):
139 |     """
140 |     Tracer workflow tool for step-by-step code tracing and dependency analysis.
141 | 
142 |     This tool implements a structured tracing workflow that guides users through
143 |     methodical investigation steps, ensuring thorough code examination, dependency
144 |     mapping, and execution flow analysis before reaching conclusions. It supports
145 |     both precision tracing (execution flow) and dependencies tracing (structural relationships).
146 |     """
147 | 
148 |     def __init__(self):
149 |         super().__init__()
150 |         self.initial_request = None
151 |         self.trace_config = {}
152 | 
153 |     def get_name(self) -> str:
154 |         return "tracer"
155 | 
156 |     def get_description(self) -> str:
157 |         return (
158 |             "Performs systematic code tracing with modes for execution flow or dependency mapping. "
159 |             "Use for method execution analysis, call chain tracing, dependency mapping, and architectural understanding. "
160 |             "Supports precision mode (execution flow) and dependencies mode (structural relationships)."
161 |         )
162 | 
163 |     def get_system_prompt(self) -> str:
164 |         return TRACER_PROMPT
165 | 
166 |     def get_default_temperature(self) -> float:
167 |         return TEMPERATURE_ANALYTICAL
168 | 
169 |     def get_model_category(self) -> "ToolModelCategory":
170 |         """Tracer requires analytical reasoning for code analysis"""
171 |         from tools.models import ToolModelCategory
172 | 
173 |         return ToolModelCategory.EXTENDED_REASONING
174 | 
175 |     def requires_model(self) -> bool:
176 |         """
177 |         Tracer tool doesn't require model resolution at the MCP boundary.
178 | 
179 |         The tracer is a structured workflow tool that organizes tracing steps
180 |         and provides detailed output formatting guidance without calling external AI models.
181 | 
182 |         Returns:
183 |             bool: False - tracer doesn't need AI model access
184 |         """
185 |         return False
186 | 
187 |     def get_workflow_request_model(self):
188 |         """Return the tracer-specific request model."""
189 |         return TracerRequest
190 | 
191 |     def get_tool_fields(self) -> dict[str, dict[str, Any]]:
192 |         """Return tracing-specific field definitions beyond the standard workflow fields."""
193 |         return {
194 |             # Tracer-specific fields
195 |             "trace_mode": {
196 |                 "type": "string",
197 |                 "enum": ["precision", "dependencies", "ask"],
198 |                 "description": TRACER_WORKFLOW_FIELD_DESCRIPTIONS["trace_mode"],
199 |             },
200 |             "target_description": {
201 |                 "type": "string",
202 |                 "description": TRACER_WORKFLOW_FIELD_DESCRIPTIONS["target_description"],
203 |             },
204 |             "images": {
205 |                 "type": "array",
206 |                 "items": {"type": "string"},
207 |                 "description": TRACER_WORKFLOW_FIELD_DESCRIPTIONS["images"],
208 |             },
209 |         }
210 | 
211 |     def get_input_schema(self) -> dict[str, Any]:
212 |         """Generate input schema using WorkflowSchemaBuilder with field exclusion."""
213 |         from .workflow.schema_builders import WorkflowSchemaBuilder
214 | 
215 |         # Exclude investigation-specific fields that tracing doesn't need
216 |         excluded_workflow_fields = [
217 |             "issues_found",  # Tracing doesn't track issues
218 |             "hypothesis",  # Tracing doesn't use hypothesis
219 |         ]
220 | 
221 |         # Exclude common fields that tracing doesn't need
222 |         excluded_common_fields = [
223 |             "temperature",  # Tracing doesn't need temperature control
224 |             "thinking_mode",  # Tracing doesn't need thinking mode
225 |             "absolute_file_paths",  # Tracing uses relevant_files instead
226 |         ]
227 | 
228 |         return WorkflowSchemaBuilder.build_schema(
229 |             tool_specific_fields=self.get_tool_fields(),
230 |             required_fields=["target_description", "trace_mode"],  # Step 1 requires these
231 |             model_field_schema=self.get_model_field_schema(),
232 |             auto_mode=self.is_effective_auto_mode(),
233 |             tool_name=self.get_name(),
234 |             excluded_workflow_fields=excluded_workflow_fields,
235 |             excluded_common_fields=excluded_common_fields,
236 |         )
237 | 
238 |     # ================================================================================
239 |     # Abstract Methods - Required Implementation from BaseWorkflowMixin
240 |     # ================================================================================
241 | 
242 |     def get_required_actions(
243 |         self, step_number: int, confidence: str, findings: str, total_steps: int, request=None
244 |     ) -> list[str]:
245 |         """Define required actions for each tracing phase."""
246 |         if step_number == 1:
247 |             # Check if we're in ask mode and need to prompt for mode selection
248 |             if self.get_trace_mode() == "ask":
249 |                 return [
250 |                     "MUST ask user to choose between precision or dependencies mode",
251 |                     "Explain precision mode: traces execution flow, call chains, and usage patterns (best for methods/functions)",
252 |                     "Explain dependencies mode: maps structural relationships and bidirectional dependencies (best for classes/modules)",
253 |                     "Wait for user's mode selection before proceeding with investigation",
254 |                 ]
255 | 
256 |             # Initial tracing investigation tasks (when mode is already selected)
257 |             return [
258 |                 "Search for and locate the target method/function/class/module in the codebase",
259 |                 "Read and understand the implementation of the target code",
260 |                 "Identify the file location, complete signature, and basic structure",
261 |                 "Begin mapping immediate relationships (what it calls, what calls it)",
262 |                 "Understand the context and purpose of the target code",
263 |             ]
264 |         elif confidence in ["exploring", "low"]:
265 |             # Need deeper investigation
266 |             return [
267 |                 "Trace deeper into the execution flow or dependency relationships",
268 |                 "Examine how the target code is used throughout the codebase",
269 |                 "Map additional layers of dependencies or call chains",
270 |                 "Look for conditional execution paths, error handling, and edge cases",
271 |                 "Understand the broader architectural context and patterns",
272 |             ]
273 |         elif confidence in ["medium", "high"]:
274 |             # Close to completion - need final verification
275 |             return [
276 |                 "Verify completeness of the traced relationships and execution paths",
277 |                 "Check for any missed dependencies, usage patterns, or execution branches",
278 |                 "Confirm understanding of side effects, state changes, and external interactions",
279 |                 "Validate that the tracing covers all significant code relationships",
280 |                 "Prepare comprehensive findings for final output formatting",
281 |             ]
282 |         else:
283 |             # General investigation needed
284 |             return [
285 |                 "Continue systematic tracing of code relationships and execution paths",
286 |                 "Gather more evidence using appropriate code analysis techniques",
287 |                 "Test assumptions about code behavior and dependency relationships",
288 |                 "Look for patterns that enhance understanding of the code structure",
289 |                 "Focus on areas that haven't been thoroughly traced yet",
290 |             ]
291 | 
292 |     def should_call_expert_analysis(self, consolidated_findings, request=None) -> bool:
293 |         """Tracer is self-contained and doesn't need expert analysis."""
294 |         return False
295 | 
296 |     def prepare_expert_analysis_context(self, consolidated_findings) -> str:
297 |         """Tracer doesn't use expert analysis."""
298 |         return ""
299 | 
300 |     def requires_expert_analysis(self) -> bool:
301 |         """Tracer is self-contained like the planner tool."""
302 |         return False
303 | 
304 |     # ================================================================================
305 |     # Workflow Customization - Match Planner Behavior
306 |     # ================================================================================
307 | 
308 |     def prepare_step_data(self, request) -> dict:
309 |         """
310 |         Prepare step data from request with tracer-specific fields.
311 |         """
312 |         step_data = {
313 |             "step": request.step,
314 |             "step_number": request.step_number,
315 |             "findings": request.findings,
316 |             "files_checked": request.files_checked,
317 |             "relevant_files": request.relevant_files,
318 |             "relevant_context": request.relevant_context,
319 |             "issues_found": [],  # Tracer doesn't track issues
320 |             "confidence": request.confidence or "exploring",
321 |             "hypothesis": None,  # Tracer doesn't use hypothesis
322 |             "images": request.images or [],
323 |             # Tracer-specific fields
324 |             "trace_mode": request.trace_mode,
325 |             "target_description": request.target_description,
326 |         }
327 |         return step_data
328 | 
329 |     def build_base_response(self, request, continuation_id: str = None) -> dict:
330 |         """
331 |         Build the base response structure with tracer-specific fields.
332 |         """
333 |         # Use work_history from workflow mixin for consistent step tracking
334 |         current_step_count = len(self.work_history) + 1
335 | 
336 |         response_data = {
337 |             "status": f"{self.get_name()}_in_progress",
338 |             "step_number": request.step_number,
339 |             "total_steps": request.total_steps,
340 |             "next_step_required": request.next_step_required,
341 |             "step_content": request.step,
342 |             f"{self.get_name()}_status": {
343 |                 "files_checked": len(self.consolidated_findings.files_checked),
344 |                 "relevant_files": len(self.consolidated_findings.relevant_files),
345 |                 "relevant_context": len(self.consolidated_findings.relevant_context),
346 |                 "issues_found": len(self.consolidated_findings.issues_found),
347 |                 "images_collected": len(self.consolidated_findings.images),
348 |                 "current_confidence": self.get_request_confidence(request),
349 |                 "step_history_length": current_step_count,
350 |             },
351 |             "metadata": {
352 |                 "trace_mode": self.trace_config.get("trace_mode", "unknown"),
353 |                 "target_description": self.trace_config.get("target_description", ""),
354 |                 "step_history_length": current_step_count,
355 |             },
356 |         }
357 | 
358 |         if continuation_id:
359 |             response_data["continuation_id"] = continuation_id
360 | 
361 |         return response_data
362 | 
363 |     def handle_work_continuation(self, response_data: dict, request) -> dict:
364 |         """
365 |         Handle work continuation with tracer-specific guidance.
366 |         """
367 |         response_data["status"] = f"pause_for_{self.get_name()}"
368 |         response_data[f"{self.get_name()}_required"] = True
369 | 
370 |         # Get tracer-specific required actions
371 |         required_actions = self.get_required_actions(
372 |             request.step_number, request.confidence or "exploring", request.findings, request.total_steps
373 |         )
374 |         response_data["required_actions"] = required_actions
375 | 
376 |         # Generate step-specific guidance
377 |         if request.step_number == 1:
378 |             # Check if we're in ask mode and need to prompt for mode selection
379 |             if self.get_trace_mode() == "ask":
380 |                 response_data["next_steps"] = (
381 |                     f"STOP! You MUST ask the user to choose a tracing mode before proceeding. "
382 |                     f"Present these options clearly:\\n\\n"
383 |                     f"**PRECISION MODE**: Traces execution flow, call chains, and usage patterns. "
384 |                     f"Best for understanding how a specific method or function works, what it calls, "
385 |                     f"and how data flows through the execution path.\\n\\n"
386 |                     f"**DEPENDENCIES MODE**: Maps structural relationships and bidirectional dependencies. "
387 |                     f"Best for understanding how a class or module relates to other components, "
388 |                     f"what depends on it, and what it depends on.\\n\\n"
389 |                     f"After the user selects a mode, call {self.get_name()} again with step_number: 1 "
390 |                     f"but with the chosen trace_mode (either 'precision' or 'dependencies')."
391 |                 )
392 |             else:
393 |                 response_data["next_steps"] = (
394 |                     f"MANDATORY: DO NOT call the {self.get_name()} tool again immediately. You MUST first investigate "
395 |                     f"the codebase to understand the target code. CRITICAL AWARENESS: You need to find and understand "
396 |                     f"the target method/function/class/module, examine its implementation, and begin mapping its "
397 |                     f"relationships. Use file reading tools, code search, and systematic examination to gather "
398 |                     f"comprehensive information about the target. Only call {self.get_name()} again AFTER completing "
399 |                     f"your investigation. When you call {self.get_name()} next time, use step_number: {request.step_number + 1} "
400 |                     f"and report specific files examined, code structure discovered, and initial relationship findings."
401 |                 )
402 |         elif request.confidence in ["exploring", "low"]:
403 |             next_step = request.step_number + 1
404 |             response_data["next_steps"] = (
405 |                 f"STOP! Do NOT call {self.get_name()} again yet. Based on your findings, you've identified areas that need "
406 |                 f"deeper tracing analysis. MANDATORY ACTIONS before calling {self.get_name()} step {next_step}:\\n"
407 |                 + "\\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
408 |                 + f"\\n\\nOnly call {self.get_name()} again with step_number: {next_step} AFTER "
409 |                 + "completing these tracing investigations."
410 |             )
411 |         elif request.confidence in ["medium", "high"]:
412 |             next_step = request.step_number + 1
413 |             response_data["next_steps"] = (
414 |                 f"WAIT! Your tracing analysis needs final verification. DO NOT call {self.get_name()} immediately. "
415 |                 f"REQUIRED ACTIONS:\\n"
416 |                 + "\\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
417 |                 + f"\\n\\nREMEMBER: Ensure you have traced all significant relationships and execution paths. "
418 |                 f"Document findings with specific file references and method signatures, then call {self.get_name()} "
419 |                 f"with step_number: {next_step}."
420 |             )
421 |         else:
422 |             # General investigation needed
423 |             next_step = request.step_number + 1
424 |             remaining_steps = request.total_steps - request.step_number
425 |             response_data["next_steps"] = (
426 |                 f"Continue systematic tracing with step {next_step}. Approximately {remaining_steps} steps remaining. "
427 |                 f"Focus on deepening your understanding of the code relationships and execution patterns."
428 |             )
429 | 
430 |         return response_data
431 | 
432 |     def customize_workflow_response(self, response_data: dict, request) -> dict:
433 |         """
434 |         Customize response to match tracer tool format with output instructions.
435 |         """
436 |         # Store trace configuration on first step
437 |         if request.step_number == 1:
438 |             self.initial_request = request.step
439 |             self.trace_config = {
440 |                 "trace_mode": request.trace_mode,
441 |                 "target_description": request.target_description,
442 |             }
443 | 
444 |             # Update metadata with trace configuration
445 |             if "metadata" in response_data:
446 |                 response_data["metadata"]["trace_mode"] = request.trace_mode or "unknown"
447 |                 response_data["metadata"]["target_description"] = request.target_description or ""
448 | 
449 |             # If in ask mode, mark this as mode selection phase
450 |             if request.trace_mode == "ask":
451 |                 response_data["mode_selection_required"] = True
452 |                 response_data["status"] = "mode_selection_required"
453 | 
454 |         # Add tracer-specific output instructions for final steps
455 |         if not request.next_step_required:
456 |             response_data["tracing_complete"] = True
457 |             response_data["trace_summary"] = f"TRACING COMPLETE: {request.step}"
458 | 
459 |             # Get mode-specific output instructions
460 |             trace_mode = self.trace_config.get("trace_mode", "precision")
461 |             rendering_instructions = self._get_rendering_instructions(trace_mode)
462 | 
463 |             response_data["output"] = {
464 |                 "instructions": (
465 |                     "This is a structured tracing analysis response. Present the comprehensive tracing findings "
466 |                     "using the specific rendering format for the trace mode. Follow the exact formatting guidelines "
467 |                     "provided in rendering_instructions. Include all discovered relationships, execution paths, "
468 |                     "and dependencies with precise file references and line numbers."
469 |                 ),
470 |                 "format": f"{trace_mode}_trace_analysis",
471 |                 "rendering_instructions": rendering_instructions,
472 |                 "presentation_guidelines": {
473 |                     "completed_trace": (
474 |                         "Use the exact rendering format specified for the trace mode. Include comprehensive "
475 |                         "diagrams, tables, and structured analysis. Reference specific file paths and line numbers. "
476 |                         "Follow formatting rules precisely."
477 |                     ),
478 |                     "step_content": "Present as main analysis with clear structure and actionable insights.",
479 |                     "continuation": "Use continuation_id for related tracing sessions or follow-up analysis",
480 |                 },
481 |             }
482 |             response_data["next_steps"] = (
483 |                 f"Tracing analysis complete. Present the comprehensive {trace_mode} trace analysis to the user "
484 |                 f"using the exact rendering format specified in the output instructions. Follow the formatting "
485 |                 f"guidelines precisely, including diagrams, tables, and file references. After presenting the "
486 |                 f"analysis, offer to help with related tracing tasks or use the continuation_id for follow-up analysis."
487 |             )
488 | 
489 |         # Convert generic status names to tracer-specific ones
490 |         tool_name = self.get_name()
491 |         status_mapping = {
492 |             f"{tool_name}_in_progress": "tracing_in_progress",
493 |             f"pause_for_{tool_name}": "pause_for_tracing",
494 |             f"{tool_name}_required": "tracing_required",
495 |             f"{tool_name}_complete": "tracing_complete",
496 |         }
497 | 
498 |         if response_data["status"] in status_mapping:
499 |             response_data["status"] = status_mapping[response_data["status"]]
500 | 
501 |         return response_data
502 | 
503 |     def _get_rendering_instructions(self, trace_mode: str) -> str:
504 |         """
505 |         Get mode-specific rendering instructions for the CLI agent.
506 | 
507 |         Args:
508 |             trace_mode: Either "precision" or "dependencies"
509 | 
510 |         Returns:
511 |             str: Complete rendering instructions for the specified mode
512 |         """
513 |         if trace_mode == "precision":
514 |             return self._get_precision_rendering_instructions()
515 |         else:  # dependencies mode
516 |             return self._get_dependencies_rendering_instructions()
517 | 
518 |     def _get_precision_rendering_instructions(self) -> str:
519 |         """Get rendering instructions for precision trace mode."""
520 |         return """
521 | ## MANDATORY RENDERING INSTRUCTIONS FOR PRECISION TRACE
522 | 
523 | You MUST render the trace analysis using ONLY the Vertical Indented Flow Style:
524 | 
525 | ### CALL FLOW DIAGRAM - Vertical Indented Style
526 | 
527 | **EXACT FORMAT TO FOLLOW:**
528 | ```
529 | [ClassName::MethodName] (file: /complete/file/path.ext, line: ##)
530 | ↓
531 | [AnotherClass::calledMethod] (file: /path/to/file.ext, line: ##)
532 | ↓
533 | [ThirdClass::nestedMethod] (file: /path/file.ext, line: ##)
534 |   ↓
535 |   [DeeperClass::innerCall] (file: /path/inner.ext, line: ##) ? if some_condition
536 |   ↓
537 |   [ServiceClass::processData] (file: /services/service.ext, line: ##)
538 |     ↓
539 |     [RepositoryClass::saveData] (file: /data/repo.ext, line: ##)
540 |     ↓
541 |     [ClientClass::sendRequest] (file: /clients/client.ext, line: ##)
542 |       ↓
543 |       [EmailService::sendEmail] (file: /email/service.ext, line: ##) ⚠️ ambiguous branch
544 |       →
545 |       [SMSService::sendSMS] (file: /sms/service.ext, line: ##) ⚠️ ambiguous branch
546 | ```
547 | 
548 | **CRITICAL FORMATTING RULES:**
549 | 
550 | 1. **Method Names**: Use the actual naming convention of the project language you're analyzing. Automatically detect and adapt to the project's conventions (camelCase, snake_case, PascalCase, etc.) based on the codebase structure and file extensions.
551 | 
552 | 2. **Vertical Flow Arrows**:
553 |    - Use `↓` for standard sequential calls (vertical flow)
554 |    - Use `→` for parallel/alternative calls (horizontal branch)
555 |    - NEVER use other arrow types
556 | 
557 | 3. **Indentation Logic**:
558 |    - Start at column 0 for entry point
559 |    - Indent 2 spaces for each nesting level
560 |    - Maintain consistent indentation for same call depth
561 |    - Sibling calls at same level should have same indentation
562 | 
563 | 4. **Conditional Calls**:
564 |    - Add `? if condition_description` after method for conditional execution
565 |    - Use actual condition names from code when possible
566 | 
567 | 5. **Ambiguous Branches**:
568 |    - Mark with `⚠️ ambiguous branch` when execution path is uncertain
569 |    - Use `→` to show alternative paths at same indentation level
570 | 
571 | 6. **File Path Format**:
572 |    - Use complete relative paths from project root
573 |    - Include actual file extensions from the project
574 |    - Show exact line numbers where method is defined
575 | 
576 | ### ADDITIONAL ANALYSIS VIEWS
577 | 
578 | **1. BRANCHING & SIDE EFFECT TABLE**
579 | 
580 | | Location | Condition | Branches | Uncertain |
581 | |----------|-----------|----------|-----------|
582 | | CompleteFileName.ext:## | if actual_condition_from_code | method1(), method2(), else skip | No |
583 | | AnotherFile.ext:## | if boolean_check | callMethod(), else return | No |
584 | | ThirdFile.ext:## | if validation_passes | processData(), else throw | Yes |
585 | 
586 | **2. SIDE EFFECTS**
587 | ```
588 | Side Effects:
589 | - [database] Specific database operation description (CompleteFileName.ext:##)
590 | - [network] Specific network call description (CompleteFileName.ext:##)
591 | - [filesystem] Specific file operation description (CompleteFileName.ext:##)
592 | - [state] State changes or property modifications (CompleteFileName.ext:##)
593 | - [memory] Memory allocation or cache operations (CompleteFileName.ext:##)
594 | ```
595 | 
596 | **3. USAGE POINTS**
597 | ```
598 | Usage Points:
599 | 1. FileName.ext:## - Context description of where/why it's called
600 | 2. AnotherFile.ext:## - Context description of usage scenario
601 | 3. ThirdFile.ext:## - Context description of calling pattern
602 | 4. FourthFile.ext:## - Context description of integration point
603 | ```
604 | 
605 | **4. ENTRY POINTS**
606 | ```
607 | Entry Points:
608 | - ClassName::methodName (context: where this flow typically starts)
609 | - AnotherClass::entryMethod (context: alternative entry scenario)
610 | - ThirdClass::triggerMethod (context: event-driven entry point)
611 | ```
612 | 
613 | **ABSOLUTE REQUIREMENTS:**
614 | - Use ONLY the vertical indented style for the call flow diagram
615 | - Present ALL FOUR additional analysis views (Branching Table, Side Effects, Usage Points, Entry Points)
616 | - Adapt method naming to match the project's programming language conventions
617 | - Use exact file paths and line numbers from the actual codebase
618 | - DO NOT invent or guess method names or locations
619 | - Follow indentation rules precisely for call hierarchy
620 | - Mark uncertain execution paths clearly
621 | - Provide contextual descriptions in Usage Points and Entry Points sections
622 | - Include comprehensive side effects categorization (database, network, filesystem, state, memory)"""
623 | 
624 |     def _get_dependencies_rendering_instructions(self) -> str:
625 |         """Get rendering instructions for dependencies trace mode."""
626 |         return """
627 | ## MANDATORY RENDERING INSTRUCTIONS FOR DEPENDENCIES TRACE
628 | 
629 | You MUST render the trace analysis using ONLY the Bidirectional Arrow Flow Style:
630 | 
631 | ### DEPENDENCY FLOW DIAGRAM - Bidirectional Arrow Style
632 | 
633 | **EXACT FORMAT TO FOLLOW:**
634 | ```
635 | INCOMING DEPENDENCIES → [TARGET_CLASS/MODULE] → OUTGOING DEPENDENCIES
636 | 
637 | CallerClass::callerMethod ←────┐
638 | AnotherCaller::anotherMethod ←─┤
639 | ThirdCaller::thirdMethod ←─────┤
640 |                                │
641 |                     [TARGET_CLASS/MODULE]
642 |                                │
643 |                                ├────→ FirstDependency::method
644 |                                ├────→ SecondDependency::method
645 |                                └────→ ThirdDependency::method
646 | 
647 | TYPE RELATIONSHIPS:
648 | InterfaceName ──implements──→ [TARGET_CLASS] ──extends──→ BaseClass
649 | DTOClass ──uses──→ [TARGET_CLASS] ──uses──→ EntityClass
650 | ```
651 | 
652 | **CRITICAL FORMATTING RULES:**
653 | 
654 | 1. **Target Placement**: Always place the target class/module in square brackets `[TARGET_NAME]` at the center
655 | 2. **Incoming Dependencies**: Show on the left side with `←` arrows pointing INTO the target
656 | 3. **Outgoing Dependencies**: Show on the right side with `→` arrows pointing OUT FROM the target
657 | 4. **Arrow Alignment**: Use consistent spacing and alignment for visual clarity
658 | 5. **Method Naming**: Use the project's actual naming conventions detected from the codebase
659 | 6. **File References**: Include complete file paths and line numbers
660 | 
661 | **VISUAL LAYOUT RULES:**
662 | 
663 | 1. **Header Format**: Always start with the flow direction indicator
664 | 2. **Left Side (Incoming)**:
665 |    - List all callers with `←` arrows
666 |    - Use `┐`, `┤`, `┘` box drawing characters for clean connection lines
667 |    - Align arrows consistently
668 | 
669 | 3. **Center (Target)**:
670 |    - Enclose target in square brackets
671 |    - Position centrally between incoming and outgoing
672 | 
673 | 4. **Right Side (Outgoing)**:
674 |    - List all dependencies with `→` arrows
675 |    - Use `├`, `└` box drawing characters for branching
676 |    - Maintain consistent spacing
677 | 
678 | 5. **Type Relationships Section**:
679 |    - Use `──relationship──→` format with double hyphens
680 |    - Show inheritance, implementation, and usage relationships
681 |    - Place below the main flow diagram
682 | 
683 | **DEPENDENCY TABLE:**
684 | 
685 | | Type | From/To | Method | File | Line |
686 | |------|---------|--------|------|------|
687 | | incoming_call | From: CallerClass | callerMethod | /complete/path/file.ext | ## |
688 | | outgoing_call | To: TargetClass | targetMethod | /complete/path/file.ext | ## |
689 | | implements | Self: ThisClass | — | /complete/path/file.ext | — |
690 | | extends | Self: ThisClass | — | /complete/path/file.ext | — |
691 | | uses_type | Self: ThisClass | — | /complete/path/file.ext | — |
692 | 
693 | **ABSOLUTE REQUIREMENTS:**
694 | - Use ONLY the bidirectional arrow flow style shown above
695 | - Automatically detect and use the project's naming conventions
696 | - Use exact file paths and line numbers from the actual codebase
697 | - DO NOT invent or guess method/class names
698 | - Maintain visual alignment and consistent spacing
699 | - Include type relationships section when applicable
700 | - Show clear directional flow with proper arrows"""
701 | 
702 |     # ================================================================================
703 |     # Hook Method Overrides for Tracer-Specific Behavior
704 |     # ================================================================================
705 | 
706 |     def get_completion_status(self) -> str:
707 |         """Tracer uses tracing-specific status."""
708 |         return "tracing_complete"
709 | 
710 |     def get_completion_data_key(self) -> str:
711 |         """Tracer uses 'complete_tracing' key."""
712 |         return "complete_tracing"
713 | 
714 |     def get_completion_message(self) -> str:
715 |         """Tracer-specific completion message."""
716 |         return (
717 |             "Tracing analysis complete. Present the comprehensive trace analysis to the user "
718 |             "using the specified rendering format and offer to help with related tracing tasks."
719 |         )
720 | 
721 |     def get_skip_reason(self) -> str:
722 |         """Tracer-specific skip reason."""
723 |         return "Tracer is self-contained and completes analysis without external assistance"
724 | 
725 |     def get_skip_expert_analysis_status(self) -> str:
726 |         """Tracer-specific expert analysis skip status."""
727 |         return "skipped_by_tool_design"
728 | 
729 |     def store_initial_issue(self, step_description: str):
730 |         """Store initial tracing description."""
731 |         self.initial_tracing_description = step_description
732 | 
733 |     def get_initial_request(self, fallback_step: str) -> str:
734 |         """Get initial tracing description."""
735 |         try:
736 |             return self.initial_tracing_description
737 |         except AttributeError:
738 |             return fallback_step
739 | 
740 |     def get_request_confidence(self, request) -> str:
741 |         """Get confidence from request for tracer workflow."""
742 |         try:
743 |             return request.confidence or "exploring"
744 |         except AttributeError:
745 |             return "exploring"
746 | 
747 |     def get_trace_mode(self) -> str:
748 |         """Get current trace mode. Override for custom trace mode handling."""
749 |         try:
750 |             return self.trace_config.get("trace_mode", "ask")
751 |         except AttributeError:
752 |             return "ask"
753 | 
754 |     # Required abstract methods from BaseTool
755 |     def get_request_model(self):
756 |         """Return the tracer-specific request model."""
757 |         return TracerRequest
758 | 
759 |     async def prepare_prompt(self, request) -> str:
760 |         """Not used - workflow tools use execute_workflow()."""
761 |         return ""  # Workflow tools use execute_workflow() directly
762 | 
```
Page 18/25FirstPrevNextLast