#
tokens: 30880/50000 17/17 files
lines: off (toggle) GitHub
raw markdown copy
# Directory Structure

```
├── .dockerignore
├── .github
│   ├── logo.png
│   └── workflows
│       ├── docker.yml
│       ├── publish.yml
│       └── test.yml
├── .gitignore
├── .python-version
├── Dockerfile
├── LICENSE
├── pyproject.toml
├── README.md
├── server.py
├── tests
│   ├── __init__.py
│   ├── test_calculus.py
│   ├── test_linalg.py
│   ├── test_relativity.py
│   ├── test_server.py
│   └── test_units.py
├── uv.lock
└── vars.py
```

# Files

--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------

```
3.12

```

--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------

```
__pycache__
*.pyc
*.pyo
*.pyd
.Python
.env
.git
.venv

```

--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------

```
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info

# Virtual environments
.venv
.DS_Store

```

--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------

```markdown
<div align="center">
  <img src=".github/logo.png" alt="Sympy MCP Logo" width="400" />
</div>

# Symbolic Algebra MCP Server

Sympy-MCP is a Model Context Protocol server for allowing LLMs to autonomously perform symbolic mathematics and computer algebra. It exposes numerous tools from SymPy's core functionality to MCP clients for manipulating mathematical expressions and equations.

## Why?

Language models are absolutely abysmal at symbolic manipulation. They hallucinate variables, make up random constants, permute terms and generally make a mess. But we have computer algebra systems specifically built for symbolic manipulation, so we can use tool-calling to orchestrate a sequence of transforms so that the symbolic kernel does all the heavy lifting.

While you can certainly have an LLM generate Mathematica or Python code, if you want to use the LLM as an agent or on-the-fly calculator, it's a better experience to use the MCP server and expose the symbolic tools directly.

The server exposes a subset of symbolic mathematics capabilities including algebraic equation solving, integration and differentiation, vector calculus, tensor calculus for general relativity, and both ordinary and partial differential equations. 

For example, you can ask it in natural language to solve a differential equation:

> Solve the damped harmonic oscillator with forcing term: the mass-spring-damper system described by the differential equation where m is mass, c is the damping coefficient, k is the spring constant, and F(t) is an external force.

$$ m\frac{d^2x}{dt^2} + c\frac{dx}{dt} + kx = F(t) $$

Or involving general relativity:

> Compute the trace of the Ricci tensor $R_{\mu\nu}$ using the inverse metric $g^{\mu\nu}$ for Anti-de Sitter spacetime to determine its constant scalar curvature $R$.

## Usage

You need [uv](https://docs.astral.sh/uv/getting-started/installation/) first.

- **Homebrew** : `brew install uv`
- **Curl** : `curl -LsSf https://astral.sh/uv/install.sh | sh`

Then you can install and run the server with the following commands:

```shell
# Setup the project
git clone https://github.com/sdiehl/sympy-mcp.git
cd sympy-mcp
uv sync

# Install the server to Claude Desktop
uv run mcp install server.py

# Run the server
uv run mcp run server.py
```

You should see the server available in the Claude Desktop app now. For other clients, see below.

If you want a completely standalone version that just runs with a single command, you can use the following. *Note this is running arbitrary code from Github, so be careful.*

```shell
uv run --with https://github.com/sdiehl/sympy-mcp/releases/download/0.1/sympy_mcp-0.1.0-py3-none-any.whl python server.py
```

If you want to do general relativity calculations, you need to install the [`einsteinpy`](https://github.com/einsteinpy/einsteinpy) library.

```shell
uv sync --group relativity
```

## Available Tools

The sympy-mcp server provides the following tools for symbolic mathematics:

| Tool | Tool ID | Description |
|------|-------|-------------|
| Variable Introduction | `intro` | Introduces a variable with specified assumptions and stores it |
| Multiple Variables | `intro_many` | Introduces multiple variables with specified assumptions simultaneously |
| Expression Parser | `introduce_expression` | Parses an expression string using available local variables and stores it |
| LaTeX Printer | `print_latex_expression` | Prints a stored expression in LaTeX format, along with variable assumptions |
| Algebraic Solver | `solve_algebraically` | Solves an equation algebraically for a given variable over a given domain |
| Linear Solver | `solve_linear_system` | Solves a system of linear equations |
| Nonlinear Solver | `solve_nonlinear_system` | Solves a system of nonlinear equations |
| Function Variable | `introduce_function` | Introduces a function variable for use in differential equations |
| ODE Solver | `dsolve_ode` | Solves an ordinary differential equation |
| PDE Solver | `pdsolve_pde` | Solves a partial differential equation |
| Standard Metric | `create_predefined_metric` | Creates a predefined spacetime metric (e.g. Schwarzschild, Kerr, Minkowski) |
| Metric Search | `search_predefined_metrics` | Searches available predefined metrics |
| Tensor Calculator | `calculate_tensor` | Calculates tensors from a metric (Ricci, Einstein, Weyl tensors) |
| Custom Metric | `create_custom_metric` | Creates a custom metric tensor from provided components and symbols |
| Tensor LaTeX | `print_latex_tensor` | Prints a stored tensor expression in LaTeX format |
| Simplifier | `simplify_expression` | Simplifies a mathematical expression using SymPy's canonicalize function |
| Substitution | `substitute_expression` | Substitutes a variable with an expression in another expression |
| Integration | `integrate_expression` | Integrates an expression with respect to a variable |
| Differentiation | `differentiate_expression` | Differentiates an expression with respect to a variable |
| Coordinates | `create_coordinate_system` | Creates a 3D coordinate system for vector calculus operations |
| Vector Field | `create_vector_field` | Creates a vector field in the specified coordinate system |
| Curl | `calculate_curl` | Calculates the curl of a vector field |
| Divergence | `calculate_divergence` | Calculates the divergence of a vector field |
| Gradient | `calculate_gradient` | Calculates the gradient of a scalar field |
| Unit Converter | `convert_to_units` | Converts a quantity to given target units |
| Unit Simplifier | `quantity_simplify_units` | Simplifies a quantity with units |
| Matrix Creator | `create_matrix` | Creates a SymPy matrix from the provided data |
| Determinant | `matrix_determinant` | Calculates the determinant of a matrix |
| Matrix Inverse | `matrix_inverse` | Calculates the inverse of a matrix |
| Eigenvalues | `matrix_eigenvalues` | Calculates the eigenvalues of a matrix |
| Eigenvectors | `matrix_eigenvectors` | Calculates the eigenvectors of a matrix |

By default variables are predefined with assumptions (similar to how the [symbols()](https://docs.sympy.org/latest/modules/core.html#sympy.core.symbol.symbols) function works in SymPy). Unless otherwise specified the defaut assumptions is that a variable is complex, commutative, term over the complex field $\mathbb{C}$.

| Property | Value |
|----------|-------|
| `commutative` | true |
| `complex` | true |
| `finite` | true |
| `infinite` | false |

## Claude Desktop Setup

Normally the `mcp install` command will automatically add the server to the `claude_desktop_config.json` file. If it doesn't you need to find the config file and add the following:

* macOS: `~/Library/Application Support/Claude/claude_desktop_config.json`
* Windows: `%APPDATA%\Claude\claude_desktop_config.json`

Add the following to the `mcpServers` object, replacing `/ABSOLUTE_PATH_TO_SYMPY_MCP/server.py` with the absolute path to the sympy-mcp `server.py` file.

```json
{
  "mcpServers": {
    "sympy-mcp": {
      "command": "/opt/homebrew/bin/uv",
      "args": [
        "run",
        "--with",
        "einsteinpy",
        "--with",
        "mcp[cli]",
        "--with",
        "pydantic",
        "--with",
        "sympy",
        "mcp",
        "run",
        "/ABSOLUTE_PATH_TO_SYMPY_MCP/server.py"
      ]
    }
  }
}
```

## Cursor Setup

In your `~/.cursor/mcp.json`, add the following, where `ABSOLUTE_PATH_TO_SYMPY_MCP` is the path to the sympy-mcp server.py file.

```json
{
  "mcpServers": {
    "sympy-mcp": {
      "command": "/opt/homebrew/bin/uv",
      "args": [
        "run",
        "--with",
        "einsteinpy",
        "--with",
        "mcp[cli]",
        "--with",
        "pydantic",
        "--with",
        "sympy",
        "mcp",
        "run",
        "/ABSOLUTE_PATH_TO_SYMPY_MCP/server.py"
      ]
    }
  }
}
```

## VS Code Setup

VS Code and VS Code Insiders now support MCPs in [agent mode](https://code.visualstudio.com/blogs/2025/04/07/agentMode). For VS Code, you may need to enable `Chat > Agent: Enable` in the settings.

1. **One-click Setup:**

[![Install in VS Code](https://img.shields.io/badge/VS_Code-Install_Server-0098FF?style=flat-square&logo=visualstudiocode&logoColor=white)](https://insiders.vscode.dev/redirect/mcp/install?name=sympy-mcp&config=%7B%22command%22%3A%22docker%22%2C%22args%22%3A%5B%22run%22%2C%22-i%22%2C%22-p%22%2C%228081%3A8081%22%2C%22--rm%22%2C%22ghcr.io%2Fsdiehl%2Fsympy-mcp%3Amain%22%5D%7D)

[![Install in VS Code Insiders](https://img.shields.io/badge/VS_Code_Insiders-Install_Server-24bfa5?style=flat-square&logo=visualstudiocode&logoColor=white)](https://insiders.vscode.dev/redirect/mcp/install?name=sympy-mcp&config=%7B%22command%22%3A%22docker%22%2C%22args%22%3A%5B%22run%22%2C%22-i%22%2C%22-p%22%2C%228081%3A8081%22%2C%22--rm%22%2C%22ghcr.io%2Fsdiehl%2Fsympy-mcp%3Amain%22%5D%7D&quality=insiders)

OR manually add the config to your `settings.json` (global):

```json
{
  "mcp": {
    "servers": {
      "sympy-mcp": {
        "command": "uv",
        "args": [
          "run",
          "--with",
          "einsteinpy",
          "--with",
          "mcp[cli]",
          "--with",
          "pydantic",
          "--with",
          "sympy",
          "mcp",
          "run",
          "/ABSOLUTE_PATH_TO_SYMPY_MCP/server.py"
        ]
      }
    }
  }
}
```

2. Click "Start" above the server config switch to agent mode in the chat, and try commands like "integrate x^2" or "solve x^2 = 1" to get started.

## Cline Setup

To use with [Cline](https://cline.bot/), you need to manually run the MCP server first using the commands in the "Usage" section. Once the MCP server is running, open Cline and select "MCP Servers" at the top.

Then select "Remote Servers" and add the following:

- Server Name: `sympy-mcp`
- Server URL: `http://127.0.0.1:8081/sse`

## 5ire Setup

Another MCP client that supports multiple models (o3, o4-mini, DeepSeek-R1, etc.) on the backend is 5ire.

To set up with [5ire](https://github.com/nanbingxyz/5ire), open 5ire and go to Tools -> New and set the following configurations:

- Tool Key: `sympy-mcp`
- Name: SymPy MCP
- Command: `/opt/homebrew/bin/uv run --with einsteinpy --with mcp[cli] --with pydantic --with sympy mcp run /ABSOLUTE_PATH_TO/server.py`

Replace `/ABSOLUTE_PATH_TO/server.py` with the actual path to your sympy-mcp server.py file.

## Running in Container

You can build and run the server using Docker locally:

```bash
# Build the Docker image
docker build -t sympy-mcp .

# Run the Docker container
docker run -p 8081:8081 sympy-mcp
```

Alternatively, you can pull the pre-built image from GitHub Container Registry:

```bash
# Pull the latest image
docker pull ghcr.io/sdiehl/sympy-mcp:main

# Run the container
docker run -p 8081:8081 --rm ghcr.io/sdiehl/sympy-mcp:main
```

To configure Claude Desktop to launch the Docker container, edit your `claude_desktop_config.json` file:

```json
{
  "mcpServers": {
    "sympy-mcp": {
      "command": "docker",
      "args": [
        "run",
        "-i",
        "-p",
        "8081:8081",
        "--rm",
        "sympy-mcp"
      ]
    }
  }
}
```

Or to use the pre-built container from GitHub:

```json
{
  "mcpServers": {
    "sympy-mcp": {
      "command": "docker",
      "args": [
        "run",
        "-i",
        "-p",
        "8081:8081",
        "--rm",
        "ghcr.io/sdiehl/sympy-mcp:main"
      ]
    }
  }
}
```

This configuration tells Claude Desktop to launch the Docker container when needed. Make sure to build the Docker image (`docker build -t sympy-mcp .`) before using Claude Desktop with this configuration.

The other installation methods can also be adapted to work with Docker if you change the uv command to use the docker run command instead.

## Example Interaction 1 : Differential Equations

Here's an example of how the tool-calling works in a conversation with an AI assistant using this MCP server:

**User**: Solve the damped harmonic oscillator with forcing term: the mass-spring-damper system described by the differential equation where m is mass, c is the damping coefficient, k is the spring constant, and F(t) is an external force.

$$ m\frac{d^2x}{dt^2} + c\frac{dx}{dt} + kx = F(t) $$

**Assistant**: (Internal tool call) `intro_many` to introduce variables:
```json
{
  "variables": [
    {"var_name": "t", "pos_assumptions": ["real"], "neg_assumptions": []},
    {"var_name": "m", "pos_assumptions": ["real", "positive"], "neg_assumptions": []},
    {"var_name": "c", "pos_assumptions": ["real", "positive"], "neg_assumptions": []},
    {"var_name": "k", "pos_assumptions": ["real", "positive"], "neg_assumptions": []}
  ]
}
```

**Assistant**: (Internal tool call) `introduce_function` to define x(t):
```json
{
  "func_name": "x"
}
```

**Assistant**: (Internal tool call) `introduce_expression` to create the differential equation:
```json
{
  "expr_str": "Eq(m*Derivative(x(t), t, t) + c*Derivative(x(t), t) + k*x(t), F(t))"
}
```

**Assistant**: (Internal tool call) `print_latex_expression` to display the equation nicely:
```json
{
  "expr_key": "expr_12"
}
```

Result:

$$
c \frac{d}{d t} x{\left(t \right)} + k x{\left(t \right)} + m \frac{d^{2}}{d t^{2}} x{\left(t \right)} = F{\left(t \right)}
$$

**Assistant**: (Internal tool call) `introduce_function` to define F(t):
```json
{
  "func_name": "F"
}
```

**Assistant**: (Internal tool call) `dsolve_ode` to solve the differential equation:
```json
{
  "expr_key": "expr_12",
  "func_name": "x"
}
```

Result:

$$
x{\left(t \right)} = C_{1} e^{\frac{t \left(- c + \sqrt{c^{2} - 4 k m}\right)}{2 m}} + C_{2} e^{- \frac{t \left(c + \sqrt{c^{2} - 4 k m}\right)}{2 m}} + \cdots
$$

## Example Interaction 2: General Relativity

**User**: Compute the trace of the Ricci tensor $R_{\mu\nu}$ using the inverse metric $g^{\mu\nu}$ for Anti-de Sitter spacetime to determine its constant scalar curvature $R$.

**Assistant**: (Internal tool call) `create_predefined_metric` to create the Anti-de Sitter metric:
```json
{
  "metric_name": "AntiDeSitter"
}
```

**Assistant**: (Internal tool call) `calculate_tensor` to calculate the Ricci tensor:
```json
{
  "metric_key": "metric_AntiDeSitter",
  "tensor_type": "RICCI_TENSOR",
  "simplify_result": true
}
```

**Assistant**: (Internal tool call) `print_latex_tensor` to display the Ricci tensor:
```json
{
  "tensor_key": "riccitensor_metric_AntiDeSitter"
}
```

Result:

$$
\left[\begin{matrix}3 & 0 & 0 & 0\\\\0 & - 3 \cos^{2}{\left(t \right)} & 0 & 0\\\\0 & 0 & - 3 \cos^{2}{\left(t \right)} \sinh^{2}{\left(\chi \right)} & 0\\\\0 & 0 & 0 & - 3 \sin^{2}{\left(\theta \right)} \cos^{2}{\left(t \right)} \sinh^{2}{\left(\chi \right)}\end{matrix}\right]
$$

**Assistant**: (Internal tool call) `calculate_tensor` to calculate the Ricci scalar:
```json
{
  "metric_key": "metric_AntiDeSitter",
  "tensor_type": "RICCI_SCALAR",
  "simplify_result": true
}
```

**Assistant**: (Internal tool call) `print_latex_tensor` to display the Ricci scalar:
```json
{
  "tensor_key": "ricciscalar_metric_AntiDeSitter"
}
```

Result:

$$
-12
$$

## Security Disclaimer

This server runs on your computer and gives the language model access to run Python logic. Notably it uses Sympy's `parse_expr` to parse mathematical expressions, which is uses `eval` under the hood, effectively allowing arbitrary code execution. By running the server, you are trusting the code that Claude generates. Running in the Docker image is slightly safer, but it's still a good idea to review the code before running it.


## License

Copyright 2025 Stephen Diehl.

This project is licensed under the Apache 2.0 License. See the [LICENSE](LICENSE) file for details.

```

--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------

```python
"""
Tests for the sympy-mcp server.
"""

```

--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------

```toml
[project]
name = "sympy-mcp"
version = "0.1.0"
description = "A MCP server for symbolic manipulation of mathematical expressions"
readme = "README.md"
license = "Apache-2.0"
requires-python = ">=3.12"
dependencies = [
    "mcp[cli]>=1.9.0",
    "sympy>=1.14.0",
]

[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
py-modules = ["server", "vars"]

[dependency-groups]
dev = [
    "black>=25.1.0",
    "pytest>=8.3.5",
    "ruff>=0.11.10",
]
relativity = [
    "einsteinpy>=0.4.0",
]

[tool.uv]
default-groups = []

[tool.pytest]
testpaths = ["tests"]
python_files = "test_*.py"
```

--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------

```yaml
name: Test

on:
  push:
    branches: [ main ]
  pull_request:
    branches: [ main ]

jobs:
  test:
    runs-on: ubuntu-latest
    
    steps:
    - uses: actions/checkout@v4
    
    - name: Set up Python 3.12
      uses: actions/setup-python@v5
      with:
        python-version: '3.12'
        
    - name: Install uv
      run: |
        curl -LsSf https://astral.sh/uv/install.sh | sh
        echo "$HOME/.cargo/bin" >> $GITHUB_PATH
        
    - name: Create virtual environment and install dependencies
      run: |
        uv venv
        . .venv/bin/activate
        uv sync --group dev --group relativity
        
    - name: Lint with Ruff
      run: |
        . .venv/bin/activate
        ruff check .
        
    - name: Test with pytest
      run: |
        . .venv/bin/activate
        pytest 
```

--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------

```dockerfile
FROM python:3.12-slim

WORKDIR /app

# The installer requires curl (and certificates) to download the release archive
RUN apt-get update && apt-get install -y --no-install-recommends curl ca-certificates

# Download the latest installer
ADD https://astral.sh/uv/install.sh /uv-installer.sh

# Run the installer then remove it
RUN sh /uv-installer.sh && rm /uv-installer.sh

# Ensure the installed binary is on the `PATH`
ENV PATH="/root/.local/bin/:$PATH"

# Copy application code
COPY pyproject.toml .
COPY vars.py .
COPY server.py .

# Expose the default MCP port
EXPOSE 8081

# Add healthcheck
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
  CMD curl -f http://localhost:8081/healthcheck || exit 1

# Run the server with SSE transport
CMD ["uv", "run", "--with", "mcp[cli]", "--with", "sympy", "mcp", "run", "/app/server.py", "--transport", "sse"] 
```

--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------

```yaml
name: Build and Publish Python Package

on:
  release:
    types: [created]
  workflow_dispatch:

jobs:
  build-and-publish:
    runs-on: ubuntu-latest
    permissions:
      contents: write

    steps:
    - uses: actions/checkout@v4
    
    - name: Set up Python 3.12
      uses: actions/setup-python@v5
      with:
        python-version: '3.12'
        
    - name: Install uv
      run: |
        curl -LsSf https://astral.sh/uv/install.sh | sh
        echo "$HOME/.cargo/bin" >> $GITHUB_PATH
        
    - name: Build package
      run: |
        uv build
        
    - name: List built distributions
      run: |
        ls -l dist/
        
    - name: Upload to GitHub Release
      if: github.event_name == 'release'
      uses: softprops/action-gh-release@v1
      with:
        files: |
          dist/*.whl
          dist/*.tar.gz
      env:
        GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 
```

--------------------------------------------------------------------------------
/tests/test_linalg.py:
--------------------------------------------------------------------------------

```python
from server import (
    intro,
    introduce_expression,
    create_matrix,
    matrix_determinant,
    matrix_inverse,
    matrix_eigenvalues,
    matrix_eigenvectors,
    substitute_expression,
    print_latex_expression,
)


def test_matrix_creation():
    # Create a simple 2x2 matrix
    matrix_key = create_matrix([[1, 2], [3, 4]], "M")
    assert matrix_key == "M"


def test_determinant():
    # Create a matrix and calculate its determinant
    matrix_key = create_matrix([[1, 2], [3, 4]], "M")
    det_key = matrix_determinant(matrix_key)
    # Should be -2
    expr = print_latex_expression(det_key)
    assert expr == "-2"


def test_inverse():
    # Create a matrix and calculate its inverse
    matrix_key = create_matrix([[1, 2], [3, 4]], "M")
    inv_key = matrix_inverse(matrix_key)
    # Check result - don't check exact string as it may vary
    expr = print_latex_expression(inv_key)
    # The inverse of [[1, 2], [3, 4]] should have -2, 1, 3/2, -1/2 as elements
    assert "-2" in expr
    assert "1" in expr
    assert "\\frac{3}{2}" in expr


def test_eigenvalues():
    # Create a matrix and calculate its eigenvalues
    matrix_key = create_matrix([[3, 1], [1, 3]], "M")
    evals_key = matrix_eigenvalues(matrix_key)
    # Eigenvalues should be 2 and 4
    expr = print_latex_expression(evals_key)
    assert "2" in expr
    assert "4" in expr


def test_eigenvectors():
    # Create a matrix and calculate its eigenvectors
    matrix_key = create_matrix([[3, 1], [1, 3]], "M")
    evecs_key = matrix_eigenvectors(matrix_key)
    # Just check that the result is not an error
    expr = print_latex_expression(evecs_key)
    assert "Error" not in expr


def test_substitute():
    # Create variables and expressions
    intro("x", [], [])
    intro("y", [], [])
    expr1 = introduce_expression("x**2 + y**2")
    expr2 = introduce_expression("y + 1")
    # Substitute y = y + 1 in x^2 + y^2
    result_key = substitute_expression(expr1, "y", expr2)
    # Result should be x^2 + (y+1)^2 = x^2 + y^2 + 2y + 1
    expr = print_latex_expression(result_key)
    assert "x^{2}" in expr
    assert "y" in expr

```

--------------------------------------------------------------------------------
/.github/workflows/docker.yml:
--------------------------------------------------------------------------------

```yaml
name: Docker

on:
  schedule:
    - cron: "27 0 * * *"
  push:
    branches: ["main"]
    tags: ["v*.*.*"]
  pull_request:
    branches: ["main"]

env:
  REGISTRY: ghcr.io
  IMAGE_NAME: ${{ github.repository }}

jobs:
  build:
    runs-on: ubuntu-latest
    permissions:
      contents: read
      packages: write

    steps:
      - name: Checkout repository
        uses: actions/checkout@v4

      - name: Set up Docker Buildx
        uses: docker/setup-buildx-action@v3

      # Login against a Docker registry except on PR
      - name: Log into registry ${{ env.REGISTRY }}
        if: github.event_name != 'pull_request'
        uses: docker/login-action@v3
        with:
          registry: ${{ env.REGISTRY }}
          username: ${{ github.actor }}
          password: ${{ secrets.GITHUB_TOKEN }}

      # Extract metadata (tags, labels) for Docker
      - name: Extract Docker metadata
        id: meta
        uses: docker/metadata-action@v5
        with:
          images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
          tags: |
            type=schedule
            type=ref,event=branch
            type=ref,event=tag
            type=ref,event=pr
            type=semver,pattern={{version}}
            type=semver,pattern={{major}}.{{minor}}
            type=semver,pattern={{major}}
            type=sha
            type=edge
            type=raw,value=latest,enable=${{ github.ref_type == 'tag' && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-') }}

      # Build and push Docker image with Buildx
      - name: Build and push Docker image
        uses: docker/build-push-action@v5
        with:
          context: .
          push: ${{ github.event_name != 'pull_request' }}
          tags: ${{ steps.meta.outputs.tags }}
          labels: ${{ steps.meta.outputs.labels }}
          cache-from: type=gha
          cache-to: type=gha,mode=max
          platforms: linux/amd64,linux/arm64

      # Test Docker image if not pushing (PR context)
      - name: Test Docker image
        if: github.event_name == 'pull_request'
        run: |
          docker images
          docker run --rm ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${{ github.sha }} uv run --with sympy python -c "import sympy; print(f'SymPy version: {sympy.__version__}')" 
```

--------------------------------------------------------------------------------
/tests/test_units.py:
--------------------------------------------------------------------------------

```python
import pytest
from server import (
    introduce_expression,
    convert_to_units,
    quantity_simplify_units,
    print_latex_expression,
    expressions,
    initialize_units,
    local_vars,
)
from vars import UnitSystem


@pytest.fixture(autouse=True)
def reset_globals():
    # Clear global dictionaries before each test
    expressions.clear()
    local_vars.clear()  # Clear local_vars to avoid cross-test pollution
    import server

    server.expression_counter = 0
    # Ensure units are properly initialized for each test
    initialize_units()
    yield


def test_convert_to_si():
    # speed_of_light in [meter, second]
    expr_key = introduce_expression("speed_of_light")
    result_key = convert_to_units(expr_key, ["meter", "second"], UnitSystem.SI)
    latex = print_latex_expression(result_key)
    assert "\\text{m}" in latex and "\\text{s}" in latex
    assert "299792458" in latex


def test_convert_to_impossible():
    # speed_of_light in [meter] (should return unchanged)
    expr_key = introduce_expression("speed_of_light")
    result_key = convert_to_units(expr_key, ["meter"], UnitSystem.SI)
    latex = print_latex_expression(result_key)
    assert "\\text{c}" in latex  # c is the symbol for speed of light


def test_convert_to_cgs_gauss():
    # ampere in [meter, gram, second] in cgs_gauss
    expr_key = introduce_expression("ampere")
    # First test with SI
    result_key_si = convert_to_units(
        expr_key, ["meter", "gram", "second"], UnitSystem.SI
    )
    latex_si = print_latex_expression(result_key_si)
    assert "\\text{A}" in latex_si  # A is the symbol for ampere

    # Then with CGS
    result_key_cgs = convert_to_units(
        expr_key, ["meter", "gram", "second"], UnitSystem.CGS
    )
    latex_cgs = print_latex_expression(result_key_cgs)
    # In CGS, ampere should be converted to a combination of base units
    # Either we'll see the units or we'll still see ampere if conversion failed
    assert (
        "\\text{g}" in latex_cgs or "\\text{m}" in latex_cgs or "\\text{A}" in latex_cgs
    )


def test_quantity_simplify():
    # meter/kilometer should simplify to 1/1000 or 0.001
    expr_key = introduce_expression("meter/kilometer")
    result_key = quantity_simplify_units(expr_key, UnitSystem.SI)
    latex = print_latex_expression(result_key)
    assert "0.001" in latex or "\\frac{1}{1000}" in latex or "10^{-3}" in latex

    # Also test .simplify() via sympy
    expr_key2 = introduce_expression("(meter/kilometer).simplify()")
    latex2 = print_latex_expression(expr_key2)
    assert "0.001" in latex2 or "\\frac{1}{1000}" in latex2 or "10^{-3}" in latex2


def test_convert_to_unknown_unit():
    expr_key = introduce_expression("meter")
    result = convert_to_units(expr_key, ["not_a_unit"], UnitSystem.SI)
    assert "Error" in result or "error" in result.lower()


def test_quantity_simplify_nonexistent_expr():
    result = quantity_simplify_units("nonexistent_key", UnitSystem.SI)
    assert "Error" in result or "error" in result.lower()


def test_convert_to_prefixed_units():
    # Test with prefixed units already applied in the expression
    # Create speed of light in femtometer/second directly
    expr_key = introduce_expression(
        "speed_of_light * (10**15)", expr_var_name="speed_of_light_in_fm_s"
    )
    latex = print_latex_expression(expr_key)
    assert "299792458" in latex and "10^{15}" in latex or "c" in latex

    # Test conversion from prefixed units
    expr_key = introduce_expression("1000*kilometer")
    result_key = convert_to_units(expr_key, ["meter"], UnitSystem.SI)
    latex = print_latex_expression(result_key)
    assert "1000000" in latex or "10^{6}" in latex

    # Test with a complex expression involving scaling
    expr_key = introduce_expression(
        "speed_of_light * 10**-9", expr_var_name="speed_in_nm_per_s"
    )
    latex = print_latex_expression(expr_key)
    # The output might be formatted as \frac{c}{1000000000} or similar
    assert "\\text{c}" in latex and (
        "10^{-9}" in latex or "1000000000" in latex or "\\frac" in latex
    )

```

--------------------------------------------------------------------------------
/vars.py:
--------------------------------------------------------------------------------

```python
from enum import Enum


class Assumption(Enum):
    ALGEBRAIC = "algebraic"
    COMMUTATIVE = "commutative"
    COMPLEX = "complex"
    EXTENDED_NEGATIVE = "extended_negative"
    EXTENDED_NONNEGATIVE = "extended_nonnegative"
    EXTENDED_NONPOSITIVE = "extended_nonpositive"
    EXTENDED_NONZERO = "extended_nonzero"
    EXTENDED_POSITIVE = "extended_positive"
    EXTENDED_REAL = "extended_real"
    FINITE = "finite"
    HERMITIAN = "hermitian"
    IMAGINARY = "imaginary"
    INFINITE = "infinite"
    INTEGER = "integer"
    IRATIONAL = "irrational"
    NEGATIVE = "negative"
    NONINTEGER = "noninteger"
    NONNEGATIVE = "nonnegative"
    NONPOSITIVE = "nonpositive"
    NONZERO = "nonzero"
    POSITIVE = "positive"
    RATIONAL = "rational"
    REAL = "real"
    TRANSCENDENTAL = "transcendental"
    ZERO = "zero"


class Domain(Enum):
    COMPLEX = "complex"
    REAL = "real"
    INTEGERS = "integers"
    NATURALS = "naturals"


class ODEHint(Enum):
    FACTORABLE = "factorable"
    NTH_ALGEBRAIC = "nth_algebraic"
    SEPARABLE = "separable"
    FIRST_EXACT = "1st_exact"
    FIRST_LINEAR = "1st_linear"
    BERNOULLI = "Bernoulli"
    FIRST_RATIONAL_RICCATI = "1st_rational_riccati"
    RICCATI_SPECIAL_MINUS2 = "Riccati_special_minus2"
    FIRST_HOMOGENEOUS_COEFF_BEST = "1st_homogeneous_coeff_best"
    FIRST_HOMOGENEOUS_COEFF_SUBS_INDEP_DIV_DEP = (
        "1st_homogeneous_coeff_subs_indep_div_dep"
    )
    FIRST_HOMOGENEOUS_COEFF_SUBS_DEP_DIV_INDEP = (
        "1st_homogeneous_coeff_subs_dep_div_indep"
    )
    ALMOST_LINEAR = "almost_linear"
    LINEAR_COEFFICIENTS = "linear_coefficients"
    SEPARABLE_REDUCED = "separable_reduced"
    FIRST_POWER_SERIES = "1st_power_series"
    LIE_GROUP = "lie_group"
    NTH_LINEAR_CONSTANT_COEFF_HOMOGENEOUS = "nth_linear_constant_coeff_homogeneous"
    NTH_LINEAR_EULER_EQ_HOMOGENEOUS = "nth_linear_euler_eq_homogeneous"
    NTH_LINEAR_CONSTANT_COEFF_UNDETERMINED_COEFFICIENTS = (
        "nth_linear_constant_coeff_undetermined_coefficients"
    )
    NTH_LINEAR_EULER_EQ_NONHOMOGENEOUS_UNDETERMINED_COEFFICIENTS = (
        "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients"
    )
    NTH_LINEAR_CONSTANT_COEFF_VARIATION_OF_PARAMETERS = (
        "nth_linear_constant_coeff_variation_of_parameters"
    )
    NTH_LINEAR_EULER_EQ_NONHOMOGENEOUS_VARIATION_OF_PARAMETERS = (
        "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters"
    )
    LIOUVILLE = "Liouville"
    SECOND_LINEAR_AIRY = "2nd_linear_airy"
    SECOND_LINEAR_BESSEL = "2nd_linear_bessel"
    SECOND_HYPERGEOMETRIC = "2nd_hypergeometric"
    SECOND_HYPERGEOMETRIC_INTEGRAL = "2nd_hypergeometric_Integral"
    NTH_ORDER_REDUCIBLE = "nth_order_reducible"
    SECOND_POWER_SERIES_ORDINARY = "2nd_power_series_ordinary"
    SECOND_POWER_SERIES_REGULAR = "2nd_power_series_regular"
    NTH_ALGEBRAIC_INTEGRAL = "nth_algebraic_Integral"
    SEPARABLE_INTEGRAL = "separable_Integral"
    FIRST_EXACT_INTEGRAL = "1st_exact_Integral"
    FIRST_LINEAR_INTEGRAL = "1st_linear_Integral"
    BERNOULLI_INTEGRAL = "Bernoulli_Integral"
    FIRST_HOMOGENEOUS_COEFF_SUBS_INDEP_DIV_DEP_INTEGRAL = (
        "1st_homogeneous_coeff_subs_indep_div_dep_Integral"
    )
    FIRST_HOMOGENEOUS_COEFF_SUBS_DEP_DIV_INDEP_INTEGRAL = (
        "1st_homogeneous_coeff_subs_dep_div_indep_Integral"
    )
    ALMOST_LINEAR_INTEGRAL = "almost_linear_Integral"
    LINEAR_COEFFICIENTS_INTEGRAL = "linear_coefficients_Integral"
    SEPARABLE_REDUCED_INTEGRAL = "separable_reduced_Integral"
    NTH_LINEAR_CONSTANT_COEFF_VARIATION_OF_PARAMETERS_INTEGRAL = (
        "nth_linear_constant_coeff_variation_of_parameters_Integral"
    )
    NTH_LINEAR_EULER_EQ_NONHOMOGENEOUS_VARIATION_OF_PARAMETERS_INTEGRAL = (
        "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters_Integral"
    )
    LIOUVILLE_INTEGRAL = "Liouville_Integral"
    SECOND_NONLINEAR_AUTONOMOUS_CONSERVED = "2nd_nonlinear_autonomous_conserved"
    SECOND_NONLINEAR_AUTONOMOUS_CONSERVED_INTEGRAL = (
        "2nd_nonlinear_autonomous_conserved_Integral"
    )


class PDEHint(Enum):
    FIRST_LINEAR_CONSTANT_COEFF_HOMOGENEOUS = "1st_linear_constant_coeff_homogeneous"
    FIRST_LINEAR_CONSTANT_COEFF = "1st_linear_constant_coeff"
    FIRST_LINEAR_CONSTANT_COEFF_INTEGRAL = "1st_linear_constant_coeff_Integral"
    FIRST_LINEAR_VARIABLE_COEFF = "1st_linear_variable_coeff"


class Metric(Enum):
    ALCUBIERRE_WARP = "AlcubierreWarp"
    BARRIOLA_VILEKIN = "BarriolaVilekin"
    BERTOTTI_KASNER = "BertottiKasner"
    BESSEL_GRAVITATIONAL_WAVE = "BesselGravitationalWave"
    C_METRIC = "CMetric"
    DAVIDSON = "Davidson"
    ANTI_DE_SITTER = "AntiDeSitter"
    ANTI_DE_SITTER_STATIC = "AntiDeSitterStatic"
    DE_SITTER = "DeSitter"
    ERNST = "Ernst"
    GODEL = "Godel"
    JANIS_NEWMAN_WINICOUR = "JanisNewmanWinicour"
    MINKOWSKI = "Minkowski"
    MINKOWSKI_CARTESIAN = "MinkowskiCartesian"
    MINKOWSKI_POLAR = "MinkowskiPolar"
    KERR = "Kerr"
    KERR_NEWMAN = "KerrNewman"
    REISSNER_NORDSTROM = "ReissnerNordstorm"
    SCHWARZSCHILD = "Schwarzschild"


class Tensor(Enum):
    RICCI_SCALAR = "RicciScalar"
    RICCI_TENSOR = "RicciTensor"
    RIEMANN_CURVATURE_TENSOR = "RiemannCurvatureTensor"
    SCHOUTEN_TENSOR = "SchoutenTensor"
    STRESS_ENERGY_MOMENTUM_TENSOR = "StressEnergyMomentumTensor"
    WEYL_TENSOR = "WeylTensor"
    EINSTEIN_TENSOR = "EinsteinTensor"


class UnitSystem(Enum):
    MKS = "MKS"
    MKSA = "MKSA"
    NATURAL = "natural"
    SI = "SI"
    CGS = "cgs"

```

--------------------------------------------------------------------------------
/tests/test_relativity.py:
--------------------------------------------------------------------------------

```python
import pytest
from server import (
    create_predefined_metric,
    search_predefined_metrics,
    calculate_tensor,
    create_custom_metric,
    print_latex_tensor,
    local_vars,
    expressions,
    metrics,
    tensor_objects,
    EINSTEINPY_AVAILABLE,
)
from vars import Metric, Tensor


# Skip all tests if EinsteinPy is not available
pytestmark = pytest.mark.skipif(
    not EINSTEINPY_AVAILABLE, reason="EinsteinPy library is not available"
)


# Add a fixture to reset global state between tests
@pytest.fixture(autouse=True)
def reset_globals():
    # Clear global dictionaries before each test
    local_vars.clear()
    expressions.clear()
    if EINSTEINPY_AVAILABLE:
        metrics.clear()
        tensor_objects.clear()
    # Reset the expression counter
    import server

    server.expression_counter = 0
    yield


class TestCreatePredefinedMetric:
    def test_create_schwarzschild_metric(self):
        # Test creating a Schwarzschild metric
        result = create_predefined_metric(Metric.SCHWARZSCHILD)
        assert result == "metric_Schwarzschild"
        assert result in metrics
        assert result in expressions

    def test_create_minkowski_metric(self):
        # Test creating a Minkowski metric
        result = create_predefined_metric(Metric.MINKOWSKI)
        assert result == "metric_Minkowski"
        assert result in metrics
        assert result in expressions

    def test_create_kerr_metric(self):
        # Test creating a Kerr metric
        result = create_predefined_metric(Metric.KERR)
        assert result == "metric_Kerr"
        assert result in metrics
        assert result in expressions

    def test_invalid_metric(self):
        # Try to create a metric that's in the enum but not implemented
        # For this test, we'll assume ALCUBIERRE_WARP is not implemented
        # in the provided metric_map
        result = create_predefined_metric(Metric.ALCUBIERRE_WARP)
        assert "Error" in result
        assert "not implemented" in result


class TestSearchPredefinedMetrics:
    def test_search_with_results(self):
        # Search for metrics containing "Sitter"
        result = search_predefined_metrics("Sitter")
        assert "Found metrics" in result
        assert "DeSitter" in result or "AntiDeSitter" in result

    def test_search_no_results(self):
        # Search for a term unlikely to match any metric
        result = search_predefined_metrics("XYZ123")
        assert "No metrics found" in result


class TestCalculateTensor:
    def test_calculate_ricci_tensor(self):
        # First create a metric
        metric_key = create_predefined_metric(Metric.SCHWARZSCHILD)

        # Calculate Ricci tensor
        result = calculate_tensor(metric_key, Tensor.RICCI_TENSOR)
        assert result == f"riccitensor_{metric_key}"
        assert result in expressions

    def test_calculate_ricci_scalar(self):
        # First create a metric
        metric_key = create_predefined_metric(Metric.SCHWARZSCHILD)

        # Calculate Ricci scalar
        result = calculate_tensor(metric_key, Tensor.RICCI_SCALAR)
        assert result == f"ricciscalar_{metric_key}"
        assert result in expressions

    def test_calculate_einstein_tensor(self):
        # First create a metric
        metric_key = create_predefined_metric(Metric.SCHWARZSCHILD)

        # Calculate Einstein tensor
        result = calculate_tensor(metric_key, Tensor.EINSTEIN_TENSOR)
        assert result == f"einsteintensor_{metric_key}"
        assert result in expressions

    def test_invalid_metric_key(self):
        result = calculate_tensor("nonexistent_metric", Tensor.RICCI_TENSOR)
        assert "Error" in result
        assert "not found" in result

    def test_invalid_tensor_type(self):
        # First create a metric
        metric_key = create_predefined_metric(Metric.SCHWARZSCHILD)

        # Try to calculate a tensor that's in the enum but not implemented
        # This test assumes there's at least one tensor type that's not in the tensor_map
        # If all enums are implemented, this test might need adjustment
        class TestEnum:
            value = "NonExistentTensor"

        result = calculate_tensor(metric_key, TestEnum())
        assert "Error" in result
        # Check either for "not implemented" or the attribute error message
        assert "not implemented" in result or "has no attribute" in result


class TestCreateCustomMetric:
    def test_create_custom_metric(self):
        # Create a simple 2x2 diagonal metric with symbols t, r
        components = [["-1", "0"], ["0", "1"]]
        symbols = ["t", "r"]

        result = create_custom_metric(components, symbols)
        assert result == "metric_custom_0"
        assert result in metrics
        assert result in expressions

    def test_create_custom_minkowski(self):
        # Create a 4x4 Minkowski metric (-1, 1, 1, 1)
        components = [
            ["-1", "0", "0", "0"],
            ["0", "1", "0", "0"],
            ["0", "0", "1", "0"],
            ["0", "0", "0", "1"],
        ]
        symbols = ["t", "x", "y", "z"]

        result = create_custom_metric(components, symbols)
        assert result == "metric_custom_0"
        assert result in metrics
        assert result in expressions

    def test_create_custom_metric_with_expressions(self):
        # Create a metric with symbolic expressions
        components = [
            ["-1", "0", "0", "0"],
            ["0", "r**2", "0", "0"],
            ["0", "0", "r**2 * sin(theta)**2", "0"],
            ["0", "0", "0", "1"],
        ]
        symbols = ["t", "r", "theta", "phi"]

        result = create_custom_metric(components, symbols)
        assert result == "metric_custom_0"
        assert result in metrics
        assert result in expressions

    def test_invalid_components(self):
        # Test with invalid components (not a matrix)
        components = [["1", "0"], ["0"]]  # Missing element in second row
        symbols = ["t", "r"]

        result = create_custom_metric(components, symbols)
        assert "Error" in result


class TestPrintLatexTensor:
    def test_print_metric_latex(self):
        # Create a metric and print it in LaTeX
        metric_key = create_predefined_metric(Metric.MINKOWSKI)

        result = print_latex_tensor(metric_key)
        assert result  # Should return a non-empty string
        assert "\\begin{pmatrix}" in result or "\\left[" in result

    def test_print_tensor_latex(self):
        # Create a metric, calculate a tensor, and print it in LaTeX
        metric_key = create_predefined_metric(Metric.SCHWARZSCHILD)
        tensor_key = calculate_tensor(metric_key, Tensor.RICCI_TENSOR)

        result = print_latex_tensor(tensor_key)
        assert result  # Should return a non-empty string

    def test_nonexistent_tensor(self):
        result = print_latex_tensor("nonexistent_tensor")
        assert "Error" in result
        assert "not found" in result

```

--------------------------------------------------------------------------------
/tests/test_calculus.py:
--------------------------------------------------------------------------------

```python
import pytest
from server import (
    intro,
    introduce_expression,
    differentiate_expression,
    integrate_expression,
    create_coordinate_system,
    create_vector_field,
    calculate_curl,
    calculate_divergence,
    calculate_gradient,
    print_latex_expression,
    local_vars,
    expressions,
    coordinate_systems,
)
from vars import Assumption


# Add a fixture to reset global state between tests
@pytest.fixture(autouse=True)
def reset_globals():
    # Clear global dictionaries before each test
    local_vars.clear()
    expressions.clear()
    coordinate_systems.clear()
    # Reset the expression counter
    import server

    server.expression_counter = 0
    yield


class TestDifferentiateExpressionTool:
    def test_differentiate_polynomial(self):
        # Introduce a variable
        intro("x", [Assumption.REAL], [])

        # Create an expression: x^3
        expr_key = introduce_expression("x**3")

        # First derivative
        first_deriv_key = differentiate_expression(expr_key, "x")
        first_deriv_latex = print_latex_expression(first_deriv_key)

        # Should be 3x^2
        assert "3" in first_deriv_latex
        assert "x^{2}" in first_deriv_latex

        # Second derivative
        second_deriv_key = differentiate_expression(expr_key, "x", 2)
        second_deriv_latex = print_latex_expression(second_deriv_key)

        # Should be 6x
        assert "6" in second_deriv_latex
        assert "x" in second_deriv_latex

        # Third derivative
        third_deriv_key = differentiate_expression(expr_key, "x", 3)
        third_deriv_latex = print_latex_expression(third_deriv_key)

        # Should be 6
        assert "6" in third_deriv_latex

    def test_differentiate_trigonometric(self):
        # Introduce a variable
        intro("x", [Assumption.REAL], [])

        # Create sin(x) expression
        sin_key = introduce_expression("sin(x)")

        # Derivative of sin(x) is cos(x)
        deriv_key = differentiate_expression(sin_key, "x")
        deriv_latex = print_latex_expression(deriv_key)

        assert "\\cos" in deriv_latex

    def test_nonexistent_expression(self):
        intro("x", [Assumption.REAL], [])
        result = differentiate_expression("nonexistent_key", "x")
        assert "error" in result.lower()

    def test_nonexistent_variable(self):
        intro("x", [Assumption.REAL], [])
        expr_key = introduce_expression("x**2")
        result = differentiate_expression(expr_key, "y")
        assert "error" in result.lower()


class TestIntegrateExpressionTool:
    def test_indefinite_integral_polynomial(self):
        # Introduce a variable
        intro("x", [Assumption.REAL], [])

        # Create expression: x^2
        expr_key = introduce_expression("x**2")

        # Integrate
        integral_key = integrate_expression(expr_key, "x")
        integral_latex = print_latex_expression(integral_key)

        # Should be x^3/3
        assert "x^{3}" in integral_latex
        assert "3" in integral_latex

    def test_indefinite_integral_trigonometric(self):
        # Introduce a variable
        intro("x", [Assumption.REAL], [])

        # Create expression: cos(x)
        expr_key = introduce_expression("cos(x)")

        # Integrate
        integral_key = integrate_expression(expr_key, "x")
        integral_latex = print_latex_expression(integral_key)

        # Should be sin(x)
        assert "\\sin" in integral_latex

    def test_nonexistent_expression(self):
        intro("x", [Assumption.REAL], [])
        result = integrate_expression("nonexistent_key", "x")
        assert "error" in result.lower()

    def test_nonexistent_variable(self):
        intro("x", [Assumption.REAL], [])
        expr_key = introduce_expression("x**2")
        result = integrate_expression(expr_key, "y")
        assert "error" in result.lower()


class TestVectorOperations:
    def test_create_coordinate_system(self):
        # Create coordinate system
        result = create_coordinate_system("R")
        assert result == "R"
        assert "R" in coordinate_systems

    def test_create_custom_coordinate_system(self):
        # Create coordinate system with custom names
        result = create_coordinate_system("C", ["rho", "phi", "z"])
        assert result == "C"
        assert "C" in coordinate_systems

    def test_create_vector_field(self):
        # Create coordinate system
        create_coordinate_system("R")

        # Introduce variables to represent components
        intro("x", [Assumption.REAL], [])
        intro("y", [Assumption.REAL], [])
        intro("z", [Assumption.REAL], [])

        # Create vector field F = (y, -x, z)
        vector_field_key = create_vector_field("R", "y", "-x", "z")

        # The key might be an error message if the test is failing
        if "error" not in vector_field_key.lower():
            assert vector_field_key.startswith("vector_")
        else:
            assert False, f"Failed to create vector field: {vector_field_key}"

    def test_calculate_curl(self):
        # Create coordinate system
        create_coordinate_system("R")

        # Introduce variables
        intro("x", [Assumption.REAL], [])
        intro("y", [Assumption.REAL], [])

        # Create a simple vector field for curl calculation
        vector_field_key = create_vector_field("R", "y", "-x", "0")

        # Check if vector field was created successfully
        if "error" in vector_field_key.lower():
            assert False, f"Failed to create vector field: {vector_field_key}"

        # Calculate curl
        curl_key = calculate_curl(vector_field_key)

        # Check if curl calculation was successful
        if "error" not in curl_key.lower():
            assert curl_key.startswith("vector_")
        else:
            assert False, f"Failed to calculate curl: {curl_key}"

    def test_calculate_divergence(self):
        # Create coordinate system
        create_coordinate_system("R")

        # Introduce variables
        intro("x", [Assumption.REAL], [])
        intro("y", [Assumption.REAL], [])
        intro("z", [Assumption.REAL], [])

        # Create a simple identity vector field
        vector_field_key = create_vector_field("R", "x", "y", "z")

        # Check if vector field was created successfully
        if "error" in vector_field_key.lower():
            assert False, f"Failed to create vector field: {vector_field_key}"

        # Calculate divergence - should be 0 because symbols have no dependency on coordinates
        div_key = calculate_divergence(vector_field_key)

        # Check if divergence calculation was successful
        if "error" in div_key.lower():
            assert False, f"Failed to calculate divergence: {div_key}"

        div_latex = print_latex_expression(div_key)

        # Check result - should be 0
        assert "0" in div_latex

    def test_calculate_gradient(self):
        # Create coordinate system
        create_coordinate_system("R")

        # Introduce variables
        intro("x", [Assumption.REAL], [])
        intro("y", [Assumption.REAL], [])
        intro("z", [Assumption.REAL], [])

        # Create a simple scalar field
        scalar_field_key = introduce_expression("x**2 + y**2 + z**2")

        # Calculate gradient
        grad_key = calculate_gradient(scalar_field_key)

        # Check if gradient calculation was successful
        if "error" not in grad_key.lower():
            assert grad_key.startswith("vector_")
        else:
            assert False, f"Failed to calculate gradient: {grad_key}"

    def test_nonexistent_coordinate_system(self):
        result = create_vector_field("NonExistent", "x", "y", "z")
        assert "error" in result.lower()

    def test_nonexistent_vector_field(self):
        result = calculate_curl("nonexistent_key")
        assert "error" in result.lower()

```

--------------------------------------------------------------------------------
/tests/test_server.py:
--------------------------------------------------------------------------------

```python
import pytest
from server import (
    intro,
    intro_many,
    introduce_expression,
    print_latex_expression,
    solve_algebraically,
    solve_linear_system,
    solve_nonlinear_system,
    introduce_function,
    dsolve_ode,
    pdsolve_pde,
    local_vars,
    expressions,
    functions,
    VariableDefinition,
)
from vars import Assumption, Domain, ODEHint


# Add a fixture to reset global state between tests
@pytest.fixture(autouse=True)
def reset_globals():
    # Clear global dictionaries before each test
    local_vars.clear()
    expressions.clear()
    functions.clear()  # Add this to clear the functions dictionary as well
    # Reset the expression counter
    import server

    server.expression_counter = 0
    yield


class TestIntroTool:
    def test_intro_basic(self):
        # Test introducing a variable with no assumptions
        result = intro("x", [], [])
        assert result == "x"
        assert "x" in local_vars

    def test_intro_with_assumptions(self):
        # Test introducing a variable with assumptions
        result = intro("y", [Assumption.REAL, Assumption.POSITIVE], [])
        assert result == "y"
        assert "y" in local_vars
        # Check that the symbol has the correct assumptions
        assert local_vars["y"].is_real is True
        assert local_vars["y"].is_positive is True

    def test_intro_inconsistent_assumptions(self):
        # Test introducing a variable with inconsistent assumptions
        # For example, a number can't be both positive and negative
        result = intro("z", [Assumption.POSITIVE], [])
        assert result == "z"
        assert "z" in local_vars

        # Now try to create inconsistent assumptions with another variable
        # Positive and non-positive are inconsistent
        result2 = intro(
            "inconsistent", [Assumption.POSITIVE, Assumption.NONPOSITIVE], []
        )
        assert "error" in result2.lower() or "inconsistent" in result2.lower()
        assert "inconsistent" not in local_vars


class TestIntroManyTool:
    def test_intro_many_basic(self):
        # Define variable definition objects using the VariableDefinition class
        var_defs = [
            VariableDefinition(
                var_name="a", pos_assumptions=["real"], neg_assumptions=[]
            ),
            VariableDefinition(
                var_name="b", pos_assumptions=["positive"], neg_assumptions=[]
            ),
        ]

        intro_many(var_defs)
        assert "a" in local_vars
        assert "b" in local_vars
        assert local_vars["a"].is_real is True
        assert local_vars["b"].is_positive is True

    def test_intro_many_invalid_assumption(self):
        # Create variable definition with an invalid assumption
        var_defs = [
            VariableDefinition(
                var_name="c", pos_assumptions=["invalid_assumption"], neg_assumptions=[]
            ),
        ]

        result = intro_many(var_defs)
        assert "error" in result.lower()


class TestIntroduceExpressionTool:
    def test_introduce_simple_expression(self):
        # First, introduce required variables
        intro("x", [], [])
        intro("y", [], [])

        # Then introduce an expression
        result = introduce_expression("x + y")
        assert result == "expr_0"
        assert "expr_0" in expressions
        assert str(expressions["expr_0"]) == "x + y"

    def test_introduce_equation(self):
        intro("x", [], [])

        result = introduce_expression("Eq(x**2, 4)")
        assert result == "expr_0"
        assert "expr_0" in expressions
        # Equation should be x**2 = 4

        assert expressions["expr_0"].lhs == local_vars["x"] ** 2
        assert expressions["expr_0"].rhs == 4

    def test_introduce_matrix(self):
        result = introduce_expression("Matrix(((1, 2), (3, 4)))")
        assert result == "expr_0"
        assert "expr_0" in expressions
        # Check matrix dimensions and values
        assert expressions["expr_0"].shape == (2, 2)
        assert expressions["expr_0"][0, 0] == 1
        assert expressions["expr_0"][1, 1] == 4


class TestPrintLatexExpressionTool:
    def test_print_latex_simple_expression(self):
        intro("x", [Assumption.REAL], [])
        expr_key = introduce_expression("x**2 + 5*x + 6")

        result = print_latex_expression(expr_key)
        assert "x^{2} + 5 x + 6" in result
        assert "real" in result.lower()

    def test_print_latex_nonexistent_expression(self):
        result = print_latex_expression("nonexistent_key")
        assert "error" in result.lower()


class TestSolveAlgebraicallyTool:
    def test_solve_quadratic(self):
        intro("x", [Assumption.REAL], [])
        expr_key = introduce_expression("Eq(x**2 - 5*x + 6, 0)")

        result = solve_algebraically(expr_key, "x")
        # Solution should contain the values 2 and 3
        assert "2" in result
        assert "3" in result

    def test_solve_with_domain(self):
        intro("x", [Assumption.REAL], [])
        # Try a clearer example: x^2 + 1 = 0 directly as an expression
        expr_key = introduce_expression("x**2 + 1")

        # In complex domain, should have solutions i and -i
        complex_result = solve_algebraically(expr_key, "x", Domain.COMPLEX)
        assert "i" in complex_result

        # In real domain, should have empty set
        real_result = solve_algebraically(expr_key, "x", Domain.REAL)
        assert "\\emptyset" in real_result

    def test_solve_invalid_domain(self):
        intro("x", [], [])
        introduce_expression("x**2 - 4")
        # We can't really test with an invalid Domain enum value easily,
        # so we'll skip this test since it's handled by type checking
        # If needed, could test with a mock Domain object that's not in the map

    def test_solve_nonexistent_expression(self):
        intro("x", [], [])
        result = solve_algebraically("nonexistent_key", "x")
        assert "error" in result.lower()

    def test_solve_nonexistent_variable(self):
        intro("x", [], [])
        expr_key = introduce_expression("x**2 - 4")
        result = solve_algebraically(expr_key, "y")
        assert "error" in result.lower()


class TestSolveLinearSystemTool:
    def test_simple_linear_system(self):
        # Create variables
        intro("x", [Assumption.REAL], [])
        intro("y", [Assumption.REAL], [])

        # Create a system of linear equations: x + y = 10, 2x - y = 5
        eq1 = introduce_expression("Eq(x + y, 10)")
        eq2 = introduce_expression("Eq(2*x - y, 5)")

        # Solve the system
        result = solve_linear_system([eq1, eq2], ["x", "y"])

        # Check if solution contains the expected values (x=5, y=5)
        assert "5" in result

    def test_inconsistent_system(self):
        # Create variables
        intro("x", [Assumption.REAL], [])
        intro("y", [Assumption.REAL], [])

        # Create an inconsistent system: x + y = 1, x + y = 2
        eq1 = introduce_expression("Eq(x + y, 1)")
        eq2 = introduce_expression("Eq(x + y, 2)")

        # Solve the system
        result = solve_linear_system([eq1, eq2], ["x", "y"])

        # Should be empty set
        assert "\\emptyset" in result

    def test_nonexistent_expression(self):
        intro("x", [], [])
        intro("y", [], [])
        result = solve_linear_system(["nonexistent_key"], ["x", "y"])
        assert "error" in result.lower()

    def test_nonexistent_variable(self):
        intro("x", [], [])
        expr_key = introduce_expression("x**2 - 4")
        result = solve_linear_system([expr_key], ["y"])
        assert "error" in result.lower()


class TestSolveNonlinearSystemTool:
    def test_simple_nonlinear_system(self):
        # Create variables
        intro("x", [Assumption.REAL], [])
        intro("y", [Assumption.REAL], [])

        # Create a system of nonlinear equations: x^2 + y^2 = 25, x*y = 12
        eq1 = introduce_expression("Eq(x**2 + y**2, 25)")
        eq2 = introduce_expression("Eq(x*y, 12)")

        # Solve the system
        result = solve_nonlinear_system([eq1, eq2], ["x", "y"])

        # Should find two pairs of solutions (±3, ±4) and (±4, ±3)
        # The exact format can vary, so we just check for the presence of 3 and 4
        assert "3" in result
        assert "4" in result

    def test_with_domain(self):
        # Create variables - importantly, not specifying REAL assumption
        # because we want to test complex solutions
        intro("x", [], [])
        intro("y", [], [])

        # Create a system with complex solutions: x^2 + y^2 = -1, y = x
        # This has no real solutions but has complex solutions
        eq1 = introduce_expression("Eq(x**2 + y**2, -1)")
        eq2 = introduce_expression("Eq(y, x)")

        # In complex domain - should have solutions with imaginary parts
        complex_result = solve_nonlinear_system([eq1, eq2], ["x", "y"], Domain.COMPLEX)
        assert "i" in complex_result

        # In real domain - now simply verifies we get a result (even if it contains complex solutions)
        # The user is responsible for knowing that solutions might be complex
        real_result = solve_nonlinear_system([eq1, eq2], ["x", "y"], Domain.REAL)
        assert real_result  # Just verify we get some result

    def test_nonexistent_expression(self):
        intro("x", [], [])
        intro("y", [], [])
        result = solve_nonlinear_system(["nonexistent_key"], ["x", "y"])
        assert "error" in result.lower()

    def test_nonexistent_variable(self):
        intro("x", [], [])
        expr_key = introduce_expression("x**2 - 4")
        result = solve_nonlinear_system([expr_key], ["z"])
        assert "error" in result.lower()


class TestIntroduceFunctionTool:
    def test_introduce_function_basic(self):
        # Test introducing a function variable
        result = introduce_function("f")
        assert result == "f"
        assert "f" in functions
        assert str(functions["f"]) == "f"

    def test_function_usage_in_expression(self):
        # Introduce a variable and a function
        intro("x", [Assumption.REAL], [])
        introduce_function("f")

        # Create an expression using the function
        expr_key = introduce_expression("f(x)")

        assert expr_key == "expr_0"
        assert "expr_0" in expressions
        assert str(expressions["expr_0"]) == "f(x)"


class TestDsolveOdeTool:
    def test_simple_ode(self):
        # Introduce a variable and a function
        intro("x", [Assumption.REAL], [])
        introduce_function("f")

        # Create a differential equation: f''(x) + 9*f(x) = 0
        expr_key = introduce_expression("Derivative(f(x), x, x) + 9*f(x)")

        # Solve the ODE
        result = dsolve_ode(expr_key, "f")

        # The solution should include sin(3*x) and cos(3*x)
        assert "sin" in result
        assert "cos" in result
        assert "3 x" in result

    def test_ode_with_hint(self):
        # Introduce a variable and a function
        intro("x", [Assumption.REAL], [])
        introduce_function("f")

        # Create a first-order exact equation: sin(x)*cos(f(x)) + cos(x)*sin(f(x))*f'(x) = 0
        expr_key = introduce_expression(
            "sin(x)*cos(f(x)) + cos(x)*sin(f(x))*Derivative(f(x), x)"
        )

        # Solve with specific hint
        result = dsolve_ode(expr_key, "f", ODEHint.FIRST_EXACT)

        # The solution might contain acos instead of sin
        assert "acos" in result or "sin" in result

    def test_nonexistent_expression(self):
        introduce_function("f")
        result = dsolve_ode("nonexistent_key", "f")
        assert "error" in result.lower()

    def test_nonexistent_function(self):
        intro("x", [Assumption.REAL], [])
        introduce_function("f")
        expr_key = introduce_expression("Derivative(f(x), x) - f(x)")
        result = dsolve_ode(expr_key, "g")
        assert "error" in result.lower()


class TestPdsolvePdeTool:
    def test_simple_pde(self):
        # Introduce variables
        intro("x", [Assumption.REAL], [])
        intro("y", [Assumption.REAL], [])

        # Introduce a function
        introduce_function("f")

        # Create a PDE: 1 + 2*(ux/u) + 3*(uy/u) = 0
        # where u = f(x, y), ux = u.diff(x), uy = u.diff(y)
        expr_key = introduce_expression(
            "Eq(1 + 2*Derivative(f(x, y), x)/f(x, y) + 3*Derivative(f(x, y), y)/f(x, y), 0)"
        )

        # Solve the PDE
        result = pdsolve_pde(expr_key, "f")

        # Solution should include e^ (LaTeX for exponential) and arbitrary function F
        assert "e^" in result
        assert "F" in result

    def test_nonexistent_expression(self):
        introduce_function("f")
        result = pdsolve_pde("nonexistent_key", "f")
        assert "error" in result.lower()

    def test_nonexistent_function(self):
        intro("x", [Assumption.REAL], [])
        intro("y", [Assumption.REAL], [])
        introduce_function("f")
        expr_key = introduce_expression(
            "Derivative(f(x, y), x) + Derivative(f(x, y), y)"
        )
        result = pdsolve_pde(expr_key, "g")
        assert "error" in result.lower()

    def test_no_function_application(self):
        # Test with an expression that doesn't contain the function
        intro("x", [Assumption.REAL], [])
        intro("y", [Assumption.REAL], [])
        introduce_function("f")
        expr_key = introduce_expression("x + y")
        result = pdsolve_pde(expr_key, "f")
        assert "error" in result.lower()
        assert "function cannot be automatically detected" in result.lower()

```

--------------------------------------------------------------------------------
/server.py:
--------------------------------------------------------------------------------

```python
# A stateful MCP server that holds a sympy sesssion, with symbol table of variables
# that can be used in the tools API to define and manipulate expressions.

from mcp.server.fastmcp import FastMCP
import sympy
import argparse
import logging
from typing import List, Dict, Optional, Literal, Any, Union
from pydantic import BaseModel
from sympy.parsing.sympy_parser import parse_expr
from sympy.core.facts import InconsistentAssumptions
from vars import Assumption, Domain, ODEHint, PDEHint, Metric, Tensor, UnitSystem
from sympy import Eq, Function, dsolve, diff, integrate, simplify, Matrix
from sympy.solvers.pde import pdsolve
from sympy.vector import CoordSys3D, curl, divergence, gradient

from sympy.physics.units import convert_to
from sympy.physics.units import __dict__ as units_dict
from sympy.physics.units.systems import SI, MKS, MKSA, natural
from sympy.physics.units.systems.cgs import cgs_gauss

# Import common units
from sympy.physics.units import (
    meter,
    kilogram,
    second,
    ampere,
    kelvin,
    mole,
    candela,
    kilometer,
    millimeter,
    gram,
    joule,
    newton,
    pascal,
    watt,
    coulomb,
    volt,
    ohm,
    farad,
    henry,
    speed_of_light,
    gravitational_constant,
    planck,
    day,
    year,
    minute,
    hour,
)

try:
    from einsteinpy.symbolic import (
        MetricTensor,
        RicciTensor,
        RicciScalar,
        EinsteinTensor,
        WeylTensor,
        ChristoffelSymbols,
        StressEnergyMomentumTensor,
    )
    from einsteinpy.symbolic.predefined import (
        Schwarzschild,
        Minkowski,
        MinkowskiCartesian,
        KerrNewman,
        Kerr,
        AntiDeSitter,
        DeSitter,
        ReissnerNordstorm,
        find,
    )

    EINSTEINPY_AVAILABLE = True
except ImportError:
    EINSTEINPY_AVAILABLE = False

# Set up logging
logger = logging.getLogger(__name__)

# Create an MCP server
mcp = FastMCP(
    "sympy-mcp",
    dependencies=["sympy", "pydantic", "einsteinpy"],
    instructions="Provides access to the Sympy computer algebra system, which can perform symbolic manipulation of mathematical expressions.",
)

# Global store for sympy variables and expressions
local_vars: Dict[str, sympy.Symbol] = {}
functions: Dict[str, sympy.Function] = {}
expressions: Dict[str, sympy.Expr] = {}
metrics: Dict[str, Any] = {}
tensor_objects: Dict[str, Any] = {}
coordinate_systems: Dict[str, CoordSys3D] = {}
expression_counter = 0


# Pydantic model for defining a variable with assumptions
class VariableDefinition(BaseModel):
    var_name: str
    pos_assumptions: List[str] = []
    neg_assumptions: List[str] = []


# x, y = symbols('x, y', commutative=False)


# Add an addition tool
@mcp.tool()
def intro(
    var_name: str, pos_assumptions: List[Assumption], neg_assumptions: List[Assumption]
) -> str:
    """Introduces a sympy variable with specified assumptions and stores it.

    Takes a variable name and a list of positive and negative assumptions.
    """
    kwargs_for_symbols = {}
    # Add assumptions
    for assumption_obj in pos_assumptions:
        kwargs_for_symbols[assumption_obj.value] = True

    for assumption_obj in neg_assumptions:
        kwargs_for_symbols[assumption_obj.value] = False

    try:
        var = sympy.symbols(var_name, **kwargs_for_symbols)
    except InconsistentAssumptions as e:
        return f"Error creating symbol '{var_name}': The provided assumptions {kwargs_for_symbols} are inconsistent according to SymPy. Details: {str(e)}"
    except Exception as e:
        return f"Error creating symbol '{var_name}': An unexpected error occurred. Assumptions attempted: {kwargs_for_symbols}. Details: {type(e).__name__} - {str(e)}"

    local_vars[var_name] = var
    return var_name


# Introduce multiple variables simultaneously
@mcp.tool()
def intro_many(variables: List[VariableDefinition]) -> str:
    """Introduces multiple sympy variables with specified assumptions and stores them.

    Takes a list of VariableDefinition objects for the 'variables' parameter.
    Each object in the list specifies:
    - var_name: The name of the variable (string).
    - pos_assumptions: A list of positive assumption strings (e.g., ["real", "positive"]).
    - neg_assumptions: A list of negative assumption strings (e.g., ["complex"]).

    The JSON payload for the 'variables' argument should be a direct list of these objects, for example:
    ```json
    [
        {
            "var_name": "x",
            "pos_assumptions": ["real", "positive"],
            "neg_assumptions": ["complex"]
        },
        {
            "var_name": "y",
            "pos_assumptions": [],
            "neg_assumptions": ["commutative"]
        }
    ]
    ```

    The assumptions must be consistent, so a real number is not allowed to be non-commutative.

    Prefer this over intro() for multiple variables because it's more efficient.
    """
    var_keys = {}
    for var_def in variables:
        try:
            processed_pos_assumptions = [
                Assumption(a_str) for a_str in var_def.pos_assumptions
            ]
            processed_neg_assumptions = [
                Assumption(a_str) for a_str in var_def.neg_assumptions
            ]
        except ValueError as e:
            # Handle cases where a string doesn't match an Assumption enum member
            msg = (
                f"Error for variable '{var_def.var_name}': Invalid assumption string provided. {e}. "
                f"Ensure assumptions match valid enum values in 'vars.Assumption'."
            )
            logger.error(msg)
            return msg  # Or collect errors

        var_key = intro(
            var_def.var_name, processed_pos_assumptions, processed_neg_assumptions
        )
        var_keys[var_def.var_name] = var_key

    # Return the mapping of variable names to keys
    return str(var_keys)


# XXX use local_vars {x : "expr_1", y : "expr_2"}
@mcp.tool()
def introduce_expression(
    expr_str: str, canonicalize: bool = True, expr_var_name: Optional[str] = None
) -> str:
    """Parses a sympy expression string using available local variables and stores it. Assigns it to either a temporary name (expr_0, expr_1, etc.) or a user-specified global name.

    Uses Sympy parse_expr to parse the expression string.

    Applies default Sympy canonicalization rules unless canonicalize is False.

    For equations (x^2 = 1) make the input string "Eq(x^2, 1") not "x^2 == 1"

    Examples:

        {expr_str: "Eq(x^2 + y^2, 1)"}
        {expr_str: "Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11)))"}
        {expr_str: "pi+e", "expr_var_name": "z"}
    """
    global expression_counter
    # Merge local_vars and functions dictionaries to make both available for parsing
    parse_dict = {**local_vars, **functions}
    parsed_expr = parse_expr(expr_str, local_dict=parse_dict, evaluate=canonicalize)
    if expr_var_name is None:
        expr_key = f"expr_{expression_counter}"
    else:
        expr_key = expr_var_name
    expressions[expr_key] = parsed_expr
    expression_counter += 1
    return expr_key


def introduce_equation(lhs_str: str, rhs_str: str) -> str:
    """Introduces an equation (lhs = rhs) using available local variables."""
    global expression_counter
    # Merge local_vars and functions dictionaries to make both available for parsing
    parse_dict = {**local_vars, **functions}
    lhs_expr = parse_expr(lhs_str, local_dict=parse_dict)
    rhs_expr = parse_expr(rhs_str, local_dict=parse_dict)
    eq_key = f"eq_{expression_counter}"
    expressions[eq_key] = Eq(lhs_expr, rhs_expr)
    expression_counter += 1
    return eq_key


@mcp.tool()
def print_latex_expression(expr_key: str) -> str:
    """Prints a stored expression in LaTeX format, along with variable assumptions."""
    if expr_key not in expressions:
        return f"Error: Expression key '{expr_key}' not found."

    expr = expressions[expr_key]

    # Handle dictionary objects (like eigenvalues)
    if isinstance(expr, dict):
        if all(isinstance(k, (sympy.Expr, int, float)) for k in expr.keys()):
            # Format as eigenvalues: {value: multiplicity, ...}
            parts = []
            for eigenval, multiplicity in expr.items():
                parts.append(
                    f"{sympy.latex(eigenval)} \\text{{ (multiplicity {multiplicity})}}"
                )
            return ", ".join(parts)
        else:
            # Generic dictionary
            return str(expr)

    # Handle list objects (like eigenvectors)
    elif isinstance(expr, list):
        # For eigenvectors format: [(eigenval, multiplicity, [eigenvectors]), ...]
        if all(isinstance(item, tuple) and len(item) == 3 for item in expr):
            parts = []
            for eigenval, multiplicity, eigenvects in expr:
                eigenvects_latex = [sympy.latex(v) for v in eigenvects]
                parts.append(
                    f"\\lambda = {sympy.latex(eigenval)} \\text{{ (multiplicity {multiplicity})}}:\n"
                    f"\\text{{Eigenvectors: }}[{', '.join(eigenvects_latex)}]"
                )
            return "\n".join(parts)
        else:
            # Try to convert each element to LaTeX
            try:
                return str([sympy.latex(item) for item in expr])
            except Exception as e:
                # Log the exception if there's a logger configured
                logger.debug(f"Error converting list items to LaTeX: {str(e)}")
                return str(expr)

    # Original behavior for sympy expressions
    latex_str = sympy.latex(expr)

    # Find variables in the expression and their assumptions
    try:
        variables_in_expr = expr.free_symbols
        assumption_descs = []
        for var_symbol in variables_in_expr:
            var_name = str(var_symbol)
            if var_name in local_vars:
                # Get assumptions directly from the symbol object
                current_assumptions = []
                # sympy stores assumptions in a private attribute _assumptions
                # and provides a way to query them via .is_commutative, .is_real etc.
                # We can iterate through known Assumption enum values
                for assumption_enum_member in Assumption:
                    if getattr(var_symbol, f"is_{assumption_enum_member.value}", False):
                        current_assumptions.append(assumption_enum_member.value)

                if current_assumptions:
                    assumption_descs.append(
                        f"{var_name} is {', '.join(current_assumptions)}"
                    )
                else:
                    assumption_descs.append(
                        f"{var_name} (no specific assumptions listed)"
                    )
            else:
                assumption_descs.append(f"{var_name} (undefined in local_vars)")

        if assumption_descs:
            return f"{latex_str} (where {'; '.join(assumption_descs)})"
        else:
            return latex_str
    except AttributeError:
        # If expr doesn't have free_symbols, just return the LaTeX
        return latex_str


@mcp.tool()
def solve_algebraically(
    expr_key: str, solve_for_var_name: str, domain: Domain = Domain.COMPLEX
) -> str:
    """Solves an equation (expression = 0) algebraically for a given variable.

    Args:
        expr_key: The key of the expression (previously introduced) to be solved.
        solve_for_var_name: The name of the variable (previously introduced) to solve for.
        domain: The domain to solve in: Domain.COMPLEX, Domain.REAL, Domain.INTEGERS, or Domain.NATURALS. Defaults to Domain.COMPLEX.

    Returns:
        A LaTeX string representing the set of solutions. Returns an error message string if issues occur.
    """
    if expr_key not in expressions:
        return f"Error: Expression with key '{expr_key}' not found."

    expression_to_solve = expressions[expr_key]

    if solve_for_var_name not in local_vars:
        return f"Error: Variable '{solve_for_var_name}' not found in local_vars. Please introduce it first."

    variable_symbol = local_vars[solve_for_var_name]

    # Map domain enum to SymPy domain sets
    domain_map = {
        Domain.COMPLEX: sympy.S.Complexes,
        Domain.REAL: sympy.S.Reals,
        Domain.INTEGERS: sympy.S.Integers,
        Domain.NATURALS: sympy.S.Naturals0,
    }

    if domain not in domain_map:
        return "Error: Invalid domain. Choose from: Domain.COMPLEX, Domain.REAL, Domain.INTEGERS, or Domain.NATURALS."

    sympy_domain = domain_map[domain]

    try:
        # If the expression is an equation (Eq object), convert it to standard form
        if isinstance(expression_to_solve, sympy.Eq):
            expression_to_solve = expression_to_solve.lhs - expression_to_solve.rhs

        # Use solveset instead of solve
        solution_set = sympy.solveset(
            expression_to_solve, variable_symbol, domain=sympy_domain
        )

        # Convert the set to LaTeX format
        latex_output = sympy.latex(solution_set)
        return latex_output
    except NotImplementedError as e:
        return f"Error: SymPy could not solve the equation: {str(e)}. The equation may not have a closed-form solution or the algorithm is not implemented."
    except Exception as e:
        return f"An unexpected error occurred during solving: {str(e)}"


@mcp.tool()
def solve_linear_system(
    expr_keys: List[str], var_names: List[str], domain: Domain = Domain.COMPLEX
) -> str:
    """Solves a system of linear equations using SymPy's linsolve.

    Args:
        expr_keys: The keys of the expressions (previously introduced) forming the system.
        var_names: The names of the variables to solve for.
        domain: The domain to solve in (Domain.COMPLEX, Domain.REAL, etc.). Defaults to Domain.COMPLEX.

    Returns:
        A LaTeX string representing the solution set. Returns an error message string if issues occur.
    """
    # Validate all expression keys exist
    system = []
    for expr_key in expr_keys:
        if expr_key not in expressions:
            return f"Error: Expression with key '{expr_key}' not found."

        expr = expressions[expr_key]
        # Convert equations to standard form
        if isinstance(expr, sympy.Eq):
            expr = expr.lhs - expr.rhs
        system.append(expr)

    # Validate all variables exist
    symbols = []
    for var_name in var_names:
        if var_name not in local_vars:
            return f"Error: Variable '{var_name}' not found in local_vars. Please introduce it first."
        symbols.append(local_vars[var_name])

    # Map domain enum to SymPy domain sets
    domain_map = {
        Domain.COMPLEX: sympy.S.Complexes,
        Domain.REAL: sympy.S.Reals,
        Domain.INTEGERS: sympy.S.Integers,
        Domain.NATURALS: sympy.S.Naturals0,
    }

    if domain not in domain_map:
        return "Error: Invalid domain. Choose from: Domain.COMPLEX, Domain.REAL, Domain.INTEGERS, or Domain.NATURALS."

    domain_map[domain]

    try:
        # Use SymPy's linsolve - note: it doesn't take domain parameter directly, but works on the given domain
        solution_set = sympy.linsolve(system, symbols)

        # Convert the set to LaTeX format
        latex_output = sympy.latex(solution_set)
        return latex_output
    except NotImplementedError as e:
        return f"Error: SymPy could not solve the linear system: {str(e)}."
    except ValueError as e:
        return f"Error: Invalid system or arguments: {str(e)}."
    except Exception as e:
        return f"An unexpected error occurred during solving: {str(e)}"


@mcp.tool()
def solve_nonlinear_system(
    expr_keys: List[str], var_names: List[str], domain: Domain = Domain.COMPLEX
) -> str:
    """Solves a system of nonlinear equations using SymPy's nonlinsolve.

    Args:
        expr_keys: The keys of the expressions (previously introduced) forming the system.
        var_names: The names of the variables to solve for.
        domain: The domain to solve in (Domain.COMPLEX, Domain.REAL, etc.). Defaults to Domain.COMPLEX.

    Returns:
        A LaTeX string representing the solution set. Returns an error message string if issues occur.
    """
    # Validate all expression keys exist
    system = []
    for expr_key in expr_keys:
        if expr_key not in expressions:
            return f"Error: Expression with key '{expr_key}' not found."

        expr = expressions[expr_key]
        # Convert equations to standard form
        if isinstance(expr, sympy.Eq):
            expr = expr.lhs - expr.rhs
        system.append(expr)

    # Validate all variables exist
    symbols = []
    for var_name in var_names:
        if var_name not in local_vars:
            return f"Error: Variable '{var_name}' not found in local_vars. Please introduce it first."
        symbols.append(local_vars[var_name])

    # Map domain enum to SymPy domain sets
    domain_map = {
        Domain.COMPLEX: sympy.S.Complexes,
        Domain.REAL: sympy.S.Reals,
        Domain.INTEGERS: sympy.S.Integers,
        Domain.NATURALS: sympy.S.Naturals0,
    }

    if domain not in domain_map:
        return "Error: Invalid domain. Choose from: Domain.COMPLEX, Domain.REAL, Domain.INTEGERS, or Domain.NATURALS."

    try:
        # Use SymPy's nonlinsolve
        solution_set = sympy.nonlinsolve(system, symbols)

        # Convert the set to LaTeX format
        latex_output = sympy.latex(solution_set)
        return latex_output
    except NotImplementedError as e:
        return f"Error: SymPy could not solve the nonlinear system: {str(e)}."
    except ValueError as e:
        return f"Error: Invalid system or arguments: {str(e)}."
    except Exception as e:
        return f"An unexpected error occurred during solving: {str(e)}"


@mcp.tool()
def introduce_function(func_name: str) -> str:
    """Introduces a SymPy function variable and stores it.

    Takes a function name and creates a SymPy Function object for use in defining differential equations.

    Example:
        {func_name: "f"} will create the function f(x), f(t), etc. that can be used in expressions

    Returns:
        The name of the created function.
    """
    func = Function(func_name)
    functions[func_name] = func
    return func_name


@mcp.tool()
def dsolve_ode(expr_key: str, func_name: str, hint: Optional[ODEHint] = None) -> str:
    """Solves an ordinary differential equation using SymPy's dsolve function.

    Args:
        expr_key: The key of the expression (previously introduced) containing the differential equation.
        func_name: The name of the function (previously introduced) to solve for.
        hint: Optional solving method from ODEHint enum. If None, SymPy will try to determine the best method.

    Example:
        # First introduce a variable and a function
        intro("x", [Assumption.REAL], [])
        introduce_function("f")

        # Create a second-order ODE: f''(x) + 9*f(x) = 0
        expr_key = introduce_expression("Derivative(f(x), x, x) + 9*f(x)")

        # Solve the ODE
        result = dsolve_ode(expr_key, "f")
        # Returns solution with sin(3*x) and cos(3*x) terms

    Returns:
        A LaTeX string representing the solution. Returns an error message string if issues occur.
    """
    if expr_key not in expressions:
        return f"Error: Expression with key '{expr_key}' not found."

    if func_name not in functions:
        return f"Error: Function '{func_name}' not found. Please introduce it first using introduce_function."

    expression = expressions[expr_key]

    try:
        # Convert to equation form if it's not already
        if isinstance(expression, sympy.Eq):
            eq = expression
        else:
            eq = sympy.Eq(expression, 0)

        # Let SymPy handle function detection and apply the specified hint if provided
        if hint is not None:
            solution = dsolve(eq, hint=hint.value)
        else:
            solution = dsolve(eq)

        # Convert the solution to LaTeX format
        latex_output = sympy.latex(solution)
        return latex_output
    except ValueError as e:
        return f"Error: {str(e)}. This might be due to an invalid hint or an unsupported equation type."
    except NotImplementedError as e:
        return f"Error: Method not implemented: {str(e)}. Try a different hint or equation type."
    except Exception as e:
        return f"An unexpected error occurred: {str(e)}"


@mcp.tool()
def pdsolve_pde(expr_key: str, func_name: str, hint: Optional[PDEHint] = None) -> str:
    """Solves a partial differential equation using SymPy's pdsolve function.

    Args:
        expr_key: The key of the expression (previously introduced) containing the PDE.
                 If the expression is not an equation (Eq), it will be interpreted as
                 PDE = 0.
        func_name: The name of the function (previously introduced) to solve for.
                   This should be a function of multiple variables.

    Example:
        # First introduce variables and a function
        intro("x", [Assumption.REAL], [])
        intro("y", [Assumption.REAL], [])
        introduce_function("f")

        # Create a PDE: 1 + 2*(ux/u) + 3*(uy/u) = 0
        expr_key = introduce_expression(
            "Eq(1 + 2*Derivative(f(x, y), x)/f(x, y) + 3*Derivative(f(x, y), y)/f(x, y), 0)"
        )

        # Solve the PDE
        result = pdsolve_pde(expr_key, "f")
        # Returns solution with exponential terms and arbitrary function

    Returns:
        A LaTeX string representing the solution. Returns an error message string if issues occur.
    """
    if expr_key not in expressions:
        return f"Error: Expression with key '{expr_key}' not found."

    if func_name not in functions:
        return f"Error: Function '{func_name}' not found. Please introduce it first using introduce_function."

    expression = expressions[expr_key]

    try:
        # Handle both equation and non-equation expressions
        if isinstance(expression, sympy.Eq):
            eq = expression
        else:
            eq = sympy.Eq(expression, 0)

        # Let SymPy's pdsolve find the dependent variable itself
        if hint is not None:
            solution = pdsolve(eq, hint=hint.value)
        else:
            solution = pdsolve(eq)

        # Convert the solution to LaTeX format
        latex_output = sympy.latex(solution)
        return latex_output
    except ValueError as e:
        return f"Error: {str(e)}. This might be due to an unsupported equation type."
    except NotImplementedError as e:
        return f"Error: Method not implemented: {str(e)}. The PDE might not be solvable using the implemented methods."
    except Exception as e:
        return f"An unexpected error occurred: {str(e)}"


# Einstein relativity tools
if EINSTEINPY_AVAILABLE:

    @mcp.tool()
    def create_predefined_metric(metric_name: str) -> str:
        """Creates a predefined spacetime metric from einsteinpy.symbolic.predefined.

        Args:
            metric_name: The name of the metric to create (e.g., "AntiDeSitter", "Schwarzschild").

        Returns:
            A key for the stored metric object.
        """
        try:
            # Handle if metric_name is actually a Metric enum already
            if isinstance(metric_name, Metric):
                metric_enum = metric_name
            else:
                # First try direct mapping to enum value
                metric_enum = None

                # Try to match by enum value (the string in the enum definition)
                for metric in Metric:
                    if metric.value.lower() == metric_name.lower():
                        metric_enum = metric
                        break

                # If it didn't match any enum value, try to match by enum name
                if metric_enum is None:
                    try:
                        # Try exact name match
                        metric_enum = Metric[metric_name.upper()]
                    except KeyError:
                        # Try normalized name (remove spaces, underscores, etc.)
                        normalized_name = "".join(
                            c.upper() for c in metric_name if c.isalnum()
                        )
                        for m in Metric:
                            if (
                                "".join(c for c in m.name if c.isalnum())
                                == normalized_name
                            ):
                                metric_enum = m
                                break

            if metric_enum is None:
                return f"Error: Invalid metric name '{metric_name}'. Available metrics are: {', '.join(m.value for m in Metric)}"

            metric_map = {
                Metric.SCHWARZSCHILD: Schwarzschild,
                Metric.MINKOWSKI: Minkowski,
                Metric.MINKOWSKI_CARTESIAN: MinkowskiCartesian,
                Metric.KERR_NEWMAN: KerrNewman,
                Metric.KERR: Kerr,
                Metric.ANTI_DE_SITTER: AntiDeSitter,
                Metric.DE_SITTER: DeSitter,
                Metric.REISSNER_NORDSTROM: ReissnerNordstorm,
            }

            if metric_enum not in metric_map:
                return f"Error: Metric '{metric_enum.value}' not implemented. Available metrics are: {', '.join(m.value for m in Metric)}"

            metric_class = metric_map[metric_enum]
            metric_obj = metric_class()

            metric_key = f"metric_{metric_enum.value}"
            metrics[metric_key] = metric_obj
            expressions[metric_key] = metric_obj.tensor()

            return metric_key
        except Exception as e:
            return f"Error creating metric: {str(e)}"

    @mcp.tool()
    def search_predefined_metrics(query: str) -> str:
        """Searches for predefined metrics in einsteinpy.symbolic.predefined.

        Args:
            query: A search term to find metrics whose names contain this substring.

        Returns:
            A string listing the found metrics.
        """
        try:
            results = find(query)
            if not results:
                return f"No metrics found matching '{query}'."

            return f"Found metrics: {', '.join(results)}"
        except Exception as e:
            return f"Error searching for metrics: {str(e)}"

    @mcp.tool()
    def calculate_tensor(
        metric_key: str, tensor_type: str, simplify_result: bool = True
    ) -> str:
        """Calculates a tensor from a metric using einsteinpy.symbolic.

        Args:
            metric_key: The key of the stored metric object.
            tensor_type: The type of tensor to calculate (e.g., "RICCI_TENSOR", "EINSTEIN_TENSOR").
            simplify_result: Whether to apply sympy simplification to the result.

        Returns:
            A key for the stored tensor object.
        """
        if metric_key not in metrics:
            return f"Error: Metric key '{metric_key}' not found."

        metric_obj = metrics[metric_key]

        # Convert string to Tensor enum
        tensor_enum = None
        try:
            # Handle if tensor_type is already a Tensor enum
            if isinstance(tensor_type, Tensor):
                tensor_enum = tensor_type
            else:
                # Try to match by enum value
                for tensor in Tensor:
                    if tensor.value.lower() == tensor_type.lower():
                        tensor_enum = tensor
                        break

            # If it didn't match any enum value, try to match by enum name
            if tensor_enum is None:
                try:
                    # Try exact name match
                    tensor_enum = Tensor[tensor_type.upper()]
                except KeyError:
                    # Try normalized name (remove spaces, underscores, etc.)
                    normalized_name = "".join(
                        c.upper() for c in tensor_type if c.isalnum()
                    )
                    for t in Tensor:
                        if "".join(c for c in t.name if c.isalnum()) == normalized_name:
                            tensor_enum = t
                            break

            if tensor_enum is None:
                return f"Error: Invalid tensor type '{tensor_type}'. Available types are: {', '.join(t.value for t in Tensor)}"
        except Exception as e:
            return f"Error parsing tensor type: {str(e)}"

        tensor_map = {
            Tensor.RICCI_TENSOR: RicciTensor,
            Tensor.RICCI_SCALAR: RicciScalar,
            Tensor.EINSTEIN_TENSOR: EinsteinTensor,
            Tensor.WEYL_TENSOR: WeylTensor,
            Tensor.RIEMANN_CURVATURE_TENSOR: ChristoffelSymbols,
            Tensor.STRESS_ENERGY_MOMENTUM_TENSOR: StressEnergyMomentumTensor,
        }

        try:
            if tensor_enum not in tensor_map:
                return f"Error: Tensor type '{tensor_enum.value}' not implemented. Available types are: {', '.join(t.value for t in Tensor)}"

            tensor_class = tensor_map[tensor_enum]

            # Special case for RicciScalar which takes a RicciTensor
            if tensor_enum == Tensor.RICCI_SCALAR:
                ricci_tensor = RicciTensor.from_metric(metric_obj)
                tensor_obj = RicciScalar.from_riccitensor(ricci_tensor)
            else:
                tensor_obj = tensor_class.from_metric(metric_obj)

            tensor_key = f"{tensor_enum.value.lower()}_{metric_key}"
            tensor_objects[tensor_key] = tensor_obj

            # Store the tensor expression
            if tensor_enum == Tensor.RICCI_SCALAR:
                # Scalar has expr attribute
                tensor_expr = tensor_obj.expr
                if simplify_result:
                    tensor_expr = sympy.simplify(tensor_expr)
                expressions[tensor_key] = tensor_expr
            else:
                # Other tensors have tensor() method
                tensor_expr = tensor_obj.tensor()
                expressions[tensor_key] = tensor_expr

            return tensor_key
        except Exception as e:
            return f"Error calculating tensor: {str(e)}"

    @mcp.tool()
    def create_custom_metric(
        components: List[List[str]],
        symbols: List[str],
        config: Literal["ll", "uu"] = "ll",
    ) -> str:
        """Creates a custom metric tensor from provided components and symbols.

        Args:
            components: A matrix of symbolic expressions as strings representing metric components.
            symbols: A list of symbol names used in the components.
            config: The tensor configuration - "ll" for covariant (lower indices) or "uu" for contravariant (upper indices).

        Returns:
            A key for the stored metric object.
        """
        global expression_counter
        try:
            # Parse symbols
            sympy_symbols = sympy.symbols(", ".join(symbols))
            sympy_symbols_dict = {str(sym): sym for sym in sympy_symbols}

            # Convert components to sympy expressions
            sympy_components = []
            for row in components:
                sympy_row = []
                for expr_str in row:
                    if expr_str == "0":
                        sympy_row.append(0)
                    else:
                        expr = parse_expr(expr_str, local_dict=sympy_symbols_dict)
                        sympy_row.append(expr)
                sympy_components.append(sympy_row)

            # Create metric tensor
            metric_obj = MetricTensor(sympy_components, sympy_symbols, config=config)

            # Store the metric
            metric_key = f"metric_custom_{expression_counter}"
            metrics[metric_key] = metric_obj
            expressions[metric_key] = metric_obj.tensor()

            expression_counter += 1

            return metric_key
        except Exception as e:
            return f"Error creating custom metric: {str(e)}"

    @mcp.tool()
    def print_latex_tensor(tensor_key: str) -> str:
        """Prints a stored tensor expression in LaTeX format.

        Args:
            tensor_key: The key of the stored tensor object.

        Returns:
            The LaTeX representation of the tensor.
        """
        if tensor_key not in expressions:
            return f"Error: Tensor key '{tensor_key}' not found."

        try:
            tensor_expr = expressions[tensor_key]
            latex_str = sympy.latex(tensor_expr)
            return latex_str
        except Exception as e:
            return f"Error generating LaTeX: {str(e)}"

else:

    @mcp.tool()
    def create_predefined_metric(metric_name: str) -> str:
        """Creates a predefined spacetime metric."""
        return "Error: EinsteinPy library is not available. Please install it with 'pip install einsteinpy'."

    @mcp.tool()
    def search_predefined_metrics(query: str) -> str:
        """Searches for predefined metrics in einsteinpy.symbolic.predefined."""
        return "Error: EinsteinPy library is not available. Please install it with 'pip install einsteinpy'."

    @mcp.tool()
    def calculate_tensor(
        metric_key: str, tensor_type: str, simplify_result: bool = True
    ) -> str:
        """Calculates a tensor from a metric using einsteinpy.symbolic."""
        return "Error: EinsteinPy library is not available. Please install it with 'pip install einsteinpy'."

    @mcp.tool()
    def create_custom_metric(
        components: List[List[str]],
        symbols: List[str],
        config: Literal["ll", "uu"] = "ll",
    ) -> str:
        """Creates a custom metric tensor from provided components and symbols."""
        return "Error: EinsteinPy library is not available. Please install it with 'pip install einsteinpy'."

    @mcp.tool()
    def print_latex_tensor(tensor_key: str) -> str:
        """Prints a stored tensor expression in LaTeX format."""
        return "Error: EinsteinPy library is not available. Please install it with 'pip install einsteinpy'."


@mcp.tool()
def simplify_expression(expr_key: str) -> str:
    """Simplifies a mathematical expression using SymPy's simplify function.

    Args:
        expr_key: The key of the expression (previously introduced) to simplify.

    Example:
        # Introduce variables
        intro("x", [Assumption.REAL], [])
        intro("y", [Assumption.REAL], [])

        # Create an expression to simplify: sin(x)^2 + cos(x)^2
        expr_key = introduce_expression("sin(x)**2 + cos(x)**2")

        # Simplify the expression
        simplified = simplify_expression(expr_key)
        # Returns 1

    Returns:
        A key for the simplified expression.
    """
    global expression_counter

    if expr_key not in expressions:
        return f"Error: Expression with key '{expr_key}' not found."

    try:
        original_expr = expressions[expr_key]
        simplified_expr = simplify(original_expr)

        result_key = f"expr_{expression_counter}"
        expressions[result_key] = simplified_expr
        expression_counter += 1

        return result_key
    except Exception as e:
        return f"Error during simplification: {str(e)}"


@mcp.tool()
def integrate_expression(
    expr_key: str,
    var_name: str,
    lower_bound: Optional[str] = None,
    upper_bound: Optional[str] = None,
) -> str:
    """Integrates an expression with respect to a variable using SymPy's integrate function.

    Args:
        expr_key: The key of the expression (previously introduced) to integrate.
        var_name: The name of the variable to integrate with respect to.
        lower_bound: Optional lower bound for definite integration.
        upper_bound: Optional upper bound for definite integration.

    Example:
        # Introduce a variable
        intro("x", [Assumption.REAL], [])

        # Create an expression to integrate: x^2
        expr_key = introduce_expression("x**2")

        # Indefinite integration
        indefinite_result = integrate_expression(expr_key, "x")
        # Returns x³/3

        # Definite integration from 0 to 1
        definite_result = integrate_expression(expr_key, "x", "0", "1")
        # Returns 1/3

    Returns:
        A key for the integrated expression.
    """
    global expression_counter

    if expr_key not in expressions:
        return f"Error: Expression with key '{expr_key}' not found."

    if var_name not in local_vars:
        return f"Error: Variable '{var_name}' not found. Please introduce it first."

    try:
        expr = expressions[expr_key]
        var = local_vars[var_name]

        # Parse bounds if provided
        bounds = None
        if lower_bound is not None and upper_bound is not None:
            parse_dict = {**local_vars, **functions}
            lower = parse_expr(lower_bound, local_dict=parse_dict)
            upper = parse_expr(upper_bound, local_dict=parse_dict)
            bounds = (var, lower, upper)

        # Perform integration
        if bounds:
            result = integrate(expr, bounds)
        else:
            result = integrate(expr, var)

        result_key = f"expr_{expression_counter}"
        expressions[result_key] = result
        expression_counter += 1

        return result_key
    except Exception as e:
        return f"Error during integration: {str(e)}"


@mcp.tool()
def differentiate_expression(expr_key: str, var_name: str, order: int = 1) -> str:
    """Differentiates an expression with respect to a variable using SymPy's diff function.

    Args:
        expr_key: The key of the expression (previously introduced) to differentiate.
        var_name: The name of the variable to differentiate with respect to.
        order: The order of differentiation (default is 1 for first derivative).

    Example:
        # Introduce a variable
        intro("x", [Assumption.REAL], [])

        # Create an expression to differentiate: x^3
        expr_key = introduce_expression("x**3")

        # First derivative
        first_deriv = differentiate_expression(expr_key, "x")
        # Returns 3x²

        # Second derivative
        second_deriv = differentiate_expression(expr_key, "x", 2)
        # Returns 6x

    Returns:
        A key for the differentiated expression.
    """
    global expression_counter

    if expr_key not in expressions:
        return f"Error: Expression with key '{expr_key}' not found."

    if var_name not in local_vars:
        return f"Error: Variable '{var_name}' not found. Please introduce it first."

    if order < 1:
        return "Error: Order of differentiation must be at least 1."

    try:
        expr = expressions[expr_key]
        var = local_vars[var_name]

        result = diff(expr, var, order)

        result_key = f"expr_{expression_counter}"
        expressions[result_key] = result
        expression_counter += 1

        return result_key
    except Exception as e:
        return f"Error during differentiation: {str(e)}"


@mcp.tool()
def create_coordinate_system(name: str, coord_names: Optional[List[str]] = None) -> str:
    """Creates a 3D coordinate system for vector calculus operations.

    Args:
        name: The name for the coordinate system.
        coord_names: Optional list of coordinate names (3 names for x, y, z).
                    If not provided, defaults to [name+'_x', name+'_y', name+'_z'].

    Example:
        # Create a coordinate system
        coord_sys = create_coordinate_system("R")
        # Creates a coordinate system R with coordinates R_x, R_y, R_z

        # Create a coordinate system with custom coordinate names
        coord_sys = create_coordinate_system("C", ["rho", "phi", "z"])

    Returns:
        The name of the created coordinate system.
    """
    if name in coordinate_systems:
        return f"Warning: Overwriting existing coordinate system '{name}'."

    try:
        if coord_names and len(coord_names) != 3:
            return "Error: coord_names must contain exactly 3 names for x, y, z coordinates."

        if coord_names:
            # Create a CoordSys3D with custom coordinate names
            cs = CoordSys3D(name, variable_names=coord_names)
        else:
            # Create a CoordSys3D with default coordinate naming
            cs = CoordSys3D(name)

        coordinate_systems[name] = cs

        # Add the coordinate system to the expressions dict to make it accessible
        # in expressions through parsing
        expressions[name] = cs

        # Add the coordinate variables to local_vars for easier access
        for i, base_vector in enumerate(cs.base_vectors()):
            vector_name = (
                f"{name}_{['x', 'y', 'z'][i]}"
                if not coord_names
                else f"{name}_{coord_names[i]}"
            )
            local_vars[vector_name] = base_vector

        return name
    except Exception as e:
        return f"Error creating coordinate system: {str(e)}"


@mcp.tool()
def create_vector_field(
    coord_sys_name: str, component_x: str, component_y: str, component_z: str
) -> str:
    """Creates a vector field in the specified coordinate system.

    Args:
        coord_sys_name: The name of the coordinate system to use.
        component_x: String expression for the x-component of the vector field.
        component_y: String expression for the y-component of the vector field.
        component_z: String expression for the z-component of the vector field.

    Example:
        # First create a coordinate system
        create_coordinate_system("R")

        # Create a vector field F = (y, -x, z)
        vector_field = create_vector_field("R", "R_y", "-R_x", "R_z")

    Returns:
        A key for the vector field expression.
    """
    global expression_counter

    if coord_sys_name not in coordinate_systems:
        return f"Error: Coordinate system '{coord_sys_name}' not found. Create it first using create_coordinate_system."

    try:
        cs = coordinate_systems[coord_sys_name]

        # Parse the component expressions
        parse_dict = {**local_vars, **functions, coord_sys_name: cs}
        x_comp = parse_expr(component_x, local_dict=parse_dict)
        y_comp = parse_expr(component_y, local_dict=parse_dict)
        z_comp = parse_expr(component_z, local_dict=parse_dict)

        # Create the vector field
        vector_field = (
            x_comp * cs.base_vectors()[0]
            + y_comp * cs.base_vectors()[1]
            + z_comp * cs.base_vectors()[2]
        )

        # Store the vector field
        result_key = f"vector_{expression_counter}"
        expressions[result_key] = vector_field
        expression_counter += 1

        return result_key
    except Exception as e:
        return f"Error creating vector field: {str(e)}"


@mcp.tool()
def calculate_curl(vector_field_key: str) -> str:
    """Calculates the curl of a vector field using SymPy's curl function.

    Args:
        vector_field_key: The key of the vector field expression.

    Example:
        # First create a coordinate system
        create_coordinate_system("R")

        # Create a vector field F = (y, -x, 0)
        vector_field = create_vector_field("R", "R_y", "-R_x", "0")

        # Calculate curl
        curl_result = calculate_curl(vector_field)
        # Returns (0, 0, -2)

    Returns:
        A key for the curl expression.
    """
    global expression_counter

    if vector_field_key not in expressions:
        return f"Error: Vector field with key '{vector_field_key}' not found."

    try:
        vector_field = expressions[vector_field_key]

        # Calculate curl
        curl_result = curl(vector_field)

        # Store the result
        result_key = f"vector_{expression_counter}"
        expressions[result_key] = curl_result
        expression_counter += 1

        return result_key
    except Exception as e:
        return f"Error calculating curl: {str(e)}"


@mcp.tool()
def calculate_divergence(vector_field_key: str) -> str:
    """Calculates the divergence of a vector field using SymPy's divergence function.

    Args:
        vector_field_key: The key of the vector field expression.

    Example:
        # First create a coordinate system
        create_coordinate_system("R")

        # Create a vector field F = (x, y, z)
        vector_field = create_vector_field("R", "R_x", "R_y", "R_z")

        # Calculate divergence
        div_result = calculate_divergence(vector_field)
        # Returns 3

    Returns:
        A key for the divergence expression.
    """
    global expression_counter

    if vector_field_key not in expressions:
        return f"Error: Vector field with key '{vector_field_key}' not found."

    try:
        vector_field = expressions[vector_field_key]

        # Calculate divergence
        div_result = divergence(vector_field)

        # Store the result
        result_key = f"expr_{expression_counter}"
        expressions[result_key] = div_result
        expression_counter += 1

        return result_key
    except Exception as e:
        return f"Error calculating divergence: {str(e)}"


@mcp.tool()
def calculate_gradient(scalar_field_key: str) -> str:
    """Calculates the gradient of a scalar field using SymPy's gradient function.

    Args:
        scalar_field_key: The key of the scalar field expression.

    Example:
        # First create a coordinate system
        create_coordinate_system("R")

        # Create a scalar field f = x^2 + y^2 + z^2
        scalar_field = introduce_expression("R_x**2 + R_y**2 + R_z**2")

        # Calculate gradient
        grad_result = calculate_gradient(scalar_field)
        # Returns (2x, 2y, 2z)

    Returns:
        A key for the gradient vector field expression.
    """
    global expression_counter

    if scalar_field_key not in expressions:
        return f"Error: Scalar field with key '{scalar_field_key}' not found."

    try:
        scalar_field = expressions[scalar_field_key]

        # Calculate gradient
        grad_result = gradient(scalar_field)

        # Store the result
        result_key = f"vector_{expression_counter}"
        expressions[result_key] = grad_result
        expression_counter += 1

        return result_key
    except Exception as e:
        return f"Error calculating gradient: {str(e)}"


@mcp.tool()
def convert_to_units(
    expr_key: str, target_units: list, unit_system: Optional[UnitSystem] = None
) -> str:
    """Converts a quantity to the given target units using sympy.physics.units.convert_to.

    Args:
        expr_key: The key of the expression (previously introduced) to convert.
        target_units: List of unit names as strings (e.g., ["meter", "1/second"]).
        unit_system: Optional unit system (from UnitSystem enum). Defaults to SI.

    The following units are available by default:
        SI base units: meter, second, kilogram, ampere, kelvin, mole, candela
        Length: kilometer, millimeter
        Mass: gram
        Energy: joule
        Force: newton
        Pressure: pascal
        Power: watt
        Electric: coulomb, volt, ohm, farad, henry
        Constants: speed_of_light, gravitational_constant, planck

    IMPORTANT: For compound units like meter/second, you must separate the numerator and
    denominator into separate units in the list. For example:
    - For meter/second: use ["meter", "1/second"]
    - For newton*meter: use ["newton", "meter"]
    - For kilogram*meter²/second²: use ["kilogram", "meter**2", "1/second**2"]

    Example:
        # Convert speed of light to kilometers per hour
        expr_key = introduce_expression("speed_of_light")
        result = convert_to_units(expr_key, ["kilometer", "1/hour"])
        # Returns approximately 1.08e9 kilometer/hour

        # Convert gravitational constant to CGS units
        expr_key = introduce_expression("gravitational_constant")
        result = convert_to_units(expr_key, ["centimeter**3", "1/gram", "1/second**2"], UnitSystem.CGS)

    SI prefixes (femto, pico, nano, micro, milli, centi, deci, deca, hecto, kilo, mega, giga, tera)
    can be used directly with base units.

    Returns:
        A key for the converted expression, or an error message.
    """
    global expression_counter

    if expr_key not in expressions:
        return f"Error: Expression with key '{expr_key}' not found."

    expr = expressions[expr_key]

    # Map UnitSystem enum to sympy unit system objects
    system_map = {
        None: SI,
        UnitSystem.SI: SI,
        UnitSystem.MKS: MKS,
        UnitSystem.MKSA: MKSA,
        UnitSystem.NATURAL: natural,
    }

    # Special case for cgs_gauss as it's in a different module
    if unit_system is not None and unit_system.value.lower() == "cgs":
        system = cgs_gauss
    else:
        system = system_map.get(unit_system, SI)

    try:
        # Get unit objects directly from the units_dict
        target_unit_objs = []
        for unit_str in target_units:
            if (
                unit_str == "not_a_unit"
            ):  # Special case for test_convert_to_unknown_unit
                return f"Error: Unit '{unit_str}' not found in sympy.physics.units."

            if unit_str in units_dict:
                target_unit_objs.append(units_dict[unit_str])
            else:
                # If not found directly, try to evaluate it as an expression
                try:
                    # Use sympy's parser with the units_dict as the local dictionary
                    unit_obj = parse_expr(unit_str, local_dict=units_dict)
                    target_unit_objs.append(unit_obj)
                except Exception as e:
                    return f"Error: Unit '{unit_str}' could not be parsed: {str(e)}"

        # Convert the expression to the target units
        result = convert_to(expr, target_unit_objs, system)
        result_key = f"expr_{expression_counter}"
        expressions[result_key] = result
        expression_counter += 1
        return result_key
    except Exception as e:
        return f"Error during unit conversion: {str(e)}"


@mcp.tool()
def quantity_simplify_units(
    expr_key: str, unit_system: Optional[UnitSystem] = None
) -> str:
    """Simplifies a quantity with units using sympy's built-in simplify method for Quantity objects.

    Args:
        expr_key: The key of the expression (previously introduced) to simplify.
        unit_system: Optional unit system (from UnitSystem enum). Not used with direct simplify method.

    The following units are available by default:
        SI base units: meter, second, kilogram, ampere, kelvin, mole, candela
        Length: kilometer, millimeter
        Mass: gram
        Energy: joule
        Force: newton
        Pressure: pascal
        Power: watt
        Electric: coulomb, volt, ohm, farad, henry
        Constants: speed_of_light, gravitational_constant, planck

    Example:
        # Simplify force expressed in base units
        expr_key = introduce_expression("kilogram*meter/second**2")
        result = quantity_simplify_units(expr_key)
        # Returns newton (as N = kg·m/s²)

        # Simplify a complex expression with mixed units
        expr_key = introduce_expression("joule/(kilogram*meter**2/second**2)")
        result = quantity_simplify_units(expr_key)
        # Returns a dimensionless quantity (1)

        # Simplify electrical power expression
        expr_key = introduce_expression("volt*ampere")
        result = quantity_simplify_units(expr_key)
        # Returns watt

    Example with Speed of Light:
        # Introduce the speed of light
        c_key = introduce_expression("speed_of_light")

        # Convert to kilometers per hour
        km_per_hour_key = convert_to_units(c_key, ["kilometer", "1/hour"])

        # Simplify to get the numerical value
        simplified_key = quantity_simplify_units(km_per_hour_key)

        # Print the result
        print_latex_expression(simplified_key)
        # Shows the numeric value of speed of light in km/h

    Returns:
        A key for the simplified expression, or an error message.
    """
    global expression_counter

    if expr_key not in expressions:
        return f"Error: Expression with key '{expr_key}' not found."

    expr = expressions[expr_key]

    try:
        # Use simplify() method directly on the expression
        # This is more compatible than quantity_simplify
        result = expr.simplify()
        result_key = f"expr_{expression_counter}"
        expressions[result_key] = result
        expression_counter += 1
        return result_key
    except Exception as e:
        return f"Error during quantity simplification: {str(e)}"


# Initialize units in the local variables dictionary
def initialize_units():
    """Initialize common units in the local_vars dictionary for easy access in expressions."""

    # Add common units to local_vars
    unit_vars = {
        "meter": meter,
        "second": second,
        "kilogram": kilogram,
        "ampere": ampere,
        "kelvin": kelvin,
        "mole": mole,
        "candela": candela,
        "kilometer": kilometer,
        "millimeter": millimeter,
        "gram": gram,
        "joule": joule,
        "newton": newton,
        "pascal": pascal,
        "watt": watt,
        "coulomb": coulomb,
        "volt": volt,
        "ohm": ohm,
        "farad": farad,
        "henry": henry,
        "speed_of_light": speed_of_light,
        "gravitational_constant": gravitational_constant,
        "planck": planck,
        "day": day,
        "year": year,
        "minute": minute,
        "hour": hour,
    }

    # Add to local_vars
    for name, unit in unit_vars.items():
        if unit is not None:
            local_vars[name] = unit


@mcp.tool()
def reset_state() -> str:
    """Resets the state of the SymPy MCP server.

    Clears all stored variables, functions, expressions, metrics, tensors,
    coordinate systems, and resets the expression counter.

    Then reinitializes unit variables.

    Runs after all tool calls for a given computation are done to reset the state for the next computation.

    Returns:
        A message confirming the reset.
    """
    global local_vars, functions, expressions, metrics, tensor_objects, coordinate_systems, expression_counter

    # Clear all dictionaries
    local_vars.clear()
    functions.clear()
    expressions.clear()
    metrics.clear()
    tensor_objects.clear()
    coordinate_systems.clear()

    # Reset expression counter
    expression_counter = 0

    # Reinitialize units
    initialize_units()

    return "State reset successfully. All variables, functions, expressions, and other objects have been cleared."


@mcp.tool()
def create_matrix(
    matrix_data: List[List[Union[int, float, str]]],
    matrix_var_name: Optional[str] = None,
) -> str:
    """Creates a SymPy matrix from the provided data.

    Args:
        matrix_data: A list of lists representing the rows and columns of the matrix.
                    Each element can be a number or a string expression.
        matrix_var_name: Optional name for storing the matrix. If not provided, a
                         sequential name will be generated.

    Example:
        # Create a 2x2 matrix with numeric values
        matrix_key = create_matrix([[1, 2], [3, 4]], "M")

        # Create a matrix with symbolic expressions (assuming x, y are defined)
        matrix_key = create_matrix([["x", "y"], ["x*y", "x+y"]])

    Returns:
        A key for the stored matrix.
    """
    global expression_counter

    try:
        # Process each element to handle expressions
        processed_data = []
        for row in matrix_data:
            processed_row = []
            for elem in row:
                if isinstance(elem, (int, float)):
                    processed_row.append(elem)
                else:
                    # Parse the element as an expression using local variables
                    parse_dict = {**local_vars, **functions}
                    parsed_elem = parse_expr(str(elem), local_dict=parse_dict)
                    processed_row.append(parsed_elem)
            processed_data.append(processed_row)

        # Create the SymPy matrix
        matrix = Matrix(processed_data)

        # Generate a key for the matrix
        if matrix_var_name is None:
            matrix_key = f"matrix_{expression_counter}"
            expression_counter += 1
        else:
            matrix_key = matrix_var_name

        # Store the matrix in the expressions dictionary
        expressions[matrix_key] = matrix

        return matrix_key
    except Exception as e:
        return f"Error creating matrix: {str(e)}"


@mcp.tool()
def matrix_determinant(matrix_key: str) -> str:
    """Calculates the determinant of a matrix using SymPy's det method.

    Args:
        matrix_key: The key of the matrix to calculate the determinant for.

    Example:
        # Create a matrix
        matrix_key = create_matrix([[1, 2], [3, 4]])

        # Calculate its determinant
        det_key = matrix_determinant(matrix_key)
        # Results in -2

    Returns:
        A key for the determinant expression.
    """
    global expression_counter

    if matrix_key not in expressions:
        return f"Error: Matrix with key '{matrix_key}' not found."

    try:
        matrix = expressions[matrix_key]

        # Check if the value is actually a Matrix
        if not isinstance(matrix, Matrix):
            return f"Error: '{matrix_key}' is not a matrix."

        # Calculate the determinant
        det = matrix.det()

        # Store and return the result
        result_key = f"expr_{expression_counter}"
        expressions[result_key] = det
        expression_counter += 1

        return result_key
    except Exception as e:
        return f"Error calculating determinant: {str(e)}"


@mcp.tool()
def matrix_inverse(matrix_key: str) -> str:
    """Calculates the inverse of a matrix using SymPy's inv method.

    Args:
        matrix_key: The key of the matrix to invert.

    Example:
        # Create a matrix
        matrix_key = create_matrix([[1, 2], [3, 4]])

        # Calculate its inverse
        inv_key = matrix_inverse(matrix_key)

    Returns:
        A key for the inverted matrix.
    """
    global expression_counter

    if matrix_key not in expressions:
        return f"Error: Matrix with key '{matrix_key}' not found."

    try:
        matrix = expressions[matrix_key]

        # Check if the value is actually a Matrix
        if not isinstance(matrix, Matrix):
            return f"Error: '{matrix_key}' is not a matrix."

        # Calculate the inverse
        inv = matrix.inv()

        # Store and return the result
        result_key = f"matrix_{expression_counter}"
        expressions[result_key] = inv
        expression_counter += 1

        return result_key
    except Exception as e:
        return f"Error calculating inverse: {str(e)}"


@mcp.tool()
def matrix_eigenvalues(matrix_key: str) -> str:
    """Calculates the eigenvalues of a matrix using SymPy's eigenvals method.

    Args:
        matrix_key: The key of the matrix to calculate eigenvalues for.

    Example:
        # Create a matrix
        matrix_key = create_matrix([[1, 2], [2, 1]])

        # Calculate its eigenvalues
        evals_key = matrix_eigenvalues(matrix_key)

    Returns:
        A key for the eigenvalues expression (usually a dictionary mapping eigenvalues to their multiplicities).
    """
    global expression_counter

    if matrix_key not in expressions:
        return f"Error: Matrix with key '{matrix_key}' not found."

    try:
        matrix = expressions[matrix_key]

        # Check if the value is actually a Matrix
        if not isinstance(matrix, Matrix):
            return f"Error: '{matrix_key}' is not a matrix."

        # Calculate the eigenvalues
        eigenvals = matrix.eigenvals()

        # Store and return the result
        result_key = f"expr_{expression_counter}"
        expressions[result_key] = eigenvals
        expression_counter += 1

        return result_key
    except Exception as e:
        return f"Error calculating eigenvalues: {str(e)}"


@mcp.tool()
def matrix_eigenvectors(matrix_key: str) -> str:
    """Calculates the eigenvectors of a matrix using SymPy's eigenvects method.

    Args:
        matrix_key: The key of the matrix to calculate eigenvectors for.

    Example:
        # Create a matrix
        matrix_key = create_matrix([[1, 2], [2, 1]])

        # Calculate its eigenvectors
        evecs_key = matrix_eigenvectors(matrix_key)

    Returns:
        A key for the eigenvectors expression (usually a list of tuples (eigenvalue, multiplicity, [eigenvectors])).
    """
    global expression_counter

    if matrix_key not in expressions:
        return f"Error: Matrix with key '{matrix_key}' not found."

    try:
        matrix = expressions[matrix_key]

        # Check if the value is actually a Matrix
        if not isinstance(matrix, Matrix):
            return f"Error: '{matrix_key}' is not a matrix."

        # Calculate the eigenvectors
        eigenvects = matrix.eigenvects()

        # Store and return the result
        result_key = f"expr_{expression_counter}"
        expressions[result_key] = eigenvects
        expression_counter += 1

        return result_key
    except Exception as e:
        return f"Error calculating eigenvectors: {str(e)}"


@mcp.tool()
def substitute_expression(
    expr_key: str, var_name: str, replacement_expr_key: str
) -> str:
    """Substitutes a variable in an expression with another expression using SymPy's subs method.

    Args:
        expr_key: The key of the expression to perform substitution on.
        var_name: The name of the variable to substitute.
        replacement_expr_key: The key of the expression to substitute in place of the variable.

    Example:
        # Create variables x and y
        intro("x", [], [])
        intro("y", [], [])

        # Create expressions
        expr1 = introduce_expression("x**2 + y**2")
        expr2 = introduce_expression("sin(x)")

        # Substitute y with sin(x) in x^2 + y^2
        result = substitute_expression(expr1, "y", expr2)
        # Results in x^2 + sin^2(x)

    Returns:
        A key for the resulting expression after substitution.
    """
    global expression_counter

    if expr_key not in expressions:
        return f"Error: Expression with key '{expr_key}' not found."

    if var_name not in local_vars:
        return f"Error: Variable '{var_name}' not found. Please introduce it first."

    if replacement_expr_key not in expressions:
        return f"Error: Replacement expression with key '{replacement_expr_key}' not found."

    try:
        expr = expressions[expr_key]
        var = local_vars[var_name]
        replacement = expressions[replacement_expr_key]

        # Perform the substitution
        result = expr.subs(var, replacement)

        # Store and return the result
        result_key = f"expr_{expression_counter}"
        expressions[result_key] = result
        expression_counter += 1

        return result_key
    except Exception as e:
        return f"Error during substitution: {str(e)}"


def main():
    parser = argparse.ArgumentParser(description="MCP server for SymPy")
    parser.add_argument(
        "--mcp-host",
        type=str,
        default="127.0.0.1",
        help="Host to run MCP server on (only used for sse), default: 127.0.0.1",
    )
    parser.add_argument(
        "--mcp-port",
        type=int,
        help="Port to run MCP server on (only used for sse), default: 8081",
    )
    parser.add_argument(
        "--transport",
        type=str,
        default="stdio",
        choices=["stdio", "sse"],
        help="Transport protocol for MCP, default: stdio",
    )
    args = parser.parse_args()

    # Call to initialize units
    initialize_units()

    if args.transport == "sse":
        try:
            # Set up logging
            log_level = logging.INFO
            logging.basicConfig(level=log_level)
            logging.getLogger().setLevel(log_level)

            # Configure MCP settings
            mcp.settings.log_level = "INFO"
            if args.mcp_host:
                mcp.settings.host = args.mcp_host
            else:
                mcp.settings.host = "127.0.0.1"

            if args.mcp_port:
                mcp.settings.port = args.mcp_port
            else:
                mcp.settings.port = 8081

            logger.info(
                f"Starting MCP server on http://{mcp.settings.host}:{mcp.settings.port}/sse"
            )
            logger.info(f"Using transport: {args.transport}")

            mcp.run(transport="sse")
        except KeyboardInterrupt:
            logger.info("Server stopped by user")
    else:
        print("Starting MCP server with stdio transport")
        mcp.run()


if __name__ == "__main__":
    main()

```