This is page 5 of 25. Use http://codebase.md/beehiveinnovations/gemini-mcp-server?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .claude
│ ├── commands
│ │ └── fix-github-issue.md
│ └── settings.json
├── .coveragerc
├── .dockerignore
├── .env.example
├── .gitattributes
├── .github
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.yml
│ │ ├── config.yml
│ │ ├── documentation.yml
│ │ ├── feature_request.yml
│ │ └── tool_addition.yml
│ ├── pull_request_template.md
│ └── workflows
│ ├── docker-pr.yml
│ ├── docker-release.yml
│ ├── semantic-pr.yml
│ ├── semantic-release.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── AGENTS.md
├── CHANGELOG.md
├── claude_config_example.json
├── CLAUDE.md
├── clink
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── claude.py
│ │ ├── codex.py
│ │ └── gemini.py
│ ├── constants.py
│ ├── models.py
│ ├── parsers
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── claude.py
│ │ ├── codex.py
│ │ └── gemini.py
│ └── registry.py
├── code_quality_checks.ps1
├── code_quality_checks.sh
├── communication_simulator_test.py
├── conf
│ ├── __init__.py
│ ├── azure_models.json
│ ├── cli_clients
│ │ ├── claude.json
│ │ ├── codex.json
│ │ └── gemini.json
│ ├── custom_models.json
│ ├── dial_models.json
│ ├── gemini_models.json
│ ├── openai_models.json
│ ├── openrouter_models.json
│ └── xai_models.json
├── config.py
├── docker
│ ├── README.md
│ └── scripts
│ ├── build.ps1
│ ├── build.sh
│ ├── deploy.ps1
│ ├── deploy.sh
│ └── healthcheck.py
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── adding_providers.md
│ ├── adding_tools.md
│ ├── advanced-usage.md
│ ├── ai_banter.md
│ ├── ai-collaboration.md
│ ├── azure_openai.md
│ ├── configuration.md
│ ├── context-revival.md
│ ├── contributions.md
│ ├── custom_models.md
│ ├── docker-deployment.md
│ ├── gemini-setup.md
│ ├── getting-started.md
│ ├── index.md
│ ├── locale-configuration.md
│ ├── logging.md
│ ├── model_ranking.md
│ ├── testing.md
│ ├── tools
│ │ ├── analyze.md
│ │ ├── apilookup.md
│ │ ├── challenge.md
│ │ ├── chat.md
│ │ ├── clink.md
│ │ ├── codereview.md
│ │ ├── consensus.md
│ │ ├── debug.md
│ │ ├── docgen.md
│ │ ├── listmodels.md
│ │ ├── planner.md
│ │ ├── precommit.md
│ │ ├── refactor.md
│ │ ├── secaudit.md
│ │ ├── testgen.md
│ │ ├── thinkdeep.md
│ │ ├── tracer.md
│ │ └── version.md
│ ├── troubleshooting.md
│ ├── vcr-testing.md
│ └── wsl-setup.md
├── examples
│ ├── claude_config_macos.json
│ └── claude_config_wsl.json
├── LICENSE
├── providers
│ ├── __init__.py
│ ├── azure_openai.py
│ ├── base.py
│ ├── custom.py
│ ├── dial.py
│ ├── gemini.py
│ ├── openai_compatible.py
│ ├── openai.py
│ ├── openrouter.py
│ ├── registries
│ │ ├── __init__.py
│ │ ├── azure.py
│ │ ├── base.py
│ │ ├── custom.py
│ │ ├── dial.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ ├── openrouter.py
│ │ └── xai.py
│ ├── registry_provider_mixin.py
│ ├── registry.py
│ ├── shared
│ │ ├── __init__.py
│ │ ├── model_capabilities.py
│ │ ├── model_response.py
│ │ ├── provider_type.py
│ │ └── temperature.py
│ └── xai.py
├── pyproject.toml
├── pytest.ini
├── README.md
├── requirements-dev.txt
├── requirements.txt
├── run_integration_tests.ps1
├── run_integration_tests.sh
├── run-server.ps1
├── run-server.sh
├── scripts
│ └── sync_version.py
├── server.py
├── simulator_tests
│ ├── __init__.py
│ ├── base_test.py
│ ├── conversation_base_test.py
│ ├── log_utils.py
│ ├── test_analyze_validation.py
│ ├── test_basic_conversation.py
│ ├── test_chat_simple_validation.py
│ ├── test_codereview_validation.py
│ ├── test_consensus_conversation.py
│ ├── test_consensus_three_models.py
│ ├── test_consensus_workflow_accurate.py
│ ├── test_content_validation.py
│ ├── test_conversation_chain_validation.py
│ ├── test_cross_tool_comprehensive.py
│ ├── test_cross_tool_continuation.py
│ ├── test_debug_certain_confidence.py
│ ├── test_debug_validation.py
│ ├── test_line_number_validation.py
│ ├── test_logs_validation.py
│ ├── test_model_thinking_config.py
│ ├── test_o3_model_selection.py
│ ├── test_o3_pro_expensive.py
│ ├── test_ollama_custom_url.py
│ ├── test_openrouter_fallback.py
│ ├── test_openrouter_models.py
│ ├── test_per_tool_deduplication.py
│ ├── test_planner_continuation_history.py
│ ├── test_planner_validation_old.py
│ ├── test_planner_validation.py
│ ├── test_precommitworkflow_validation.py
│ ├── test_prompt_size_limit_bug.py
│ ├── test_refactor_validation.py
│ ├── test_secaudit_validation.py
│ ├── test_testgen_validation.py
│ ├── test_thinkdeep_validation.py
│ ├── test_token_allocation_validation.py
│ ├── test_vision_capability.py
│ └── test_xai_models.py
├── systemprompts
│ ├── __init__.py
│ ├── analyze_prompt.py
│ ├── chat_prompt.py
│ ├── clink
│ │ ├── codex_codereviewer.txt
│ │ ├── default_codereviewer.txt
│ │ ├── default_planner.txt
│ │ └── default.txt
│ ├── codereview_prompt.py
│ ├── consensus_prompt.py
│ ├── debug_prompt.py
│ ├── docgen_prompt.py
│ ├── generate_code_prompt.py
│ ├── planner_prompt.py
│ ├── precommit_prompt.py
│ ├── refactor_prompt.py
│ ├── secaudit_prompt.py
│ ├── testgen_prompt.py
│ ├── thinkdeep_prompt.py
│ └── tracer_prompt.py
├── tests
│ ├── __init__.py
│ ├── CASSETTE_MAINTENANCE.md
│ ├── conftest.py
│ ├── gemini_cassettes
│ │ ├── chat_codegen
│ │ │ └── gemini25_pro_calculator
│ │ │ └── mldev.json
│ │ ├── chat_cross
│ │ │ └── step1_gemini25_flash_number
│ │ │ └── mldev.json
│ │ └── consensus
│ │ └── step2_gemini25_flash_against
│ │ └── mldev.json
│ ├── http_transport_recorder.py
│ ├── mock_helpers.py
│ ├── openai_cassettes
│ │ ├── chat_cross_step2_gpt5_reminder.json
│ │ ├── chat_gpt5_continuation.json
│ │ ├── chat_gpt5_moon_distance.json
│ │ ├── consensus_step1_gpt5_for.json
│ │ └── o3_pro_basic_math.json
│ ├── pii_sanitizer.py
│ ├── sanitize_cassettes.py
│ ├── test_alias_target_restrictions.py
│ ├── test_auto_mode_comprehensive.py
│ ├── test_auto_mode_custom_provider_only.py
│ ├── test_auto_mode_model_listing.py
│ ├── test_auto_mode_provider_selection.py
│ ├── test_auto_mode.py
│ ├── test_auto_model_planner_fix.py
│ ├── test_azure_openai_provider.py
│ ├── test_buggy_behavior_prevention.py
│ ├── test_cassette_semantic_matching.py
│ ├── test_challenge.py
│ ├── test_chat_codegen_integration.py
│ ├── test_chat_cross_model_continuation.py
│ ├── test_chat_openai_integration.py
│ ├── test_chat_simple.py
│ ├── test_clink_claude_agent.py
│ ├── test_clink_claude_parser.py
│ ├── test_clink_codex_agent.py
│ ├── test_clink_gemini_agent.py
│ ├── test_clink_gemini_parser.py
│ ├── test_clink_integration.py
│ ├── test_clink_parsers.py
│ ├── test_clink_tool.py
│ ├── test_collaboration.py
│ ├── test_config.py
│ ├── test_consensus_integration.py
│ ├── test_consensus_schema.py
│ ├── test_consensus.py
│ ├── test_conversation_continuation_integration.py
│ ├── test_conversation_field_mapping.py
│ ├── test_conversation_file_features.py
│ ├── test_conversation_memory.py
│ ├── test_conversation_missing_files.py
│ ├── test_custom_openai_temperature_fix.py
│ ├── test_custom_provider.py
│ ├── test_debug.py
│ ├── test_deploy_scripts.py
│ ├── test_dial_provider.py
│ ├── test_directory_expansion_tracking.py
│ ├── test_disabled_tools.py
│ ├── test_docker_claude_desktop_integration.py
│ ├── test_docker_config_complete.py
│ ├── test_docker_healthcheck.py
│ ├── test_docker_implementation.py
│ ├── test_docker_mcp_validation.py
│ ├── test_docker_security.py
│ ├── test_docker_volume_persistence.py
│ ├── test_file_protection.py
│ ├── test_gemini_token_usage.py
│ ├── test_image_support_integration.py
│ ├── test_image_validation.py
│ ├── test_integration_utf8.py
│ ├── test_intelligent_fallback.py
│ ├── test_issue_245_simple.py
│ ├── test_large_prompt_handling.py
│ ├── test_line_numbers_integration.py
│ ├── test_listmodels_restrictions.py
│ ├── test_listmodels.py
│ ├── test_mcp_error_handling.py
│ ├── test_model_enumeration.py
│ ├── test_model_metadata_continuation.py
│ ├── test_model_resolution_bug.py
│ ├── test_model_restrictions.py
│ ├── test_o3_pro_output_text_fix.py
│ ├── test_o3_temperature_fix_simple.py
│ ├── test_openai_compatible_token_usage.py
│ ├── test_openai_provider.py
│ ├── test_openrouter_provider.py
│ ├── test_openrouter_registry.py
│ ├── test_parse_model_option.py
│ ├── test_per_tool_model_defaults.py
│ ├── test_pii_sanitizer.py
│ ├── test_pip_detection_fix.py
│ ├── test_planner.py
│ ├── test_precommit_workflow.py
│ ├── test_prompt_regression.py
│ ├── test_prompt_size_limit_bug_fix.py
│ ├── test_provider_retry_logic.py
│ ├── test_provider_routing_bugs.py
│ ├── test_provider_utf8.py
│ ├── test_providers.py
│ ├── test_rate_limit_patterns.py
│ ├── test_refactor.py
│ ├── test_secaudit.py
│ ├── test_server.py
│ ├── test_supported_models_aliases.py
│ ├── test_thinking_modes.py
│ ├── test_tools.py
│ ├── test_tracer.py
│ ├── test_utf8_localization.py
│ ├── test_utils.py
│ ├── test_uvx_resource_packaging.py
│ ├── test_uvx_support.py
│ ├── test_workflow_file_embedding.py
│ ├── test_workflow_metadata.py
│ ├── test_workflow_prompt_size_validation_simple.py
│ ├── test_workflow_utf8.py
│ ├── test_xai_provider.py
│ ├── transport_helpers.py
│ └── triangle.png
├── tools
│ ├── __init__.py
│ ├── analyze.py
│ ├── apilookup.py
│ ├── challenge.py
│ ├── chat.py
│ ├── clink.py
│ ├── codereview.py
│ ├── consensus.py
│ ├── debug.py
│ ├── docgen.py
│ ├── listmodels.py
│ ├── models.py
│ ├── planner.py
│ ├── precommit.py
│ ├── refactor.py
│ ├── secaudit.py
│ ├── shared
│ │ ├── __init__.py
│ │ ├── base_models.py
│ │ ├── base_tool.py
│ │ ├── exceptions.py
│ │ └── schema_builders.py
│ ├── simple
│ │ ├── __init__.py
│ │ └── base.py
│ ├── testgen.py
│ ├── thinkdeep.py
│ ├── tracer.py
│ ├── version.py
│ └── workflow
│ ├── __init__.py
│ ├── base.py
│ ├── schema_builders.py
│ └── workflow_mixin.py
├── utils
│ ├── __init__.py
│ ├── client_info.py
│ ├── conversation_memory.py
│ ├── env.py
│ ├── file_types.py
│ ├── file_utils.py
│ ├── image_utils.py
│ ├── model_context.py
│ ├── model_restrictions.py
│ ├── security_config.py
│ ├── storage_backend.py
│ └── token_utils.py
└── zen-mcp-server
```
# Files
--------------------------------------------------------------------------------
/tests/test_conversation_field_mapping.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Test that conversation history is correctly mapped to tool-specific fields
3 | """
4 |
5 | from datetime import datetime
6 | from unittest.mock import patch
7 |
8 | import pytest
9 |
10 | from server import reconstruct_thread_context
11 | from utils.conversation_memory import ConversationTurn, ThreadContext
12 |
13 |
14 | @pytest.mark.asyncio
15 | @pytest.mark.no_mock_provider
16 | async def test_conversation_history_field_mapping():
17 | """Test that enhanced prompts are mapped to prompt field for all tools"""
18 |
19 | # Test data for different tools - all use 'prompt' now
20 | test_cases = [
21 | {
22 | "tool_name": "analyze",
23 | "original_value": "What does this code do?",
24 | },
25 | {
26 | "tool_name": "chat",
27 | "original_value": "Explain this concept",
28 | },
29 | {
30 | "tool_name": "debug",
31 | "original_value": "Getting undefined error",
32 | },
33 | {
34 | "tool_name": "codereview",
35 | "original_value": "Review this implementation",
36 | },
37 | {
38 | "tool_name": "thinkdeep",
39 | "original_value": "My analysis so far",
40 | },
41 | ]
42 |
43 | for test_case in test_cases:
44 | # Create real conversation context
45 | mock_context = ThreadContext(
46 | thread_id="test-thread-123",
47 | tool_name=test_case["tool_name"],
48 | created_at=datetime.now().isoformat(),
49 | last_updated_at=datetime.now().isoformat(),
50 | turns=[
51 | ConversationTurn(
52 | role="user",
53 | content="Previous user message",
54 | timestamp=datetime.now().isoformat(),
55 | files=["/test/file1.py"],
56 | ),
57 | ConversationTurn(
58 | role="assistant",
59 | content="Previous assistant response",
60 | timestamp=datetime.now().isoformat(),
61 | ),
62 | ],
63 | initial_context={},
64 | )
65 |
66 | # Mock get_thread to return our test context
67 | with patch("utils.conversation_memory.get_thread", return_value=mock_context):
68 | with patch("utils.conversation_memory.add_turn", return_value=True):
69 | # Create arguments with continuation_id and use a test model
70 | arguments = {
71 | "continuation_id": "test-thread-123",
72 | "prompt": test_case["original_value"],
73 | "absolute_file_paths": ["/test/file2.py"],
74 | "model": "flash", # Use test model to avoid provider errors
75 | }
76 |
77 | # Call reconstruct_thread_context
78 | enhanced_args = await reconstruct_thread_context(arguments)
79 |
80 | # Verify the enhanced prompt is in the prompt field
81 | assert "prompt" in enhanced_args
82 | enhanced_value = enhanced_args["prompt"]
83 |
84 | # Should contain conversation history
85 | assert "=== CONVERSATION HISTORY" in enhanced_value # Allow for both formats
86 | assert "Previous user message" in enhanced_value
87 | assert "Previous assistant response" in enhanced_value
88 |
89 | # Should contain the new user input
90 | assert "=== NEW USER INPUT ===" in enhanced_value
91 | assert test_case["original_value"] in enhanced_value
92 |
93 | # Should have token budget
94 | assert "_remaining_tokens" in enhanced_args
95 | assert enhanced_args["_remaining_tokens"] > 0
96 |
97 |
98 | @pytest.mark.asyncio
99 | @pytest.mark.no_mock_provider
100 | async def test_unknown_tool_defaults_to_prompt():
101 | """Test that unknown tools default to using 'prompt' field"""
102 |
103 | mock_context = ThreadContext(
104 | thread_id="test-thread-456",
105 | tool_name="unknown_tool",
106 | created_at=datetime.now().isoformat(),
107 | last_updated_at=datetime.now().isoformat(),
108 | turns=[
109 | ConversationTurn(
110 | role="user",
111 | content="First message",
112 | timestamp=datetime.now().isoformat(),
113 | ),
114 | ConversationTurn(
115 | role="assistant",
116 | content="First response",
117 | timestamp=datetime.now().isoformat(),
118 | ),
119 | ],
120 | initial_context={},
121 | )
122 |
123 | with patch("utils.conversation_memory.get_thread", return_value=mock_context):
124 | with patch("utils.conversation_memory.add_turn", return_value=True):
125 | arguments = {
126 | "continuation_id": "test-thread-456",
127 | "prompt": "User input",
128 | "model": "flash", # Use test model for real integration
129 | }
130 |
131 | enhanced_args = await reconstruct_thread_context(arguments)
132 |
133 | # Should default to 'prompt' field
134 | assert "prompt" in enhanced_args
135 | assert "=== CONVERSATION HISTORY" in enhanced_args["prompt"] # Allow for both formats
136 | assert "First message" in enhanced_args["prompt"]
137 | assert "First response" in enhanced_args["prompt"]
138 | assert "User input" in enhanced_args["prompt"]
139 |
140 |
141 | @pytest.mark.asyncio
142 | async def test_tool_parameter_standardization():
143 | """Test that workflow tools use standardized investigation pattern"""
144 | from tools.analyze import AnalyzeWorkflowRequest
145 | from tools.codereview import CodeReviewRequest
146 | from tools.debug import DebugInvestigationRequest
147 | from tools.precommit import PrecommitRequest
148 | from tools.thinkdeep import ThinkDeepWorkflowRequest
149 |
150 | # Test analyze tool uses workflow pattern
151 | analyze = AnalyzeWorkflowRequest(
152 | step="What does this do?",
153 | step_number=1,
154 | total_steps=1,
155 | next_step_required=False,
156 | findings="Initial analysis",
157 | relevant_files=["/test.py"],
158 | )
159 | assert analyze.step == "What does this do?"
160 |
161 | # Debug tool now uses self-investigation pattern with different fields
162 | debug = DebugInvestigationRequest(
163 | step="Investigating error",
164 | step_number=1,
165 | total_steps=3,
166 | next_step_required=True,
167 | findings="Initial error analysis",
168 | )
169 | assert debug.step == "Investigating error"
170 | assert debug.findings == "Initial error analysis"
171 |
172 | # Test codereview tool uses workflow fields
173 | review = CodeReviewRequest(
174 | step="Initial code review investigation",
175 | step_number=1,
176 | total_steps=2,
177 | next_step_required=True,
178 | findings="Initial review findings",
179 | relevant_files=["/test.py"],
180 | )
181 | assert review.step == "Initial code review investigation"
182 | assert review.findings == "Initial review findings"
183 |
184 | # Test thinkdeep tool uses workflow pattern
185 | think = ThinkDeepWorkflowRequest(
186 | step="My analysis", step_number=1, total_steps=1, next_step_required=False, findings="Initial thinking analysis"
187 | )
188 | assert think.step == "My analysis"
189 |
190 | # Test precommit tool uses workflow fields
191 | precommit = PrecommitRequest(
192 | step="Validating changes for commit",
193 | step_number=1,
194 | total_steps=2,
195 | next_step_required=True,
196 | findings="Initial validation findings",
197 | path="/repo", # path only needed for step 1
198 | )
199 | assert precommit.step == "Validating changes for commit"
200 | assert precommit.findings == "Initial validation findings"
201 |
```
--------------------------------------------------------------------------------
/providers/custom.py:
--------------------------------------------------------------------------------
```python
1 | """Custom API provider implementation."""
2 |
3 | import logging
4 |
5 | from utils.env import get_env
6 |
7 | from .openai_compatible import OpenAICompatibleProvider
8 | from .registries.custom import CustomEndpointModelRegistry
9 | from .registries.openrouter import OpenRouterModelRegistry
10 | from .shared import ModelCapabilities, ProviderType
11 |
12 |
13 | class CustomProvider(OpenAICompatibleProvider):
14 | """Adapter for self-hosted or local OpenAI-compatible endpoints.
15 |
16 | Role
17 | Provide a uniform bridge between the MCP server and user-managed
18 | OpenAI-compatible services (Ollama, vLLM, LM Studio, bespoke gateways).
19 | By subclassing :class:`OpenAICompatibleProvider` it inherits request and
20 | token handling, while the custom registry exposes locally defined model
21 | metadata.
22 |
23 | Notable behaviour
24 | * Uses :class:`OpenRouterModelRegistry` to load model definitions and
25 | aliases so custom deployments share the same metadata pipeline as
26 | OpenRouter itself.
27 | * Normalises version-tagged model names (``model:latest``) and applies
28 | restriction policies just like cloud providers, ensuring consistent
29 | behaviour across environments.
30 | """
31 |
32 | FRIENDLY_NAME = "Custom API"
33 |
34 | # Model registry for managing configurations and aliases
35 | _registry: CustomEndpointModelRegistry | None = None
36 |
37 | def __init__(self, api_key: str = "", base_url: str = "", **kwargs):
38 | """Initialize Custom provider for local/self-hosted models.
39 |
40 | This provider supports any OpenAI-compatible API endpoint including:
41 | - Ollama (typically no API key required)
42 | - vLLM (may require API key)
43 | - LM Studio (may require API key)
44 | - Text Generation WebUI (may require API key)
45 | - Enterprise/self-hosted APIs (typically require API key)
46 |
47 | Args:
48 | api_key: API key for the custom endpoint. Can be empty string for
49 | providers that don't require authentication (like Ollama).
50 | Falls back to CUSTOM_API_KEY environment variable if not provided.
51 | base_url: Base URL for the custom API endpoint (e.g., 'http://localhost:11434/v1').
52 | Falls back to CUSTOM_API_URL environment variable if not provided.
53 | **kwargs: Additional configuration passed to parent OpenAI-compatible provider
54 |
55 | Raises:
56 | ValueError: If no base_url is provided via parameter or environment variable
57 | """
58 | # Fall back to environment variables only if not provided
59 | if not base_url:
60 | base_url = get_env("CUSTOM_API_URL", "") or ""
61 | if not api_key:
62 | api_key = get_env("CUSTOM_API_KEY", "") or ""
63 |
64 | if not base_url:
65 | raise ValueError(
66 | "Custom API URL must be provided via base_url parameter or CUSTOM_API_URL environment variable"
67 | )
68 |
69 | # For Ollama and other providers that don't require authentication,
70 | # set a dummy API key to avoid OpenAI client header issues
71 | if not api_key:
72 | api_key = "dummy-key-for-unauthenticated-endpoint"
73 | logging.debug("Using dummy API key for unauthenticated custom endpoint")
74 |
75 | logging.info(f"Initializing Custom provider with endpoint: {base_url}")
76 |
77 | self._alias_cache: dict[str, str] = {}
78 |
79 | super().__init__(api_key, base_url=base_url, **kwargs)
80 |
81 | # Initialize model registry
82 | if CustomProvider._registry is None:
83 | CustomProvider._registry = CustomEndpointModelRegistry()
84 | # Log loaded models and aliases only on first load
85 | models = self._registry.list_models()
86 | aliases = self._registry.list_aliases()
87 | logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases")
88 |
89 | # ------------------------------------------------------------------
90 | # Capability surface
91 | # ------------------------------------------------------------------
92 | def _lookup_capabilities(
93 | self,
94 | canonical_name: str,
95 | requested_name: str | None = None,
96 | ) -> ModelCapabilities | None:
97 | """Return capabilities for models explicitly marked as custom."""
98 |
99 | builtin = super()._lookup_capabilities(canonical_name, requested_name)
100 | if builtin is not None:
101 | return builtin
102 |
103 | registry_entry = self._registry.resolve(canonical_name)
104 | if registry_entry:
105 | registry_entry.provider = ProviderType.CUSTOM
106 | return registry_entry
107 |
108 | logging.debug(
109 | "Custom provider cannot resolve model '%s'; ensure it is declared in custom_models.json",
110 | canonical_name,
111 | )
112 | return None
113 |
114 | def get_provider_type(self) -> ProviderType:
115 | """Identify this provider for restriction and logging logic."""
116 |
117 | return ProviderType.CUSTOM
118 |
119 | # ------------------------------------------------------------------
120 | # Registry helpers
121 | # ------------------------------------------------------------------
122 |
123 | def _resolve_model_name(self, model_name: str) -> str:
124 | """Resolve registry aliases and strip version tags for local models."""
125 |
126 | cache_key = model_name.lower()
127 | if cache_key in self._alias_cache:
128 | return self._alias_cache[cache_key]
129 |
130 | config = self._registry.resolve(model_name)
131 | if config:
132 | if config.model_name != model_name:
133 | logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name)
134 | resolved = config.model_name
135 | self._alias_cache[cache_key] = resolved
136 | self._alias_cache.setdefault(resolved.lower(), resolved)
137 | return resolved
138 |
139 | if ":" in model_name:
140 | base_model = model_name.split(":")[0]
141 | logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
142 |
143 | base_config = self._registry.resolve(base_model)
144 | if base_config:
145 | logging.debug("Resolved base model '%s' to '%s'", base_model, base_config.model_name)
146 | resolved = base_config.model_name
147 | self._alias_cache[cache_key] = resolved
148 | self._alias_cache.setdefault(resolved.lower(), resolved)
149 | return resolved
150 | self._alias_cache[cache_key] = base_model
151 | return base_model
152 |
153 | logging.debug(f"Model '{model_name}' not found in registry, using as-is")
154 | # Attempt to resolve via OpenRouter registry so aliases still map cleanly
155 | openrouter_registry = OpenRouterModelRegistry()
156 | openrouter_config = openrouter_registry.resolve(model_name)
157 | if openrouter_config:
158 | resolved = openrouter_config.model_name
159 | self._alias_cache[cache_key] = resolved
160 | self._alias_cache.setdefault(resolved.lower(), resolved)
161 | return resolved
162 |
163 | self._alias_cache[cache_key] = model_name
164 | return model_name
165 |
166 | def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
167 | """Expose registry capabilities for models marked as custom."""
168 |
169 | if not self._registry:
170 | return {}
171 |
172 | capabilities = {}
173 | for model in self._registry.list_models():
174 | config = self._registry.resolve(model)
175 | if config:
176 | capabilities[model] = config
177 | return capabilities
178 |
```
--------------------------------------------------------------------------------
/simulator_tests/test_line_number_validation.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Test to validate line number handling across different tools
3 | """
4 |
5 | import json
6 | import os
7 |
8 | from .base_test import BaseSimulatorTest
9 |
10 |
11 | class LineNumberValidationTest(BaseSimulatorTest):
12 | """Test that validates correct line number handling in chat, analyze, and refactor tools"""
13 |
14 | @property
15 | def test_name(self) -> str:
16 | return "line_number_validation"
17 |
18 | @property
19 | def test_description(self) -> str:
20 | return "Line number handling validation across tools"
21 |
22 | def run_test(self) -> bool:
23 | """Test line number handling in different tools"""
24 | try:
25 | self.logger.info("Test: Line number handling validation")
26 |
27 | # Setup test files
28 | self.setup_test_files()
29 |
30 | # Create a test file with known content
31 | test_file_content = '''# Example code with specific elements
32 | def calculate_total(items):
33 | """Calculate total with tax"""
34 | subtotal = 0
35 | tax_rate = 0.08 # Line 5 - tax_rate defined
36 |
37 | for item in items: # Line 7 - loop starts
38 | if item.price > 0:
39 | subtotal += item.price
40 |
41 | tax_amount = subtotal * tax_rate # Line 11
42 | return subtotal + tax_amount
43 |
44 | def validate_data(data):
45 | """Validate input data""" # Line 15
46 | required_fields = ["name", "email", "age"] # Line 16
47 |
48 | for field in required_fields:
49 | if field not in data:
50 | raise ValueError(f"Missing field: {field}")
51 |
52 | return True # Line 22
53 | '''
54 |
55 | test_file_path = os.path.join(self.test_dir, "line_test.py")
56 | with open(test_file_path, "w") as f:
57 | f.write(test_file_content)
58 |
59 | self.logger.info(f"Created test file: {test_file_path}")
60 |
61 | # Test 1: Chat tool asking about specific line
62 | self.logger.info(" 1.1: Testing chat tool with line number question")
63 | content, continuation_id = self.call_mcp_tool(
64 | "chat",
65 | {
66 | "prompt": "Where is tax_rate defined in this file? Please tell me the exact line number.",
67 | "absolute_file_paths": [test_file_path],
68 | "model": "flash",
69 | },
70 | )
71 |
72 | if content:
73 | # Check if the response mentions line 5
74 | if "line 5" in content.lower() or "line 5" in content:
75 | self.logger.info(" ✅ Chat tool correctly identified tax_rate at line 5")
76 | else:
77 | self.logger.warning(f" ⚠️ Chat tool response didn't mention line 5: {content[:200]}...")
78 | else:
79 | self.logger.error(" ❌ Chat tool request failed")
80 | return False
81 |
82 | # Test 2: Analyze tool with line number reference
83 | self.logger.info(" 1.2: Testing analyze tool with line number analysis")
84 | content, continuation_id = self.call_mcp_tool(
85 | "analyze",
86 | {
87 | "prompt": "What happens between lines 7-11 in this code? Focus on the loop logic.",
88 | "absolute_file_paths": [test_file_path],
89 | "model": "flash",
90 | },
91 | )
92 |
93 | if content:
94 | # Check if the response references the loop
95 | if any(term in content.lower() for term in ["loop", "iterate", "line 7", "lines 7"]):
96 | self.logger.info(" ✅ Analyze tool correctly analyzed the specified line range")
97 | else:
98 | self.logger.warning(" ⚠️ Analyze tool response unclear about line range")
99 | else:
100 | self.logger.error(" ❌ Analyze tool request failed")
101 | return False
102 |
103 | # Test 3: Refactor tool with line number precision
104 | self.logger.info(" 1.3: Testing refactor tool line number precision")
105 | content, continuation_id = self.call_mcp_tool(
106 | "refactor",
107 | {
108 | "prompt": "Analyze this code for refactoring opportunities",
109 | "absolute_file_paths": [test_file_path],
110 | "refactor_type": "codesmells",
111 | "model": "flash",
112 | },
113 | )
114 |
115 | if content:
116 | try:
117 | # Parse the JSON response
118 | result = json.loads(content)
119 | if result.get("status") == "refactor_analysis_complete":
120 | opportunities = result.get("refactor_opportunities", [])
121 | if opportunities:
122 | # Check if line numbers are precise
123 | has_line_refs = any(
124 | opp.get("start_line") is not None and opp.get("end_line") is not None
125 | for opp in opportunities
126 | )
127 | if has_line_refs:
128 | self.logger.info(" ✅ Refactor tool provided precise line number references")
129 | # Log some examples
130 | for opp in opportunities[:2]:
131 | if opp.get("start_line"):
132 | self.logger.info(
133 | f" - Issue at lines {opp['start_line']}-{opp['end_line']}: {opp.get('issue', '')[:50]}..."
134 | )
135 | else:
136 | self.logger.warning(" ⚠️ Refactor tool response missing line numbers")
137 | else:
138 | self.logger.info(" ℹ️ No refactoring opportunities found (code might be too clean)")
139 | except json.JSONDecodeError:
140 | self.logger.warning(" ⚠️ Refactor tool response not valid JSON")
141 | else:
142 | self.logger.error(" ❌ Refactor tool request failed")
143 | return False
144 |
145 | # Test 4: Validate log patterns
146 | self.logger.info(" 1.4: Validating line number processing in logs")
147 |
148 | # Get logs from server
149 | try:
150 | log_file_path = "logs/mcp_server.log"
151 | with open(log_file_path) as f:
152 | lines = f.readlines()
153 | logs = "".join(lines[-500:])
154 | except Exception as e:
155 | self.logger.error(f"Failed to read server logs: {e}")
156 | logs = ""
157 | pass
158 |
159 | # Check for line number formatting patterns
160 | line_number_patterns = ["Line numbers for", "enabled", "│", "line number"] # The line number separator
161 |
162 | found_patterns = 0
163 | for pattern in line_number_patterns:
164 | if pattern in logs:
165 | found_patterns += 1
166 |
167 | self.logger.info(f" Found {found_patterns}/{len(line_number_patterns)} line number patterns in logs")
168 |
169 | if found_patterns >= 2:
170 | self.logger.info(" ✅ Line number processing confirmed in logs")
171 | else:
172 | self.logger.warning(" ⚠️ Limited line number processing evidence in logs")
173 |
174 | self.logger.info(" ✅ Line number validation test completed successfully")
175 | return True
176 |
177 | except Exception as e:
178 | self.logger.error(f"Line number validation test failed: {type(e).__name__}: {e}")
179 | return False
180 |
```
--------------------------------------------------------------------------------
/utils/model_context.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Model context management for dynamic token allocation.
3 |
4 | This module provides a clean abstraction for model-specific token management,
5 | ensuring that token limits are properly calculated based on the current model
6 | being used, not global constants.
7 |
8 | CONVERSATION MEMORY INTEGRATION:
9 | This module works closely with the conversation memory system to provide
10 | optimal token allocation for multi-turn conversations:
11 |
12 | 1. DUAL PRIORITIZATION STRATEGY SUPPORT:
13 | - Provides separate token budgets for conversation history vs. files
14 | - Enables the conversation memory system to apply newest-first prioritization
15 | - Ensures optimal balance between context preservation and new content
16 |
17 | 2. MODEL-SPECIFIC ALLOCATION:
18 | - Dynamic allocation based on model capabilities (context window size)
19 | - Conservative allocation for smaller models (O3: 200K context)
20 | - Generous allocation for larger models (Gemini: 1M+ context)
21 | - Adapts token distribution ratios based on model capacity
22 |
23 | 3. CROSS-TOOL CONSISTENCY:
24 | - Provides consistent token budgets across different tools
25 | - Enables seamless conversation continuation between tools
26 | - Supports conversation reconstruction with proper budget management
27 | """
28 |
29 | import logging
30 | from dataclasses import dataclass
31 | from typing import Any, Optional
32 |
33 | from config import DEFAULT_MODEL
34 | from providers import ModelCapabilities, ModelProviderRegistry
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 |
39 | @dataclass
40 | class TokenAllocation:
41 | """Token allocation strategy for a model."""
42 |
43 | total_tokens: int
44 | content_tokens: int
45 | response_tokens: int
46 | file_tokens: int
47 | history_tokens: int
48 |
49 | @property
50 | def available_for_prompt(self) -> int:
51 | """Tokens available for the actual prompt after allocations."""
52 | return self.content_tokens - self.file_tokens - self.history_tokens
53 |
54 |
55 | class ModelContext:
56 | """
57 | Encapsulates model-specific information and token calculations.
58 |
59 | This class provides a single source of truth for all model-related
60 | token calculations, ensuring consistency across the system.
61 | """
62 |
63 | def __init__(self, model_name: str, model_option: Optional[str] = None):
64 | self.model_name = model_name
65 | self.model_option = model_option # Store optional model option (e.g., "for", "against", etc.)
66 | self._provider = None
67 | self._capabilities = None
68 | self._token_allocation = None
69 |
70 | @property
71 | def provider(self):
72 | """Get the model provider lazily."""
73 | if self._provider is None:
74 | self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name)
75 | if not self._provider:
76 | available_models = ModelProviderRegistry.get_available_model_names()
77 | if available_models:
78 | available_text = ", ".join(available_models)
79 | else:
80 | available_text = (
81 | "No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option."
82 | )
83 |
84 | raise ValueError(
85 | f"Model '{self.model_name}' is not available with current API keys. Available models: {available_text}."
86 | )
87 | return self._provider
88 |
89 | @property
90 | def capabilities(self) -> ModelCapabilities:
91 | """Get model capabilities lazily."""
92 | if self._capabilities is None:
93 | self._capabilities = self.provider.get_capabilities(self.model_name)
94 | return self._capabilities
95 |
96 | def calculate_token_allocation(self, reserved_for_response: Optional[int] = None) -> TokenAllocation:
97 | """
98 | Calculate token allocation based on model capacity and conversation requirements.
99 |
100 | This method implements the core token budget calculation that supports the
101 | dual prioritization strategy used in conversation memory and file processing:
102 |
103 | TOKEN ALLOCATION STRATEGY:
104 | 1. CONTENT vs RESPONSE SPLIT:
105 | - Smaller models (< 300K): 60% content, 40% response (conservative)
106 | - Larger models (≥ 300K): 80% content, 20% response (generous)
107 |
108 | 2. CONTENT SUB-ALLOCATION:
109 | - File tokens: 30-40% of content budget for newest file versions
110 | - History tokens: 40-50% of content budget for conversation context
111 | - Remaining: Available for tool-specific prompt content
112 |
113 | 3. CONVERSATION MEMORY INTEGRATION:
114 | - History allocation enables conversation reconstruction in reconstruct_thread_context()
115 | - File allocation supports newest-first file prioritization in tools
116 | - Remaining budget passed to tools via _remaining_tokens parameter
117 |
118 | Args:
119 | reserved_for_response: Override response token reservation
120 |
121 | Returns:
122 | TokenAllocation with calculated budgets for dual prioritization strategy
123 | """
124 | total_tokens = self.capabilities.context_window
125 |
126 | # Dynamic allocation based on model capacity
127 | if total_tokens < 300_000:
128 | # Smaller context models (O3): Conservative allocation
129 | content_ratio = 0.6 # 60% for content
130 | response_ratio = 0.4 # 40% for response
131 | file_ratio = 0.3 # 30% of content for files
132 | history_ratio = 0.5 # 50% of content for history
133 | else:
134 | # Larger context models (Gemini): More generous allocation
135 | content_ratio = 0.8 # 80% for content
136 | response_ratio = 0.2 # 20% for response
137 | file_ratio = 0.4 # 40% of content for files
138 | history_ratio = 0.4 # 40% of content for history
139 |
140 | # Calculate allocations
141 | content_tokens = int(total_tokens * content_ratio)
142 | response_tokens = reserved_for_response or int(total_tokens * response_ratio)
143 |
144 | # Sub-allocations within content budget
145 | file_tokens = int(content_tokens * file_ratio)
146 | history_tokens = int(content_tokens * history_ratio)
147 |
148 | allocation = TokenAllocation(
149 | total_tokens=total_tokens,
150 | content_tokens=content_tokens,
151 | response_tokens=response_tokens,
152 | file_tokens=file_tokens,
153 | history_tokens=history_tokens,
154 | )
155 |
156 | logger.debug(f"Token allocation for {self.model_name}:")
157 | logger.debug(f" Total: {allocation.total_tokens:,}")
158 | logger.debug(f" Content: {allocation.content_tokens:,} ({content_ratio:.0%})")
159 | logger.debug(f" Response: {allocation.response_tokens:,} ({response_ratio:.0%})")
160 | logger.debug(f" Files: {allocation.file_tokens:,} ({file_ratio:.0%} of content)")
161 | logger.debug(f" History: {allocation.history_tokens:,} ({history_ratio:.0%} of content)")
162 |
163 | return allocation
164 |
165 | def estimate_tokens(self, text: str) -> int:
166 | """
167 | Estimate token count for text using model-specific tokenizer.
168 |
169 | For now, uses simple estimation. Can be enhanced with model-specific
170 | tokenizers (tiktoken for OpenAI, etc.) in the future.
171 | """
172 | # TODO: Integrate model-specific tokenizers
173 | # For now, use conservative estimation
174 | return len(text) // 3 # Conservative estimate
175 |
176 | @classmethod
177 | def from_arguments(cls, arguments: dict[str, Any]) -> "ModelContext":
178 | """Create ModelContext from tool arguments."""
179 | model_name = arguments.get("model") or DEFAULT_MODEL
180 | return cls(model_name)
181 |
```
--------------------------------------------------------------------------------
/tools/challenge.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Challenge tool - Encourages critical thinking and thoughtful disagreement
3 |
4 | This tool takes a user's statement and returns it wrapped in instructions that
5 | encourage the CLI agent to challenge ideas and think critically before agreeing. It helps
6 | avoid reflexive agreement by prompting deeper analysis and genuine evaluation.
7 |
8 | This is a simple, self-contained tool that doesn't require AI model access.
9 | """
10 |
11 | from typing import TYPE_CHECKING, Any, Optional
12 |
13 | from pydantic import Field
14 |
15 | if TYPE_CHECKING:
16 | from tools.models import ToolModelCategory
17 |
18 | from config import TEMPERATURE_ANALYTICAL
19 | from tools.shared.base_models import ToolRequest
20 | from tools.shared.exceptions import ToolExecutionError
21 |
22 | from .simple.base import SimpleTool
23 |
24 | # Field descriptions for the Challenge tool
25 | CHALLENGE_FIELD_DESCRIPTIONS = {
26 | "prompt": (
27 | "Statement to scrutinize. If you invoke `challenge` manually, strip the word 'challenge' and pass just the statement. "
28 | "Automatic invocations send the full user message as-is; do not modify it."
29 | ),
30 | }
31 |
32 |
33 | class ChallengeRequest(ToolRequest):
34 | """Request model for Challenge tool"""
35 |
36 | prompt: str = Field(..., description=CHALLENGE_FIELD_DESCRIPTIONS["prompt"])
37 |
38 |
39 | class ChallengeTool(SimpleTool):
40 | """
41 | Challenge tool for encouraging critical thinking and avoiding automatic agreement.
42 |
43 | This tool wraps user statements in instructions that encourage the CLI agent to:
44 | - Challenge ideas and think critically before responding
45 | - Evaluate whether they actually agree or disagree
46 | - Provide thoughtful analysis rather than reflexive agreement
47 |
48 | The tool is self-contained and doesn't require AI model access - it simply
49 | transforms the input prompt into a structured critical thinking challenge.
50 | """
51 |
52 | def get_name(self) -> str:
53 | return "challenge"
54 |
55 | def get_description(self) -> str:
56 | return (
57 | "Prevents reflexive agreement by forcing critical thinking and reasoned analysis when a statement is challenged. "
58 | "Trigger automatically when a user critically questions, disagrees or appears to push back on earlier answers, and use it manually to sanity-check contentious claims."
59 | )
60 |
61 | def get_system_prompt(self) -> str:
62 | # Challenge tool doesn't need a system prompt since it doesn't call AI
63 | return ""
64 |
65 | def get_default_temperature(self) -> float:
66 | return TEMPERATURE_ANALYTICAL
67 |
68 | def get_model_category(self) -> "ToolModelCategory":
69 | """Challenge doesn't need a model category since it doesn't use AI"""
70 | from tools.models import ToolModelCategory
71 |
72 | return ToolModelCategory.FAST_RESPONSE # Default, but not used
73 |
74 | def requires_model(self) -> bool:
75 | """
76 | Challenge tool doesn't require model resolution at the MCP boundary.
77 |
78 | Like the planner tool, this is a pure data processing tool that transforms
79 | the input without calling external AI models.
80 |
81 | Returns:
82 | bool: False - challenge doesn't need AI model access
83 | """
84 | return False
85 |
86 | def get_request_model(self):
87 | """Return the Challenge-specific request model"""
88 | return ChallengeRequest
89 |
90 | def get_input_schema(self) -> dict[str, Any]:
91 | """
92 | Generate input schema for the challenge tool.
93 |
94 | Since this tool doesn't require a model, we exclude model-related fields.
95 | """
96 | schema = {
97 | "type": "object",
98 | "properties": {
99 | "prompt": {
100 | "type": "string",
101 | "description": CHALLENGE_FIELD_DESCRIPTIONS["prompt"],
102 | },
103 | },
104 | "required": ["prompt"],
105 | }
106 |
107 | return schema
108 |
109 | async def execute(self, arguments: dict[str, Any]) -> list:
110 | """
111 | Execute the challenge tool by wrapping the prompt in critical thinking instructions.
112 |
113 | This is the main execution method that transforms the user's statement into
114 | a structured challenge that encourages thoughtful re-evaluation.
115 | """
116 | import json
117 |
118 | from mcp.types import TextContent
119 |
120 | try:
121 | # Validate request
122 | request = self.get_request_model()(**arguments)
123 |
124 | # Wrap the prompt in challenge instructions
125 | wrapped_prompt = self._wrap_prompt_for_challenge(request.prompt)
126 |
127 | # Return the wrapped prompt as the response
128 | response_data = {
129 | "status": "challenge_accepted",
130 | "original_statement": request.prompt,
131 | "challenge_prompt": wrapped_prompt,
132 | "instructions": (
133 | "Present the challenge_prompt to yourself and follow its instructions. "
134 | "Reassess the statement carefully and critically before responding. "
135 | "If, after reflection, you find reasons to disagree or qualify it, explain your reasoning. "
136 | "Likewise, if you find reasons to agree, articulate them clearly and justify your agreement."
137 | ),
138 | }
139 |
140 | return [TextContent(type="text", text=json.dumps(response_data, indent=2, ensure_ascii=False))]
141 |
142 | except ToolExecutionError:
143 | raise
144 | except Exception as e:
145 | import logging
146 |
147 | logger = logging.getLogger(__name__)
148 | logger.error(f"Error in challenge tool execution: {e}", exc_info=True)
149 |
150 | error_data = {
151 | "status": "error",
152 | "error": str(e),
153 | "content": f"Failed to create challenge prompt: {str(e)}",
154 | }
155 |
156 | raise ToolExecutionError(json.dumps(error_data, ensure_ascii=False)) from e
157 |
158 | def _wrap_prompt_for_challenge(self, prompt: str) -> str:
159 | """
160 | Wrap the user's statement in instructions that encourage critical challenge.
161 |
162 | Args:
163 | prompt: The original user statement to wrap
164 |
165 | Returns:
166 | The statement wrapped in challenge instructions
167 | """
168 | return (
169 | f"CRITICAL REASSESSMENT – Do not automatically agree:\n\n"
170 | f'"{prompt}"\n\n'
171 | f"Carefully evaluate the statement above. Is it accurate, complete, and well-reasoned? "
172 | f"Investigate if needed before replying, and stay focused. If you identify flaws, gaps, or misleading "
173 | f"points, explain them clearly. Likewise, if you find the reasoning sound, explain why it holds up. "
174 | f"Respond with thoughtful analysis—stay to the point and avoid reflexive agreement."
175 | )
176 |
177 | # Required method implementations from SimpleTool
178 |
179 | async def prepare_prompt(self, request: ChallengeRequest) -> str:
180 | """Not used since challenge doesn't call AI models"""
181 | return ""
182 |
183 | def format_response(self, response: str, request: ChallengeRequest, model_info: Optional[dict] = None) -> str:
184 | """Not used since challenge doesn't call AI models"""
185 | return response
186 |
187 | def get_tool_fields(self) -> dict[str, dict[str, Any]]:
188 | """Tool-specific field definitions for Challenge"""
189 | return {
190 | "prompt": {
191 | "type": "string",
192 | "description": CHALLENGE_FIELD_DESCRIPTIONS["prompt"],
193 | },
194 | }
195 |
196 | def get_required_fields(self) -> list[str]:
197 | """Required fields for Challenge tool"""
198 | return ["prompt"]
199 |
```
--------------------------------------------------------------------------------
/tests/test_precommit_workflow.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for the workflow-based PrecommitTool
3 |
4 | Tests the core functionality of the precommit workflow tool including:
5 | - Tool metadata and configuration
6 | - Request model validation
7 | - Workflow step handling
8 | - Tool categorization
9 | """
10 |
11 | import pytest
12 |
13 | from tools.models import ToolModelCategory
14 | from tools.precommit import PrecommitRequest, PrecommitTool
15 |
16 |
17 | class TestPrecommitWorkflowTool:
18 | """Test suite for the workflow-based PrecommitTool"""
19 |
20 | def test_tool_metadata(self):
21 | """Test basic tool metadata"""
22 | tool = PrecommitTool()
23 |
24 | assert tool.get_name() == "precommit"
25 | assert "git changes" in tool.get_description()
26 | assert "systematic analysis" in tool.get_description()
27 |
28 | def test_tool_model_category(self):
29 | """Test that precommit tool uses extended reasoning category"""
30 | tool = PrecommitTool()
31 | assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
32 |
33 | def test_default_temperature(self):
34 | """Test analytical temperature setting"""
35 | tool = PrecommitTool()
36 | temp = tool.get_default_temperature()
37 | # Should be analytical temperature (0.2)
38 | assert temp == 0.2
39 |
40 | def test_request_model_basic_validation(self):
41 | """Test basic request model validation"""
42 | # Valid minimal workflow request
43 | request = PrecommitRequest(
44 | step="Initial validation step",
45 | step_number=1,
46 | total_steps=3,
47 | next_step_required=True,
48 | findings="Initial findings",
49 | path="/test/repo", # Required for step 1
50 | )
51 |
52 | assert request.step == "Initial validation step"
53 | assert request.step_number == 1
54 | assert request.total_steps == 3
55 | assert request.next_step_required is True
56 | assert request.findings == "Initial findings"
57 | assert request.path == "/test/repo"
58 |
59 | def test_request_model_step_one_validation(self):
60 | """Test that step 1 requires path field"""
61 | # Step 1 without path should fail
62 | with pytest.raises(ValueError, match="Step 1 requires 'path' field"):
63 | PrecommitRequest(
64 | step="Initial validation step",
65 | step_number=1,
66 | total_steps=3,
67 | next_step_required=True,
68 | findings="Initial findings",
69 | # Missing path for step 1
70 | )
71 |
72 | def test_request_model_later_steps_no_path_required(self):
73 | """Test that later steps don't require path"""
74 | # Step 2+ without path should be fine
75 | request = PrecommitRequest(
76 | step="Continued validation",
77 | step_number=2,
78 | total_steps=3,
79 | next_step_required=True,
80 | findings="Detailed findings",
81 | # No path needed for step 2+
82 | )
83 |
84 | assert request.step_number == 2
85 | assert request.path is None
86 |
87 | def test_request_model_optional_fields(self):
88 | """Test optional workflow fields"""
89 | request = PrecommitRequest(
90 | step="Validation with optional fields",
91 | step_number=1,
92 | total_steps=2,
93 | next_step_required=False,
94 | findings="Comprehensive findings",
95 | path="/test/repo",
96 | precommit_type="external",
97 | files_checked=["/file1.py", "/file2.py"],
98 | relevant_files=["/file1.py"],
99 | relevant_context=["function_name", "class_name"],
100 | issues_found=[{"severity": "medium", "description": "Test issue"}],
101 | images=["/screenshot.png"],
102 | )
103 |
104 | assert request.precommit_type == "external"
105 | assert len(request.files_checked) == 2
106 | assert len(request.relevant_files) == 1
107 | assert len(request.relevant_context) == 2
108 | assert len(request.issues_found) == 1
109 | assert len(request.images) == 1
110 |
111 | def test_precommit_specific_fields(self):
112 | """Test precommit-specific configuration fields"""
113 | request = PrecommitRequest(
114 | step="Validation with git config",
115 | step_number=1,
116 | total_steps=1,
117 | next_step_required=False,
118 | findings="Complete validation",
119 | path="/repo",
120 | compare_to="main",
121 | include_staged=True,
122 | include_unstaged=False,
123 | focus_on="security issues",
124 | severity_filter="high",
125 | )
126 |
127 | assert request.compare_to == "main"
128 | assert request.include_staged is True
129 | assert request.include_unstaged is False
130 | assert request.focus_on == "security issues"
131 | assert request.severity_filter == "high"
132 |
133 | def test_precommit_type_validation(self):
134 | """Test precommit type validation"""
135 | valid_types = ["external", "internal"]
136 |
137 | for precommit_type in valid_types:
138 | request = PrecommitRequest(
139 | step="Test precommit type",
140 | step_number=1,
141 | total_steps=1,
142 | next_step_required=False,
143 | findings="Test findings",
144 | path="/repo",
145 | precommit_type=precommit_type,
146 | )
147 | assert request.precommit_type == precommit_type
148 |
149 | # Test default is external
150 | request = PrecommitRequest(
151 | step="Test default type",
152 | step_number=1,
153 | total_steps=1,
154 | next_step_required=False,
155 | findings="Test findings",
156 | path="/repo",
157 | )
158 | assert request.precommit_type == "external"
159 |
160 | def test_severity_filter_options(self):
161 | """Test severity filter validation"""
162 | valid_severities = ["critical", "high", "medium", "low", "all"]
163 |
164 | for severity in valid_severities:
165 | request = PrecommitRequest(
166 | step="Test severity filter",
167 | step_number=1,
168 | total_steps=1,
169 | next_step_required=False,
170 | findings="Test findings",
171 | path="/repo",
172 | severity_filter=severity,
173 | )
174 | assert request.severity_filter == severity
175 |
176 | def test_input_schema_generation(self):
177 | """Test that input schema is generated correctly"""
178 | tool = PrecommitTool()
179 | schema = tool.get_input_schema()
180 |
181 | # Check basic schema structure
182 | assert schema["type"] == "object"
183 | assert "properties" in schema
184 | assert "required" in schema
185 |
186 | # Check required fields are present
187 | required_fields = {"step", "step_number", "total_steps", "next_step_required", "findings"}
188 | assert all(field in schema["properties"] for field in required_fields)
189 |
190 | # Check model field is present and configured correctly
191 | assert "model" in schema["properties"]
192 | assert schema["properties"]["model"]["type"] == "string"
193 |
194 | def test_workflow_request_model_method(self):
195 | """Test get_workflow_request_model returns correct model"""
196 | tool = PrecommitTool()
197 | assert tool.get_workflow_request_model() == PrecommitRequest
198 | assert tool.get_request_model() == PrecommitRequest
199 |
200 | def test_system_prompt_integration(self):
201 | """Test system prompt integration"""
202 | tool = PrecommitTool()
203 | system_prompt = tool.get_system_prompt()
204 |
205 | # Should get the precommit prompt
206 | assert isinstance(system_prompt, str)
207 | assert len(system_prompt) > 0
208 |
```
--------------------------------------------------------------------------------
/tests/test_chat_cross_model_continuation.py:
--------------------------------------------------------------------------------
```python
1 | """Cross-provider continuation tests for ChatTool."""
2 |
3 | from __future__ import annotations
4 |
5 | import json
6 | import os
7 | import re
8 | import uuid
9 | from pathlib import Path
10 |
11 | import pytest
12 |
13 | from providers.registry import ModelProviderRegistry
14 | from providers.shared import ProviderType
15 | from tests.transport_helpers import inject_transport
16 | from tools.chat import ChatTool
17 |
18 | CASSETTE_DIR = Path(__file__).parent / "openai_cassettes"
19 | CASSETTE_DIR.mkdir(exist_ok=True)
20 | OPENAI_CASSETTE_PATH = CASSETTE_DIR / "chat_cross_step2_gpt5_reminder.json"
21 |
22 | GEMINI_CASSETTE_DIR = Path(__file__).parent / "gemini_cassettes"
23 | GEMINI_CASSETTE_DIR.mkdir(exist_ok=True)
24 | GEMINI_REPLAY_ID = "chat_cross/step1_gemini25_flash_number/mldev"
25 | GEMINI_REPLAY_PATH = GEMINI_CASSETTE_DIR / "chat_cross" / "step1_gemini25_flash_number" / "mldev.json"
26 |
27 | FIXED_THREAD_ID = uuid.UUID("dbadc23e-c0f4-4853-982f-6c5bc722b5de")
28 |
29 |
30 | WORD_TO_NUMBER = {
31 | "one": 1,
32 | "two": 2,
33 | "three": 3,
34 | "four": 4,
35 | "five": 5,
36 | "six": 6,
37 | "seven": 7,
38 | "eight": 8,
39 | "nine": 9,
40 | "ten": 10,
41 | }
42 |
43 |
44 | def _extract_number(text: str) -> str:
45 | digit_match = re.search(r"\b(\d{1,2})\b", text)
46 | if digit_match:
47 | return digit_match.group(1)
48 |
49 | lower_text = text.lower()
50 | for word, value in WORD_TO_NUMBER.items():
51 | if re.search(rf"\b{word}\b", lower_text):
52 | return str(value)
53 | return ""
54 |
55 |
56 | @pytest.mark.asyncio
57 | @pytest.mark.no_mock_provider
58 | async def test_chat_cross_model_continuation(monkeypatch, tmp_path):
59 | """Verify continuation across Gemini then OpenAI using recorded interactions."""
60 |
61 | env_updates = {
62 | "DEFAULT_MODEL": "auto",
63 | "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY", ""),
64 | "GEMINI_API_KEY": os.getenv("GEMINI_API_KEY", ""),
65 | }
66 | keys_to_clear = [
67 | "XAI_API_KEY",
68 | "OPENROUTER_API_KEY",
69 | "ANTHROPIC_API_KEY",
70 | "MISTRAL_API_KEY",
71 | "CUSTOM_API_KEY",
72 | "CUSTOM_API_URL",
73 | ]
74 |
75 | recording_mode = not OPENAI_CASSETTE_PATH.exists() or not GEMINI_REPLAY_PATH.exists()
76 | if recording_mode:
77 | openai_key = env_updates["OPENAI_API_KEY"].strip()
78 | gemini_key = env_updates["GEMINI_API_KEY"].strip()
79 | if (not openai_key or openai_key.startswith("dummy")) or (not gemini_key or gemini_key.startswith("dummy")):
80 | pytest.skip(
81 | "Cross-provider cassette missing and OPENAI_API_KEY/GEMINI_API_KEY not configured. Provide real keys to record."
82 | )
83 |
84 | GEMINI_REPLAY_PATH.parent.mkdir(parents=True, exist_ok=True)
85 |
86 | # Step 1 – Gemini picks a number
87 | with monkeypatch.context() as m:
88 | m.setenv("DEFAULT_MODEL", env_updates["DEFAULT_MODEL"])
89 | m.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-flash")
90 | m.setenv("OPENAI_ALLOWED_MODELS", "gpt-5")
91 | if recording_mode:
92 | m.setenv("OPENAI_API_KEY", env_updates["OPENAI_API_KEY"])
93 | m.setenv("GEMINI_API_KEY", env_updates["GEMINI_API_KEY"])
94 | m.setenv("GOOGLE_GENAI_CLIENT_MODE", "record")
95 | else:
96 | m.setenv("OPENAI_API_KEY", "dummy-key-for-replay")
97 | m.setenv("GEMINI_API_KEY", "dummy-key-for-replay")
98 | m.setenv("GOOGLE_GENAI_CLIENT_MODE", "replay")
99 |
100 | m.setenv("GOOGLE_GENAI_REPLAYS_DIRECTORY", str(GEMINI_CASSETTE_DIR))
101 | m.setenv("GOOGLE_GENAI_REPLAY_ID", GEMINI_REPLAY_ID)
102 |
103 | for key in keys_to_clear:
104 | m.delenv(key, raising=False)
105 |
106 | ModelProviderRegistry.reset_for_testing()
107 | from providers.gemini import GeminiModelProvider
108 | from providers.openai import OpenAIModelProvider
109 |
110 | ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
111 | ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
112 |
113 | from utils import conversation_memory
114 |
115 | m.setattr(conversation_memory.uuid, "uuid4", lambda: FIXED_THREAD_ID)
116 |
117 | chat_tool = ChatTool()
118 | working_directory = str(tmp_path)
119 |
120 | step1_args = {
121 | "prompt": "Pick a number between 1 and 10 and respond with JUST that number.",
122 | "model": "gemini-2.5-flash",
123 | "temperature": 0.2,
124 | "working_directory_absolute_path": working_directory,
125 | }
126 |
127 | step1_result = await chat_tool.execute(step1_args)
128 | assert step1_result and step1_result[0].type == "text"
129 |
130 | step1_data = json.loads(step1_result[0].text)
131 | assert step1_data["status"] in {"success", "continuation_available"}
132 | assert step1_data.get("metadata", {}).get("provider_used") == "google"
133 | continuation_offer = step1_data.get("continuation_offer")
134 | assert continuation_offer is not None
135 | continuation_id = continuation_offer["continuation_id"]
136 | assert continuation_id
137 |
138 | chosen_number = _extract_number(step1_data["content"])
139 | assert chosen_number.isdigit()
140 | assert 1 <= int(chosen_number) <= 10
141 |
142 | # Ensure replay is flushed for Gemini recordings
143 | gemini_provider = ModelProviderRegistry.get_provider_for_model("gemini-2.5-flash")
144 | if gemini_provider is not None:
145 | try:
146 | client = gemini_provider.client
147 | if hasattr(client, "close"):
148 | client.close()
149 | finally:
150 | if hasattr(gemini_provider, "_client"):
151 | gemini_provider._client = None
152 |
153 | assert GEMINI_REPLAY_PATH.exists()
154 |
155 | # Step 2 – gpt-5 recalls the number via continuation
156 | with monkeypatch.context() as m:
157 | if recording_mode:
158 | m.setenv("OPENAI_API_KEY", env_updates["OPENAI_API_KEY"])
159 | m.setenv("GEMINI_API_KEY", env_updates["GEMINI_API_KEY"])
160 | m.setenv("GOOGLE_GENAI_CLIENT_MODE", "record")
161 | else:
162 | m.setenv("OPENAI_API_KEY", "dummy-key-for-replay")
163 | m.setenv("GEMINI_API_KEY", "dummy-key-for-replay")
164 | m.setenv("GOOGLE_GENAI_CLIENT_MODE", "replay")
165 |
166 | m.setenv("DEFAULT_MODEL", env_updates["DEFAULT_MODEL"])
167 | m.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-flash")
168 | m.setenv("OPENAI_ALLOWED_MODELS", "gpt-5")
169 | m.setenv("GOOGLE_GENAI_REPLAYS_DIRECTORY", str(GEMINI_CASSETTE_DIR))
170 | m.setenv("GOOGLE_GENAI_REPLAY_ID", GEMINI_REPLAY_ID)
171 | for key in keys_to_clear:
172 | m.delenv(key, raising=False)
173 |
174 | ModelProviderRegistry.reset_for_testing()
175 | from providers.gemini import GeminiModelProvider
176 | from providers.openai import OpenAIModelProvider
177 |
178 | ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
179 | ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
180 |
181 | inject_transport(monkeypatch, OPENAI_CASSETTE_PATH)
182 |
183 | chat_tool = ChatTool()
184 | step2_args = {
185 | "prompt": "Remind me, what number did you pick, respond with JUST that number.",
186 | "model": "gpt-5",
187 | "continuation_id": continuation_id,
188 | "temperature": 0.2,
189 | "working_directory_absolute_path": working_directory,
190 | }
191 |
192 | step2_result = await chat_tool.execute(step2_args)
193 | assert step2_result and step2_result[0].type == "text"
194 |
195 | step2_data = json.loads(step2_result[0].text)
196 | assert step2_data["status"] in {"success", "continuation_available"}
197 | assert step2_data.get("metadata", {}).get("provider_used") == "openai"
198 |
199 | recalled_number = _extract_number(step2_data["content"])
200 | assert recalled_number == chosen_number
201 |
202 | assert OPENAI_CASSETTE_PATH.exists()
203 |
204 | ModelProviderRegistry.reset_for_testing()
205 |
```
--------------------------------------------------------------------------------
/tests/test_auto_mode_model_listing.py:
--------------------------------------------------------------------------------
```python
1 | """Tests covering model restriction-aware error messaging in auto mode."""
2 |
3 | import asyncio
4 | import importlib
5 | import json
6 |
7 | import pytest
8 |
9 | import utils.env as env_config
10 | import utils.model_restrictions as model_restrictions
11 | from providers.gemini import GeminiModelProvider
12 | from providers.openai import OpenAIModelProvider
13 | from providers.openrouter import OpenRouterProvider
14 | from providers.registry import ModelProviderRegistry
15 | from providers.shared import ProviderType
16 | from providers.xai import XAIModelProvider
17 | from tools.shared.exceptions import ToolExecutionError
18 |
19 |
20 | def _extract_available_models(message: str) -> list[str]:
21 | """Parse the available model list from the error message."""
22 |
23 | marker = "Available models: "
24 | if marker not in message:
25 | raise AssertionError(f"Expected '{marker}' in message: {message}")
26 |
27 | start = message.index(marker) + len(marker)
28 | end = message.find(". Suggested", start)
29 | if end == -1:
30 | end = len(message)
31 |
32 | available_segment = message[start:end].strip()
33 | if not available_segment:
34 | return []
35 |
36 | return [item.strip() for item in available_segment.split(",")]
37 |
38 |
39 | @pytest.fixture
40 | def reset_registry():
41 | """Ensure registry and restriction service state is isolated."""
42 |
43 | ModelProviderRegistry.reset_for_testing()
44 | model_restrictions._restriction_service = None
45 | env_config.reload_env()
46 | yield
47 | ModelProviderRegistry.reset_for_testing()
48 | model_restrictions._restriction_service = None
49 |
50 |
51 | def _register_core_providers(*, include_xai: bool = False):
52 | ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
53 | ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
54 | ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
55 | if include_xai:
56 | ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
57 |
58 |
59 | @pytest.mark.no_mock_provider
60 | def test_error_listing_respects_env_restrictions(monkeypatch, reset_registry):
61 | """Error payload should surface only the allowed models for each provider."""
62 |
63 | monkeypatch.setenv("DEFAULT_MODEL", "auto")
64 | monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
65 | monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
66 | monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
67 | monkeypatch.delenv("XAI_API_KEY", raising=False)
68 | # Ensure Azure provider stays disabled regardless of developer workstation env
69 | for azure_var in (
70 | "AZURE_OPENAI_API_KEY",
71 | "AZURE_OPENAI_ENDPOINT",
72 | "AZURE_OPENAI_ALLOWED_MODELS",
73 | "AZURE_MODELS_CONFIG_PATH",
74 | ):
75 | monkeypatch.delenv(azure_var, raising=False)
76 | monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
77 | env_config.reload_env({"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
78 | try:
79 | import dotenv
80 |
81 | monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
82 | except ModuleNotFoundError:
83 | pass
84 |
85 | monkeypatch.setenv("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro")
86 | monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "gpt-5")
87 | monkeypatch.setenv("OPENROUTER_ALLOWED_MODELS", "gpt5nano")
88 | monkeypatch.setenv("XAI_ALLOWED_MODELS", "")
89 |
90 | import config
91 |
92 | importlib.reload(config)
93 |
94 | _register_core_providers()
95 |
96 | import server
97 |
98 | importlib.reload(server)
99 |
100 | # Reload may have re-applied .env overrides; enforce our test configuration
101 | for key, value in (
102 | ("DEFAULT_MODEL", "auto"),
103 | ("GEMINI_API_KEY", "test-gemini"),
104 | ("OPENAI_API_KEY", "test-openai"),
105 | ("OPENROUTER_API_KEY", "test-openrouter"),
106 | ("GOOGLE_ALLOWED_MODELS", "gemini-2.5-pro"),
107 | ("OPENAI_ALLOWED_MODELS", "gpt-5"),
108 | ("OPENROUTER_ALLOWED_MODELS", "gpt5nano"),
109 | ("XAI_ALLOWED_MODELS", ""),
110 | ):
111 | monkeypatch.setenv(key, value)
112 |
113 | for var in ("XAI_API_KEY", "CUSTOM_API_URL", "CUSTOM_API_KEY", "DIAL_API_KEY"):
114 | monkeypatch.delenv(var, raising=False)
115 | for azure_var in (
116 | "AZURE_OPENAI_API_KEY",
117 | "AZURE_OPENAI_ENDPOINT",
118 | "AZURE_OPENAI_ALLOWED_MODELS",
119 | "AZURE_MODELS_CONFIG_PATH",
120 | ):
121 | monkeypatch.delenv(azure_var, raising=False)
122 |
123 | ModelProviderRegistry.reset_for_testing()
124 | model_restrictions._restriction_service = None
125 | server.configure_providers()
126 |
127 | with pytest.raises(ToolExecutionError) as exc_info:
128 | asyncio.run(
129 | server.handle_call_tool(
130 | "chat",
131 | {
132 | "model": "gpt5mini",
133 | "prompt": "Tell me about your strengths",
134 | },
135 | )
136 | )
137 |
138 | payload = json.loads(exc_info.value.payload)
139 | assert payload["status"] == "error"
140 |
141 | available_models = _extract_available_models(payload["content"])
142 | assert set(available_models) == {"gemini-2.5-pro", "gpt-5", "gpt5nano", "openai/gpt-5-nano"}
143 |
144 |
145 | @pytest.mark.no_mock_provider
146 | def test_error_listing_without_restrictions_shows_full_catalog(monkeypatch, reset_registry):
147 | """When no restrictions are set, the full high-capability catalogue should appear."""
148 |
149 | monkeypatch.setenv("DEFAULT_MODEL", "auto")
150 | monkeypatch.setenv("GEMINI_API_KEY", "test-gemini")
151 | monkeypatch.setenv("OPENAI_API_KEY", "test-openai")
152 | monkeypatch.setenv("OPENROUTER_API_KEY", "test-openrouter")
153 | monkeypatch.setenv("XAI_API_KEY", "test-xai")
154 | monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
155 | for azure_var in (
156 | "AZURE_OPENAI_API_KEY",
157 | "AZURE_OPENAI_ENDPOINT",
158 | "AZURE_OPENAI_ALLOWED_MODELS",
159 | "AZURE_MODELS_CONFIG_PATH",
160 | ):
161 | monkeypatch.delenv(azure_var, raising=False)
162 | env_config.reload_env({"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
163 | try:
164 | import dotenv
165 |
166 | monkeypatch.setattr(dotenv, "dotenv_values", lambda *_args, **_kwargs: {"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
167 | except ModuleNotFoundError:
168 | pass
169 |
170 | for var in (
171 | "GOOGLE_ALLOWED_MODELS",
172 | "OPENAI_ALLOWED_MODELS",
173 | "OPENROUTER_ALLOWED_MODELS",
174 | "XAI_ALLOWED_MODELS",
175 | "DIAL_ALLOWED_MODELS",
176 | ):
177 | monkeypatch.delenv(var, raising=False)
178 |
179 | import config
180 |
181 | importlib.reload(config)
182 |
183 | _register_core_providers(include_xai=True)
184 |
185 | import server
186 |
187 | importlib.reload(server)
188 |
189 | for key, value in (
190 | ("DEFAULT_MODEL", "auto"),
191 | ("GEMINI_API_KEY", "test-gemini"),
192 | ("OPENAI_API_KEY", "test-openai"),
193 | ("OPENROUTER_API_KEY", "test-openrouter"),
194 | ):
195 | monkeypatch.setenv(key, value)
196 |
197 | for var in (
198 | "GOOGLE_ALLOWED_MODELS",
199 | "OPENAI_ALLOWED_MODELS",
200 | "OPENROUTER_ALLOWED_MODELS",
201 | "XAI_ALLOWED_MODELS",
202 | "DIAL_ALLOWED_MODELS",
203 | "CUSTOM_API_URL",
204 | "CUSTOM_API_KEY",
205 | ):
206 | monkeypatch.delenv(var, raising=False)
207 |
208 | ModelProviderRegistry.reset_for_testing()
209 | model_restrictions._restriction_service = None
210 | server.configure_providers()
211 |
212 | with pytest.raises(ToolExecutionError) as exc_info:
213 | asyncio.run(
214 | server.handle_call_tool(
215 | "chat",
216 | {
217 | "model": "dummymodel",
218 | "prompt": "Hi there",
219 | },
220 | )
221 | )
222 |
223 | payload = json.loads(exc_info.value.payload)
224 | assert payload["status"] == "error"
225 |
226 | available_models = _extract_available_models(payload["content"])
227 | assert "gemini-2.5-pro" in available_models
228 | assert "gpt-5" in available_models
229 | assert "grok-4" in available_models
230 | assert len(available_models) >= 5
231 |
```
--------------------------------------------------------------------------------
/simulator_tests/test_prompt_size_limit_bug.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Prompt Size Limit Bug Test
4 |
5 | This test reproduces a critical bug where the prompt size limit check
6 | incorrectly includes conversation history when validating incoming prompts
7 | from Claude to MCP. The limit should ONLY apply to the actual prompt text
8 | sent by the user, not the entire conversation context.
9 |
10 | Bug Scenario:
11 | - User starts a conversation with chat tool
12 | - Continues conversation multiple times (building up history)
13 | - On subsequent continuation, a short prompt (150 chars) triggers
14 | "resend_prompt" error claiming >50k characters
15 |
16 | Expected Behavior:
17 | - Only count the actual prompt parameter for size limit
18 | - Conversation history should NOT count toward prompt size limit
19 | - Only the user's actual input should be validated against 50k limit
20 | """
21 |
22 | from .conversation_base_test import ConversationBaseTest
23 |
24 |
25 | class PromptSizeLimitBugTest(ConversationBaseTest):
26 | """Test to reproduce and verify fix for prompt size limit bug"""
27 |
28 | @property
29 | def test_name(self) -> str:
30 | return "prompt_size_limit_bug"
31 |
32 | @property
33 | def test_description(self) -> str:
34 | return "Reproduce prompt size limit bug with conversation continuation"
35 |
36 | def run_test(self) -> bool:
37 | """Test prompt size limit bug reproduction using in-process calls"""
38 | try:
39 | self.logger.info("🐛 Test: Prompt size limit bug reproduction (in-process)")
40 |
41 | # Setup test environment
42 | self.setUp()
43 |
44 | # Create a test file to provide context
45 | test_file_content = """
46 | # Test SwiftUI-like Framework Implementation
47 |
48 | struct ContentView: View {
49 | @State private var counter = 0
50 |
51 | var body: some View {
52 | VStack {
53 | Text("Count: \\(counter)")
54 | Button("Increment") {
55 | counter += 1
56 | }
57 | }
58 | }
59 | }
60 |
61 | class Renderer {
62 | static let shared = Renderer()
63 |
64 | func render(view: View) {
65 | // Implementation details for UIKit/AppKit rendering
66 | }
67 | }
68 |
69 | protocol View {
70 | var body: some View { get }
71 | }
72 | """
73 | test_file_path = self.create_additional_test_file("SwiftFramework.swift", test_file_content)
74 |
75 | # Step 1: Start initial conversation
76 | self.logger.info(" Step 1: Start conversation with initial context")
77 |
78 | initial_prompt = "I'm building a SwiftUI-like framework. Can you help me design the architecture?"
79 |
80 | response1, continuation_id = self.call_mcp_tool_direct(
81 | "chat",
82 | {
83 | "prompt": initial_prompt,
84 | "absolute_file_paths": [test_file_path],
85 | "model": "flash",
86 | },
87 | )
88 |
89 | if not response1 or not continuation_id:
90 | self.logger.error(" ❌ Failed to start initial conversation")
91 | return False
92 |
93 | self.logger.info(f" ✅ Initial conversation started: {continuation_id[:8]}...")
94 |
95 | # Step 2: Continue conversation multiple times to build substantial history
96 | conversation_prompts = [
97 | "That's helpful! Can you elaborate on the View protocol design?",
98 | "How should I implement the State property wrapper?",
99 | "What's the best approach for the VStack layout implementation?",
100 | "Should I use UIKit directly or create an abstraction layer?",
101 | "Smart approach! For the rendering layer, would you suggest UIKit/AppKit directly?",
102 | ]
103 |
104 | for i, prompt in enumerate(conversation_prompts, 2):
105 | self.logger.info(f" Step {i}: Continue conversation (exchange {i})")
106 |
107 | response, _ = self.call_mcp_tool_direct(
108 | "chat",
109 | {
110 | "prompt": prompt,
111 | "continuation_id": continuation_id,
112 | "model": "flash",
113 | },
114 | )
115 |
116 | if not response:
117 | self.logger.error(f" ❌ Failed at exchange {i}")
118 | return False
119 |
120 | self.logger.info(f" ✅ Exchange {i} completed")
121 |
122 | # Step 3: Send short prompt that should NOT trigger size limit
123 | self.logger.info(" Step 7: Send short prompt (should NOT trigger size limit)")
124 |
125 | # This is a very short prompt - should not trigger the bug after fix
126 | short_prompt = "Thanks! This gives me a solid foundation to start prototyping."
127 |
128 | self.logger.info(f" Short prompt length: {len(short_prompt)} characters")
129 |
130 | response_final, _ = self.call_mcp_tool_direct(
131 | "chat",
132 | {
133 | "prompt": short_prompt,
134 | "continuation_id": continuation_id,
135 | "model": "flash",
136 | },
137 | )
138 |
139 | if not response_final:
140 | self.logger.error(" ❌ Final short prompt failed")
141 | return False
142 |
143 | # Parse the response to check for the bug
144 | import json
145 |
146 | try:
147 | response_data = json.loads(response_final)
148 | status = response_data.get("status", "")
149 |
150 | if status == "resend_prompt":
151 | # This is the bug! Short prompt incorrectly triggering size limit
152 | metadata = response_data.get("metadata", {})
153 | prompt_size = metadata.get("prompt_size", 0)
154 |
155 | self.logger.error(
156 | f" 🐛 BUG STILL EXISTS: Short prompt ({len(short_prompt)} chars) triggered resend_prompt"
157 | )
158 | self.logger.error(f" Reported prompt_size: {prompt_size} (should be ~{len(short_prompt)})")
159 | self.logger.error(" This indicates conversation history is still being counted")
160 |
161 | return False # Bug still exists
162 |
163 | elif status in ["success", "continuation_available"]:
164 | self.logger.info(" ✅ Short prompt processed correctly - bug appears to be FIXED!")
165 | self.logger.info(f" Prompt length: {len(short_prompt)} chars, Status: {status}")
166 | return True
167 |
168 | else:
169 | self.logger.warning(f" ⚠️ Unexpected status: {status}")
170 | # Check if this might be a non-JSON response (successful execution)
171 | if len(response_final) > 0 and not response_final.startswith('{"'):
172 | self.logger.info(" ✅ Non-JSON response suggests successful tool execution")
173 | return True
174 | return False
175 |
176 | except json.JSONDecodeError:
177 | # Non-JSON response often means successful tool execution
178 | self.logger.info(" ✅ Non-JSON response suggests successful tool execution (bug likely fixed)")
179 | self.logger.debug(f" Response preview: {response_final[:200]}...")
180 | return True
181 |
182 | except Exception as e:
183 | self.logger.error(f"Prompt size limit bug test failed: {e}")
184 | import traceback
185 |
186 | self.logger.debug(f"Full traceback: {traceback.format_exc()}")
187 | return False
188 |
189 |
190 | def main():
191 | """Run the prompt size limit bug test"""
192 | import sys
193 |
194 | verbose = "--verbose" in sys.argv or "-v" in sys.argv
195 | test = PromptSizeLimitBugTest(verbose=verbose)
196 |
197 | success = test.run_test()
198 | if success:
199 | print("Bug reproduction test completed - check logs for details")
200 | else:
201 | print("Test failed to complete")
202 | sys.exit(0 if success else 1)
203 |
204 |
205 | if __name__ == "__main__":
206 | main()
207 |
```
--------------------------------------------------------------------------------
/systemprompts/debug_prompt.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Debug tool system prompt
3 | """
4 |
5 | DEBUG_ISSUE_PROMPT = """
6 | ROLE
7 | You are an expert debugging assistant receiving systematic investigation findings from another AI agent.
8 | The agent has performed methodical investigation work following systematic debugging methodology.
9 | Your role is to provide expert analysis based on the comprehensive investigation presented to you.
10 |
11 | SYSTEMATIC INVESTIGATION CONTEXT
12 | The agent has followed a systematic investigation approach:
13 | 1. Methodical examination of error reports and symptoms
14 | 2. Step-by-step code analysis and evidence collection
15 | 3. Use of tracer tool for complex method interactions when needed
16 | 4. Hypothesis formation and testing against actual code
17 | 5. Documentation of findings and investigation evolution
18 |
19 | You are receiving:
20 | 1. Issue description and original symptoms
21 | 2. The agent's systematic investigation findings (comprehensive analysis)
22 | 3. Essential files identified as critical for understanding the issue
23 | 4. Error context, logs, and diagnostic information
24 | 5. Tracer tool analysis results (if complex flow analysis was needed)
25 |
26 | TRACER TOOL INTEGRATION AWARENESS
27 | If the agent used the tracer tool during investigation, the findings will include:
28 | - Method call flow analysis
29 | - Class dependency mapping
30 | - Side effect identification
31 | - Execution path tracing
32 | This provides deep understanding of how code interactions contribute to the issue.
33 |
34 | CRITICAL LINE NUMBER INSTRUCTIONS
35 | Code is presented with line number markers "LINE│ code". These markers are for reference ONLY and MUST NOT be
36 | included in any code you generate. Always reference specific line numbers in your replies in order to locate
37 | exact positions if needed to point to exact locations. Include a very short code excerpt alongside for clarity.
38 | Include context_start_text and context_end_text as backup references. Never include "LINE│" markers in generated code
39 | snippets.
40 |
41 | WORKFLOW CONTEXT
42 | Your task is to analyze the systematic investigation given to you and provide expert debugging analysis back to the
43 | agent, who will then present the findings to the user in a consolidated format.
44 |
45 | STRUCTURED JSON OUTPUT FORMAT
46 | You MUST respond with a properly formatted JSON object following this exact schema.
47 | Do NOT include any text before or after the JSON. The response must be valid JSON only.
48 |
49 | IF MORE INFORMATION IS NEEDED:
50 | If you lack critical information to proceed, you MUST only respond with the following:
51 | {
52 | "status": "files_required_to_continue",
53 | "mandatory_instructions": "<your critical instructions for the agent>",
54 | "files_needed": ["[file name here]", "[or some folder/]"]
55 | }
56 |
57 | IF NO BUG FOUND AFTER THOROUGH INVESTIGATION:
58 | If after a very thorough investigation, no concrete evidence of a bug is found correlating to reported symptoms, you
59 | MUST only respond with the following:
60 | {
61 | "status": "no_bug_found",
62 | "summary": "<summary of what was thoroughly investigated>",
63 | "investigation_steps": ["<step 1>", "<step 2>", "..."],
64 | "areas_examined": ["<code areas>", "<potential failure points>", "..."],
65 | "confidence_level": "High|Medium|Low",
66 | "alternative_explanations": ["<possible misunderstanding>", "<user expectation mismatch>", "..."],
67 | "recommended_questions": ["<question 1 to clarify the issue>", "<question 2 to gather more context>", "..."],
68 | "next_steps": ["<suggested actions to better understand the reported issue>"]
69 | }
70 |
71 | FOR COMPLETE ANALYSIS:
72 | {
73 | "status": "analysis_complete",
74 | "summary": "<brief description of the problem and its impact>",
75 | "investigation_steps": [
76 | "<step 1: what you analyzed first>",
77 | "<step 2: what you discovered next>",
78 | "<step 3: how findings evolved>",
79 | "..."
80 | ],
81 | "hypotheses": [
82 | {
83 | "name": "<HYPOTHESIS NAME>",
84 | "confidence": "High|Medium|Low",
85 | "root_cause": "<technical explanation>",
86 | "evidence": "<logs or code clues supporting this hypothesis>",
87 | "correlation": "<how symptoms map to the cause>",
88 | "validation": "<quick test to confirm>",
89 | "minimal_fix": "<smallest change to resolve the issue>",
90 | "regression_check": "<why this fix is safe>",
91 | "file_references": ["<file:line format for exact locations>"],
92 | "function_name": "<optional: specific function/method name if identified>",
93 | "start_line": "<optional: starting line number if specific location identified>",
94 | "end_line": "<optional: ending line number if specific location identified>",
95 | "context_start_text": "<optional: exact text from start line for verification>",
96 | "context_end_text": "<optional: exact text from end line for verification>"
97 | }
98 | ],
99 | "key_findings": [
100 | "<finding 1: important discoveries made during analysis>",
101 | "<finding 2: code patterns or issues identified>",
102 | "<finding 3: invalidated assumptions or refined understanding>"
103 | ],
104 | "immediate_actions": [
105 | "<action 1: steps to take regardless of which hypothesis is correct>",
106 | "<action 2: additional logging or monitoring needed>"
107 | ],
108 | "recommended_tools": [
109 | "<tool recommendation if additional analysis needed, e.g., 'tracer tool for call flow analysis'>"
110 | ],
111 | "prevention_strategy": "<optional: targeted measures to prevent this exact issue from recurring>",
112 | "investigation_summary": "<comprehensive summary of the complete investigation process and final conclusions>"
113 | }
114 |
115 | CRITICAL DEBUGGING PRINCIPLES:
116 | 1. Bugs can ONLY be found and fixed from given code - these cannot be made up or imagined
117 | 2. Focus ONLY on the reported issue - avoid suggesting extensive refactoring or unrelated improvements
118 | 3. Propose minimal fixes that address the specific problem without introducing regressions
119 | 4. Document your investigation process systematically for future reference
120 | 5. Rank hypotheses by likelihood based on evidence from the actual code and logs provided
121 | 6. Always include specific file:line references for exact locations of issues
122 | 7. CRITICAL: If the agent's investigation finds no concrete evidence of a bug correlating to reported symptoms,
123 | you should consider that the reported issue may not actually exist, may be a misunderstanding, or may be
124 | conflated with something else entirely. In such cases, recommend gathering more information from the user
125 | through targeted questioning rather than continuing to hunt for non-existent bugs
126 |
127 | PRECISE LOCATION REFERENCES:
128 | When you identify specific code locations for hypotheses, include optional precision fields:
129 | - function_name: The exact function/method name where the issue occurs
130 | - start_line/end_line: Line numbers from the LINE│ markers (for reference ONLY - never include LINE│ in generated code)
131 | - context_start_text/context_end_text: Exact text from those lines for verification
132 | - These fields help the agent locate exact positions for implementing fixes
133 |
134 | REGRESSION PREVENTION: Before suggesting any fix, thoroughly analyze the proposed change to ensure it does not
135 | introduce new issues or break existing functionality. Consider:
136 | - How the change might affect other parts of the codebase
137 | - Whether the fix could impact related features or workflows
138 | - If the solution maintains backward compatibility
139 | - What potential side effects or unintended consequences might occur
140 |
141 | Your debugging approach should generate focused hypotheses ranked by likelihood, with emphasis on identifying
142 | the exact root cause and implementing minimal, targeted fixes while maintaining comprehensive documentation
143 | of the investigation process.
144 |
145 | Your analysis should build upon the agent's systematic investigation to provide:
146 | - Expert validation of hypotheses
147 | - Additional insights based on systematic findings
148 | - Specific implementation guidance for fixes
149 | - Regression prevention analysis
150 | """
151 |
```
--------------------------------------------------------------------------------
/tests/test_docker_config_complete.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Complete configuration test for Docker MCP
3 | """
4 |
5 | import os
6 | from pathlib import Path
7 | from unittest.mock import patch
8 |
9 | import pytest
10 |
11 |
12 | class TestDockerMCPConfiguration:
13 | """Docker MCP configuration tests"""
14 |
15 | def test_dockerfile_configuration(self):
16 | """Test Dockerfile configuration"""
17 | project_root = Path(__file__).parent.parent
18 | dockerfile = project_root / "Dockerfile"
19 |
20 | if not dockerfile.exists():
21 | pytest.skip("Dockerfile not found")
22 |
23 | content = dockerfile.read_text()
24 |
25 | # Essential checks
26 | assert "FROM python:" in content
27 | assert "COPY" in content or "ADD" in content
28 | assert "server.py" in content
29 |
30 | # Recommended security checks
31 | security_checks = [
32 | "USER " in content, # Non-root user
33 | "WORKDIR" in content, # Defined working directory
34 | ]
35 |
36 | # At least one security practice should be present
37 | if any(security_checks):
38 | assert True, "Security best practices detected"
39 |
40 | def test_environment_file_template(self):
41 | """Test environment file template"""
42 | project_root = Path(__file__).parent.parent
43 | env_example = project_root / ".env.example"
44 |
45 | if env_example.exists():
46 | content = env_example.read_text()
47 |
48 | # Essential variables
49 | essential_vars = ["GEMINI_API_KEY", "OPENAI_API_KEY", "LOG_LEVEL"]
50 |
51 | for var in essential_vars:
52 | assert f"{var}=" in content, f"Variable {var} missing"
53 |
54 | # Docker-specific variables should also be present
55 | docker_vars = ["COMPOSE_PROJECT_NAME", "TZ", "LOG_MAX_SIZE"]
56 | for var in docker_vars:
57 | assert f"{var}=" in content, f"Docker variable {var} missing"
58 |
59 | def test_logs_directory_setup(self):
60 | """Test logs directory setup"""
61 | project_root = Path(__file__).parent.parent
62 | logs_dir = project_root / "logs"
63 |
64 | # The logs directory should exist or be creatable
65 | if not logs_dir.exists():
66 | try:
67 | logs_dir.mkdir(exist_ok=True)
68 | created = True
69 | except Exception:
70 | created = False
71 |
72 | assert created, "Logs directory should be creatable"
73 | else:
74 | assert logs_dir.is_dir(), "logs should be a directory"
75 |
76 |
77 | class TestDockerCommandValidation:
78 | """Docker command validation tests"""
79 |
80 | @patch("subprocess.run")
81 | def test_docker_build_command(self, mock_run):
82 | """Test docker build command"""
83 | mock_run.return_value.returncode = 0
84 |
85 | # Standard build command
86 | build_cmd = ["docker", "build", "-t", "zen-mcp-server:latest", "."]
87 |
88 | import subprocess
89 |
90 | subprocess.run(build_cmd, capture_output=True)
91 | mock_run.assert_called_once()
92 |
93 | @patch("subprocess.run")
94 | def test_docker_run_mcp_command(self, mock_run):
95 | """Test docker run command for MCP"""
96 | mock_run.return_value.returncode = 0
97 |
98 | # Run command for MCP
99 | run_cmd = [
100 | "docker",
101 | "run",
102 | "--rm",
103 | "-i",
104 | "--env-file",
105 | ".env",
106 | "-v",
107 | "logs:/app/logs",
108 | "zen-mcp-server:latest",
109 | "python",
110 | "server.py",
111 | ]
112 |
113 | import subprocess
114 |
115 | subprocess.run(run_cmd, capture_output=True)
116 | mock_run.assert_called_once()
117 |
118 | def test_docker_command_structure(self):
119 | """Test Docker command structure"""
120 |
121 | # Recommended MCP command
122 | mcp_cmd = [
123 | "docker",
124 | "run",
125 | "--rm",
126 | "-i",
127 | "--env-file",
128 | "/path/to/.env",
129 | "-v",
130 | "/path/to/logs:/app/logs",
131 | "zen-mcp-server:latest",
132 | "python",
133 | "server.py",
134 | ]
135 |
136 | # Structure checks
137 | assert mcp_cmd[0] == "docker"
138 | assert "run" in mcp_cmd
139 | assert "--rm" in mcp_cmd # Automatic cleanup
140 | assert "-i" in mcp_cmd # Interactive mode
141 | assert "--env-file" in mcp_cmd # Environment variables
142 | assert "zen-mcp-server:latest" in mcp_cmd # Image
143 |
144 |
145 | class TestIntegrationChecks:
146 | """Integration checks"""
147 |
148 | def test_complete_setup_checklist(self):
149 | """Test complete setup checklist"""
150 | project_root = Path(__file__).parent.parent
151 |
152 | # Checklist for essential files
153 | essential_files = {
154 | "Dockerfile": project_root / "Dockerfile",
155 | "server.py": project_root / "server.py",
156 | "requirements.txt": project_root / "requirements.txt",
157 | "docker-compose.yml": project_root / "docker-compose.yml",
158 | }
159 |
160 | missing_files = []
161 | for name, path in essential_files.items():
162 | if not path.exists():
163 | missing_files.append(name)
164 |
165 | # Allow some missing files for flexibility
166 | critical_files = ["Dockerfile", "server.py"]
167 | missing_critical = [f for f in missing_files if f in critical_files]
168 |
169 | assert not missing_critical, f"Critical files missing: {missing_critical}"
170 |
171 | def test_mcp_integration_readiness(self):
172 | """Test MCP integration readiness"""
173 | project_root = Path(__file__).parent.parent
174 |
175 | # MCP integration checks
176 | checks = {
177 | "dockerfile": (project_root / "Dockerfile").exists(),
178 | "server_script": (project_root / "server.py").exists(),
179 | "logs_dir": (project_root / "logs").exists() or True,
180 | }
181 |
182 | # At least critical elements must be present
183 | critical_checks = ["dockerfile", "server_script"]
184 | missing_critical = [k for k in critical_checks if not checks[k]]
185 |
186 | assert not missing_critical, f"Critical elements missing: {missing_critical}"
187 |
188 | # Readiness score
189 | ready_score = sum(checks.values()) / len(checks)
190 | assert ready_score >= 0.75, f"Insufficient readiness score: {ready_score:.2f}"
191 |
192 |
193 | class TestErrorHandling:
194 | """Error handling tests"""
195 |
196 | def test_missing_api_key_handling(self):
197 | """Test handling of missing API key"""
198 |
199 | # Simulate environment without API keys
200 | with patch.dict(os.environ, {}, clear=True):
201 | api_keys = [os.getenv("GEMINI_API_KEY"), os.getenv("OPENAI_API_KEY"), os.getenv("XAI_API_KEY")]
202 |
203 | has_api_key = any(key for key in api_keys)
204 |
205 | # No key should be present
206 | assert not has_api_key, "No API key detected (expected for test)"
207 |
208 | # System should handle this gracefully
209 | error_handled = True # Simulate error handling
210 | assert error_handled, "API key error handling implemented"
211 |
212 | def test_docker_not_available_handling(self):
213 | """Test handling of Docker not available"""
214 |
215 | @patch("subprocess.run")
216 | def simulate_docker_unavailable(mock_run):
217 | # Simulate Docker not available
218 | mock_run.side_effect = FileNotFoundError("docker: command not found")
219 |
220 | try:
221 | import subprocess
222 |
223 | subprocess.run(["docker", "--version"], capture_output=True)
224 | docker_available = True
225 | except FileNotFoundError:
226 | docker_available = False
227 |
228 | # Docker is not available - expected error
229 | assert not docker_available, "Docker unavailable (simulation)"
230 |
231 | # System should provide a clear error message
232 | error_message_clear = True # Simulation
233 | assert error_message_clear, "Clear Docker error message"
234 |
235 | simulate_docker_unavailable()
236 |
237 |
238 | if __name__ == "__main__":
239 | pytest.main([__file__, "-v"])
240 |
```
--------------------------------------------------------------------------------
/utils/file_types.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | File type definitions and constants for file processing
3 |
4 | This module centralizes all file type and extension definitions used
5 | throughout the MCP server for consistent file handling.
6 | """
7 |
8 | # Programming language file extensions - core code files
9 | PROGRAMMING_LANGUAGES = {
10 | ".py", # Python
11 | ".js", # JavaScript
12 | ".ts", # TypeScript
13 | ".jsx", # React JavaScript
14 | ".tsx", # React TypeScript
15 | ".java", # Java
16 | ".cpp", # C++
17 | ".c", # C
18 | ".h", # C/C++ Header
19 | ".hpp", # C++ Header
20 | ".cs", # C#
21 | ".go", # Go
22 | ".rs", # Rust
23 | ".rb", # Ruby
24 | ".php", # PHP
25 | ".swift", # Swift
26 | ".kt", # Kotlin
27 | ".scala", # Scala
28 | ".r", # R
29 | ".m", # Objective-C
30 | ".mm", # Objective-C++
31 | }
32 |
33 | # Script and shell file extensions
34 | SCRIPTS = {
35 | ".sql", # SQL
36 | ".sh", # Shell
37 | ".bash", # Bash
38 | ".zsh", # Zsh
39 | ".fish", # Fish shell
40 | ".ps1", # PowerShell
41 | ".bat", # Batch
42 | ".cmd", # Command
43 | }
44 |
45 | # Configuration and data file extensions
46 | CONFIGS = {
47 | ".yml", # YAML
48 | ".yaml", # YAML
49 | ".json", # JSON
50 | ".xml", # XML
51 | ".toml", # TOML
52 | ".ini", # INI
53 | ".cfg", # Config
54 | ".conf", # Config
55 | ".properties", # Properties
56 | ".env", # Environment
57 | }
58 |
59 | # Documentation and markup file extensions
60 | DOCS = {
61 | ".txt", # Text
62 | ".md", # Markdown
63 | ".rst", # reStructuredText
64 | ".tex", # LaTeX
65 | }
66 |
67 | # Web development file extensions
68 | WEB = {
69 | ".html", # HTML
70 | ".css", # CSS
71 | ".scss", # Sass
72 | ".sass", # Sass
73 | ".less", # Less
74 | }
75 |
76 | # Additional text file extensions for logs and data
77 | TEXT_DATA = {
78 | ".log", # Log files
79 | ".csv", # CSV
80 | ".tsv", # TSV
81 | ".gitignore", # Git ignore
82 | ".dockerfile", # Dockerfile
83 | ".makefile", # Make
84 | ".cmake", # CMake
85 | ".gradle", # Gradle
86 | ".sbt", # SBT
87 | ".pom", # Maven POM
88 | ".lock", # Lock files
89 | ".changeset", # Precommit changeset
90 | }
91 |
92 | # Image file extensions - limited to what AI models actually support
93 | # Based on OpenAI and Gemini supported formats: PNG, JPEG, GIF, WebP
94 | IMAGES = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
95 |
96 | # Binary executable and library extensions
97 | BINARIES = {
98 | ".exe", # Windows executable
99 | ".dll", # Windows library
100 | ".so", # Linux shared object
101 | ".dylib", # macOS dynamic library
102 | ".bin", # Binary
103 | ".class", # Java class
104 | }
105 |
106 | # Archive and package file extensions
107 | ARCHIVES = {
108 | ".jar",
109 | ".war",
110 | ".ear", # Java archives
111 | ".zip",
112 | ".tar",
113 | ".gz", # General archives
114 | ".7z",
115 | ".rar", # Compression
116 | ".deb",
117 | ".rpm", # Linux packages
118 | ".dmg",
119 | ".pkg", # macOS packages
120 | }
121 |
122 | # Derived sets for different use cases
123 | CODE_EXTENSIONS = PROGRAMMING_LANGUAGES | SCRIPTS | CONFIGS | DOCS | WEB
124 | PROGRAMMING_EXTENSIONS = PROGRAMMING_LANGUAGES # For line numbering
125 | TEXT_EXTENSIONS = CODE_EXTENSIONS | TEXT_DATA
126 | IMAGE_EXTENSIONS = IMAGES
127 | BINARY_EXTENSIONS = BINARIES | ARCHIVES
128 |
129 | # All extensions by category for easy access
130 | FILE_CATEGORIES = {
131 | "programming": PROGRAMMING_LANGUAGES,
132 | "scripts": SCRIPTS,
133 | "configs": CONFIGS,
134 | "docs": DOCS,
135 | "web": WEB,
136 | "text_data": TEXT_DATA,
137 | "images": IMAGES,
138 | "binaries": BINARIES,
139 | "archives": ARCHIVES,
140 | }
141 |
142 |
143 | def get_file_category(file_path: str) -> str:
144 | """
145 | Determine the category of a file based on its extension.
146 |
147 | Args:
148 | file_path: Path to the file
149 |
150 | Returns:
151 | Category name or "unknown" if not recognized
152 | """
153 | from pathlib import Path
154 |
155 | extension = Path(file_path).suffix.lower()
156 |
157 | for category, extensions in FILE_CATEGORIES.items():
158 | if extension in extensions:
159 | return category
160 |
161 | return "unknown"
162 |
163 |
164 | def is_code_file(file_path: str) -> bool:
165 | """Check if a file is a code file (programming language)."""
166 | from pathlib import Path
167 |
168 | return Path(file_path).suffix.lower() in PROGRAMMING_LANGUAGES
169 |
170 |
171 | def is_text_file(file_path: str) -> bool:
172 | """Check if a file is a text file."""
173 | from pathlib import Path
174 |
175 | return Path(file_path).suffix.lower() in TEXT_EXTENSIONS
176 |
177 |
178 | def is_binary_file(file_path: str) -> bool:
179 | """Check if a file is a binary file."""
180 | from pathlib import Path
181 |
182 | return Path(file_path).suffix.lower() in BINARY_EXTENSIONS
183 |
184 |
185 | # File-type specific token-to-byte ratios for accurate token estimation
186 | # Based on empirical analysis of file compression characteristics and tokenization patterns
187 | TOKEN_ESTIMATION_RATIOS = {
188 | # Programming languages
189 | ".py": 3.5, # Python - moderate verbosity
190 | ".js": 3.2, # JavaScript - compact syntax
191 | ".ts": 3.3, # TypeScript - type annotations add tokens
192 | ".jsx": 3.1, # React JSX - JSX tags are tokenized efficiently
193 | ".tsx": 3.0, # React TSX - combination of TypeScript + JSX
194 | ".java": 3.6, # Java - verbose syntax, long identifiers
195 | ".cpp": 3.7, # C++ - preprocessor directives, templates
196 | ".c": 3.8, # C - function definitions, struct declarations
197 | ".go": 3.9, # Go - explicit error handling, package names
198 | ".rs": 3.5, # Rust - similar to Python in verbosity
199 | ".php": 3.3, # PHP - mixed HTML/code, variable prefixes
200 | ".rb": 3.6, # Ruby - descriptive method names
201 | ".swift": 3.4, # Swift - modern syntax, type inference
202 | ".kt": 3.5, # Kotlin - similar to modern languages
203 | ".scala": 3.2, # Scala - functional programming, concise
204 | # Scripts and configuration
205 | ".sh": 4.1, # Shell scripts - commands and paths
206 | ".bat": 4.0, # Batch files - similar to shell
207 | ".ps1": 3.8, # PowerShell - more structured than bash
208 | ".sql": 3.8, # SQL - keywords and table/column names
209 | # Data and configuration formats
210 | ".json": 2.5, # JSON - lots of punctuation and quotes
211 | ".yaml": 3.0, # YAML - structured but readable
212 | ".yml": 3.0, # YAML (alternative extension)
213 | ".xml": 2.8, # XML - tags and attributes
214 | ".toml": 3.2, # TOML - similar to config files
215 | # Documentation and text
216 | ".md": 4.2, # Markdown - natural language with formatting
217 | ".txt": 4.0, # Plain text - mostly natural language
218 | ".rst": 4.1, # reStructuredText - documentation format
219 | # Web technologies
220 | ".html": 2.9, # HTML - tags and attributes
221 | ".css": 3.4, # CSS - properties and selectors
222 | # Logs and data
223 | ".log": 4.5, # Log files - timestamps, messages, stack traces
224 | ".csv": 3.1, # CSV - data with delimiters
225 | # Infrastructure files
226 | ".dockerfile": 3.7, # Dockerfile - commands and paths
227 | ".tf": 3.5, # Terraform - infrastructure as code
228 | }
229 |
230 |
231 | def get_token_estimation_ratio(file_path: str) -> float:
232 | """
233 | Get the token estimation ratio for a file based on its extension.
234 |
235 | Args:
236 | file_path: Path to the file
237 |
238 | Returns:
239 | Token-to-byte ratio for the file type (default: 3.5 for unknown types)
240 | """
241 | from pathlib import Path
242 |
243 | extension = Path(file_path).suffix.lower()
244 | return TOKEN_ESTIMATION_RATIOS.get(extension, 3.5) # Conservative default
245 |
246 |
247 | # MIME type mappings for image files - limited to what AI models actually support
248 | # Based on OpenAI and Gemini supported formats: PNG, JPEG, GIF, WebP
249 | IMAGE_MIME_TYPES = {
250 | ".jpg": "image/jpeg",
251 | ".jpeg": "image/jpeg",
252 | ".png": "image/png",
253 | ".gif": "image/gif",
254 | ".webp": "image/webp",
255 | }
256 |
257 |
258 | def get_image_mime_type(extension: str) -> str:
259 | """
260 | Get the MIME type for an image file extension.
261 |
262 | Args:
263 | extension: File extension (with or without leading dot)
264 |
265 | Returns:
266 | MIME type string (default: image/jpeg for unknown extensions)
267 | """
268 | if not extension.startswith("."):
269 | extension = "." + extension
270 | extension = extension.lower()
271 | return IMAGE_MIME_TYPES.get(extension, "image/jpeg")
272 |
```
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Pytest configuration for Zen MCP Server tests
3 | """
4 |
5 | import asyncio
6 | import importlib
7 | import os
8 | import sys
9 | from pathlib import Path
10 |
11 | import pytest
12 |
13 | # Ensure the parent directory is in the Python path for imports
14 | parent_dir = Path(__file__).resolve().parent.parent
15 | if str(parent_dir) not in sys.path:
16 | sys.path.insert(0, str(parent_dir))
17 |
18 | import utils.env as env_config # noqa: E402
19 |
20 | # Ensure tests operate with runtime environment rather than .env overrides during imports
21 | env_config.reload_env({"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
22 |
23 | # Set default model to a specific value for tests to avoid auto mode
24 | # This prevents all tests from failing due to missing model parameter
25 | os.environ["DEFAULT_MODEL"] = "gemini-2.5-flash"
26 |
27 | # Force reload of config module to pick up the env var
28 | import config # noqa: E402
29 |
30 | importlib.reload(config)
31 |
32 | # Note: This creates a test sandbox environment
33 | # Tests create their own temporary directories as needed
34 |
35 | # Configure asyncio for Windows compatibility
36 | if sys.platform == "win32":
37 | asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
38 |
39 | # Register providers for all tests
40 | from providers.gemini import GeminiModelProvider # noqa: E402
41 | from providers.openai import OpenAIModelProvider # noqa: E402
42 | from providers.registry import ModelProviderRegistry # noqa: E402
43 | from providers.shared import ProviderType # noqa: E402
44 | from providers.xai import XAIModelProvider # noqa: E402
45 |
46 | # Register providers at test startup
47 | ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
48 | ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
49 | ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
50 |
51 | # Register CUSTOM provider if CUSTOM_API_URL is available (for integration tests)
52 | # But only if we're actually running integration tests, not unit tests
53 | if os.getenv("CUSTOM_API_URL") and "test_prompt_regression.py" in os.getenv("PYTEST_CURRENT_TEST", ""):
54 | from providers.custom import CustomProvider # noqa: E402
55 |
56 | def custom_provider_factory(api_key=None):
57 | """Factory function that creates CustomProvider with proper parameters."""
58 | base_url = os.getenv("CUSTOM_API_URL", "")
59 | return CustomProvider(api_key=api_key or "", base_url=base_url)
60 |
61 | ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
62 |
63 |
64 | @pytest.fixture
65 | def project_path(tmp_path):
66 | """
67 | Provides a temporary directory for tests.
68 | This ensures all file operations during tests are isolated.
69 | """
70 | # Create a subdirectory for this specific test
71 | test_dir = tmp_path / "test_workspace"
72 | test_dir.mkdir(parents=True, exist_ok=True)
73 |
74 | return test_dir
75 |
76 |
77 | def _set_dummy_keys_if_missing():
78 | """Set dummy API keys only when they are completely absent."""
79 | for var in ("GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY"):
80 | if not os.environ.get(var):
81 | os.environ[var] = "dummy-key-for-tests"
82 |
83 |
84 | # Pytest configuration
85 | def pytest_configure(config):
86 | """Configure pytest with custom markers"""
87 | config.addinivalue_line("markers", "asyncio: mark test as async")
88 | config.addinivalue_line("markers", "no_mock_provider: disable automatic provider mocking")
89 | # Assume we need dummy keys until we learn otherwise
90 | config._needs_dummy_keys = True
91 |
92 |
93 | def pytest_collection_modifyitems(session, config, items):
94 | """Hook that runs after test collection to check for no_mock_provider markers."""
95 | # Always set dummy keys if real keys are missing
96 | # This ensures tests work in CI even with no_mock_provider marker
97 | _set_dummy_keys_if_missing()
98 |
99 |
100 | @pytest.fixture(autouse=True)
101 | def mock_provider_availability(request, monkeypatch):
102 | """
103 | Automatically mock provider availability for all tests to prevent
104 | effective auto mode from being triggered when DEFAULT_MODEL is unavailable.
105 |
106 | This fixture ensures that when tests run with dummy API keys,
107 | the tools don't require model selection unless explicitly testing auto mode.
108 | """
109 | # Skip this fixture for tests that need real providers
110 | if hasattr(request, "node"):
111 | marker = request.node.get_closest_marker("no_mock_provider")
112 | if marker:
113 | return
114 |
115 | # Ensure providers are registered (in case other tests cleared the registry)
116 | from providers.shared import ProviderType
117 |
118 | registry = ModelProviderRegistry()
119 |
120 | if ProviderType.GOOGLE not in registry._providers:
121 | ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
122 | if ProviderType.OPENAI not in registry._providers:
123 | ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
124 | if ProviderType.XAI not in registry._providers:
125 | ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider)
126 |
127 | # Ensure CUSTOM provider is registered if needed for integration tests
128 | if (
129 | os.getenv("CUSTOM_API_URL")
130 | and "test_prompt_regression.py" in os.getenv("PYTEST_CURRENT_TEST", "")
131 | and ProviderType.CUSTOM not in registry._providers
132 | ):
133 | from providers.custom import CustomProvider
134 |
135 | def custom_provider_factory(api_key=None):
136 | base_url = os.getenv("CUSTOM_API_URL", "")
137 | return CustomProvider(api_key=api_key or "", base_url=base_url)
138 |
139 | ModelProviderRegistry.register_provider(ProviderType.CUSTOM, custom_provider_factory)
140 |
141 | # Also mock is_effective_auto_mode for all BaseTool instances to return False
142 | # unless we're specifically testing auto mode behavior
143 | from tools.shared.base_tool import BaseTool
144 |
145 | def mock_is_effective_auto_mode(self):
146 | # If this is an auto mode test file or specific auto mode test, use the real logic
147 | test_file = request.node.fspath.basename if hasattr(request, "node") and hasattr(request.node, "fspath") else ""
148 | test_name = request.node.name if hasattr(request, "node") else ""
149 |
150 | # Allow auto mode for tests in auto mode files or with auto in the name
151 | if (
152 | "auto_mode" in test_file.lower()
153 | or "auto" in test_name.lower()
154 | or "intelligent_fallback" in test_file.lower()
155 | or "per_tool_model_defaults" in test_file.lower()
156 | ):
157 | # Call original method logic
158 | from config import DEFAULT_MODEL
159 |
160 | if DEFAULT_MODEL.lower() == "auto":
161 | return True
162 | provider = ModelProviderRegistry.get_provider_for_model(DEFAULT_MODEL)
163 | return provider is None
164 | # For all other tests, return False to disable auto mode
165 | return False
166 |
167 | monkeypatch.setattr(BaseTool, "is_effective_auto_mode", mock_is_effective_auto_mode)
168 |
169 |
170 | @pytest.fixture(autouse=True)
171 | def clear_model_restriction_env(monkeypatch):
172 | """Ensure per-test isolation from user-defined model restriction env vars."""
173 |
174 | restriction_vars = [
175 | "OPENAI_ALLOWED_MODELS",
176 | "GOOGLE_ALLOWED_MODELS",
177 | "XAI_ALLOWED_MODELS",
178 | "OPENROUTER_ALLOWED_MODELS",
179 | "DIAL_ALLOWED_MODELS",
180 | ]
181 |
182 | for var in restriction_vars:
183 | monkeypatch.delenv(var, raising=False)
184 |
185 |
186 | @pytest.fixture(autouse=True)
187 | def disable_force_env_override(monkeypatch):
188 | """Default tests to runtime environment visibility unless they explicitly opt in."""
189 |
190 | monkeypatch.setenv("ZEN_MCP_FORCE_ENV_OVERRIDE", "false")
191 | env_config.reload_env({"ZEN_MCP_FORCE_ENV_OVERRIDE": "false"})
192 | monkeypatch.setenv("DEFAULT_MODEL", "gemini-2.5-flash")
193 | monkeypatch.setenv("MAX_CONVERSATION_TURNS", "50")
194 |
195 | import importlib
196 |
197 | import config
198 | import utils.conversation_memory as conversation_memory
199 |
200 | importlib.reload(config)
201 | importlib.reload(conversation_memory)
202 |
203 | try:
204 | yield
205 | finally:
206 | env_config.reload_env()
207 |
```
--------------------------------------------------------------------------------
/systemprompts/generate_code_prompt.py:
--------------------------------------------------------------------------------
```python
1 | """System prompt fragment enabling structured code generation exports.
2 |
3 | This prompt is injected into the system prompt for models that have the
4 | 'allow_code_generation' capability enabled. It instructs the model to output
5 | complete, working code in a structured format that coding agents can parse
6 | and apply automatically.
7 |
8 | The structured format uses XML-like tags to clearly delineate:
9 | - New files to create (<NEWFILE>)
10 | - Existing files to update (<UPDATED_EXISTING_FILE>)
11 | - Step-by-step instructions for the coding agent
12 |
13 | This enables:
14 | 1. Automated code extraction and application
15 | 2. Clear separation between instructions and implementation
16 | 3. Complete, runnable code without manual edits
17 | 4. Precise change tracking across multiple files
18 | """
19 |
20 | GENERATE_CODE_PROMPT = """
21 | # Structured Code Generation Protocol
22 |
23 | **WHEN TO USE THIS PROTOCOL:**
24 |
25 | Use this structured format ONLY when you are explicitly tasked with substantial code generation, such as:
26 | - Creating new features from scratch with multiple files or significant code and you have been asked to help implement this
27 | - Major refactoring across multiple files or large sections of code and you have been tasked to help do this
28 | - Implementing new modules, components, or subsystems and you have been tasked to help with the implementation
29 | - Large-scale updates affecting substantial portions of the codebase that you have been asked to help implement
30 |
31 | **WHEN NOT TO USE THIS PROTOCOL:**
32 |
33 | Do NOT use this format for minor changes:
34 | - Small tweaks to existing functions or methods (1-20 lines)
35 | - Bug fixes in isolated sections
36 | - Simple algorithm improvements
37 | - Minor refactoring of a single function
38 | - Adding/removing a few lines of code
39 | - Quick parameter adjustments or config changes
40 |
41 | For minor changes:
42 | - Follow the existing instructions provided earlier in your system prompt, such as the CRITICAL LINE NUMBER INSTRUCTIONS.
43 | - Use inline code blocks with proper line number references and direct explanations instead of this structured format.
44 |
45 | **IMPORTANT:** This protocol is for SUBSTANTIAL implementation work when explicitly requested, such as:
46 | - "implement feature X"
47 | - "create module Y"
48 | - "refactor system Z"
49 | - "rewrite the authentication logic"
50 | - "redesign the data processing pipeline"
51 | - "rebuild the algorithm from scratch"
52 | - "convert this approach to use a different pattern"
53 | - "create a complete implementation of..."
54 | - "build out the entire workflow for..."
55 |
56 | If the request is for explanation, analysis, debugging, planning, or discussion WITHOUT substantial code generation, respond normally without this structured format.
57 |
58 | ## Core Requirements (for substantial code generation tasks)
59 |
60 | 1. **Complete, Working Code**: Every code block must be fully functional without requiring additional edits. Include all necessary imports, definitions, docstrings, type hints, and error handling.
61 |
62 | 2. **Clear, Actionable Instructions**: Provide step-by-step guidance using simple numbered lists. Each instruction should map directly to file blocks that follow.
63 |
64 | 3. **Structured Output Format**: All generated code MUST be contained within a single `<GENERATED-CODE>` block using the exact structure defined below.
65 |
66 | 4. **Minimal External Commentary**: Keep any text outside the `<GENERATED-CODE>` block brief. Reserve detailed explanations for the instruction sections inside the block.
67 |
68 | ## Required Structure
69 |
70 | Use this exact format (do not improvise tag names or reorder components):
71 |
72 | ```
73 | <GENERATED-CODE>
74 | [Step-by-step instructions for the coding agent]
75 | 1. Create new file [filename] with [description]
76 | 2. Update existing file [filename] by [description]
77 | 3. [Additional steps as needed]
78 |
79 | <NEWFILE: path/to/new_file.py>
80 | [Complete file contents with all necessary components:
81 | - File-level docstring
82 | - All imports (standard library, third-party, local)
83 | - All class/function definitions with complete implementations
84 | - All necessary helper functions
85 | - Inline comments for complex logic
86 | - Type hints where applicable]
87 | </NEWFILE>
88 |
89 | [Additional instructions for the next file, if needed]
90 |
91 | <NEWFILE: path/to/another_file.py>
92 | [Complete, working code for this file - no partial implementations or placeholders]
93 | </NEWFILE>
94 |
95 | [Instructions for updating existing files]
96 |
97 | <UPDATED_EXISTING_FILE: existing/path.py>
98 | [Complete replacement code for the modified sections or routines / lines that need updating:
99 | - Full function/method bodies (not just the changed lines)
100 | - Complete class definitions if modifying class methods
101 | - All necessary imports if adding new dependencies
102 | - Preserve existing code structure and style]
103 | </UPDATED_EXISTING_FILE>
104 |
105 | [If additional files need updates (based on existing code that was shared with you earlier), repeat the UPDATED_EXISTING_FILE block]
106 |
107 | <UPDATED_EXISTING_FILE: another/existing/file.py>
108 | [Complete code for this file's modifications]
109 | </UPDATED_EXISTING_FILE>
110 |
111 | [For file deletions, explicitly state in instructions with justification:
112 | "Delete file path/to/obsolete.py - no longer needed because [reason]"]
113 | </GENERATED-CODE>
114 | ```
115 |
116 | ## Critical Rules
117 |
118 | **Completeness:**
119 | - Never output partial code snippets or placeholder comments like "# rest of code here"
120 | - Include complete function/class implementations from start to finish
121 | - Add all required imports at the file level
122 | - Include proper error handling and edge case logic
123 |
124 | **Accuracy:**
125 | - Match the existing codebase indentation style (tabs vs spaces)
126 | - Preserve language-specific formatting conventions
127 | - Include trailing newlines where required by language tooling
128 | - Use correct file paths relative to project root
129 |
130 | **Clarity:**
131 | - Number instructions sequentially (1, 2, 3...)
132 | - Map each instruction to specific file blocks below it
133 | - Explain *why* changes are needed, not just *what* changes
134 | - Highlight any breaking changes or migration steps required
135 |
136 | **Structure:**
137 | - Use `<NEWFILE: ...>` for files that don't exist yet
138 | - Use `<UPDATED_EXISTING_FILE: ...>` for modifying existing files
139 | - Place instructions between file blocks to provide context
140 | - Keep the single `<GENERATED-CODE>` wrapper around everything
141 |
142 | ## Special Cases
143 |
144 | **No Changes Needed:**
145 | If the task doesn't require file creation or modification, explicitly state:
146 | "No file changes required. The existing implementation already handles [requirement]."
147 | Do not emit an empty `<GENERATED-CODE>` block.
148 |
149 | **Configuration Changes:**
150 | If modifying configuration files (JSON, YAML, TOML), include complete file contents with the changes applied, not just the changed lines.
151 |
152 | **Test Files:**
153 | When generating tests, include complete test suites with:
154 | - All necessary test fixtures and setup
155 | - Multiple test cases covering happy path and edge cases
156 | - Proper teardown and cleanup
157 | - Clear test descriptions and assertions
158 |
159 | **Documentation:**
160 | Include docstrings for all public functions, classes, and modules using the project's documentation style (Google, NumPy, Sphinx, etc.).
161 |
162 | ## Context Awareness
163 |
164 | **CRITICAL:** Your implementation builds upon the ongoing conversation context:
165 | - All previously shared files, requirements, and constraints remain relevant
166 | - If updating existing code discussed earlier, reference it and preserve unmodified sections
167 | - If the user shared code for improvement, your generated code should build upon it, not replace everything
168 | - The coding agent has full conversation history—your instructions should reference prior discussion as needed
169 |
170 | Your generated code is NOT standalone—it's a continuation of the collaborative session with full context awareness.
171 |
172 | ## Remember
173 |
174 | The coding agent depends on this structured format to:
175 | - Parse and extract code automatically
176 | - Apply changes to the correct files within the conversation context
177 | - Validate completeness before execution
178 | - Track modifications across the codebase
179 |
180 | Always prioritize clarity, completeness, correctness, and context awareness over brevity.
181 | """
182 |
```
--------------------------------------------------------------------------------
/tools/shared/base_models.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Base models for Zen MCP tools.
3 |
4 | This module contains the shared Pydantic models used across all tools,
5 | extracted to avoid circular imports and promote code reuse.
6 |
7 | Key Models:
8 | - ToolRequest: Base request model for all tools
9 | - WorkflowRequest: Extended request model for workflow-based tools
10 | - ConsolidatedFindings: Model for tracking workflow progress
11 | """
12 |
13 | import logging
14 | from typing import Optional
15 |
16 | from pydantic import BaseModel, Field, field_validator
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | # Shared field descriptions to avoid duplication
22 | COMMON_FIELD_DESCRIPTIONS = {
23 | "model": "Model to run. Supply a name if requested by the user or stay in auto mode. When in auto mode, use `listmodels` tool for model discovery.",
24 | "temperature": "0 = deterministic · 1 = creative.",
25 | "thinking_mode": "Reasoning depth: minimal, low, medium, high, or max.",
26 | "continuation_id": (
27 | "Unique thread continuation ID for multi-turn conversations. Works across different tools. "
28 | "ALWAYS reuse the last continuation_id you were given—this preserves full conversation context, "
29 | "files, and findings so the agent can resume seamlessly."
30 | ),
31 | "images": "Optional absolute image paths or base64 blobs for visual context.",
32 | "absolute_file_paths": "Full paths to relevant code",
33 | }
34 |
35 | # Workflow-specific field descriptions
36 | WORKFLOW_FIELD_DESCRIPTIONS = {
37 | "step": "Current work step content and findings from your overall work",
38 | "step_number": "Current step number in work sequence (starts at 1)",
39 | "total_steps": "Estimated total steps needed to complete work",
40 | "next_step_required": "Whether another work step is needed. When false, aim to reduce total_steps to match step_number to avoid mismatch.",
41 | "findings": "Important findings, evidence and insights discovered in this step",
42 | "files_checked": "List of files examined during this work step",
43 | "relevant_files": "Files identified as relevant to issue/goal (FULL absolute paths to real files/folders - DO NOT SHORTEN)",
44 | "relevant_context": "Methods/functions identified as involved in the issue",
45 | "issues_found": "Issues identified with severity levels during work",
46 | "confidence": (
47 | "Confidence level: exploring (just starting), low (early investigation), "
48 | "medium (some evidence), high (strong evidence), very_high (comprehensive understanding), "
49 | "almost_certain (near complete confidence), certain (100% confidence locally - no external validation needed)"
50 | ),
51 | "hypothesis": "Current theory about issue/goal based on work",
52 | "use_assistant_model": (
53 | "Use assistant model for expert analysis after workflow steps. "
54 | "False skips expert analysis, relies solely on your personal investigation. "
55 | "Defaults to True for comprehensive validation."
56 | ),
57 | }
58 |
59 |
60 | class ToolRequest(BaseModel):
61 | """
62 | Base request model for all Zen MCP tools.
63 |
64 | This model defines common fields that all tools accept, including
65 | model selection, temperature control, and conversation threading.
66 | Tool-specific request models should inherit from this class.
67 | """
68 |
69 | # Model configuration
70 | model: Optional[str] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["model"])
71 | temperature: Optional[float] = Field(None, ge=0.0, le=1.0, description=COMMON_FIELD_DESCRIPTIONS["temperature"])
72 | thinking_mode: Optional[str] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["thinking_mode"])
73 |
74 | # Conversation support
75 | continuation_id: Optional[str] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["continuation_id"])
76 |
77 | # Visual context
78 | images: Optional[list[str]] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["images"])
79 |
80 |
81 | class BaseWorkflowRequest(ToolRequest):
82 | """
83 | Minimal base request model for workflow tools.
84 |
85 | This provides only the essential fields that ALL workflow tools need,
86 | allowing for maximum flexibility in tool-specific implementations.
87 | """
88 |
89 | # Core workflow fields that ALL workflow tools need
90 | step: str = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["step"])
91 | step_number: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
92 | total_steps: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
93 | next_step_required: bool = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
94 |
95 |
96 | class WorkflowRequest(BaseWorkflowRequest):
97 | """
98 | Extended request model for workflow-based tools.
99 |
100 | This model extends ToolRequest with fields specific to the workflow
101 | pattern, where tools perform multi-step work with forced pauses between steps.
102 |
103 | Used by: debug, precommit, codereview, refactor, thinkdeep, analyze
104 | """
105 |
106 | # Required workflow fields
107 | step: str = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["step"])
108 | step_number: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
109 | total_steps: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
110 | next_step_required: bool = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
111 |
112 | # Work tracking fields
113 | findings: str = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["findings"])
114 | files_checked: list[str] = Field(default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["files_checked"])
115 | relevant_files: list[str] = Field(default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"])
116 | relevant_context: list[str] = Field(
117 | default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"]
118 | )
119 | issues_found: list[dict] = Field(default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["issues_found"])
120 | confidence: str = Field("low", description=WORKFLOW_FIELD_DESCRIPTIONS["confidence"])
121 |
122 | # Optional workflow fields
123 | hypothesis: Optional[str] = Field(None, description=WORKFLOW_FIELD_DESCRIPTIONS["hypothesis"])
124 | use_assistant_model: Optional[bool] = Field(True, description=WORKFLOW_FIELD_DESCRIPTIONS["use_assistant_model"])
125 |
126 | @field_validator("files_checked", "relevant_files", "relevant_context", mode="before")
127 | @classmethod
128 | def convert_string_to_list(cls, v):
129 | """Convert string inputs to empty lists to handle malformed inputs gracefully."""
130 | if isinstance(v, str):
131 | logger.warning(f"Field received string '{v}' instead of list, converting to empty list")
132 | return []
133 | return v
134 |
135 |
136 | class ConsolidatedFindings(BaseModel):
137 | """
138 | Model for tracking consolidated findings across workflow steps.
139 |
140 | This model accumulates findings, files, methods, and issues
141 | discovered during multi-step work. It's used by
142 | BaseWorkflowMixin to track progress across workflow steps.
143 | """
144 |
145 | files_checked: set[str] = Field(default_factory=set, description="All files examined across all steps")
146 | relevant_files: set[str] = Field(
147 | default_factory=set,
148 | description="Subset of files_checked identified as relevant for work at hand",
149 | )
150 | relevant_context: set[str] = Field(
151 | default_factory=set, description="All methods/functions identified during overall work"
152 | )
153 | findings: list[str] = Field(default_factory=list, description="Chronological findings from each work step")
154 | hypotheses: list[dict] = Field(default_factory=list, description="Evolution of hypotheses across steps")
155 | issues_found: list[dict] = Field(default_factory=list, description="All issues with severity levels")
156 | images: list[str] = Field(default_factory=list, description="Images collected during work")
157 | confidence: str = Field("low", description="Latest confidence level from steps")
158 |
159 |
160 | # Tool-specific field descriptions are now declared in each tool file
161 | # This keeps concerns separated and makes each tool self-contained
162 |
```
--------------------------------------------------------------------------------
/tests/test_auto_model_planner_fix.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for the auto model planner fix.
3 |
4 | This test confirms that the planner tool no longer fails when DEFAULT_MODEL is "auto"
5 | and only basic providers (Google/OpenAI) are configured, while ensuring other tools
6 | still properly require model resolution.
7 | """
8 |
9 | from unittest.mock import patch
10 |
11 | from mcp.types import TextContent
12 |
13 | from tools.chat import ChatTool
14 | from tools.planner import PlannerTool
15 | from tools.shared.base_tool import BaseTool
16 |
17 |
18 | class TestAutoModelPlannerFix:
19 | """Test the fix for auto model resolution with planner tool."""
20 |
21 | def test_planner_requires_model_false(self):
22 | """Test that planner tool returns False for requires_model."""
23 | planner = PlannerTool()
24 | assert planner.requires_model() is False
25 |
26 | def test_chat_requires_model_true(self):
27 | """Test that chat tool returns True for requires_model (default behavior)."""
28 | chat = ChatTool()
29 | assert chat.requires_model() is True
30 |
31 | def test_base_tool_requires_model_default(self):
32 | """Test that BaseTool default implementation returns True."""
33 |
34 | # Create a mock tool that doesn't override requires_model
35 | class MockTool(BaseTool):
36 | def get_name(self):
37 | return "mock"
38 |
39 | def get_description(self):
40 | return "Mock tool"
41 |
42 | def get_input_schema(self):
43 | return {}
44 |
45 | def get_system_prompt(self):
46 | return "Mock prompt"
47 |
48 | def get_request_model(self):
49 | from tools.shared.base_models import ToolRequest
50 |
51 | return ToolRequest
52 |
53 | async def prepare_prompt(self, request):
54 | return "Mock prompt"
55 |
56 | mock_tool = MockTool()
57 | assert mock_tool.requires_model() is True
58 |
59 | @patch("config.DEFAULT_MODEL", "auto")
60 | @patch("providers.registry.ModelProviderRegistry.get_provider_for_model")
61 | def test_auto_model_error_before_fix_simulation(self, mock_get_provider):
62 | """
63 | Simulate the error that would occur before the fix.
64 |
65 | This test simulates what would happen if server.py didn't check requires_model()
66 | and tried to resolve "auto" as a literal model name.
67 | """
68 | # Mock the scenario where no provider is found for "auto"
69 | mock_get_provider.return_value = None
70 |
71 | # This should return None, simulating the "No provider found for model auto" error
72 | result = mock_get_provider("auto")
73 | assert result is None
74 |
75 | # Verify that the mock was called with "auto"
76 | mock_get_provider.assert_called_with("auto")
77 |
78 | @patch("server.DEFAULT_MODEL", "auto")
79 | async def test_planner_execution_bypasses_model_resolution(self):
80 | """
81 | Test that planner tool execution works even when DEFAULT_MODEL is "auto".
82 |
83 | This test confirms that the fix allows planner to work regardless of
84 | model configuration since it doesn't need model resolution.
85 | """
86 | planner = PlannerTool()
87 |
88 | # Test with minimal planner arguments
89 | arguments = {"step": "Test planning step", "step_number": 1, "total_steps": 1, "next_step_required": False}
90 |
91 | # This should work without any model resolution
92 | result = await planner.execute(arguments)
93 |
94 | # Verify we got a result
95 | assert isinstance(result, list)
96 | assert len(result) > 0
97 | assert isinstance(result[0], TextContent)
98 |
99 | # Parse the JSON response to verify it's valid
100 | import json
101 |
102 | response_data = json.loads(result[0].text)
103 | assert response_data["status"] == "planning_complete"
104 | assert response_data["step_number"] == 1
105 |
106 | @patch("config.DEFAULT_MODEL", "auto")
107 | def test_server_model_resolution_logic(self):
108 | """
109 | Test the server-side logic that checks requires_model() before model resolution.
110 |
111 | This simulates the key fix in server.py where we check tool.requires_model()
112 | before attempting model resolution.
113 | """
114 | planner = PlannerTool()
115 | chat = ChatTool()
116 |
117 | # Simulate the server logic
118 | def simulate_server_model_resolution(tool, model_name):
119 | """Simulate the fixed server logic from server.py"""
120 | if not tool.requires_model():
121 | # Skip model resolution for tools that don't require models
122 | return "SKIP_MODEL_RESOLUTION"
123 | else:
124 | # Would normally do model resolution here
125 | return f"RESOLVE_MODEL_{model_name}"
126 |
127 | # Test planner (should skip model resolution)
128 | result = simulate_server_model_resolution(planner, "auto")
129 | assert result == "SKIP_MODEL_RESOLUTION"
130 |
131 | # Test chat (should attempt model resolution)
132 | result = simulate_server_model_resolution(chat, "auto")
133 | assert result == "RESOLVE_MODEL_auto"
134 |
135 | def test_provider_registry_auto_handling(self):
136 | """
137 | Test that the provider registry correctly handles model resolution.
138 |
139 | This tests the scenario where providers don't recognize "auto" as a model.
140 | """
141 | from providers.registry import ModelProviderRegistry
142 |
143 | # This should return None since "auto" is not a real model name
144 | provider = ModelProviderRegistry.get_provider_for_model("auto")
145 | assert provider is None, "Provider registry should not find a provider for literal 'auto'"
146 |
147 | @patch("config.DEFAULT_MODEL", "auto")
148 | async def test_end_to_end_planner_with_auto_mode(self):
149 | """
150 | End-to-end test of planner tool execution in auto mode.
151 |
152 | This test verifies that the complete flow works when DEFAULT_MODEL is "auto"
153 | and the planner tool is used.
154 | """
155 | planner = PlannerTool()
156 |
157 | # Verify the tool doesn't require model resolution
158 | assert not planner.requires_model()
159 |
160 | # Test a multi-step planning scenario
161 | step1_args = {
162 | "step": "Analyze the current system architecture",
163 | "step_number": 1,
164 | "total_steps": 3,
165 | "next_step_required": True,
166 | }
167 |
168 | result1 = await planner.execute(step1_args)
169 | assert len(result1) > 0
170 |
171 | # Parse and verify the response
172 | import json
173 |
174 | response1 = json.loads(result1[0].text)
175 | assert response1["status"] == "pause_for_planning"
176 | assert response1["next_step_required"] is True
177 | assert "continuation_id" in response1
178 |
179 | # Test step 2 with continuation
180 | continuation_id = response1["continuation_id"]
181 | step2_args = {
182 | "step": "Design the microservices architecture",
183 | "step_number": 2,
184 | "total_steps": 3,
185 | "next_step_required": True,
186 | "continuation_id": continuation_id,
187 | }
188 |
189 | result2 = await planner.execute(step2_args)
190 | assert len(result2) > 0
191 |
192 | response2 = json.loads(result2[0].text)
193 | assert response2["status"] == "pause_for_planning"
194 | assert response2["step_number"] == 2
195 |
196 | def test_other_tools_still_require_models(self):
197 | """
198 | Verify that other tools still properly require model resolution.
199 |
200 | This ensures our fix doesn't break existing functionality.
201 | Note: Debug tool requires model resolution for expert analysis phase.
202 | """
203 | from tools.analyze import AnalyzeTool
204 | from tools.chat import ChatTool
205 | from tools.debug import DebugIssueTool
206 |
207 | # Test various tools still require models
208 | tools_requiring_models = [ChatTool(), AnalyzeTool(), DebugIssueTool()]
209 |
210 | for tool in tools_requiring_models:
211 | assert tool.requires_model() is True, f"{tool.get_name()} should require model resolution"
212 |
213 | # Note: Debug tool requires model resolution for expert analysis phase
214 | # Only planner truly manages its own model calls and doesn't need resolution
215 |
```
--------------------------------------------------------------------------------
/tests/test_challenge.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Tests for Challenge tool - validating critical challenge prompt wrapper
3 |
4 | This module contains unit tests to ensure that the Challenge tool
5 | properly wraps statements to encourage critical thinking and avoid
6 | automatic agreement patterns.
7 | """
8 |
9 | import json
10 | from unittest.mock import patch
11 |
12 | import pytest
13 |
14 | from tools.challenge import ChallengeRequest, ChallengeTool
15 | from tools.shared.exceptions import ToolExecutionError
16 |
17 |
18 | class TestChallengeTool:
19 | """Test suite for Challenge tool"""
20 |
21 | def setup_method(self):
22 | """Set up test fixtures"""
23 | self.tool = ChallengeTool()
24 |
25 | def test_tool_metadata(self):
26 | """Test that tool metadata matches requirements"""
27 | assert self.tool.get_name() == "challenge"
28 | assert "reflexive agreement" in self.tool.get_description()
29 | assert "critical thinking" in self.tool.get_description()
30 | assert "reasoned analysis" in self.tool.get_description()
31 | assert self.tool.get_default_temperature() == 0.2 # TEMPERATURE_ANALYTICAL
32 |
33 | def test_requires_model(self):
34 | """Test that challenge tool doesn't require a model"""
35 | assert self.tool.requires_model() is False
36 |
37 | def test_schema_structure(self):
38 | """Test that schema has correct structure and excludes model fields"""
39 | schema = self.tool.get_input_schema()
40 |
41 | # Basic schema structure
42 | assert schema["type"] == "object"
43 | assert "properties" in schema
44 | assert "required" in schema
45 |
46 | # Required fields
47 | assert "prompt" in schema["required"]
48 | assert len(schema["required"]) == 1 # Only prompt is required
49 |
50 | # Properties
51 | properties = schema["properties"]
52 | assert "prompt" in properties
53 |
54 | # Should NOT have model-related fields since it doesn't require a model
55 | assert "model" not in properties
56 | assert "temperature" not in properties
57 | assert "thinking_mode" not in properties
58 | assert "continuation_id" not in properties
59 |
60 | def test_request_model_validation(self):
61 | """Test that the request model validates correctly"""
62 | # Test valid request
63 | request = ChallengeRequest(prompt="The sky is green")
64 | assert request.prompt == "The sky is green"
65 |
66 | # Test with longer prompt
67 | long_prompt = (
68 | "Machine learning models always produce accurate results and should be trusted without verification"
69 | )
70 | request = ChallengeRequest(prompt=long_prompt)
71 | assert request.prompt == long_prompt
72 |
73 | def test_required_fields(self):
74 | """Test that required fields are enforced"""
75 | from pydantic import ValidationError
76 |
77 | # Missing prompt should raise validation error
78 | with pytest.raises(ValidationError):
79 | ChallengeRequest()
80 |
81 | @pytest.mark.asyncio
82 | async def test_execute_success(self):
83 | """Test successful execution of challenge tool"""
84 | arguments = {"prompt": "All software bugs are caused by syntax errors"}
85 |
86 | result = await self.tool.execute(arguments)
87 |
88 | # Should return a list with TextContent
89 | assert len(result) == 1
90 | assert result[0].type == "text"
91 |
92 | # Parse the JSON response
93 | response_data = json.loads(result[0].text)
94 |
95 | # Check response structure
96 | assert response_data["status"] == "challenge_accepted"
97 | assert response_data["original_statement"] == "All software bugs are caused by syntax errors"
98 | assert "challenge_prompt" in response_data
99 | assert "instructions" in response_data
100 |
101 | # Check that the challenge prompt contains critical thinking instructions
102 | challenge_prompt = response_data["challenge_prompt"]
103 | assert "CRITICAL REASSESSMENT – Do not automatically agree" in challenge_prompt
104 | assert "Carefully evaluate the statement above" in challenge_prompt
105 | assert response_data["original_statement"] in challenge_prompt
106 | assert "flaws, gaps, or misleading points" in challenge_prompt
107 | assert "thoughtful analysis" in challenge_prompt
108 |
109 | @pytest.mark.asyncio
110 | async def test_execute_error_handling(self):
111 | """Test error handling in execute method"""
112 | # Test with invalid arguments (non-dict)
113 | with patch.object(self.tool, "get_request_model", side_effect=Exception("Test error")):
114 | with pytest.raises(ToolExecutionError) as exc_info:
115 | await self.tool.execute({"prompt": "test"})
116 |
117 | response_data = json.loads(exc_info.value.payload)
118 | assert response_data["status"] == "error"
119 | assert "Test error" in response_data["error"]
120 |
121 | def test_wrap_prompt_for_challenge(self):
122 | """Test the prompt wrapping functionality"""
123 | original_prompt = "Python is the best programming language"
124 | wrapped = self.tool._wrap_prompt_for_challenge(original_prompt)
125 |
126 | # Check structure
127 | assert "CRITICAL REASSESSMENT – Do not automatically agree" in wrapped
128 | assert "Carefully evaluate the statement above" in wrapped
129 | assert f'"{original_prompt}"' in wrapped
130 | assert "flaws, gaps, or misleading points" in wrapped
131 | assert "thoughtful analysis" in wrapped
132 |
133 | def test_multiple_prompts(self):
134 | """Test that tool handles various types of prompts correctly"""
135 | test_prompts = [
136 | "All code should be written in assembly for maximum performance",
137 | "Comments are unnecessary if code is self-documenting",
138 | "Testing is a waste of time for experienced developers",
139 | "Global variables make code easier to understand",
140 | "The more design patterns used, the better the code",
141 | ]
142 |
143 | for prompt in test_prompts:
144 | request = ChallengeRequest(prompt=prompt)
145 | wrapped = self.tool._wrap_prompt_for_challenge(request.prompt)
146 |
147 | # Each wrapped prompt should contain the original
148 | assert prompt in wrapped
149 | assert "CRITICAL REASSESSMENT" in wrapped
150 |
151 | def test_tool_fields(self):
152 | """Test tool-specific field definitions"""
153 | fields = self.tool.get_tool_fields()
154 |
155 | assert "prompt" in fields
156 | assert fields["prompt"]["type"] == "string"
157 | assert "Statement to scrutinize" in fields["prompt"]["description"]
158 | assert "strip the word 'challenge'" in fields["prompt"]["description"]
159 |
160 | def test_required_fields_list(self):
161 | """Test required fields list"""
162 | required = self.tool.get_required_fields()
163 | assert required == ["prompt"]
164 |
165 | @pytest.mark.asyncio
166 | async def test_not_used_methods(self):
167 | """Test that methods not used by challenge tool work correctly"""
168 | request = ChallengeRequest(prompt="test")
169 |
170 | # These methods aren't used since challenge doesn't call AI
171 | prompt = await self.tool.prepare_prompt(request)
172 | assert prompt == ""
173 |
174 | response = self.tool.format_response("test response", request)
175 | assert response == "test response"
176 |
177 | def test_special_characters_in_prompt(self):
178 | """Test handling of special characters in prompts"""
179 | special_prompt = 'The "best" way to handle errors is to use try/except: pass'
180 | request = ChallengeRequest(prompt=special_prompt)
181 | wrapped = self.tool._wrap_prompt_for_challenge(request.prompt)
182 |
183 | # Should handle quotes properly
184 | assert special_prompt in wrapped
185 |
186 | @pytest.mark.asyncio
187 | async def test_unicode_support(self):
188 | """Test that tool handles unicode characters correctly"""
189 | unicode_prompt = "软件开发中最重要的是写代码,测试不重要 🚀"
190 | arguments = {"prompt": unicode_prompt}
191 |
192 | result = await self.tool.execute(arguments)
193 | response_data = json.loads(result[0].text)
194 |
195 | assert response_data["original_statement"] == unicode_prompt
196 | assert unicode_prompt in response_data["challenge_prompt"]
197 |
198 |
199 | if __name__ == "__main__":
200 | pytest.main([__file__])
201 |
```
--------------------------------------------------------------------------------
/clink/agents/base.py:
--------------------------------------------------------------------------------
```python
1 | """Execute configured CLI agents for the clink tool and parse output."""
2 |
3 | from __future__ import annotations
4 |
5 | import asyncio
6 | import logging
7 | import os
8 | import shlex
9 | import shutil
10 | import tempfile
11 | import time
12 | from collections.abc import Sequence
13 | from dataclasses import dataclass
14 | from pathlib import Path
15 |
16 | from clink.constants import DEFAULT_STREAM_LIMIT
17 | from clink.models import ResolvedCLIClient, ResolvedCLIRole
18 | from clink.parsers import BaseParser, ParsedCLIResponse, ParserError, get_parser
19 |
20 | logger = logging.getLogger("clink.agent")
21 |
22 |
23 | @dataclass
24 | class AgentOutput:
25 | """Container returned by CLI agents after successful execution."""
26 |
27 | parsed: ParsedCLIResponse
28 | sanitized_command: list[str]
29 | returncode: int
30 | stdout: str
31 | stderr: str
32 | duration_seconds: float
33 | parser_name: str
34 | output_file_content: str | None = None
35 |
36 |
37 | class CLIAgentError(RuntimeError):
38 | """Raised when a CLI agent fails (non-zero exit, timeout, parse errors)."""
39 |
40 | def __init__(self, message: str, *, returncode: int | None = None, stdout: str = "", stderr: str = "") -> None:
41 | super().__init__(message)
42 | self.returncode = returncode
43 | self.stdout = stdout
44 | self.stderr = stderr
45 |
46 |
47 | class BaseCLIAgent:
48 | """Execute a configured CLI command and parse its output."""
49 |
50 | def __init__(self, client: ResolvedCLIClient):
51 | self.client = client
52 | self._parser: BaseParser = get_parser(client.parser)
53 | self._logger = logging.getLogger(f"clink.runner.{client.name}")
54 |
55 | async def run(
56 | self,
57 | *,
58 | role: ResolvedCLIRole,
59 | prompt: str,
60 | system_prompt: str | None = None,
61 | files: Sequence[str],
62 | images: Sequence[str],
63 | ) -> AgentOutput:
64 | # Files and images are already embedded into the prompt by the tool; they are
65 | # accepted here only to keep parity with SimpleTool callers.
66 | _ = (files, images)
67 | # The runner simply executes the configured CLI command for the selected role.
68 | command = self._build_command(role=role, system_prompt=system_prompt)
69 | env = self._build_environment()
70 |
71 | # Resolve executable path for cross-platform compatibility (especially Windows)
72 | executable_name = command[0]
73 | resolved_executable = shutil.which(executable_name)
74 | if resolved_executable is None:
75 | raise CLIAgentError(
76 | f"Executable '{executable_name}' not found in PATH for CLI '{self.client.name}'. "
77 | f"Ensure the command is installed and accessible."
78 | )
79 | command[0] = resolved_executable
80 |
81 | sanitized_command = list(command)
82 |
83 | cwd = str(self.client.working_dir) if self.client.working_dir else None
84 | limit = DEFAULT_STREAM_LIMIT
85 |
86 | stdout_text = ""
87 | stderr_text = ""
88 | output_file_content: str | None = None
89 | start_time = time.monotonic()
90 |
91 | output_file_path: Path | None = None
92 | command_with_output_flag = list(command)
93 |
94 | if self.client.output_to_file:
95 | fd, tmp_path = tempfile.mkstemp(prefix="clink-", suffix=".json")
96 | os.close(fd)
97 | output_file_path = Path(tmp_path)
98 | flag_template = self.client.output_to_file.flag_template
99 | try:
100 | rendered_flag = flag_template.format(path=str(output_file_path))
101 | except KeyError as exc: # pragma: no cover - defensive
102 | raise CLIAgentError(f"Invalid output flag template '{flag_template}': missing placeholder {exc}")
103 | command_with_output_flag.extend(shlex.split(rendered_flag))
104 | sanitized_command = list(command_with_output_flag)
105 |
106 | self._logger.debug("Executing CLI command: %s", " ".join(sanitized_command))
107 | if cwd:
108 | self._logger.debug("Working directory: %s", cwd)
109 |
110 | try:
111 | process = await asyncio.create_subprocess_exec(
112 | *command_with_output_flag,
113 | stdin=asyncio.subprocess.PIPE,
114 | stdout=asyncio.subprocess.PIPE,
115 | stderr=asyncio.subprocess.PIPE,
116 | cwd=cwd,
117 | limit=limit,
118 | env=env,
119 | )
120 | except FileNotFoundError as exc:
121 | raise CLIAgentError(f"Executable not found for CLI '{self.client.name}': {exc}") from exc
122 |
123 | try:
124 | stdout_bytes, stderr_bytes = await asyncio.wait_for(
125 | process.communicate(prompt.encode("utf-8")),
126 | timeout=self.client.timeout_seconds,
127 | )
128 | except asyncio.TimeoutError as exc:
129 | process.kill()
130 | await process.communicate()
131 | raise CLIAgentError(
132 | f"CLI '{self.client.name}' timed out after {self.client.timeout_seconds} seconds",
133 | returncode=None,
134 | ) from exc
135 |
136 | duration = time.monotonic() - start_time
137 | return_code = process.returncode
138 | stdout_text = stdout_bytes.decode("utf-8", errors="replace")
139 | stderr_text = stderr_bytes.decode("utf-8", errors="replace")
140 |
141 | if output_file_path and output_file_path.exists():
142 | output_file_content = output_file_path.read_text(encoding="utf-8", errors="replace")
143 | if self.client.output_to_file and self.client.output_to_file.cleanup:
144 | try:
145 | output_file_path.unlink()
146 | except OSError: # pragma: no cover - best effort cleanup
147 | pass
148 |
149 | if output_file_content and not stdout_text.strip():
150 | stdout_text = output_file_content
151 |
152 | if return_code != 0:
153 | recovered = self._recover_from_error(
154 | returncode=return_code,
155 | stdout=stdout_text,
156 | stderr=stderr_text,
157 | sanitized_command=sanitized_command,
158 | duration_seconds=duration,
159 | output_file_content=output_file_content,
160 | )
161 | if recovered is not None:
162 | return recovered
163 |
164 | if return_code != 0:
165 | raise CLIAgentError(
166 | f"CLI '{self.client.name}' exited with status {return_code}",
167 | returncode=return_code,
168 | stdout=stdout_text,
169 | stderr=stderr_text,
170 | )
171 |
172 | try:
173 | parsed = self._parser.parse(stdout_text, stderr_text)
174 | except ParserError as exc:
175 | raise CLIAgentError(
176 | f"Failed to parse output from CLI '{self.client.name}': {exc}",
177 | returncode=return_code,
178 | stdout=stdout_text,
179 | stderr=stderr_text,
180 | ) from exc
181 |
182 | return AgentOutput(
183 | parsed=parsed,
184 | sanitized_command=sanitized_command,
185 | returncode=return_code,
186 | stdout=stdout_text,
187 | stderr=stderr_text,
188 | duration_seconds=duration,
189 | parser_name=self._parser.name,
190 | output_file_content=output_file_content,
191 | )
192 |
193 | def _build_command(self, *, role: ResolvedCLIRole, system_prompt: str | None) -> list[str]:
194 | base = list(self.client.executable)
195 | base.extend(self.client.internal_args)
196 | base.extend(self.client.config_args)
197 | base.extend(role.role_args)
198 |
199 | return base
200 |
201 | def _build_environment(self) -> dict[str, str]:
202 | env = os.environ.copy()
203 | env.update(self.client.env)
204 | return env
205 |
206 | # ------------------------------------------------------------------
207 | # Error recovery hooks
208 | # ------------------------------------------------------------------
209 |
210 | def _recover_from_error(
211 | self,
212 | *,
213 | returncode: int,
214 | stdout: str,
215 | stderr: str,
216 | sanitized_command: list[str],
217 | duration_seconds: float,
218 | output_file_content: str | None,
219 | ) -> AgentOutput | None:
220 | """Hook for subclasses to convert CLI errors into successful outputs.
221 |
222 | Return an AgentOutput to treat the failure as success, or None to signal
223 | that normal error handling should proceed.
224 | """
225 |
226 | return None
227 |
```
--------------------------------------------------------------------------------
/providers/openrouter.py:
--------------------------------------------------------------------------------
```python
1 | """OpenRouter provider implementation."""
2 |
3 | import logging
4 |
5 | from utils.env import get_env
6 |
7 | from .openai_compatible import OpenAICompatibleProvider
8 | from .registries.openrouter import OpenRouterModelRegistry
9 | from .shared import (
10 | ModelCapabilities,
11 | ProviderType,
12 | RangeTemperatureConstraint,
13 | )
14 |
15 |
16 | class OpenRouterProvider(OpenAICompatibleProvider):
17 | """Client for OpenRouter's multi-model aggregation service.
18 |
19 | Role
20 | Surface OpenRouter’s dynamic catalogue through the same interface as
21 | native providers so tools can reference OpenRouter models and aliases
22 | without special cases.
23 |
24 | Characteristics
25 | * Pulls live model definitions from :class:`OpenRouterModelRegistry`
26 | (aliases, provider-specific metadata, capability hints)
27 | * Applies alias-aware restriction checks before exposing models to the
28 | registry or tooling
29 | * Reuses :class:`OpenAICompatibleProvider` infrastructure for request
30 | execution so OpenRouter endpoints behave like standard OpenAI-style
31 | APIs.
32 | """
33 |
34 | FRIENDLY_NAME = "OpenRouter"
35 |
36 | # Custom headers required by OpenRouter
37 | DEFAULT_HEADERS = {
38 | "HTTP-Referer": get_env("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server")
39 | or "https://github.com/BeehiveInnovations/zen-mcp-server",
40 | "X-Title": get_env("OPENROUTER_TITLE", "Zen MCP Server") or "Zen MCP Server",
41 | }
42 |
43 | # Model registry for managing configurations and aliases
44 | _registry: OpenRouterModelRegistry | None = None
45 |
46 | def __init__(self, api_key: str, **kwargs):
47 | """Initialize OpenRouter provider.
48 |
49 | Args:
50 | api_key: OpenRouter API key
51 | **kwargs: Additional configuration
52 | """
53 | base_url = "https://openrouter.ai/api/v1"
54 | self._alias_cache: dict[str, str] = {}
55 | super().__init__(api_key, base_url=base_url, **kwargs)
56 |
57 | # Initialize model registry
58 | if OpenRouterProvider._registry is None:
59 | OpenRouterProvider._registry = OpenRouterModelRegistry()
60 | # Log loaded models and aliases only on first load
61 | models = self._registry.list_models()
62 | aliases = self._registry.list_aliases()
63 | logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
64 |
65 | # ------------------------------------------------------------------
66 | # Capability surface
67 | # ------------------------------------------------------------------
68 |
69 | def _lookup_capabilities(
70 | self,
71 | canonical_name: str,
72 | requested_name: str | None = None,
73 | ) -> ModelCapabilities | None:
74 | """Fetch OpenRouter capabilities from the registry or build a generic fallback."""
75 |
76 | capabilities = self._registry.get_capabilities(canonical_name)
77 | if capabilities:
78 | return capabilities
79 |
80 | base_identifier = canonical_name.split(":", 1)[0]
81 | if "/" in base_identifier:
82 | logging.debug(
83 | "Using generic OpenRouter capabilities for %s (provider/model format detected)", canonical_name
84 | )
85 | generic = ModelCapabilities(
86 | provider=ProviderType.OPENROUTER,
87 | model_name=canonical_name,
88 | friendly_name=self.FRIENDLY_NAME,
89 | intelligence_score=9,
90 | context_window=32_768,
91 | max_output_tokens=32_768,
92 | supports_extended_thinking=False,
93 | supports_system_prompts=True,
94 | supports_streaming=True,
95 | supports_function_calling=False,
96 | temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
97 | )
98 | generic._is_generic = True
99 | return generic
100 |
101 | logging.debug(
102 | "Rejecting unknown OpenRouter model '%s' (no provider prefix); requires explicit configuration",
103 | canonical_name,
104 | )
105 | return None
106 |
107 | # ------------------------------------------------------------------
108 | # Provider identity
109 | # ------------------------------------------------------------------
110 |
111 | def get_provider_type(self) -> ProviderType:
112 | """Identify this provider for restrictions and logging."""
113 | return ProviderType.OPENROUTER
114 |
115 | # ------------------------------------------------------------------
116 | # Registry helpers
117 | # ------------------------------------------------------------------
118 |
119 | def list_models(
120 | self,
121 | *,
122 | respect_restrictions: bool = True,
123 | include_aliases: bool = True,
124 | lowercase: bool = False,
125 | unique: bool = False,
126 | ) -> list[str]:
127 | """Return formatted OpenRouter model names, respecting alias-aware restrictions."""
128 |
129 | if not self._registry:
130 | return []
131 |
132 | from utils.model_restrictions import get_restriction_service
133 |
134 | restriction_service = get_restriction_service() if respect_restrictions else None
135 | allowed_configs: dict[str, ModelCapabilities] = {}
136 |
137 | for model_name in self._registry.list_models():
138 | config = self._registry.resolve(model_name)
139 | if not config:
140 | continue
141 |
142 | # Custom models belong to CustomProvider; skip them here so the two
143 | # providers don't race over the same registrations (important for tests
144 | # that stub the registry with minimal objects lacking attrs).
145 | if config.provider == ProviderType.CUSTOM:
146 | continue
147 |
148 | if restriction_service:
149 | allowed = restriction_service.is_allowed(self.get_provider_type(), model_name)
150 |
151 | if not allowed and config.aliases:
152 | for alias in config.aliases:
153 | if restriction_service.is_allowed(self.get_provider_type(), alias):
154 | allowed = True
155 | break
156 |
157 | if not allowed:
158 | continue
159 |
160 | allowed_configs[model_name] = config
161 |
162 | if not allowed_configs:
163 | return []
164 |
165 | # When restrictions are in place, don't include aliases to avoid confusion
166 | # Only return the canonical model names that are actually allowed
167 | actual_include_aliases = include_aliases and not respect_restrictions
168 |
169 | return ModelCapabilities.collect_model_names(
170 | allowed_configs,
171 | include_aliases=actual_include_aliases,
172 | lowercase=lowercase,
173 | unique=unique,
174 | )
175 |
176 | # ------------------------------------------------------------------
177 | # Registry helpers
178 | # ------------------------------------------------------------------
179 |
180 | def _resolve_model_name(self, model_name: str) -> str:
181 | """Resolve aliases defined in the OpenRouter registry."""
182 |
183 | cache_key = model_name.lower()
184 | if cache_key in self._alias_cache:
185 | return self._alias_cache[cache_key]
186 |
187 | config = self._registry.resolve(model_name)
188 | if config:
189 | if config.model_name != model_name:
190 | logging.debug("Resolved model alias '%s' to '%s'", model_name, config.model_name)
191 | resolved = config.model_name
192 | self._alias_cache[cache_key] = resolved
193 | self._alias_cache.setdefault(resolved.lower(), resolved)
194 | return resolved
195 |
196 | logging.debug(f"Model '{model_name}' not found in registry, using as-is")
197 | self._alias_cache[cache_key] = model_name
198 | return model_name
199 |
200 | def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
201 | """Expose registry-backed OpenRouter capabilities."""
202 |
203 | if not self._registry:
204 | return {}
205 |
206 | capabilities: dict[str, ModelCapabilities] = {}
207 | for model_name in self._registry.list_models():
208 | config = self._registry.resolve(model_name)
209 | if not config:
210 | continue
211 |
212 | # See note in list_models: respect the CustomProvider boundary.
213 | if config.provider == ProviderType.CUSTOM:
214 | continue
215 |
216 | capabilities[model_name] = config
217 | return capabilities
218 |
```
--------------------------------------------------------------------------------
/tests/test_auto_mode_custom_provider_only.py:
--------------------------------------------------------------------------------
```python
1 | """Test auto mode with only custom provider configured to reproduce the reported issue."""
2 |
3 | import importlib
4 | import os
5 | from unittest.mock import patch
6 |
7 | import pytest
8 |
9 | from providers.registry import ModelProviderRegistry
10 | from providers.shared import ProviderType
11 |
12 |
13 | @pytest.mark.no_mock_provider
14 | class TestAutoModeCustomProviderOnly:
15 | """Test auto mode when only custom provider is configured."""
16 |
17 | def setup_method(self):
18 | """Set up clean state before each test."""
19 | # Save original environment state for restoration
20 | self._original_env = {}
21 | for key in [
22 | "GEMINI_API_KEY",
23 | "OPENAI_API_KEY",
24 | "XAI_API_KEY",
25 | "OPENROUTER_API_KEY",
26 | "CUSTOM_API_URL",
27 | "CUSTOM_API_KEY",
28 | "DEFAULT_MODEL",
29 | ]:
30 | self._original_env[key] = os.environ.get(key)
31 |
32 | # Clear restriction service cache
33 | import utils.model_restrictions
34 |
35 | utils.model_restrictions._restriction_service = None
36 |
37 | # Clear provider registry by resetting singleton instance
38 | ModelProviderRegistry._instance = None
39 |
40 | def teardown_method(self):
41 | """Clean up after each test."""
42 | # Restore original environment
43 | for key, value in self._original_env.items():
44 | if value is not None:
45 | os.environ[key] = value
46 | elif key in os.environ:
47 | del os.environ[key]
48 |
49 | # Reload config to pick up the restored environment
50 | import config
51 |
52 | importlib.reload(config)
53 |
54 | # Clear restriction service cache
55 | import utils.model_restrictions
56 |
57 | utils.model_restrictions._restriction_service = None
58 |
59 | # Clear provider registry by resetting singleton instance
60 | ModelProviderRegistry._instance = None
61 |
62 | def test_reproduce_auto_mode_custom_provider_only_issue(self):
63 | """Test the fix for auto mode failing when only custom provider is configured."""
64 |
65 | # Set up environment with ONLY custom provider configured
66 | test_env = {
67 | "CUSTOM_API_URL": "http://localhost:11434/v1",
68 | "CUSTOM_API_KEY": "", # Empty for Ollama-style
69 | "DEFAULT_MODEL": "auto",
70 | }
71 |
72 | # Clear all other provider keys
73 | clear_keys = ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]
74 |
75 | with patch.dict(os.environ, test_env, clear=False):
76 | # Ensure other provider keys are not set
77 | for key in clear_keys:
78 | if key in os.environ:
79 | del os.environ[key]
80 |
81 | # Reload config to pick up auto mode
82 | import config
83 |
84 | importlib.reload(config)
85 |
86 | # Register only the custom provider (simulating server startup)
87 | from providers.custom import CustomProvider
88 |
89 | ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
90 |
91 | # This should now work after the fix
92 | # The fix added support for custom provider registry system in get_available_models()
93 | available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
94 |
95 | # This assertion should now pass after the fix
96 | assert available_models, (
97 | "Expected custom provider models to be available. "
98 | "This test verifies the fix for auto mode failing with custom providers."
99 | )
100 |
101 | def test_custom_provider_models_available_via_registry(self):
102 | """Test that custom provider has models available via its registry system."""
103 |
104 | # Set up environment with only custom provider
105 | test_env = {
106 | "CUSTOM_API_URL": "http://localhost:11434/v1",
107 | "CUSTOM_API_KEY": "",
108 | }
109 |
110 | with patch.dict(os.environ, test_env, clear=False):
111 | # Clear other provider keys
112 | for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]:
113 | if key in os.environ:
114 | del os.environ[key]
115 |
116 | # Register custom provider
117 | from providers.custom import CustomProvider
118 |
119 | ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
120 |
121 | # Get the provider instance
122 | custom_provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
123 | assert custom_provider is not None, "Custom provider should be available"
124 |
125 | # Verify it has a registry with models
126 | assert hasattr(custom_provider, "_registry"), "Custom provider should have _registry"
127 | assert custom_provider._registry is not None, "Registry should be initialized"
128 |
129 | # Get models from registry
130 | models = custom_provider._registry.list_models()
131 | aliases = custom_provider._registry.list_aliases()
132 |
133 | # Should have some models and aliases available
134 | assert models, "Custom provider registry should have models"
135 | assert aliases, "Custom provider registry should have aliases"
136 |
137 | print(f"Available models: {len(models)}")
138 | print(f"Available aliases: {len(aliases)}")
139 |
140 | def test_custom_provider_validate_model_name(self):
141 | """Test that custom provider can validate model names."""
142 |
143 | # Set up environment with only custom provider
144 | test_env = {
145 | "CUSTOM_API_URL": "http://localhost:11434/v1",
146 | "CUSTOM_API_KEY": "",
147 | }
148 |
149 | with patch.dict(os.environ, test_env, clear=False):
150 | # Register custom provider
151 | from providers.custom import CustomProvider
152 |
153 | ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
154 |
155 | # Get the provider instance
156 | custom_provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
157 | assert custom_provider is not None
158 |
159 | # Test that it can validate some typical custom model names
160 | test_models = ["llama3.2", "llama3.2:latest", "local-model", "ollama-model"]
161 |
162 | for model in test_models:
163 | is_valid = custom_provider.validate_model_name(model)
164 | print(f"Model '{model}' validation: {is_valid}")
165 | # Should validate at least some local-style models
166 | # (The exact validation logic may vary based on registry content)
167 |
168 | def test_auto_mode_fallback_with_custom_only_should_work(self):
169 | """Test that auto mode fallback should work when only custom provider is available."""
170 |
171 | # Set up environment with only custom provider
172 | test_env = {
173 | "CUSTOM_API_URL": "http://localhost:11434/v1",
174 | "CUSTOM_API_KEY": "",
175 | "DEFAULT_MODEL": "auto",
176 | }
177 |
178 | with patch.dict(os.environ, test_env, clear=False):
179 | # Clear other provider keys
180 | for key in ["GEMINI_API_KEY", "OPENAI_API_KEY", "XAI_API_KEY", "OPENROUTER_API_KEY", "DIAL_API_KEY"]:
181 | if key in os.environ:
182 | del os.environ[key]
183 |
184 | # Reload config
185 | import config
186 |
187 | importlib.reload(config)
188 |
189 | # Register custom provider
190 | from providers.custom import CustomProvider
191 |
192 | ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
193 |
194 | # This should work and return a fallback model from custom provider
195 | # Currently fails because get_preferred_fallback_model doesn't consider custom models
196 | from tools.models import ToolModelCategory
197 |
198 | try:
199 | fallback_model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
200 | print(f"Fallback model for FAST_RESPONSE: {fallback_model}")
201 |
202 | # Should get a valid model name, not the hardcoded fallback
203 | assert (
204 | fallback_model != "gemini-2.5-flash"
205 | ), "Should not fallback to hardcoded Gemini model when custom provider is available"
206 |
207 | except Exception as e:
208 | pytest.fail(f"Getting fallback model failed: {e}")
209 |
```
--------------------------------------------------------------------------------
/docs/adding_tools.md:
--------------------------------------------------------------------------------
```markdown
1 | # Adding Tools to Zen MCP Server
2 |
3 | Zen MCP tools are Python classes that inherit from the shared infrastructure in `tools/shared/base_tool.py`.
4 | Every tool must provide a request model (Pydantic), a system prompt, and the methods the base class marks as
5 | abstract. The quickest path to a working tool is to copy an existing implementation that matches your use case
6 | (`tools/chat.py` for simple request/response tools, `tools/consensus.py` or `tools/codereview.py` for workflows).
7 | This document captures the minimal steps required to add a new tool without drifting from the current codebase.
8 |
9 | ## 1. Pick the Tool Architecture
10 |
11 | Zen supports two architectures, implemented in `tools/simple/base.py` and `tools/workflow/base.py`.
12 |
13 | - **SimpleTool** (`SimpleTool`): single MCP call – request comes in, you build one prompt, call the model, return.
14 | The base class handles schema generation, conversation threading, file loading, temperature bounds, retries,
15 | and response formatting hooks.
16 | - **WorkflowTool** (`WorkflowTool`): multi-step workflows driven by `BaseWorkflowMixin`. The tool accumulates
17 | findings across steps, forces Claude to pause between investigations, and optionally calls an expert model at
18 | the end. Use this whenever you need structured multi-step work (debug, code review, consensus, etc.).
19 |
20 | If you are unsure, compare `tools/chat.py` (SimpleTool) and `tools/consensus.py` (WorkflowTool) to see the patterns.
21 |
22 | ## 2. Common Responsibilities
23 |
24 | Regardless of architecture, subclasses of `BaseTool` must provide:
25 |
26 | - `get_name()`: unique string identifier used in the MCP registry.
27 | - `get_description()`: concise, action-oriented summary for clients.
28 | - `get_system_prompt()`: import your prompt from `systemprompts/` and return it.
29 | - `get_input_schema()`: leverage the schema builders (`SchemaBuilder` or `WorkflowSchemaBuilder`) or override to
30 | match an existing contract exactly.
31 | - `get_request_model()`: return the Pydantic model used to validate the incoming arguments.
32 | - `async prepare_prompt(...)`: assemble the content sent to the model. You can reuse helpers like
33 | `prepare_chat_style_prompt` or `build_standard_prompt`.
34 |
35 | The base class already handles model selection (`ToolModelCategory`), conversation memory, token budgeting, safety
36 | failures, retries, and serialization. Override hooks like `get_default_temperature`, `get_model_category`, or
37 | `format_response` only when you need behaviour different from the defaults.
38 |
39 | ## 3. Implementing a Simple Tool
40 |
41 | 1. **Define a request model** that inherits from `tools.shared.base_models.ToolRequest` to describe the fields and
42 | validation rules for your tool.
43 | 2. **Implement the tool class** by inheriting from `SimpleTool` and overriding the required methods. Most tools can
44 | rely on `SchemaBuilder` and the shared field constants already exposed on `SimpleTool`.
45 |
46 | ```python
47 | from pydantic import Field
48 | from systemprompts import CHAT_PROMPT
49 | from tools.shared.base_models import ToolRequest
50 | from tools.simple.base import SimpleTool
51 |
52 | class ChatRequest(ToolRequest):
53 | prompt: str = Field(..., description="Your question or idea.")
54 | absolute_file_paths: list[str] | None = Field(default_factory=list)
55 | working_directory_absolute_path: str = Field(
56 | ...,
57 | description="Absolute path to an existing directory where generated code can be saved.",
58 | )
59 |
60 | class ChatTool(SimpleTool):
61 | def get_name(self) -> str: # required by BaseTool
62 | return "chat"
63 |
64 | def get_description(self) -> str:
65 | return "General chat and collaborative thinking partner."
66 |
67 | def get_system_prompt(self) -> str:
68 | return CHAT_PROMPT
69 |
70 | def get_request_model(self):
71 | return ChatRequest
72 |
73 | def get_tool_fields(self) -> dict[str, dict[str, object]]:
74 | return {
75 | "prompt": {"type": "string", "description": "Your question."},
76 | "absolute_file_paths": SimpleTool.FILES_FIELD,
77 | "working_directory_absolute_path": {
78 | "type": "string",
79 | "description": "Absolute path to an existing directory for generated code artifacts.",
80 | },
81 | }
82 |
83 | def get_required_fields(self) -> list[str]:
84 | return ["prompt", "working_directory_absolute_path"]
85 |
86 | async def prepare_prompt(self, request: ChatRequest) -> str:
87 | return self.prepare_chat_style_prompt(request)
88 | ```
89 |
90 | Only implement `get_input_schema()` manually if you must preserve an existing schema contract (see
91 | `tools/chat.py` for an example). Otherwise `SimpleTool.get_input_schema()` merges your field definitions with the
92 | common parameters (temperature, model, continuation_id, etc.).
93 |
94 | ## 4. Implementing a Workflow Tool
95 |
96 | Workflow tools extend `WorkflowTool`, which mixes in `BaseWorkflowMixin` for step tracking and expert analysis.
97 |
98 | 1. **Create a request model** that inherits from `tools.shared.base_models.WorkflowRequest` (or a subclass) and add
99 | any tool-specific fields or validators. Examples: `CodeReviewRequest`, `ConsensusRequest`.
100 | 2. **Override the workflow hooks** to steer the investigation. At minimum you must implement
101 | `get_required_actions(...)`; override `should_call_expert_analysis(...)` and
102 | `prepare_expert_analysis_context(...)` when the expert model call should happen conditionally.
103 | 3. **Expose the schema** either by returning `WorkflowSchemaBuilder.build_schema(...)` (the default implementation on
104 | `WorkflowTool` already does this) or by overriding `get_input_schema()` if you need custom descriptions/enums.
105 |
106 | ```python
107 | from pydantic import Field
108 | from systemprompts import CONSENSUS_PROMPT
109 | from tools.shared.base_models import WorkflowRequest
110 | from tools.workflow.base import WorkflowTool
111 |
112 | class ConsensusRequest(WorkflowRequest):
113 | models: list[dict] = Field(..., description="Models to consult (with optional stance).")
114 |
115 | class ConsensusTool(WorkflowTool):
116 | def get_name(self) -> str:
117 | return "consensus"
118 |
119 | def get_description(self) -> str:
120 | return "Multi-model consensus workflow with expert synthesis."
121 |
122 | def get_system_prompt(self) -> str:
123 | return CONSENSUS_PROMPT
124 |
125 | def get_workflow_request_model(self):
126 | return ConsensusRequest
127 |
128 | def get_required_actions(self, step_number: int, confidence: str, findings: str, total_steps: int, request=None) -> list[str]:
129 | if step_number == 1:
130 | return ["Write the shared proposal all models will evaluate."]
131 | return ["Summarize the latest model response before moving on."]
132 |
133 | def should_call_expert_analysis(self, consolidated_findings, request=None) -> bool:
134 | return not (request and request.next_step_required)
135 |
136 | def prepare_expert_analysis_context(self, consolidated_findings) -> str:
137 | return "\n".join(consolidated_findings.findings)
138 | ```
139 |
140 | `WorkflowTool` already records work history, merges findings, and handles continuation IDs. Use helpers such as
141 | `get_standard_required_actions` when you want default guidance, and override `requires_expert_analysis()` if the tool
142 | never calls out to the assistant model.
143 |
144 | ## 5. Register the Tool
145 |
146 | 1. **Create or reuse a system prompt** in `systemprompts/your_tool_prompt.py` and export it from
147 | `systemprompts/__init__.py`.
148 | 2. **Expose the tool class** from `tools/__init__.py` so that `server.py` can import it.
149 | 3. **Add an instance to the `TOOLS` dictionary** in `server.py`. This makes the tool callable via MCP.
150 | 4. **(Optional) Add a prompt template** to `PROMPT_TEMPLATES` in `server.py` if you want clients to show a canned
151 | launch command.
152 | 5. Confirm that `DISABLED_TOOLS` environment variable handling covers the new tool if you need to toggle it.
153 |
154 | ## 6. Validate the Tool
155 |
156 | - Run unit tests that cover any new request/response logic: `python -m pytest tests/ -v -m "not integration"`.
157 | - Add a simulator scenario in `simulator_tests/communication_simulator_test.py` to exercise the tool end-to-end and
158 | run it with `python communication_simulator_test.py --individual <case>` or `--quick` for the fast smoke suite.
159 | - If the tool interacts with external providers or multiple models, consider integration coverage via
160 | `./run_integration_tests.sh --with-simulator`.
161 |
162 | Following the steps above keeps new tools aligned with the existing infrastructure and avoids drift between the
163 | documentation and the actual base classes.
164 |
```