This is page 18 of 25. Use http://codebase.md/beehiveinnovations/gemini-mcp-server?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .claude
│ ├── commands
│ │ └── fix-github-issue.md
│ └── settings.json
├── .coveragerc
├── .dockerignore
├── .env.example
├── .gitattributes
├── .github
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.yml
│ │ ├── config.yml
│ │ ├── documentation.yml
│ │ ├── feature_request.yml
│ │ └── tool_addition.yml
│ ├── pull_request_template.md
│ └── workflows
│ ├── docker-pr.yml
│ ├── docker-release.yml
│ ├── semantic-pr.yml
│ ├── semantic-release.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── AGENTS.md
├── CHANGELOG.md
├── claude_config_example.json
├── CLAUDE.md
├── clink
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── claude.py
│ │ ├── codex.py
│ │ └── gemini.py
│ ├── constants.py
│ ├── models.py
│ ├── parsers
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── claude.py
│ │ ├── codex.py
│ │ └── gemini.py
│ └── registry.py
├── code_quality_checks.ps1
├── code_quality_checks.sh
├── communication_simulator_test.py
├── conf
│ ├── __init__.py
│ ├── azure_models.json
│ ├── cli_clients
│ │ ├── claude.json
│ │ ├── codex.json
│ │ └── gemini.json
│ ├── custom_models.json
│ ├── dial_models.json
│ ├── gemini_models.json
│ ├── openai_models.json
│ ├── openrouter_models.json
│ └── xai_models.json
├── config.py
├── docker
│ ├── README.md
│ └── scripts
│ ├── build.ps1
│ ├── build.sh
│ ├── deploy.ps1
│ ├── deploy.sh
│ └── healthcheck.py
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── adding_providers.md
│ ├── adding_tools.md
│ ├── advanced-usage.md
│ ├── ai_banter.md
│ ├── ai-collaboration.md
│ ├── azure_openai.md
│ ├── configuration.md
│ ├── context-revival.md
│ ├── contributions.md
│ ├── custom_models.md
│ ├── docker-deployment.md
│ ├── gemini-setup.md
│ ├── getting-started.md
│ ├── index.md
│ ├── locale-configuration.md
│ ├── logging.md
│ ├── model_ranking.md
│ ├── testing.md
│ ├── tools
│ │ ├── analyze.md
│ │ ├── apilookup.md
│ │ ├── challenge.md
│ │ ├── chat.md
│ │ ├── clink.md
│ │ ├── codereview.md
│ │ ├── consensus.md
│ │ ├── debug.md
│ │ ├── docgen.md
│ │ ├── listmodels.md
│ │ ├── planner.md
│ │ ├── precommit.md
│ │ ├── refactor.md
│ │ ├── secaudit.md
│ │ ├── testgen.md
│ │ ├── thinkdeep.md
│ │ ├── tracer.md
│ │ └── version.md
│ ├── troubleshooting.md
│ ├── vcr-testing.md
│ └── wsl-setup.md
├── examples
│ ├── claude_config_macos.json
│ └── claude_config_wsl.json
├── LICENSE
├── providers
│ ├── __init__.py
│ ├── azure_openai.py
│ ├── base.py
│ ├── custom.py
│ ├── dial.py
│ ├── gemini.py
│ ├── openai_compatible.py
│ ├── openai.py
│ ├── openrouter.py
│ ├── registries
│ │ ├── __init__.py
│ │ ├── azure.py
│ │ ├── base.py
│ │ ├── custom.py
│ │ ├── dial.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ ├── openrouter.py
│ │ └── xai.py
│ ├── registry_provider_mixin.py
│ ├── registry.py
│ ├── shared
│ │ ├── __init__.py
│ │ ├── model_capabilities.py
│ │ ├── model_response.py
│ │ ├── provider_type.py
│ │ └── temperature.py
│ └── xai.py
├── pyproject.toml
├── pytest.ini
├── README.md
├── requirements-dev.txt
├── requirements.txt
├── run_integration_tests.ps1
├── run_integration_tests.sh
├── run-server.ps1
├── run-server.sh
├── scripts
│ └── sync_version.py
├── server.py
├── simulator_tests
│ ├── __init__.py
│ ├── base_test.py
│ ├── conversation_base_test.py
│ ├── log_utils.py
│ ├── test_analyze_validation.py
│ ├── test_basic_conversation.py
│ ├── test_chat_simple_validation.py
│ ├── test_codereview_validation.py
│ ├── test_consensus_conversation.py
│ ├── test_consensus_three_models.py
│ ├── test_consensus_workflow_accurate.py
│ ├── test_content_validation.py
│ ├── test_conversation_chain_validation.py
│ ├── test_cross_tool_comprehensive.py
│ ├── test_cross_tool_continuation.py
│ ├── test_debug_certain_confidence.py
│ ├── test_debug_validation.py
│ ├── test_line_number_validation.py
│ ├── test_logs_validation.py
│ ├── test_model_thinking_config.py
│ ├── test_o3_model_selection.py
│ ├── test_o3_pro_expensive.py
│ ├── test_ollama_custom_url.py
│ ├── test_openrouter_fallback.py
│ ├── test_openrouter_models.py
│ ├── test_per_tool_deduplication.py
│ ├── test_planner_continuation_history.py
│ ├── test_planner_validation_old.py
│ ├── test_planner_validation.py
│ ├── test_precommitworkflow_validation.py
│ ├── test_prompt_size_limit_bug.py
│ ├── test_refactor_validation.py
│ ├── test_secaudit_validation.py
│ ├── test_testgen_validation.py
│ ├── test_thinkdeep_validation.py
│ ├── test_token_allocation_validation.py
│ ├── test_vision_capability.py
│ └── test_xai_models.py
├── systemprompts
│ ├── __init__.py
│ ├── analyze_prompt.py
│ ├── chat_prompt.py
│ ├── clink
│ │ ├── codex_codereviewer.txt
│ │ ├── default_codereviewer.txt
│ │ ├── default_planner.txt
│ │ └── default.txt
│ ├── codereview_prompt.py
│ ├── consensus_prompt.py
│ ├── debug_prompt.py
│ ├── docgen_prompt.py
│ ├── generate_code_prompt.py
│ ├── planner_prompt.py
│ ├── precommit_prompt.py
│ ├── refactor_prompt.py
│ ├── secaudit_prompt.py
│ ├── testgen_prompt.py
│ ├── thinkdeep_prompt.py
│ └── tracer_prompt.py
├── tests
│ ├── __init__.py
│ ├── CASSETTE_MAINTENANCE.md
│ ├── conftest.py
│ ├── gemini_cassettes
│ │ ├── chat_codegen
│ │ │ └── gemini25_pro_calculator
│ │ │ └── mldev.json
│ │ ├── chat_cross
│ │ │ └── step1_gemini25_flash_number
│ │ │ └── mldev.json
│ │ └── consensus
│ │ └── step2_gemini25_flash_against
│ │ └── mldev.json
│ ├── http_transport_recorder.py
│ ├── mock_helpers.py
│ ├── openai_cassettes
│ │ ├── chat_cross_step2_gpt5_reminder.json
│ │ ├── chat_gpt5_continuation.json
│ │ ├── chat_gpt5_moon_distance.json
│ │ ├── consensus_step1_gpt5_for.json
│ │ └── o3_pro_basic_math.json
│ ├── pii_sanitizer.py
│ ├── sanitize_cassettes.py
│ ├── test_alias_target_restrictions.py
│ ├── test_auto_mode_comprehensive.py
│ ├── test_auto_mode_custom_provider_only.py
│ ├── test_auto_mode_model_listing.py
│ ├── test_auto_mode_provider_selection.py
│ ├── test_auto_mode.py
│ ├── test_auto_model_planner_fix.py
│ ├── test_azure_openai_provider.py
│ ├── test_buggy_behavior_prevention.py
│ ├── test_cassette_semantic_matching.py
│ ├── test_challenge.py
│ ├── test_chat_codegen_integration.py
│ ├── test_chat_cross_model_continuation.py
│ ├── test_chat_openai_integration.py
│ ├── test_chat_simple.py
│ ├── test_clink_claude_agent.py
│ ├── test_clink_claude_parser.py
│ ├── test_clink_codex_agent.py
│ ├── test_clink_gemini_agent.py
│ ├── test_clink_gemini_parser.py
│ ├── test_clink_integration.py
│ ├── test_clink_parsers.py
│ ├── test_clink_tool.py
│ ├── test_collaboration.py
│ ├── test_config.py
│ ├── test_consensus_integration.py
│ ├── test_consensus_schema.py
│ ├── test_consensus.py
│ ├── test_conversation_continuation_integration.py
│ ├── test_conversation_field_mapping.py
│ ├── test_conversation_file_features.py
│ ├── test_conversation_memory.py
│ ├── test_conversation_missing_files.py
│ ├── test_custom_openai_temperature_fix.py
│ ├── test_custom_provider.py
│ ├── test_debug.py
│ ├── test_deploy_scripts.py
│ ├── test_dial_provider.py
│ ├── test_directory_expansion_tracking.py
│ ├── test_disabled_tools.py
│ ├── test_docker_claude_desktop_integration.py
│ ├── test_docker_config_complete.py
│ ├── test_docker_healthcheck.py
│ ├── test_docker_implementation.py
│ ├── test_docker_mcp_validation.py
│ ├── test_docker_security.py
│ ├── test_docker_volume_persistence.py
│ ├── test_file_protection.py
│ ├── test_gemini_token_usage.py
│ ├── test_image_support_integration.py
│ ├── test_image_validation.py
│ ├── test_integration_utf8.py
│ ├── test_intelligent_fallback.py
│ ├── test_issue_245_simple.py
│ ├── test_large_prompt_handling.py
│ ├── test_line_numbers_integration.py
│ ├── test_listmodels_restrictions.py
│ ├── test_listmodels.py
│ ├── test_mcp_error_handling.py
│ ├── test_model_enumeration.py
│ ├── test_model_metadata_continuation.py
│ ├── test_model_resolution_bug.py
│ ├── test_model_restrictions.py
│ ├── test_o3_pro_output_text_fix.py
│ ├── test_o3_temperature_fix_simple.py
│ ├── test_openai_compatible_token_usage.py
│ ├── test_openai_provider.py
│ ├── test_openrouter_provider.py
│ ├── test_openrouter_registry.py
│ ├── test_parse_model_option.py
│ ├── test_per_tool_model_defaults.py
│ ├── test_pii_sanitizer.py
│ ├── test_pip_detection_fix.py
│ ├── test_planner.py
│ ├── test_precommit_workflow.py
│ ├── test_prompt_regression.py
│ ├── test_prompt_size_limit_bug_fix.py
│ ├── test_provider_retry_logic.py
│ ├── test_provider_routing_bugs.py
│ ├── test_provider_utf8.py
│ ├── test_providers.py
│ ├── test_rate_limit_patterns.py
│ ├── test_refactor.py
│ ├── test_secaudit.py
│ ├── test_server.py
│ ├── test_supported_models_aliases.py
│ ├── test_thinking_modes.py
│ ├── test_tools.py
│ ├── test_tracer.py
│ ├── test_utf8_localization.py
│ ├── test_utils.py
│ ├── test_uvx_resource_packaging.py
│ ├── test_uvx_support.py
│ ├── test_workflow_file_embedding.py
│ ├── test_workflow_metadata.py
│ ├── test_workflow_prompt_size_validation_simple.py
│ ├── test_workflow_utf8.py
│ ├── test_xai_provider.py
│ ├── transport_helpers.py
│ └── triangle.png
├── tools
│ ├── __init__.py
│ ├── analyze.py
│ ├── apilookup.py
│ ├── challenge.py
│ ├── chat.py
│ ├── clink.py
│ ├── codereview.py
│ ├── consensus.py
│ ├── debug.py
│ ├── docgen.py
│ ├── listmodels.py
│ ├── models.py
│ ├── planner.py
│ ├── precommit.py
│ ├── refactor.py
│ ├── secaudit.py
│ ├── shared
│ │ ├── __init__.py
│ │ ├── base_models.py
│ │ ├── base_tool.py
│ │ ├── exceptions.py
│ │ └── schema_builders.py
│ ├── simple
│ │ ├── __init__.py
│ │ └── base.py
│ ├── testgen.py
│ ├── thinkdeep.py
│ ├── tracer.py
│ ├── version.py
│ └── workflow
│ ├── __init__.py
│ ├── base.py
│ ├── schema_builders.py
│ └── workflow_mixin.py
├── utils
│ ├── __init__.py
│ ├── client_info.py
│ ├── conversation_memory.py
│ ├── env.py
│ ├── file_types.py
│ ├── file_utils.py
│ ├── image_utils.py
│ ├── model_context.py
│ ├── model_restrictions.py
│ ├── security_config.py
│ ├── storage_backend.py
│ └── token_utils.py
└── zen-mcp-server
```
# Files
--------------------------------------------------------------------------------
/tools/secaudit.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | SECAUDIT Workflow tool - Comprehensive security audit with systematic investigation
3 |
4 | This tool provides a structured workflow for comprehensive security assessment and analysis.
5 | It guides the CLI agent through systematic investigation steps with forced pauses between each step
6 | to ensure thorough security examination, vulnerability identification, and compliance assessment
7 | before proceeding. The tool supports complex security scenarios including OWASP Top 10 coverage,
8 | compliance framework mapping, and technology-specific security patterns.
9 |
10 | Key features:
11 | - Step-by-step security audit workflow with progress tracking
12 | - Context-aware file embedding (references during investigation, full content for analysis)
13 | - Automatic security issue tracking with severity classification
14 | - Expert analysis integration with external models
15 | - Support for focused security audits (OWASP, compliance, technology-specific)
16 | - Confidence-based workflow optimization
17 | - Risk-based prioritization and remediation planning
18 | """
19 |
20 | import logging
21 | from typing import TYPE_CHECKING, Any, Literal, Optional
22 |
23 | from pydantic import Field, model_validator
24 |
25 | if TYPE_CHECKING:
26 | from tools.models import ToolModelCategory
27 |
28 | from config import TEMPERATURE_ANALYTICAL
29 | from systemprompts import SECAUDIT_PROMPT
30 | from tools.shared.base_models import WorkflowRequest
31 |
32 | from .workflow.base import WorkflowTool
33 |
34 | logger = logging.getLogger(__name__)
35 |
36 | # Tool-specific field descriptions for security audit workflow
37 | SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS = {
38 | "step": (
39 | "Step 1: outline the audit strategy (OWASP Top 10, auth, validation, etc.). Later steps: report findings. MANDATORY: use `relevant_files` for code references and avoid large snippets."
40 | ),
41 | "step_number": "Current security-audit step number (starts at 1).",
42 | "total_steps": "Expected number of audit steps; adjust as new risks surface.",
43 | "next_step_required": "True while additional threat analysis remains; set False once you are ready to hand off for validation.",
44 | "findings": "Summarize vulnerabilities, auth issues, validation gaps, compliance notes, and positives; update prior findings as needed.",
45 | "files_checked": "Absolute paths for every file inspected, including rejected candidates.",
46 | "relevant_files": "Absolute paths for security-relevant files (auth modules, configs, sensitive code).",
47 | "relevant_context": "Security-critical classes/methods (e.g. 'AuthService.login', 'encryption_helper').",
48 | "issues_found": "Security issues with severity (critical/high/medium/low) and descriptions (vulns, auth flaws, injection, crypto, config).",
49 | "confidence": "exploring/low/medium/high/very_high/almost_certain/certain. 'certain' blocks external validation—use only when fully complete.",
50 | "images": "Optional absolute paths to diagrams or threat models that inform the audit.",
51 | "security_scope": "Security context (web, mobile, API, cloud, etc.) including stack, user types, data sensitivity, and threat landscape.",
52 | "threat_level": "Assess the threat level: low (internal/low-risk), medium (customer-facing/business data), high (regulated or sensitive), critical (financial/healthcare/PII).",
53 | "compliance_requirements": "Applicable compliance frameworks or standards (SOC2, PCI DSS, HIPAA, GDPR, ISO 27001, NIST, etc.).",
54 | "audit_focus": "Primary focus area: owasp, compliance, infrastructure, dependencies, or comprehensive.",
55 | "severity_filter": "Minimum severity to include when reporting security issues.",
56 | }
57 |
58 |
59 | class SecauditRequest(WorkflowRequest):
60 | """Request model for security audit workflow investigation steps"""
61 |
62 | # Required fields for each investigation step
63 | step: str = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["step"])
64 | step_number: int = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
65 | total_steps: int = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
66 | next_step_required: bool = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
67 |
68 | # Investigation tracking fields
69 | findings: str = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["findings"])
70 | files_checked: list[str] = Field(
71 | default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["files_checked"]
72 | )
73 | relevant_files: list[str] = Field(
74 | default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"]
75 | )
76 | relevant_context: list[str] = Field(
77 | default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"]
78 | )
79 | issues_found: list[dict] = Field(
80 | default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["issues_found"]
81 | )
82 | confidence: Optional[str] = Field("low", description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["confidence"])
83 |
84 | # Optional images for visual context
85 | images: Optional[list[str]] = Field(default=None, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["images"])
86 |
87 | # Security audit-specific fields
88 | security_scope: Optional[str] = Field(None, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["security_scope"])
89 | threat_level: Optional[Literal["low", "medium", "high", "critical"]] = Field(
90 | "medium", description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["threat_level"]
91 | )
92 | compliance_requirements: Optional[list[str]] = Field(
93 | default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["compliance_requirements"]
94 | )
95 | audit_focus: Optional[Literal["owasp", "compliance", "infrastructure", "dependencies", "comprehensive"]] = Field(
96 | "comprehensive", description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["audit_focus"]
97 | )
98 | severity_filter: Optional[Literal["critical", "high", "medium", "low", "all"]] = Field(
99 | "all", description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["severity_filter"]
100 | )
101 |
102 | @model_validator(mode="after")
103 | def validate_security_audit_request(self):
104 | """Validate security audit request parameters"""
105 | # Ensure security scope is provided for comprehensive audits
106 | if self.step_number == 1 and not self.security_scope:
107 | logger.warning("Security scope not provided for security audit - defaulting to general application")
108 |
109 | # Validate compliance requirements format
110 | if self.compliance_requirements:
111 | valid_compliance = {"SOC2", "PCI DSS", "HIPAA", "GDPR", "ISO 27001", "NIST", "FedRAMP", "FISMA"}
112 | for req in self.compliance_requirements:
113 | if req not in valid_compliance:
114 | logger.warning(f"Unknown compliance requirement: {req}")
115 |
116 | return self
117 |
118 |
119 | class SecauditTool(WorkflowTool):
120 | """
121 | Comprehensive security audit workflow tool.
122 |
123 | Provides systematic security assessment through multi-step investigation
124 | covering OWASP Top 10, compliance requirements, and technology-specific
125 | security patterns. Follows established WorkflowTool patterns while adding
126 | security-specific capabilities.
127 | """
128 |
129 | def __init__(self):
130 | super().__init__()
131 | self.initial_request = None
132 | self.security_config = {}
133 |
134 | def get_name(self) -> str:
135 | """Return the unique name of the tool."""
136 | return "secaudit"
137 |
138 | def get_description(self) -> str:
139 | """Return a description of the tool."""
140 | return (
141 | "Performs comprehensive security audit with systematic vulnerability assessment. "
142 | "Use for OWASP Top 10 analysis, compliance evaluation, threat modeling, and security architecture review. "
143 | "Guides through structured security investigation with expert validation."
144 | )
145 |
146 | def get_system_prompt(self) -> str:
147 | """Return the system prompt for expert security analysis."""
148 | return SECAUDIT_PROMPT
149 |
150 | def get_default_temperature(self) -> float:
151 | """Return the temperature for security audit analysis"""
152 | return TEMPERATURE_ANALYTICAL
153 |
154 | def get_model_category(self) -> "ToolModelCategory":
155 | """Return the model category for security audit"""
156 | from tools.models import ToolModelCategory
157 |
158 | return ToolModelCategory.EXTENDED_REASONING
159 |
160 | def get_workflow_request_model(self) -> type:
161 | """Return the workflow request model class"""
162 | return SecauditRequest
163 |
164 | def get_tool_fields(self) -> dict[str, dict[str, Any]]:
165 | """
166 | Get security audit tool field definitions.
167 |
168 | Returns comprehensive field definitions including security-specific
169 | parameters while maintaining compatibility with existing workflow patterns.
170 | """
171 | return SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS
172 |
173 | def get_required_actions(
174 | self, step_number: int, confidence: str, findings: str, total_steps: int, request=None
175 | ) -> list[str]:
176 | """
177 | Provide step-specific guidance for systematic security analysis.
178 |
179 | Each step focuses on specific security domains to ensure comprehensive
180 | coverage without missing critical security aspects.
181 | """
182 | if step_number == 1:
183 | return [
184 | "Identify application type, technology stack, and security scope",
185 | "Map attack surface, entry points, and data flows",
186 | "Determine relevant security standards and compliance requirements",
187 | "Establish threat landscape and risk context for the application",
188 | ]
189 | elif step_number == 2:
190 | return [
191 | "Analyze authentication mechanisms and session management",
192 | "Check authorization controls, access patterns, and privilege escalation risks",
193 | "Assess multi-factor authentication, password policies, and account security",
194 | "Review identity and access management implementations",
195 | ]
196 | elif step_number == 3:
197 | return [
198 | "Examine input validation and sanitization mechanisms across all entry points",
199 | "Check for injection vulnerabilities (SQL, XSS, Command, LDAP, NoSQL)",
200 | "Review data encryption, sensitive data handling, and cryptographic implementations",
201 | "Analyze API input validation, rate limiting, and request/response security",
202 | ]
203 | elif step_number == 4:
204 | return [
205 | "Conduct OWASP Top 10 (2021) systematic review across all categories",
206 | "Check each OWASP category methodically with specific findings and evidence",
207 | "Cross-reference findings with application context and technology stack",
208 | "Prioritize vulnerabilities based on exploitability and business impact",
209 | ]
210 | elif step_number == 5:
211 | return [
212 | "Analyze third-party dependencies for known vulnerabilities and outdated versions",
213 | "Review configuration security, default settings, and hardening measures",
214 | "Check for hardcoded secrets, credentials, and sensitive information exposure",
215 | "Assess logging, monitoring, incident response, and security observability",
216 | ]
217 | elif step_number == 6:
218 | return [
219 | "Evaluate compliance requirements and identify gaps in controls",
220 | "Assess business impact and risk levels of all identified findings",
221 | "Create prioritized remediation roadmap with timeline and effort estimates",
222 | "Document comprehensive security posture and recommendations",
223 | ]
224 | else:
225 | return [
226 | "Continue systematic security investigation based on emerging findings",
227 | "Deep-dive into specific security concerns identified in previous steps",
228 | "Validate security hypotheses and confirm vulnerability assessments",
229 | ]
230 |
231 | def should_call_expert_analysis(self, consolidated_findings, request=None) -> bool:
232 | """
233 | Determine when to call expert security analysis.
234 |
235 | Expert analysis is triggered when the security audit has meaningful findings
236 | unless the user requested to skip assistant model.
237 | """
238 | # Check if user requested to skip assistant model
239 | if request and not self.get_request_use_assistant_model(request):
240 | return False
241 |
242 | # Check if we have meaningful investigation data
243 | return (
244 | len(consolidated_findings.relevant_files) > 0
245 | or len(consolidated_findings.findings) >= 2
246 | or len(consolidated_findings.issues_found) > 0
247 | )
248 |
249 | def prepare_expert_analysis_context(self, consolidated_findings) -> str:
250 | """
251 | Prepare comprehensive context for expert security model analysis.
252 |
253 | Provides security-specific context including scope, threat level,
254 | compliance requirements, and systematic findings for expert validation.
255 | """
256 | context_parts = [
257 | f"=== SECURITY AUDIT REQUEST ===\n{self.initial_request or 'Security audit workflow initiated'}\n=== END REQUEST ==="
258 | ]
259 |
260 | # Add investigation summary
261 | investigation_summary = self._build_security_audit_summary(consolidated_findings)
262 | context_parts.append(
263 | f"\n=== AGENT'S SECURITY INVESTIGATION ===\n{investigation_summary}\n=== END INVESTIGATION ==="
264 | )
265 |
266 | # Add security configuration context if available
267 | if self.security_config:
268 | config_text = "\n".join(f"- {key}: {value}" for key, value in self.security_config.items() if value)
269 | context_parts.append(f"\n=== SECURITY CONFIGURATION ===\n{config_text}\n=== END CONFIGURATION ===")
270 |
271 | # Add relevant files if available
272 | if consolidated_findings.relevant_files:
273 | files_text = "\n".join(f"- {file}" for file in consolidated_findings.relevant_files)
274 | context_parts.append(f"\n=== RELEVANT FILES ===\n{files_text}\n=== END FILES ===")
275 |
276 | # Add relevant security elements if available
277 | if consolidated_findings.relevant_context:
278 | methods_text = "\n".join(f"- {method}" for method in consolidated_findings.relevant_context)
279 | context_parts.append(
280 | f"\n=== SECURITY-CRITICAL CODE ELEMENTS ===\n{methods_text}\n=== END CODE ELEMENTS ==="
281 | )
282 |
283 | # Add security issues found if available
284 | if consolidated_findings.issues_found:
285 | issues_text = self._format_security_issues(consolidated_findings.issues_found)
286 | context_parts.append(f"\n=== SECURITY ISSUES IDENTIFIED ===\n{issues_text}\n=== END ISSUES ===")
287 |
288 | # Add assessment evolution if available
289 | if consolidated_findings.hypotheses:
290 | assessments_text = "\n".join(
291 | f"Step {h['step']} ({h['confidence']} confidence): {h['hypothesis']}"
292 | for h in consolidated_findings.hypotheses
293 | )
294 | context_parts.append(f"\n=== ASSESSMENT EVOLUTION ===\n{assessments_text}\n=== END ASSESSMENTS ===")
295 |
296 | # Add images if available
297 | if consolidated_findings.images:
298 | images_text = "\n".join(f"- {img}" for img in consolidated_findings.images)
299 | context_parts.append(
300 | f"\n=== VISUAL SECURITY INFORMATION ===\n{images_text}\n=== END VISUAL INFORMATION ==="
301 | )
302 |
303 | return "\n".join(context_parts)
304 |
305 | def _format_security_issues(self, issues_found: list[dict]) -> str:
306 | """
307 | Format security issues for expert analysis.
308 |
309 | Organizes security findings by severity for clear expert review.
310 | """
311 | if not issues_found:
312 | return "No security issues identified during systematic investigation."
313 |
314 | # Group issues by severity
315 | severity_groups = {"critical": [], "high": [], "medium": [], "low": []}
316 |
317 | for issue in issues_found:
318 | severity = issue.get("severity", "low").lower()
319 | description = issue.get("description", "No description provided")
320 | if severity in severity_groups:
321 | severity_groups[severity].append(description)
322 | else:
323 | severity_groups["low"].append(f"[{severity.upper()}] {description}")
324 |
325 | formatted_issues = []
326 | for severity in ["critical", "high", "medium", "low"]:
327 | if severity_groups[severity]:
328 | formatted_issues.append(f"\n{severity.upper()} SEVERITY:")
329 | for issue in severity_groups[severity]:
330 | formatted_issues.append(f" • {issue}")
331 |
332 | return "\n".join(formatted_issues) if formatted_issues else "No security issues identified."
333 |
334 | def _build_security_audit_summary(self, consolidated_findings) -> str:
335 | """Prepare a comprehensive summary of the security audit investigation."""
336 | summary_parts = [
337 | "=== SYSTEMATIC SECURITY AUDIT INVESTIGATION SUMMARY ===",
338 | f"Total steps: {len(consolidated_findings.findings)}",
339 | f"Files examined: {len(consolidated_findings.files_checked)}",
340 | f"Relevant files identified: {len(consolidated_findings.relevant_files)}",
341 | f"Security-critical elements analyzed: {len(consolidated_findings.relevant_context)}",
342 | f"Security issues identified: {len(consolidated_findings.issues_found)}",
343 | "",
344 | "=== INVESTIGATION PROGRESSION ===",
345 | ]
346 |
347 | for finding in consolidated_findings.findings:
348 | summary_parts.append(finding)
349 |
350 | return "\n".join(summary_parts)
351 |
352 | def get_input_schema(self) -> dict[str, Any]:
353 | """Generate input schema using WorkflowSchemaBuilder with security audit-specific overrides."""
354 | from .workflow.schema_builders import WorkflowSchemaBuilder
355 |
356 | # Security audit workflow-specific field overrides
357 | secaudit_field_overrides = {
358 | "step": {
359 | "type": "string",
360 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["step"],
361 | },
362 | "step_number": {
363 | "type": "integer",
364 | "minimum": 1,
365 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["step_number"],
366 | },
367 | "total_steps": {
368 | "type": "integer",
369 | "minimum": 1,
370 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["total_steps"],
371 | },
372 | "next_step_required": {
373 | "type": "boolean",
374 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"],
375 | },
376 | "findings": {
377 | "type": "string",
378 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["findings"],
379 | },
380 | "files_checked": {
381 | "type": "array",
382 | "items": {"type": "string"},
383 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["files_checked"],
384 | },
385 | "relevant_files": {
386 | "type": "array",
387 | "items": {"type": "string"},
388 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"],
389 | },
390 | "confidence": {
391 | "type": "string",
392 | "enum": ["exploring", "low", "medium", "high", "very_high", "almost_certain", "certain"],
393 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["confidence"],
394 | },
395 | "issues_found": {
396 | "type": "array",
397 | "items": {"type": "object"},
398 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["issues_found"],
399 | },
400 | "images": {
401 | "type": "array",
402 | "items": {"type": "string"},
403 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["images"],
404 | },
405 | # Security audit-specific fields (for step 1)
406 | "security_scope": {
407 | "type": "string",
408 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["security_scope"],
409 | },
410 | "threat_level": {
411 | "type": "string",
412 | "enum": ["low", "medium", "high", "critical"],
413 | "default": "medium",
414 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["threat_level"],
415 | },
416 | "compliance_requirements": {
417 | "type": "array",
418 | "items": {"type": "string"},
419 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["compliance_requirements"],
420 | },
421 | "audit_focus": {
422 | "type": "string",
423 | "enum": ["owasp", "compliance", "infrastructure", "dependencies", "comprehensive"],
424 | "default": "comprehensive",
425 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["audit_focus"],
426 | },
427 | "severity_filter": {
428 | "type": "string",
429 | "enum": ["critical", "high", "medium", "low", "all"],
430 | "default": "all",
431 | "description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["severity_filter"],
432 | },
433 | }
434 |
435 | # Use WorkflowSchemaBuilder with security audit-specific tool fields
436 | return WorkflowSchemaBuilder.build_schema(
437 | tool_specific_fields=secaudit_field_overrides,
438 | model_field_schema=self.get_model_field_schema(),
439 | auto_mode=self.is_effective_auto_mode(),
440 | tool_name=self.get_name(),
441 | )
442 |
443 | # Hook method overrides for security audit-specific behavior
444 |
445 | def prepare_step_data(self, request) -> dict:
446 | """Map security audit-specific fields for internal processing."""
447 | step_data = {
448 | "step": request.step,
449 | "step_number": request.step_number,
450 | "findings": request.findings,
451 | "files_checked": request.files_checked,
452 | "relevant_files": request.relevant_files,
453 | "relevant_context": request.relevant_context,
454 | "issues_found": request.issues_found,
455 | "confidence": request.confidence,
456 | "hypothesis": request.findings, # Map findings to hypothesis for compatibility
457 | "images": request.images or [],
458 | }
459 |
460 | # Store security-specific configuration on first step
461 | if request.step_number == 1:
462 | self.security_config = {
463 | "security_scope": request.security_scope,
464 | "threat_level": request.threat_level,
465 | "compliance_requirements": request.compliance_requirements,
466 | "audit_focus": request.audit_focus,
467 | "severity_filter": request.severity_filter,
468 | }
469 |
470 | return step_data
471 |
472 | def should_skip_expert_analysis(self, request, consolidated_findings) -> bool:
473 | """Security audit workflow skips expert analysis when the CLI agent has "certain" confidence."""
474 | return request.confidence == "certain" and not request.next_step_required
475 |
476 | def store_initial_issue(self, step_description: str):
477 | """Store initial request for expert analysis."""
478 | self.initial_request = step_description
479 |
480 | def should_include_files_in_expert_prompt(self) -> bool:
481 | """Include files in expert analysis for comprehensive security audit."""
482 | return True
483 |
484 | def should_embed_system_prompt(self) -> bool:
485 | """Embed system prompt in expert analysis for proper context."""
486 | return True
487 |
488 | def get_expert_thinking_mode(self) -> str:
489 | """Use high thinking mode for thorough security analysis."""
490 | return "high"
491 |
492 | def get_expert_analysis_instruction(self) -> str:
493 | """Get specific instruction for security audit expert analysis."""
494 | return (
495 | "Please provide comprehensive security analysis based on the investigation findings. "
496 | "Focus on identifying any remaining vulnerabilities, validating the completeness of the analysis, "
497 | "and providing final recommendations for security improvements, following the OWASP-based "
498 | "format specified in the system prompt."
499 | )
500 |
501 | def get_completion_next_steps_message(self, expert_analysis_used: bool = False) -> str:
502 | """
503 | Security audit-specific completion message.
504 | """
505 | base_message = (
506 | "SECURITY AUDIT IS COMPLETE. You MUST now summarize and present ALL security findings organized by "
507 | "severity (Critical → High → Medium → Low), specific code locations with line numbers, and exact "
508 | "remediation steps for each vulnerability. Clearly prioritize the top 3 security issues that need "
509 | "immediate attention. Provide concrete, actionable guidance for each vulnerability—make it easy for "
510 | "developers to understand exactly what needs to be fixed and how to implement the security improvements."
511 | )
512 |
513 | # Add expert analysis guidance only when expert analysis was actually used
514 | if expert_analysis_used:
515 | expert_guidance = self.get_expert_analysis_guidance()
516 | if expert_guidance:
517 | return f"{base_message}\n\n{expert_guidance}"
518 |
519 | return base_message
520 |
521 | def get_expert_analysis_guidance(self) -> str:
522 | """
523 | Provide specific guidance for handling expert analysis in security audits.
524 | """
525 | return (
526 | "IMPORTANT: Analysis from an assistant model has been provided above. You MUST critically evaluate and validate "
527 | "the expert security findings rather than accepting them blindly. Cross-reference the expert analysis with "
528 | "your own investigation findings, verify that suggested security improvements are appropriate for this "
529 | "application's context and threat model, and ensure recommendations align with the project's security requirements. "
530 | "Present a synthesis that combines your systematic security review with validated expert insights, clearly "
531 | "distinguishing between vulnerabilities you've independently confirmed and additional insights from expert analysis."
532 | )
533 |
534 | def get_step_guidance_message(self, request) -> str:
535 | """
536 | Security audit-specific step guidance with detailed investigation instructions.
537 | """
538 | step_guidance = self.get_security_audit_step_guidance(request.step_number, request.confidence, request)
539 | return step_guidance["next_steps"]
540 |
541 | def get_security_audit_step_guidance(self, step_number: int, confidence: str, request) -> dict[str, Any]:
542 | """
543 | Provide step-specific guidance for security audit workflow.
544 | """
545 | # Generate the next steps instruction based on required actions
546 | required_actions = self.get_required_actions(step_number, confidence, request.findings, request.total_steps)
547 |
548 | if step_number == 1:
549 | next_steps = (
550 | f"MANDATORY: DO NOT call the {self.get_name()} tool again immediately. You MUST first examine "
551 | f"the code files thoroughly using appropriate tools. CRITICAL AWARENESS: You need to understand "
552 | f"the security landscape, identify potential vulnerabilities across OWASP Top 10 categories, "
553 | f"and look for authentication flaws, injection points, cryptographic issues, and authorization bypasses. "
554 | f"Use file reading tools, security analysis, and systematic examination to gather comprehensive information. "
555 | f"Only call {self.get_name()} again AFTER completing your security investigation. When you call "
556 | f"{self.get_name()} next time, use step_number: {step_number + 1} and report specific "
557 | f"files examined, vulnerabilities found, and security assessments discovered."
558 | )
559 | elif confidence in ["exploring", "low"]:
560 | next_steps = (
561 | f"STOP! Do NOT call {self.get_name()} again yet. Based on your findings, you've identified areas that need "
562 | f"deeper security analysis. MANDATORY ACTIONS before calling {self.get_name()} step {step_number + 1}:\n"
563 | + "\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
564 | + f"\n\nOnly call {self.get_name()} again with step_number: {step_number + 1} AFTER "
565 | + "completing these security audit tasks."
566 | )
567 | elif confidence in ["medium", "high"]:
568 | next_steps = (
569 | f"WAIT! Your security audit needs final verification. DO NOT call {self.get_name()} immediately. REQUIRED ACTIONS:\n"
570 | + "\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
571 | + f"\n\nREMEMBER: Ensure you have identified all significant vulnerabilities across all severity levels and "
572 | f"verified the completeness of your security review. Document findings with specific file references and "
573 | f"line numbers where applicable, then call {self.get_name()} with step_number: {step_number + 1}."
574 | )
575 | else:
576 | next_steps = (
577 | f"PAUSE SECURITY AUDIT. Before calling {self.get_name()} step {step_number + 1}, you MUST examine more code thoroughly. "
578 | + "Required: "
579 | + ", ".join(required_actions[:2])
580 | + ". "
581 | + f"Your next {self.get_name()} call (step_number: {step_number + 1}) must include "
582 | f"NEW evidence from actual security analysis, not just theories. NO recursive {self.get_name()} calls "
583 | f"without investigation work!"
584 | )
585 |
586 | return {"next_steps": next_steps}
587 |
588 | def customize_workflow_response(self, response_data: dict, request) -> dict:
589 | """
590 | Customize response to match security audit workflow format.
591 | """
592 | # Store initial request on first step
593 | if request.step_number == 1:
594 | self.initial_request = request.step
595 | # Store security configuration for expert analysis
596 | if request.relevant_files:
597 | self.security_config = {
598 | "relevant_files": request.relevant_files,
599 | "security_scope": request.security_scope,
600 | "threat_level": request.threat_level,
601 | "compliance_requirements": request.compliance_requirements,
602 | "audit_focus": request.audit_focus,
603 | "severity_filter": request.severity_filter,
604 | }
605 |
606 | # Convert generic status names to security audit-specific ones
607 | tool_name = self.get_name()
608 | status_mapping = {
609 | f"{tool_name}_in_progress": "security_audit_in_progress",
610 | f"pause_for_{tool_name}": "pause_for_security_audit",
611 | f"{tool_name}_required": "security_audit_required",
612 | f"{tool_name}_complete": "security_audit_complete",
613 | }
614 |
615 | if response_data["status"] in status_mapping:
616 | response_data["status"] = status_mapping[response_data["status"]]
617 |
618 | # Rename status field to match security audit workflow
619 | if f"{tool_name}_status" in response_data:
620 | response_data["security_audit_status"] = response_data.pop(f"{tool_name}_status")
621 | # Add security audit-specific status fields
622 | response_data["security_audit_status"]["vulnerabilities_by_severity"] = {}
623 | for issue in self.consolidated_findings.issues_found:
624 | severity = issue.get("severity", "unknown")
625 | if severity not in response_data["security_audit_status"]["vulnerabilities_by_severity"]:
626 | response_data["security_audit_status"]["vulnerabilities_by_severity"][severity] = 0
627 | response_data["security_audit_status"]["vulnerabilities_by_severity"][severity] += 1
628 | response_data["security_audit_status"]["audit_confidence"] = self.get_request_confidence(request)
629 |
630 | # Map complete_secaudit to complete_security_audit
631 | if f"complete_{tool_name}" in response_data:
632 | response_data["complete_security_audit"] = response_data.pop(f"complete_{tool_name}")
633 |
634 | # Map the completion flag to match security audit workflow
635 | if f"{tool_name}_complete" in response_data:
636 | response_data["security_audit_complete"] = response_data.pop(f"{tool_name}_complete")
637 |
638 | return response_data
639 |
640 | # Override inheritance hooks for security audit-specific behavior
641 |
642 | def get_completion_status(self) -> str:
643 | """Security audit tools use audit-specific status."""
644 | return "security_analysis_complete"
645 |
646 | def get_completion_data_key(self) -> str:
647 | """Security audit uses 'complete_security_audit' key."""
648 | return "complete_security_audit"
649 |
650 | def get_final_analysis_from_request(self, request):
651 | """Security audit tools use 'findings' field."""
652 | return request.findings
653 |
654 | def get_confidence_level(self, request) -> str:
655 | """Security audit tools use 'certain' for high confidence."""
656 | return "certain"
657 |
658 | def get_completion_message(self) -> str:
659 | """Security audit-specific completion message."""
660 | return (
661 | "Security audit complete with CERTAIN confidence. You have identified all significant vulnerabilities "
662 | "and provided comprehensive security analysis. MANDATORY: Present the user with the complete security audit results "
663 | "categorized by severity, and IMMEDIATELY proceed with implementing the highest priority security fixes "
664 | "or provide specific guidance for vulnerability remediation. Focus on actionable security recommendations."
665 | )
666 |
667 | def get_skip_reason(self) -> str:
668 | """Security audit-specific skip reason."""
669 | return "Completed comprehensive security audit with full confidence locally"
670 |
671 | def get_skip_expert_analysis_status(self) -> str:
672 | """Security audit-specific expert analysis skip status."""
673 | return "skipped_due_to_certain_audit_confidence"
674 |
675 | def prepare_work_summary(self) -> str:
676 | """Security audit-specific work summary."""
677 | return self._build_security_audit_summary(self.consolidated_findings)
678 |
679 | def get_request_model(self):
680 | """Return the request model for this tool"""
681 | return SecauditRequest
682 |
683 | async def prepare_prompt(self, request: SecauditRequest) -> str:
684 | """Not used - workflow tools use execute_workflow()."""
685 | return "" # Workflow tools use execute_workflow() directly
686 |
```
--------------------------------------------------------------------------------
/simulator_tests/test_testgen_validation.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | TestGen Tool Validation Test
4 |
5 | Tests the testgen tool's capabilities using the workflow architecture.
6 | This validates that the workflow-based implementation guides Claude through
7 | systematic test generation analysis before creating comprehensive test suites.
8 | """
9 |
10 | import json
11 | from typing import Optional
12 |
13 | from .conversation_base_test import ConversationBaseTest
14 |
15 |
16 | class TestGenValidationTest(ConversationBaseTest):
17 | """Test testgen tool with workflow architecture"""
18 |
19 | @property
20 | def test_name(self) -> str:
21 | return "testgen_validation"
22 |
23 | @property
24 | def test_description(self) -> str:
25 | return "TestGen tool validation with step-by-step test planning"
26 |
27 | def run_test(self) -> bool:
28 | """Test testgen tool capabilities"""
29 | # Set up the test environment
30 | self.setUp()
31 |
32 | try:
33 | self.logger.info("Test: TestGen tool validation")
34 |
35 | # Create sample code files to test
36 | self._create_test_code_files()
37 |
38 | # Test 1: Single investigation session with multiple steps
39 | if not self._test_single_test_generation_session():
40 | return False
41 |
42 | # Test 2: Test generation with pattern following
43 | if not self._test_generation_with_pattern_following():
44 | return False
45 |
46 | # Test 3: Complete test generation with expert analysis
47 | if not self._test_complete_generation_with_analysis():
48 | return False
49 |
50 | # Test 4: Certain confidence behavior
51 | if not self._test_certain_confidence():
52 | return False
53 |
54 | # Test 5: Context-aware file embedding
55 | if not self._test_context_aware_file_embedding():
56 | return False
57 |
58 | # Test 6: Multi-step test planning
59 | if not self._test_multi_step_test_planning():
60 | return False
61 |
62 | self.logger.info(" ✅ All testgen validation tests passed")
63 | return True
64 |
65 | except Exception as e:
66 | self.logger.error(f"TestGen validation test failed: {e}")
67 | return False
68 |
69 | def _create_test_code_files(self):
70 | """Create sample code files for test generation"""
71 | # Create a calculator module with various functions
72 | calculator_code = """#!/usr/bin/env python3
73 | \"\"\"
74 | Simple calculator module for demonstration
75 | \"\"\"
76 |
77 | def add(a, b):
78 | \"\"\"Add two numbers\"\"\"
79 | return a + b
80 |
81 | def subtract(a, b):
82 | \"\"\"Subtract b from a\"\"\"
83 | return a - b
84 |
85 | def multiply(a, b):
86 | \"\"\"Multiply two numbers\"\"\"
87 | return a * b
88 |
89 | def divide(a, b):
90 | \"\"\"Divide a by b\"\"\"
91 | if b == 0:
92 | raise ValueError("Cannot divide by zero")
93 | return a / b
94 |
95 | def calculate_percentage(value, percentage):
96 | \"\"\"Calculate percentage of a value\"\"\"
97 | if percentage < 0:
98 | raise ValueError("Percentage cannot be negative")
99 | if percentage > 100:
100 | raise ValueError("Percentage cannot exceed 100")
101 | return (value * percentage) / 100
102 |
103 | def power(base, exponent):
104 | \"\"\"Calculate base raised to exponent\"\"\"
105 | if base == 0 and exponent < 0:
106 | raise ValueError("Cannot raise 0 to negative power")
107 | return base ** exponent
108 | """
109 |
110 | # Create test file
111 | self.calculator_file = self.create_additional_test_file("calculator.py", calculator_code)
112 | self.logger.info(f" ✅ Created calculator module: {self.calculator_file}")
113 |
114 | # Create a simple existing test file to use as pattern
115 | existing_test = """#!/usr/bin/env python3
116 | import pytest
117 | from calculator import add, subtract
118 |
119 | class TestCalculatorBasic:
120 | \"\"\"Test basic calculator operations\"\"\"
121 |
122 | def test_add_positive_numbers(self):
123 | \"\"\"Test adding two positive numbers\"\"\"
124 | assert add(2, 3) == 5
125 | assert add(10, 20) == 30
126 |
127 | def test_add_negative_numbers(self):
128 | \"\"\"Test adding negative numbers\"\"\"
129 | assert add(-5, -3) == -8
130 | assert add(-10, 5) == -5
131 |
132 | def test_subtract_positive(self):
133 | \"\"\"Test subtracting positive numbers\"\"\"
134 | assert subtract(10, 3) == 7
135 | assert subtract(5, 5) == 0
136 | """
137 |
138 | self.existing_test_file = self.create_additional_test_file("test_calculator_basic.py", existing_test)
139 | self.logger.info(f" ✅ Created existing test file: {self.existing_test_file}")
140 |
141 | def _test_single_test_generation_session(self) -> bool:
142 | """Test a complete test generation session with multiple steps"""
143 | try:
144 | self.logger.info(" 1.1: Testing single test generation session")
145 |
146 | # Step 1: Start investigation
147 | self.logger.info(" 1.1.1: Step 1 - Initial test planning")
148 | response1, continuation_id = self.call_mcp_tool(
149 | "testgen",
150 | {
151 | "step": "I need to generate comprehensive tests for the calculator module. Let me start by analyzing the code structure and understanding the functionality.",
152 | "step_number": 1,
153 | "total_steps": 4,
154 | "next_step_required": True,
155 | "findings": "Calculator module contains 6 functions: add, subtract, multiply, divide, calculate_percentage, and power. Each has specific error conditions that need testing.",
156 | "files_checked": [self.calculator_file],
157 | "relevant_files": [self.calculator_file],
158 | "relevant_context": ["add", "subtract", "multiply", "divide", "calculate_percentage", "power"],
159 | },
160 | )
161 |
162 | if not response1 or not continuation_id:
163 | self.logger.error("Failed to get initial test planning response")
164 | return False
165 |
166 | # Parse and validate JSON response
167 | response1_data = self._parse_testgen_response(response1)
168 | if not response1_data:
169 | return False
170 |
171 | # Validate step 1 response structure
172 | if not self._validate_step_response(response1_data, 1, 4, True, "pause_for_test_analysis"):
173 | return False
174 |
175 | self.logger.info(f" ✅ Step 1 successful, continuation_id: {continuation_id}")
176 |
177 | # Step 2: Analyze test requirements
178 | self.logger.info(" 1.1.2: Step 2 - Test requirements analysis")
179 | response2, _ = self.call_mcp_tool(
180 | "testgen",
181 | {
182 | "step": "Now analyzing the test requirements for each function, identifying edge cases and boundary conditions.",
183 | "step_number": 2,
184 | "total_steps": 4,
185 | "next_step_required": True,
186 | "findings": "Identified key test scenarios: (1) divide - zero division error, (2) calculate_percentage - negative/over 100 validation, (3) power - zero to negative power error. Need tests for normal cases and edge cases.",
187 | "files_checked": [self.calculator_file],
188 | "relevant_files": [self.calculator_file],
189 | "relevant_context": ["divide", "calculate_percentage", "power"],
190 | "confidence": "medium",
191 | "continuation_id": continuation_id,
192 | },
193 | )
194 |
195 | if not response2:
196 | self.logger.error("Failed to continue test planning to step 2")
197 | return False
198 |
199 | response2_data = self._parse_testgen_response(response2)
200 | if not self._validate_step_response(response2_data, 2, 4, True, "pause_for_test_analysis"):
201 | return False
202 |
203 | # Check test generation status tracking
204 | test_status = response2_data.get("test_generation_status", {})
205 | if test_status.get("test_scenarios_identified", 0) < 3:
206 | self.logger.error("Test scenarios not properly tracked")
207 | return False
208 |
209 | if test_status.get("analysis_confidence") != "medium":
210 | self.logger.error("Confidence level not properly tracked")
211 | return False
212 |
213 | self.logger.info(" ✅ Step 2 successful with proper tracking")
214 |
215 | # Store continuation_id for next test
216 | self.test_continuation_id = continuation_id
217 | return True
218 |
219 | except Exception as e:
220 | self.logger.error(f"Single test generation session test failed: {e}")
221 | return False
222 |
223 | def _test_generation_with_pattern_following(self) -> bool:
224 | """Test test generation following existing patterns"""
225 | try:
226 | self.logger.info(" 1.2: Testing test generation with pattern following")
227 |
228 | # Start a new investigation with existing test patterns
229 | self.logger.info(" 1.2.1: Start test generation with pattern reference")
230 | response1, continuation_id = self.call_mcp_tool(
231 | "testgen",
232 | {
233 | "step": "Generating tests for remaining calculator functions following existing test patterns",
234 | "step_number": 1,
235 | "total_steps": 3,
236 | "next_step_required": True,
237 | "findings": "Found existing test pattern using pytest with class-based organization and descriptive test names",
238 | "files_checked": [self.calculator_file, self.existing_test_file],
239 | "relevant_files": [self.calculator_file, self.existing_test_file],
240 | "relevant_context": ["TestCalculatorBasic", "multiply", "divide", "calculate_percentage", "power"],
241 | },
242 | )
243 |
244 | if not response1 or not continuation_id:
245 | self.logger.error("Failed to start pattern following test")
246 | return False
247 |
248 | # Step 2: Analyze patterns
249 | self.logger.info(" 1.2.2: Step 2 - Pattern analysis")
250 | response2, _ = self.call_mcp_tool(
251 | "testgen",
252 | {
253 | "step": "Analyzing the existing test patterns to maintain consistency",
254 | "step_number": 2,
255 | "total_steps": 3,
256 | "next_step_required": True,
257 | "findings": "Existing tests use: class-based organization (TestCalculatorBasic), descriptive method names (test_operation_scenario), multiple assertions per test, pytest framework",
258 | "files_checked": [self.existing_test_file],
259 | "relevant_files": [self.calculator_file, self.existing_test_file],
260 | "confidence": "high",
261 | "continuation_id": continuation_id,
262 | },
263 | )
264 |
265 | if not response2:
266 | self.logger.error("Failed to continue to step 2")
267 | return False
268 |
269 | self.logger.info(" ✅ Pattern analysis successful")
270 | return True
271 |
272 | except Exception as e:
273 | self.logger.error(f"Pattern following test failed: {e}")
274 | return False
275 |
276 | def _test_complete_generation_with_analysis(self) -> bool:
277 | """Test complete test generation ending with expert analysis"""
278 | try:
279 | self.logger.info(" 1.3: Testing complete test generation with expert analysis")
280 |
281 | # Use the continuation from first test or start fresh
282 | continuation_id = getattr(self, "test_continuation_id", None)
283 | if not continuation_id:
284 | # Start fresh if no continuation available
285 | self.logger.info(" 1.3.0: Starting fresh test generation")
286 | response0, continuation_id = self.call_mcp_tool(
287 | "testgen",
288 | {
289 | "step": "Analyzing calculator module for comprehensive test generation",
290 | "step_number": 1,
291 | "total_steps": 2,
292 | "next_step_required": True,
293 | "findings": "Identified 6 functions needing tests with various edge cases",
294 | "files_checked": [self.calculator_file],
295 | "relevant_files": [self.calculator_file],
296 | "relevant_context": ["add", "subtract", "multiply", "divide", "calculate_percentage", "power"],
297 | },
298 | )
299 | if not response0 or not continuation_id:
300 | self.logger.error("Failed to start fresh test generation")
301 | return False
302 |
303 | # Final step - trigger expert analysis
304 | self.logger.info(" 1.3.1: Final step - complete test planning")
305 | response_final, _ = self.call_mcp_tool(
306 | "testgen",
307 | {
308 | "step": "Test planning complete. Identified all test scenarios including edge cases, error conditions, and boundary values for comprehensive coverage.",
309 | "step_number": 2,
310 | "total_steps": 2,
311 | "next_step_required": False, # Final step - triggers expert analysis
312 | "findings": "Complete test plan: normal operations, edge cases (zero, negative), error conditions (divide by zero, invalid percentage, zero to negative power), boundary values",
313 | "files_checked": [self.calculator_file],
314 | "relevant_files": [self.calculator_file],
315 | "relevant_context": ["add", "subtract", "multiply", "divide", "calculate_percentage", "power"],
316 | "confidence": "high",
317 | "continuation_id": continuation_id,
318 | "model": "flash", # Use flash for expert analysis
319 | },
320 | )
321 |
322 | if not response_final:
323 | self.logger.error("Failed to complete test generation")
324 | return False
325 |
326 | response_final_data = self._parse_testgen_response(response_final)
327 | if not response_final_data:
328 | return False
329 |
330 | # Validate final response structure
331 | if response_final_data.get("status") != "calling_expert_analysis":
332 | self.logger.error(
333 | f"Expected status 'calling_expert_analysis', got '{response_final_data.get('status')}'"
334 | )
335 | return False
336 |
337 | if not response_final_data.get("test_generation_complete"):
338 | self.logger.error("Expected test_generation_complete=true for final step")
339 | return False
340 |
341 | # Check for expert analysis
342 | if "expert_analysis" not in response_final_data:
343 | self.logger.error("Missing expert_analysis in final response")
344 | return False
345 |
346 | expert_analysis = response_final_data.get("expert_analysis", {})
347 |
348 | # Check for expected analysis content
349 | analysis_text = json.dumps(expert_analysis, ensure_ascii=False).lower()
350 |
351 | # Look for test generation indicators
352 | test_indicators = ["test", "edge", "boundary", "error", "coverage", "pytest"]
353 | found_indicators = sum(1 for indicator in test_indicators if indicator in analysis_text)
354 |
355 | if found_indicators >= 4:
356 | self.logger.info(" ✅ Expert analysis provided comprehensive test suggestions")
357 | else:
358 | self.logger.warning(
359 | f" ⚠️ Expert analysis may not have fully addressed test generation (found {found_indicators}/6 indicators)"
360 | )
361 |
362 | # Check complete test generation summary
363 | if "complete_test_generation" not in response_final_data:
364 | self.logger.error("Missing complete_test_generation in final response")
365 | return False
366 |
367 | complete_generation = response_final_data["complete_test_generation"]
368 | if not complete_generation.get("relevant_context"):
369 | self.logger.error("Missing relevant context in complete test generation")
370 | return False
371 |
372 | self.logger.info(" ✅ Complete test generation with expert analysis successful")
373 | return True
374 |
375 | except Exception as e:
376 | self.logger.error(f"Complete test generation test failed: {e}")
377 | return False
378 |
379 | def _test_certain_confidence(self) -> bool:
380 | """Test certain confidence behavior - should skip expert analysis"""
381 | try:
382 | self.logger.info(" 1.4: Testing certain confidence behavior")
383 |
384 | # Test certain confidence - should skip expert analysis
385 | self.logger.info(" 1.4.1: Certain confidence test generation")
386 | response_certain, _ = self.call_mcp_tool(
387 | "testgen",
388 | {
389 | "step": "I have fully analyzed the code and identified all test scenarios with 100% certainty. Test plan is complete.",
390 | "step_number": 1,
391 | "total_steps": 1,
392 | "next_step_required": False, # Final step
393 | "findings": "Complete test coverage plan: all functions covered with normal cases, edge cases, and error conditions. Ready for implementation.",
394 | "files_checked": [self.calculator_file],
395 | "relevant_files": [self.calculator_file],
396 | "relevant_context": ["add", "subtract", "multiply", "divide", "calculate_percentage", "power"],
397 | "confidence": "certain", # This should skip expert analysis
398 | "model": "flash",
399 | },
400 | )
401 |
402 | if not response_certain:
403 | self.logger.error("Failed to test certain confidence")
404 | return False
405 |
406 | response_certain_data = self._parse_testgen_response(response_certain)
407 | if not response_certain_data:
408 | return False
409 |
410 | # Validate certain confidence response - should skip expert analysis
411 | if response_certain_data.get("status") != "test_generation_complete_ready_for_implementation":
412 | self.logger.error(
413 | f"Expected status 'test_generation_complete_ready_for_implementation', got '{response_certain_data.get('status')}'"
414 | )
415 | return False
416 |
417 | if not response_certain_data.get("skip_expert_analysis"):
418 | self.logger.error("Expected skip_expert_analysis=true for certain confidence")
419 | return False
420 |
421 | expert_analysis = response_certain_data.get("expert_analysis", {})
422 | if expert_analysis.get("status") != "skipped_due_to_certain_test_confidence":
423 | self.logger.error("Expert analysis should be skipped for certain confidence")
424 | return False
425 |
426 | self.logger.info(" ✅ Certain confidence behavior working correctly")
427 | return True
428 |
429 | except Exception as e:
430 | self.logger.error(f"Certain confidence test failed: {e}")
431 | return False
432 |
433 | def call_mcp_tool(self, tool_name: str, params: dict) -> tuple[Optional[str], Optional[str]]:
434 | """Call an MCP tool in-process - override for testgen-specific response handling"""
435 | # Use in-process implementation to maintain conversation memory
436 | response_text, _ = self.call_mcp_tool_direct(tool_name, params)
437 |
438 | if not response_text:
439 | return None, None
440 |
441 | # Extract continuation_id from testgen response specifically
442 | continuation_id = self._extract_testgen_continuation_id(response_text)
443 |
444 | return response_text, continuation_id
445 |
446 | def _extract_testgen_continuation_id(self, response_text: str) -> Optional[str]:
447 | """Extract continuation_id from testgen response"""
448 | try:
449 | # Parse the response
450 | response_data = json.loads(response_text)
451 | return response_data.get("continuation_id")
452 |
453 | except json.JSONDecodeError as e:
454 | self.logger.debug(f"Failed to parse response for testgen continuation_id: {e}")
455 | return None
456 |
457 | def _parse_testgen_response(self, response_text: str) -> dict:
458 | """Parse testgen tool JSON response"""
459 | try:
460 | # Parse the response - it should be direct JSON
461 | return json.loads(response_text)
462 |
463 | except json.JSONDecodeError as e:
464 | self.logger.error(f"Failed to parse testgen response as JSON: {e}")
465 | self.logger.error(f"Response text: {response_text[:500]}...")
466 | return {}
467 |
468 | def _validate_step_response(
469 | self,
470 | response_data: dict,
471 | expected_step: int,
472 | expected_total: int,
473 | expected_next_required: bool,
474 | expected_status: str,
475 | ) -> bool:
476 | """Validate a test generation step response structure"""
477 | try:
478 | # Check status
479 | if response_data.get("status") != expected_status:
480 | self.logger.error(f"Expected status '{expected_status}', got '{response_data.get('status')}'")
481 | return False
482 |
483 | # Check step number
484 | if response_data.get("step_number") != expected_step:
485 | self.logger.error(f"Expected step_number {expected_step}, got {response_data.get('step_number')}")
486 | return False
487 |
488 | # Check total steps
489 | if response_data.get("total_steps") != expected_total:
490 | self.logger.error(f"Expected total_steps {expected_total}, got {response_data.get('total_steps')}")
491 | return False
492 |
493 | # Check next_step_required
494 | if response_data.get("next_step_required") != expected_next_required:
495 | self.logger.error(
496 | f"Expected next_step_required {expected_next_required}, got {response_data.get('next_step_required')}"
497 | )
498 | return False
499 |
500 | # Check test_generation_status exists
501 | if "test_generation_status" not in response_data:
502 | self.logger.error("Missing test_generation_status in response")
503 | return False
504 |
505 | # Check next_steps guidance
506 | if not response_data.get("next_steps"):
507 | self.logger.error("Missing next_steps guidance in response")
508 | return False
509 |
510 | return True
511 |
512 | except Exception as e:
513 | self.logger.error(f"Error validating step response: {e}")
514 | return False
515 |
516 | def _test_context_aware_file_embedding(self) -> bool:
517 | """Test context-aware file embedding optimization"""
518 | try:
519 | self.logger.info(" 1.5: Testing context-aware file embedding")
520 |
521 | # Create additional test files
522 | utils_code = """#!/usr/bin/env python3
523 | def validate_number(n):
524 | \"\"\"Validate if input is a number\"\"\"
525 | return isinstance(n, (int, float))
526 |
527 | def format_result(result):
528 | \"\"\"Format calculation result\"\"\"
529 | if isinstance(result, float):
530 | return round(result, 2)
531 | return result
532 | """
533 |
534 | math_helpers_code = """#!/usr/bin/env python3
535 | import math
536 |
537 | def factorial(n):
538 | \"\"\"Calculate factorial of n\"\"\"
539 | if n < 0:
540 | raise ValueError("Factorial not defined for negative numbers")
541 | return math.factorial(n)
542 |
543 | def is_prime(n):
544 | \"\"\"Check if number is prime\"\"\"
545 | if n < 2:
546 | return False
547 | for i in range(2, int(n**0.5) + 1):
548 | if n % i == 0:
549 | return False
550 | return True
551 | """
552 |
553 | # Create test files
554 | utils_file = self.create_additional_test_file("utils.py", utils_code)
555 | math_file = self.create_additional_test_file("math_helpers.py", math_helpers_code)
556 |
557 | # Test 1: New conversation, intermediate step - should only reference files
558 | self.logger.info(" 1.5.1: New conversation intermediate step (should reference only)")
559 | response1, continuation_id = self.call_mcp_tool(
560 | "testgen",
561 | {
562 | "step": "Starting test generation for utility modules",
563 | "step_number": 1,
564 | "total_steps": 3,
565 | "next_step_required": True, # Intermediate step
566 | "findings": "Initial analysis of utility functions",
567 | "files_checked": [utils_file, math_file],
568 | "relevant_files": [utils_file], # This should be referenced, not embedded
569 | "relevant_context": ["validate_number", "format_result"],
570 | "confidence": "low",
571 | "model": "flash",
572 | },
573 | )
574 |
575 | if not response1 or not continuation_id:
576 | self.logger.error("Failed to start context-aware file embedding test")
577 | return False
578 |
579 | response1_data = self._parse_testgen_response(response1)
580 | if not response1_data:
581 | return False
582 |
583 | # Check file context - should be reference_only for intermediate step
584 | file_context = response1_data.get("file_context", {})
585 | if file_context.get("type") != "reference_only":
586 | self.logger.error(f"Expected reference_only file context, got: {file_context.get('type')}")
587 | return False
588 |
589 | self.logger.info(" ✅ Intermediate step correctly uses reference_only file context")
590 |
591 | # Test 2: Final step - should embed files for expert analysis
592 | self.logger.info(" 1.5.2: Final step (should embed files)")
593 | response2, _ = self.call_mcp_tool(
594 | "testgen",
595 | {
596 | "step": "Test planning complete - all test scenarios identified",
597 | "step_number": 2,
598 | "total_steps": 2,
599 | "next_step_required": False, # Final step - should embed files
600 | "continuation_id": continuation_id,
601 | "findings": "Complete test plan for all utility functions with edge cases",
602 | "files_checked": [utils_file, math_file],
603 | "relevant_files": [utils_file, math_file], # Should be fully embedded
604 | "relevant_context": ["validate_number", "format_result", "factorial", "is_prime"],
605 | "confidence": "high",
606 | "model": "flash",
607 | },
608 | )
609 |
610 | if not response2:
611 | self.logger.error("Failed to complete to final step")
612 | return False
613 |
614 | response2_data = self._parse_testgen_response(response2)
615 | if not response2_data:
616 | return False
617 |
618 | # Check file context - should be fully_embedded for final step
619 | file_context2 = response2_data.get("file_context", {})
620 | if file_context2.get("type") != "fully_embedded":
621 | self.logger.error(
622 | f"Expected fully_embedded file context for final step, got: {file_context2.get('type')}"
623 | )
624 | return False
625 |
626 | # Verify expert analysis was called for final step
627 | if response2_data.get("status") != "calling_expert_analysis":
628 | self.logger.error("Final step should trigger expert analysis")
629 | return False
630 |
631 | self.logger.info(" ✅ Context-aware file embedding test completed successfully")
632 | return True
633 |
634 | except Exception as e:
635 | self.logger.error(f"Context-aware file embedding test failed: {e}")
636 | return False
637 |
638 | def _test_multi_step_test_planning(self) -> bool:
639 | """Test multi-step test planning with complex code"""
640 | try:
641 | self.logger.info(" 1.6: Testing multi-step test planning")
642 |
643 | # Create a complex class to test
644 | complex_code = """#!/usr/bin/env python3
645 | import asyncio
646 | from typing import List, Dict, Optional
647 |
648 | class DataProcessor:
649 | \"\"\"Complex data processor with async operations\"\"\"
650 |
651 | def __init__(self, batch_size: int = 100):
652 | self.batch_size = batch_size
653 | self.processed_count = 0
654 | self.error_count = 0
655 | self.cache: Dict[str, any] = {}
656 |
657 | async def process_batch(self, items: List[dict]) -> List[dict]:
658 | \"\"\"Process a batch of items asynchronously\"\"\"
659 | if not items:
660 | return []
661 |
662 | if len(items) > self.batch_size:
663 | raise ValueError(f"Batch size {len(items)} exceeds limit {self.batch_size}")
664 |
665 | results = []
666 | for item in items:
667 | try:
668 | result = await self._process_single_item(item)
669 | results.append(result)
670 | self.processed_count += 1
671 | except Exception as e:
672 | self.error_count += 1
673 | results.append({"error": str(e), "item": item})
674 |
675 | return results
676 |
677 | async def _process_single_item(self, item: dict) -> dict:
678 | \"\"\"Process a single item with caching\"\"\"
679 | item_id = item.get('id')
680 | if not item_id:
681 | raise ValueError("Item must have an ID")
682 |
683 | # Check cache
684 | if item_id in self.cache:
685 | return self.cache[item_id]
686 |
687 | # Simulate async processing
688 | await asyncio.sleep(0.01)
689 |
690 | processed = {
691 | 'id': item_id,
692 | 'processed': True,
693 | 'value': item.get('value', 0) * 2
694 | }
695 |
696 | # Cache result
697 | self.cache[item_id] = processed
698 | return processed
699 |
700 | def get_stats(self) -> Dict[str, int]:
701 | \"\"\"Get processing statistics\"\"\"
702 | return {
703 | 'processed': self.processed_count,
704 | 'errors': self.error_count,
705 | 'cache_size': len(self.cache),
706 | 'success_rate': self.processed_count / (self.processed_count + self.error_count) if (self.processed_count + self.error_count) > 0 else 0
707 | }
708 | """
709 |
710 | # Create test file
711 | processor_file = self.create_additional_test_file("data_processor.py", complex_code)
712 |
713 | # Step 1: Start investigation
714 | self.logger.info(" 1.6.1: Step 1 - Start complex test planning")
715 | response1, continuation_id = self.call_mcp_tool(
716 | "testgen",
717 | {
718 | "step": "Analyzing complex DataProcessor class for comprehensive test generation",
719 | "step_number": 1,
720 | "total_steps": 4,
721 | "next_step_required": True,
722 | "findings": "DataProcessor is an async class with caching, error handling, and statistics. Need async test patterns.",
723 | "files_checked": [processor_file],
724 | "relevant_files": [processor_file],
725 | "relevant_context": ["DataProcessor", "process_batch", "_process_single_item", "get_stats"],
726 | "confidence": "low",
727 | "model": "flash",
728 | },
729 | )
730 |
731 | if not response1 or not continuation_id:
732 | self.logger.error("Failed to start multi-step test planning")
733 | return False
734 |
735 | response1_data = self._parse_testgen_response(response1)
736 |
737 | # Validate step 1
738 | file_context1 = response1_data.get("file_context", {})
739 | if file_context1.get("type") != "reference_only":
740 | self.logger.error("Step 1 should use reference_only file context")
741 | return False
742 |
743 | self.logger.info(" ✅ Step 1: Started complex test planning")
744 |
745 | # Step 2: Analyze async patterns
746 | self.logger.info(" 1.6.2: Step 2 - Async pattern analysis")
747 | response2, _ = self.call_mcp_tool(
748 | "testgen",
749 | {
750 | "step": "Analyzing async patterns and edge cases for testing",
751 | "step_number": 2,
752 | "total_steps": 4,
753 | "next_step_required": True,
754 | "continuation_id": continuation_id,
755 | "findings": "Key test areas: async batch processing, cache behavior, error handling, batch size limits, empty items, statistics calculation",
756 | "files_checked": [processor_file],
757 | "relevant_files": [processor_file],
758 | "relevant_context": ["process_batch", "_process_single_item"],
759 | "confidence": "medium",
760 | "model": "flash",
761 | },
762 | )
763 |
764 | if not response2:
765 | self.logger.error("Failed to continue to step 2")
766 | return False
767 |
768 | self.logger.info(" ✅ Step 2: Async patterns analyzed")
769 |
770 | # Step 3: Edge case identification
771 | self.logger.info(" 1.6.3: Step 3 - Edge case identification")
772 | response3, _ = self.call_mcp_tool(
773 | "testgen",
774 | {
775 | "step": "Identifying all edge cases and boundary conditions",
776 | "step_number": 3,
777 | "total_steps": 4,
778 | "next_step_required": True,
779 | "continuation_id": continuation_id,
780 | "findings": "Edge cases: empty batch, oversized batch, items without ID, cache hits/misses, concurrent processing, error accumulation",
781 | "files_checked": [processor_file],
782 | "relevant_files": [processor_file],
783 | "confidence": "high",
784 | "model": "flash",
785 | },
786 | )
787 |
788 | if not response3:
789 | self.logger.error("Failed to continue to step 3")
790 | return False
791 |
792 | self.logger.info(" ✅ Step 3: Edge cases identified")
793 |
794 | # Step 4: Final test plan with expert analysis
795 | self.logger.info(" 1.6.4: Step 4 - Complete test plan")
796 | response4, _ = self.call_mcp_tool(
797 | "testgen",
798 | {
799 | "step": "Test planning complete with comprehensive coverage strategy",
800 | "step_number": 4,
801 | "total_steps": 4,
802 | "next_step_required": False, # Final step
803 | "continuation_id": continuation_id,
804 | "findings": "Complete async test suite plan: unit tests for each method, integration tests for batch processing, edge case coverage, performance tests",
805 | "files_checked": [processor_file],
806 | "relevant_files": [processor_file],
807 | "confidence": "high",
808 | "model": "flash",
809 | },
810 | )
811 |
812 | if not response4:
813 | self.logger.error("Failed to complete to final step")
814 | return False
815 |
816 | response4_data = self._parse_testgen_response(response4)
817 |
818 | # Validate final step
819 | if response4_data.get("status") != "calling_expert_analysis":
820 | self.logger.error("Final step should trigger expert analysis")
821 | return False
822 |
823 | file_context4 = response4_data.get("file_context", {})
824 | if file_context4.get("type") != "fully_embedded":
825 | self.logger.error("Final step should use fully_embedded file context")
826 | return False
827 |
828 | self.logger.info(" ✅ Multi-step test planning completed successfully")
829 | return True
830 |
831 | except Exception as e:
832 | self.logger.error(f"Multi-step test planning test failed: {e}")
833 | return False
834 |
```
--------------------------------------------------------------------------------
/tests/test_model_restrictions.py:
--------------------------------------------------------------------------------
```python
1 | """Tests for model restriction functionality."""
2 |
3 | import os
4 | from unittest.mock import MagicMock, patch
5 |
6 | import pytest
7 |
8 | from providers.gemini import GeminiModelProvider
9 | from providers.openai import OpenAIModelProvider
10 | from providers.shared import ProviderType
11 | from utils.model_restrictions import ModelRestrictionService
12 |
13 |
14 | class TestModelRestrictionService:
15 | """Test cases for ModelRestrictionService."""
16 |
17 | def test_no_restrictions_by_default(self):
18 | """Test that no restrictions exist when env vars are not set."""
19 | with patch.dict(os.environ, {}, clear=True):
20 | service = ModelRestrictionService()
21 |
22 | # Should allow all models
23 | assert service.is_allowed(ProviderType.OPENAI, "o3")
24 | assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
25 | assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
26 | assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash")
27 | assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4")
28 | assert service.is_allowed(ProviderType.OPENROUTER, "openai/o3")
29 |
30 | # Should have no restrictions
31 | assert not service.has_restrictions(ProviderType.OPENAI)
32 | assert not service.has_restrictions(ProviderType.GOOGLE)
33 | assert not service.has_restrictions(ProviderType.OPENROUTER)
34 |
35 | def test_load_single_model_restriction(self):
36 | """Test loading a single allowed model."""
37 | with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini"}):
38 | service = ModelRestrictionService()
39 |
40 | # Should only allow o3-mini
41 | assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
42 | assert not service.is_allowed(ProviderType.OPENAI, "o3")
43 | assert not service.is_allowed(ProviderType.OPENAI, "o4-mini")
44 |
45 | # Google and OpenRouter should have no restrictions
46 | assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
47 | assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4")
48 |
49 | def test_load_multiple_models_restriction(self):
50 | """Test loading multiple allowed models."""
51 | with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
52 | # Instantiate providers so alias resolution for allow-lists is available
53 | openai_provider = OpenAIModelProvider(api_key="test-key")
54 | gemini_provider = GeminiModelProvider(api_key="test-key")
55 |
56 | from providers.registry import ModelProviderRegistry
57 |
58 | def fake_get_provider(provider_type, force_new=False):
59 | mapping = {
60 | ProviderType.OPENAI: openai_provider,
61 | ProviderType.GOOGLE: gemini_provider,
62 | }
63 | return mapping.get(provider_type)
64 |
65 | with patch.object(ModelProviderRegistry, "get_provider", side_effect=fake_get_provider):
66 |
67 | service = ModelRestrictionService()
68 |
69 | # Check OpenAI models
70 | assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
71 | assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
72 | assert not service.is_allowed(ProviderType.OPENAI, "o3")
73 |
74 | # Check Google models
75 | assert service.is_allowed(ProviderType.GOOGLE, "flash")
76 | assert service.is_allowed(ProviderType.GOOGLE, "pro")
77 | assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
78 |
79 | def test_case_insensitive_and_whitespace_handling(self):
80 | """Test that model names are case-insensitive and whitespace is trimmed."""
81 | with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": " O3-MINI , o4-Mini "}):
82 | service = ModelRestrictionService()
83 |
84 | # Should work with any case
85 | assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
86 | assert service.is_allowed(ProviderType.OPENAI, "O3-MINI")
87 | assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
88 | assert service.is_allowed(ProviderType.OPENAI, "O4-Mini")
89 |
90 | def test_empty_string_allows_all(self):
91 | """Test that empty string allows all models (same as unset)."""
92 | with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "", "GOOGLE_ALLOWED_MODELS": "flash"}):
93 | service = ModelRestrictionService()
94 |
95 | # OpenAI should allow all models (empty string = no restrictions)
96 | assert service.is_allowed(ProviderType.OPENAI, "o3")
97 | assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
98 | assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
99 |
100 | # Google should only allow flash (and its resolved name)
101 | assert service.is_allowed(ProviderType.GOOGLE, "flash")
102 | assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash", "flash")
103 | assert not service.is_allowed(ProviderType.GOOGLE, "pro")
104 | assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro", "pro")
105 |
106 | def test_filter_models(self):
107 | """Test filtering a list of models based on restrictions."""
108 | with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
109 | service = ModelRestrictionService()
110 |
111 | models = ["o3", "o3-mini", "o4-mini", "o3-pro"]
112 | filtered = service.filter_models(ProviderType.OPENAI, models)
113 |
114 | assert filtered == ["o3-mini", "o4-mini"]
115 |
116 | def test_get_allowed_models(self):
117 | """Test getting the set of allowed models."""
118 | with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
119 | service = ModelRestrictionService()
120 |
121 | allowed = service.get_allowed_models(ProviderType.OPENAI)
122 | assert allowed == {"o3-mini", "o4-mini"}
123 |
124 | # No restrictions for Google
125 | assert service.get_allowed_models(ProviderType.GOOGLE) is None
126 |
127 | def test_shorthand_names_in_restrictions(self):
128 | """Test that shorthand names work in restrictions."""
129 | with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4mini,o3mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
130 | # Instantiate providers so the registry can resolve aliases
131 | OpenAIModelProvider(api_key="test-key")
132 | GeminiModelProvider(api_key="test-key")
133 |
134 | service = ModelRestrictionService()
135 |
136 | # When providers check models, they pass both resolved and original names
137 | # OpenAI: 'o4mini' shorthand allows o4-mini
138 | assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "o4mini") # How providers actually call it
139 | assert service.is_allowed(ProviderType.OPENAI, "o4-mini") # Canonical should also be allowed
140 |
141 | # OpenAI: o3-mini allowed directly
142 | assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
143 | assert not service.is_allowed(ProviderType.OPENAI, "o3")
144 |
145 | # Google should allow both models via shorthands
146 | assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash", "flash")
147 | assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro", "pro")
148 |
149 | # Also test that full names work when specified in restrictions
150 | assert service.is_allowed(ProviderType.OPENAI, "o3-mini", "o3mini") # Even with shorthand
151 |
152 | def test_validation_against_known_models(self, caplog):
153 | """Test validation warnings for unknown models."""
154 | with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mimi"}): # Note the typo: o4-mimi
155 | service = ModelRestrictionService()
156 |
157 | # Create mock provider with known models
158 | mock_provider = MagicMock()
159 | mock_provider.MODEL_CAPABILITIES = {
160 | "o3": {"context_window": 200000},
161 | "o3-mini": {"context_window": 200000},
162 | "o4-mini": {"context_window": 200000},
163 | }
164 | mock_provider.list_models.return_value = ["o3", "o3-mini", "o4-mini"]
165 |
166 | provider_instances = {ProviderType.OPENAI: mock_provider}
167 | service.validate_against_known_models(provider_instances)
168 |
169 | # Should have logged a warning about the typo
170 | assert "o4-mimi" in caplog.text
171 | assert "not a recognized" in caplog.text
172 |
173 | def test_openrouter_model_restrictions(self):
174 | """Test OpenRouter model restrictions functionality."""
175 | with patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,sonnet"}):
176 | service = ModelRestrictionService()
177 |
178 | # Should only allow specified OpenRouter models
179 | assert service.is_allowed(ProviderType.OPENROUTER, "opus")
180 | assert service.is_allowed(ProviderType.OPENROUTER, "sonnet")
181 | assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4", "opus") # With original name
182 | assert not service.is_allowed(ProviderType.OPENROUTER, "haiku")
183 | assert not service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-haiku")
184 | assert not service.is_allowed(ProviderType.OPENROUTER, "mistral-large")
185 |
186 | # Other providers should have no restrictions
187 | assert service.is_allowed(ProviderType.OPENAI, "o3")
188 | assert service.is_allowed(ProviderType.GOOGLE, "pro")
189 |
190 | # Should have restrictions for OpenRouter
191 | assert service.has_restrictions(ProviderType.OPENROUTER)
192 | assert not service.has_restrictions(ProviderType.OPENAI)
193 | assert not service.has_restrictions(ProviderType.GOOGLE)
194 |
195 | def test_openrouter_filter_models(self):
196 | """Test filtering OpenRouter models based on restrictions."""
197 | with patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,mistral"}):
198 | service = ModelRestrictionService()
199 |
200 | models = ["opus", "sonnet", "haiku", "mistral", "llama"]
201 | filtered = service.filter_models(ProviderType.OPENROUTER, models)
202 |
203 | assert filtered == ["opus", "mistral"]
204 |
205 | def test_combined_provider_restrictions(self):
206 | """Test that restrictions work correctly when set for multiple providers."""
207 | with patch.dict(
208 | os.environ,
209 | {
210 | "OPENAI_ALLOWED_MODELS": "o3-mini",
211 | "GOOGLE_ALLOWED_MODELS": "flash",
212 | "OPENROUTER_ALLOWED_MODELS": "opus,sonnet",
213 | },
214 | ):
215 | service = ModelRestrictionService()
216 |
217 | # OpenAI restrictions
218 | assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
219 | assert not service.is_allowed(ProviderType.OPENAI, "o3")
220 |
221 | # Google restrictions
222 | assert service.is_allowed(ProviderType.GOOGLE, "flash")
223 | assert not service.is_allowed(ProviderType.GOOGLE, "pro")
224 |
225 | # OpenRouter restrictions
226 | assert service.is_allowed(ProviderType.OPENROUTER, "opus")
227 | assert service.is_allowed(ProviderType.OPENROUTER, "sonnet")
228 | assert not service.is_allowed(ProviderType.OPENROUTER, "haiku")
229 |
230 | # All providers should have restrictions
231 | assert service.has_restrictions(ProviderType.OPENAI)
232 | assert service.has_restrictions(ProviderType.GOOGLE)
233 | assert service.has_restrictions(ProviderType.OPENROUTER)
234 |
235 |
236 | class TestProviderIntegration:
237 | """Test integration with actual providers."""
238 |
239 | @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini"})
240 | def test_openai_provider_respects_restrictions(self):
241 | """Test that OpenAI provider respects restrictions."""
242 | # Clear any cached restriction service
243 | import utils.model_restrictions
244 |
245 | utils.model_restrictions._restriction_service = None
246 |
247 | provider = OpenAIModelProvider(api_key="test-key")
248 |
249 | # Should validate allowed model
250 | assert provider.validate_model_name("o3-mini")
251 |
252 | # Should not validate disallowed model
253 | assert not provider.validate_model_name("o3")
254 |
255 | # get_capabilities should raise for disallowed model
256 | with pytest.raises(ValueError) as exc_info:
257 | provider.get_capabilities("o3")
258 | assert "not allowed by restriction policy" in str(exc_info.value)
259 |
260 | @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash,flash"})
261 | def test_gemini_provider_respects_restrictions(self):
262 | """Test that Gemini provider respects restrictions."""
263 | # Clear any cached restriction service
264 | import utils.model_restrictions
265 |
266 | utils.model_restrictions._restriction_service = None
267 |
268 | provider = GeminiModelProvider(api_key="test-key")
269 |
270 | # Should validate allowed models (both shorthand and full name allowed)
271 | assert provider.validate_model_name("flash")
272 | assert provider.validate_model_name("gemini-2.5-flash")
273 |
274 | # Should not validate disallowed model
275 | assert not provider.validate_model_name("pro")
276 | assert not provider.validate_model_name("gemini-2.5-pro")
277 |
278 | # get_capabilities should raise for disallowed model
279 | with pytest.raises(ValueError) as exc_info:
280 | provider.get_capabilities("pro")
281 | assert "not allowed by restriction policy" in str(exc_info.value)
282 |
283 | @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"})
284 | def test_gemini_parameter_order_regression_protection(self):
285 | """Test that prevents regression of parameter order bug in is_allowed calls.
286 |
287 | This test specifically catches the bug where parameters were incorrectly
288 | passed as (provider, user_input, resolved_name) instead of
289 | (provider, resolved_name, user_input).
290 |
291 | The bug was subtle because the is_allowed method uses OR logic, so it
292 | worked in most cases by accident. This test creates a scenario where
293 | the parameter order matters.
294 | """
295 | # Clear any cached restriction service
296 | import utils.model_restrictions
297 |
298 | utils.model_restrictions._restriction_service = None
299 |
300 | provider = GeminiModelProvider(api_key="test-key")
301 |
302 | from providers.registry import ModelProviderRegistry
303 |
304 | with patch.object(ModelProviderRegistry, "get_provider", return_value=provider):
305 |
306 | # Test case: Only alias "flash" is allowed, not the full name
307 | # If parameters are in wrong order, this test will catch it
308 |
309 | # Should allow "flash" alias
310 | assert provider.validate_model_name("flash")
311 |
312 | # Should allow getting capabilities for "flash"
313 | capabilities = provider.get_capabilities("flash")
314 | assert capabilities.model_name == "gemini-2.5-flash"
315 |
316 | # Canonical form should also be allowed now that alias is on the allowlist
317 | assert provider.validate_model_name("gemini-2.5-flash")
318 | # Unrelated models remain blocked
319 | assert not provider.validate_model_name("pro")
320 | assert not provider.validate_model_name("gemini-2.5-pro")
321 |
322 | @patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
323 | def test_gemini_parameter_order_edge_case_full_name_only(self):
324 | """Test parameter order with only full name allowed, not alias.
325 |
326 | This is the reverse scenario - only the full canonical name is allowed,
327 | not the shorthand alias. This tests that the parameter order is correct
328 | when resolving aliases.
329 | """
330 | # Clear any cached restriction service
331 | import utils.model_restrictions
332 |
333 | utils.model_restrictions._restriction_service = None
334 |
335 | provider = GeminiModelProvider(api_key="test-key")
336 |
337 | # Should allow full name
338 | assert provider.validate_model_name("gemini-2.5-flash")
339 |
340 | # Should also allow alias that resolves to allowed full name
341 | # This works because is_allowed checks both resolved_name and original_name
342 | assert provider.validate_model_name("flash")
343 |
344 | # Should not allow "pro" alias
345 | assert not provider.validate_model_name("pro")
346 | assert not provider.validate_model_name("gemini-2.5-pro")
347 |
348 |
349 | class TestCustomProviderOpenRouterRestrictions:
350 | """Test custom provider integration with OpenRouter restrictions."""
351 |
352 | @patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,sonnet", "OPENROUTER_API_KEY": "test-key"})
353 | def test_custom_provider_respects_openrouter_restrictions(self):
354 | """Test that custom provider correctly defers OpenRouter models to OpenRouter provider."""
355 | # Clear any cached restriction service
356 | import utils.model_restrictions
357 |
358 | utils.model_restrictions._restriction_service = None
359 |
360 | from providers.custom import CustomProvider
361 |
362 | provider = CustomProvider(base_url="http://test.com/v1")
363 |
364 | # CustomProvider should NOT validate OpenRouter models - they should be deferred to OpenRouter
365 | assert not provider.validate_model_name("opus")
366 | assert not provider.validate_model_name("sonnet")
367 | assert not provider.validate_model_name("haiku")
368 |
369 | # Should still validate custom models defined in conf/custom_models.json
370 | assert provider.validate_model_name("local-llama")
371 |
372 | @patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus", "OPENROUTER_API_KEY": "test-key"})
373 | def test_custom_provider_openrouter_capabilities_restrictions(self):
374 | """Test that custom provider's get_capabilities correctly handles OpenRouter models."""
375 | # Clear any cached restriction service
376 | import utils.model_restrictions
377 |
378 | utils.model_restrictions._restriction_service = None
379 |
380 | from providers.custom import CustomProvider
381 |
382 | provider = CustomProvider(base_url="http://test.com/v1")
383 |
384 | # For OpenRouter models, CustomProvider should defer by raising
385 | with pytest.raises(ValueError):
386 | provider.get_capabilities("opus")
387 |
388 | # Should raise for disallowed OpenRouter model (still defers)
389 | with pytest.raises(ValueError):
390 | provider.get_capabilities("haiku")
391 |
392 | # Should still work for custom models
393 | capabilities = provider.get_capabilities("local-llama")
394 | assert capabilities.provider == ProviderType.CUSTOM
395 |
396 | @patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus"}, clear=False)
397 | def test_custom_provider_no_openrouter_key_ignores_restrictions(self):
398 | """Test that when OpenRouter key is not set, cloud models are rejected regardless of restrictions."""
399 | # Make sure OPENROUTER_API_KEY is not set
400 | if "OPENROUTER_API_KEY" in os.environ:
401 | del os.environ["OPENROUTER_API_KEY"]
402 | # Clear any cached restriction service
403 | import utils.model_restrictions
404 |
405 | utils.model_restrictions._restriction_service = None
406 |
407 | from providers.custom import CustomProvider
408 |
409 | provider = CustomProvider(base_url="http://test.com/v1")
410 |
411 | # Should not validate OpenRouter models when key is not available
412 | assert not provider.validate_model_name("opus") # Even though it's in allowed list
413 | assert not provider.validate_model_name("haiku")
414 |
415 | # Should still validate custom models
416 | assert provider.validate_model_name("local-llama")
417 |
418 | @patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "", "OPENROUTER_API_KEY": "test-key"})
419 | def test_custom_provider_empty_restrictions_allows_all_openrouter(self):
420 | """Test that custom provider correctly defers OpenRouter models regardless of restrictions."""
421 | # Clear any cached restriction service
422 | import utils.model_restrictions
423 |
424 | utils.model_restrictions._restriction_service = None
425 |
426 | from providers.custom import CustomProvider
427 |
428 | provider = CustomProvider(base_url="http://test.com/v1")
429 |
430 | # CustomProvider should NOT validate OpenRouter models - they should be deferred to OpenRouter
431 | assert not provider.validate_model_name("opus")
432 | assert not provider.validate_model_name("sonnet")
433 | assert not provider.validate_model_name("haiku")
434 |
435 |
436 | class TestRegistryIntegration:
437 | """Test integration with ModelProviderRegistry."""
438 |
439 | @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GOOGLE_ALLOWED_MODELS": "flash"})
440 | def test_registry_with_shorthand_restrictions(self):
441 | """Test that registry handles shorthand restrictions correctly."""
442 | # Clear cached restriction service
443 | import utils.model_restrictions
444 |
445 | utils.model_restrictions._restriction_service = None
446 |
447 | from providers.registry import ModelProviderRegistry
448 |
449 | # Clear registry cache
450 | ModelProviderRegistry.clear_cache()
451 |
452 | # Get available models with restrictions
453 | # This test documents current behavior - get_available_models doesn't handle aliases
454 | ModelProviderRegistry.get_available_models(respect_restrictions=True)
455 |
456 | # Currently, this will be empty because get_available_models doesn't
457 | # recognize that "mini" allows "o4-mini"
458 | # This is a known limitation that should be documented
459 |
460 | @patch("providers.registry.ModelProviderRegistry.get_provider")
461 | def test_get_available_models_respects_restrictions(self, mock_get_provider):
462 | """Test that registry filters models based on restrictions."""
463 | from providers.registry import ModelProviderRegistry
464 |
465 | # Mock providers
466 | mock_openai = MagicMock()
467 | mock_openai.MODEL_CAPABILITIES = {
468 | "o3": {"context_window": 200000},
469 | "o3-mini": {"context_window": 200000},
470 | }
471 | mock_openai.get_provider_type.return_value = ProviderType.OPENAI
472 |
473 | def openai_list_models(
474 | *,
475 | respect_restrictions: bool = True,
476 | include_aliases: bool = True,
477 | lowercase: bool = False,
478 | unique: bool = False,
479 | ):
480 | from utils.model_restrictions import get_restriction_service
481 |
482 | restriction_service = get_restriction_service() if respect_restrictions else None
483 | models = []
484 | for model_name, config in mock_openai.MODEL_CAPABILITIES.items():
485 | if isinstance(config, str):
486 | target_model = config
487 | if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
488 | continue
489 | if include_aliases:
490 | models.append(model_name)
491 | else:
492 | if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
493 | continue
494 | models.append(model_name)
495 | if lowercase:
496 | models = [m.lower() for m in models]
497 | if unique:
498 | seen = set()
499 | ordered = []
500 | for name in models:
501 | if name in seen:
502 | continue
503 | seen.add(name)
504 | ordered.append(name)
505 | models = ordered
506 | return models
507 |
508 | mock_openai.list_models = MagicMock(side_effect=openai_list_models)
509 |
510 | mock_gemini = MagicMock()
511 | mock_gemini.MODEL_CAPABILITIES = {
512 | "gemini-2.5-pro": {"context_window": 1048576},
513 | "gemini-2.5-flash": {"context_window": 1048576},
514 | }
515 | mock_gemini.get_provider_type.return_value = ProviderType.GOOGLE
516 |
517 | def gemini_list_models(
518 | *,
519 | respect_restrictions: bool = True,
520 | include_aliases: bool = True,
521 | lowercase: bool = False,
522 | unique: bool = False,
523 | ):
524 | from utils.model_restrictions import get_restriction_service
525 |
526 | restriction_service = get_restriction_service() if respect_restrictions else None
527 | models = []
528 | for model_name, config in mock_gemini.MODEL_CAPABILITIES.items():
529 | if isinstance(config, str):
530 | target_model = config
531 | if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model):
532 | continue
533 | if include_aliases:
534 | models.append(model_name)
535 | else:
536 | if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, model_name):
537 | continue
538 | models.append(model_name)
539 | if lowercase:
540 | models = [m.lower() for m in models]
541 | if unique:
542 | seen = set()
543 | ordered = []
544 | for name in models:
545 | if name in seen:
546 | continue
547 | seen.add(name)
548 | ordered.append(name)
549 | models = ordered
550 | return models
551 |
552 | mock_gemini.list_models = MagicMock(side_effect=gemini_list_models)
553 |
554 | def get_provider_side_effect(provider_type):
555 | if provider_type == ProviderType.OPENAI:
556 | return mock_openai
557 | elif provider_type == ProviderType.GOOGLE:
558 | return mock_gemini
559 | return None
560 |
561 | mock_get_provider.side_effect = get_provider_side_effect
562 |
563 | # Set up registry with providers
564 | registry = ModelProviderRegistry()
565 | registry._providers = {
566 | ProviderType.OPENAI: type(mock_openai),
567 | ProviderType.GOOGLE: type(mock_gemini),
568 | }
569 |
570 | with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini", "GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}):
571 | # Clear cached restriction service
572 | import utils.model_restrictions
573 |
574 | utils.model_restrictions._restriction_service = None
575 |
576 | available = ModelProviderRegistry.get_available_models(respect_restrictions=True)
577 |
578 | # Should only include allowed models
579 | assert "o3-mini" in available
580 | assert "o3" not in available
581 | assert "gemini-2.5-flash" in available
582 | assert "gemini-2.5-pro" not in available
583 |
584 |
585 | class TestShorthandRestrictions:
586 | """Test that shorthand model names work correctly in restrictions."""
587 |
588 | @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GOOGLE_ALLOWED_MODELS": "flash"})
589 | def test_providers_validate_shorthands_correctly(self):
590 | """Test that providers correctly validate shorthand names."""
591 | # Clear cached restriction service
592 | import utils.model_restrictions
593 |
594 | utils.model_restrictions._restriction_service = None
595 |
596 | # Test OpenAI provider
597 | openai_provider = OpenAIModelProvider(api_key="test-key")
598 | gemini_provider = GeminiModelProvider(api_key="test-key")
599 |
600 | from providers.registry import ModelProviderRegistry
601 |
602 | def registry_side_effect(provider_type, force_new=False):
603 | mapping = {
604 | ProviderType.OPENAI: openai_provider,
605 | ProviderType.GOOGLE: gemini_provider,
606 | }
607 | return mapping.get(provider_type)
608 |
609 | with patch.object(ModelProviderRegistry, "get_provider", side_effect=registry_side_effect):
610 | assert openai_provider.validate_model_name("mini") # Should work with shorthand
611 | assert openai_provider.validate_model_name("gpt-5-mini") # Canonical resolved from shorthand
612 | assert not openai_provider.validate_model_name("o4-mini") # Unrelated model still blocked
613 | assert not openai_provider.validate_model_name("o3-mini")
614 |
615 | # Test Gemini provider
616 | assert gemini_provider.validate_model_name("flash") # Should work with shorthand
617 | assert gemini_provider.validate_model_name("gemini-2.5-flash") # Canonical allowed
618 | assert not gemini_provider.validate_model_name("pro") # Not allowed
619 |
620 | @patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
621 | def test_multiple_shorthands_for_same_model(self):
622 | """Test that multiple shorthands work correctly."""
623 | # Clear cached restriction service
624 | import utils.model_restrictions
625 |
626 | utils.model_restrictions._restriction_service = None
627 |
628 | openai_provider = OpenAIModelProvider(api_key="test-key")
629 |
630 | # Both shorthands should work
631 | assert openai_provider.validate_model_name("mini") # mini -> o4-mini
632 | assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini
633 |
634 | # Resolved names should be allowed when their shorthands are present
635 | assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed
636 | assert openai_provider.validate_model_name("o3-mini") # Allowed via shorthand
637 |
638 | # Other models should not work
639 | assert not openai_provider.validate_model_name("o3")
640 | assert not openai_provider.validate_model_name("o3-pro")
641 |
642 | @patch.dict(
643 | os.environ,
644 | {"OPENAI_ALLOWED_MODELS": "mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,gemini-2.5-flash"},
645 | )
646 | def test_both_shorthand_and_full_name_allowed(self):
647 | """Test that we can allow both shorthand and full names."""
648 | # Clear cached restriction service
649 | import utils.model_restrictions
650 |
651 | utils.model_restrictions._restriction_service = None
652 |
653 | # OpenAI - both mini and o4-mini are allowed
654 | openai_provider = OpenAIModelProvider(api_key="test-key")
655 | assert openai_provider.validate_model_name("mini")
656 | assert openai_provider.validate_model_name("o4-mini")
657 |
658 | # Gemini - both flash and full name are allowed
659 | gemini_provider = GeminiModelProvider(api_key="test-key")
660 | assert gemini_provider.validate_model_name("flash")
661 | assert gemini_provider.validate_model_name("gemini-2.5-flash")
662 |
663 |
664 | class TestAutoModeWithRestrictions:
665 | """Test auto mode behavior with restrictions."""
666 |
667 | @patch("providers.registry.ModelProviderRegistry.get_provider")
668 | def test_fallback_model_respects_restrictions(self, mock_get_provider):
669 | """Test that fallback model selection respects restrictions."""
670 | from providers.registry import ModelProviderRegistry
671 | from tools.models import ToolModelCategory
672 |
673 | # Mock providers
674 | mock_openai = MagicMock()
675 | mock_openai.MODEL_CAPABILITIES = {
676 | "o3": {"context_window": 200000},
677 | "o3-mini": {"context_window": 200000},
678 | "o4-mini": {"context_window": 200000},
679 | }
680 | mock_openai.get_provider_type.return_value = ProviderType.OPENAI
681 |
682 | def openai_list_models(
683 | *,
684 | respect_restrictions: bool = True,
685 | include_aliases: bool = True,
686 | lowercase: bool = False,
687 | unique: bool = False,
688 | ):
689 | from utils.model_restrictions import get_restriction_service
690 |
691 | restriction_service = get_restriction_service() if respect_restrictions else None
692 | models = []
693 | for model_name, config in mock_openai.MODEL_CAPABILITIES.items():
694 | if isinstance(config, str):
695 | target_model = config
696 | if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
697 | continue
698 | if include_aliases:
699 | models.append(model_name)
700 | else:
701 | if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
702 | continue
703 | models.append(model_name)
704 | if lowercase:
705 | models = [m.lower() for m in models]
706 | if unique:
707 | seen = set()
708 | ordered = []
709 | for name in models:
710 | if name in seen:
711 | continue
712 | seen.add(name)
713 | ordered.append(name)
714 | models = ordered
715 | return models
716 |
717 | mock_openai.list_models = MagicMock(side_effect=openai_list_models)
718 |
719 | # Add get_preferred_model method to mock to match new implementation
720 | def get_preferred_model(category, allowed_models):
721 | # Simple preference logic for testing - just return first allowed model
722 | return allowed_models[0] if allowed_models else None
723 |
724 | mock_openai.get_preferred_model = get_preferred_model
725 |
726 | def get_provider_side_effect(provider_type):
727 | if provider_type == ProviderType.OPENAI:
728 | return mock_openai
729 | return None
730 |
731 | mock_get_provider.side_effect = get_provider_side_effect
732 |
733 | # Set up registry
734 | registry = ModelProviderRegistry()
735 | registry._providers = {ProviderType.OPENAI: type(mock_openai)}
736 |
737 | with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}):
738 | # Clear cached restriction service
739 | import utils.model_restrictions
740 |
741 | utils.model_restrictions._restriction_service = None
742 |
743 | # Should pick o4-mini instead of o3-mini for fast response
744 | model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
745 | assert model == "o4-mini"
746 |
747 | def test_fallback_with_shorthand_restrictions(self, monkeypatch):
748 | """Test fallback model selection with shorthand restrictions."""
749 | # Use monkeypatch to set environment variables with automatic cleanup
750 | monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "mini")
751 | monkeypatch.setenv("GEMINI_API_KEY", "")
752 | monkeypatch.setenv("OPENAI_API_KEY", "test-key")
753 |
754 | # Clear caches and reset registry
755 | import utils.model_restrictions
756 | from providers.registry import ModelProviderRegistry
757 | from tools.models import ToolModelCategory
758 |
759 | utils.model_restrictions._restriction_service = None
760 |
761 | # Store original providers for restoration
762 | registry = ModelProviderRegistry()
763 | original_providers = registry._providers.copy()
764 | original_initialized = registry._initialized_providers.copy()
765 |
766 | try:
767 | # Clear registry and register only OpenAI and Gemini providers
768 | ModelProviderRegistry._instance = None
769 | from providers.gemini import GeminiModelProvider
770 | from providers.openai import OpenAIModelProvider
771 |
772 | ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
773 | ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
774 |
775 | # Even with "mini" restriction, fallback should work if provider handles it correctly
776 | # This tests the real-world scenario
777 | model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
778 |
779 | # The fallback will depend on how get_available_models handles aliases
780 | # When "mini" is allowed, it's returned as the allowed model
781 | # "mini" is now an alias for gpt-5-mini, but the list shows "mini" itself
782 | assert model in ["mini", "gpt-5-mini", "o4-mini", "gemini-2.5-flash"]
783 | finally:
784 | # Restore original registry state
785 | registry = ModelProviderRegistry()
786 | registry._providers.clear()
787 | registry._initialized_providers.clear()
788 | registry._providers.update(original_providers)
789 | registry._initialized_providers.update(original_initialized)
790 |
```
--------------------------------------------------------------------------------
/providers/openai_compatible.py:
--------------------------------------------------------------------------------
```python
1 | """Base class for OpenAI-compatible API providers."""
2 |
3 | import copy
4 | import ipaddress
5 | import logging
6 | from typing import Optional
7 | from urllib.parse import urlparse
8 |
9 | from openai import OpenAI
10 |
11 | from utils.env import get_env, suppress_env_vars
12 | from utils.image_utils import validate_image
13 |
14 | from .base import ModelProvider
15 | from .shared import (
16 | ModelCapabilities,
17 | ModelResponse,
18 | ProviderType,
19 | )
20 |
21 |
22 | class OpenAICompatibleProvider(ModelProvider):
23 | """Shared implementation for OpenAI API lookalikes.
24 |
25 | The class owns HTTP client configuration (timeouts, proxy hardening,
26 | custom headers) and normalises the OpenAI SDK responses into
27 | :class:`~providers.shared.ModelResponse`. Concrete subclasses only need to
28 | provide capability metadata and any provider-specific request tweaks.
29 | """
30 |
31 | DEFAULT_HEADERS = {}
32 | FRIENDLY_NAME = "OpenAI Compatible"
33 |
34 | def __init__(self, api_key: str, base_url: str = None, **kwargs):
35 | """Initialize the provider with API key and optional base URL.
36 |
37 | Args:
38 | api_key: API key for authentication
39 | base_url: Base URL for the API endpoint
40 | **kwargs: Additional configuration options including timeout
41 | """
42 | self._allowed_alias_cache: dict[str, str] = {}
43 | super().__init__(api_key, **kwargs)
44 | self._client = None
45 | self.base_url = base_url
46 | self.organization = kwargs.get("organization")
47 | self.allowed_models = self._parse_allowed_models()
48 |
49 | # Configure timeouts - especially important for custom/local endpoints
50 | self.timeout_config = self._configure_timeouts(**kwargs)
51 |
52 | # Validate base URL for security
53 | if self.base_url:
54 | self._validate_base_url()
55 |
56 | # Warn if using external URL without authentication
57 | if self.base_url and not self._is_localhost_url() and not api_key:
58 | logging.warning(
59 | f"Using external URL '{self.base_url}' without API key. "
60 | "This may be insecure. Consider setting an API key for authentication."
61 | )
62 |
63 | def _ensure_model_allowed(
64 | self,
65 | capabilities: ModelCapabilities,
66 | canonical_name: str,
67 | requested_name: str,
68 | ) -> None:
69 | """Respect provider-specific allowlists before default restriction checks."""
70 |
71 | super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
72 |
73 | if self.allowed_models is not None:
74 | requested = requested_name.lower()
75 | canonical = canonical_name.lower()
76 |
77 | if requested not in self.allowed_models and canonical not in self.allowed_models:
78 | allowed = False
79 | for allowed_entry in list(self.allowed_models):
80 | normalized_resolved = self._allowed_alias_cache.get(allowed_entry)
81 | if normalized_resolved is None:
82 | try:
83 | resolved_name = self._resolve_model_name(allowed_entry)
84 | except Exception:
85 | continue
86 |
87 | if not resolved_name:
88 | continue
89 |
90 | normalized_resolved = resolved_name.lower()
91 | self._allowed_alias_cache[allowed_entry] = normalized_resolved
92 |
93 | if normalized_resolved == canonical:
94 | # Canonical match discovered via alias resolution – mark as allowed and
95 | # memoise the canonical entry for future lookups.
96 | allowed = True
97 | self._allowed_alias_cache[canonical] = canonical
98 | self.allowed_models.add(canonical)
99 | break
100 |
101 | if not allowed:
102 | raise ValueError(
103 | f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
104 | )
105 |
106 | def _parse_allowed_models(self) -> Optional[set[str]]:
107 | """Parse allowed models from environment variable.
108 |
109 | Returns:
110 | Set of allowed model names (lowercase) or None if not configured
111 | """
112 | # Get provider-specific allowed models
113 | provider_type = self.get_provider_type().value.upper()
114 | env_var = f"{provider_type}_ALLOWED_MODELS"
115 | models_str = get_env(env_var, "") or ""
116 |
117 | if models_str:
118 | # Parse and normalize to lowercase for case-insensitive comparison
119 | models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
120 | if models:
121 | logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
122 | self._allowed_alias_cache = {}
123 | return models
124 |
125 | # Log info if no allow-list configured for proxy providers
126 | if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
127 | logging.info(
128 | f"Model allow-list not configured for {self.FRIENDLY_NAME} - all models permitted. "
129 | f"To restrict access, set {env_var} with comma-separated model names."
130 | )
131 |
132 | return None
133 |
134 | def _configure_timeouts(self, **kwargs):
135 | """Configure timeout settings based on provider type and custom settings.
136 |
137 | Custom URLs and local models often need longer timeouts due to:
138 | - Network latency on local networks
139 | - Extended thinking models taking longer to respond
140 | - Local inference being slower than cloud APIs
141 |
142 | Returns:
143 | httpx.Timeout object with appropriate timeout settings
144 | """
145 | import httpx
146 |
147 | # Default timeouts - more generous for custom/local endpoints
148 | default_connect = 30.0 # 30 seconds for connection (vs OpenAI's 5s)
149 | default_read = 600.0 # 10 minutes for reading (same as OpenAI default)
150 | default_write = 600.0 # 10 minutes for writing
151 | default_pool = 600.0 # 10 minutes for pool
152 |
153 | # For custom/local URLs, use even longer timeouts
154 | if self.base_url and self._is_localhost_url():
155 | default_connect = 60.0 # 1 minute for local connections
156 | default_read = 1800.0 # 30 minutes for local models (extended thinking)
157 | default_write = 1800.0 # 30 minutes for local models
158 | default_pool = 1800.0 # 30 minutes for local models
159 | logging.info(f"Using extended timeouts for local endpoint: {self.base_url}")
160 | elif self.base_url:
161 | default_connect = 45.0 # 45 seconds for custom remote endpoints
162 | default_read = 900.0 # 15 minutes for custom remote endpoints
163 | default_write = 900.0 # 15 minutes for custom remote endpoints
164 | default_pool = 900.0 # 15 minutes for custom remote endpoints
165 | logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}")
166 |
167 | # Allow override via kwargs or environment variables in future, for now...
168 | connect_timeout = kwargs.get("connect_timeout")
169 | if connect_timeout is None:
170 | connect_timeout_raw = get_env("CUSTOM_CONNECT_TIMEOUT")
171 | connect_timeout = float(connect_timeout_raw) if connect_timeout_raw is not None else float(default_connect)
172 |
173 | read_timeout = kwargs.get("read_timeout")
174 | if read_timeout is None:
175 | read_timeout_raw = get_env("CUSTOM_READ_TIMEOUT")
176 | read_timeout = float(read_timeout_raw) if read_timeout_raw is not None else float(default_read)
177 |
178 | write_timeout = kwargs.get("write_timeout")
179 | if write_timeout is None:
180 | write_timeout_raw = get_env("CUSTOM_WRITE_TIMEOUT")
181 | write_timeout = float(write_timeout_raw) if write_timeout_raw is not None else float(default_write)
182 |
183 | pool_timeout = kwargs.get("pool_timeout")
184 | if pool_timeout is None:
185 | pool_timeout_raw = get_env("CUSTOM_POOL_TIMEOUT")
186 | pool_timeout = float(pool_timeout_raw) if pool_timeout_raw is not None else float(default_pool)
187 |
188 | timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout)
189 |
190 | logging.debug(
191 | f"Configured timeouts - Connect: {connect_timeout}s, Read: {read_timeout}s, "
192 | f"Write: {write_timeout}s, Pool: {pool_timeout}s"
193 | )
194 |
195 | return timeout
196 |
197 | def _is_localhost_url(self) -> bool:
198 | """Check if the base URL points to localhost or local network.
199 |
200 | Returns:
201 | True if URL is localhost or local network, False otherwise
202 | """
203 | if not self.base_url:
204 | return False
205 |
206 | try:
207 | parsed = urlparse(self.base_url)
208 | hostname = parsed.hostname
209 |
210 | # Check for common localhost patterns
211 | if hostname in ["localhost", "127.0.0.1", "::1"]:
212 | return True
213 |
214 | # Check for private network ranges (local network)
215 | if hostname:
216 | try:
217 | ip = ipaddress.ip_address(hostname)
218 | return ip.is_private or ip.is_loopback
219 | except ValueError:
220 | # Not an IP address, might be a hostname
221 | pass
222 |
223 | return False
224 | except Exception:
225 | return False
226 |
227 | def _validate_base_url(self) -> None:
228 | """Validate base URL for security (SSRF protection).
229 |
230 | Raises:
231 | ValueError: If URL is invalid or potentially unsafe
232 | """
233 | if not self.base_url:
234 | return
235 |
236 | try:
237 | parsed = urlparse(self.base_url)
238 |
239 | # Check URL scheme - only allow http/https
240 | if parsed.scheme not in ("http", "https"):
241 | raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
242 |
243 | # Check hostname exists
244 | if not parsed.hostname:
245 | raise ValueError("URL must include a hostname")
246 |
247 | # Check port is valid (if specified)
248 | port = parsed.port
249 | if port is not None and (port < 1 or port > 65535):
250 | raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.")
251 | except Exception as e:
252 | if isinstance(e, ValueError):
253 | raise
254 | raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
255 |
256 | @property
257 | def client(self):
258 | """Lazy initialization of OpenAI client with security checks and timeout configuration."""
259 | if self._client is None:
260 | import httpx
261 |
262 | proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]
263 |
264 | with suppress_env_vars(*proxy_env_vars):
265 | try:
266 | # Create a custom httpx client that explicitly avoids proxy parameters
267 | timeout_config = (
268 | self.timeout_config
269 | if hasattr(self, "timeout_config") and self.timeout_config
270 | else httpx.Timeout(30.0)
271 | )
272 |
273 | # Create httpx client with minimal config to avoid proxy conflicts
274 | # Note: proxies parameter was removed in httpx 0.28.0
275 | # Check for test transport injection
276 | if hasattr(self, "_test_transport"):
277 | # Use custom transport for testing (HTTP recording/replay)
278 | http_client = httpx.Client(
279 | transport=self._test_transport,
280 | timeout=timeout_config,
281 | follow_redirects=True,
282 | )
283 | else:
284 | # Normal production client
285 | http_client = httpx.Client(
286 | timeout=timeout_config,
287 | follow_redirects=True,
288 | )
289 |
290 | # Keep client initialization minimal to avoid proxy parameter conflicts
291 | client_kwargs = {
292 | "api_key": self.api_key,
293 | "http_client": http_client,
294 | }
295 |
296 | if self.base_url:
297 | client_kwargs["base_url"] = self.base_url
298 |
299 | if self.organization:
300 | client_kwargs["organization"] = self.organization
301 |
302 | # Add default headers if any
303 | if self.DEFAULT_HEADERS:
304 | client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
305 |
306 | logging.debug(
307 | "OpenAI client initialized with custom httpx client and timeout: %s",
308 | timeout_config,
309 | )
310 |
311 | # Create OpenAI client with custom httpx client
312 | self._client = OpenAI(**client_kwargs)
313 |
314 | except Exception as e:
315 | # If all else fails, try absolute minimal client without custom httpx
316 | logging.warning(
317 | "Failed to create client with custom httpx, falling back to minimal config: %s",
318 | e,
319 | )
320 | try:
321 | minimal_kwargs = {"api_key": self.api_key}
322 | if self.base_url:
323 | minimal_kwargs["base_url"] = self.base_url
324 | self._client = OpenAI(**minimal_kwargs)
325 | except Exception as fallback_error:
326 | logging.error("Even minimal OpenAI client creation failed: %s", fallback_error)
327 | raise
328 |
329 | return self._client
330 |
331 | def _sanitize_for_logging(self, params: dict) -> dict:
332 | """Sanitize sensitive data from parameters before logging.
333 |
334 | Args:
335 | params: Dictionary of API parameters
336 |
337 | Returns:
338 | dict: Sanitized copy of parameters safe for logging
339 | """
340 | sanitized = copy.deepcopy(params)
341 |
342 | # Sanitize messages content
343 | if "input" in sanitized:
344 | for msg in sanitized.get("input", []):
345 | if isinstance(msg, dict) and "content" in msg:
346 | for content_item in msg.get("content", []):
347 | if isinstance(content_item, dict) and "text" in content_item:
348 | # Truncate long text and add ellipsis
349 | text = content_item["text"]
350 | if len(text) > 100:
351 | content_item["text"] = text[:100] + "... [truncated]"
352 |
353 | # Remove any API keys that might be in headers/auth
354 | sanitized.pop("api_key", None)
355 | sanitized.pop("authorization", None)
356 |
357 | return sanitized
358 |
359 | def _safe_extract_output_text(self, response) -> str:
360 | """Safely extract output_text from o3-pro response with validation.
361 |
362 | Args:
363 | response: Response object from OpenAI SDK
364 |
365 | Returns:
366 | str: The output text content
367 |
368 | Raises:
369 | ValueError: If output_text is missing, None, or not a string
370 | """
371 | logging.debug(f"Response object type: {type(response)}")
372 | logging.debug(f"Response attributes: {dir(response)}")
373 |
374 | if not hasattr(response, "output_text"):
375 | raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
376 |
377 | content = response.output_text
378 | logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
379 |
380 | if content is None:
381 | raise ValueError("o3-pro returned None for output_text")
382 |
383 | if not isinstance(content, str):
384 | raise ValueError(f"o3-pro output_text is not a string. Got type: {type(content).__name__}")
385 |
386 | return content
387 |
388 | def _generate_with_responses_endpoint(
389 | self,
390 | model_name: str,
391 | messages: list,
392 | temperature: float,
393 | max_output_tokens: Optional[int] = None,
394 | capabilities: Optional[ModelCapabilities] = None,
395 | **kwargs,
396 | ) -> ModelResponse:
397 | """Generate content using the /v1/responses endpoint for reasoning models."""
398 | # Convert messages to the correct format for responses endpoint
399 | input_messages = []
400 |
401 | for message in messages:
402 | role = message.get("role", "")
403 | content = message.get("content", "")
404 |
405 | if role == "system":
406 | # For o3-pro, system messages should be handled carefully to avoid policy violations
407 | # Instead of prefixing with "System:", we'll include the system content naturally
408 | input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
409 | elif role == "user":
410 | input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
411 | elif role == "assistant":
412 | input_messages.append({"role": "assistant", "content": [{"type": "output_text", "text": content}]})
413 |
414 | # Prepare completion parameters for responses endpoint
415 | # Based on OpenAI documentation, use nested reasoning object for responses endpoint
416 | effort = "medium"
417 | if capabilities and capabilities.default_reasoning_effort:
418 | effort = capabilities.default_reasoning_effort
419 |
420 | completion_params = {
421 | "model": model_name,
422 | "input": input_messages,
423 | "reasoning": {"effort": effort},
424 | "store": True,
425 | }
426 |
427 | # Add max tokens if specified (using max_completion_tokens for responses endpoint)
428 | if max_output_tokens:
429 | completion_params["max_completion_tokens"] = max_output_tokens
430 |
431 | # For responses endpoint, we only add parameters that are explicitly supported
432 | # Remove unsupported chat completion parameters that may cause API errors
433 |
434 | # Retry logic with progressive delays
435 | max_retries = 4
436 | retry_delays = [1, 3, 5, 8]
437 | attempt_counter = {"value": 0}
438 |
439 | def _attempt() -> ModelResponse:
440 | attempt_counter["value"] += 1
441 | import json
442 |
443 | sanitized_params = self._sanitize_for_logging(completion_params)
444 | logging.info(
445 | f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
446 | )
447 |
448 | response = self.client.responses.create(**completion_params)
449 |
450 | content = self._safe_extract_output_text(response)
451 |
452 | usage = None
453 | if hasattr(response, "usage"):
454 | usage = self._extract_usage(response)
455 | elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
456 | input_tokens = getattr(response, "input_tokens", 0) or 0
457 | output_tokens = getattr(response, "output_tokens", 0) or 0
458 | usage = {
459 | "input_tokens": input_tokens,
460 | "output_tokens": output_tokens,
461 | "total_tokens": input_tokens + output_tokens,
462 | }
463 |
464 | return ModelResponse(
465 | content=content,
466 | usage=usage,
467 | model_name=model_name,
468 | friendly_name=self.FRIENDLY_NAME,
469 | provider=self.get_provider_type(),
470 | metadata={
471 | "model": getattr(response, "model", model_name),
472 | "id": getattr(response, "id", ""),
473 | "created": getattr(response, "created_at", 0),
474 | "endpoint": "responses",
475 | },
476 | )
477 |
478 | try:
479 | return self._run_with_retries(
480 | operation=_attempt,
481 | max_attempts=max_retries,
482 | delays=retry_delays,
483 | log_prefix="responses endpoint",
484 | )
485 | except Exception as exc:
486 | attempts = max(attempt_counter["value"], 1)
487 | error_msg = f"responses endpoint error after {attempts} attempt{'s' if attempts > 1 else ''}: {exc}"
488 | logging.error(error_msg)
489 | raise RuntimeError(error_msg) from exc
490 |
491 | def generate_content(
492 | self,
493 | prompt: str,
494 | model_name: str,
495 | system_prompt: Optional[str] = None,
496 | temperature: float = 0.3,
497 | max_output_tokens: Optional[int] = None,
498 | images: Optional[list[str]] = None,
499 | **kwargs,
500 | ) -> ModelResponse:
501 | """Generate content using the OpenAI-compatible API.
502 |
503 | Args:
504 | prompt: User prompt to send to the model
505 | model_name: Canonical model name or its alias
506 | system_prompt: Optional system prompt for model behavior
507 | temperature: Sampling temperature
508 | max_output_tokens: Maximum tokens to generate
509 | images: Optional list of image paths or data URLs to include with the prompt (for vision models)
510 | **kwargs: Additional provider-specific parameters
511 |
512 | Returns:
513 | ModelResponse with generated content and metadata
514 | """
515 | # Validate model name against allow-list
516 | if not self.validate_model_name(model_name):
517 | raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
518 |
519 | capabilities: Optional[ModelCapabilities]
520 | try:
521 | capabilities = self.get_capabilities(model_name)
522 | except Exception as exc:
523 | logging.debug(f"Falling back to generic capabilities for {model_name}: {exc}")
524 | capabilities = None
525 |
526 | # Get effective temperature for this model from capabilities when available
527 | if capabilities:
528 | effective_temperature = capabilities.get_effective_temperature(temperature)
529 | if effective_temperature is not None and effective_temperature != temperature:
530 | logging.debug(
531 | f"Adjusting temperature from {temperature} to {effective_temperature} for model {model_name}"
532 | )
533 | else:
534 | effective_temperature = temperature
535 |
536 | # Only validate if temperature is not None (meaning the model supports it)
537 | if effective_temperature is not None:
538 | # Validate parameters with the effective temperature
539 | self.validate_parameters(model_name, effective_temperature)
540 |
541 | # Resolve to canonical model name
542 | resolved_model = self._resolve_model_name(model_name)
543 |
544 | # Prepare messages
545 | messages = []
546 | if system_prompt:
547 | messages.append({"role": "system", "content": system_prompt})
548 |
549 | # Prepare user message with text and potentially images
550 | user_content = []
551 | user_content.append({"type": "text", "text": prompt})
552 |
553 | # Add images if provided and model supports vision
554 | if images and capabilities and capabilities.supports_images:
555 | for image_path in images:
556 | try:
557 | image_content = self._process_image(image_path)
558 | if image_content:
559 | user_content.append(image_content)
560 | except Exception as e:
561 | logging.warning(f"Failed to process image {image_path}: {e}")
562 | # Continue with other images and text
563 | continue
564 | elif images and (not capabilities or not capabilities.supports_images):
565 | logging.warning(f"Model {resolved_model} does not support images, ignoring {len(images)} image(s)")
566 |
567 | # Add user message
568 | if len(user_content) == 1:
569 | # Only text content, use simple string format for compatibility
570 | messages.append({"role": "user", "content": prompt})
571 | else:
572 | # Text + images, use content array format
573 | messages.append({"role": "user", "content": user_content})
574 |
575 | # Prepare completion parameters
576 | # Always disable streaming for OpenRouter
577 | # MCP doesn't use streaming, and this avoids issues with O3 model access
578 | completion_params = {
579 | "model": resolved_model,
580 | "messages": messages,
581 | "stream": False,
582 | }
583 |
584 | # Use the effective temperature we calculated earlier
585 | supports_sampling = effective_temperature is not None
586 |
587 | if supports_sampling:
588 | completion_params["temperature"] = effective_temperature
589 |
590 | # Add max tokens if specified and model supports it
591 | # O3/O4 models that don't support temperature also don't support max_tokens
592 | if max_output_tokens and supports_sampling:
593 | completion_params["max_tokens"] = max_output_tokens
594 |
595 | # Add any additional OpenAI-specific parameters
596 | # Use capabilities to filter parameters for reasoning models
597 | for key, value in kwargs.items():
598 | if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
599 | # Reasoning models (those that don't support temperature) also don't support these parameters
600 | if not supports_sampling and key in ["top_p", "frequency_penalty", "presence_penalty", "stream"]:
601 | continue # Skip unsupported parameters for reasoning models
602 | completion_params[key] = value
603 |
604 | # Check if this model needs the Responses API endpoint
605 | # Prefer capability metadata; fall back to static map when capabilities unavailable
606 | use_responses_api = False
607 | if capabilities is not None:
608 | use_responses_api = getattr(capabilities, "use_openai_response_api", False)
609 | else:
610 | static_capabilities = self.get_all_model_capabilities().get(resolved_model)
611 | if static_capabilities is not None:
612 | use_responses_api = getattr(static_capabilities, "use_openai_response_api", False)
613 |
614 | if use_responses_api:
615 | # These models require the /v1/responses endpoint for stateful context
616 | # If it fails, we should not fall back to chat/completions
617 | return self._generate_with_responses_endpoint(
618 | model_name=resolved_model,
619 | messages=messages,
620 | temperature=temperature,
621 | max_output_tokens=max_output_tokens,
622 | capabilities=capabilities,
623 | **kwargs,
624 | )
625 |
626 | # Retry logic with progressive delays
627 | max_retries = 4 # Total of 4 attempts
628 | retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
629 | attempt_counter = {"value": 0}
630 |
631 | def _attempt() -> ModelResponse:
632 | attempt_counter["value"] += 1
633 | response = self.client.chat.completions.create(**completion_params)
634 |
635 | content = response.choices[0].message.content
636 | usage = self._extract_usage(response)
637 |
638 | return ModelResponse(
639 | content=content,
640 | usage=usage,
641 | model_name=resolved_model,
642 | friendly_name=self.FRIENDLY_NAME,
643 | provider=self.get_provider_type(),
644 | metadata={
645 | "finish_reason": response.choices[0].finish_reason,
646 | "model": response.model,
647 | "id": response.id,
648 | "created": response.created,
649 | },
650 | )
651 |
652 | try:
653 | return self._run_with_retries(
654 | operation=_attempt,
655 | max_attempts=max_retries,
656 | delays=retry_delays,
657 | log_prefix=f"{self.FRIENDLY_NAME} API ({resolved_model})",
658 | )
659 | except Exception as exc:
660 | attempts = max(attempt_counter["value"], 1)
661 | error_msg = (
662 | f"{self.FRIENDLY_NAME} API error for model {resolved_model} after {attempts} attempt"
663 | f"{'s' if attempts > 1 else ''}: {exc}"
664 | )
665 | logging.error(error_msg)
666 | raise RuntimeError(error_msg) from exc
667 |
668 | def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
669 | """Validate model parameters.
670 |
671 | For proxy providers, this may use generic capabilities.
672 |
673 | Args:
674 | model_name: Canonical model name or its alias
675 | temperature: Temperature to validate
676 | **kwargs: Additional parameters to validate
677 | """
678 | try:
679 | capabilities = self.get_capabilities(model_name)
680 |
681 | # Check if we're using generic capabilities
682 | if hasattr(capabilities, "_is_generic"):
683 | logging.debug(
684 | f"Using generic parameter validation for {model_name}. Actual model constraints may differ."
685 | )
686 |
687 | # Validate temperature using parent class method
688 | super().validate_parameters(model_name, temperature, **kwargs)
689 |
690 | except Exception as e:
691 | # For proxy providers, we might not have accurate capabilities
692 | # Log warning but don't fail
693 | logging.warning(f"Parameter validation limited for {model_name}: {e}")
694 |
695 | def _extract_usage(self, response) -> dict[str, int]:
696 | """Extract token usage from OpenAI response.
697 |
698 | Args:
699 | response: OpenAI API response object
700 |
701 | Returns:
702 | Dictionary with usage statistics
703 | """
704 | usage = {}
705 |
706 | if hasattr(response, "usage") and response.usage:
707 | # Safely extract token counts with None handling
708 | usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) or 0
709 | usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) or 0
710 | usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) or 0
711 |
712 | return usage
713 |
714 | def count_tokens(self, text: str, model_name: str) -> int:
715 | """Count tokens using OpenAI-compatible tokenizer tables when available."""
716 |
717 | resolved_model = self._resolve_model_name(model_name)
718 |
719 | try:
720 | import tiktoken
721 |
722 | try:
723 | encoding = tiktoken.encoding_for_model(resolved_model)
724 | except KeyError:
725 | encoding = tiktoken.get_encoding("cl100k_base")
726 |
727 | return len(encoding.encode(text))
728 |
729 | except (ImportError, Exception) as exc:
730 | logging.debug("tiktoken unavailable for %s: %s", resolved_model, exc)
731 |
732 | return super().count_tokens(text, model_name)
733 |
734 | def _is_error_retryable(self, error: Exception) -> bool:
735 | """Determine if an error should be retried based on structured error codes.
736 |
737 | Uses OpenAI API error structure instead of text pattern matching for reliability.
738 |
739 | Args:
740 | error: Exception from OpenAI API call
741 |
742 | Returns:
743 | True if error should be retried, False otherwise
744 | """
745 | error_str = str(error).lower()
746 |
747 | # Check for 429 errors first - these need special handling
748 | if "429" in error_str:
749 | # Try to extract structured error information
750 | error_type = None
751 | error_code = None
752 |
753 | # Parse structured error from OpenAI API response
754 | # Format: "Error code: 429 - {'error': {'type': 'tokens', 'code': 'rate_limit_exceeded', ...}}"
755 | try:
756 | import ast
757 | import json
758 | import re
759 |
760 | # Extract JSON part from error string using regex
761 | # Look for pattern: {...} (from first { to last })
762 | json_match = re.search(r"\{.*\}", str(error))
763 | if json_match:
764 | json_like_str = json_match.group(0)
765 |
766 | # First try: parse as Python literal (handles single quotes safely)
767 | try:
768 | error_data = ast.literal_eval(json_like_str)
769 | except (ValueError, SyntaxError):
770 | # Fallback: try JSON parsing with simple quote replacement
771 | # (for cases where it's already valid JSON or simple replacements work)
772 | json_str = json_like_str.replace("'", '"')
773 | error_data = json.loads(json_str)
774 |
775 | if "error" in error_data:
776 | error_info = error_data["error"]
777 | error_type = error_info.get("type")
778 | error_code = error_info.get("code")
779 |
780 | except (json.JSONDecodeError, ValueError, SyntaxError, AttributeError):
781 | # Fall back to checking hasattr for OpenAI SDK exception objects
782 | if hasattr(error, "response") and hasattr(error.response, "json"):
783 | try:
784 | response_data = error.response.json()
785 | if "error" in response_data:
786 | error_info = response_data["error"]
787 | error_type = error_info.get("type")
788 | error_code = error_info.get("code")
789 | except Exception:
790 | pass
791 |
792 | # Determine if 429 is retryable based on structured error codes
793 | if error_type == "tokens":
794 | # Token-related 429s are typically non-retryable (request too large)
795 | logging.debug(f"Non-retryable 429: token-related error (type={error_type}, code={error_code})")
796 | return False
797 | elif error_code in ["invalid_request_error", "context_length_exceeded"]:
798 | # These are permanent failures
799 | logging.debug(f"Non-retryable 429: permanent failure (type={error_type}, code={error_code})")
800 | return False
801 | else:
802 | # Other 429s (like requests per minute) are retryable
803 | logging.debug(f"Retryable 429: rate limiting (type={error_type}, code={error_code})")
804 | return True
805 |
806 | # For non-429 errors, check if they're retryable
807 | retryable_indicators = [
808 | "timeout",
809 | "connection",
810 | "network",
811 | "temporary",
812 | "unavailable",
813 | "retry",
814 | "408", # Request timeout
815 | "500", # Internal server error
816 | "502", # Bad gateway
817 | "503", # Service unavailable
818 | "504", # Gateway timeout
819 | "ssl", # SSL errors
820 | "handshake", # Handshake failures
821 | ]
822 |
823 | return any(indicator in error_str for indicator in retryable_indicators)
824 |
825 | def _process_image(self, image_path: str) -> Optional[dict]:
826 | """Process an image for OpenAI-compatible API."""
827 | try:
828 | if image_path.startswith("data:"):
829 | # Validate the data URL
830 | validate_image(image_path)
831 | # Handle data URL: data:image/png;base64,iVBORw0...
832 | return {"type": "image_url", "image_url": {"url": image_path}}
833 | else:
834 | # Use base class validation
835 | image_bytes, mime_type = validate_image(image_path)
836 |
837 | # Read and encode the image
838 | import base64
839 |
840 | image_data = base64.b64encode(image_bytes).decode()
841 | logging.debug(f"Processing image '{image_path}' as MIME type '{mime_type}'")
842 |
843 | # Create data URL for OpenAI API
844 | data_url = f"data:{mime_type};base64,{image_data}"
845 |
846 | return {"type": "image_url", "image_url": {"url": data_url}}
847 |
848 | except ValueError as e:
849 | logging.warning(str(e))
850 | return None
851 | except Exception as e:
852 | logging.error(f"Error processing image {image_path}: {e}")
853 | return None
854 |
```
--------------------------------------------------------------------------------
/tools/tracer.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Tracer Workflow tool - Step-by-step code tracing and dependency analysis
3 |
4 | This tool provides a structured workflow for comprehensive code tracing and analysis.
5 | It guides the CLI agent through systematic investigation steps with forced pauses between each step
6 | to ensure thorough code examination, dependency mapping, and execution flow analysis before proceeding.
7 |
8 | The tracer guides users through sequential code analysis with full context awareness and
9 | the ability to revise and adapt as understanding deepens.
10 |
11 | Key features:
12 | - Sequential tracing with systematic investigation workflow
13 | - Support for precision tracing (execution flow) and dependencies tracing (structural relationships)
14 | - Self-contained completion with detailed output formatting instructions
15 | - Context-aware analysis that builds understanding step by step
16 | - No external expert analysis needed - provides comprehensive guidance internally
17 |
18 | Perfect for: method/function execution flow analysis, dependency mapping, call chain tracing,
19 | structural relationship analysis, architectural understanding, and code comprehension.
20 | """
21 |
22 | import logging
23 | from typing import TYPE_CHECKING, Any, Literal, Optional
24 |
25 | from pydantic import Field, field_validator
26 |
27 | if TYPE_CHECKING:
28 | from tools.models import ToolModelCategory
29 |
30 | from config import TEMPERATURE_ANALYTICAL
31 | from systemprompts import TRACER_PROMPT
32 | from tools.shared.base_models import WorkflowRequest
33 |
34 | from .workflow.base import WorkflowTool
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 | # Tool-specific field descriptions for tracer workflow
39 | TRACER_WORKFLOW_FIELD_DESCRIPTIONS = {
40 | "step": (
41 | "The plan for the current tracing step. Step 1: State the tracing strategy. Later steps: Report findings and adapt the plan. "
42 | "CRITICAL: For 'precision' mode, focus on execution flow and call chains. For 'dependencies' mode, focus on structural relationships. "
43 | "If trace_mode is 'ask' in step 1, you MUST prompt the user to choose a mode."
44 | ),
45 | "step_number": (
46 | "The index of the current step in the tracing sequence, beginning at 1. Each step should build upon or "
47 | "revise the previous one."
48 | ),
49 | "total_steps": (
50 | "Your current estimate for how many steps will be needed to complete the tracing analysis. "
51 | "Adjust as new findings emerge."
52 | ),
53 | "next_step_required": (
54 | "Set to true if you plan to continue the investigation with another step. False means you believe the "
55 | "tracing analysis is complete and ready for final output formatting."
56 | ),
57 | "findings": (
58 | "Summary of discoveries from this step, including execution paths, dependency relationships, call chains, and structural patterns. "
59 | "IMPORTANT: Document both direct (immediate calls) and indirect (transitive, side effects) relationships."
60 | ),
61 | "files_checked": (
62 | "List all files examined (absolute paths). Include even ruled-out files to track exploration path."
63 | ),
64 | "relevant_files": (
65 | "Subset of files_checked directly relevant to the tracing target (absolute paths). Include implementation files, "
66 | "dependencies, or files demonstrating key relationships."
67 | ),
68 | "relevant_context": (
69 | "List methods/functions central to the tracing analysis, in 'ClassName.methodName' or 'functionName' format. "
70 | "Prioritize those in the execution flow or dependency chain."
71 | ),
72 | "confidence": (
73 | "Your confidence in the tracing analysis. Use: 'exploring', 'low', 'medium', 'high', 'very_high', 'almost_certain', 'certain'. "
74 | "CRITICAL: 'certain' implies the analysis is 100% complete locally and PREVENTS external model validation."
75 | ),
76 | "trace_mode": "Type of tracing: 'ask' (default - prompts user to choose mode), 'precision' (execution flow) or 'dependencies' (structural relationships)",
77 | "target_description": (
78 | "Description of what to trace and WHY. Include context about what you're trying to understand or analyze."
79 | ),
80 | "images": ("Optional paths to architecture diagrams or flow charts that help understand the tracing context."),
81 | }
82 |
83 |
84 | class TracerRequest(WorkflowRequest):
85 | """Request model for tracer workflow investigation steps"""
86 |
87 | # Required fields for each investigation step
88 | step: str = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["step"])
89 | step_number: int = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
90 | total_steps: int = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
91 | next_step_required: bool = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
92 |
93 | # Investigation tracking fields
94 | findings: str = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["findings"])
95 | files_checked: list[str] = Field(
96 | default_factory=list, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["files_checked"]
97 | )
98 | relevant_files: list[str] = Field(
99 | default_factory=list, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"]
100 | )
101 | relevant_context: list[str] = Field(
102 | default_factory=list, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"]
103 | )
104 | confidence: Optional[str] = Field("exploring", description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["confidence"])
105 |
106 | # Tracer-specific fields (used in step 1 to initialize)
107 | trace_mode: Optional[Literal["precision", "dependencies", "ask"]] = Field(
108 | "ask", description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["trace_mode"]
109 | )
110 | target_description: Optional[str] = Field(
111 | None, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["target_description"]
112 | )
113 | images: Optional[list[str]] = Field(default=None, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["images"])
114 |
115 | # Exclude fields not relevant to tracing workflow
116 | issues_found: list[dict] = Field(default_factory=list, exclude=True, description="Tracing doesn't track issues")
117 | hypothesis: Optional[str] = Field(default=None, exclude=True, description="Tracing doesn't use hypothesis")
118 | # Exclude other non-tracing fields
119 | temperature: Optional[float] = Field(default=None, exclude=True)
120 | thinking_mode: Optional[str] = Field(default=None, exclude=True)
121 | use_assistant_model: Optional[bool] = Field(default=False, exclude=True, description="Tracing is self-contained")
122 |
123 | @field_validator("step_number")
124 | @classmethod
125 | def validate_step_number(cls, v):
126 | if v < 1:
127 | raise ValueError("step_number must be at least 1")
128 | return v
129 |
130 | @field_validator("total_steps")
131 | @classmethod
132 | def validate_total_steps(cls, v):
133 | if v < 1:
134 | raise ValueError("total_steps must be at least 1")
135 | return v
136 |
137 |
138 | class TracerTool(WorkflowTool):
139 | """
140 | Tracer workflow tool for step-by-step code tracing and dependency analysis.
141 |
142 | This tool implements a structured tracing workflow that guides users through
143 | methodical investigation steps, ensuring thorough code examination, dependency
144 | mapping, and execution flow analysis before reaching conclusions. It supports
145 | both precision tracing (execution flow) and dependencies tracing (structural relationships).
146 | """
147 |
148 | def __init__(self):
149 | super().__init__()
150 | self.initial_request = None
151 | self.trace_config = {}
152 |
153 | def get_name(self) -> str:
154 | return "tracer"
155 |
156 | def get_description(self) -> str:
157 | return (
158 | "Performs systematic code tracing with modes for execution flow or dependency mapping. "
159 | "Use for method execution analysis, call chain tracing, dependency mapping, and architectural understanding. "
160 | "Supports precision mode (execution flow) and dependencies mode (structural relationships)."
161 | )
162 |
163 | def get_system_prompt(self) -> str:
164 | return TRACER_PROMPT
165 |
166 | def get_default_temperature(self) -> float:
167 | return TEMPERATURE_ANALYTICAL
168 |
169 | def get_model_category(self) -> "ToolModelCategory":
170 | """Tracer requires analytical reasoning for code analysis"""
171 | from tools.models import ToolModelCategory
172 |
173 | return ToolModelCategory.EXTENDED_REASONING
174 |
175 | def requires_model(self) -> bool:
176 | """
177 | Tracer tool doesn't require model resolution at the MCP boundary.
178 |
179 | The tracer is a structured workflow tool that organizes tracing steps
180 | and provides detailed output formatting guidance without calling external AI models.
181 |
182 | Returns:
183 | bool: False - tracer doesn't need AI model access
184 | """
185 | return False
186 |
187 | def get_workflow_request_model(self):
188 | """Return the tracer-specific request model."""
189 | return TracerRequest
190 |
191 | def get_tool_fields(self) -> dict[str, dict[str, Any]]:
192 | """Return tracing-specific field definitions beyond the standard workflow fields."""
193 | return {
194 | # Tracer-specific fields
195 | "trace_mode": {
196 | "type": "string",
197 | "enum": ["precision", "dependencies", "ask"],
198 | "description": TRACER_WORKFLOW_FIELD_DESCRIPTIONS["trace_mode"],
199 | },
200 | "target_description": {
201 | "type": "string",
202 | "description": TRACER_WORKFLOW_FIELD_DESCRIPTIONS["target_description"],
203 | },
204 | "images": {
205 | "type": "array",
206 | "items": {"type": "string"},
207 | "description": TRACER_WORKFLOW_FIELD_DESCRIPTIONS["images"],
208 | },
209 | }
210 |
211 | def get_input_schema(self) -> dict[str, Any]:
212 | """Generate input schema using WorkflowSchemaBuilder with field exclusion."""
213 | from .workflow.schema_builders import WorkflowSchemaBuilder
214 |
215 | # Exclude investigation-specific fields that tracing doesn't need
216 | excluded_workflow_fields = [
217 | "issues_found", # Tracing doesn't track issues
218 | "hypothesis", # Tracing doesn't use hypothesis
219 | ]
220 |
221 | # Exclude common fields that tracing doesn't need
222 | excluded_common_fields = [
223 | "temperature", # Tracing doesn't need temperature control
224 | "thinking_mode", # Tracing doesn't need thinking mode
225 | "absolute_file_paths", # Tracing uses relevant_files instead
226 | ]
227 |
228 | return WorkflowSchemaBuilder.build_schema(
229 | tool_specific_fields=self.get_tool_fields(),
230 | required_fields=["target_description", "trace_mode"], # Step 1 requires these
231 | model_field_schema=self.get_model_field_schema(),
232 | auto_mode=self.is_effective_auto_mode(),
233 | tool_name=self.get_name(),
234 | excluded_workflow_fields=excluded_workflow_fields,
235 | excluded_common_fields=excluded_common_fields,
236 | )
237 |
238 | # ================================================================================
239 | # Abstract Methods - Required Implementation from BaseWorkflowMixin
240 | # ================================================================================
241 |
242 | def get_required_actions(
243 | self, step_number: int, confidence: str, findings: str, total_steps: int, request=None
244 | ) -> list[str]:
245 | """Define required actions for each tracing phase."""
246 | if step_number == 1:
247 | # Check if we're in ask mode and need to prompt for mode selection
248 | if self.get_trace_mode() == "ask":
249 | return [
250 | "MUST ask user to choose between precision or dependencies mode",
251 | "Explain precision mode: traces execution flow, call chains, and usage patterns (best for methods/functions)",
252 | "Explain dependencies mode: maps structural relationships and bidirectional dependencies (best for classes/modules)",
253 | "Wait for user's mode selection before proceeding with investigation",
254 | ]
255 |
256 | # Initial tracing investigation tasks (when mode is already selected)
257 | return [
258 | "Search for and locate the target method/function/class/module in the codebase",
259 | "Read and understand the implementation of the target code",
260 | "Identify the file location, complete signature, and basic structure",
261 | "Begin mapping immediate relationships (what it calls, what calls it)",
262 | "Understand the context and purpose of the target code",
263 | ]
264 | elif confidence in ["exploring", "low"]:
265 | # Need deeper investigation
266 | return [
267 | "Trace deeper into the execution flow or dependency relationships",
268 | "Examine how the target code is used throughout the codebase",
269 | "Map additional layers of dependencies or call chains",
270 | "Look for conditional execution paths, error handling, and edge cases",
271 | "Understand the broader architectural context and patterns",
272 | ]
273 | elif confidence in ["medium", "high"]:
274 | # Close to completion - need final verification
275 | return [
276 | "Verify completeness of the traced relationships and execution paths",
277 | "Check for any missed dependencies, usage patterns, or execution branches",
278 | "Confirm understanding of side effects, state changes, and external interactions",
279 | "Validate that the tracing covers all significant code relationships",
280 | "Prepare comprehensive findings for final output formatting",
281 | ]
282 | else:
283 | # General investigation needed
284 | return [
285 | "Continue systematic tracing of code relationships and execution paths",
286 | "Gather more evidence using appropriate code analysis techniques",
287 | "Test assumptions about code behavior and dependency relationships",
288 | "Look for patterns that enhance understanding of the code structure",
289 | "Focus on areas that haven't been thoroughly traced yet",
290 | ]
291 |
292 | def should_call_expert_analysis(self, consolidated_findings, request=None) -> bool:
293 | """Tracer is self-contained and doesn't need expert analysis."""
294 | return False
295 |
296 | def prepare_expert_analysis_context(self, consolidated_findings) -> str:
297 | """Tracer doesn't use expert analysis."""
298 | return ""
299 |
300 | def requires_expert_analysis(self) -> bool:
301 | """Tracer is self-contained like the planner tool."""
302 | return False
303 |
304 | # ================================================================================
305 | # Workflow Customization - Match Planner Behavior
306 | # ================================================================================
307 |
308 | def prepare_step_data(self, request) -> dict:
309 | """
310 | Prepare step data from request with tracer-specific fields.
311 | """
312 | step_data = {
313 | "step": request.step,
314 | "step_number": request.step_number,
315 | "findings": request.findings,
316 | "files_checked": request.files_checked,
317 | "relevant_files": request.relevant_files,
318 | "relevant_context": request.relevant_context,
319 | "issues_found": [], # Tracer doesn't track issues
320 | "confidence": request.confidence or "exploring",
321 | "hypothesis": None, # Tracer doesn't use hypothesis
322 | "images": request.images or [],
323 | # Tracer-specific fields
324 | "trace_mode": request.trace_mode,
325 | "target_description": request.target_description,
326 | }
327 | return step_data
328 |
329 | def build_base_response(self, request, continuation_id: str = None) -> dict:
330 | """
331 | Build the base response structure with tracer-specific fields.
332 | """
333 | # Use work_history from workflow mixin for consistent step tracking
334 | current_step_count = len(self.work_history) + 1
335 |
336 | response_data = {
337 | "status": f"{self.get_name()}_in_progress",
338 | "step_number": request.step_number,
339 | "total_steps": request.total_steps,
340 | "next_step_required": request.next_step_required,
341 | "step_content": request.step,
342 | f"{self.get_name()}_status": {
343 | "files_checked": len(self.consolidated_findings.files_checked),
344 | "relevant_files": len(self.consolidated_findings.relevant_files),
345 | "relevant_context": len(self.consolidated_findings.relevant_context),
346 | "issues_found": len(self.consolidated_findings.issues_found),
347 | "images_collected": len(self.consolidated_findings.images),
348 | "current_confidence": self.get_request_confidence(request),
349 | "step_history_length": current_step_count,
350 | },
351 | "metadata": {
352 | "trace_mode": self.trace_config.get("trace_mode", "unknown"),
353 | "target_description": self.trace_config.get("target_description", ""),
354 | "step_history_length": current_step_count,
355 | },
356 | }
357 |
358 | if continuation_id:
359 | response_data["continuation_id"] = continuation_id
360 |
361 | return response_data
362 |
363 | def handle_work_continuation(self, response_data: dict, request) -> dict:
364 | """
365 | Handle work continuation with tracer-specific guidance.
366 | """
367 | response_data["status"] = f"pause_for_{self.get_name()}"
368 | response_data[f"{self.get_name()}_required"] = True
369 |
370 | # Get tracer-specific required actions
371 | required_actions = self.get_required_actions(
372 | request.step_number, request.confidence or "exploring", request.findings, request.total_steps
373 | )
374 | response_data["required_actions"] = required_actions
375 |
376 | # Generate step-specific guidance
377 | if request.step_number == 1:
378 | # Check if we're in ask mode and need to prompt for mode selection
379 | if self.get_trace_mode() == "ask":
380 | response_data["next_steps"] = (
381 | f"STOP! You MUST ask the user to choose a tracing mode before proceeding. "
382 | f"Present these options clearly:\\n\\n"
383 | f"**PRECISION MODE**: Traces execution flow, call chains, and usage patterns. "
384 | f"Best for understanding how a specific method or function works, what it calls, "
385 | f"and how data flows through the execution path.\\n\\n"
386 | f"**DEPENDENCIES MODE**: Maps structural relationships and bidirectional dependencies. "
387 | f"Best for understanding how a class or module relates to other components, "
388 | f"what depends on it, and what it depends on.\\n\\n"
389 | f"After the user selects a mode, call {self.get_name()} again with step_number: 1 "
390 | f"but with the chosen trace_mode (either 'precision' or 'dependencies')."
391 | )
392 | else:
393 | response_data["next_steps"] = (
394 | f"MANDATORY: DO NOT call the {self.get_name()} tool again immediately. You MUST first investigate "
395 | f"the codebase to understand the target code. CRITICAL AWARENESS: You need to find and understand "
396 | f"the target method/function/class/module, examine its implementation, and begin mapping its "
397 | f"relationships. Use file reading tools, code search, and systematic examination to gather "
398 | f"comprehensive information about the target. Only call {self.get_name()} again AFTER completing "
399 | f"your investigation. When you call {self.get_name()} next time, use step_number: {request.step_number + 1} "
400 | f"and report specific files examined, code structure discovered, and initial relationship findings."
401 | )
402 | elif request.confidence in ["exploring", "low"]:
403 | next_step = request.step_number + 1
404 | response_data["next_steps"] = (
405 | f"STOP! Do NOT call {self.get_name()} again yet. Based on your findings, you've identified areas that need "
406 | f"deeper tracing analysis. MANDATORY ACTIONS before calling {self.get_name()} step {next_step}:\\n"
407 | + "\\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
408 | + f"\\n\\nOnly call {self.get_name()} again with step_number: {next_step} AFTER "
409 | + "completing these tracing investigations."
410 | )
411 | elif request.confidence in ["medium", "high"]:
412 | next_step = request.step_number + 1
413 | response_data["next_steps"] = (
414 | f"WAIT! Your tracing analysis needs final verification. DO NOT call {self.get_name()} immediately. "
415 | f"REQUIRED ACTIONS:\\n"
416 | + "\\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
417 | + f"\\n\\nREMEMBER: Ensure you have traced all significant relationships and execution paths. "
418 | f"Document findings with specific file references and method signatures, then call {self.get_name()} "
419 | f"with step_number: {next_step}."
420 | )
421 | else:
422 | # General investigation needed
423 | next_step = request.step_number + 1
424 | remaining_steps = request.total_steps - request.step_number
425 | response_data["next_steps"] = (
426 | f"Continue systematic tracing with step {next_step}. Approximately {remaining_steps} steps remaining. "
427 | f"Focus on deepening your understanding of the code relationships and execution patterns."
428 | )
429 |
430 | return response_data
431 |
432 | def customize_workflow_response(self, response_data: dict, request) -> dict:
433 | """
434 | Customize response to match tracer tool format with output instructions.
435 | """
436 | # Store trace configuration on first step
437 | if request.step_number == 1:
438 | self.initial_request = request.step
439 | self.trace_config = {
440 | "trace_mode": request.trace_mode,
441 | "target_description": request.target_description,
442 | }
443 |
444 | # Update metadata with trace configuration
445 | if "metadata" in response_data:
446 | response_data["metadata"]["trace_mode"] = request.trace_mode or "unknown"
447 | response_data["metadata"]["target_description"] = request.target_description or ""
448 |
449 | # If in ask mode, mark this as mode selection phase
450 | if request.trace_mode == "ask":
451 | response_data["mode_selection_required"] = True
452 | response_data["status"] = "mode_selection_required"
453 |
454 | # Add tracer-specific output instructions for final steps
455 | if not request.next_step_required:
456 | response_data["tracing_complete"] = True
457 | response_data["trace_summary"] = f"TRACING COMPLETE: {request.step}"
458 |
459 | # Get mode-specific output instructions
460 | trace_mode = self.trace_config.get("trace_mode", "precision")
461 | rendering_instructions = self._get_rendering_instructions(trace_mode)
462 |
463 | response_data["output"] = {
464 | "instructions": (
465 | "This is a structured tracing analysis response. Present the comprehensive tracing findings "
466 | "using the specific rendering format for the trace mode. Follow the exact formatting guidelines "
467 | "provided in rendering_instructions. Include all discovered relationships, execution paths, "
468 | "and dependencies with precise file references and line numbers."
469 | ),
470 | "format": f"{trace_mode}_trace_analysis",
471 | "rendering_instructions": rendering_instructions,
472 | "presentation_guidelines": {
473 | "completed_trace": (
474 | "Use the exact rendering format specified for the trace mode. Include comprehensive "
475 | "diagrams, tables, and structured analysis. Reference specific file paths and line numbers. "
476 | "Follow formatting rules precisely."
477 | ),
478 | "step_content": "Present as main analysis with clear structure and actionable insights.",
479 | "continuation": "Use continuation_id for related tracing sessions or follow-up analysis",
480 | },
481 | }
482 | response_data["next_steps"] = (
483 | f"Tracing analysis complete. Present the comprehensive {trace_mode} trace analysis to the user "
484 | f"using the exact rendering format specified in the output instructions. Follow the formatting "
485 | f"guidelines precisely, including diagrams, tables, and file references. After presenting the "
486 | f"analysis, offer to help with related tracing tasks or use the continuation_id for follow-up analysis."
487 | )
488 |
489 | # Convert generic status names to tracer-specific ones
490 | tool_name = self.get_name()
491 | status_mapping = {
492 | f"{tool_name}_in_progress": "tracing_in_progress",
493 | f"pause_for_{tool_name}": "pause_for_tracing",
494 | f"{tool_name}_required": "tracing_required",
495 | f"{tool_name}_complete": "tracing_complete",
496 | }
497 |
498 | if response_data["status"] in status_mapping:
499 | response_data["status"] = status_mapping[response_data["status"]]
500 |
501 | return response_data
502 |
503 | def _get_rendering_instructions(self, trace_mode: str) -> str:
504 | """
505 | Get mode-specific rendering instructions for the CLI agent.
506 |
507 | Args:
508 | trace_mode: Either "precision" or "dependencies"
509 |
510 | Returns:
511 | str: Complete rendering instructions for the specified mode
512 | """
513 | if trace_mode == "precision":
514 | return self._get_precision_rendering_instructions()
515 | else: # dependencies mode
516 | return self._get_dependencies_rendering_instructions()
517 |
518 | def _get_precision_rendering_instructions(self) -> str:
519 | """Get rendering instructions for precision trace mode."""
520 | return """
521 | ## MANDATORY RENDERING INSTRUCTIONS FOR PRECISION TRACE
522 |
523 | You MUST render the trace analysis using ONLY the Vertical Indented Flow Style:
524 |
525 | ### CALL FLOW DIAGRAM - Vertical Indented Style
526 |
527 | **EXACT FORMAT TO FOLLOW:**
528 | ```
529 | [ClassName::MethodName] (file: /complete/file/path.ext, line: ##)
530 | ↓
531 | [AnotherClass::calledMethod] (file: /path/to/file.ext, line: ##)
532 | ↓
533 | [ThirdClass::nestedMethod] (file: /path/file.ext, line: ##)
534 | ↓
535 | [DeeperClass::innerCall] (file: /path/inner.ext, line: ##) ? if some_condition
536 | ↓
537 | [ServiceClass::processData] (file: /services/service.ext, line: ##)
538 | ↓
539 | [RepositoryClass::saveData] (file: /data/repo.ext, line: ##)
540 | ↓
541 | [ClientClass::sendRequest] (file: /clients/client.ext, line: ##)
542 | ↓
543 | [EmailService::sendEmail] (file: /email/service.ext, line: ##) ⚠️ ambiguous branch
544 | →
545 | [SMSService::sendSMS] (file: /sms/service.ext, line: ##) ⚠️ ambiguous branch
546 | ```
547 |
548 | **CRITICAL FORMATTING RULES:**
549 |
550 | 1. **Method Names**: Use the actual naming convention of the project language you're analyzing. Automatically detect and adapt to the project's conventions (camelCase, snake_case, PascalCase, etc.) based on the codebase structure and file extensions.
551 |
552 | 2. **Vertical Flow Arrows**:
553 | - Use `↓` for standard sequential calls (vertical flow)
554 | - Use `→` for parallel/alternative calls (horizontal branch)
555 | - NEVER use other arrow types
556 |
557 | 3. **Indentation Logic**:
558 | - Start at column 0 for entry point
559 | - Indent 2 spaces for each nesting level
560 | - Maintain consistent indentation for same call depth
561 | - Sibling calls at same level should have same indentation
562 |
563 | 4. **Conditional Calls**:
564 | - Add `? if condition_description` after method for conditional execution
565 | - Use actual condition names from code when possible
566 |
567 | 5. **Ambiguous Branches**:
568 | - Mark with `⚠️ ambiguous branch` when execution path is uncertain
569 | - Use `→` to show alternative paths at same indentation level
570 |
571 | 6. **File Path Format**:
572 | - Use complete relative paths from project root
573 | - Include actual file extensions from the project
574 | - Show exact line numbers where method is defined
575 |
576 | ### ADDITIONAL ANALYSIS VIEWS
577 |
578 | **1. BRANCHING & SIDE EFFECT TABLE**
579 |
580 | | Location | Condition | Branches | Uncertain |
581 | |----------|-----------|----------|-----------|
582 | | CompleteFileName.ext:## | if actual_condition_from_code | method1(), method2(), else skip | No |
583 | | AnotherFile.ext:## | if boolean_check | callMethod(), else return | No |
584 | | ThirdFile.ext:## | if validation_passes | processData(), else throw | Yes |
585 |
586 | **2. SIDE EFFECTS**
587 | ```
588 | Side Effects:
589 | - [database] Specific database operation description (CompleteFileName.ext:##)
590 | - [network] Specific network call description (CompleteFileName.ext:##)
591 | - [filesystem] Specific file operation description (CompleteFileName.ext:##)
592 | - [state] State changes or property modifications (CompleteFileName.ext:##)
593 | - [memory] Memory allocation or cache operations (CompleteFileName.ext:##)
594 | ```
595 |
596 | **3. USAGE POINTS**
597 | ```
598 | Usage Points:
599 | 1. FileName.ext:## - Context description of where/why it's called
600 | 2. AnotherFile.ext:## - Context description of usage scenario
601 | 3. ThirdFile.ext:## - Context description of calling pattern
602 | 4. FourthFile.ext:## - Context description of integration point
603 | ```
604 |
605 | **4. ENTRY POINTS**
606 | ```
607 | Entry Points:
608 | - ClassName::methodName (context: where this flow typically starts)
609 | - AnotherClass::entryMethod (context: alternative entry scenario)
610 | - ThirdClass::triggerMethod (context: event-driven entry point)
611 | ```
612 |
613 | **ABSOLUTE REQUIREMENTS:**
614 | - Use ONLY the vertical indented style for the call flow diagram
615 | - Present ALL FOUR additional analysis views (Branching Table, Side Effects, Usage Points, Entry Points)
616 | - Adapt method naming to match the project's programming language conventions
617 | - Use exact file paths and line numbers from the actual codebase
618 | - DO NOT invent or guess method names or locations
619 | - Follow indentation rules precisely for call hierarchy
620 | - Mark uncertain execution paths clearly
621 | - Provide contextual descriptions in Usage Points and Entry Points sections
622 | - Include comprehensive side effects categorization (database, network, filesystem, state, memory)"""
623 |
624 | def _get_dependencies_rendering_instructions(self) -> str:
625 | """Get rendering instructions for dependencies trace mode."""
626 | return """
627 | ## MANDATORY RENDERING INSTRUCTIONS FOR DEPENDENCIES TRACE
628 |
629 | You MUST render the trace analysis using ONLY the Bidirectional Arrow Flow Style:
630 |
631 | ### DEPENDENCY FLOW DIAGRAM - Bidirectional Arrow Style
632 |
633 | **EXACT FORMAT TO FOLLOW:**
634 | ```
635 | INCOMING DEPENDENCIES → [TARGET_CLASS/MODULE] → OUTGOING DEPENDENCIES
636 |
637 | CallerClass::callerMethod ←────┐
638 | AnotherCaller::anotherMethod ←─┤
639 | ThirdCaller::thirdMethod ←─────┤
640 | │
641 | [TARGET_CLASS/MODULE]
642 | │
643 | ├────→ FirstDependency::method
644 | ├────→ SecondDependency::method
645 | └────→ ThirdDependency::method
646 |
647 | TYPE RELATIONSHIPS:
648 | InterfaceName ──implements──→ [TARGET_CLASS] ──extends──→ BaseClass
649 | DTOClass ──uses──→ [TARGET_CLASS] ──uses──→ EntityClass
650 | ```
651 |
652 | **CRITICAL FORMATTING RULES:**
653 |
654 | 1. **Target Placement**: Always place the target class/module in square brackets `[TARGET_NAME]` at the center
655 | 2. **Incoming Dependencies**: Show on the left side with `←` arrows pointing INTO the target
656 | 3. **Outgoing Dependencies**: Show on the right side with `→` arrows pointing OUT FROM the target
657 | 4. **Arrow Alignment**: Use consistent spacing and alignment for visual clarity
658 | 5. **Method Naming**: Use the project's actual naming conventions detected from the codebase
659 | 6. **File References**: Include complete file paths and line numbers
660 |
661 | **VISUAL LAYOUT RULES:**
662 |
663 | 1. **Header Format**: Always start with the flow direction indicator
664 | 2. **Left Side (Incoming)**:
665 | - List all callers with `←` arrows
666 | - Use `┐`, `┤`, `┘` box drawing characters for clean connection lines
667 | - Align arrows consistently
668 |
669 | 3. **Center (Target)**:
670 | - Enclose target in square brackets
671 | - Position centrally between incoming and outgoing
672 |
673 | 4. **Right Side (Outgoing)**:
674 | - List all dependencies with `→` arrows
675 | - Use `├`, `└` box drawing characters for branching
676 | - Maintain consistent spacing
677 |
678 | 5. **Type Relationships Section**:
679 | - Use `──relationship──→` format with double hyphens
680 | - Show inheritance, implementation, and usage relationships
681 | - Place below the main flow diagram
682 |
683 | **DEPENDENCY TABLE:**
684 |
685 | | Type | From/To | Method | File | Line |
686 | |------|---------|--------|------|------|
687 | | incoming_call | From: CallerClass | callerMethod | /complete/path/file.ext | ## |
688 | | outgoing_call | To: TargetClass | targetMethod | /complete/path/file.ext | ## |
689 | | implements | Self: ThisClass | — | /complete/path/file.ext | — |
690 | | extends | Self: ThisClass | — | /complete/path/file.ext | — |
691 | | uses_type | Self: ThisClass | — | /complete/path/file.ext | — |
692 |
693 | **ABSOLUTE REQUIREMENTS:**
694 | - Use ONLY the bidirectional arrow flow style shown above
695 | - Automatically detect and use the project's naming conventions
696 | - Use exact file paths and line numbers from the actual codebase
697 | - DO NOT invent or guess method/class names
698 | - Maintain visual alignment and consistent spacing
699 | - Include type relationships section when applicable
700 | - Show clear directional flow with proper arrows"""
701 |
702 | # ================================================================================
703 | # Hook Method Overrides for Tracer-Specific Behavior
704 | # ================================================================================
705 |
706 | def get_completion_status(self) -> str:
707 | """Tracer uses tracing-specific status."""
708 | return "tracing_complete"
709 |
710 | def get_completion_data_key(self) -> str:
711 | """Tracer uses 'complete_tracing' key."""
712 | return "complete_tracing"
713 |
714 | def get_completion_message(self) -> str:
715 | """Tracer-specific completion message."""
716 | return (
717 | "Tracing analysis complete. Present the comprehensive trace analysis to the user "
718 | "using the specified rendering format and offer to help with related tracing tasks."
719 | )
720 |
721 | def get_skip_reason(self) -> str:
722 | """Tracer-specific skip reason."""
723 | return "Tracer is self-contained and completes analysis without external assistance"
724 |
725 | def get_skip_expert_analysis_status(self) -> str:
726 | """Tracer-specific expert analysis skip status."""
727 | return "skipped_by_tool_design"
728 |
729 | def store_initial_issue(self, step_description: str):
730 | """Store initial tracing description."""
731 | self.initial_tracing_description = step_description
732 |
733 | def get_initial_request(self, fallback_step: str) -> str:
734 | """Get initial tracing description."""
735 | try:
736 | return self.initial_tracing_description
737 | except AttributeError:
738 | return fallback_step
739 |
740 | def get_request_confidence(self, request) -> str:
741 | """Get confidence from request for tracer workflow."""
742 | try:
743 | return request.confidence or "exploring"
744 | except AttributeError:
745 | return "exploring"
746 |
747 | def get_trace_mode(self) -> str:
748 | """Get current trace mode. Override for custom trace mode handling."""
749 | try:
750 | return self.trace_config.get("trace_mode", "ask")
751 | except AttributeError:
752 | return "ask"
753 |
754 | # Required abstract methods from BaseTool
755 | def get_request_model(self):
756 | """Return the tracer-specific request model."""
757 | return TracerRequest
758 |
759 | async def prepare_prompt(self, request) -> str:
760 | """Not used - workflow tools use execute_workflow()."""
761 | return "" # Workflow tools use execute_workflow() directly
762 |
```