This is page 1 of 2. Use http://codebase.md/rusiaaman/wcgw?page={x} to view the full context.
# Directory Structure
```
├── .github
│ └── workflows
│ ├── python-publish.yml
│ ├── python-tests.yml
│ └── python-types.yml
├── .gitignore
├── .gitmodules
├── .python-version
├── .vscode
│ └── settings.json
├── CLAUDE.md
├── Dockerfile
├── LICENSE
├── pyproject.toml
├── README.md
├── src
│ ├── wcgw
│ │ ├── __init__.py
│ │ ├── client
│ │ │ ├── __init__.py
│ │ │ ├── bash_state
│ │ │ │ ├── bash_state.py
│ │ │ │ └── parser
│ │ │ │ ├── __init__.py
│ │ │ │ └── bash_statement_parser.py
│ │ │ ├── common.py
│ │ │ ├── diff-instructions.txt
│ │ │ ├── encoder
│ │ │ │ └── __init__.py
│ │ │ ├── file_ops
│ │ │ │ ├── diff_edit.py
│ │ │ │ ├── extensions.py
│ │ │ │ └── search_replace.py
│ │ │ ├── mcp_server
│ │ │ │ ├── __init__.py
│ │ │ │ ├── Readme.md
│ │ │ │ └── server.py
│ │ │ ├── memory.py
│ │ │ ├── modes.py
│ │ │ ├── repo_ops
│ │ │ │ ├── display_tree.py
│ │ │ │ ├── file_stats.py
│ │ │ │ ├── path_prob.py
│ │ │ │ ├── paths_model.vocab
│ │ │ │ ├── paths_tokens.model
│ │ │ │ └── repo_context.py
│ │ │ ├── schema_generator.py
│ │ │ ├── tool_prompts.py
│ │ │ └── tools.py
│ │ ├── py.typed
│ │ └── types_.py
│ └── wcgw_cli
│ ├── __init__.py
│ ├── __main__.py
│ ├── anthropic_client.py
│ ├── cli.py
│ ├── openai_client.py
│ └── openai_utils.py
├── static
│ ├── claude-ss.jpg
│ ├── computer-use.jpg
│ ├── example.jpg
│ ├── rocket-icon.png
│ ├── ss1.png
│ └── workflow-demo.gif
├── tests
│ ├── test_bash_parser_complex.py
│ ├── test_bash_parser.py
│ ├── test_bg_commands.py
│ ├── test_edit.py
│ ├── test_file_range_tracking.py
│ ├── test_mcp_server.py
│ ├── test_readfiles.py
│ └── test_tools.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
```
```
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
```
3.12
```
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
```
.venv
.env
.wcgw
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
```
--------------------------------------------------------------------------------
/src/wcgw/client/mcp_server/Readme.md:
--------------------------------------------------------------------------------
```markdown
# The doc has moved to main Readme.md

```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
```markdown
# Shell and Coding agent for Claude and other mcp clients
Empowering chat applications to code, build and run on your local machine.
wcgw is an MCP server with tightly integrated shell and code editing tools.
⚠️ Warning: do not allow BashCommand tool without reviewing the command, it may result in data loss.
[](https://github.com/rusiaaman/wcgw/actions/workflows/python-tests.yml)
[](https://github.com/rusiaaman/wcgw/actions/workflows/python-types.yml)
[](https://github.com/rusiaaman/wcgw/actions/workflows/python-publish.yml)
[](https://codecov.io/gh/rusiaaman/wcgw)
## Demo

## Updates
- [6 Oct 2025] Model can now run multiple commands in background. ZSH is now a supported shell. Multiplexing improvements.
- [27 Apr 2025] Removed support for GPTs over relay server. Only MCP server is supported in version >= 5.
- [24 Mar 2025] Improved writing and editing experience for sonnet 3.7, CLAUDE.md gets loaded automatically.
- [16 Feb 2025] You can now attach to the working terminal that the AI uses. See the "attach-to-terminal" section below.
- [15 Jan 2025] Modes introduced: architect, code-writer, and all powerful wcgw mode.
- [8 Jan 2025] Context saving tool for saving relevant file paths along with a description in a single file. Can be used as a task checkpoint or for knowledge transfer.
- [29 Dec 2024] Syntax checking on file writing and edits is now stable. Made `initialize` tool call useful; sending smart repo structure to claude if any repo is referenced. Large file handling is also now improved.
- [9 Dec 2024] [Vscode extension to paste context on Claude app](https://marketplace.visualstudio.com/items?itemName=AmanRusia.wcgw)
## 🚀 Highlights
- ⚡ **Create, Execute, Iterate**: Ask claude to keep running compiler checks till all errors are fixed, or ask it to keep checking for the status of a long running command till it's done.
- ⚡ **Large file edit**: Supports large file incremental edits to avoid token limit issues. Smartly selects when to do small edits or large rewrite based on % of change needed.
- ⚡ **Syntax checking on edits**: Reports feedback to the LLM if its edits have any syntax errors, so that it can redo it.
- ⚡ **Interactive Command Handling**: Supports interactive commands using arrow keys, interrupt, and ansi escape sequences.
- ⚡ **File protections**:
- The AI needs to read a file at least once before it's allowed to edit or rewrite it. This avoids accidental overwrites.
- Avoids context filling up while reading very large files. Files get chunked based on token length.
- On initialisation the provided workspace's directory structure is returned after selecting important files (based on .gitignore as well as a statistical approach)
- File edit based on search-replace tries to find correct search block if it has multiple matches based on previous search blocks. Fails otherwise (for correctness).
- File edit has spacing tolerant matching, with warning on issues like indentation mismatch. If there's no match, the closest match is returned to the AI to fix its mistakes.
- Using Aider-like search and replace, which has better performance than tool call based search and replace.
- ⚡ **Shell optimizations**:
- Current working directory is always returned after any shell command to prevent AI from getting lost.
- Command polling exits after a quick timeout to avoid slow feedback. However, status checking has wait tolerance based on fresh output streaming from a command. Both of these approach combined provides a good shell interaction experience.
- Supports multiple concurrent background commands alongside the main interactive shell.
- ⚡ **Saving repo context in a single file**: Task checkpointing using "ContextSave" tool saves detailed context in a single file. Tasks can later be resumed in a new chat asking "Resume `task id`". The saved file can be used to do other kinds of knowledge transfer, such as taking help from another AI.
- ⚡ **Easily switch between various modes**:
- Ask it to run in 'architect' mode for planning. Inspired by adier's architect mode, work with Claude to come up with a plan first. Leads to better accuracy and prevents premature file editing.
- Ask it to run in 'code-writer' mode for code editing and project building. You can provide specific paths with wild card support to prevent other files getting edited.
- By default it runs in 'wcgw' mode that has no restrictions and full authorisation.
- More details in [Modes section](#modes)
- ⚡ **Runs in multiplex terminal** Use [vscode extension](https://marketplace.visualstudio.com/items?itemName=AmanRusia.wcgw) or run `screen -x` to attach to the terminal that the AI runs commands on. See history or interrupt process or interact with the same terminal that AI uses.
- ⚡ **Automatically load CLAUDE.md/AGENTS.md** Loads "CLAUDE.md" or "AGENTS.md" file in project root and sends as instructions during initialisation. Instructions in a global "~/.wcgw/CLAUDE.md" or "~/.wcgw/AGENTS.md" file are loaded and added along with project specific CLAUDE.md. The file name is case sensitive. CLAUDE.md is attached if it's present otherwise AGENTS.md is attached.
## Top use cases examples
- Solve problem X using python, create and run test cases and fix any issues. Do it in a temporary directory
- Find instances of code with X behavior in my repository
- Git clone https://github.com/my/repo in my home directory, then understand the project, set up the environment and build
- Create a golang htmx tailwind webapp, then open browser to see if it works (use with puppeteer mcp)
- Edit or update a large file
- In a separate branch create feature Y, then use github cli to create a PR to original branch
- Command X is failing in Y directory, please run and fix issues
- Using X virtual environment run Y command
- Using cli tools, create build and test an android app. Finally run it using emulator for me to use
- Fix all mypy issues in my repo at X path.
- Using 'screen' run my server in background instead, then run another api server in bg, finally run the frontend build. Keep checking logs for any issues in all three
- Create repo wide unittest cases. Keep iterating through files and creating cases. Also keep running the tests after each update. Do not modify original code.
## Claude setup (using mcp)
### Mac and linux
First install `uv` using homebrew `brew install uv`
(**Important:** use homebrew to install uv. Otherwise make sure `uv` is present in a global location like /usr/bin/)
Then create or update `claude_desktop_config.json` (~/Library/Application Support/Claude/claude_desktop_config.json) with following json.
```json
{
"mcpServers": {
"wcgw": {
"command": "uvx",
"args": ["wcgw@latest"]
}
}
}
```
Then restart claude app.
**Optional: Force a specific shell**
To use a specific shell (bash or zsh), add the `--shell` argument:
```json
{
"mcpServers": {
"wcgw": {
"command": "uvx",
"args": ["wcgw@latest", "--shell", "/bin/bash"]
}
}
}
```
_If there's an error in setting up_
- If there's an error like "uv ENOENT", make sure `uv` is installed. Then run 'which uv' in the terminal, and use its output in place of "uv" in the configuration.
- If there's still an issue, check that `uv tool run --python 3.12 wcgw` runs in your terminal. It should have no output and shouldn't exit.
- Try removing ~/.cache/uv folder
- Try using `uv` version `0.6.0` for which this tool was tested.
- Debug the mcp server using `npx @modelcontextprotocol/[email protected] uv tool run --python 3.12 wcgw`
### Windows on wsl
This mcp server works only on wsl on windows.
To set it up, [install uv](https://docs.astral.sh/uv/getting-started/installation/)
Then add or update the claude config file `%APPDATA%\Claude\claude_desktop_config.json` with the following
```json
{
"mcpServers": {
"wcgw": {
"command": "wsl.exe",
"args": ["uvx", "wcgw@latest"]
}
}
}
```
When you encounter an error, execute the command wsl uv --python 3.12 wcgw in command prompt. If you get the `error /bin/bash: line 1: uv: command not found`, it means uv was not installed globally and you need to point to the correct path of uv.
1. Find where uv is installed:
```bash
whereis uv
```
Example output:
```uv: /home/mywsl/.local/bin/uv```
2. Test the full path works:
```
wsl /home/mywsl/.local/bin/uv tool run --python 3.12 wcgw
```
3. Update the config with the full path:
```
{
"mcpServers": {
"wcgw": {
"command": "wsl.exe",
"args": ["/home/mywsl/.local/bin/uv", "tool", "run", "--python", "3.12", "wcgw"]
}
}
}
```
Replace `/home/mywsl/.local/bin/uv` with your actual uv path from step 1.
### Usage
Wait for a few seconds. You should be able to see this icon if everything goes right.

over here

Then ask claude to execute shell commands, read files, edit files, run your code, etc.
#### Task checkpoint or knowledge transfer
- You can do a task checkpoint or a knowledge transfer by attaching "KnowledgeTransfer" prompt using "Attach from MCP" button.
- On running "KnowledgeTransfer" prompt, the "ContextSave" tool will be called saving the task description and all file content together in a single file. An id for the task will be generated.
- You can in a new chat say "Resume '<task id>'", the AI should then call "Initialize" with the task id and load the context from there.
- Or you can directly open the file generated and share it with another AI for help.
#### Modes
There are three built-in modes. You may ask Claude to run in one of the modes, like "Use 'architect' mode"
| **Mode** | **Description** | **Allows** | **Denies** | **Invoke prompt** |
|-----------------|-----------------------------------------------------------------------------|---------------------------------------------------------|----------------------------------------------|----------------------------------------------------------------------------------------------------|
| **Architect** | Designed for you to work with Claude to investigate and understand your repo. | Read-only commands | FileEdit and Write tool | Run in mode='architect' |
| **Code-writer** | For code writing and development | Specified path globs for editing or writing, specified commands | FileEdit for paths not matching specified glob, Write for paths not matching specified glob | Run in code writer mode, only 'tests/**' allowed, only uv command allowed |
| **wcgw\*\* | Default mode with everything allowed | Everything | Nothing | No prompt, or "Run in wcgw mode" |
Note: in code-writer mode either all commands are allowed or none are allowed for now. If you give a list of allowed commands, Claude is instructed to run only those commands, but no actual check happens. (WIP)
#### Attach to the working terminal to investigate
NEW: the [vscode extension](https://marketplace.visualstudio.com/items?itemName=AmanRusia.wcgw) now automatically attach the running terminal
if workspace path matches.
If you've `screen` command installed, wcgw runs on a screen instance automatically. If you've started wcgw mcp server, you can list the screen sessions:
`screen -ls`
And note down the wcgw screen name which will be something like `93358.wcgw.235521` where the last number is in the hour-minute-second format.
You can then attach to the session using `screen -x 93358.wcgw.235521`
You may interrupt any running command safely.
You can interact with the terminal safely, for example for entering passwords, or entering some text. (Warning: If you run a new command, any new LLM command will interrupt it.)
You shouldn't exit the session using `exit `or Ctrl-d, instead you should use `ctrl+a+d` to safely detach without destroying the screen session.
Include the following in ~/.screenrc for better scrolling experience
```
defscrollback 10000
termcapinfo xterm* ti@:te@
```
### [Optional] Vs code extension
https://marketplace.visualstudio.com/items?itemName=AmanRusia.wcgw
Commands:
- Select a text and press `cmd+'` and then enter instructions. This will switch the app to Claude and paste a text containing your instructions, file path, workspace dir, and the selected text.
## Examples

## Using mcp server over docker
First build the docker image `docker build -t wcgw https://github.com/rusiaaman/wcgw.git`
Then you can update `/Users/username/Library/Application Support/Claude/claude_desktop_config.json` to have
```
{
"mcpServers": {
"wcgw": {
"command": "docker",
"args": [
"run",
"-i",
"--rm",
"--mount",
"type=bind,src=/Users/username/Desktop,dst=/workspace/Desktop",
"wcgw"
]
}
}
}
```
## [Optional] Local shell access with openai API key or anthropic API key
### Openai
Add `OPENAI_API_KEY` and `OPENAI_ORG_ID` env variables.
Then run
`uvx wcgw wcgw_local --limit 0.1` # Cost limit $0.1
You can now directly write messages or press enter key to open vim for multiline message and text pasting.
### Anthropic
Add `ANTHROPIC_API_KEY` env variable.
Then run
`uvx wcgw wcgw_local --claude`
You can now directly write messages or press enter key to open vim for multiline message and text pasting.
## Tools
The server provides the following MCP tools:
**Shell Operations:**
- `Initialize`: Reset shell and set up workspace environment
- Parameters: `any_workspace_path` (string), `initial_files_to_read` (string[]), `mode_name` ("wcgw"|"architect"|"code_writer"), `task_id_to_resume` (string)
- `BashCommand`: Execute shell commands with timeout control
- Parameters: `command` (string), `wait_for_seconds` (int, optional)
- Parameters: `send_text` (string) or `send_specials` (["Enter"|"Key-up"|...]) or `send_ascii` (int[]), `wait_for_seconds` (int, optional)
**File Operations:**
- `ReadFiles`: Read content from one or more files
- Parameters: `file_paths` (string[])
- `WriteIfEmpty`: Create new files or write to empty files
- Parameters: `file_path` (string), `file_content` (string)
- `FileEdit`: Edit existing files using search/replace blocks
- Parameters: `file_path` (string), `file_edit_using_search_replace_blocks` (string)
- `ReadImage`: Read image files for display/processing
- Parameters: `file_path` (string)
**Project Management:**
- `ContextSave`: Save project context and files for Knowledge Transfer or saving task checkpoints to be resumed later
- Parameters: `id` (string), `project_root_path` (string), `description` (string), `relevant_file_globs` (string[])
All tools support absolute paths and include built-in protections against common errors. See the [MCP specification](https://modelcontextprotocol.io/) for detailed protocol information.
```
--------------------------------------------------------------------------------
/CLAUDE.md:
--------------------------------------------------------------------------------
```markdown
# Alignment instructions to contribute to this repository
## Hard rules
- Make sure mypy --strict passes for these two folders `uv run mypy --strict src/wcgw src/wcgw_cli`.
- Use `list` directly for typing like `list[str]` no need to import `List`. Same thing for `tuple`, `set`, etc.
- No optional parameters in a function with default values. All parameters must be passed by a caller.
- This library uses `uv` as package manager. To add a package `uv add numpy`. To run pytest `uv run pytest` and so on.
## Coding mantras
### Reduce states and dependencies between the states
- Don't introduce any state unless really necessary.
- If anything can be derived, avoid storing it or passing it.
#### Python `Exception` guideline 1
- Exception thrown inside functions are their hidden extra state which should be avoided.
- Parse don't validate: avoid throwing validation errors by letting the types avoid bad values to be passed in the first place.
### Put burden on type checker not the code reader
- No hidden contracts and assumptions.
- Don't assume any relationship between two states unless it's encoded in the type of the state.
- Any contract should be enforced by the way types are constructed.
- If it's just not possible due to complexity to type in such a way to avoid hidden contract, add in docstring details.
#### Python `Exception` guideline 2
- When you can't avoid it, instead of enforcing the hidden contract as hard failure during runtime, try to return some sensible value instead.
_Example_
In PIL adding boxes outside image bounds don't do anything, but they don't fail either, making it a cleaner experience to deal with edge cases.
- A functions signature (along with types) should be enough to understand its purpose.
- This can be achieved by typing the parameters to only take narrow types
### Functions should be as pure as possible
- Avoid mutating mutable input parameters, instead return newly derived values in the output and leave upto the caller to update the state if required.
- It should be clear from function signature what the function computes, this should also enforce the previous point of not updating mutable input parameters.
```
--------------------------------------------------------------------------------
/src/wcgw/client/__init__.py:
--------------------------------------------------------------------------------
```python
```
--------------------------------------------------------------------------------
/src/wcgw_cli/__init__.py:
--------------------------------------------------------------------------------
```python
from .cli import app
```
--------------------------------------------------------------------------------
/src/wcgw_cli/__main__.py:
--------------------------------------------------------------------------------
```python
from .cli import app
app()
```
--------------------------------------------------------------------------------
/src/wcgw/__init__.py:
--------------------------------------------------------------------------------
```python
from .client.mcp_server import main as mcp_server
# Export mcp_server as the default entry point for wcgw
listen = mcp_server
```
--------------------------------------------------------------------------------
/src/wcgw/client/bash_state/parser/__init__.py:
--------------------------------------------------------------------------------
```python
"""
Parser for bash statements using tree-sitter.
This module provides functionality to parse and identify individual bash statements.
"""
from .bash_statement_parser import BashStatementParser, Statement
```
--------------------------------------------------------------------------------
/.github/workflows/python-types.yml:
--------------------------------------------------------------------------------
```yaml
name: Mypy strict
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
typecheck:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "${{ matrix.python-version }}"
- name: Install dependencies
run: |
pip install uv
- name: Run type checks
run: |
uv run --frozen --group types --python "${{ matrix.python-version }}" mypy --strict src/wcgw
```
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
```json
{
"python.defaultInterpreterPath": ".venv/bin/python",
"mypy-type-checker.interpreter": [
".venv/bin/python"
],
"mypy.extraArguments": [
"--enable-incomplete-feature=NewGenericSyntax",
"--strict"
],
"mypy-type-checker.args": [
"--enable-incomplete-feature=NewGenericSyntax",
"--strict",
],
"editor.formatOnSave": true,
"editor.renderWhitespace": "selection",
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.codeActionsOnSave": {
"source.fixAll": "always",
"source.organizeImports": "always",
}
},
"python-envs.pythonProjects": []
}
```
--------------------------------------------------------------------------------
/src/wcgw/client/mcp_server/__init__.py:
--------------------------------------------------------------------------------
```python
# mypy: disable-error-code="import-untyped"
import asyncio
import importlib
import typer
from typer import Typer
from wcgw.client.mcp_server import server
main = Typer()
@main.command()
def app(
version: bool = typer.Option(
False, "--version", "-v", help="Show version and exit"
),
shell: str = typer.Option(
"", "--shell", help="Path to shell executable (defaults to $SHELL or /bin/bash)"
),
) -> None:
"""Main entry point for the package."""
if version:
version_ = importlib.metadata.version("wcgw")
print(f"wcgw version: {version_}")
raise typer.Exit()
asyncio.run(server.main(shell))
# Optionally expose other important items at package level
__all__ = ["main", "server"]
```
--------------------------------------------------------------------------------
/.github/workflows/python-tests.yml:
--------------------------------------------------------------------------------
```yaml
name: Python Test
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "${{ matrix.python-version }}"
- name: Install dependencies
run: |
pip install uv
- name: Run tests with coverage
run: |
uv run --frozen --group tests --python "${{ matrix.python-version }}" pytest --cov=wcgw --cov-report=xml --cov-report=term-missing
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4
if: success()
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
```
--------------------------------------------------------------------------------
/src/wcgw_cli/cli.py:
--------------------------------------------------------------------------------
```python
import importlib
from typing import Optional
import typer
from typer import Typer
from wcgw_cli.anthropic_client import loop as claude_loop
from wcgw_cli.openai_client import loop as openai_loop
app = Typer(pretty_exceptions_show_locals=False)
@app.command()
def loop(
claude: bool = False,
first_message: Optional[str] = None,
limit: Optional[float] = None,
resume: Optional[str] = None,
version: bool = typer.Option(False, "--version", "-v"),
) -> tuple[str, float]:
if version:
version_ = importlib.metadata.version("wcgw")
print(f"wcgw version: {version_}")
exit()
if claude:
return claude_loop(
first_message=first_message,
limit=limit,
resume=resume,
)
else:
return openai_loop(
first_message=first_message,
limit=limit,
resume=resume,
)
if __name__ == "__main__":
app()
```
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
```yaml
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
```
--------------------------------------------------------------------------------
/src/wcgw/client/common.py:
--------------------------------------------------------------------------------
```python
import select
import sys
import termios
import tty
from typing import Literal
from pydantic import BaseModel
class CostData(BaseModel):
cost_per_1m_input_tokens: float
cost_per_1m_output_tokens: float
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ParsedChatCompletionMessage,
)
History = list[ChatCompletionMessageParam]
Models = Literal["gpt-4o-2024-08-06", "gpt-4o-mini"]
def discard_input() -> None:
try:
# Get the file descriptor for stdin
fd = sys.stdin.fileno()
# Save current terminal settings
old_settings = termios.tcgetattr(fd)
try:
# Switch terminal to non-canonical mode where input is read immediately
tty.setcbreak(fd)
# Discard all input
while True:
# Check if there is input to be read
if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
sys.stdin.read(
1
) # Read one character at a time to flush the input buffer
else:
break
finally:
# Restore old terminal settings
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
except (termios.error, ValueError) as e:
# Handle the error gracefully
print(f"Warning: Unable to discard input. Error: {e}")
```
--------------------------------------------------------------------------------
/src/wcgw/client/encoder/__init__.py:
--------------------------------------------------------------------------------
```python
import threading
from typing import Callable, Protocol, TypeVar, cast
import tokenizers # type: ignore[import-untyped]
T = TypeVar("T")
class EncoderDecoder(Protocol[T]):
def encoder(self, text: str) -> list[T]: ...
def decoder(self, tokens: list[T]) -> str: ...
class LazyEncoder:
def __init__(self) -> None:
self._tokenizer: tokenizers.Tokenizer | None = None
self._init_lock = threading.Lock()
self._init_thread = threading.Thread(target=self._initialize, daemon=True)
self._init_thread.start()
def _initialize(self) -> None:
with self._init_lock:
if self._tokenizer is None:
self._tokenizer = tokenizers.Tokenizer.from_pretrained(
"Xenova/claude-tokenizer"
)
def _ensure_initialized(self) -> None:
if self._tokenizer is None:
with self._init_lock:
if self._tokenizer is None:
self._init_thread.join()
def encoder(self, text: str) -> list[int]:
self._ensure_initialized()
assert self._tokenizer is not None, "Couldn't initialize tokenizer"
return cast(list[int], self._tokenizer.encode(text).ids)
def decoder(self, tokens: list[int]) -> str:
self._ensure_initialized()
assert self._tokenizer is not None, "Couldn't initialize tokenizer"
return cast(str, self._tokenizer.decode(tokens))
def get_default_encoder() -> EncoderDecoder[int]:
return LazyEncoder()
```
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
```dockerfile
# Generated by https://smithery.ai. See: https://smithery.ai/docs/config#dockerfile
# Start from a Python base image with the necessary tools
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim AS uv
# Set the working directory in the container
WORKDIR /app
# Copy the project's pyproject.toml and lock file for dependency installation
COPY pyproject.toml /app/
COPY uv.lock /app/
COPY README.md /app/
# Enable bytecode compilation and set link mode to copy for dependencies
ENV UV_COMPILE_BYTECODE=1
ENV UV_LINK_MODE=copy
# No need for git as we don't need to clone submodules anymore
RUN apt-get update && rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --frozen --no-install-project --no-dev --no-editable
# Copy the entire project into the container
COPY src /app/src
# No need to clone the submodule as it has been removed
# Install the project
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --frozen --no-dev --no-editable
# Use a smaller image to run the application
FROM python:3.12-slim-bookworm
RUN apt-get update && apt-get install -y screen && rm -rf /var/lib/apt/lists/*
# Create app user and group
RUN groupadd -r app && useradd -r -g app app
# Set the working directory in the container
WORKDIR /workspace
# Copy the installed application from the previous stage
COPY --from=uv --chown=app:app /app/.venv /app/.venv
# No need to copy the submodule as it has been removed
# Add the virtual environment to the PATH
ENV PATH="/app/.venv/bin:$PATH"
# Specify the command to run on container start
ENTRYPOINT ["wcgw_mcp"]
```
--------------------------------------------------------------------------------
/src/wcgw/client/diff-instructions.txt:
--------------------------------------------------------------------------------
```
Instructions for editing files.
# Example
## Input file
```
import numpy as np
from impls import impl1, impl2
def hello():
"print a greeting"
print("hello")
def call_hello():
"call hello"
hello()
print("Called")
impl1()
hello()
impl2()
```
## Edit format on the input file
```
<<<<<<< SEARCH
from impls import impl1, impl2
=======
from impls import impl1, impl2
from hello import hello as hello_renamed
>>>>>>> REPLACE
<<<<<<< SEARCH
def hello():
"print a greeting"
print("hello")
=======
>>>>>>> REPLACE
<<<<<<< SEARCH
def call_hello():
"call hello"
hello()
=======
def call_hello_renamed():
"call hello renamed"
hello_renamed()
>>>>>>> REPLACE
<<<<<<< SEARCH
impl1()
hello()
impl2()
=======
impl1()
hello_renamed()
impl2()
>>>>>>> REPLACE
```
# *SEARCH/REPLACE block* Rules:
Every "<<<<<<< SEARCH" section must *EXACTLY MATCH* the existing file content, character for character, including all comments, docstrings, whitespaces, etc.
Including multiple unique *SEARCH/REPLACE* blocks if needed.
Include enough and only enough lines in each SEARCH section to uniquely match each set of lines that need to change.
Keep *SEARCH/REPLACE* blocks concise.
Break large *SEARCH/REPLACE* blocks into a series of smaller blocks that each change a small portion of the file.
Include just the changing lines, and a few surrounding lines (0-3 lines) if needed for uniqueness.
Other than for uniqueness, avoid including those lines which do not change in search (and replace) blocks. Target 0-3 non trivial extra lines per block.
Preserve leading spaces and indentations in both SEARCH and REPLACE blocks.
```
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
```toml
[project]
authors = [{ name = "Aman Rusia", email = "[email protected]" }]
name = "wcgw"
version = "5.5.1"
description = "Shell and coding agent for Claude and other mcp clients"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"openai>=1.46.0",
"typer>=0.12.5",
"rich>=13.8.1",
"python-dotenv>=1.0.1",
"pexpect>=4.9.0",
"toml>=0.10.2",
"petname>=2.6",
"pyte>=0.8.2",
"fastapi>=0.115.0",
"uvicorn>=0.31.0",
"websockets>=13.1",
"pydantic>=2.9.2",
"semantic-version>=2.10.0",
"anthropic>=0.39.0",
"tokenizers>=0.21.0",
"pygit2>=1.16.0",
"syntax-checker==0.4.0",
"psutil>=7.0.0",
"tree-sitter>=0.24.0",
"tree-sitter-bash>=0.23.3",
"mcp>=1.7.0",
"wcmatch>=10.1",
]
[project.urls]
Homepage = "https://github.com/rusiaaman/wcgw"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/wcgw", "src/wcgw_cli"]
[tool.hatch.build.targets.wheel.sources]
"src/wcgw" = "wcgw"
"src/wcgw_cli" = "wcgw_cli"
[project.scripts]
wcgw_local = "wcgw_cli:app"
wcgw = "wcgw:mcp_server"
wcgw_mcp = "wcgw:mcp_server"
[tool.uv]
default-groups = []
[dependency-groups]
types = [
"mypy>=1.11.2",
"types-toml>=0.10.8.20240310",
"types-pexpect>=4.9.0.20241208",
"types-psutil>=7.0.0.20250218",
"line-profiler>=4.2.0",
]
tests = [
"pytest>=8.0.0",
"pytest-cov>=4.1.0",
"pytest-asyncio>=0.25.3",
]
dev = [
"autoflake",
"ipython>=8.12.3",
"gunicorn>=23.0.0",
"line-profiler>=4.2.0",
]
[tool.pytest.ini_options]
addopts = "--cov=wcgw --cov-report=term-missing --cov-report=html"
testpaths = ["tests"]
```
--------------------------------------------------------------------------------
/src/wcgw/client/repo_ops/path_prob.py:
--------------------------------------------------------------------------------
```python
from typing import Dict, List, Tuple
import tokenizers # type: ignore[import-untyped]
class FastPathAnalyzer:
def __init__(self, model_path: str, vocab_path: str) -> None:
"""Initialize with vocabulary."""
# Load vocabulary and probabilities
self.vocab_probs: Dict[str, float] = {}
with open(vocab_path, "r") as f:
for line in f:
parts = line.strip().split()
if len(parts) == 2:
token, prob = parts
try:
self.vocab_probs[token] = float(prob)
except ValueError:
continue
self.encoder = tokenizers.Tokenizer.from_file(model_path)
def tokenize_batch(self, texts: List[str]) -> List[List[str]]:
"""Tokenize multiple texts at once."""
encodings = self.encoder.encode_batch(texts)
return [encoding.tokens for encoding in encodings]
def detokenize(self, tokens: List[str]) -> str:
"""Convert tokens back to text, handling special tokens."""
return self.encoder.decode(tokens) # type: ignore[no-any-return]
def calculate_path_probabilities_batch(
self, paths: List[str]
) -> List[Tuple[float, List[str], List[str]]]:
"""Calculate log probability for multiple paths at once."""
# Batch tokenize all paths
all_tokens = self.tokenize_batch(paths)
results = []
for tokens in all_tokens:
# Calculate sum of log probabilities for each path
log_prob_sum = 0.0
unknown_tokens = []
for token in tokens:
if token in self.vocab_probs:
log_prob_sum += self.vocab_probs[token]
else:
unknown_tokens.append(token)
results.append((log_prob_sum, tokens, unknown_tokens))
return results
def calculate_path_probability(
self, path: str
) -> Tuple[float, List[str], List[str]]:
"""Calculate log probability for a single path."""
return self.calculate_path_probabilities_batch([path])[0]
```
--------------------------------------------------------------------------------
/src/wcgw/client/schema_generator.py:
--------------------------------------------------------------------------------
```python
"""
Custom JSON schema generator to remove title fields from Pydantic models.
This module provides utilities to remove auto-generated title fields from JSON schemas,
making them more suitable for tool schemas where titles are not needed.
"""
import copy
from typing import Any, Dict
def recursive_purge_dict_key(d: Dict[str, Any], k: str) -> None:
"""
Remove a key from a dictionary recursively, but only from JSON schema metadata.
This function removes the specified key from dictionaries that appear to be
JSON schema objects (have "type" or "$ref" or are property definitions).
This prevents removing legitimate data fields that happen to have the same name.
Args:
d: The dictionary to clean
k: The key to remove (typically "title")
"""
if isinstance(d, dict):
# Only remove the key if this looks like a JSON schema object
# This includes objects with "type", "$ref", or if we're in a "properties" context
is_schema_object = (
"type" in d or
"$ref" in d or
any(schema_key in d for schema_key in ["properties", "items", "additionalProperties", "enum", "const", "anyOf", "allOf", "oneOf"])
)
if is_schema_object and k in d:
del d[k]
# Recursively process all values, regardless of key names
# This ensures we catch all nested structures
for key, value in d.items():
if isinstance(value, dict):
recursive_purge_dict_key(value, k)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
recursive_purge_dict_key(item, k)
def remove_titles_from_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Remove all 'title' keys from a JSON schema dictionary.
This function creates a copy of the schema and removes all title keys
recursively, making it suitable for use with APIs that don't need titles.
Args:
schema: The JSON schema dictionary to clean
Returns:
A new dictionary with all title keys removed
"""
schema_copy = copy.deepcopy(schema)
recursive_purge_dict_key(schema_copy, "title")
return schema_copy
```
--------------------------------------------------------------------------------
/src/wcgw_cli/openai_utils.py:
--------------------------------------------------------------------------------
```python
from typing import cast
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ChatCompletionMessageParam,
ParsedChatCompletionMessage,
)
from tokenizers import Tokenizer # type: ignore[import-untyped]
from wcgw.client.common import CostData, History
def get_input_cost(
cost_map: CostData, enc: Tokenizer, history: History
) -> tuple[float, int]:
input_tokens = 0
for msg in history:
content = msg["content"]
refusal = msg.get("refusal")
if isinstance(content, list):
for part in content:
if "text" in part:
input_tokens += len(enc.encode(part["text"]))
elif content is None:
if refusal is None:
raise ValueError("Expected content or refusal to be present")
input_tokens += len(enc.encode(str(refusal)))
elif not isinstance(content, str):
raise ValueError(f"Expected content to be string, got {type(content)}")
else:
input_tokens += len(enc.encode(content))
cost = input_tokens * cost_map.cost_per_1m_input_tokens / 1_000_000
return cost, input_tokens
def get_output_cost(
cost_map: CostData,
enc: Tokenizer,
item: ChatCompletionMessage | ChatCompletionMessageParam,
) -> tuple[float, int]:
if isinstance(item, ChatCompletionMessage):
content = item.content
if not isinstance(content, str):
raise ValueError(f"Expected content to be string, got {type(content)}")
else:
if not isinstance(item["content"], str):
raise ValueError(
f"Expected content to be string, got {type(item['content'])}"
)
content = item["content"]
if item["role"] == "tool":
return 0, 0
output_tokens = len(enc.encode(content))
if "tool_calls" in item:
item = cast(ChatCompletionAssistantMessageParam, item)
toolcalls = item["tool_calls"]
for tool_call in toolcalls or []:
output_tokens += len(enc.encode(tool_call["function"]["arguments"]))
elif isinstance(item, ParsedChatCompletionMessage):
if item.tool_calls:
for tool_callf in item.tool_calls:
output_tokens += len(enc.encode(tool_callf.function.arguments))
cost = output_tokens * cost_map.cost_per_1m_output_tokens / 1_000_000
return cost, output_tokens
```
--------------------------------------------------------------------------------
/tests/test_bash_parser.py:
--------------------------------------------------------------------------------
```python
"""
Tests for the bash statement parser.
"""
import pytest
from unittest.mock import patch
from wcgw.client.bash_state.parser.bash_statement_parser import BashStatementParser
def test_bash_statement_parser_basic():
"""Test basic statement parsing."""
parser = BashStatementParser()
# Test single statement
statements = parser.parse_string("echo hello")
assert len(statements) == 1
assert statements[0].text == "echo hello"
# Test command with newlines inside string
statements = parser.parse_string('echo "hello\nworld"')
assert len(statements) == 1
# Test command with && chain
statements = parser.parse_string("echo hello && echo world")
assert len(statements) == 1
# Test command with || chain
statements = parser.parse_string("echo hello || echo world")
assert len(statements) == 1
# Test command with pipe
statements = parser.parse_string("echo hello | grep hello")
assert len(statements) == 1
def test_bash_statement_parser_multiple():
"""Test multiple statement detection."""
parser = BashStatementParser()
# Test multiple statements on separate lines
statements = parser.parse_string("echo hello\necho world")
assert len(statements) == 2
# Test multiple statements with semicolons
statements = parser.parse_string("echo hello; echo world")
assert len(statements) == 2
# Test more complex case
statements = parser.parse_string("echo hello; echo world && echo again")
assert len(statements) == 2
# Test mixed separation
statements = parser.parse_string("echo a; echo b\necho c")
assert len(statements) == 3
def test_bash_statement_parser_complex():
"""Test complex statement handling."""
parser = BashStatementParser()
# Test subshell
statements = parser.parse_string("(echo hello; echo world)")
assert len(statements) == 1
# Test braces
statements = parser.parse_string("{ echo hello; echo world; }")
assert len(statements) == 1
# Test semicolons in strings
statements = parser.parse_string('echo "hello;world"')
assert len(statements) == 1
# Test escaped semicolons
statements = parser.parse_string('echo hello\\; echo world')
assert len(statements) == 1
# Test quoted semicolons
statements = parser.parse_string("echo 'hello;world'")
assert len(statements) == 1
```
--------------------------------------------------------------------------------
/tests/test_bash_parser_complex.py:
--------------------------------------------------------------------------------
```python
"""
Tests specifically for complex bash parsing scenarios.
"""
from wcgw.client.bash_state.parser.bash_statement_parser import BashStatementParser
def test_semicolon_lists():
"""Test parsing of semicolon-separated commands."""
parser = BashStatementParser()
# Simple case: two commands separated by semicolon
statements = parser.parse_string("echo a; echo b")
assert len(statements) == 2
assert statements[0].text.strip() == "echo a"
assert statements[1].text.strip() == "echo b"
# Multiple semicolons
statements = parser.parse_string("echo a; echo b; echo c")
assert len(statements) == 3
assert statements[0].text.strip() == "echo a"
assert statements[1].text.strip() == "echo b"
assert statements[2].text.strip() == "echo c"
# Semicolons with whitespace
statements = parser.parse_string("echo a ; echo b")
assert len(statements) == 2
assert statements[0].text.strip() == "echo a"
assert statements[1].text.strip() == "echo b"
def test_bash_command_with_semicolons_in_quotes():
"""Test that semicolons inside quotes don't split statements."""
parser = BashStatementParser()
# Semicolon in single quotes
statements = parser.parse_string("echo 'a;b'")
assert len(statements) == 1
# Semicolon in double quotes
statements = parser.parse_string('echo "a;b"')
assert len(statements) == 1
# Mixed quotes
statements = parser.parse_string("echo \"a;b\" ; echo 'c;d'")
assert len(statements) == 2
def test_complex_commands():
"""Test complex command scenarios."""
parser = BashStatementParser()
# Command with redirection and semicolon
statements = parser.parse_string("cat > file.txt << EOF\ntest\nEOF\n; echo done")
assert len(statements) == 2
# Command with subshell and semicolon
statements = parser.parse_string("(cd /tmp && echo 'in tmp'); echo 'outside'")
assert len(statements) == 2
# Command with braces and semicolon
statements = parser.parse_string("{ echo a; echo b; }; echo c")
assert len(statements) == 2
def test_command_chaining():
"""Test command chains are treated as a single statement."""
parser = BashStatementParser()
# AND chaining
statements = parser.parse_string("echo a && echo b")
assert len(statements) == 1
# OR chaining
statements = parser.parse_string("echo a || echo b")
assert len(statements) == 1
# Pipe chaining
statements = parser.parse_string("echo a | grep a")
assert len(statements) == 1
# Mixed chaining
statements = parser.parse_string("echo a && echo b || echo c")
assert len(statements) == 1
```
--------------------------------------------------------------------------------
/src/wcgw/client/memory.py:
--------------------------------------------------------------------------------
```python
import json
import os
import re
import shlex
from typing import Any, Callable, Optional, TypeVar
from ..types_ import ContextSave
def get_app_dir_xdg() -> str:
xdg_data_dir = os.environ.get("XDG_DATA_HOME", os.path.expanduser("~/.local/share"))
return os.path.join(xdg_data_dir, "wcgw")
def format_memory(task_memory: ContextSave, relevant_files: str) -> str:
memory_data = ""
if task_memory.project_root_path:
memory_data += (
f"# PROJECT ROOT = {shlex.quote(task_memory.project_root_path)}\n"
)
memory_data += task_memory.description
memory_data += (
"\n\n"
+ "# Relevant file paths\n"
+ ", ".join(map(shlex.quote, task_memory.relevant_file_globs))
)
memory_data += "\n\n# Relevant Files:\n" + relevant_files
return memory_data
def save_memory(
task_memory: ContextSave,
relevant_files: str,
bash_state_dict: Optional[dict[str, Any]] = None,
) -> str:
app_dir = get_app_dir_xdg()
memory_dir = os.path.join(app_dir, "memory")
os.makedirs(memory_dir, exist_ok=True)
task_id = task_memory.id
if not task_id:
raise Exception("Task id can not be empty")
memory_data = format_memory(task_memory, relevant_files)
memory_file_full = os.path.join(memory_dir, f"{task_id}.txt")
with open(memory_file_full, "w") as f:
f.write(memory_data)
# Save bash state if provided
if bash_state_dict is not None:
state_file = os.path.join(memory_dir, f"{task_id}_bash_state.json")
with open(state_file, "w") as f:
json.dump(bash_state_dict, f, indent=2)
return memory_file_full
T = TypeVar("T")
def load_memory(
task_id: str,
coding_max_tokens: Optional[int],
noncoding_max_tokens: Optional[int],
encoder: Callable[[str], list[T]],
decoder: Callable[[list[T]], str],
) -> tuple[str, str, Optional[dict[str, Any]]]:
app_dir = get_app_dir_xdg()
memory_dir = os.path.join(app_dir, "memory")
memory_file = os.path.join(memory_dir, f"{task_id}.txt")
with open(memory_file, "r") as f:
data = f.read()
# Memory files are considered non-code files for token limits
max_tokens = noncoding_max_tokens
if max_tokens:
toks = encoder(data)
if len(toks) > max_tokens:
toks = toks[: max(0, max_tokens - 10)]
data = decoder(toks)
data += "\n(... truncated)"
project_root_match = re.search(r"# PROJECT ROOT = \s*(.*?)\s*$", data, re.MULTILINE)
project_root_path = ""
if project_root_match:
matched_path = project_root_match.group(1)
parsed_ = shlex.split(matched_path)
if parsed_ and len(parsed_) == 1:
project_root_path = parsed_[0]
# Try to load bash state if exists
state_file = os.path.join(memory_dir, f"{task_id}_bash_state.json")
bash_state: Optional[dict[str, Any]] = None
if os.path.exists(state_file):
with open(state_file) as f:
bash_state = json.load(f)
return project_root_path, data, bash_state
```
--------------------------------------------------------------------------------
/src/wcgw/client/file_ops/extensions.py:
--------------------------------------------------------------------------------
```python
"""
File with definitions of known source code file extensions.
Used to determine the appropriate context length for files.
Supports selecting between coding_max_tokens and noncoding_max_tokens
based on file extensions.
"""
from typing import Dict, Optional, Set
# Set of file extensions considered to be source code
# Each extension should be listed without the dot (e.g., 'py' not '.py')
SOURCE_CODE_EXTENSIONS: Set[str] = {
# Python
'py', 'pyx', 'pyi', 'pyw',
# JavaScript and TypeScript
'js', 'jsx', 'ts', 'tsx', 'mjs', 'cjs',
# Web
'html', 'htm', 'xhtml', 'css', 'scss', 'sass', 'less',
# C and C++
'c', 'h', 'cpp', 'cxx', 'cc', 'hpp', 'hxx', 'hh', 'inl',
# C#
'cs', 'csx',
# Java
'java', 'scala', 'kt', 'kts', 'groovy',
# Go
'go', 'mod',
# Rust
'rs', 'rlib',
# Swift
'swift',
# Ruby
'rb', 'rake', 'gemspec',
# PHP
'php', 'phtml', 'phar', 'phps',
# Shell
'sh', 'bash', 'zsh', 'fish',
# PowerShell
'ps1', 'psm1', 'psd1',
# SQL
'sql', 'ddl', 'dml',
# Markup and config
'xml', 'json', 'yaml', 'yml', 'toml', 'ini', 'cfg', 'conf',
# Documentation
'md', 'markdown', 'rst', 'adoc', 'tex',
# Build and dependency files
'Makefile', 'Dockerfile', 'Jenkinsfile',
# Haskell
'hs', 'lhs',
# Lisp family
'lisp', 'cl', 'el', 'clj', 'cljs', 'edn', 'scm',
# Erlang and Elixir
'erl', 'hrl', 'ex', 'exs',
# Dart and Flutter
'dart',
# Objective-C
'm', 'mm',
}
# Context length limits based on file type (in tokens)
CONTEXT_LENGTH_LIMITS: Dict[str, int] = {
'source_code': 24000, # For known source code files
'default': 8000, # For all other files
}
def is_source_code_file(filename: str) -> bool:
"""
Determine if a file is a source code file based on its extension.
Args:
filename: The name of the file to check
Returns:
True if the file has a recognized source code extension, False otherwise
"""
# Extract extension (without the dot)
parts = filename.split('.')
if len(parts) > 1:
ext = parts[-1].lower()
return ext in SOURCE_CODE_EXTENSIONS
# Files without extensions (like 'Makefile', 'Dockerfile')
# Case-insensitive match for files without extensions
return filename.lower() in {ext.lower() for ext in SOURCE_CODE_EXTENSIONS}
def get_context_length_for_file(filename: str) -> int:
"""
Get the appropriate context length limit for a file based on its extension.
Args:
filename: The name of the file to check
Returns:
The context length limit in tokens
"""
if is_source_code_file(filename):
return CONTEXT_LENGTH_LIMITS['source_code']
return CONTEXT_LENGTH_LIMITS['default']
def select_max_tokens(filename: str, coding_max_tokens: Optional[int], noncoding_max_tokens: Optional[int]) -> Optional[int]:
"""
Select the appropriate max_tokens limit based on file type.
Args:
filename: The name of the file to check
coding_max_tokens: Maximum tokens for source code files
noncoding_max_tokens: Maximum tokens for non-source code files
Returns:
The appropriate max_tokens limit for the file
"""
if coding_max_tokens is None and noncoding_max_tokens is None:
return None
if is_source_code_file(filename):
return coding_max_tokens
return noncoding_max_tokens
```
--------------------------------------------------------------------------------
/tests/test_readfiles.py:
--------------------------------------------------------------------------------
```python
import os
import tempfile
from typing import Optional
from wcgw.types_ import ReadFiles
def test_readfiles_line_number_parsing():
# Create a temporary file
with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp:
tmp.write("Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n")
tmp_path = tmp.name
try:
# Test with no line numbers
read_files = ReadFiles(file_paths=[tmp_path])
assert read_files.file_paths == [tmp_path]
assert read_files.start_line_nums == [None]
assert read_files.end_line_nums == [None]
# Test with start line only (e.g., file.py:2)
read_files = ReadFiles(file_paths=[f"{tmp_path}:2"])
assert read_files.file_paths == [tmp_path]
assert read_files.start_line_nums == [2]
assert read_files.end_line_nums == [None]
# Test with end line only (e.g., file.py:-3)
read_files = ReadFiles(file_paths=[f"{tmp_path}:-3"])
assert read_files.file_paths == [tmp_path]
assert read_files.start_line_nums == [None]
assert read_files.end_line_nums == [3]
# Test with start and end lines (e.g., file.py:2-4)
read_files = ReadFiles(file_paths=[f"{tmp_path}:2-4"])
assert read_files.file_paths == [tmp_path]
assert read_files.start_line_nums == [2]
assert read_files.end_line_nums == [4]
# Test with start line and beyond (e.g., file.py:5-)
read_files = ReadFiles(file_paths=[f"{tmp_path}:5-"])
assert read_files.file_paths == [tmp_path]
assert read_files.start_line_nums == [5]
assert read_files.end_line_nums == [None]
# Test with multiple files
read_files = ReadFiles(file_paths=[
tmp_path,
f"{tmp_path}:2-3",
f"{tmp_path}:1-"
])
assert read_files.file_paths == [tmp_path, tmp_path, tmp_path]
assert read_files.start_line_nums == [None, 2, 1]
assert read_files.end_line_nums == [None, 3, None]
# Test with invalid line numbers
read_files = ReadFiles(file_paths=[f"{tmp_path}:invalid-line"])
assert read_files.file_paths == [f"{tmp_path}:invalid-line"] # Should keep the whole path
assert read_files.start_line_nums == [None]
assert read_files.end_line_nums == [None]
# Test with files that legitimately contain colons
filename_with_colon = f"{tmp_path}:colon_in_name"
read_files = ReadFiles(file_paths=[filename_with_colon])
assert read_files.file_paths == [filename_with_colon] # Should keep the whole path
assert read_files.start_line_nums == [None]
assert read_files.end_line_nums == [None]
# Test with URLs that contain colons
url_path = "/path/to/http://example.com/file.txt"
read_files = ReadFiles(file_paths=[url_path])
assert read_files.file_paths == [url_path] # Should keep the whole path
assert read_files.start_line_nums == [None]
assert read_files.end_line_nums == [None]
# Test with URLs that contain colons followed by valid line numbers
url_path_with_line = "/path/to/http://example.com/file.txt:10-20"
read_files = ReadFiles(file_paths=[url_path_with_line])
assert read_files.file_paths == ["/path/to/http://example.com/file.txt"]
assert read_files.start_line_nums == [10]
assert read_files.end_line_nums == [20]
finally:
# Clean up: remove the temporary file
os.unlink(tmp_path)
if __name__ == "__main__":
test_readfiles_line_number_parsing()
print("All tests passed!")
```
--------------------------------------------------------------------------------
/src/wcgw/client/repo_ops/display_tree.py:
--------------------------------------------------------------------------------
```python
import io
from pathlib import Path
from typing import List, Set
class DirectoryTree:
def __init__(self, root: Path, max_files: int = 10):
"""
Initialize the DirectoryTree with a root path and maximum number of files to display
Args:
root_path: The root directory path to start from
max_files: Maximum number of files to display in unexpanded directories
"""
self.root = root
self.max_files = max_files
self.expanded_files: Set[Path] = set()
self.expanded_dirs: Set[Path] = set()
if not self.root.exists():
raise ValueError(f"Root path {root} does not exist")
if not self.root.is_dir():
raise ValueError(f"Root path {root} is not a directory")
def expand(self, rel_path: str) -> None:
"""
Expand a specific file in the tree
Args:
rel_path: Relative path from root to the file to expand
"""
abs_path = self.root / rel_path
if not abs_path.exists():
return
if not abs_path.is_file():
return
if not str(abs_path).startswith(str(self.root)):
return
self.expanded_files.add(abs_path)
# Add all parent directories to expanded dirs
current = abs_path.parent
while str(current) >= str(self.root):
if current not in self.expanded_dirs:
self.expanded_dirs.add(current)
if current == current.parent:
break
current = current.parent
def _list_directory(self, dir_path: Path) -> List[Path]:
"""List contents of a directory, sorted with directories first"""
contents = list(dir_path.iterdir())
return sorted(contents, key=lambda x: (not x.is_dir(), x.name.lower()))
def _count_hidden_items(
self, dir_path: Path, shown_items: List[Path]
) -> tuple[int, int]:
"""Count hidden files and directories in a directory"""
all_items = set(self._list_directory(dir_path))
shown_items_set = set(shown_items)
hidden_items = all_items - shown_items_set
hidden_files = sum(1 for p in hidden_items if p.is_file())
hidden_dirs = sum(1 for p in hidden_items if p.is_dir())
return hidden_files, hidden_dirs
def display(self) -> str:
"""Display the directory tree with expanded state"""
writer = io.StringIO()
def _display_recursive(
current_path: Path, indent: int = 0, depth: int = 0
) -> None:
# Print current directory name with a trailing slash for directories
if current_path == self.root:
writer.write(f"{current_path}/\n")
else:
writer.write(f"{' ' * indent}{current_path.name}/\n")
# Don't recurse beyond depth 1 unless path contains expanded files
if depth > 0 and current_path not in self.expanded_dirs:
return
# Get directory contents
contents = self._list_directory(current_path)
shown_items = []
for item in contents:
# Show items only if:
# 1. They are expanded files
# 2. They are parents of expanded items
should_show = item in self.expanded_files or item in self.expanded_dirs
if should_show:
shown_items.append(item)
if item.is_dir():
_display_recursive(item, indent + 2, depth + 1)
else:
writer.write(f"{' ' * (indent + 2)}{item.name}\n")
# Show hidden items count if any items were hidden
hidden_files, hidden_dirs = self._count_hidden_items(
current_path, shown_items
)
if hidden_files > 0 or hidden_dirs > 0:
writer.write(f"{' ' * (indent + 2)}...\n")
_display_recursive(self.root, depth=0)
return writer.getvalue()
```
--------------------------------------------------------------------------------
/tests/test_file_range_tracking.py:
--------------------------------------------------------------------------------
```python
import os
import tempfile
from typing import Dict, List, Tuple
import pytest
from wcgw.client.bash_state.bash_state import BashState, FileWhitelistData
from wcgw.client.tools import Context, read_file, read_files
class MockConsole:
def print(self, msg: str, *args, **kwargs) -> None:
pass
def log(self, msg: str, *args, **kwargs) -> None:
pass
@pytest.fixture
def test_file():
"""Create a temporary file with 20 lines of content."""
with tempfile.NamedTemporaryFile(delete=False, mode="w") as f:
for i in range(1, 21):
f.write(f"Line {i}\n")
path = f.name
yield path
# Cleanup
os.unlink(path)
@pytest.fixture
def context():
"""Create a context with BashState for testing."""
with BashState(
console=MockConsole(),
working_dir="",
bash_command_mode=None,
file_edit_mode=None,
write_if_empty_mode=None,
mode=None,
use_screen=False,
) as bash_state:
return Context(bash_state=bash_state, console=MockConsole())
def test_read_file_tracks_line_ranges(test_file, context):
"""Test that read_file correctly returns line ranges."""
# Read lines 5-10
_, _, _, path, line_range = read_file(
test_file, coding_max_tokens=None, noncoding_max_tokens=None, context=context, start_line_num=5, end_line_num=10
)
# Check that the line range is correct
assert line_range == (5, 10)
def test_read_files_tracks_multiple_ranges(test_file, context):
"""Test that read_files correctly collects line ranges for multiple reads."""
# Create a second test file
with tempfile.NamedTemporaryFile(delete=False, mode="w") as f:
for i in range(1, 31):
f.write(f"Another line {i}\n")
second_file = f.name
try:
# Read different ranges from both files
_, file_ranges, _ = read_files(
file_paths=[test_file, second_file],
coding_max_tokens=None,
noncoding_max_tokens=None,
context=context,
start_line_nums=[5, 10],
end_line_nums=[10, 20],
)
# Check that the file ranges dictionary has both files with correct ranges
assert len(file_ranges) == 2
assert test_file in file_ranges
assert second_file in file_ranges
assert file_ranges[test_file] == [(5, 10)]
assert file_ranges[second_file] == [(10, 20)]
finally:
# Cleanup
os.unlink(second_file)
def test_whitelist_data_tracking(test_file):
"""Test that FileWhitelistData correctly tracks line ranges."""
# Create whitelist data with some initial ranges and a total of 20 lines
whitelist_data = FileWhitelistData(
file_hash="abc123", line_ranges_read=[(1, 5), (10, 15)], total_lines=20
)
# Add another range
whitelist_data.add_range(7, 9)
# Calculate percentage
percentage = whitelist_data.get_percentage_read()
# We've read lines 1-5, 7-9, and 10-15, so 14 out of 20 lines = 70%
assert percentage == 70.0
# Test is_read_enough
assert not whitelist_data.is_read_enough()
# Test get_unread_ranges
unread_ranges = whitelist_data.get_unread_ranges()
# We've read lines 1-5, 7-9, and 10-15, so we're missing 6 and 16-20
assert len(unread_ranges) == 2
assert (6, 6) in unread_ranges
assert (16, 20) in unread_ranges
# Add remaining lines
whitelist_data.add_range(6, 6)
whitelist_data.add_range(16, 20)
# Now we should have read everything
assert whitelist_data.is_read_enough()
assert len(whitelist_data.get_unread_ranges()) == 0
def test_bash_state_whitelist_for_overwrite(context, test_file):
"""Test that BashState correctly tracks file whitelist data."""
# Create a dictionary mapping the test file to a line range
file_paths_with_ranges: Dict[str, List[Tuple[int, int]]] = {test_file: [(1, 10)]}
# Add to whitelist
context.bash_state.add_to_whitelist_for_overwrite(file_paths_with_ranges)
# Check that the file was added to the whitelist
assert test_file in context.bash_state.whitelist_for_overwrite
# Check that the line range was stored correctly
whitelist_data = context.bash_state.whitelist_for_overwrite[test_file]
assert whitelist_data.line_ranges_read[0] == (1, 10)
# Add another range
context.bash_state.add_to_whitelist_for_overwrite({test_file: [(15, 20)]})
# Check that the new range was added
whitelist_data = context.bash_state.whitelist_for_overwrite[test_file]
assert len(whitelist_data.line_ranges_read) == 2
assert (15, 20) in whitelist_data.line_ranges_read
```
--------------------------------------------------------------------------------
/src/wcgw/client/repo_ops/file_stats.py:
--------------------------------------------------------------------------------
```python
import hashlib
import json
import os
import sys
from typing import Any, Callable, Dict, TypeVar, cast
T = TypeVar("T") # Type variable for generic functions
F = TypeVar("F", bound=Callable[..., Any]) # Type variable for decorated functions
class FileStats:
"""Track read, edit, and write counts for a single file."""
def __init__(self) -> None:
self.read_count: int = 0
self.edit_count: int = 0
self.write_count: int = 0
def increment_read(self) -> None:
"""Increment the read counter."""
self.read_count += 1
def increment_edit(self) -> None:
"""Increment the edit counter."""
self.edit_count += 1
def increment_write(self) -> None:
"""Increment the write counter."""
self.write_count += 1
def to_dict(self) -> Dict[str, int]:
"""Convert to a dictionary for serialization."""
return {
"read_count": self.read_count,
"edit_count": self.edit_count,
"write_count": self.write_count,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FileStats":
"""Create from a serialized dictionary."""
stats = cls()
stats.read_count = data.get("read_count", 0)
stats.edit_count = data.get("edit_count", 0)
stats.write_count = data.get("write_count", 0)
return stats
class WorkspaceStats:
"""Track file operations statistics for an entire workspace."""
def __init__(self) -> None:
self.files: Dict[str, FileStats] = {} # filepath -> FileStats
def to_dict(self) -> Dict[str, Any]:
"""Convert to a dictionary for serialization."""
return {"files": {k: v.to_dict() for k, v in self.files.items()}}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WorkspaceStats":
"""Create from a serialized dictionary."""
stats = cls()
files_data = data.get("files", {})
stats.files = {k: FileStats.from_dict(v) for k, v in files_data.items()}
return stats
def safe_stats_operation(func: F) -> F:
"""
Decorator to safely perform stats operations without affecting core functionality.
If an exception occurs, it logs the error but allows the program to continue.
"""
def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
return func(*args, **kwargs)
except Exception as e:
# Log the error but continue with the operation
print(f"Warning: Stats tracking error - {e}", file=sys.stderr)
return None
# This is a workaround for proper typing with decorators
return cast(F, wrapper)
def get_stats_path(workspace_path: str) -> str:
"""
Get the path to the stats file for a workspace using a hash-based approach.
Args:
workspace_path: The full path of the workspace directory.
Returns:
The path to the stats file.
"""
# Normalize the path
workspace_path = os.path.normpath(os.path.expanduser(workspace_path))
# Get the basename of the workspace path for readability
workspace_name = os.path.basename(workspace_path)
if not workspace_name: # In case of root directory
workspace_name = "root"
# Create a hash of the full path
path_hash = hashlib.md5(workspace_path.encode()).hexdigest()
# Combine to create a unique identifier that's still somewhat readable
filename = f"{workspace_name}_{path_hash}.json"
# Create directory if it doesn't exist
xdg_data_dir = os.environ.get("XDG_DATA_HOME", os.path.expanduser("~/.local/share"))
stats_dir = os.path.join(xdg_data_dir, "wcgw/workspace_stats")
os.makedirs(stats_dir, exist_ok=True)
return os.path.join(stats_dir, filename)
@safe_stats_operation
def load_workspace_stats(workspace_path: str) -> WorkspaceStats:
"""
Load the stats for a workspace, or create empty stats if not exists.
Args:
workspace_path: The full path of the workspace directory.
Returns:
WorkspaceStats object containing file operation statistics.
"""
stats_path = get_stats_path(workspace_path)
if os.path.exists(stats_path):
try:
with open(stats_path, "r") as f:
return WorkspaceStats.from_dict(json.load(f))
except (json.JSONDecodeError, KeyError, ValueError):
# Handle corrupted file
return WorkspaceStats()
else:
return WorkspaceStats()
@safe_stats_operation
def save_workspace_stats(workspace_path: str, stats: WorkspaceStats) -> None:
"""
Save the stats for a workspace.
Args:
workspace_path: The full path of the workspace directory.
stats: WorkspaceStats object to save.
"""
stats_path = get_stats_path(workspace_path)
with open(stats_path, "w") as f:
json.dump(stats.to_dict(), f, indent=2)
```
--------------------------------------------------------------------------------
/src/wcgw/client/tool_prompts.py:
--------------------------------------------------------------------------------
```python
import os
from mcp.types import Tool, ToolAnnotations
from ..types_ import (
BashCommand,
ContextSave,
FileWriteOrEdit,
Initialize,
ReadFiles,
ReadImage,
)
from .schema_generator import remove_titles_from_schema
with open(os.path.join(os.path.dirname(__file__), "diff-instructions.txt")) as f:
diffinstructions = f.read()
TOOL_PROMPTS = [
Tool(
inputSchema=remove_titles_from_schema(Initialize.model_json_schema()),
name="Initialize",
description="""
- Always call this at the start of the conversation before using any of the shell tools from wcgw.
- Use `any_workspace_path` to initialize the shell in the appropriate project directory.
- If the user has mentioned a workspace or project root or any other file or folder use it to set `any_workspace_path`.
- If user has mentioned any files use `initial_files_to_read` to read, use absolute paths only (~ allowed)
- By default use mode "wcgw"
- In "code-writer" mode, set the commands and globs which user asked to set, otherwise use 'all'.
- Use type="first_call" if it's the first call to this tool.
- Use type="user_asked_mode_change" if in a conversation user has asked to change mode.
- Use type="reset_shell" if in a conversation shell is not working after multiple tries.
- Use type="user_asked_change_workspace" if in a conversation user asked to change workspace
""",
annotations=ToolAnnotations(readOnlyHint=True, openWorldHint=False),
),
Tool(
inputSchema=remove_titles_from_schema(BashCommand.model_json_schema()),
name="BashCommand",
description="""
- Execute a bash command. This is stateful (beware with subsequent calls).
- Status of the command and the current working directory will always be returned at the end.
- The first or the last line might be `(...truncated)` if the output is too long.
- Always run `pwd` if you get any file or directory not found error to make sure you're not lost.
- Do not run bg commands using "&", instead use this tool.
- You must not use echo/cat to read/write files, use ReadFiles/FileWriteOrEdit
- In order to check status of previous command, use `status_check` with empty command argument.
- Only command is allowed to run at a time. You need to wait for any previous command to finish before running a new one.
- Programs don't hang easily, so most likely explanation for no output is usually that the program is still running, and you need to check status again.
- Do not send Ctrl-c before checking for status till 10 minutes or whatever is appropriate for the program to finish.
- Only run long running commands in background. Each background command is run in a new non-reusable shell.
- On running a bg command you'll get a bg command id that you should use to get status or interact.
""",
annotations=ToolAnnotations(destructiveHint=True, openWorldHint=True),
),
Tool(
inputSchema=remove_titles_from_schema(ReadFiles.model_json_schema()),
name="ReadFiles",
description="""
- Read full file content of one or more files.
- Provide absolute paths only (~ allowed)
- Only if the task requires line numbers understanding:
- You may extract a range of lines. E.g., `/path/to/file:1-10` for lines 1-10. You can drop start or end like `/path/to/file:1-` or `/path/to/file:-10`
""",
annotations=ToolAnnotations(readOnlyHint=True, openWorldHint=False),
),
Tool(
inputSchema=remove_titles_from_schema(ReadImage.model_json_schema()),
name="ReadImage",
description="Read an image from the shell.",
annotations=ToolAnnotations(readOnlyHint=True, openWorldHint=False),
),
Tool(
inputSchema=remove_titles_from_schema(FileWriteOrEdit.model_json_schema()),
name="FileWriteOrEdit",
description="""
- Writes or edits a file based on the percentage of changes.
- Use absolute path only (~ allowed).
- First write down percentage of lines that need to be replaced in the file (between 0-100) in percentage_to_change
- percentage_to_change should be low if mostly new code is to be added. It should be high if a lot of things are to be replaced.
- If percentage_to_change > 50, provide full file content in text_or_search_replace_blocks
- If percentage_to_change <= 50, text_or_search_replace_blocks should be search/replace blocks.
"""
+ diffinstructions,
annotations=ToolAnnotations(
destructiveHint=True, idempotentHint=True, openWorldHint=False
),
),
Tool(
inputSchema=remove_titles_from_schema(ContextSave.model_json_schema()),
name="ContextSave",
description="""
Saves provided description and file contents of all the relevant file paths or globs in a single text file.
- Provide random 3 word unqiue id or whatever user provided.
- Leave project path as empty string if no project path""",
annotations=ToolAnnotations(readOnlyHint=True, openWorldHint=False),
),
]
```
--------------------------------------------------------------------------------
/src/wcgw/client/mcp_server/server.py:
--------------------------------------------------------------------------------
```python
import importlib
import logging
import os
from typing import Any, Optional
import mcp.server.stdio
import mcp.types as types
from mcp.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions
from pydantic import AnyUrl
from wcgw.client.modes import KTS
from wcgw.client.tool_prompts import TOOL_PROMPTS
from ...types_ import (
Initialize,
)
from ..bash_state.bash_state import CONFIG, BashState, get_tmpdir
from ..tools import (
Context,
default_enc,
get_tool_output,
parse_tool_by_name,
which_tool_name,
)
server: Server[Any] = Server("wcgw")
# Log only time stamp
logging.basicConfig(level=logging.INFO, format="%(asctime)s: %(message)s")
logger = logging.getLogger("wcgw")
class Console:
def print(self, msg: str, *args: Any, **kwargs: Any) -> None:
logger.info(msg)
def log(self, msg: str, *args: Any, **kwargs: Any) -> None:
logger.info(msg)
@server.list_resources() # type: ignore
async def handle_list_resources() -> list[types.Resource]:
return []
@server.read_resource() # type: ignore
async def handle_read_resource(uri: AnyUrl) -> str:
raise ValueError("No resources available")
PROMPTS = {
"KnowledgeTransfer": (
types.Prompt(
name="KnowledgeTransfer",
description="Prompt for invoking ContextSave tool in order to do a comprehensive knowledge transfer of a coding task. Prompts to save detailed error log and instructions.",
),
KTS,
)
}
@server.list_prompts() # type: ignore
async def handle_list_prompts() -> list[types.Prompt]:
return [x[0] for x in PROMPTS.values()]
@server.get_prompt() # type: ignore
async def handle_get_prompt(
name: str, arguments: dict[str, str] | None
) -> types.GetPromptResult:
assert BASH_STATE
messages = [
types.PromptMessage(
role="user",
content=types.TextContent(
type="text", text=PROMPTS[name][1][BASH_STATE.mode]
),
)
]
return types.GetPromptResult(messages=messages)
@server.list_tools() # type: ignore
async def handle_list_tools() -> list[types.Tool]:
"""
List available tools.
Each tool specifies its arguments using JSON Schema validation.
"""
return TOOL_PROMPTS
@server.call_tool() # type: ignore
async def handle_call_tool(
name: str, arguments: dict[str, Any] | None
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
global BASH_STATE
if not arguments:
raise ValueError("Missing arguments")
tool_type = which_tool_name(name)
tool_call = parse_tool_by_name(name, arguments)
try:
assert BASH_STATE
output_or_dones, _ = get_tool_output(
Context(BASH_STATE, BASH_STATE.console),
tool_call,
default_enc,
0.0,
lambda x, y: ("", 0),
24000, # coding_max_tokens
8000, # noncoding_max_tokens
)
except Exception as e:
output_or_dones = [f"GOT EXCEPTION while calling tool. Error: {e}"]
content: list[types.TextContent | types.ImageContent | types.EmbeddedResource] = []
for output_or_done in output_or_dones:
if isinstance(output_or_done, str):
if issubclass(tool_type, Initialize):
# Prepare the original hardcoded message
original_message = """
- Additional important note: as soon as you encounter "The user has chosen to disallow the tool call.", immediately stop doing everything and ask user for the reason.
Initialize call done.
"""
# If custom instructions exist, prepend them to the original message
if CUSTOM_INSTRUCTIONS:
output_or_done += f"\n{CUSTOM_INSTRUCTIONS}\n{original_message}"
else:
output_or_done += original_message
content.append(types.TextContent(type="text", text=output_or_done))
else:
content.append(
types.ImageContent(
type="image",
data=output_or_done.data,
mimeType=output_or_done.media_type,
)
)
return content
BASH_STATE = None
CUSTOM_INSTRUCTIONS = None
async def main(shell_path: str = "") -> None:
global BASH_STATE, CUSTOM_INSTRUCTIONS
CONFIG.update(3, 55, 5)
version = str(importlib.metadata.version("wcgw"))
# Read custom instructions from environment variable
CUSTOM_INSTRUCTIONS = os.getenv("WCGW_SERVER_INSTRUCTIONS")
# starting_dir is inside tmp dir
tmp_dir = get_tmpdir()
starting_dir = os.path.join(tmp_dir, "claude_playground")
with BashState(
Console(), starting_dir, None, None, None, None, True, None, None, shell_path or None
) as BASH_STATE:
BASH_STATE.console.log("wcgw version: " + version)
# Run the server using stdin/stdout streams
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="wcgw",
server_version=version,
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
raise_exceptions=False,
)
```
--------------------------------------------------------------------------------
/src/wcgw/client/bash_state/parser/bash_statement_parser.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Bash Statement Parser
This script parses bash scripts and identifies individual statements using tree-sitter.
It correctly handles multi-line strings, command chains with && and ||, and semicolon-separated statements.
"""
import sys
from dataclasses import dataclass
from typing import Any, List, Optional
import tree_sitter_bash
from tree_sitter import Language, Parser
@dataclass
class Statement:
"""A bash statement with its source code and position information."""
text: str
start_line: int
end_line: int
start_byte: int
end_byte: int
node_type: str
parent_type: Optional[str] = None
def __str__(self) -> str:
return self.text.strip()
class BashStatementParser:
def __init__(self) -> None:
# Use the precompiled bash language
self.language = Language(tree_sitter_bash.language())
self.parser = Parser(self.language)
def parse_file(self, file_path: str) -> List[Statement]:
"""Parse a bash script file and return a list of statements."""
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return self.parse_string(content)
def parse_string(self, content: str) -> List[Statement]:
"""Parse a string containing bash script and return a list of statements."""
tree = self.parser.parse(bytes(content, "utf-8"))
root_node = tree.root_node
# For debugging: Uncomment to print the tree structure
# self._print_tree(root_node, content)
statements: List[Statement] = []
self._extract_statements(root_node, content, statements, None)
# Post-process statements to handle multi-line statements correctly
return self._post_process_statements(statements, content)
def _print_tree(self, node: Any, content: str, indent: str = "") -> None:
"""Debug helper to print the entire syntax tree."""
node_text = content[node.start_byte : node.end_byte]
if len(node_text) > 40:
node_text = node_text[:37] + "..."
print(f"{indent}{node.type}: {repr(node_text)}")
for child in node.children:
self._print_tree(child, content, indent + " ")
def _extract_statements(
self,
node: Any,
content: str,
statements: List[Statement],
parent_type: Optional[str],
) -> None:
"""Recursively extract statements from the syntax tree."""
# Node types that represent bash statements
statement_node_types = {
# Basic statements
"command",
"variable_assignment",
"declaration_command",
"unset_command",
# Control flow statements
"for_statement",
"c_style_for_statement",
"while_statement",
"if_statement",
"case_statement",
# Function definition
"function_definition",
# Command chains and groups
"pipeline", # For command chains with | and |&
"list", # For command chains with && and ||
"compound_statement",
"subshell",
"redirected_statement",
}
# Create a Statement object for this node if it's a recognized statement type
if node.type in statement_node_types:
# Get the text of this statement
start_byte = node.start_byte
end_byte = node.end_byte
statement_text = content[start_byte:end_byte]
# Get line numbers
start_line = (
node.start_point[0] + 1
) # tree-sitter uses 0-indexed line numbers
end_line = node.end_point[0] + 1
statements.append(
Statement(
text=statement_text,
start_line=start_line,
end_line=end_line,
start_byte=start_byte,
end_byte=end_byte,
node_type=node.type,
parent_type=parent_type,
)
)
# Update parent type for children
parent_type = node.type
# Recursively process all children
for child in node.children:
self._extract_statements(child, content, statements, parent_type)
def _post_process_statements(
self, statements: List[Statement], content: str
) -> List[Statement]:
if not statements:
return []
# Filter out list statements that have been split
top_statements = []
for stmt in statements:
# Skip statements that are contained within others
is_contained = False
for other in statements:
if other is stmt:
continue
# Check if completely contained (except for lists we've split)
if other.node_type != "list" or ";" not in other.text:
if (
other.start_line <= stmt.start_line
and other.end_line >= stmt.end_line
and len(other.text) > len(stmt.text)
and stmt.text in other.text
):
is_contained = True
break
if not is_contained:
top_statements.append(stmt)
# Sort by position in file for consistent output
top_statements.sort(key=lambda s: (s.start_line, s.text))
return top_statements
def main() -> None:
if len(sys.argv) < 2:
print("Usage: python bash_statement_parser.py <bash_script_file>")
sys.exit(1)
parser = BashStatementParser()
statements = parser.parse_file(sys.argv[1])
print(f"Found {len(statements)} statements:")
for i, stmt in enumerate(statements, 1):
print(f"\n--- Statement {i} (Lines {stmt.start_line}-{stmt.end_line}) ---")
print(stmt)
if __name__ == "__main__":
main()
```
--------------------------------------------------------------------------------
/src/wcgw/client/file_ops/search_replace.py:
--------------------------------------------------------------------------------
```python
import re
from typing import Callable, Optional
from .diff_edit import FileEditInput, FileEditOutput, SearchReplaceMatchError
# Global regex patterns
SEARCH_MARKER = re.compile(r"^<<<<<<+\s*SEARCH>?\s*$")
DIVIDER_MARKER = re.compile(r"^======*\s*$")
REPLACE_MARKER = re.compile(r"^>>>>>>+\s*REPLACE\s*$")
class SearchReplaceSyntaxError(Exception):
def __init__(self, message: str):
message = f"""Got syntax error while parsing search replace blocks:
{message}
---
Make sure blocks are in correct sequence, and the markers are in separate lines:
<{"<<<<<< SEARCH"}
example old
=======
example new
>{">>>>>> REPLACE"}
"""
super().__init__(message)
def search_replace_edit(
lines: list[str], original_content: str, logger: Callable[[str], object]
) -> tuple[str, str]:
if not lines:
raise SearchReplaceSyntaxError("Error: No input to search replace edit")
original_lines = original_content.split("\n")
n_lines = len(lines)
i = 0
search_replace_blocks = list[tuple[list[str], list[str]]]()
while i < n_lines:
if SEARCH_MARKER.match(lines[i]):
line_num = i + 1
search_block = []
i += 1
while i < n_lines and not DIVIDER_MARKER.match(lines[i]):
if SEARCH_MARKER.match(lines[i]) or REPLACE_MARKER.match(lines[i]):
raise SearchReplaceSyntaxError(
f"Line {i + 1}: Found stray marker in SEARCH block: {lines[i]}"
)
search_block.append(lines[i])
i += 1
if i >= n_lines:
raise SearchReplaceSyntaxError(
f"Line {line_num}: Unclosed SEARCH block - missing ======= marker"
)
if not search_block:
raise SearchReplaceSyntaxError(
f"Line {line_num}: SEARCH block cannot be empty"
)
i += 1
replace_block = []
while i < n_lines and not REPLACE_MARKER.match(lines[i]):
if SEARCH_MARKER.match(lines[i]) or DIVIDER_MARKER.match(lines[i]):
raise SearchReplaceSyntaxError(
f"Line {i + 1}: Found stray marker in REPLACE block: {lines[i]}"
)
replace_block.append(lines[i])
i += 1
if i >= n_lines:
raise SearchReplaceSyntaxError(
f"Line {line_num}: Unclosed block - missing REPLACE marker"
)
i += 1
for line in search_block:
logger("> " + line)
logger("=======")
for line in replace_block:
logger("< " + line)
logger("\n\n\n\n")
search_replace_blocks.append((search_block, replace_block))
else:
if REPLACE_MARKER.match(lines[i]) or DIVIDER_MARKER.match(lines[i]):
raise SearchReplaceSyntaxError(
f"Line {i + 1}: Found stray marker outside block: {lines[i]}"
)
i += 1
if not search_replace_blocks:
raise SearchReplaceSyntaxError(
"No valid search replace blocks found, ensure your SEARCH/REPLACE blocks are formatted correctly"
)
edited_content, comments_ = edit_with_individual_fallback(
original_lines, search_replace_blocks
)
edited_file = "\n".join(edited_content)
if not comments_:
comments = "Edited successfully"
else:
comments = (
"Edited successfully. However, following warnings were generated while matching search blocks.\n"
+ "\n".join(comments_)
)
return edited_file, comments
def identify_first_differing_block(
best_matches: list[FileEditOutput],
) -> Optional[list[str]]:
"""
Identify the first search block that differs across multiple best matches.
Returns the search block content that first shows different matches.
"""
if not best_matches or len(best_matches) <= 1:
return None
# First, check if the number of blocks differs (shouldn't happen, but let's be safe)
block_counts = [len(match.edited_with_tolerances) for match in best_matches]
if not all(count == block_counts[0] for count in block_counts):
# If block counts differ, just return the first search block as problematic
return (
best_matches[0].orig_search_blocks[0]
if best_matches[0].orig_search_blocks
else None
)
# Go through each block position and see if the slices differ
for i in range(min(block_counts)):
slices = [match.edited_with_tolerances[i][0] for match in best_matches]
# Check if we have different slices for this block across matches
if any(s.start != slices[0].start or s.stop != slices[0].stop for s in slices):
# We found our differing block - return the search block content
if i < len(best_matches[0].orig_search_blocks):
return best_matches[0].orig_search_blocks[i]
else:
return None
# If we get here, we couldn't identify a specific differing block
return None
def edit_with_individual_fallback(
original_lines: list[str], search_replace_blocks: list[tuple[list[str], list[str]]]
) -> tuple[list[str], set[str]]:
outputs = FileEditInput(original_lines, 0, search_replace_blocks, 0).edit_file()
best_matches = FileEditOutput.get_best_match(outputs)
try:
edited_content, comments_ = best_matches[0].replace_or_throw(3)
except SearchReplaceMatchError:
if len(search_replace_blocks) > 1:
# Try one at a time
all_comments = set[str]()
running_lines = list(original_lines)
for block in search_replace_blocks:
running_lines, comments_ = edit_with_individual_fallback(
running_lines, [block]
)
all_comments |= comments_
return running_lines, all_comments
raise
if len(best_matches) > 1:
# Find the first block that differs across matches
first_diff_block = identify_first_differing_block(best_matches)
if first_diff_block is not None:
block_content = "\n".join(first_diff_block)
raise SearchReplaceMatchError(f"""
The following block matched more than once:
```
{block_content}
```
Consider adding more context before and after this block to make the match unique.
""")
else:
raise SearchReplaceMatchError("""
One of the blocks matched more than once
Consider adding more context before and after all the blocks to make the match unique.
""")
return edited_content, comments_
```
--------------------------------------------------------------------------------
/tests/test_bg_commands.py:
--------------------------------------------------------------------------------
```python
"""
Tests for background command execution feature.
"""
import tempfile
from typing import Generator
import pytest
from wcgw.client.bash_state.bash_state import BashState
from wcgw.client.tools import (
BashCommand,
Context,
Initialize,
default_enc,
get_tool_output,
)
from wcgw.types_ import (
Command,
Console,
SendSpecials,
StatusCheck,
)
class TestConsole(Console):
def __init__(self):
self.logs = []
self.prints = []
def log(self, msg: str) -> None:
self.logs.append(msg)
def print(self, msg: str) -> None:
self.prints.append(msg)
@pytest.fixture
def temp_dir() -> Generator[str, None, None]:
"""Provides a temporary directory for testing."""
with tempfile.TemporaryDirectory() as td:
yield td
@pytest.fixture
def context(temp_dir: str) -> Generator[Context, None, None]:
"""Provides a test context with temporary directory and handles cleanup."""
console = TestConsole()
bash_state = BashState(
console=console,
working_dir=temp_dir,
bash_command_mode=None,
file_edit_mode=None,
write_if_empty_mode=None,
mode=None,
use_screen=True,
)
ctx = Context(
bash_state=bash_state,
console=console,
)
# Initialize once for all tests
init_args = Initialize(
type="first_call",
any_workspace_path=temp_dir,
initial_files_to_read=[],
task_id_to_resume="",
mode_name="wcgw",
code_writer_config=None,
thread_id="",
)
get_tool_output(
ctx, init_args, default_enc, 1.0, lambda x, y: ("", 0.0), 8000, 4000
)
yield ctx
# Cleanup after each test
try:
bash_state.sendintr()
bash_state.cleanup()
except Exception as e:
print(f"Error during cleanup: {e}")
def test_bg_command_basic(context: Context, temp_dir: str) -> None:
"""Test basic background command execution."""
# Start a background command
cmd = BashCommand(
action_json=Command(command="sleep 2", is_background=True),
wait_for_seconds=0.1,
thread_id=context.bash_state._current_thread_id,
)
outputs, _ = get_tool_output(
context, cmd, default_enc, 1.0, lambda x, y: ("", 0.0), 8000, 4000
)
assert len(outputs) == 1
assert "bg_command_id" in outputs[0]
assert "status = still running" in outputs[0]
# Extract bg_command_id from output
bg_id = None
for line in outputs[0].split("\n"):
if "bg_command_id" in line:
bg_id = line.split("=")[1].strip()
break
assert bg_id is not None
assert len(context.bash_state.background_shells) == 1
def test_bg_command_status_check(context: Context, temp_dir: str) -> None:
"""Test checking status of background command."""
# Start a background command
cmd = BashCommand(
action_json=Command(command="sleep 1", is_background=True),
wait_for_seconds=0.1,
thread_id=context.bash_state._current_thread_id,
)
outputs, _ = get_tool_output(
context, cmd, default_enc, 1.0, lambda x, y: ("", 0.0), 8000, 4000
)
# Extract bg_command_id
bg_id = None
for line in outputs[0].split("\n"):
if "bg_command_id" in line:
bg_id = line.split("=")[1].strip()
break
assert bg_id is not None
# Check status of background command
status_cmd = BashCommand(
action_json=StatusCheck(status_check=True, bg_command_id=bg_id),
thread_id=context.bash_state._current_thread_id,
)
outputs, _ = get_tool_output(
context, status_cmd, default_enc, 1.0, lambda x, y: ("", 0.0), 8000, 4000
)
assert len(outputs) == 1
assert "status = process exited" in outputs[0]
def test_bg_command_invalid_id(context: Context, temp_dir: str) -> None:
"""Test error handling for invalid bg_command_id."""
# Try to check status with invalid bg_command_id
status_cmd = BashCommand(
action_json=StatusCheck(status_check=True, bg_command_id="invalid_id"),
thread_id=context.bash_state._current_thread_id,
)
try:
outputs, _ = get_tool_output(
context, status_cmd, default_enc, 1.0, lambda x, y: ("", 0.0), 8000, 4000
)
assert False, "Expected exception for invalid bg_command_id"
except Exception as e:
assert "No shell found running with command id" in str(e)
def test_bg_command_interrupt(context: Context, temp_dir: str) -> None:
"""Test interrupting a background command."""
# Start a background command
cmd = BashCommand(
action_json=Command(command="sleep 5", is_background=True),
wait_for_seconds=0.1,
thread_id=context.bash_state._current_thread_id,
)
outputs, _ = get_tool_output(
context, cmd, default_enc, 1.0, lambda x, y: ("", 0.0), 8000, 4000
)
# Extract bg_command_id
bg_id = None
for line in outputs[0].split("\n"):
if "bg_command_id" in line:
bg_id = line.split("=")[1].strip()
break
assert bg_id is not None
# Send Ctrl-C to background command
interrupt_cmd = BashCommand(
action_json=SendSpecials(send_specials=["Ctrl-c"], bg_command_id=bg_id),
thread_id=context.bash_state._current_thread_id,
)
outputs, _ = get_tool_output(
context, interrupt_cmd, default_enc, 1.0, lambda x, y: ("", 0.0), 8000, 4000
)
assert len(outputs) == 1
assert "status = process exited" in outputs[0]
def test_multiple_bg_commands(context: Context, temp_dir: str) -> None:
"""Test running multiple background commands simultaneously."""
# Start first background command
cmd1 = BashCommand(
action_json=Command(command="sleep 2", is_background=True),
wait_for_seconds=0.1,
thread_id=context.bash_state._current_thread_id,
)
outputs1, _ = get_tool_output(
context, cmd1, default_enc, 1.0, lambda x, y: ("", 0.0), 8000, 4000
)
# Start second background command
cmd2 = BashCommand(
action_json=Command(command="sleep 2", is_background=True),
wait_for_seconds=0.1,
thread_id=context.bash_state._current_thread_id,
)
outputs2, _ = get_tool_output(
context, cmd2, default_enc, 1.0, lambda x, y: ("", 0.0), 8000, 4000
)
# Verify both commands are running
assert len(context.bash_state.background_shells) == 2
assert "bg_command_id" in outputs1[0]
assert "bg_command_id" in outputs2[0]
# Extract both bg_command_ids
bg_ids = []
for output in [outputs1[0], outputs2[0]]:
for line in output.split("\n"):
if "bg_command_id" in line:
bg_ids.append(line.split("=")[1].strip())
break
assert len(bg_ids) == 2
assert bg_ids[0] != bg_ids[1]
```
--------------------------------------------------------------------------------
/tests/test_mcp_server.py:
--------------------------------------------------------------------------------
```python
import os
from unittest.mock import AsyncMock, Mock, patch
import pytest
from mcp.server.models import InitializationOptions
from mcp.types import (
GetPromptResult,
Prompt,
PromptMessage,
TextContent,
)
from mcp.types import Tool as ToolParam
from pydantic import ValidationError
from wcgw.client.bash_state.bash_state import CONFIG, BashState
from wcgw.client.mcp_server import server
from wcgw.client.mcp_server.server import (
Console,
handle_call_tool,
handle_get_prompt,
handle_list_prompts,
handle_list_resources,
handle_list_tools,
handle_read_resource,
main,
)
# Reset server.BASH_STATE before all tests
@pytest.fixture(scope="function", autouse=True)
def setup_bash_state():
"""Setup BashState for each test"""
# Update CONFIG immediately
CONFIG.update(3, 55, 5)
# Create new BashState with mode
home_dir = os.path.expanduser("~")
bash_state = BashState(Console(), home_dir, None, None, None, "wcgw", False, None)
server.BASH_STATE = bash_state
try:
yield server.BASH_STATE
finally:
try:
bash_state.cleanup()
except Exception as e:
print(f"Error during cleanup: {e}")
server.BASH_STATE = None
@pytest.mark.asyncio
async def test_handle_list_resources(setup_bash_state):
resources = await handle_list_resources()
assert isinstance(resources, list)
assert len(resources) == 0
@pytest.mark.asyncio
async def test_handle_read_resource(setup_bash_state):
with pytest.raises(ValueError, match="No resources available"):
await handle_read_resource("http://example.com")
@pytest.mark.asyncio
async def test_handle_list_prompts(setup_bash_state):
prompts = await handle_list_prompts()
assert isinstance(prompts, list)
assert len(prompts) > 0
assert isinstance(prompts[0], Prompt)
assert "KnowledgeTransfer" in [p.name for p in prompts]
# Test prompt structure
kt_prompt = next(p for p in prompts if p.name == "KnowledgeTransfer")
assert (
kt_prompt.description
== "Prompt for invoking ContextSave tool in order to do a comprehensive knowledge transfer of a coding task. Prompts to save detailed error log and instructions."
)
@pytest.mark.asyncio
async def test_handle_get_prompt(setup_bash_state):
# Test valid prompt
result = await handle_get_prompt("KnowledgeTransfer", None)
assert isinstance(result, GetPromptResult)
assert len(result.messages) == 1
assert isinstance(result.messages[0], PromptMessage)
assert result.messages[0].role == "user"
assert isinstance(result.messages[0].content, TextContent)
# Test invalid prompt
with pytest.raises(KeyError):
await handle_get_prompt("NonExistentPrompt", None)
# Test with arguments
result = await handle_get_prompt("KnowledgeTransfer", {"arg": "value"})
assert isinstance(result, GetPromptResult)
@pytest.mark.asyncio
async def test_handle_list_tools():
print("Running test_handle_list_tools")
tools = await handle_list_tools()
assert isinstance(tools, list)
assert len(tools) > 0
# Check all required tools are present
tool_names = {tool.name for tool in tools}
required_tools = {
"Initialize",
"BashCommand",
"ReadFiles",
"ReadImage",
"FileWriteOrEdit",
"ContextSave",
}
assert required_tools.issubset(tool_names), (
f"Missing tools: {required_tools - tool_names}"
)
# Test each tool's schema and description
for tool in tools:
assert isinstance(tool, ToolParam)
assert tool.inputSchema is not None
assert isinstance(tool.description, str)
assert len(tool.description.strip()) > 0
# Test specific tool properties based on tool type
if tool.name == "Initialize":
properties = tool.inputSchema["properties"]
assert "mode_name" in properties
assert properties["mode_name"]["enum"] == [
"wcgw",
"architect",
"code_writer",
]
assert "any_workspace_path" in properties
assert properties["any_workspace_path"]["type"] == "string"
assert "initial_files_to_read" in properties
assert properties["initial_files_to_read"]["type"] == "array"
elif tool.name == "BashCommand":
properties = tool.inputSchema["properties"]
assert "action_json" in properties
assert "wait_for_seconds" in properties
# Check type field has all the command types
type_properties = tool.inputSchema["$defs"]["ActionJsonSchema"][
"properties"
]
type_refs = set(type_properties)
required_types = {
"command",
"status_check",
"send_text",
"send_specials",
"send_ascii",
}
assert required_types.issubset(type_refs)
elif tool.name == "FileWriteOrEdit":
properties = tool.inputSchema["properties"]
assert "file_path" in properties
assert "text_or_search_replace_blocks" in properties
@pytest.mark.asyncio
async def test_handle_call_tool(setup_bash_state):
# Test missing arguments
with pytest.raises(ValueError, match="Missing arguments"):
await handle_call_tool("Initialize", None)
# Test Initialize tool with valid arguments
init_args = {
"any_workspace_path": "",
"initial_files_to_read": [],
"task_id_to_resume": "",
"mode_name": "wcgw",
"type": "first_call",
"thread_id": "",
}
result = await handle_call_tool("Initialize", init_args)
assert isinstance(result, list)
assert len(result) > 0
assert isinstance(result[0], TextContent)
assert "Initialize" in result[0].text
# Test JSON string argument handling
json_args = {
"action_json": {"command": "ls"},
"wait_for_seconds": None,
"thread_id": "",
}
result = await handle_call_tool("BashCommand", json_args)
assert isinstance(result, list)
# Test validation error handling
with pytest.raises(ValidationError):
invalid_args = {
"any_workspace_path": 123, # Invalid type
"initial_files_to_read": [],
"task_id_to_resume": "",
"mode_name": "wcgw",
}
await handle_call_tool("Initialize", invalid_args)
# Test tool exception handling
with patch(
"wcgw.client.mcp_server.server.get_tool_output",
side_effect=Exception("Test error"),
):
result = await handle_call_tool(
"BashCommand",
{
"action_json": {"command": "ls"},
"wait_for_seconds": None,
"thread_id": "",
},
)
assert "GOT EXCEPTION" in result[0].text
@pytest.mark.asyncio
async def test_handle_call_tool_image_response(setup_bash_state):
# Test handling of image content
mock_image_data = "fake_image_data"
mock_media_type = "image/png"
# Create a mock image object that matches the expected response
mock_image = Mock()
mock_image.data = mock_image_data
mock_image.media_type = mock_media_type
with patch(
"wcgw.client.mcp_server.server.get_tool_output",
return_value=([mock_image], None),
):
result = await handle_call_tool("ReadImage", {"file_path": "test.png"})
assert result[0].data == mock_image_data
assert result[0].mimeType == mock_media_type
@pytest.mark.asyncio
async def test_main(setup_bash_state):
CONFIG.update(3, 55, 5) # Ensure CONFIG is set before main()
# Mock the version function
with patch("importlib.metadata.version", return_value="1.0.0") as mock_version:
# Mock the stdio server
mock_read_stream = AsyncMock()
mock_write_stream = AsyncMock()
mock_context = AsyncMock()
mock_context.__aenter__.return_value = (mock_read_stream, mock_write_stream)
with patch("mcp.server.stdio.stdio_server", return_value=mock_context):
# Mock server.run to prevent actual server start
with patch("wcgw.client.mcp_server.server.server.run") as mock_run:
await main()
# Verify CONFIG update
assert CONFIG.timeout == 3
assert CONFIG.timeout_while_output == 55
assert CONFIG.output_wait_patience == 5
# Verify server run was called with correct initialization
mock_run.assert_called_once()
init_options = mock_run.call_args[0][2]
assert isinstance(init_options, InitializationOptions)
assert init_options.server_name == "wcgw"
assert init_options.server_version == "1.0.0"
```
--------------------------------------------------------------------------------
/src/wcgw/client/repo_ops/repo_context.py:
--------------------------------------------------------------------------------
```python
import os
from collections import deque
from pathlib import Path # Still needed for other parts
from typing import Optional
from pygit2 import GitError, Repository
from pygit2.enums import SortMode
from .display_tree import DirectoryTree
from .file_stats import load_workspace_stats
from .path_prob import FastPathAnalyzer
curr_folder = Path(__file__).parent
vocab_file = curr_folder / "paths_model.vocab"
model_file = curr_folder / "paths_tokens.model"
PATH_SCORER = FastPathAnalyzer(str(model_file), str(vocab_file))
def find_ancestor_with_git(path: Path) -> Optional[Repository]:
if path.is_file():
path = path.parent
try:
return Repository(str(path))
except GitError:
return None
MAX_ENTRIES_CHECK = 100_000
def get_all_files_max_depth(
abs_folder: str,
max_depth: int,
repo: Optional[Repository],
) -> list[str]:
"""BFS implementation using deque that maintains relative paths during traversal.
Returns (files_list, total_files_found) to track file count."""
all_files = []
# Queue stores: (folder_path, depth, rel_path_prefix)
queue = deque([(abs_folder, 0, "")])
entries_check = 0
while queue and entries_check < MAX_ENTRIES_CHECK:
current_folder, depth, prefix = queue.popleft()
if depth > max_depth:
continue
try:
entries = list(os.scandir(current_folder))
except PermissionError:
continue
except OSError:
continue
# Split into files and folders with single scan
files = []
folders = []
for entry in entries:
entries_check += 1
try:
is_file = entry.is_file(follow_symlinks=False)
except OSError:
continue
name = entry.name
rel_path = f"{prefix}{name}" if prefix else name
if repo and repo.path_is_ignored(rel_path):
continue
if is_file:
files.append(rel_path)
else:
folders.append((entry.path, rel_path))
# Process files first (maintain priority)
chunk = files[: min(10_000, max(0, MAX_ENTRIES_CHECK - entries_check))]
all_files.extend(chunk)
# Add folders to queue for BFS traversal
for folder_path, folder_rel_path in folders:
next_prefix = f"{folder_rel_path}/"
queue.append((folder_path, depth + 1, next_prefix))
return all_files
def get_recent_git_files(repo: Repository, count: int = 10) -> list[str]:
"""
Get the most recently modified files from git history
Args:
repo: The git repository
count: Number of recent files to return
Returns:
List of relative paths to recently modified files
"""
# Track seen files to avoid duplicates
seen_files: set[str] = set()
recent_files: list[str] = []
try:
# Get the HEAD reference and walk through recent commits
head = repo.head
for commit in repo.walk(head.target, SortMode.TOPOLOGICAL | SortMode.TIME):
# Skip merge commits which have multiple parents
if len(commit.parents) > 1:
continue
# If we have a parent, get the diff between the commit and its parent
if commit.parents:
parent = commit.parents[0]
diff = repo.diff(parent, commit) # type: ignore[attr-defined]
else:
# For the first commit, get the diff against an empty tree
diff = commit.tree.diff_to_tree(context_lines=0)
# Process each changed file in the diff
for patch in diff:
file_path = patch.delta.new_file.path
# Skip if we've already seen this file or if the file was deleted
repo_path_parent = Path(repo.path).parent
if (
file_path in seen_files
or not (repo_path_parent / file_path).exists()
):
continue
seen_files.add(file_path)
recent_files.append(file_path)
# If we have enough files, stop
if len(recent_files) >= count:
return recent_files
except Exception:
# Handle git errors gracefully
pass
return recent_files
def calculate_dynamic_file_limit(total_files: int) -> int:
# Scale linearly, with minimum and maximum bounds
min_files = 50
max_files = 400
if total_files <= min_files:
return min_files
scale_factor = (max_files - min_files) / (30000 - min_files)
dynamic_limit = min_files + int((total_files - min_files) * scale_factor)
return min(max_files, dynamic_limit)
def get_repo_context(file_or_repo_path: str) -> tuple[str, Path]:
file_or_repo_path_ = Path(file_or_repo_path).absolute()
repo = find_ancestor_with_git(file_or_repo_path_)
recent_git_files: list[str] = []
# Determine the context directory
if repo is not None:
context_dir = Path(repo.path).parent
else:
if file_or_repo_path_.is_file():
context_dir = file_or_repo_path_.parent
else:
context_dir = file_or_repo_path_
# Load workspace stats from the context directory
workspace_stats = load_workspace_stats(str(context_dir))
# Get all files and calculate dynamic max files limit once
all_files = get_all_files_max_depth(str(context_dir), 10, repo)
# For Git repositories, get recent files
if repo is not None:
dynamic_max_files = calculate_dynamic_file_limit(len(all_files))
# Get recent git files - get at least 10 or 20% of dynamic_max_files, whichever is larger
recent_files_count = max(10, int(dynamic_max_files * 0.2))
recent_git_files = get_recent_git_files(repo, recent_files_count)
else:
# We don't want dynamic limit for non git folders like /tmp or ~
dynamic_max_files = 50
# Calculate probabilities in batch
path_scores = PATH_SCORER.calculate_path_probabilities_batch(all_files)
# Create list of (path, score) tuples and sort by score
path_with_scores = list(zip(all_files, (score[0] for score in path_scores)))
sorted_files = [
path for path, _ in sorted(path_with_scores, key=lambda x: x[1], reverse=True)
]
# Start with recent git files, then add other important files
top_files = []
# If we have workspace stats, prioritize the most active files first
active_files = []
if workspace_stats is not None:
# Get files with activity score (weighted count of operations)
scored_files = []
for file_path, file_stats in workspace_stats.files.items():
try:
# Convert to relative path if possible
if str(context_dir) in file_path:
rel_path = os.path.relpath(file_path, str(context_dir))
else:
rel_path = file_path
# Calculate activity score - weight reads more for this functionality
activity_score = (
file_stats.read_count * 2
+ (file_stats.edit_count)
+ (file_stats.write_count)
)
# Only include files that still exist
if rel_path in all_files or os.path.exists(file_path):
scored_files.append((rel_path, activity_score))
except (ValueError, OSError):
# Skip files that cause path resolution errors
continue
# Sort by activity score (highest first) and get top 5
active_files = [
f for f, _ in sorted(scored_files, key=lambda x: x[1], reverse=True)[:5]
]
# Add active files first
for file in active_files:
if file not in top_files and file in all_files:
top_files.append(file)
# Add recent git files next - these should be prioritized
for file in recent_git_files:
if file not in top_files and file in all_files:
top_files.append(file)
# Use statistical sorting for the remaining files, but respect dynamic_max_files limit
# and ensure we don't add duplicates
if len(top_files) < dynamic_max_files:
# Only add statistically important files that aren't already in top_files
for file in sorted_files:
if file not in top_files and len(top_files) < dynamic_max_files:
top_files.append(file)
directory_printer = DirectoryTree(context_dir, max_files=dynamic_max_files)
for file in top_files[:dynamic_max_files]:
directory_printer.expand(file)
return directory_printer.display(), context_dir
if __name__ == "__main__":
import cProfile
import pstats
import sys
from line_profiler import LineProfiler
folder = sys.argv[1]
# Profile using cProfile for overall function statistics
profiler = cProfile.Profile()
profiler.enable()
result = get_repo_context(folder)[0]
profiler.disable()
# Print cProfile stats
stats = pstats.Stats(profiler)
stats.sort_stats("cumulative")
print("\n=== Function-level profiling ===")
stats.print_stats(20) # Print top 20 functions
# Profile using line_profiler for line-by-line statistics
lp = LineProfiler()
lp_wrapper = lp(get_repo_context)
lp_wrapper(folder)
print("\n=== Line-by-line profiling ===")
lp.print_stats()
print("\n=== Result ===")
print(result)
```
--------------------------------------------------------------------------------
/src/wcgw/types_.py:
--------------------------------------------------------------------------------
```python
import os
from typing import Any, List, Literal, Optional, Protocol, Sequence, Union
from pydantic import BaseModel as PydanticBaseModel
from pydantic import Field, PrivateAttr
class NoExtraArgs(PydanticBaseModel):
class Config:
extra = "forbid"
BaseModel = NoExtraArgs
Modes = Literal["wcgw", "architect", "code_writer"]
class CodeWriterMode(BaseModel):
allowed_globs: Literal["all"] | list[str]
allowed_commands: Literal["all"] | list[str]
def model_post_init(self, _: Any) -> None:
# Patch frequently wrong output trading off accuracy
# in rare case there's a file named 'all' or a command named 'all'
if len(self.allowed_commands) == 1:
if self.allowed_commands[0] == "all":
self.allowed_commands = "all"
if len(self.allowed_globs) == 1:
if self.allowed_globs[0] == "all":
self.allowed_globs = "all"
def update_relative_globs(self, workspace_root: str) -> None:
"""Update globs if they're relative paths"""
if self.allowed_globs != "all":
self.allowed_globs = [
glob if os.path.isabs(glob) else os.path.join(workspace_root, glob)
for glob in self.allowed_globs
]
ModesConfig = Union[Literal["wcgw", "architect"], CodeWriterMode]
class Initialize(BaseModel):
type: Literal[
"first_call",
"user_asked_mode_change",
"reset_shell",
"user_asked_change_workspace",
]
any_workspace_path: str = Field(
description="Workspce to initialise in. Don't use ~ by default, instead use empty string"
)
initial_files_to_read: list[str]
task_id_to_resume: str
mode_name: Literal["wcgw", "architect", "code_writer"]
thread_id: str = Field(
description="Use the thread_id created in first_call, leave it as empty string if first_call"
)
code_writer_config: Optional[CodeWriterMode] = None
def model_post_init(self, __context: Any) -> None:
if self.mode_name == "code_writer":
assert self.code_writer_config is not None, (
"code_writer_config can't be null when the mode is code_writer"
)
if self.type != "first_call" and not self.thread_id:
raise ValueError(
"Thread id should be provided if type != 'first_call', including when resetting"
)
return super().model_post_init(__context)
@property
def mode(self) -> ModesConfig:
if self.mode_name == "wcgw":
return "wcgw"
if self.mode_name == "architect":
return "architect"
assert self.code_writer_config is not None, (
"code_writer_config can't be null when the mode is code_writer"
)
return self.code_writer_config
class Command(BaseModel):
command: str
type: Literal["command"] = "command"
is_background: bool = False
class StatusCheck(BaseModel):
status_check: Literal[True] = True
type: Literal["status_check"] = "status_check"
bg_command_id: str | None = None
class SendText(BaseModel):
send_text: str
type: Literal["send_text"] = "send_text"
bg_command_id: str | None = None
Specials = Literal[
"Enter", "Key-up", "Key-down", "Key-left", "Key-right", "Ctrl-c", "Ctrl-d"
]
class SendSpecials(BaseModel):
send_specials: Sequence[Specials]
type: Literal["send_specials"] = "send_specials"
bg_command_id: str | None = None
class SendAscii(BaseModel):
send_ascii: Sequence[int]
type: Literal["send_ascii"] = "send_ascii"
bg_command_id: str | None = None
class ActionJsonSchema(BaseModel):
type: Literal[
"command", "status_check", "send_text", "send_specials", "send_ascii"
] = Field(description="type of action.")
command: Optional[str] = Field(
default=None, description='Set only if type="command"'
)
status_check: Optional[Literal[True]] = Field(
default=None, description='Set only if type="status_check"'
)
send_text: Optional[str] = Field(
default=None, description='Set only if type="send_text"'
)
send_specials: Optional[Sequence[Specials]] = Field(
default=None, description='Set only if type="send_specials"'
)
send_ascii: Optional[Sequence[int]] = Field(
default=None, description='Set only if type="send_ascii"'
)
is_background: bool = Field(
default=False,
description='Set only if type="command" and running the command in background',
)
bg_command_id: str | None = Field(
default=None,
description='Set only if type!="command" and doing action on a running background command',
)
class BashCommandOverride(BaseModel):
action_json: ActionJsonSchema
wait_for_seconds: Optional[float] = None
thread_id: str
class BashCommand(BaseModel):
action_json: Command | StatusCheck | SendText | SendSpecials | SendAscii
wait_for_seconds: Optional[float] = None
thread_id: str
@staticmethod
def model_json_schema(*args, **kwargs) -> dict[str, Any]: # type: ignore
return BashCommandOverride.model_json_schema(*args, **kwargs)
class ReadImage(BaseModel):
file_path: str
class WriteIfEmpty(BaseModel):
file_path: str
file_content: str
class ReadFiles(BaseModel):
file_paths: list[str]
_start_line_nums: List[Optional[int]] = PrivateAttr(default_factory=lambda: [])
_end_line_nums: List[Optional[int]] = PrivateAttr(default_factory=lambda: [])
@property
def show_line_numbers_reason(self) -> str:
return "True"
@property
def start_line_nums(self) -> List[Optional[int]]:
"""Get the start line numbers."""
return self._start_line_nums
@property
def end_line_nums(self) -> List[Optional[int]]:
"""Get the end line numbers."""
return self._end_line_nums
def model_post_init(self, __context: Any) -> None:
# Parse file paths for line ranges and store them in private attributes
self._start_line_nums = []
self._end_line_nums = []
# Create new file_paths list without line ranges
clean_file_paths = []
for file_path in self.file_paths:
start_line_num = None
end_line_num = None
path_part = file_path
# Check if the path ends with a line range pattern
# We're looking for patterns at the very end of the path like:
# - file.py:10 (specific line)
# - file.py:10-20 (line range)
# - file.py:10- (from line 10 to end)
# - file.py:-20 (from start to line 20)
# Split by the last colon
if ":" in file_path:
parts = file_path.rsplit(":", 1)
if len(parts) == 2:
potential_path = parts[0]
line_spec = parts[1]
# Check if it's a valid line range format
if line_spec.isdigit():
# Format: file.py:10
try:
start_line_num = int(line_spec)
path_part = potential_path
except ValueError:
# Keep the original path if conversion fails
pass
elif "-" in line_spec:
# Could be file.py:10-20, file.py:10-, or file.py:-20
line_parts = line_spec.split("-", 1)
if not line_parts[0] and line_parts[1].isdigit():
# Format: file.py:-20
try:
end_line_num = int(line_parts[1])
path_part = potential_path
except ValueError:
# Keep original path
pass
elif line_parts[0].isdigit():
# Format: file.py:10-20 or file.py:10-
try:
start_line_num = int(line_parts[0])
if line_parts[1].isdigit():
# file.py:10-20
end_line_num = int(line_parts[1])
# In both cases, update the path
path_part = potential_path
except ValueError:
# Keep original path
pass
# Add clean path and corresponding line numbers
clean_file_paths.append(path_part)
self._start_line_nums.append(start_line_num)
self._end_line_nums.append(end_line_num)
# Update file_paths with clean paths
self.file_paths = clean_file_paths
return super().model_post_init(__context)
class FileEdit(BaseModel):
file_path: str
file_edit_using_search_replace_blocks: str
class FileWriteOrEdit(BaseModel):
# Naming should be in sorted order otherwise it gets changed in LLM backend.
file_path: str = Field(description="#1: absolute file path")
percentage_to_change: int = Field(
description="#2: predict this percentage, calculated as number of existing lines that will have some diff divided by total existing lines."
)
text_or_search_replace_blocks: str = Field(
description="#3: content/edit blocks. Must be after #2 in the tool xml"
)
thread_id: str = Field(description="#4: thread_id")
class ContextSave(BaseModel):
id: str
project_root_path: str
description: str
relevant_file_globs: list[str]
class Console(Protocol):
def print(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
def log(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
class Mdata(PydanticBaseModel):
data: BashCommand | FileWriteOrEdit | str | ReadFiles | Initialize | ContextSave
```
--------------------------------------------------------------------------------
/src/wcgw/client/modes.py:
--------------------------------------------------------------------------------
```python
from dataclasses import dataclass
from typing import Any, Literal, NamedTuple
from ..types_ import Modes, ModesConfig
class BashCommandMode(NamedTuple):
bash_mode: Literal["normal_mode", "restricted_mode"]
allowed_commands: Literal["all", "none"]
def serialize(self) -> dict[str, Any]:
return {"bash_mode": self.bash_mode, "allowed_commands": self.allowed_commands}
@classmethod
def deserialize(cls, data: dict[str, Any]) -> "BashCommandMode":
return cls(data["bash_mode"], data["allowed_commands"])
class FileEditMode(NamedTuple):
allowed_globs: Literal["all"] | list[str]
def serialize(self) -> dict[str, Any]:
return {"allowed_globs": self.allowed_globs}
@classmethod
def deserialize(cls, data: dict[str, Any]) -> "FileEditMode":
return cls(data["allowed_globs"])
class WriteIfEmptyMode(NamedTuple):
allowed_globs: Literal["all"] | list[str]
def serialize(self) -> dict[str, Any]:
return {"allowed_globs": self.allowed_globs}
@classmethod
def deserialize(cls, data: dict[str, Any]) -> "WriteIfEmptyMode":
return cls(data["allowed_globs"])
@dataclass
class ModeImpl:
bash_command_mode: BashCommandMode
file_edit_mode: FileEditMode
write_if_empty_mode: WriteIfEmptyMode
def code_writer_prompt(
allowed_file_edit_globs: Literal["all"] | list[str],
all_write_new_globs: Literal["all"] | list[str],
allowed_commands: Literal["all"] | list[str],
) -> str:
base = """
You are now running in "code_writer" mode.
"""
path_prompt = """
- You are allowed to edit files in the provided repository only.
"""
if allowed_file_edit_globs != "all":
if allowed_file_edit_globs:
path_prompt = f"""
- You are allowed to edit files for files matching only the following globs: {", ".join(allowed_file_edit_globs)}
"""
else:
path_prompt = """
- You are not allowed to edit files.
"""
base += path_prompt
path_prompt = """
- You are allowed to write files in the provided repository only.
"""
if all_write_new_globs != "all":
if all_write_new_globs:
path_prompt = f"""
- You are allowed to write files files matching only the following globs: {", ".join(allowed_file_edit_globs)}
"""
else:
path_prompt = """
- You are not allowed to write files.
"""
base += path_prompt
run_command_common = """
- Do not use Ctrl-c interrupt commands without asking the user, because often the programs don't show any update but they still are running.
- Do not use echo/cat to write any file, always use FileWriteOrEdit tool to create/update files.
- Do not provide code snippets unless asked by the user, instead directly add/edit the code.
- You should use the provided bash execution, reading and writing file tools to complete objective.
- Do not use artifacts if you have access to the repository and not asked by the user to provide artifacts/snippets. Directly create/update using wcgw tools.
"""
command_prompt = f"""
- You are only allowed to run commands for project setup, code writing, editing, updating, testing, running and debugging related to the project.
- Do not run anything that adds or removes packages, changes system configuration or environment.
{run_command_common}
"""
if allowed_commands != "all":
if allowed_commands:
command_prompt = f"""
- You are only allowed to run the following commands: {", ".join(allowed_commands)}
{run_command_common}
"""
else:
command_prompt = """
- You are not allowed to run any commands.
"""
base += command_prompt
return base
WCGW_PROMPT = """
# Instructions
- You should use the provided bash execution, reading and writing file tools to complete objective.
- Do not provide code snippets unless asked by the user, instead directly add/edit the code.
- Do not install new tools/packages before ensuring no such tools/package or an alternative already exists.
- Do not use artifacts if you have access to the repository and not asked by the user to provide artifacts/snippets. Directly create/update using wcgw tools
- Do not use Ctrl-c or interrupt commands without asking the user, because often the programs don't show any update but they still are running.
- Do not use echo/cat to write any file, always use FileWriteOrEdit tool to create/update files.
- You can share task summary directly without creating any file.
- Provide as many file paths as you need in ReadFiles in one go.
Additional instructions:
Always run `pwd` if you get any file or directory not found error to make sure you're not lost, or to get absolute cwd.
"""
ARCHITECT_PROMPT = """
# Instructions
You are now running in "architect" mode. This means
- You are not allowed to edit or update any file. You are not allowed to create any file.
- You are not allowed to run any commands that may change disk, system configuration, packages or environment. Only read-only commands are allowed.
- Only run commands that allows you to explore the repository, understand the system or read anything of relevance.
- Do not use Ctrl-c or interrupt commands without asking the user, because often the programs don't show any update but they still are running.
- You are not allowed to change directory (bash will run in -r mode)
- Share only snippets when any implementation is requested.
- Provide as many file paths as you need in ReadFiles in one go.
# Disallowed tools (important!)
- FileWriteOrEdit
# Response instructions
Respond only after doing the following:
- Read as many relevant files as possible.
- Be comprehensive in your understanding and search of relevant files.
- First understand about the project by getting the folder structure (ignoring .git, node_modules, venv, etc.)
- Share minimal snippets higlighting the changes (avoid large number of lines in the snippets, use ... comments)
"""
DEFAULT_MODES: dict[Modes, ModeImpl] = {
"wcgw": ModeImpl(
bash_command_mode=BashCommandMode("normal_mode", "all"),
write_if_empty_mode=WriteIfEmptyMode("all"),
file_edit_mode=FileEditMode("all"),
),
"architect": ModeImpl(
bash_command_mode=BashCommandMode("restricted_mode", "all"),
write_if_empty_mode=WriteIfEmptyMode([]),
file_edit_mode=FileEditMode([]),
),
"code_writer": ModeImpl(
bash_command_mode=BashCommandMode("normal_mode", "all"),
write_if_empty_mode=WriteIfEmptyMode("all"),
file_edit_mode=FileEditMode("all"),
),
}
def modes_to_state(
mode: ModesConfig,
) -> tuple[BashCommandMode, FileEditMode, WriteIfEmptyMode, Modes]:
# First get default mode config
if isinstance(mode, str):
mode_impl = DEFAULT_MODES[mode] # converts str to Modes enum
mode_name: Modes = mode
else:
# For CodeWriterMode, use code_writer as base and override
mode_impl = DEFAULT_MODES["code_writer"]
# Override with custom settings from CodeWriterMode
mode_impl = ModeImpl(
bash_command_mode=BashCommandMode(
mode_impl.bash_command_mode.bash_mode,
"all" if mode.allowed_commands else "none",
),
file_edit_mode=FileEditMode(mode.allowed_globs),
write_if_empty_mode=WriteIfEmptyMode(mode.allowed_globs),
)
mode_name = "code_writer"
return (
mode_impl.bash_command_mode,
mode_impl.file_edit_mode,
mode_impl.write_if_empty_mode,
mode_name,
)
WCGW_KT = """Use `ContextSave` tool to do a knowledge transfer of the task in hand.
Write detailed description in order to do a KT.
Save all information necessary for a person to understand the task and the problems.
Format the `description` field using Markdown with the following sections.
- "# Objective" section containing project and task objective.
- "# All user instructions" section should be provided containing all instructions user shared in the conversation.
- "# Current status of the task" should be provided containing only what is already achieved, not what's remaining.
- "# Pending issues with snippets" section containing snippets of pending errors, traceback, file snippets, commands, etc. But no comments or solutions.
- Be very verbose in the all issues with snippets section providing as much error context as possible.
- "# Build and development instructions" section containing instructions to build or run project or run tests, or envrionment related information. Only include what's known. Leave empty if unknown.
- Any other relevant sections following the above.
- After the tool completes succesfully, tell me the task id and the file path the tool generated (important!)
- This tool marks end of your conversation, do not run any further tools after calling this.
Provide all relevant file paths in order to understand and solve the the task. Err towards providing more file paths than fewer.
(Note to self: this conversation can then be resumed later asking "Resume wcgw task `<generated id>`" which should call Initialize tool)
"""
ARCHITECT_KT = """Use `ContextSave` tool to do a knowledge transfer of the task in hand.
Write detailed description in order to do a KT.
Save all information necessary for a person to understand the task and the problems.
Format the `description` field using Markdown with the following sections.
- "# Objective" section containing project and task objective.
- "# All user instructions" section should be provided containing all instructions user shared in the conversation.
- "# Designed plan" should be provided containing the designed plan as discussed.
- Any other relevant sections following the above.
- After the tool completes succesfully, tell me the task id and the file path the tool generated (important!)
- This tool marks end of your conversation, do not run any further tools after calling this.
Provide all relevant file paths in order to understand and solve the the task. Err towards providing more file paths than fewer.
(Note to self: this conversation can then be resumed later asking "Resume wcgw task `<generated id>`" which should call Initialize tool)
"""
KTS = {"wcgw": WCGW_KT, "architect": ARCHITECT_KT, "code_writer": WCGW_KT}
```
--------------------------------------------------------------------------------
/src/wcgw_cli/openai_client.py:
--------------------------------------------------------------------------------
```python
import base64
import json
import mimetypes
import os
import subprocess
import tempfile
import traceback
import uuid
from pathlib import Path
from typing import DefaultDict, Optional, cast
import openai
import petname # type: ignore[import-untyped]
import rich
import tokenizers # type: ignore[import-untyped]
from dotenv import load_dotenv
from openai import OpenAI
from openai.types.chat import (
ChatCompletionContentPartParam,
ChatCompletionMessageParam,
ChatCompletionUserMessageParam,
)
from pydantic import BaseModel
from typer import Typer
from wcgw.client.bash_state.bash_state import BashState
from wcgw.client.common import CostData, History, Models, discard_input
from wcgw.client.memory import load_memory
from wcgw.client.tool_prompts import TOOL_PROMPTS
from wcgw.client.tools import (
Context,
ImageData,
default_enc,
get_tool_output,
initialize,
which_tool,
which_tool_name,
)
from .openai_utils import get_input_cost, get_output_cost
class Config(BaseModel):
model: Models
cost_limit: float
cost_file: dict[Models, CostData]
cost_unit: str = "$"
def text_from_editor(console: rich.console.Console) -> str:
# First consume all the input till now
discard_input()
console.print("\n---------------------------------------\n# User message")
data = input()
if data:
return data
editor = os.environ.get("EDITOR", "vim")
with tempfile.NamedTemporaryFile(suffix=".tmp") as tf:
subprocess.run([editor, tf.name], check=True)
with open(tf.name, "r") as f:
data = f.read()
console.print(data)
return data
def save_history(history: History, session_id: str) -> None:
myid = str(history[1]["content"]).replace("/", "_").replace(" ", "_").lower()[:60]
myid += "_" + session_id
myid = myid + ".json"
mypath = Path(".wcgw") / myid
mypath.parent.mkdir(parents=True, exist_ok=True)
with open(mypath, "w") as f:
json.dump(history, f, indent=3)
def parse_user_message_special(msg: str) -> ChatCompletionUserMessageParam:
# Search for lines starting with `%` and treat them as special commands
parts: list[ChatCompletionContentPartParam] = []
for line in msg.split("\n"):
if line.startswith("%"):
args = line[1:].strip().split(" ")
command = args[0]
assert command == "image"
image_path = " ".join(args[1:])
with open(image_path, "rb") as f:
image_bytes = f.read()
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
image_type = mimetypes.guess_type(image_path)[0]
dataurl = f"data:{image_type};base64,{image_b64}"
parts.append(
{"type": "image_url", "image_url": {"url": dataurl, "detail": "auto"}}
)
else:
if len(parts) > 0 and parts[-1]["type"] == "text":
parts[-1]["text"] += "\n" + line
else:
parts.append({"type": "text", "text": line})
return {"role": "user", "content": parts}
app = Typer(pretty_exceptions_show_locals=False)
@app.command()
def loop(
first_message: Optional[str] = None,
limit: Optional[float] = None,
resume: Optional[str] = None,
) -> tuple[str, float]:
load_dotenv()
session_id = str(uuid.uuid4())[:6]
history: History = []
waiting_for_assistant = False
memory = None
if resume:
try:
_, memory, _ = load_memory(
resume,
24000, # coding_max_tokens
8000, # noncoding_max_tokens
lambda x: default_enc.encoder(x),
lambda x: default_enc.decoder(x),
)
except OSError:
if resume == "latest":
resume_path = sorted(Path(".wcgw").iterdir(), key=os.path.getmtime)[-1]
else:
resume_path = Path(resume)
if not resume_path.exists():
raise FileNotFoundError(f"File {resume} not found")
with resume_path.open() as f:
history = json.load(f)
if len(history) <= 2:
raise ValueError("Invalid history file")
first_message = ""
waiting_for_assistant = history[-1]["role"] != "assistant"
my_dir = os.path.dirname(__file__)
config = Config(
model=cast(Models, os.getenv("OPENAI_MODEL", "gpt-4o-2024-08-06").lower()),
cost_limit=0.1,
cost_unit="$",
cost_file={
"gpt-4o-2024-08-06": CostData(
cost_per_1m_input_tokens=5, cost_per_1m_output_tokens=15
),
},
)
if limit is not None:
config.cost_limit = limit
limit = config.cost_limit
enc = tokenizers.Tokenizer.from_pretrained("Xenova/gpt-4o")
tools = [
openai.pydantic_function_tool(
which_tool_name(tool.name), description=tool.description
)
for tool in TOOL_PROMPTS
if tool.name != "Initialize"
]
cost: float = 0
input_toks = 0
output_toks = 0
system_console = rich.console.Console(style="blue", highlight=False, markup=False)
error_console = rich.console.Console(style="red", highlight=False, markup=False)
user_console = rich.console.Console(
style="bright_black", highlight=False, markup=False
)
assistant_console = rich.console.Console(
style="white bold", highlight=False, markup=False
)
with BashState(
system_console, os.getcwd(), None, None, None, None, True, None
) as bash_state:
context = Context(bash_state, system_console)
system, context, _ = initialize(
"first_call",
context,
os.getcwd(),
[],
resume if (memory and resume) else "",
24000, # coding_max_tokens
8000, # noncoding_max_tokens
mode="wcgw",
thread_id="",
)
if not history:
history = [{"role": "system", "content": system}]
else:
if history[-1]["role"] == "tool":
waiting_for_assistant = True
client = OpenAI()
while True:
if cost > limit:
system_console.print(
f"\nCost limit exceeded. Current cost: {cost}, input tokens: {input_toks}, output tokens: {output_toks}"
)
break
if not waiting_for_assistant:
if first_message:
msg = first_message
first_message = ""
else:
msg = text_from_editor(user_console)
history.append(parse_user_message_special(msg))
else:
waiting_for_assistant = False
cost_, input_toks_ = get_input_cost(
config.cost_file[config.model], enc, history
)
cost += cost_
input_toks += input_toks_
stream = client.chat.completions.create(
messages=history,
model=config.model,
stream=True,
tools=tools,
)
system_console.print(
"\n---------------------------------------\n# Assistant response",
style="bold",
)
tool_call_args_by_id = DefaultDict[str, DefaultDict[int, str]](
lambda: DefaultDict(str)
)
_histories: History = []
item: ChatCompletionMessageParam
full_response: str = ""
image_histories: History = []
try:
for chunk in stream:
if chunk.choices[0].finish_reason == "tool_calls":
assert tool_call_args_by_id
item = {
"role": "assistant",
"content": full_response,
"tool_calls": [
{
"id": tool_call_id + str(toolindex),
"type": "function",
"function": {
"arguments": tool_args,
"name": type(which_tool(tool_args)).__name__,
},
}
for tool_call_id, toolcallargs in tool_call_args_by_id.items()
for toolindex, tool_args in toolcallargs.items()
],
}
cost_, output_toks_ = get_output_cost(
config.cost_file[config.model], enc, item
)
cost += cost_
system_console.print(
f"\n---------------------------------------\n# Assistant invoked tools: {[which_tool(tool['function']['arguments']) for tool in item['tool_calls']]}"
)
system_console.print(
f"\nTotal cost: {config.cost_unit}{cost:.3f}"
)
output_toks += output_toks_
_histories.append(item)
for tool_call_id, toolcallargs in tool_call_args_by_id.items():
for toolindex, tool_args in toolcallargs.items():
try:
output_or_dones, cost_ = get_tool_output(
context,
json.loads(tool_args),
enc,
limit - cost,
loop,
24000, # coding_max_tokens
8000, # noncoding_max_tokens
)
output_or_done = output_or_dones[0]
except Exception as e:
output_or_done = (
f"GOT EXCEPTION while calling tool. Error: {e}"
)
tb = traceback.format_exc()
error_console.print(output_or_done + "\n" + tb)
cost_ = 0
cost += cost_
system_console.print(
f"\nTotal cost: {config.cost_unit}{cost:.3f}"
)
output = output_or_done
if isinstance(output, ImageData):
randomId = petname.Generate(2, "-")
if not image_histories:
image_histories.extend(
[
{
"role": "assistant",
"content": f"Share images with ids: {randomId}",
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": output.dataurl,
"detail": "auto",
},
}
],
},
]
)
else:
image_histories[0]["content"] += ", " + randomId
second_content = image_histories[1]["content"]
assert isinstance(second_content, list)
second_content.append(
{
"type": "image_url",
"image_url": {
"url": output.dataurl,
"detail": "auto",
},
}
)
item = {
"role": "tool",
"content": f"Ask user for image id: {randomId}",
"tool_call_id": tool_call_id + str(toolindex),
}
else:
item = {
"role": "tool",
"content": str(output),
"tool_call_id": tool_call_id + str(toolindex),
}
cost_, output_toks_ = get_output_cost(
config.cost_file[config.model], enc, item
)
cost += cost_
output_toks += output_toks_
_histories.append(item)
waiting_for_assistant = True
break
elif chunk.choices[0].finish_reason:
assistant_console.print("")
item = {
"role": "assistant",
"content": full_response,
}
cost_, output_toks_ = get_output_cost(
config.cost_file[config.model], enc, item
)
cost += cost_
output_toks += output_toks_
system_console.print(
f"\nTotal cost: {config.cost_unit}{cost:.3f}"
)
_histories.append(item)
break
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.function and tool_call.function.arguments:
tool_call_args_by_id[tool_call.id or ""][
tool_call.index
] += tool_call.function.arguments
chunk_str = chunk.choices[0].delta.content or ""
assistant_console.print(chunk_str, end="")
full_response += chunk_str
except KeyboardInterrupt:
waiting_for_assistant = False
input("Interrupted...enter to redo the current turn")
else:
history.extend(_histories)
history.extend(image_histories)
save_history(history, session_id)
return "Couldn't finish the task", cost
if __name__ == "__main__":
app()
```
--------------------------------------------------------------------------------
/tests/test_edit.py:
--------------------------------------------------------------------------------
```python
import os
import tempfile
from typing import Generator
import pytest
from wcgw.client.bash_state.bash_state import BashState
from wcgw.client.file_ops.diff_edit import SearchReplaceMatchError
from wcgw.client.file_ops.search_replace import SearchReplaceSyntaxError
from wcgw.client.tools import (
Context,
FileWriteOrEdit,
Initialize,
default_enc,
get_tool_output,
)
from wcgw.types_ import Console
class TestConsole(Console):
def __init__(self):
self.logs = []
self.prints = []
def log(self, msg: str) -> None:
self.logs.append(msg)
def print(self, msg: str) -> None:
self.prints.append(msg)
@pytest.fixture
def temp_dir() -> Generator[str, None, None]:
"""Provides a temporary directory for testing."""
with tempfile.TemporaryDirectory() as td:
yield td
@pytest.fixture
def context(temp_dir: str) -> Generator[Context, None, None]:
"""Provides a test context with temporary directory and handles cleanup."""
console = TestConsole()
bash_state = BashState(
console=console,
working_dir=temp_dir,
bash_command_mode=None,
file_edit_mode=None,
write_if_empty_mode=None,
mode=None,
use_screen=False,
)
ctx = Context(
bash_state=bash_state,
console=console,
)
yield ctx
# Cleanup after each test
bash_state.cleanup()
def test_file_edit(context: Context, temp_dir: str) -> None:
"""Test the FileWriteOrEdit tool."""
# First initialize
init_args = Initialize(
thread_id="",
any_workspace_path=temp_dir,
initial_files_to_read=[],
task_id_to_resume="",
mode_name="wcgw",
code_writer_config=None,
type="first_call",
)
get_tool_output(
context, init_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
# Create a test file
test_file = os.path.join(temp_dir, "test.py")
with open(test_file, "w") as f:
f.write("def hello():\n print('hello')\n")
# Test editing the file
edit_args = FileWriteOrEdit(
thread_id=context.bash_state.current_thread_id,
file_path=test_file,
percentage_to_change=10,
text_or_search_replace_blocks="""<<<<<<< SEARCH
def hello():
print('hello')
=======
def hello():
print('hello world')
>>>>>>> REPLACE""",
)
outputs, _ = get_tool_output(
context, edit_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
assert len(outputs) == 1
# Verify the change
with open(test_file) as f:
content = f.read()
assert "hello world" in content
# Test indentation match
edit_args = FileWriteOrEdit(
thread_id=context.bash_state.current_thread_id,
file_path=test_file,
percentage_to_change=100,
text_or_search_replace_blocks="""<<<<<<< SEARCH
def hello():
print('hello world')
=======
def hello():
print('ok')
>>>>>>> REPLACE""",
)
outputs, _ = get_tool_output(
context, edit_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
assert len(outputs) == 1
assert "Warning: matching without considering indentation" in outputs[0]
# Verify the change
with open(test_file) as f:
content = f.read()
assert "print('ok')" in content
# Test no match with partial
edit_args = FileWriteOrEdit(
thread_id=context.bash_state.current_thread_id,
file_path=test_file,
percentage_to_change=50,
text_or_search_replace_blocks="""<<<<<<< SEARCH
def hello():
print('no match')
=======
def hello():
print('no match replace')
>>>>>>> REPLACE""",
)
with pytest.raises(SearchReplaceMatchError) as e:
outputs, _ = get_tool_output(
context, edit_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
assert """def hello():
print('ok')""" in str(e)
with open(test_file) as f:
content = f.read()
assert "print('ok')" in content
# Test syntax error
edit_args = FileWriteOrEdit(
thread_id=context.bash_state.current_thread_id,
file_path=test_file,
percentage_to_change=10,
text_or_search_replace_blocks="""<<<<<<< SEARCH
def hello():
print('ok')
=======
def hello():
print('ok")
>>>>>>> REPLACE""",
)
outputs, _ = get_tool_output(
context, edit_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
assert len(outputs) == 1
assert "tree-sitter reported syntax errors" in outputs[0]
# Verify the change
with open(test_file) as f:
content = f.read()
assert "print('ok\")" in content
with pytest.raises(SearchReplaceSyntaxError) as e:
edit_args = FileWriteOrEdit(
thread_id=context.bash_state.current_thread_id,
file_path=test_file,
percentage_to_change=50,
text_or_search_replace_blocks="""<<<<<<< SEARCH
def hello():
print('ok')
=======
def hello():
print('ok")
>>>>>>> REPLACE
>>>>>>> REPLACE
""",
)
outputs, _ = get_tool_output(
context, edit_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
with pytest.raises(SearchReplaceSyntaxError) as e:
edit_args = FileWriteOrEdit(
thread_id=context.bash_state.current_thread_id,
file_path=test_file,
percentage_to_change=10,
text_or_search_replace_blocks="""<<<<<<< SEARCH
def hello():
print('ok')
=======
def hello():
print('ok")
""",
)
outputs, _ = get_tool_output(
context, edit_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
# Test multiple matches
with open(test_file, "w") as f:
f.write("""
def hello():
print('ok')
# Comment
def hello():
print('ok')
""")
with pytest.raises(SearchReplaceMatchError) as e:
edit_args = FileWriteOrEdit(
thread_id=context.bash_state.current_thread_id,
file_path=test_file,
percentage_to_change=1,
text_or_search_replace_blocks="""<<<<<<< SEARCH
def hello():
print('ok')
=======
def hello():
print('hello world')
>>>>>>> REPLACE
""",
)
outputs, _ = get_tool_output(
context, edit_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
# Grounding should pass even when duplicate found
edit_args = FileWriteOrEdit(
thread_id=context.bash_state.current_thread_id,
file_path=test_file,
percentage_to_change=10,
text_or_search_replace_blocks="""<<<<<<< SEARCH
# Comment
=======
# New Comment
>>>>>>> REPLACE
<<<"""
+ """<<<< SEARCH
def hello():
print('ok')
=======
def hello():
print('hello world')
>>>>>>> REPLACE
""",
)
outputs, _ = get_tool_output(
context, edit_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
with open(test_file) as f:
content = f.read()
assert (
content
== """
def hello():
print('ok')
# New Comment
def hello():
print('hello world')
"""
)
import re
def fix_indentation(
matched_lines: list[str], searched_lines: list[str], replaced_lines: list[str]
) -> list[str]:
if not matched_lines or not searched_lines or not replaced_lines:
return replaced_lines
def get_indentation(line: str) -> str:
match = re.match(r"^(\s*)", line)
assert match
return match.group(0)
matched_indents = [get_indentation(line) for line in matched_lines if line.strip()]
searched_indents = [
get_indentation(line) for line in searched_lines if line.strip()
]
if len(matched_indents) != len(searched_indents):
return replaced_lines
diffs: list[int] = [
len(searched) - len(matched)
for matched, searched in zip(matched_indents, searched_indents)
]
if not diffs:
return replaced_lines
if not all(diff == diffs[0] for diff in diffs):
return replaced_lines
if diffs[0] == 0:
return replaced_lines
def adjust_indentation(line: str, diff: int) -> str:
if diff < 0:
# Need to add -diff spaces
return matched_indents[0][:-diff] + line
# Need to remove diff spaces
return line[diff:]
if diffs[0] > 0:
# Check if replaced_lines have enough leading spaces to remove
if not all(not line[: diffs[0]].strip() for line in replaced_lines):
return replaced_lines
return [adjust_indentation(line, diffs[0]) for line in replaced_lines]
def test_empty_inputs():
assert fix_indentation([], [" foo"], [" bar"]) == [" bar"]
assert fix_indentation([" foo"], [], [" bar"]) == [" bar"]
assert fix_indentation([" foo"], [" foo"], []) == []
def test_no_non_empty_lines_in_matched_or_searched():
# All lines in matched_lines/searched_lines are blank or just spaces
matched_lines = [" ", " "]
searched_lines = [" ", "\t "]
replaced_lines = [" Some text", " Another text"]
# Because matched_lines / searched_lines effectively have 0 non-empty lines,
# the function returns replaced_lines as is
assert (
fix_indentation(matched_lines, searched_lines, replaced_lines) == replaced_lines
)
def test_same_indentation_no_change():
# The non-empty lines have the same indentation => diff=0 => no changes
matched_lines = [" foo", " bar"]
searched_lines = [" baz", " qux"]
replaced_lines = [" spam", " ham"]
# Should return replaced_lines unchanged
assert (
fix_indentation(matched_lines, searched_lines, replaced_lines) == replaced_lines
)
def test_positive_indentation_difference():
# matched_lines have fewer spaces than searched_lines => diff > 0 => remove indentation from replaced_lines
matched_lines = [" foo", " bar"]
searched_lines = [" foo", " bar"]
replaced_lines = [" spam", " ham"]
# diff is 2 => remove 2 spaces from the start of each replaced line
expected = [" spam", " ham"]
assert fix_indentation(matched_lines, searched_lines, replaced_lines) == expected
def test_positive_indentation_not_enough_spaces():
# We want to remove 2 spaces, but replaced_lines do not have that many leading spaces
matched_lines = ["foo", "bar"]
searched_lines = [" foo", " bar"]
replaced_lines = [" spam", " ham"] # only 1 leading space
# The function should detect there's not enough indentation to remove => return replaced_lines unchanged
assert (
fix_indentation(matched_lines, searched_lines, replaced_lines) == replaced_lines
)
def test_negative_indentation_difference():
# matched_lines have more spaces than searched_lines => diff < 0 => add indentation to replaced_lines
matched_lines = [" foo", " bar"]
searched_lines = [" foo", " bar"]
replaced_lines = ["spam", "ham"]
# diff is -2 => add 2 spaces from matched_indents[0] to each line
# matched_indents[0] = ' ' => matched_indents[0][:-diff] => ' '[:2] => ' '
expected = [" spam", " ham"]
assert fix_indentation(matched_lines, searched_lines, replaced_lines) == expected
def test_different_number_of_non_empty_lines():
# matched_indents and searched_indents have different lengths => return replaced_lines
matched_lines = [
" foo",
" ",
" baz",
] # effectively 2 non-empty lines
searched_lines = [" foo", " bar", " baz"] # 3 non-empty lines
replaced_lines = [" spam", " ham"]
assert (
fix_indentation(matched_lines, searched_lines, replaced_lines) == replaced_lines
)
def test_inconsistent_indentation_difference():
# The diffs are not all the same => return replaced_lines
matched_lines = [" foo", " bar"]
searched_lines = [" foo", " bar"]
replaced_lines = ["spam", "ham"]
# For the first pair, diff = len(" ") - len(" ") = 2 - 4 = -2
# For the second pair, diff = len(" ") - len(" ") = 4 - 8 = -4
# Not all diffs are equal => should return replaced_lines
assert (
fix_indentation(matched_lines, searched_lines, replaced_lines) == replaced_lines
)
def test_realistic_fix_indentation_scenario():
matched_lines = [
" class Example:",
" def method(self):",
" print('hello')",
]
searched_lines = [
"class Example:",
" def method(self):",
" print('world')",
]
replaced_lines = [
"class Example:",
" def another_method(self):",
" print('world')",
]
expected = [
" class Example:",
" def another_method(self):",
" print('world')",
]
assert fix_indentation(matched_lines, searched_lines, replaced_lines) == expected
def test_realistic_nonfix_indentation_scenario():
matched_lines = [
" class Example:",
" def method(self):",
" print('hello')",
]
searched_lines = [
"class Example:",
" def method(self):",
" print('world')",
]
replaced_lines = [
"class Example:",
" def another_method(self):",
" print('world')",
]
assert (
fix_indentation(matched_lines, searched_lines, replaced_lines) == replaced_lines
)
def test_context_based_matching(context: Context, temp_dir: str) -> None:
"""Test using past and future context to uniquely identify search blocks."""
# First initialize
init_args = Initialize(
thread_id="",
any_workspace_path=temp_dir,
initial_files_to_read=[],
task_id_to_resume="",
mode_name="wcgw",
code_writer_config=None,
type="first_call",
)
get_tool_output(
context, init_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
# Create a test file with repeating pattern
test_file = os.path.join(temp_dir, "test_context.py")
with open(test_file, "w") as f:
f.write("A\nB\nC\nB\n")
# Test case 1: Using future context to uniquely identify a block
# The search "A" followed by "B" followed by "C" uniquely determines the first B
edit_args = FileWriteOrEdit(
thread_id=context.bash_state.current_thread_id,
file_path=test_file,
percentage_to_change=10,
text_or_search_replace_blocks="""<<<<<<< SEARCH
A
=======
A
>>>>>>> REPLACE
<<<<<<< SEARCH
B
=======
B_MODIFIED_FIRST
>>>>>>> REPLACE
<<<<<<< SEARCH
C
=======
C
>>>>>>> REPLACE""",
)
outputs, _ = get_tool_output(
context, edit_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
# Verify the change - first B should be modified
with open(test_file) as f:
content = f.read()
assert content == "A\nB_MODIFIED_FIRST\nC\nB\n"
# Test case 2: Using past context to uniquely identify a block
# Reset the file
with open(test_file, "w") as f:
f.write("A\nB\nC\nB\n")
# The search "C" followed by "B" uniquely determines the second B
edit_args = FileWriteOrEdit(
thread_id=context.bash_state.current_thread_id,
file_path=test_file,
percentage_to_change=10,
text_or_search_replace_blocks="""<<<<<<< SEARCH
C
=======
C
>>>>>>> REPLACE
<<<<<<< SEARCH
B
=======
B_MODIFIED_SECOND
>>>>>>> REPLACE""",
)
outputs, _ = get_tool_output(
context, edit_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
# Verify the change - second B should be modified
with open(test_file) as f:
content = f.read()
assert content == "A\nB\nC\nB_MODIFIED_SECOND\n"
def test_unordered(context: Context, temp_dir: str) -> None:
# First initialize
init_args = Initialize(
thread_id="",
any_workspace_path=temp_dir,
initial_files_to_read=[],
task_id_to_resume="",
mode_name="wcgw",
code_writer_config=None,
type="first_call",
)
get_tool_output(
context, init_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
# Create a test file with repeating pattern
test_file = os.path.join(temp_dir, "test_context.py")
with open(test_file, "w") as f:
f.write("A\nB\nC\nB\n")
# Test case 1: Using future context to uniquely identify a block
# The search "A" followed by "B" followed by "C" uniquely determines the first B
edit_args = FileWriteOrEdit(
thread_id=context.bash_state.current_thread_id,
file_path=test_file,
percentage_to_change=10,
text_or_search_replace_blocks="""<<<<<<< SEARCH
C
=======
CPrime
>>>>>>> REPLACE
<<<<<<< SEARCH
A
=======
A_MODIFIED
>>>>>>> REPLACE
""",
)
outputs, _ = get_tool_output(
context, edit_args, default_enc, 1.0, lambda x, y: ("", 0.0), None, None
)
# Verify the change - first B should be modified
with open(test_file) as f:
content = f.read()
assert content == "A_MODIFIED\nB\nCPrime\nB\n"
```
--------------------------------------------------------------------------------
/src/wcgw/client/file_ops/diff_edit.py:
--------------------------------------------------------------------------------
```python
import re
from dataclasses import dataclass, field
from difflib import SequenceMatcher
from typing import Callable, DefaultDict, Literal, Optional
TOLERANCE_TYPES = Literal["SILENT", "WARNING", "ERROR"]
class SearchReplaceMatchError(Exception):
def __init__(self, message: str):
message = f"""
{message}
---
Edit failed, no changes are applied. You'll have to reapply all search/replace blocks again.
Retry immediately with same "percentage_to_change" using search replace blocks fixing above error.
"""
super().__init__(message)
@dataclass
class Tolerance:
line_process: Callable[[str], str]
severity_cat: TOLERANCE_TYPES
score_multiplier: float
error_name: str
@dataclass
class TolerancesHit(Tolerance):
count: int
@dataclass
class FileEditOutput:
original_content: list[str]
orig_search_blocks: list[list[str]]
edited_with_tolerances: list[
tuple[slice, list[TolerancesHit], list[str]]
] # Need not be equal to orig_search_blocks when early exit
def replace_or_throw(
self,
max_errors: int,
) -> tuple[list[str], set[str]]:
new_lines = list[str]()
last_idx = 0
errors = []
warnings = set[str]()
info = set[str]()
score = 0.0
for (span, tolerances, replace_with), search_ in zip(
self.edited_with_tolerances, self.orig_search_blocks
):
for tol in tolerances:
score += tol.count * tol.score_multiplier
if tol.count > 0:
if tol.severity_cat == "WARNING":
warnings.add(tol.error_name)
elif tol.severity_cat == "ERROR":
search__ = "\n".join(search_)
errors.append(f"""
Got error while processing the following search block:
---
```
{search__}
```
---
Error:
{tol.error_name}
---
""")
else:
info.add(tol.error_name)
if len(errors) >= max_errors:
raise SearchReplaceMatchError("\n".join(errors))
if last_idx < span.start:
new_lines.extend(self.original_content[last_idx : span.start])
new_lines.extend(replace_with)
last_idx = span.stop
if last_idx < len(self.original_content):
new_lines.extend(self.original_content[last_idx:])
if errors:
raise SearchReplaceMatchError("\n".join(errors))
if score > 1000:
display = (list(warnings) + list(info))[:max_errors]
raise SearchReplaceMatchError(
"Too many warnings generated, not apply the edits\n"
+ "\n".join(display)
)
return new_lines, set(warnings)
@staticmethod
def get_best_match(
outputs: list["FileEditOutput"],
) -> list["FileEditOutput"]:
best_hits: list[FileEditOutput] = []
best_score = float("-inf")
assert outputs
for output in outputs:
hit_score = 0.0
for _, tols, _ in output.edited_with_tolerances:
for tol in tols:
hit_score += tol.count * tol.score_multiplier
if not best_hits:
best_hits.append(output)
best_score = hit_score
else:
if hit_score < best_score:
best_hits = [output]
best_score = hit_score
elif abs(hit_score - best_score) < 1e-3:
best_hits.append(output)
return best_hits
def line_process_max_space_tolerance(line: str) -> str:
line = line.strip()
return re.sub(r"\s", "", line)
REMOVE_INDENTATION = "Warning: matching after removing all spaces in lines."
DEFAULT_TOLERANCES = [
Tolerance(
line_process=str.rstrip,
severity_cat="SILENT",
score_multiplier=1,
error_name="",
),
Tolerance(
line_process=str.lstrip,
severity_cat="WARNING",
score_multiplier=10,
error_name="Warning: matching without considering indentation (leading spaces).",
),
Tolerance(
line_process=line_process_max_space_tolerance,
severity_cat="WARNING",
score_multiplier=50,
error_name=REMOVE_INDENTATION,
),
]
def fix_indentation(
matched_lines: list[str], searched_lines: list[str], replaced_lines: list[str]
) -> list[str]:
if not matched_lines or not searched_lines or not replaced_lines:
return replaced_lines
def get_indentation(line: str) -> str:
match = re.match(r"^(\s*)", line)
assert match
return match.group(0)
matched_indents = [get_indentation(line) for line in matched_lines if line.strip()]
searched_indents = [
get_indentation(line) for line in searched_lines if line.strip()
]
if len(matched_indents) != len(searched_indents):
return replaced_lines
diffs: list[int] = [
len(searched) - len(matched)
for matched, searched in zip(matched_indents, searched_indents)
]
if not all(diff == diffs[0] for diff in diffs):
return replaced_lines
if diffs[0] == 0:
return replaced_lines
# At this point we have same number of non-empty lines and the same indentation difference
# We can now adjust the indentation of the replaced lines
def adjust_indentation(line: str, diff: int) -> str:
if diff < 0:
return matched_indents[0][:-diff] + line
return line[diff:]
if diffs[0] > 0:
if not (all(not line[: diffs[0]].strip() for line in replaced_lines)):
return replaced_lines
return [adjust_indentation(line, diffs[0]) for line in replaced_lines]
def remove_leading_trailing_empty_lines(lines: list[str]) -> list[str]:
start = 0
end = len(lines) - 1
if end < start:
return lines
while not lines[start].strip():
start += 1
if start >= len(lines):
break
while not lines[end].strip():
end -= 1
if end < 0:
break
return lines[start : end + 1]
@dataclass
class FileEditInput:
file_lines: list[str]
file_line_offset: int
search_replace_blocks: list[tuple[list[str], list[str]]]
search_replace_offset: int
tolerances: list["Tolerance"] = field(default_factory=lambda: DEFAULT_TOLERANCES)
def edit_file(self) -> list[FileEditOutput]:
n_file_lines = len(self.file_lines)
n_blocks = len(self.search_replace_blocks)
# Boundary conditions
no_match_output = FileEditOutput(
original_content=self.file_lines,
orig_search_blocks=[x[0] for x in self.search_replace_blocks],
edited_with_tolerances=[
(
slice(0, 0),
[
TolerancesHit(
line_process=lambda x: x,
severity_cat="ERROR",
score_multiplier=float("inf"),
error_name="The blocks couldn't be matched, maybe the sequence of search blocks was incorrect?",
count=max(1, len(search_lines)),
)
for search_lines, _ in self.search_replace_blocks[
self.search_replace_offset :
]
],
[],
)
],
)
if (
self.file_line_offset >= n_file_lines
and self.search_replace_offset < n_blocks
):
return [no_match_output]
elif self.file_line_offset >= n_file_lines:
return [
FileEditOutput(
self.file_lines,
[x[0] for x in self.search_replace_blocks],
[(slice(0, 0), [], [])],
)
]
elif self.search_replace_offset >= n_blocks:
return [
FileEditOutput(
self.file_lines,
[x[0] for x in self.search_replace_blocks],
[(slice(0, 0), [], [])],
)
]
# search for first block
first_block = self.search_replace_blocks[self.search_replace_offset]
replace_by = first_block[1]
# Try exact match
matches = match_exact(self.file_lines, self.file_line_offset, first_block[0])
all_outputs = list[list[tuple[slice, list[TolerancesHit], list[str]]]]()
if not matches:
# Try tolerances
matches_with_tolerances = match_with_tolerance(
self.file_lines, self.file_line_offset, first_block[0], self.tolerances
)
if not matches_with_tolerances:
# Try with no empty lines
matches_with_tolerances = match_with_tolerance_empty_line(
self.file_lines,
self.file_line_offset,
first_block[0],
self.tolerances,
)
replace_by = remove_leading_trailing_empty_lines(first_block[1])
if not matches_with_tolerances:
# Report edit distance
sim_match, sim_sim, sim_context = (
find_least_edit_distance_substring(
self.file_lines, self.file_line_offset, first_block[0]
)
)
if sim_match:
matches_with_tolerances = [
(
sim_match,
[
TolerancesHit(
lambda x: x,
"ERROR",
float("inf"),
"Couldn't find match. Here's the latest snippet from the file which might be relevant for you to consider:\n```"
+ sim_context
+ "\n```",
int(len(first_block[0]) // sim_sim),
)
],
)
]
else:
matches_with_tolerances = [(match, []) for match in matches]
for match, tolerances in matches_with_tolerances:
if any(
tolerance.error_name == REMOVE_INDENTATION for tolerance in tolerances
):
replace_by = fix_indentation(
self.file_lines[match.start : match.stop],
first_block[0],
replace_by,
)
file_edit_input = FileEditInput(
self.file_lines,
match.stop,
self.search_replace_blocks,
self.search_replace_offset + 1,
self.tolerances,
)
if any(tolerance.severity_cat == "ERROR" for tolerance in tolerances):
# Exit early
all_outputs.append(
[
(match, tolerances, replace_by),
]
)
else:
remaining_output = file_edit_input.edit_file()
for rem_output in remaining_output:
all_outputs.append(
[
(match, tolerances, replace_by),
*rem_output.edited_with_tolerances,
]
)
if not all_outputs:
return [no_match_output]
return [
FileEditOutput(
self.file_lines, [x[0] for x in self.search_replace_blocks], output
)
for output in all_outputs
]
def find_contiguous_match(search_line_positions: list[set[int]]) -> list[slice]:
n_search_lines = len(search_line_positions)
def search_in_dictionary(search_offset: int, search_index: int) -> bool:
if search_offset >= n_search_lines:
return True
if search_index in search_line_positions[search_offset]:
return search_in_dictionary(search_offset + 1, search_index + 1)
return False
matched_slices = []
for index in search_line_positions[0]:
if search_in_dictionary(1, index + 1):
matched_slices.append(slice(index, index + n_search_lines, 1))
return matched_slices
def match_exact(
content: list[str], content_offset: int, search: list[str]
) -> list[slice]:
n_search_lines = len(search)
n_content = len(content) - content_offset
if n_search_lines > n_content:
return []
if n_search_lines == 0:
return []
if n_content == 0:
return []
content_positions = DefaultDict[str, set[int]](set)
for i in range(content_offset, n_content):
content_positions[content[i]].add(i)
search_line_positions = [content_positions[line] for line in search]
matched_slices = find_contiguous_match(search_line_positions)
return matched_slices
def match_with_tolerance(
content: list[str],
content_offset: int,
search: list[str],
tolerances: list[Tolerance],
) -> list[tuple[slice, list[TolerancesHit]]]:
n_search_lines = len(search)
n_content = len(content) - content_offset
if n_search_lines > n_content:
return []
if n_search_lines == 0:
return []
if n_content == 0:
return []
content_positions = DefaultDict[str, set[int]](set)
for i in range(content_offset, n_content):
content_positions[content[i]].add(i)
search_line_positions = [content_positions[line] for line in search]
tolerance_index_by_content_line: list[dict[int, int]] = [
{} for _ in range(len(search))
]
for tidx, tolerance in enumerate(tolerances):
content_positions = DefaultDict[str, set[int]](set)
for i in range(content_offset, n_content):
line = content[i]
content_positions[tolerance.line_process(line)].add(i)
for i, line in enumerate(search):
new_lines = content_positions[tolerance.line_process(line)]
new_indices = new_lines - search_line_positions[i]
search_line_positions[i].update(new_indices)
tolerance_index_by_content_line[i].update(
{idx: tidx for idx in new_indices}
)
matched_slices = find_contiguous_match(search_line_positions)
tolerances_counts: list[list[TolerancesHit]] = [
[
TolerancesHit(
line_process=tol.line_process,
severity_cat=tol.severity_cat,
score_multiplier=tol.score_multiplier,
count=0,
error_name=tol.error_name,
)
for tol in tolerances
]
for _ in range(len(matched_slices))
]
for sidx, slice in enumerate(matched_slices):
for search_idx, content_idx in enumerate(
range(slice.start, slice.stop, slice.step)
):
if content_idx in tolerance_index_by_content_line[search_idx]:
tolerances_counts[sidx][
tolerance_index_by_content_line[search_idx][content_idx]
].count += 1
return list(zip(matched_slices, tolerances_counts))
def match_with_tolerance_empty_line(
content: list[str],
content_offset: int,
search: list[str],
tolerances: list[Tolerance],
) -> list[tuple[slice, list[TolerancesHit]]]:
new_content = list[str]()
new_to_original = dict[int, int]()
for i in range(content_offset, len(content)):
line = content[i]
if line.strip():
new_to_original[len(new_content)] = i
new_content.append(line)
search = [line for line in search if line.strip()]
matches_with_tolerancs = match_with_tolerance(new_content, 0, search, tolerances)
new_matches_with_tolerances = list[tuple[slice, list[TolerancesHit]]]()
for matches, tolerance_counts in matches_with_tolerancs:
matches = slice(
new_to_original[matches.start], new_to_original[matches.stop - 1] + 1, 1
)
new_matches_with_tolerances.append((matches, tolerance_counts))
return new_matches_with_tolerances
def find_least_edit_distance_substring(
orig_content_lines: list[str], offset: int, find_lines: list[str]
) -> tuple[Optional[slice], float, str]:
# Prepare content lines, stripping whitespace and keeping track of original indices
content_lines = [
orig_content_lines[i].strip() for i in range(offset, len(orig_content_lines))
]
new_to_original_indices = {}
new_content_lines = []
for i, line in enumerate(content_lines):
if not line:
continue
new_content_lines.append(line)
new_to_original_indices[len(new_content_lines) - 1] = i
content_lines = new_content_lines
# Prepare find lines, removing empty lines
find_lines = [line.strip() for line in find_lines if line.strip()]
# Initialize variables for best match tracking
max_similarity = 0.0
min_edit_distance_lines = None
context_lines = []
# For each possible starting position in content
for i in range(max(1, len(content_lines) - len(find_lines) + 1)):
# Calculate similarity for the block starting at position i
block_similarity = 0.0
for j in range(len(find_lines)):
if (i + j) < len(content_lines):
# Use SequenceMatcher for more efficient similarity calculation
similarity = SequenceMatcher(
None, content_lines[i + j], find_lines[j]
).ratio()
block_similarity += similarity
# If this block is more similar than previous best
if block_similarity > max_similarity:
max_similarity = block_similarity
# Map back to original line indices
orig_start_index = new_to_original_indices[i]
orig_end_index = (
new_to_original_indices.get(
i + len(find_lines) - 1, len(orig_content_lines) - 1
)
+ 1
)
# Get the original lines
min_edit_distance_lines = slice(
orig_start_index + offset, orig_end_index + offset
)
# Get context (10 lines before and after)
context_lines = orig_content_lines[
max(0, orig_start_index - 10 + offset) : (orig_end_index + 10 + offset)
]
return (
min_edit_distance_lines,
max_similarity,
"\n".join(context_lines),
)
```
--------------------------------------------------------------------------------
/src/wcgw_cli/anthropic_client.py:
--------------------------------------------------------------------------------
```python
import base64
import json
import mimetypes
import os
import subprocess
import tempfile
import traceback
import uuid
from pathlib import Path
from typing import Literal, Optional, cast
import rich
from anthropic import Anthropic, MessageStopEvent
from anthropic.types import (
ImageBlockParam,
MessageParam,
ModelParam,
RawMessageStartEvent,
TextBlockParam,
ToolParam,
ToolResultBlockParam,
ToolUseBlockParam,
)
from dotenv import load_dotenv
from pydantic import BaseModel, ValidationError
from typer import Typer
from wcgw.client.bash_state.bash_state import BashState
from wcgw.client.common import CostData, discard_input
from wcgw.client.memory import load_memory
from wcgw.client.tool_prompts import TOOL_PROMPTS
from wcgw.client.tools import (
Context,
ImageData,
default_enc,
get_tool_output,
initialize,
parse_tool_by_name,
)
class Config(BaseModel):
model: ModelParam
cost_limit: float
cost_file: dict[ModelParam, CostData]
cost_unit: str = "$"
History = list[MessageParam]
def text_from_editor(console: rich.console.Console) -> str:
# First consume all the input till now
discard_input()
console.print("\n---------------------------------------\n# User message")
data = input()
if data:
return data
editor = os.environ.get("EDITOR", "vim")
with tempfile.NamedTemporaryFile(suffix=".tmp") as tf:
subprocess.run([editor, tf.name], check=True)
with open(tf.name, "r") as f:
data = f.read()
console.print(data)
return data
def save_history(history: History, session_id: str) -> None:
myid = str(history[1]["content"]).replace("/", "_").replace(" ", "_").lower()[:60]
myid += "_" + session_id
myid = myid + ".json"
mypath = Path(".wcgw") / myid
mypath.parent.mkdir(parents=True, exist_ok=True)
with open(mypath, "w") as f:
json.dump(history, f, indent=3)
def parse_user_message_special(msg: str) -> MessageParam:
# Search for lines starting with `%` and treat them as special commands
parts: list[ImageBlockParam | TextBlockParam] = []
for line in msg.split("\n"):
if line.startswith("%"):
args = line[1:].strip().split(" ")
command = args[0]
assert command == "image"
image_path = " ".join(args[1:])
with open(image_path, "rb") as f:
image_bytes = f.read()
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
image_type = mimetypes.guess_type(image_path)[0]
parts.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": cast(
'Literal["image/jpeg", "image/png", "image/gif", "image/webp"]',
image_type or "image/png",
),
"data": image_b64,
},
}
)
else:
if len(parts) > 0 and parts[-1]["type"] == "text":
parts[-1]["text"] += "\n" + line
else:
parts.append({"type": "text", "text": line})
return {"role": "user", "content": parts}
app = Typer(pretty_exceptions_show_locals=False)
@app.command()
def loop(
first_message: Optional[str] = None,
limit: Optional[float] = None,
resume: Optional[str] = None,
) -> tuple[str, float]:
load_dotenv()
session_id = str(uuid.uuid4())[:6]
history: History = []
waiting_for_assistant = False
memory = None
if resume:
try:
_, memory, _ = load_memory(
resume,
24000, # coding_max_tokens
8000, # noncoding_max_tokens
lambda x: default_enc.encoder(x),
lambda x: default_enc.decoder(x),
)
except OSError:
if resume == "latest":
resume_path = sorted(Path(".wcgw").iterdir(), key=os.path.getmtime)[-1]
else:
resume_path = Path(resume)
if not resume_path.exists():
raise FileNotFoundError(f"File {resume} not found")
with resume_path.open() as f:
history = json.load(f)
if len(history) <= 2:
raise ValueError("Invalid history file")
first_message = ""
waiting_for_assistant = history[-1]["role"] != "assistant"
config = Config(
model="claude-3-5-sonnet-20241022",
cost_limit=0.1,
cost_unit="$",
cost_file={
# Claude 3.5 Haiku
"claude-3-5-haiku-latest": CostData(
cost_per_1m_input_tokens=0.80, cost_per_1m_output_tokens=4
),
"claude-3-5-haiku-20241022": CostData(
cost_per_1m_input_tokens=0.80, cost_per_1m_output_tokens=4
),
# Claude 3.5 Sonnet
"claude-3-5-sonnet-latest": CostData(
cost_per_1m_input_tokens=3.0, cost_per_1m_output_tokens=15.0
),
"claude-3-5-sonnet-20241022": CostData(
cost_per_1m_input_tokens=3.0, cost_per_1m_output_tokens=15.0
),
"claude-3-5-sonnet-20240620": CostData(
cost_per_1m_input_tokens=3.0, cost_per_1m_output_tokens=15.0
),
# Claude 3 Opus
"claude-3-opus-latest": CostData(
cost_per_1m_input_tokens=15.0, cost_per_1m_output_tokens=75.0
),
"claude-3-opus-20240229": CostData(
cost_per_1m_input_tokens=15.0, cost_per_1m_output_tokens=75.0
),
# Legacy Models
"claude-3-haiku-20240307": CostData(
cost_per_1m_input_tokens=0.25, cost_per_1m_output_tokens=1.25
),
"claude-2.1": CostData(
cost_per_1m_input_tokens=8.0, cost_per_1m_output_tokens=24.0
),
"claude-2.0": CostData(
cost_per_1m_input_tokens=8.0, cost_per_1m_output_tokens=24.0
),
},
)
if limit is not None:
config.cost_limit = limit
limit = config.cost_limit
tools = [
ToolParam(
name=tool.name,
description=tool.description or "", # Ensure it's not None
input_schema=tool.inputSchema,
)
for tool in TOOL_PROMPTS
if tool.name != "Initialize"
]
system_console = rich.console.Console(style="blue", highlight=False, markup=False)
error_console = rich.console.Console(style="red", highlight=False, markup=False)
user_console = rich.console.Console(
style="bright_black", highlight=False, markup=False
)
assistant_console = rich.console.Console(
style="white bold", highlight=False, markup=False
)
with BashState(
system_console, os.getcwd(), None, None, None, None, True, None
) as bash_state:
context = Context(bash_state, system_console)
system, context, _ = initialize(
"first_call",
context,
os.getcwd(),
[],
resume if (memory and resume) else "",
24000, # coding_max_tokens
8000, # noncoding_max_tokens
mode="wcgw",
thread_id="",
)
if history:
if (
(last_msg := history[-1])["role"] == "user"
and isinstance((content := last_msg["content"]), dict)
and content["type"] == "tool_result"
):
waiting_for_assistant = True
client = Anthropic()
cost: float = 0
input_toks = 0
output_toks = 0
while True:
if cost > limit:
system_console.print(
f"\nCost limit exceeded. Current cost: {config.cost_unit}{cost:.4f}, "
f"input tokens: {input_toks}"
f"output tokens: {output_toks}"
)
break
else:
system_console.print(
f"\nTotal cost: {config.cost_unit}{cost:.4f}, input tokens: {input_toks}, output tokens: {output_toks}"
)
if not waiting_for_assistant:
if first_message:
msg = first_message
first_message = ""
else:
msg = text_from_editor(user_console)
history.append(parse_user_message_special(msg))
else:
waiting_for_assistant = False
stream = client.messages.stream(
model=config.model,
messages=history,
tools=tools,
max_tokens=8096,
system=system,
)
system_console.print(
"\n---------------------------------------\n# Assistant response",
style="bold",
)
_histories: History = []
full_response: str = ""
tool_calls = []
tool_results: list[ToolResultBlockParam] = []
try:
with stream as stream_:
for chunk in stream_:
type_ = chunk.type
if isinstance(chunk, RawMessageStartEvent):
message_start = chunk.message
# Update cost based on token usage from the API response
input_tokens = message_start.usage.input_tokens
input_toks += input_tokens
cost += (
input_tokens
* config.cost_file[
config.model
].cost_per_1m_input_tokens
) / 1_000_000
elif isinstance(chunk, MessageStopEvent):
message_stop = chunk.message
# Update cost based on output tokens
output_tokens = message_stop.usage.output_tokens
output_toks += output_tokens
cost += (
output_tokens
* config.cost_file[
config.model
].cost_per_1m_output_tokens
) / 1_000_000
continue
elif type_ == "content_block_start" and hasattr(
chunk, "content_block"
):
content_block = chunk.content_block
if (
hasattr(content_block, "type")
and content_block.type == "text"
and hasattr(content_block, "text")
):
chunk_str = content_block.text
assistant_console.print(chunk_str, end="")
full_response += chunk_str
elif content_block.type == "tool_use":
if (
hasattr(content_block, "input")
and hasattr(content_block, "name")
and hasattr(content_block, "id")
):
assert content_block.input == {}
tool_calls.append(
{
"name": str(content_block.name),
"input": str(""),
"done": False,
"id": str(content_block.id),
}
)
else:
error_console.log(
f"Ignoring unknown content block type {content_block.type}"
)
elif type_ == "content_block_delta" and hasattr(chunk, "delta"):
delta = chunk.delta
if hasattr(delta, "type"):
delta_type = str(delta.type)
if delta_type == "text_delta" and hasattr(
delta, "text"
):
chunk_str = delta.text
assistant_console.print(chunk_str, end="")
full_response += chunk_str
elif delta_type == "input_json_delta" and hasattr(
delta, "partial_json"
):
partial_json = delta.partial_json
if isinstance(tool_calls[-1]["input"], str):
tool_calls[-1]["input"] += partial_json
else:
error_console.log(
f"Ignoring unknown content block delta type {delta_type}"
)
else:
raise ValueError("Content block delta has no type")
elif type_ == "content_block_stop":
if tool_calls and not tool_calls[-1]["done"]:
tc = tool_calls[-1]
tool_name = str(tc["name"])
tool_input = str(tc["input"])
tool_id = str(tc["id"])
_histories.append(
{
"role": "assistant",
"content": [
ToolUseBlockParam(
id=tool_id,
name=tool_name,
input=json.loads(tool_input),
type="tool_use",
)
],
}
)
try:
tool_parsed = parse_tool_by_name(
tool_name, json.loads(tool_input)
)
except ValidationError:
error_msg = f"Error parsing tool {tool_name}\n{traceback.format_exc()}"
system_console.log(
f"Error parsing tool {tool_name}"
)
tool_results.append(
ToolResultBlockParam(
type="tool_result",
tool_use_id=str(tc["id"]),
content=error_msg,
is_error=True,
)
)
continue
system_console.print(
f"\n---------------------------------------\n# Assistant invoked tool: {tool_parsed}"
)
try:
output_or_dones, _ = get_tool_output(
context,
tool_parsed,
default_enc,
limit - cost,
loop,
24000, # coding_max_tokens
8000, # noncoding_max_tokens
)
except Exception as e:
output_or_dones = [
(
f"GOT EXCEPTION while calling tool. Error: {e}"
)
]
tb = traceback.format_exc()
error_console.print(
str(output_or_dones) + "\n" + tb
)
tool_results_content: list[
TextBlockParam | ImageBlockParam
] = []
for output in output_or_dones:
if isinstance(output, ImageData):
tool_results_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": output.media_type,
"data": output.data,
},
}
)
else:
tool_results_content.append(
{
"type": "text",
"text": output,
},
)
tool_results.append(
ToolResultBlockParam(
type="tool_result",
tool_use_id=str(tc["id"]),
content=tool_results_content,
)
)
else:
_histories.append(
{
"role": "assistant",
"content": full_response
if full_response.strip()
else "...",
} # Fixes anthropic issue of non empty response only
)
except KeyboardInterrupt:
waiting_for_assistant = False
input("Interrupted...enter to redo the current turn")
else:
history.extend(_histories)
if tool_results:
history.append({"role": "user", "content": tool_results})
waiting_for_assistant = True
save_history(history, session_id)
return "Couldn't finish the task", cost
if __name__ == "__main__":
app()
```