#
tokens: 48833/50000 97/133 files (page 1/4)
lines: off (toggle) GitHub
raw markdown copy
This is page 1 of 4. Use http://codebase.md/datalab-to/surya?page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── ISSUE_TEMPLATE
│   │   ├── breaking-bug-report.md
│   │   ├── feature_request.md
│   │   └── output-bug-report.md
│   └── workflows
│       ├── benchmarks.yml
│       ├── ci.yml
│       ├── cla.yml
│       ├── publish.yml
│       └── scripts.yml
├── .gitignore
├── .pre-commit-config.yaml
├── benchmark
│   ├── detection.py
│   ├── layout.py
│   ├── ordering.py
│   ├── recognition.py
│   ├── table_recognition.py
│   ├── texify.py
│   └── utils
│       ├── __init__.py
│       ├── bbox.py
│       ├── metrics.py
│       ├── scoring.py
│       ├── tatr.py
│       ├── tesseract.py
│       ├── textract.py
│       └── verify_benchmark_scores.py
├── CITATION.cff
├── CLA.md
├── detect_layout.py
├── detect_text.py
├── LICENSE
├── ocr_app.py
├── ocr_latex.py
├── ocr_text.py
├── poetry.lock
├── pyproject.toml
├── pytest.ini
├── README.md
├── signatures
│   └── version1
│       └── cla.json
├── static
│   ├── fonts
│   │   └── .gitignore
│   └── images
│       ├── arabic_layout.jpg
│       ├── arabic_reading.jpg
│       ├── arabic_text.jpg
│       ├── arabic.jpg
│       ├── benchmark_chart_small.png
│       ├── benchmark_chart.png
│       ├── benchmark_layout_chart.png
│       ├── benchmark_rec_chart.png
│       ├── benchmark_tablerec_acc.png
│       ├── benchmark_tablerec_speed.png
│       ├── chi_hind_layout.jpg
│       ├── chi_hind_orig.jpg
│       ├── chi_hind_reading.jpg
│       ├── chi_hind_text.jpg
│       ├── chi_hind.jpg
│       ├── chinese_layout.jpg
│       ├── chinese_reading.jpg
│       ├── chinese_text.jpg
│       ├── chinese.jpg
│       ├── excerpt_layout.png
│       ├── excerpt_reading.jpg
│       ├── excerpt_text.png
│       ├── excerpt.png
│       ├── funsd_layout.jpg
│       ├── funsd_reading.jpg
│       ├── funsd_text.jpg
│       ├── funsd.png
│       ├── gcloud_full_langs.png
│       ├── gcloud_rec_bench.png
│       ├── hindi_layout.jpg
│       ├── hindi_reading.jpg
│       ├── hindi_text.jpg
│       ├── hindi.jpg
│       ├── japanese_layout.jpg
│       ├── japanese_reading.jpg
│       ├── japanese_tablerec.png
│       ├── japanese_text.jpg
│       ├── japanese.jpg
│       ├── latex_ocr.png
│       ├── nyt_layout.jpg
│       ├── nyt_order.jpg
│       ├── nyt_text.jpg
│       ├── nyt.jpg
│       ├── paper_layout.jpg
│       ├── paper_reading.jpg
│       ├── paper_tablerec.png
│       ├── paper_text.jpg
│       ├── paper.jpg
│       ├── pres_layout.jpg
│       ├── pres_reading.jpg
│       ├── pres_tablerec.png
│       ├── pres_text.jpg
│       ├── pres.png
│       ├── rec_acc_table.png
│       ├── scanned_layout.jpg
│       ├── scanned_reading.jpg
│       ├── scanned_tablerec.png
│       ├── scanned_tablerec2.png
│       ├── scanned_text.jpg
│       ├── scanned.png
│       ├── surya_rec_perf.png
│       ├── table_rec.png
│       ├── textbook_layout.jpg
│       ├── textbook_order.jpg
│       ├── textbook_text.jpg
│       └── textbook.jpg
├── surya
│   ├── __init__.py
│   ├── common
│   │   ├── __init__.py
│   │   ├── adetr
│   │   │   └── decoder.py
│   │   ├── donut
│   │   │   ├── encoder.py
│   │   │   └── processor.py
│   │   ├── load.py
│   │   ├── polygon.py
│   │   ├── predictor.py
│   │   ├── pretrained.py
│   │   ├── s3.py
│   │   ├── surya
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   ├── decoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── embedder
│   │   │   │   └── __init__.py
│   │   │   ├── encoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── flash_attn_utils.py
│   │   │   ├── processor
│   │   │   │   ├── __init__.py
│   │   │   │   ├── schema.py
│   │   │   │   └── tokenizer.py
│   │   │   └── schema.py
│   │   ├── util.py
│   │   └── xla.py
│   ├── debug
│   │   ├── draw.py
│   │   ├── fonts.py
│   │   ├── katex.js
│   │   ├── render_html.py
│   │   └── text.py
│   ├── detection
│   │   ├── __init__.py
│   │   ├── heatmap.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoderdecoder.py
│   │   ├── parallel.py
│   │   ├── processor.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── foundation
│   │   ├── __init__.py
│   │   ├── cache
│   │   │   ├── __init__.py
│   │   │   ├── dynamic_ops.py
│   │   │   └── static_ops.py
│   │   ├── loader.py
│   │   └── util.py
│   ├── input
│   │   ├── load.py
│   │   └── processing.py
│   ├── layout
│   │   ├── __init__.py
│   │   ├── label.py
│   │   └── schema.py
│   ├── logging.py
│   ├── models.py
│   ├── ocr_error
│   │   ├── __init__.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoder.py
│   │   ├── schema.py
│   │   └── tokenizer.py
│   ├── recognition
│   │   ├── __init__.py
│   │   ├── languages.py
│   │   ├── postprocessing.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── scripts
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── detect_layout.py
│   │   ├── detect_text.py
│   │   ├── finetune_ocr.py
│   │   ├── hf_to_s3.py
│   │   ├── ocr_latex.py
│   │   ├── ocr_text.py
│   │   ├── run_streamlit_app.py
│   │   ├── run_texify_app.py
│   │   ├── streamlit_app.py
│   │   ├── table_recognition.py
│   │   └── texify_app.py
│   ├── settings.py
│   └── table_rec
│       ├── __init__.py
│       ├── loader.py
│       ├── model
│       │   ├── __init__.py
│       │   ├── config.py
│       │   ├── decoder.py
│       │   ├── encoder.py
│       │   └── encoderdecoder.py
│       ├── processor.py
│       ├── schema.py
│       └── shaper.py
├── table_recognition.py
├── tests
│   ├── assets
│   │   └── test_latex.png
│   ├── conftest.py
│   ├── test_detection.py
│   ├── test_foundation.py
│   ├── test_latex_ocr.py
│   ├── test_layout.py
│   ├── test_ocr_errors.py
│   ├── test_recognition.py
│   └── test_table_rec.py
└── texify_app.py
```

# Files

--------------------------------------------------------------------------------
/static/fonts/.gitignore:
--------------------------------------------------------------------------------

```
*
!.gitignore
```

--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------

```yaml
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
  # Ruff version.
  rev: v0.9.10
  hooks:
    # Run the linter.
    - id: ruff
      types_or: [ python, pyi ]
      args: [ --fix ]
    # Run the formatter.
    - id: ruff-format
      types_or: [ python, pyi ]
```

--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------

```
private.py
.DS_Store
local.env
experiments
test_data
training
wandb
notebooks
results
data
slices

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

```

--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------

```markdown
# Surya

Surya is a document OCR toolkit that does:

- OCR in 90+ languages that benchmarks favorably vs cloud services
- Line-level text detection in any language
- Layout analysis (table, image, header, etc detection)
- Reading order detection
- Table recognition (detecting rows/columns)
- LaTeX OCR

It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).

For our managed API or on-prem document intelligence solution, check out [our platform here](https://datalab.to?utm_source=gh-surya).


|                            Detection                             |                                   OCR                                   |
|:----------------------------------------------------------------:|:-----------------------------------------------------------------------:|
|  <img src="static/images/excerpt.png" width="500px"/>  |  <img src="static/images/excerpt_text.png" width="500px"/> |

|                               Layout                               |                               Reading Order                                |
|:------------------------------------------------------------------:|:--------------------------------------------------------------------------:|
| <img src="static/images/excerpt_layout.png" width="500px"/> | <img src="static/images/excerpt_reading.jpg" width="500px"/> |

|                       Table Recognition                       |                       LaTeX OCR                        |
|:-------------------------------------------------------------:|:------------------------------------------------------:|
| <img src="static/images/scanned_tablerec.png" width="500px"/> | <img src="static/images/latex_ocr.png" width="500px"/> |


Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision.

## Community

[Discord](https://discord.gg//KuZwXNGnfH) is where we discuss future development.

## Examples

| Name             |              Detection              |                                      OCR |                                     Layout |                                       Order |                                    Table Rec |
|------------------|:-----------------------------------:|-----------------------------------------:|-------------------------------------------:|--------------------------------------------:|---------------------------------------------:|
| Japanese         | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | [Image](static/images/japanese_reading.jpg) | [Image](static/images/japanese_tablerec.png) |
| Chinese          | [Image](static/images/chinese.jpg)  |  [Image](static/images/chinese_text.jpg) |  [Image](static/images/chinese_layout.jpg) |  [Image](static/images/chinese_reading.jpg) |                                              |
| Hindi            |  [Image](static/images/hindi.jpg)   |    [Image](static/images/hindi_text.jpg) |    [Image](static/images/hindi_layout.jpg) |    [Image](static/images/hindi_reading.jpg) |                                              |
| Arabic           |  [Image](static/images/arabic.jpg)  |   [Image](static/images/arabic_text.jpg) |   [Image](static/images/arabic_layout.jpg) |   [Image](static/images/arabic_reading.jpg) |                                              |
| Chinese + Hindi  | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | [Image](static/images/chi_hind_reading.jpg) |                                              |
| Presentation     |   [Image](static/images/pres.png)   |     [Image](static/images/pres_text.jpg) |     [Image](static/images/pres_layout.jpg) |     [Image](static/images/pres_reading.jpg) |     [Image](static/images/pres_tablerec.png) |
| Scientific Paper |  [Image](static/images/paper.jpg)   |    [Image](static/images/paper_text.jpg) |    [Image](static/images/paper_layout.jpg) |    [Image](static/images/paper_reading.jpg) |    [Image](static/images/paper_tablerec.png) |
| Scanned Document | [Image](static/images/scanned.png)  |  [Image](static/images/scanned_text.jpg) |  [Image](static/images/scanned_layout.jpg) |  [Image](static/images/scanned_reading.jpg) |  [Image](static/images/scanned_tablerec.png) |
| New York Times   |   [Image](static/images/nyt.jpg)    |      [Image](static/images/nyt_text.jpg) |      [Image](static/images/nyt_layout.jpg) |        [Image](static/images/nyt_order.jpg) |                                              |
| Scanned Form     |  [Image](static/images/funsd.png)   |    [Image](static/images/funsd_text.jpg) |    [Image](static/images/funsd_layout.jpg) |    [Image](static/images/funsd_reading.jpg) | [Image](static/images/scanned_tablerec2.png) |
| Textbook         | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) |   [Image](static/images/textbook_order.jpg) |                                              |

# Hosted API

There is a hosted API for all surya models available [here](https://www.datalab.to?utm_source=gh-surya):

- Works with PDF, images, word docs, and powerpoints
- Consistent speed, with no latency spikes
- High reliability and uptime

# Commercial usage

Our model weights use a modified AI Pubs Open Rail-M license (free for research, personal use, and startups under $2M funding/revenue) and our code is GPL. For broader commercial licensing or to remove GPL requirements, visit our pricing page [here](https://www.datalab.to/pricing?utm_source=gh-surya).


# Installation

You'll need python 3.10+ and PyTorch. You may need to install the CPU version of torch first if you're not using a Mac or a GPU machine.  See [here](https://pytorch.org/get-started/locally/) for more details.

Install with:

```shell
pip install surya-ocr
```

Model weights will automatically download the first time you run surya.

# Usage

- Inspect the settings in `surya/settings.py`.  You can override any settings with environment variables.
- Your torch device will be automatically detected, but you can override this.  For example, `TORCH_DEVICE=cuda`.

## Interactive App

I've included a streamlit app that lets you interactively try Surya on images or PDF files.  Run it with:

```shell
pip install streamlit pdftext
surya_gui
```

## OCR (text recognition)

This command will write out a json file with the detected text and bboxes:

```shell
surya_ocr DATA_PATH
```

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--task_name` will specify which task to use for predicting the lines.  `ocr_with_boxes` is the default, which will format text and give you bboxes.  If you get bad performance, try `ocr_without_boxes`, which will give you potentially better performance but no bboxes.  For blocks like equations and paragraphs, try `block_without_boxes`.
- `--images` will save images of the pages and detected text lines (optional)
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
- `--disable_math` - by default, surya will recognize math in text.  This can lead to false positives - you can disable this with this flag.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  Each page dictionary contains:

- `text_lines` - the detected text and bounding boxes for each line
  - `text` - the text in the line
  - `confidence` - the confidence of the model in the detected text (0-1)
  - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format.  The points are in clockwise order from the top left.
  - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
  - `chars` - the individual characters in the line
    - `text` - the text of the character
    - `bbox` - the character bbox (same format as line bbox)
    - `polygon` - the character polygon (same format as line polygon)
    - `confidence` - the confidence of the model in the detected character (0-1)
    - `bbox_valid` - if the character is a special token or math, the bbox may not be valid
  - `words` - the individual words in the line (computed from the characters)
    - `text` - the text of the word
    - `bbox` - the word bbox (same format as line bbox)
    - `polygon` - the word polygon (same format as line polygon)
    - `confidence` - mean character confidence
    - `bbox_valid` - if the word is a special token or math, the bbox may not be valid
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.  All line bboxes will be contained within this bbox.

**Performance tips**

Setting the `RECOGNITION_BATCH_SIZE` env var properly will make a big difference when using a GPU.  Each batch item will use `40MB` of VRAM, so very high batch sizes are possible.  The default is a batch size `512`, which will use about 20GB of VRAM.  Depending on your CPU core count, it may help, too - the default CPU batch size is `32`.

### From python

```python
from PIL import Image
from surya.foundation import FoundationPredictor
from surya.recognition import RecognitionPredictor
from surya.detection import DetectionPredictor

image = Image.open(IMAGE_PATH)
foundation_predictor = FoundationPredictor()
recognition_predictor = RecognitionPredictor(foundation_predictor)
detection_predictor = DetectionPredictor()

predictions = recognition_predictor([image], det_predictor=detection_predictor)
```


## Text line detection

This command will write out a json file with the detected bboxes.

```shell
surya_detect DATA_PATH
```

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected text lines (optional)
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  Each page dictionary contains:

- `bboxes` - detected bounding boxes for text
  - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
  - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format.  The points are in clockwise order from the top left.
  - `confidence` - the confidence of the model in the detected text (0-1)
- `vertical_lines` - vertical lines detected in the document
  - `bbox` - the axis-aligned line coordinates.
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.  All line bboxes will be contained within this bbox.

**Performance tips**

Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU.  Each batch item will use `440MB` of VRAM, so very high batch sizes are possible.  The default is a batch size `36`, which will use about 16GB of VRAM.  Depending on your CPU core count, it might help, too - the default CPU batch size is `6`.

### From python

```python
from PIL import Image
from surya.detection import DetectionPredictor

image = Image.open(IMAGE_PATH)
det_predictor = DetectionPredictor()

# predictions is a list of dicts, one per image
predictions = det_predictor([image])
```

## Layout and reading order

This command will write out a json file with the detected layout and reading order.

```shell
surya_layout DATA_PATH
```

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected text lines (optional)
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  Each page dictionary contains:

- `bboxes` - detected bounding boxes for text
  - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
  - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format.  The points are in clockwise order from the top left.
  - `position` - the reading order of the box.
  - `label` - the label for the bbox.  One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`.
  - `top_k` - the top-k other potential labels for the box.  A dictionary with labels as keys and confidences as values.
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.  All line bboxes will be contained within this bbox.

**Performance tips**

Setting the `LAYOUT_BATCH_SIZE` env var properly will make a big difference when using a GPU.  Each batch item will use `220MB` of VRAM, so very high batch sizes are possible.  The default is a batch size `32`, which will use about 7GB of VRAM.  Depending on your CPU core count, it might help, too - the default CPU batch size is `4`.

### From python

```python
from PIL import Image
from surya.foundation import FoundationPredictor
from surya.layout import LayoutPredictor
from surya.settings import settings

image = Image.open(IMAGE_PATH)
layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT))

# layout_predictions is a list of dicts, one per image
layout_predictions = layout_predictor([image])
```

## Table Recognition

This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes.  If you want to get cell positions and text, along with nice formatting, check out the [marker](https://www.github.com/VikParuchuri/marker) repo.  You can use the `TableConverter` to detect and extract tables in images and PDFs.  It supports output in json (with bboxes), markdown, and html.

```shell
surya_table DATA_PATH
```

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected table cells + rows and columns (optional)
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
- `--detect_boxes` specifies if cells should be detected.  By default, they're pulled out of the PDF, but this is not always possible.
- `--skip_table_detection` tells table recognition not to detect tables first.  Use this if your image is already cropped to a table.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  Each page dictionary contains:

- `rows` - detected table rows
  - `bbox` - the bounding box of the table row
  - `row_id` - the id of the row
  - `is_header` - if it is a header row.
- `cols` - detected table columns
  - `bbox` - the bounding box of the table column
  - `col_id`- the id of the column
  - `is_header` - if it is a header column
- `cells` - detected table cells
  - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
  - `text` - if text could be pulled out of the pdf, the text of this cell.
  - `row_id` - the id of the row the cell belongs to.
  - `col_id` - the id of the column the cell belongs to.
  - `colspan` - the number of columns spanned by the cell.
  - `rowspan` - the number of rows spanned by the cell.
  - `is_header` - whether it is a header cell.
- `page` - the page number in the file
- `table_idx` - the index of the table on the page (sorted in vertical order)
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.  All line bboxes will be contained within this bbox.

**Performance tips**

Setting the `TABLE_REC_BATCH_SIZE` env var properly will make a big difference when using a GPU.  Each batch item will use `150MB` of VRAM, so very high batch sizes are possible.  The default is a batch size `64`, which will use about 10GB of VRAM.  Depending on your CPU core count, it might help, too - the default CPU batch size is `8`.

### From python

```python
from PIL import Image
from surya.table_rec import TableRecPredictor

image = Image.open(IMAGE_PATH)
table_rec_predictor = TableRecPredictor()

table_predictions = table_rec_predictor([image])
```

## LaTeX OCR

This command will write out a json file with the LaTeX of the equations.  You must pass in images that are already cropped to the equations.  You can do this by running the layout model, then cropping, if you want.

```shell
surya_latex_ocr DATA_PATH
```

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  See the OCR section above for the format of the output.

### From python

```python
from PIL import Image
from surya.texify import TexifyPredictor

image = Image.open(IMAGE_PATH)
predictor = TexifyPredictor()

predictor([image])
```

### Interactive app

You can also run a special interactive app that lets you select equations and OCR them (kind of like MathPix snip) with:

```shell
pip install streamlit==1.40 streamlit-drawable-canvas-jsretry
texify_gui
```

## Compilation

The following models have support for compilation. You will need to set the following environment variables to enable compilation:

- Detection: `COMPILE_DETECTOR=true`
- Layout: `COMPILE_LAYOUT=true`
- Table recognition: `COMPILE_TABLE_REC=true`

Alternatively, you can also set `COMPILE_ALL=true` which will compile all models.

Here are the speedups on an A10 GPU:

| Model             | Time per page (s) | Compiled time per page (s) | Speedup (%) |
| ----------------- | ----------------- | -------------------------- | ----------- |
| Detection         | 0.108808          | 0.10521                    | 3.306742151 |
| Layout            | 0.27319           | 0.27063                    | 0.93707676  |
| Table recognition | 0.0219            | 0.01938                    | 11.50684932 |

# Limitations

- This is specialized for document OCR.  It will likely not work on photos or other images.
- It is for printed text, not handwriting (though it may work on some handwriting).
- The text detection model has trained itself to ignore advertisements.
- You can find language support for OCR in `surya/recognition/languages.py`.  Text detection, layout analysis, and reading order will work with any language.

## Troubleshooting

If OCR isn't working properly:

- Try increasing resolution of the image so the text is bigger.  If the resolution is already very high, try decreasing it to no more than a `2048px` width.
- Preprocessing the image (binarizing, deskewing, etc) can help with very old/blurry images.
- You can adjust `DETECTOR_BLANK_THRESHOLD` and `DETECTOR_TEXT_THRESHOLD` if you don't get good results.  `DETECTOR_BLANK_THRESHOLD` controls the space between lines - any prediction below this number will be considered blank space.  `DETECTOR_TEXT_THRESHOLD` controls how text is joined - any number above this is considered text.  `DETECTOR_TEXT_THRESHOLD` should always be higher than `DETECTOR_BLANK_THRESHOLD`, and both should be in the 0-1 range.  Looking at the heatmap from the debug output of the detector can tell you how to adjust these (if you see faint things that look like boxes, lower the thresholds, and if you see bboxes being joined together, raise the thresholds).

# Manual install

If you want to develop surya, you can install it manually:

- `git clone https://github.com/VikParuchuri/surya.git`
- `cd surya`
- `poetry install` - installs main and dev dependencies
- `poetry shell` - activates the virtual environment

# Benchmarks

## OCR

![Benchmark chart tesseract](static/images/benchmark_rec_chart.png)

| Model     | Time per page (s) | Avg similarity (⬆) |
|-----------|-------------------|--------------------|
| surya     | .62               | 0.97               |
| tesseract | .45               | 0.88               |

[Full language results](static/images/rec_acc_table.png)

Tesseract is CPU-based, and surya is CPU or GPU.  I tried to cost-match the resources used, so I used a 1xA6000 (48GB VRAM) for surya, and 28 CPU cores for Tesseract (same price on Lambda Labs/DigitalOcean).

### Google Cloud Vision

I benchmarked OCR against Google Cloud vision since it has similar language coverage to Surya.

![Benchmark chart google cloud](static/images/gcloud_rec_bench.png)

[Full language results](static/images/gcloud_full_langs.png)

**Methodology**

I measured normalized sentence similarity (0-1, higher is better) based on a set of real-world and synthetic pdfs.  I sampled PDFs from common crawl, then filtered out the ones with bad OCR.  I couldn't find PDFs for some languages, so I also generated simple synthetic PDFs for those.

I used the reference line bboxes from the PDFs with both tesseract and surya, to just evaluate the OCR quality.

For Google Cloud, I aligned the output from Google Cloud with the ground truth.  I had to skip RTL languages since they didn't align well.

## Text line detection

![Benchmark chart](static/images/benchmark_chart_small.png)

| Model     | Time (s)   | Time per page (s)   | precision   |   recall |
|-----------|------------|---------------------|-------------|----------|
| surya     | 47.2285    | 0.094452            | 0.835857    | 0.960807 |
| tesseract | 74.4546    | 0.290838            | 0.631498    | 0.997694 |


Tesseract is CPU-based, and surya is CPU or GPU.  I ran the benchmarks on a system with an A10 GPU, and a 32 core CPU.  This was the resource usage:

- tesseract - 32 CPU cores, or 8 workers using 4 cores each
- surya - 36 batch size, for 16GB VRAM usage

**Methodology**

Surya predicts line-level bboxes, while tesseract and others predict word-level or character-level.  It's hard to find 100% correct datasets with line-level annotations. Merging bboxes can be noisy, so I chose not to use IoU as the metric for evaluation.

I instead used coverage, which calculates:

- Precision - how well the predicted bboxes cover ground truth bboxes
- Recall - how well ground truth bboxes cover predicted bboxes

First calculate coverage for each bbox, then add a small penalty for double coverage, since we want the detection to have non-overlapping bboxes.  Anything with a coverage of 0.5 or higher is considered a match.

Then we calculate precision and recall for the whole dataset.

## Layout analysis

| Layout Type   |   precision |   recall |
|---------------|-------------|----------|
| Image         |     0.91265 |  0.93976 |
| List          |     0.80849 |  0.86792 |
| Table         |     0.84957 |  0.96104 |
| Text          |     0.93019 |  0.94571 |
| Title         |     0.92102 |  0.95404 |

Time per image - .13 seconds on GPU (A10).

**Methodology**

I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/PubLayNet), which was not in the training data.  I had to align publaynet labels with the surya layout labels.  I was then able to find coverage for each layout type:

- Precision - how well the predicted bboxes cover ground truth bboxes
- Recall - how well ground truth bboxes cover predicted bboxes

## Reading Order

88% mean accuracy, and .4 seconds per image on an A10 GPU.  See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check.

**Methodology**

I benchmarked the reading order on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data.  Unfortunately, this dataset is fairly noisy, and not all the labels are correct.  It was very hard to find a dataset annotated with reading order and also layout information.  I wanted to avoid using a cloud service for the ground truth.

The accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct.

## Table Recognition

| Model             |   Row Intersection |   Col Intersection |   Time Per Image |
|-------------------|--------------------|--------------------|------------------|
| Surya             |               1    |            0.98625 |          0.30202 |
| Table transformer |               0.84 |            0.86857 |          0.08082 |

Higher is better for intersection, which the percentage of the actual row/column overlapped by the predictions.  This benchmark is mostly a sanity check - there is a more rigorous one in [marker](https://www.github.com/VikParuchuri/marker)

**Methodology**

The benchmark uses a subset of [Fintabnet](https://developer.ibm.com/exchanges/data/all/fintabnet/) from IBM.  It has labeled rows and columns.  After table recognition is run, the predicted rows and columns are compared to the ground truth.  There is an additional penalty for predicting too many or too few rows/columns.

## LaTeX OCR

| Method | edit ⬇   | time taken (s) ⬇ |
|--------|----------|------------------|
| texify | 0.122617 | 35.6345          |

This inferences texify on a ground truth set of LaTeX, then does edit distance.  This is a bit noisy, since 2 LaTeX strings that render the same can have different symbols in them.

## Running your own benchmarks

You can benchmark the performance of surya on your machine.

- Follow the manual install instructions above.
- `poetry install --group dev` - installs dev dependencies

**Text line detection**

This will evaluate tesseract and surya for text line detection across a randomly sampled set of images from [doclaynet](https://huggingface.co/datasets/vikp/doclaynet_bench).

```shell
python benchmark/detection.py --max_rows 256
```

- `--max_rows` controls how many images to process for the benchmark
- `--debug` will render images and detected bboxes
- `--pdf_path` will let you specify a pdf to benchmark instead of the default data
- `--results_dir` will let you specify a directory to save results to instead of the default one

**Text recognition**

This will evaluate surya and optionally tesseract on multilingual pdfs from common crawl (with synthetic data for missing languages).

```shell
python benchmark/recognition.py --tesseract
```

- `--max_rows` controls how many images to process for the benchmark
- `--debug 2` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one
- `--tesseract` will run the benchmark with tesseract.  You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder.

- Set `RECOGNITION_BATCH_SIZE=864` to use the same batch size as the benchmark.
- Set `RECOGNITION_BENCH_DATASET_NAME=vikp/rec_bench_hist` to use the historical document data for benchmarking.  This data comes from the [tapuscorpus](https://github.com/HTR-United/tapuscorpus).

**Layout analysis**

This will evaluate surya on the publaynet dataset.

```shell
python benchmark/layout.py
```

- `--max_rows` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one

**Reading Order**

```shell
python benchmark/ordering.py
```

- `--max_rows` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one

**Table Recognition**

```shell
python benchmark/table_recognition.py --max_rows 1024 --tatr
```

- `--max_rows` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one
- `--tatr` specifies whether to also run table transformer

**LaTeX OCR**

```shell
python benchmark/texify.py --max_rows 128
```

- `--max_rows` controls how many images to process for the benchmark
- `--results_dir` will let you specify a directory to save results to instead of the default one

# Training

Text detection was trained on 4x A6000s for 3 days.  It used a diverse set of images as training data.  It was trained from scratch using a modified efficientvit architecture for semantic segmentation.

Text recognition was trained on 4x A6000s for 2 weeks.  It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes).

# Finetuning Surya OCR
You can now take Surya OCR further by training it on your own data with our [finetuning script](/surya/scripts/finetune_ocr.py).
It’s built on Hugging Face Trainer, and supports all the [arguments](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments) that the huggingface trainer provides, and integrations like torchrun, or deepspeed.

To setup your dataset, follow the example dataset format [here](https://huggingface.co/datasets/datalab-to/ocr_finetune_example) and provide the path to your own dataset when launching the training script.
```bash
# Tested on 1xH100 GPU
# Set --pretrained_checkpoint_path to load from a custom checkpoint, otherwise
# the default surya ocr weights will be loaded as the initialization
python surya/scripts/finetune_ocr.py \
  --output_dir $OUTPUT_DIR \
  --dataset_name datalab-to/ocr_finetune_example \
  --per_device_train_batch_size 64 \
  --gradient_checkpointing true \
  --max_sequence_length 1024
```

This is a minimal training script to get you started finetuning Surya. Our internal training stack includes character bounding box finetuning, sliding window attention with specialized attention masks, custom kernels, augmentations, and other optimizations that can push OCR accuracy well beyond standard finetuning. If you want to get the most out of your data, reach us at [email protected]!

# Thanks

This work would not have been possible without amazing open source AI work:

- [Segformer](https://arxiv.org/pdf/2105.15203.pdf) from NVIDIA
- [EfficientViT](https://github.com/mit-han-lab/efficientvit) from MIT
- [timm](https://github.com/huggingface/pytorch-image-models) from Ross Wightman
- [Donut](https://github.com/clovaai/donut) from Naver
- [transformers](https://github.com/huggingface/transformers) from huggingface
- [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model

Thank you to everyone who makes open source AI possible.

# Citation

If you use surya (or the associated models) in your work or research, please consider citing us using the following BibTeX entry:

```bibtex
@misc{paruchuri2025surya,
  author       = {Vikas Paruchuri and Datalab Team},
  title        = {Surya: A lightweight document OCR and analysis toolkit},
  year         = {2025},
  howpublished = {\url{https://github.com/VikParuchuri/surya}},
  note         = {GitHub repository},
}

```

--------------------------------------------------------------------------------
/benchmark/utils/__init__.py:
--------------------------------------------------------------------------------

```python

```

--------------------------------------------------------------------------------
/surya/__init__.py:
--------------------------------------------------------------------------------

```python

```

--------------------------------------------------------------------------------
/surya/detection/model/__init__.py:
--------------------------------------------------------------------------------

```python

```

--------------------------------------------------------------------------------
/surya/foundation/cache/__init__.py:
--------------------------------------------------------------------------------

```python

```

--------------------------------------------------------------------------------
/surya/ocr_error/model/__init__.py:
--------------------------------------------------------------------------------

```python

```

--------------------------------------------------------------------------------
/surya/scripts/__init__.py:
--------------------------------------------------------------------------------

```python

```

--------------------------------------------------------------------------------
/surya/table_rec/model/__init__.py:
--------------------------------------------------------------------------------

```python

```

--------------------------------------------------------------------------------
/surya/common/__init__.py:
--------------------------------------------------------------------------------

```python




```

--------------------------------------------------------------------------------
/ocr_text.py:
--------------------------------------------------------------------------------

```python
from surya.scripts.ocr_text import ocr_text_cli

if __name__ == "__main__":
    ocr_text_cli()

```

--------------------------------------------------------------------------------
/ocr_latex.py:
--------------------------------------------------------------------------------

```python
from surya.scripts.ocr_latex import ocr_latex_cli

if __name__ == "__main__":
    ocr_latex_cli()

```

--------------------------------------------------------------------------------
/texify_app.py:
--------------------------------------------------------------------------------

```python
from surya.scripts.run_texify_app import texify_app_cli

if __name__ == "__main__":
    texify_app_cli()
```

--------------------------------------------------------------------------------
/detect_layout.py:
--------------------------------------------------------------------------------

```python
from surya.scripts.detect_layout import detect_layout_cli

if __name__ == "__main__":
    detect_layout_cli()

```

--------------------------------------------------------------------------------
/detect_text.py:
--------------------------------------------------------------------------------

```python
from surya.scripts.detect_text import detect_text_cli

if __name__ == "__main__":
    detect_text_cli()








```

--------------------------------------------------------------------------------
/ocr_app.py:
--------------------------------------------------------------------------------

```python
from surya.scripts.run_streamlit_app import streamlit_app_cli

if __name__ == "__main__":
    streamlit_app_cli()
```

--------------------------------------------------------------------------------
/table_recognition.py:
--------------------------------------------------------------------------------

```python
from surya.scripts.table_recognition import table_recognition_cli

if __name__ == "__main__":
    table_recognition_cli()
```

--------------------------------------------------------------------------------
/surya/ocr_error/schema.py:
--------------------------------------------------------------------------------

```python
from typing import List

from pydantic import BaseModel


class OCRErrorDetectionResult(BaseModel):
    texts: List[str]
    labels: List[str]

```

--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------

```
[pytest]
testpaths=tests
pythonpath=.
filterwarnings =
    ignore::UserWarning
    ignore::PendingDeprecationWarning
    ignore::DeprecationWarning
```

--------------------------------------------------------------------------------
/surya/detection/schema.py:
--------------------------------------------------------------------------------

```python
from typing import List, Optional, Any

from pydantic import BaseModel

from surya.common.polygon import PolygonBox


class TextDetectionResult(BaseModel):
    bboxes: List[PolygonBox]
    heatmap: Optional[Any]
    affinity_map: Optional[Any]
    image_bbox: List[float]

```

--------------------------------------------------------------------------------
/surya/scripts/run_texify_app.py:
--------------------------------------------------------------------------------

```python
import subprocess
import os


def texify_app_cli():
    cur_dir = os.path.dirname(os.path.abspath(__file__))
    ocr_app_path = os.path.join(cur_dir, "texify_app.py")
    cmd = ["streamlit", "run", ocr_app_path, "--server.fileWatcherType", "none", "--server.headless", "true"]
    subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"})
```

--------------------------------------------------------------------------------
/surya/scripts/run_streamlit_app.py:
--------------------------------------------------------------------------------

```python
import subprocess
import os


def streamlit_app_cli():
    cur_dir = os.path.dirname(os.path.abspath(__file__))
    ocr_app_path = os.path.join(cur_dir, "streamlit_app.py")
    cmd = ["streamlit", "run", ocr_app_path, "--server.fileWatcherType", "none", "--server.headless", "true"]
    subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"})
```

--------------------------------------------------------------------------------
/surya/common/surya/schema.py:
--------------------------------------------------------------------------------

```python
class TaskNames:
    block_without_boxes = "block_without_boxes"
    ocr_with_boxes = "ocr_with_boxes"
    ocr_without_boxes = "ocr_without_boxes"
    layout = "layout"
    table_structure = "table_structure"


TASK_NAMES = [
    TaskNames.block_without_boxes,
    TaskNames.ocr_with_boxes,
    TaskNames.ocr_without_boxes,
    TaskNames.layout,
    TaskNames.table_structure,
]

```

--------------------------------------------------------------------------------
/surya/layout/schema.py:
--------------------------------------------------------------------------------

```python
from typing import Optional, Dict, List

from pydantic import BaseModel

from surya.common.polygon import PolygonBox


class LayoutBox(PolygonBox):
    label: str
    position: int
    top_k: Optional[Dict[str, float]] = None


class LayoutResult(BaseModel):
    bboxes: List[LayoutBox]
    image_bbox: List[float]
    sliced: bool = False  # Whether the image was sliced and reconstructed

```

--------------------------------------------------------------------------------
/surya/detection/parallel.py:
--------------------------------------------------------------------------------

```python
class FakeFuture:
    def __init__(self, func, *args, **kwargs):
        self._result = func(*args, **kwargs)

    def result(self):
        return self._result

class FakeExecutor:
    def __init__(self, **kwargs):
        pass

    def __enter__(self):
        return self

    def __exit__(self, *excinfo):
        pass

    def submit(self, fn, *args, **kwargs):
        return FakeFuture(fn, *args, **kwargs)

```

--------------------------------------------------------------------------------
/tests/test_layout.py:
--------------------------------------------------------------------------------

```python
def test_layout_topk(layout_predictor, test_image):
    layout_results = layout_predictor([test_image])

    assert len(layout_results) == 1
    assert layout_results[0].image_bbox == [0, 0, 1024, 1024]

    bboxes = layout_results[0].bboxes
    assert len(bboxes) == 2

    assert bboxes[0].label == "SectionHeader"
    assert len(bboxes[0].top_k) == 5

    assert bboxes[1].label == "Text"
    assert len(bboxes[1].top_k) == 5

```

--------------------------------------------------------------------------------
/tests/test_foundation.py:
--------------------------------------------------------------------------------

```python
from surya.foundation import FoundationPredictor


def test_foundation_flash2():
    try:
        f = FoundationPredictor(None, None, None, "flash_attention_2")
        assert f.model.decoder.config._attn_implementation == "flash_attention_2"
        assert f.model.vision_encoder.config._attn_implementation == "flash_attention_2"
    except Exception as e:
        assert False, (
            f"FoundationPredictor with flash_attention_2 raised an exception: {e}"
        )

```

--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------

```yaml
name: Unit tests

on: [push]

jobs:
  build:
    runs-on: ${{ matrix.os }}
    strategy:
      matrix:
        os: [t4_gpu, ubuntu-latest, windows-latest]
    steps:
      - uses: actions/checkout@v3
      - name: Set up Python 3.11
        uses: actions/setup-python@v4
        with:
          python-version: 3.11
      - name: Install python dependencies
        run: |
          pip install poetry
          poetry install
      - name: Run tests
        run: poetry run pytest
```

--------------------------------------------------------------------------------
/surya/layout/label.py:
--------------------------------------------------------------------------------

```python
LAYOUT_PRED_RELABEL = {
    "<page-header>": "PageHeader",
    "<page-footer>": "PageFooter",
    "<footnote>": "Footnote",
    "<image>": "Picture",
    "<figure>": "Figure",
    "<text>": "Text",
    "<caption>": "Caption",
    "<list-item>": "ListItem",
    "<section-header>": "SectionHeader",
    "<table>": "Table",
    "<table-of-contents>": "TableOfContents",
    "<form>": "Form",
    "<equation-block>": "Equation",
    "<code-block>": "Code",
    "<complex-block>": "Figure",
}

```

--------------------------------------------------------------------------------
/tests/test_ocr_errors.py:
--------------------------------------------------------------------------------

```python
def test_garbled_text(ocr_error_predictor):
    text = """"
    ; dh vksj ls mifLFkr vf/koDrk % Jh vfuy dqekj
    2. vfHk;qDr dh vksj ls mifLFkr vf/koDrk % Jh iznhi d
    """.strip()
    results = ocr_error_predictor([text])
    assert results.labels[0] == "bad"


def test_good_text(ocr_error_predictor):
    text = """"
    There are professions more harmful than industrial design, but only a very few of them.
    """.strip()
    results = ocr_error_predictor([text])
    assert results.labels[0] == "good"
```

--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------

```markdown
---
name: Feature request
about: Suggest an idea for this project
title: "[FEAT]"
labels: enhancement
assignees: ''

---

## ✨ Is your feature request related to a problem?

A clear and concise description of what the problem is. 

## 💡 Describe the Solution You'd Like

A concise description of what you want to happen or how you envision it working.

## 📋 Alternatives Considered

Any alternative solutions or workarounds you've tried.

## 🧩 Additional Context

Any additional context, references, or related issues.

```

--------------------------------------------------------------------------------
/surya/common/xla.py:
--------------------------------------------------------------------------------

```python
import math
from surya.settings import settings

if settings.TORCH_DEVICE_MODEL == "xla":
    import torch_xla.core.xla_model as xm
else:
    xm = None


def get_nearest_pad(
    length: int, pad_multiple: int = settings.FOUNDATION_PAD_TO_NEAREST
):
    return math.ceil(length / pad_multiple) * pad_multiple


def get_compile_args(device: str) -> dict:
    if not settings.FOUNDATION_XLA:
        return {}

    return {
        "backend": "openxla",
    }


def mark_step():
    if xm is not None:
        xm.mark_step()

```

--------------------------------------------------------------------------------
/tests/test_latex_ocr.py:
--------------------------------------------------------------------------------

```python
from typing import List

from PIL import Image, ImageDraw

from surya.common.surya.schema import TaskNames
from surya.recognition import OCRResult


def test_latex_ocr(recognition_predictor, test_image_latex):
    width, height = test_image_latex.size
    results: List[OCRResult] = recognition_predictor(
        [test_image_latex], [TaskNames.block_without_boxes], bboxes=[[[0, 0, width, height]]]
    )
    text = results[0].text_lines[0].text
    assert len(results) == 1

    assert text.startswith("<math")
    assert text.endswith("</math>")

```

--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------

```yaml
name: Python package
on:
  push:
    tags:
      - "v*.*.*"
jobs:
  build:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v3
      - name: Set up Python 3.11
        uses: actions/setup-python@v4
        with:
          python-version: 3.11
      - name: Install python dependencies
        run: |
          pip install poetry
          poetry install
      - name: Build package
        run: |
          poetry build
      - name: Publish package
        env:
          PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
        run: |
          poetry config pypi-token.pypi "$PYPI_TOKEN"
          poetry publish

```

--------------------------------------------------------------------------------
/tests/test_detection.py:
--------------------------------------------------------------------------------

```python
def test_detection(detection_predictor, test_image):
    detection_results = detection_predictor([test_image])

    assert len(detection_results) == 1
    assert detection_results[0].image_bbox == [0, 0, 1024, 1024]

    bboxes = detection_results[0].bboxes
    assert len(bboxes) == 4


def test_detection_chunking(detection_predictor, test_image_tall):
    detection_results = detection_predictor([test_image_tall])

    assert len(detection_results) == 1
    assert detection_results[0].image_bbox == [0, 0, 4096, 4096]

    bboxes = detection_results[0].bboxes
    assert len(bboxes) >= 3 # Sometimes merges into 3
    assert abs(4000 - bboxes[1].polygon[0][0]) < 50
```

--------------------------------------------------------------------------------
/surya/common/load.py:
--------------------------------------------------------------------------------

```python
from typing import Optional, Any

import torch

from surya.settings import settings


class ModelLoader:
    def __init__(self, checkpoint: Optional[str] = None):
        self.checkpoint = checkpoint

    def model(
        self,
        device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,
        dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE,
        attention_implementation: Optional[str] = None,
    ) -> Any:
        raise NotImplementedError()

    def processor(
        self,
        device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,
        dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE,
    ) -> Any:
        raise NotImplementedError()

```

--------------------------------------------------------------------------------
/surya/common/surya/processor/schema.py:
--------------------------------------------------------------------------------

```python
from typing import TypedDict, Literal, List, Tuple

import torch
from PIL import Image


class TaskDict(TypedDict):
    datasets: List[str]
    img_size: Tuple[int, int]


class TasksDict(TypedDict):
    ocr_with_boxes: TaskDict
    ocr_without_boxes: TaskDict
    block_without_boxes: TaskDict


class ProcessorInput(TypedDict):
    type: Literal["image", "ocr", "text", "empty_output"]


class ImageInput(ProcessorInput):
    type: Literal["image"]
    image: Image.Image
    rotated: bool


class TextInput(ProcessorInput):
    type: Literal["text"]
    text: str
    math: bool


class ProcessorOutput(TypedDict):
    input_ids: List[int]
    image_tiles: torch.Tensor | None
    grid_thw: torch.Tensor | None

```

--------------------------------------------------------------------------------
/surya/logging.py:
--------------------------------------------------------------------------------

```python
import logging
import warnings
from surya.settings import settings


def configure_logging():
    logger = get_logger()

    # Remove any existing handlers to prevent duplicates
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)

    # Add our handler
    handler = logging.StreamHandler()
    formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    # Prevent propagation to parent loggers to avoid double logging
    logger.propagate = False

    logger.setLevel(settings.LOGLEVEL)
    warnings.simplefilter(action="ignore", category=FutureWarning)


def get_logger():
    return logging.getLogger("surya")

```

--------------------------------------------------------------------------------
/surya/common/pretrained.py:
--------------------------------------------------------------------------------

```python
from typing import Optional

from transformers import PreTrainedModel
from transformers.utils import is_flash_attn_2_available


class SuryaPreTrainedModel(PreTrainedModel):
    # No-op if we pass attention, so we can set attention however we want in the config
    def _check_and_adjust_attn_implementation(
        self, attn_implementation: Optional[str], **kwargs
    ):
        if attn_implementation is None:
            try:
                self._sdpa_can_dispatch(True)
                attn_implementation = "sdpa"
            except (ValueError, ImportError):
                attn_implementation = "eager"

            if self._supports_flash_attn and is_flash_attn_2_available():
                attn_implementation = "flash_attention_2"

        return attn_implementation

```

--------------------------------------------------------------------------------
/surya/debug/fonts.py:
--------------------------------------------------------------------------------

```python
from typing import List, Optional
import os
import requests

from surya.settings import settings


def get_font_path(langs: Optional[List[str]] = None) -> str:
    font_path = settings.RECOGNITION_RENDER_FONTS["all"]
    if langs is not None:
        for k in settings.RECOGNITION_RENDER_FONTS:
            if k in langs and len(langs) == 1:
                font_path = settings.RECOGNITION_RENDER_FONTS[k]
                break

    if not os.path.exists(font_path):
        os.makedirs(os.path.dirname(font_path), exist_ok=True)
        font_dl_path = f"{settings.RECOGNITION_FONT_DL_BASE}/{os.path.basename(font_path)}"
        with requests.get(font_dl_path, stream=True) as r, open(font_path, 'wb') as f:
            r.raise_for_status()
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)

    return font_path
```

--------------------------------------------------------------------------------
/surya/recognition/schema.py:
--------------------------------------------------------------------------------

```python
import math
import numpy as np
from typing import Optional, List

from pydantic import BaseModel, field_validator

from surya.common.polygon import PolygonBox


class BaseChar(PolygonBox):
    text: str
    confidence: Optional[float] = 0

    @field_validator("confidence", mode="before")
    @classmethod
    def validate_confidence(cls, v: float) -> float:
        if v is None:
            return 0
        elif math.isnan(v) or np.isnan(v):
            return 0
        return v


class TextChar(BaseChar):
    bbox_valid: bool = True  # This is false when the given bbox is not valid


class TextWord(BaseChar):
    bbox_valid: bool = True


class TextLine(BaseChar):
    chars: List[TextChar]  # Individual characters in the line
    original_text_good: bool = False
    words: List[TextWord] | None = None


class OCRResult(BaseModel):
    text_lines: List[TextLine]
    image_bbox: List[float]

```

--------------------------------------------------------------------------------
/surya/table_rec/schema.py:
--------------------------------------------------------------------------------

```python
from typing import List

from pydantic import BaseModel

from surya.common.polygon import PolygonBox


class TableCell(PolygonBox):
    row_id: int
    colspan: int
    within_row_id: int
    cell_id: int
    is_header: bool
    rowspan: int | None = None
    merge_up: bool = False
    merge_down: bool = False
    col_id: int | None = None
    text_lines: List[dict] | None = None

    @property
    def label(self):
        return f'Cell {self.cell_id} {self.rowspan}/{self.colspan}'


class TableRow(PolygonBox):
    row_id: int
    is_header: bool

    @property
    def label(self):
        return f'Row {self.row_id}'


class TableCol(PolygonBox):
    col_id: int
    is_header: bool

    @property
    def label(self):
        return f'Column {self.col_id}'


class TableResult(BaseModel):
    cells: List[TableCell]
    unmerged_cells: List[TableCell]
    rows: List[TableRow]
    cols: List[TableCol]
    image_bbox: List[float]

```

--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/output-bug-report.md:
--------------------------------------------------------------------------------

```markdown
---
name: Output bug report
about: Create a report about poor output quality
title: "[BUG: Output]"
labels: 'bug: output'
assignees: ''

---

## 📝 Describe the Output Issue

A clear and concise description of the incorrect or unexpected output.

## 📄 Input Document

Attach the PDF or input file used.

## 📤 Current Output

Paste the Markdown or HTML that Marker generated:

````markdown
Paste output here
`````

## ✅ Expected Output

Describe or paste what you expected Marker to generate.

## ⚙️ Environment

Please fill in all relevant details:

* **Marker version**:
* **Surya version**:
* **Python version**:
* **PyTorch version**:
* **Transformers version**:
* **Operating System**:

## 📟 Command or Code Used

Paste the **exact bash command** or **Python code** you used to run Marker:

<details>
<summary>Click to expand</summary>

```bash
# or Python code block
your_command_here --with-flags
```

</details>

## 📎 Additional Context

Any other relevant info, configs, or assumptions.

```

--------------------------------------------------------------------------------
/benchmark/utils/textract.py:
--------------------------------------------------------------------------------

```python
import os
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import traceback

from surya.input.processing import slice_bboxes_from_image
from surya.recognition import RecognitionPredictor

def textract_ocr(extractor, img):
    try:
        document = extractor.detect_document_text(file_source=img)
        return [line.text for line in document.lines]
    except:
        traceback.print_exc()
        return [None]

def textract_ocr_parallel(imgs, cpus=None):
    from textractor import Textractor # Optional dependency

    extractor = Textractor(profile_name='default')
    parallel_cores = min(len(imgs), RecognitionPredictor().get_batch_size())
    if not cpus:
        cpus = os.cpu_count()
    parallel_cores = min(parallel_cores, cpus)

    with ThreadPoolExecutor(max_workers=parallel_cores) as executor:
        textract_text = tqdm(executor.map(textract_ocr, [extractor]*len(imgs), imgs), total=len(imgs), desc="Running textract OCR")
        textract_text = list(textract_text)
    return textract_text
```

--------------------------------------------------------------------------------
/surya/models.py:
--------------------------------------------------------------------------------

```python
from typing import Dict

import torch

from surya.common.predictor import BasePredictor
from surya.detection import DetectionPredictor
from surya.layout import LayoutPredictor
from surya.logging import configure_logging
from surya.ocr_error import OCRErrorPredictor
from surya.foundation import FoundationPredictor
from surya.recognition import RecognitionPredictor
from surya.table_rec import TableRecPredictor
from surya.settings import settings

configure_logging()


def load_predictors(
    device: str | torch.device | None = None, dtype: torch.dtype | str | None = None
) -> Dict[str, BasePredictor]:
    return {
        "layout": LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)),
        "ocr_error": OCRErrorPredictor(device=device, dtype=dtype),
        "recognition": RecognitionPredictor(FoundationPredictor(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT)),
        "detection": DetectionPredictor(device=device, dtype=dtype),
        "table_rec": TableRecPredictor(device=device, dtype=dtype),
    }

```

--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/breaking-bug-report.md:
--------------------------------------------------------------------------------

```markdown
---
name: Breaking bug report
about: Create a report about a breaking bug
title: "[BUG: Breaking]"
labels: 'bug: breaking'
assignees: ''

---

## 🧨 Describe the Bug

A clear and concise description of the breaking issue (e.g., crash, OOM, exception, etc).

## 📄 Input Document

Attach the PDF or input file that triggered the error.

## 📤 Output Trace / Stack Trace

Paste the **complete** stack trace or error output, if available.

<details>
<summary>Click to expand</summary>

```
Paste stack trace here
```

</details>

## ⚙️ Environment

Please fill in all relevant details:

- **Marker version**: 
- **Surya version**: 
- **Python version**: 
- **PyTorch version**: 
- **Transformers version**: 
- **Operating System** (incl. container info if relevant): 

## ✅ Expected Behavior

What did you expect Marker to do?

## 📟 Command or Code Used

Paste the **exact bash command** or **Python code** you used to run Marker:

<details>
<summary>Click to expand</summary>

```bash
# or Python code block
your_command_here --with-flags
```

</details>

## 📎 Additional Context

Any other context that might help us debug this (e.g., CLI options, working directory, runtime settings).

```

--------------------------------------------------------------------------------
/surya/detection/util.py:
--------------------------------------------------------------------------------

```python
import math
from PIL import ImageOps

from surya.settings import settings


def get_total_splits(image_size, height):
    img_height = list(image_size)[1]
    max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT
    if img_height > max_height:
        num_splits = math.ceil(img_height / height)
        return num_splits
    return 1


def split_image(img, height):
    # This will not modify/return the original image - it will either crop, or copy the image
    img_height = list(img.size)[1]
    max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT
    if img_height > max_height:
        num_splits = math.ceil(img_height / height)
        splits = []
        split_heights = []
        for i in range(num_splits):
            top = i * height
            bottom = (i + 1) * height
            if bottom > img_height:
                bottom = img_height
            cropped = img.crop((0, top, img.size[0], bottom))
            chunk_height = bottom - top
            if chunk_height < height:
                cropped = ImageOps.pad(cropped, (img.size[0], height), color=255, centering=(0, 0))
            splits.append(cropped)
            split_heights.append(chunk_height)
        return splits, split_heights
    return [img.copy()], [img_height]

```

--------------------------------------------------------------------------------
/benchmark/utils/scoring.py:
--------------------------------------------------------------------------------

```python
import math
from typing import List

from rapidfuzz import fuzz


def overlap_score(pred_lines: List[str], reference_lines: List[str]):
    line_scores = []
    line_weights = []
    line_match = {}
    for i, pred_line in enumerate(pred_lines):
        max_score = 0
        line_weight = 1
        match = None
        for j, ref_line in enumerate(reference_lines):
            score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100
            if score > max_score:
                max_score = score
                line_weight = math.sqrt(len(ref_line))
                match = j
        line_scores.append(max_score)
        line_weights.append(line_weight)
        line_match[i] = match
    line_scores = [line_scores[i] * line_weights[i] for i in range(len(line_scores))]

    return line_scores, line_weights, line_match


def overlap_score_exact(pred_lines: List[str], reference_lines: List[str]):
    line_scores = []
    line_weights = []
    assert len(pred_lines) == len(reference_lines)

    for i, (pred_line, ref_line) in enumerate(zip(pred_lines, reference_lines)):
        score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100
        weight = math.sqrt(len(ref_line))
        line_scores.append(score * weight)
        line_weights.append(weight)

    return line_scores, line_weights

```

--------------------------------------------------------------------------------
/.github/workflows/cla.yml:
--------------------------------------------------------------------------------

```yaml
name: "Surya CLA Assistant"
on:
  issue_comment:
    types: [created]
  pull_request_target:
    types: [opened,closed,synchronize]

# explicitly configure permissions, in case your GITHUB_TOKEN workflow permissions are set to read-only in repository settings
permissions:
  actions: write
  contents: write
  pull-requests: write
  statuses: write

jobs:
  CLAAssistant:
    runs-on: ubuntu-latest
    steps:
      - name: "Surya CLA Assistant"
        if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target'
        uses: contributor-assistant/[email protected]
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
          # the below token should have repo scope and must be manually added by you in the repository's secret
          # This token is required only if you have configured to store the signatures in a remote repository/organization
          PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
        with:
          path-to-signatures: 'signatures/version1/cla.json'
          path-to-document: 'https://github.com/VikParuchuri/surya/blob/master/CLA.md'
          # branch should not be protected
          branch: 'master'
          allowlist: VikParuchuri
```

--------------------------------------------------------------------------------
/.github/workflows/scripts.yml:
--------------------------------------------------------------------------------

```yaml
name: Test CLI scripts

on: [push]

jobs:
  build:
    runs-on: t4_gpu
    steps:
      - uses: actions/checkout@v3
      - name: Set up Python 3.11
        uses: actions/setup-python@v4
        with:
          python-version: 3.11
      - name: Install python dependencies
        run: |
          pip install poetry
          poetry install
      - name: Download benchmark data
        run: |
          wget -O benchmark_data.zip "https://drive.google.com/uc?export=download&id=1NHrdYatR1rtqs2gPVfdvO0BAvocH8CJi"
          unzip -o benchmark_data.zip
      - name: Test detection
        run: poetry run surya_detect benchmark_data/pdfs/switch_trans.pdf --page_range 0
      - name: Test OCR
        env:
          RECOGNITION_MAX_TOKENS: 25
        run: poetry run surya_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0
      - name: Test layout
        run: poetry run surya_layout benchmark_data/pdfs/switch_trans.pdf --page_range 0
      - name: Test table
        run: poetry run surya_table benchmark_data/pdfs/switch_trans.pdf --page_range 0
      - name: Test texify
        env:
          TEXIFY_MAX_TOKENS: 25
        run: poetry run surya_latex_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0
      - name: Test detection folder
        run: poetry run surya_detect benchmark_data/pdfs --page_range 0

```

--------------------------------------------------------------------------------
/surya/common/surya/encoder/config.py:
--------------------------------------------------------------------------------

```python
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)


class SuryaEncoderConfig(PretrainedConfig):
    model_type = "qwen2_5_vl"
    base_config_key = "vision_config"

    attribute_map = {
        "num_attention_heads": "num_heads",
        "num_hidden_layers": "depth",
    }

    def __init__(
        self,
        depth=8,
        hidden_size=1280,
        hidden_act="silu",
        intermediate_size=3420,
        num_heads=16,
        in_channels=3,
        patch_size=14,
        spatial_merge_size=2,
        spatial_patch_size=14,
        temporal_patch_size=1,
        tokens_per_second=4,
        window_size=112,
        out_hidden_size=1280,
        fullatt_block_indexes=(3, 7),
        initializer_range=0.02,
        image_size=4096,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.depth = depth
        self.hidden_size = hidden_size
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.num_heads = num_heads
        self.in_channels = in_channels
        self.patch_size = patch_size
        self.spatial_merge_size = spatial_merge_size
        self.temporal_patch_size = temporal_patch_size
        self.tokens_per_second = tokens_per_second
        self.window_size = window_size
        self.fullatt_block_indexes = fullatt_block_indexes
        self.out_hidden_size = out_hidden_size
        self.initializer_range = initializer_range
        self.spatial_patch_size = spatial_patch_size
        self.image_size = image_size

```

--------------------------------------------------------------------------------
/surya/detection/model/config.py:
--------------------------------------------------------------------------------

```python
from transformers import PretrainedConfig

from surya.common.s3 import S3DownloaderMixin


class EfficientViTConfig(S3DownloaderMixin, PretrainedConfig):
    r"""
    ```"""

    model_type = "efficientvit"

    def __init__(
        self,
        num_classes=2,
        num_channels=3,
        widths=(32, 64, 128, 256, 512),
        head_dim=32,
        num_stages=4,
        depths=(1, 1, 1, 6, 6),
        strides=(2, 2, 2, 2, 2),
        hidden_sizes=(32, 64, 160, 256),
        patch_size=(7, 7),
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        classifier_dropout_prob=0.0,
        layer_norm_eps=1e-6,
        decoder_layer_hidden_size=128,
        decoder_hidden_size=512,
        semantic_loss_ignore_index=255,
        initializer_range=0.02,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.num_classes = num_classes
        self.widths = widths
        self.head_dim = head_dim

        self.num_channels = num_channels
        self.num_stages = num_stages
        self.depths = depths
        self.strides = strides
        self.hidden_sizes = hidden_sizes
        self.patch_size = patch_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.classifier_dropout_prob = classifier_dropout_prob
        self.layer_norm_eps = layer_norm_eps
        self.decoder_hidden_size = decoder_hidden_size
        self.decoder_layer_hidden_size = decoder_layer_hidden_size
        self.semantic_loss_ignore_index = semantic_loss_ignore_index

        self.initializer_range = initializer_range
```

--------------------------------------------------------------------------------
/tests/test_table_rec.py:
--------------------------------------------------------------------------------

```python
from PIL import Image, ImageDraw

def test_table_rec(table_rec_predictor):
    data = [
        ["Name", "Age", "City"],
        ["Alice", 25, "New York"],
        ["Bob", 30, "Los Angeles"],
        ["Charlie", 35, "Chicago"],
    ]
    test_image = draw_table(data)

    results = table_rec_predictor([test_image])
    assert len(results) == 1
    assert results[0].image_bbox == [0, 0, test_image.size[0], test_image.size[1]]

    cells = results[0].cells
    assert len(cells) == 12
    for row_id in range(4):
        for col_id in range(3):
            cell = [c for c in cells if c.row_id == row_id and c.col_id == col_id]
            assert len(cell) == 1, f"Missing cell at row {row_id}, col {col_id}"

def draw_table(data, cell_width=100, cell_height=40):
    rows = len(data)
    cols = len(data[0])
    width = cols * cell_width
    height = rows * cell_height

    image = Image.new('RGB', (width, height), 'white')
    draw = ImageDraw.Draw(image)

    for i in range(rows + 1):
        y = i * cell_height
        draw.line([(0, y), (width, y)], fill='black', width=1)

    for i in range(cols + 1):
        x = i * cell_width
        draw.line([(x, 0), (x, height)], fill='black', width=1)

    for i in range(rows):
        for j in range(cols):
            text = str(data[i][j])
            text_bbox = draw.textbbox((0, 0), text)
            text_width = text_bbox[2] - text_bbox[0]
            text_height = text_bbox[3] - text_bbox[1]

            x = j * cell_width + (cell_width - text_width) // 2
            y = i * cell_height + (cell_height - text_height) // 2

            draw.text((x, y), text, fill='black')

    return image
```

--------------------------------------------------------------------------------
/benchmark/utils/bbox.py:
--------------------------------------------------------------------------------

```python
import fitz as pymupdf
from surya.common.util import rescale_bbox


def get_pdf_lines(pdf_path, img_sizes):
    doc = pymupdf.open(pdf_path)
    page_lines = []
    for idx, img_size in enumerate(img_sizes):
        page = doc[idx]
        blocks = page.get_text("dict", sort=True, flags=pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES)["blocks"]

        line_boxes = []
        for block_idx, block in enumerate(blocks):
            for l in block["lines"]:
                line_boxes.append(list(l["bbox"]))

        page_box = page.bound()
        pwidth, pheight = page_box[2] - page_box[0], page_box[3] - page_box[1]
        line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes]
        page_lines.append(line_boxes)

    return page_lines

def merge_boxes(box1, box2):
    return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3]))


def join_lines(bboxes, max_gap=5):
    to_merge = {}
    for i, box1 in bboxes:
        for z, box2 in bboxes[i + 1:]:
            j = i + z + 1
            if box1 == box2:
                continue

            if box1[0] <= box2[0] and box1[2] >= box2[2]:
                if abs(box1[1] - box2[3]) <= max_gap:
                    if i not in to_merge:
                        to_merge[i] = []
                    to_merge[i].append(j)

    merged_boxes = set()
    merged = []
    for i, box in bboxes:
        if i in merged_boxes:
            continue

        if i in to_merge:
            for j in to_merge[i]:
                box = merge_boxes(box, bboxes[j][1])
                merged_boxes.add(j)

        merged.append(box)
    return merged

```

--------------------------------------------------------------------------------
/surya/scripts/ocr_latex.py:
--------------------------------------------------------------------------------

```python
import os

import click
import json
import time
from collections import defaultdict

from surya.logging import configure_logging, get_logger
from surya.scripts.config import CLILoader
from surya.foundation import FoundationPredictor
from surya.recognition import RecognitionPredictor
from surya.common.surya.schema import TaskNames

configure_logging()
logger = get_logger()


@click.command(help="OCR LaTeX equations.")
@CLILoader.common_options
def ocr_latex_cli(input_path: str, **kwargs):
    loader = CLILoader(input_path, kwargs, highres=True)

    foundation_predictor = FoundationPredictor()
    texify_predictor = RecognitionPredictor(foundation_predictor)
    tasks = [TaskNames.block_without_boxes] * len(loader.images)
    bboxes = [[[0, 0, image.width, image.height]] for image in loader.images]

    start = time.time()
    predictions_by_image = texify_predictor(
        loader.images,
        tasks,
        bboxes=bboxes,
    )

    latex_predictions = [p.text_lines[0].text for p in predictions_by_image]

    if loader.debug:
        logger.debug(f"OCR took {time.time() - start:.2f} seconds")
        max_chars = max([len(latex) for latex in latex_predictions])
        logger.debug(f"Max chars: {max_chars}")

    out_preds = defaultdict(list)
    for name, pred, image in zip(loader.names, latex_predictions, loader.images):
        out_pred = {
            "equation": pred,
            "page": len(out_preds[name]) + 1,
        }
        out_preds[name].append(out_pred)

    with open(
        os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8"
    ) as f:
        json.dump(out_preds, f, ensure_ascii=False)

    logger.info(f"Wrote results to {loader.result_path}")

```

--------------------------------------------------------------------------------
/surya/foundation/util.py:
--------------------------------------------------------------------------------

```python
from typing import List, Tuple
import numpy as np
import torch

def detect_repeat_token(predicted_tokens: List[int], max_repeats: int = 40):
    if len(predicted_tokens) < max_repeats:
        return False

    # Detect repeats containing 1 or 2 tokens
    last_n = predicted_tokens[-max_repeats:]
    unique_tokens = len(set(last_n))
    if unique_tokens > 5:
        return False

    return last_n[-unique_tokens:] == last_n[-unique_tokens * 2 : -unique_tokens]

def prediction_to_polygon_batch(
    pred: torch.Tensor,
    img_sizes: List[Tuple[int, int]],
    bbox_scaler,
    skew_scaler,
    skew_min=0.001,
):
    img_sizes = torch.from_numpy(np.array(img_sizes, dtype=np.float32)).to(
        pred.device
    )
    w_scale = (img_sizes[:, 1] / bbox_scaler)[:, None, None]
    h_scale = (img_sizes[:, 0] / bbox_scaler)[:, None, None]

    cx = pred[:, :, 0]
    cy = pred[:, :, 1]
    width = pred[:, :, 2]
    height = pred[:, :, 3]

    x1 = cx - width / 2
    y1 = cy - height / 2
    x2 = cx + width / 2
    y2 = cy + height / 2

    skew_x = torch.floor((pred[:, :, 4] - skew_scaler) / 2)
    skew_y = torch.floor((pred[:, :, 5] - skew_scaler) / 2)

    skew_x[torch.abs(skew_x) < skew_min] = 0
    skew_y[torch.abs(skew_y) < skew_min] = 0

    polygons_flat = torch.stack(
        [
            x1 - skew_x,
            y1 - skew_y,
            x2 - skew_x,
            y1 + skew_y,
            x2 + skew_x,
            y2 + skew_y,
            x1 + skew_x,
            y2 - skew_y,
        ],
        dim=2,
    )

    batch_size, seq_len, _ = pred.shape
    polygons = polygons_flat.view(batch_size, seq_len, 4, 2)

    polygons[:, :, :, 0] *= w_scale
    polygons[:, :, :, 1] *= h_scale

    return polygons
```

--------------------------------------------------------------------------------
/surya/scripts/detect_text.py:
--------------------------------------------------------------------------------

```python
import click
import copy
import json
import time
from collections import defaultdict

from surya.detection import DetectionPredictor
from surya.debug.draw import draw_polys_on_image
from surya.logging import configure_logging, get_logger
from surya.scripts.config import CLILoader
import os

configure_logging()
logger = get_logger()


@click.command(help="Detect bboxes in an input file or folder (PDFs or image).")
@CLILoader.common_options
def detect_text_cli(input_path: str, **kwargs):
    loader = CLILoader(input_path, kwargs)

    det_predictor = DetectionPredictor()

    start = time.time()
    predictions = det_predictor(loader.images, include_maps=loader.debug)
    end = time.time()
    if loader.debug:
        logger.debug(f"Detection took {end - start} seconds")

    if loader.save_images:
        for idx, (image, pred, name) in enumerate(
            zip(loader.images, predictions, loader.names)
        ):
            polygons = [p.polygon for p in pred.bboxes]
            bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image))
            bbox_image.save(os.path.join(loader.result_path, f"{name}_{idx}_bbox.png"))

            if loader.debug:
                heatmap = pred.heatmap
                heatmap.save(os.path.join(loader.result_path, f"{name}_{idx}_heat.png"))

    predictions_by_page = defaultdict(list)
    for idx, (pred, name, image) in enumerate(
        zip(predictions, loader.names, loader.images)
    ):
        out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"])
        out_pred["page"] = len(predictions_by_page[name]) + 1
        predictions_by_page[name].append(out_pred)

    with open(
        os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8"
    ) as f:
        json.dump(predictions_by_page, f, ensure_ascii=False)

    logger.info(f"Wrote results to {loader.result_path}")

```

--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------

```toml
[tool.poetry]
name = "surya-ocr"
version = "0.17.0"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
license = "GPL-3.0-or-later"
repository = "https://github.com/VikParuchuri/surya"
keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"]
packages = [
    {include = "surya"}
]

[tool.poetry.dependencies]
python = "^3.10"
transformers = ">=4.56.1"
torch = "^2.7.0"
pydantic = "^2.5.3"
pydantic-settings = "^2.1.0"
python-dotenv = "^1.0.0"
pillow = "^10.2.0"
pypdfium2 = "=4.30.0"
filetype = "^1.2.0"
click = "^8.1.8"
platformdirs = "^4.3.6"
opencv-python-headless = "==4.11.0.86"
einops = "^0.8.1"
pre-commit = "^4.2.0"

[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
pytesseract = "^0.3.10"
pymupdf = "^1.23.8"
datasets = "^2.16.1"
rapidfuzz = "^3.6.1"
streamlit = "^1.31.0"
pytest = "^8.3.4"
pdftext = "^0.5.1"
tabulate = "^0.9.0"

[tool.poetry.scripts]
surya_detect = "surya.scripts.detect_text:detect_text_cli"
surya_ocr = "surya.scripts.ocr_text:ocr_text_cli"
surya_layout = "surya.scripts.detect_layout:detect_layout_cli"
surya_gui = "surya.scripts.run_streamlit_app:streamlit_app_cli"
surya_table = "surya.scripts.table_recognition:table_recognition_cli"
surya_latex_ocr = "surya.scripts.ocr_latex:ocr_latex_cli"
texify_gui = "surya.scripts.run_texify_app:texify_app_cli"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[[tool.poetry.source]]
name = "libtpu-releases"
url = "https://storage.googleapis.com/libtpu-releases/index.html"
priority = "supplemental"

[[tool.poetry.source]]
name = "libtpu-wheels"
url = "https://storage.googleapis.com/libtpu-wheels/index.html"
priority = "supplemental"

[tool.poetry.group.xla]
optional = true

[tool.poetry.group.xla.dependencies]
torch-xla = {version = "^2.4.1", extras = ["tpu"]}

```

--------------------------------------------------------------------------------
/.github/workflows/benchmarks.yml:
--------------------------------------------------------------------------------

```yaml
name: Integration test

on: [push]

env:
  PYTHONIOENCODING: "utf-8"

jobs:
  build:
    runs-on: t4_gpu
    steps:
      - uses: actions/checkout@v3
      - name: Set up Python 3.11
        uses: actions/setup-python@v4
        with:
          python-version: 3.11
      - name: Install python dependencies
        run: |
          pip install poetry
          poetry install
      - name: Run detection benchmark test
        run: |
          poetry run python benchmark/detection.py --max_rows 2
          poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection
      - name: Run recognition benchmark test
        run: |
          poetry run python benchmark/recognition.py --max_rows 2
          poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition
      - name: Run layout benchmark test
        run: |
          poetry run python benchmark/layout.py --max_rows 5
          poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout
      - name: Run ordering benchmark
        run: |
          poetry run python benchmark/ordering.py --max_rows 5
          poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering
      - name: Run table recognition benchmark
        run: |
          poetry run python benchmark/table_recognition.py --max_rows 5
          poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition
      - name: Run texify benchmark
        run: |
          poetry run python benchmark/texify.py --max_rows 5
          poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/texify_bench/results.json --bench_type texify
```

--------------------------------------------------------------------------------
/surya/ocr_error/model/config.py:
--------------------------------------------------------------------------------

```python
from collections import OrderedDict
from typing import Mapping

from transformers.configuration_utils import PretrainedConfig
from transformers.onnx import OnnxConfig

from surya.common.s3 import S3DownloaderMixin

ID2LABEL = {
    0: 'good',
    1: 'bad'
}

class DistilBertConfig(S3DownloaderMixin, PretrainedConfig):
    model_type = "distilbert"
    attribute_map = {
        "hidden_size": "dim",
        "num_attention_heads": "n_heads",
        "num_hidden_layers": "n_layers",
    }

    def __init__(
        self,
        vocab_size=30522,
        max_position_embeddings=512,
        sinusoidal_pos_embds=False,
        n_layers=6,
        n_heads=12,
        dim=768,
        hidden_dim=4 * 768,
        dropout=0.1,
        attention_dropout=0.1,
        activation="gelu",
        initializer_range=0.02,
        qa_dropout=0.1,
        seq_classif_dropout=0.2,
        pad_token_id=0,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.sinusoidal_pos_embds = sinusoidal_pos_embds
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.attention_dropout = attention_dropout
        self.activation = activation
        self.initializer_range = initializer_range
        self.qa_dropout = qa_dropout
        self.seq_classif_dropout = seq_classif_dropout
        super().__init__(**kwargs, pad_token_id=pad_token_id)


class DistilBertOnnxConfig(OnnxConfig):
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
        else:
            dynamic_axis = {0: "batch", 1: "sequence"}
        return OrderedDict(
            [
                ("input_ids", dynamic_axis),
                ("attention_mask", dynamic_axis),
            ]
        )
```

--------------------------------------------------------------------------------
/surya/debug/katex.js:
--------------------------------------------------------------------------------

```javascript
<style>
    .katex-display-container {
        display: inline-block;
        max-width: 100%;
        overflow-x: auto;
        max-height: 100%;
    }

    .katex-inline-container {
        display: inline-block;
        max-width: 100%;
        overflow-x: auto;
        max-height: 100%;
    }
</style>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.js" onload="setTimeout(function() {renderMath()})" async></script>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css">
<script>
    function htmlUnescape(escapedText) {
      const htmlEntities = {
        '&amp;': '&',
        '&lt;': '<',
        '&gt;': '>',
        '&quot;': '"',
        '&#39;': "'",
        '&nbsp;': ' '
      };

      return escapedText.replace(/&amp;|&lt;|&gt;|&quot;|&#39;|&nbsp;/g, match => htmlEntities[match]);
    }

    const renderMath = (function() {
    try {
       const mathElements = document.querySelectorAll('math');

        mathElements.forEach(function(element) {
          let mathContent = element.innerHTML.trim();
          mathContent = htmlUnescape(mathContent);
          const isDisplay = element.getAttribute('display') === 'block';

          const container = document.createElement('span');
          container.className = isDisplay ? 'katex-display-container' : 'katex-inline-container';
          element.parentNode.insertBefore(container, element);

          try {
            katex.render(mathContent, container, {
              displayMode: isDisplay,
              throwOnError: false
            });

          } catch (err) {
            console.error('KaTeX rendering error:', err);
            container.textContent = mathContent; // Fallback to raw text
          }

          element.parentNode.removeChild(element);
        });

        console.log('Math rendering complete with', mathElements.length, 'expressions');
      } catch (err) {
        console.error('Error in renderMath function:', err);
      }
    });
</script>
```

--------------------------------------------------------------------------------
/surya/scripts/detect_layout.py:
--------------------------------------------------------------------------------

```python
import time
import click
import copy
import json
from collections import defaultdict

from surya.foundation import FoundationPredictor
from surya.layout import LayoutPredictor
from surya.debug.draw import draw_polys_on_image
from surya.logging import configure_logging, get_logger
from surya.scripts.config import CLILoader
from surya.settings import settings
import os

configure_logging()
logger = get_logger()


@click.command(help="Detect layout of an input file or folder (PDFs or image).")
@CLILoader.common_options
def detect_layout_cli(input_path: str, **kwargs):
    loader = CLILoader(input_path, kwargs)

    foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
    layout_predictor = LayoutPredictor(foundation_predictor)

    start = time.time()
    layout_predictions = layout_predictor(loader.images)

    if loader.debug:
        logger.debug(f"Layout took {time.time() - start} seconds")

    if loader.save_images:
        for idx, (image, layout_pred, name) in enumerate(
            zip(loader.images, layout_predictions, loader.names)
        ):
            polygons = [p.polygon for p in layout_pred.bboxes]
            labels = [f"{p.label}-{p.position}" for p in layout_pred.bboxes]
            bbox_image = draw_polys_on_image(
                polygons, copy.deepcopy(image), labels=labels
            )
            bbox_image.save(
                os.path.join(loader.result_path, f"{name}_{idx}_layout.png")
            )

    predictions_by_page = defaultdict(list)
    for idx, (pred, name, image) in enumerate(
        zip(layout_predictions, loader.names, loader.images)
    ):
        out_pred = pred.model_dump()
        out_pred["page"] = len(predictions_by_page[name]) + 1
        predictions_by_page[name].append(out_pred)

    with open(
        os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8"
    ) as f:
        json.dump(predictions_by_page, f, ensure_ascii=False)

    logger.info(f"Wrote results to {loader.result_path}")

```

--------------------------------------------------------------------------------
/surya/ocr_error/loader.py:
--------------------------------------------------------------------------------

```python
from typing import Optional

import torch

from surya.common.load import ModelLoader
from surya.logging import get_logger
from surya.ocr_error.model.config import DistilBertConfig
from surya.ocr_error.model.encoder import DistilBertForSequenceClassification
from surya.ocr_error.tokenizer import DistilBertTokenizer
from surya.settings import settings

logger = get_logger()


class OCRErrorModelLoader(ModelLoader):
    def __init__(self, checkpoint: Optional[str] = None):
        super().__init__(checkpoint)

        if self.checkpoint is None:
            self.checkpoint = settings.OCR_ERROR_MODEL_CHECKPOINT

    def model(
        self,
        device=settings.TORCH_DEVICE_MODEL,
        dtype=settings.MODEL_DTYPE,
        attention_implementation: Optional[str] = None,
    ) -> DistilBertForSequenceClassification:
        if device is None:
            device = settings.TORCH_DEVICE_MODEL
        if dtype is None:
            dtype = settings.MODEL_DTYPE

        config = DistilBertConfig.from_pretrained(self.checkpoint)
        model = (
            DistilBertForSequenceClassification.from_pretrained(
                self.checkpoint,
                dtype=dtype,
                config=config,
            )
            .to(device)
            .eval()
        )

        if settings.COMPILE_ALL or settings.COMPILE_OCR_ERROR:
            torch._dynamo.config.cache_size_limit = 1
            torch._dynamo.config.suppress_errors = False

            logger.info(
                f"Compiling detection model {self.checkpoint} from {DistilBertForSequenceClassification.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}"
            )
            compile_args = {"backend": "openxla"} if device == "xla" else {}
            model = torch.compile(model, **compile_args)

        return model

    def processor(
        self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE
    ) -> DistilBertTokenizer:
        return DistilBertTokenizer.from_pretrained(self.checkpoint)

```

--------------------------------------------------------------------------------
/benchmark/utils/verify_benchmark_scores.py:
--------------------------------------------------------------------------------

```python
import json
import click


def verify_layout(data):
    scores = data["metrics"]
    for layout_type, metrics in scores.items():
        if layout_type == "List":  # Skip lists since none appear early on
            continue

        if metrics["precision"] <= 0.6 or metrics["recall"] <= 0.6:
            raise ValueError("Scores do not meet the required threshold")


def verify_det(data):
    scores = data["metrics"]["surya"]
    if scores["precision"] <= 0.9 or scores["recall"] <= 0.9:
        raise ValueError("Scores do not meet the required threshold")


def verify_rec(data):
    scores = data["surya"]
    if scores["avg_score"] <= 0.9:
        raise ValueError("Scores do not meet the required threshold")


def verify_order(data):
    score = data["mean_accuracy"]
    if score < 0.75:
        raise ValueError("Scores do not meet the required threshold")


def verify_table_rec(data):
    row_score = data["surya"]["mean_row_iou"]
    col_score = data["surya"]["mean_col_iou"]

    if row_score < 0.75 or col_score < 0.75:
        raise ValueError("Scores do not meet the required threshold")


def verify_texify(data):
    edit_dist = data["scores"]
    if edit_dist > 0.2:
        raise ValueError("Scores do not meet the required threshold")


@click.command(help="Verify benchmark scores")
@click.argument("file_path", type=str)
@click.option(
    "--bench_type", type=str, help="Type of benchmark to verify", default="detection"
)
def main(file_path, bench_type):
    with open(file_path, "r") as file:
        data = json.load(file)

    if bench_type == "detection":
        verify_det(data)
    elif bench_type == "recognition":
        verify_rec(data)
    elif bench_type == "layout":
        verify_layout(data)
    elif bench_type == "ordering":
        verify_order(data)
    elif bench_type == "table_recognition":
        verify_table_rec(data)
    elif bench_type == "texify":
        verify_texify(data)
    else:
        raise ValueError("Invalid benchmark type")


if __name__ == "__main__":
    main()

```

--------------------------------------------------------------------------------
/surya/debug/draw.py:
--------------------------------------------------------------------------------

```python
from PIL import ImageDraw, ImageFont

from surya.debug.fonts import get_font_path
from surya.debug.text import get_text_size


def draw_bboxes_on_image(
    bboxes, image, labels=None, label_font_size=10, color: str | list = "red"
):
    polys = []
    for bb in bboxes:
        # Clockwise polygon
        poly = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]]
        polys.append(poly)

    return draw_polys_on_image(
        polys, image, labels, label_font_size=label_font_size, color=color
    )


def draw_polys_on_image(
    corners,
    image,
    labels=None,
    box_padding=-1,
    label_offset=1,
    label_font_size=10,
    color: str | list = "red",
):
    draw = ImageDraw.Draw(image)
    font_path = get_font_path()
    label_font = ImageFont.truetype(font_path, label_font_size)

    for i in range(len(corners)):
        poly = corners[i]
        poly = [(int(p[0]), int(p[1])) for p in poly]
        draw.polygon(
            poly, outline=color[i] if isinstance(color, list) else color, width=1
        )

        if labels is not None:
            label = labels[i]
            text_position = (
                min([p[0] for p in poly]) + label_offset,
                min([p[1] for p in poly]) + label_offset,
            )
            text_size = get_text_size(label, label_font)
            box_position = (
                text_position[0] - box_padding + label_offset,
                text_position[1] - box_padding + label_offset,
                text_position[0] + text_size[0] + box_padding + label_offset,
                text_position[1] + text_size[1] + box_padding + label_offset,
            )
            try:
                draw.rectangle(box_position, fill="white")
            except Exception as e:
                print(f"Error drawing rectangle at {box_position}: {e}")
                continue
            draw.text(
                text_position,
                label,
                fill=color[i] if isinstance(color, list) else color,
                font=label_font,
            )

    return image

```

--------------------------------------------------------------------------------
/surya/recognition/languages.py:
--------------------------------------------------------------------------------

```python
CODE_TO_LANGUAGE = {
    "_math": "Math",
    "af": "Afrikaans",
    "am": "Amharic",
    "ar": "Arabic",
    "as": "Assamese",
    "az": "Azerbaijani",
    "be": "Belarusian",
    "bg": "Bulgarian",
    "bn": "Bengali",
    "br": "Breton",
    "bs": "Bosnian",
    "ca": "Catalan",
    "cs": "Czech",
    "cy": "Welsh",
    "da": "Danish",
    "de": "German",
    "el": "Greek",
    "en": "English",
    "eo": "Esperanto",
    "es": "Spanish",
    "et": "Estonian",
    "eu": "Basque",
    "fa": "Persian",
    "fi": "Finnish",
    "fr": "French",
    "fy": "Western Frisian",
    "ga": "Irish",
    "gd": "Scottish Gaelic",
    "gl": "Galician",
    "gu": "Gujarati",
    "ha": "Hausa",
    "he": "Hebrew",
    "hi": "Hindi",
    "hr": "Croatian",
    "hu": "Hungarian",
    "hy": "Armenian",
    "id": "Indonesian",
    "is": "Icelandic",
    "it": "Italian",
    "ja": "Japanese",
    "jv": "Javanese",
    "ka": "Georgian",
    "kk": "Kazakh",
    "km": "Khmer",
    "kn": "Kannada",
    "ko": "Korean",
    "ku": "Kurdish",
    "ky": "Kyrgyz",
    "la": "Latin",
    "lo": "Lao",
    "lt": "Lithuanian",
    "lv": "Latvian",
    "mg": "Malagasy",
    "mk": "Macedonian",
    "ml": "Malayalam",
    "mn": "Mongolian",
    "mr": "Marathi",
    "ms": "Malay",
    "my": "Burmese",
    "ne": "Nepali",
    "nl": "Dutch",
    "no": "Norwegian",
    "om": "Oromo",
    "or": "Oriya",
    "pa": "Punjabi",
    "pl": "Polish",
    "ps": "Pashto",
    "pt": "Portuguese",
    "ro": "Romanian",
    "ru": "Russian",
    "sa": "Sanskrit",
    "sd": "Sindhi",
    "si": "Sinhala",
    "sk": "Slovak",
    "sl": "Slovenian",
    "so": "Somali",
    "sq": "Albanian",
    "sr": "Serbian",
    "su": "Sundanese",
    "sv": "Swedish",
    "sw": "Swahili",
    "ta": "Tamil",
    "te": "Telugu",
    "th": "Thai",
    "tl": "Tagalog",
    "tr": "Turkish",
    "ug": "Uyghur",
    "uk": "Ukrainian",
    "ur": "Urdu",
    "uz": "Uzbek",
    "vi": "Vietnamese",
    "xh": "Xhosa",
    "yi": "Yiddish",
    "zh": "Chinese",
}

LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()}

```

--------------------------------------------------------------------------------
/surya/common/surya/embedder/__init__.py:
--------------------------------------------------------------------------------

```python
import torch
import torch.nn as nn
import torch.nn.functional as F


class SimpleTokenEmbedder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
        self.bbox_embed = nn.ModuleList(
            [
                nn.Embedding(
                    config.bbox_size + config.special_token_count,
                    config.bbox_embed_size,
                )
                for _ in range(6)
            ]
        )
        self.max_bbox_embedding = config.bbox_size + config.special_token_count - 1
        self.max_bbox_size = config.bbox_size

    def embed(
        self,
        input_tokens: torch.Tensor,
        input_boxes: torch.Tensor | None,
        embed_boxes: torch.Tensor,
    ) -> torch.Tensor:
        # Embed tokens
        token_embeds = self.token_embed(input_tokens)

        # Optionally embed boxes
        if input_boxes is not None and embed_boxes.any():  # Is none in prefill
            input_boxes = input_boxes.to(torch.long)
            bbox_loss_ignore_mask = (
                (input_boxes[:, :, 0] < 0) | (input_boxes[:, :, 0] > self.max_bbox_size)
            ).unsqueeze(-1)
            input_boxes = torch.clamp(input_boxes, 0, self.max_bbox_embedding)

            bbox_embeds = torch.sum(
                torch.stack(
                    [
                        self.bbox_embed[i](input_boxes[:, :, i])
                        for i in range(len(self.bbox_embed))
                    ],
                    dim=-1,
                ),
                dim=-1,
            )

            bbox_embeds = F.pad(
                bbox_embeds, (token_embeds.shape[-1] - bbox_embeds.shape[-1], 0)
            )
            embed_boxes = embed_boxes.unsqueeze(1).unsqueeze(1).expand_as(bbox_embeds)
            bbox_loss_ignore_mask = bbox_loss_ignore_mask.expand_as(bbox_embeds)

            mask = embed_boxes & ~bbox_loss_ignore_mask
            bbox_embeds *= mask.float()

            token_embeds = token_embeds + bbox_embeds

        return token_embeds

```

--------------------------------------------------------------------------------
/surya/detection/loader.py:
--------------------------------------------------------------------------------

```python
from typing import Optional

import torch

from surya.common.load import ModelLoader
from surya.detection.processor import SegformerImageProcessor

from surya.detection.model.config import EfficientViTConfig
from surya.detection.model.encoderdecoder import EfficientViTForSemanticSegmentation
from surya.logging import get_logger
from surya.settings import settings

logger = get_logger()


class DetectionModelLoader(ModelLoader):
    def __init__(self, checkpoint: Optional[str] = None):
        super().__init__(checkpoint)

        if self.checkpoint is None:
            self.checkpoint = settings.DETECTOR_MODEL_CHECKPOINT

    def model(
        self,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype | str] = None,
        attention_implementation: Optional[str] = None,
    ) -> EfficientViTForSemanticSegmentation:
        if device is None:
            device = settings.TORCH_DEVICE_MODEL
        if dtype is None:
            dtype = settings.MODEL_DTYPE

        config = EfficientViTConfig.from_pretrained(self.checkpoint)
        model = EfficientViTForSemanticSegmentation.from_pretrained(
            self.checkpoint,
            dtype=dtype,
            config=config,
        )
        model = model.to(device)
        model = model.eval()

        if settings.COMPILE_ALL or settings.COMPILE_DETECTOR:
            torch._dynamo.config.cache_size_limit = 1
            torch._dynamo.config.suppress_errors = False

            logger.info(
                f"Compiling detection model {self.checkpoint} on device {device} with dtype {dtype}"
            )
            compile_args = {"backend": "openxla"} if device == "xla" else {}
            model = torch.compile(model, **compile_args)

        logger.debug(
            f"Loaded detection model {self.checkpoint} from {EfficientViTForSemanticSegmentation.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}"
        )
        return model

    def processor(
        self,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype | str] = None,
    ) -> SegformerImageProcessor:
        return SegformerImageProcessor.from_pretrained(self.checkpoint)

```

--------------------------------------------------------------------------------
/surya/scripts/hf_to_s3.py:
--------------------------------------------------------------------------------

```python
import json
import shutil
import datetime
from pathlib import Path
import boto3

from huggingface_hub import snapshot_download

import click
from tqdm import tqdm

S3_API_URL = "https://1afbe4656a6b40d982ab5e730a39f6b9.r2.cloudflarestorage.com"


# Example usage - python scripts/hf_to_s3.py <REPO_NAME> layout
# This will upload to s3://layout/TODAYS_DATE
@click.command(help="Uploads the data from huggingface to an S3 bucket")
@click.argument("hf_repo_id", type=str)
@click.argument("s3_path", type=str)
@click.option("--bucket_name", type=str, default="datalab")
@click.option("--revision_hash", type=str, default=None)
@click.option("--access_key_id", type=str, default="<access_key_id>")
@click.option("--access_key_secret", type=str, default="<access_key_secret>")
@click.option("--suffix", type=str, default="")
def main(
    hf_repo_id: str,
    s3_path: str,
    bucket_name: str,
    revision_hash: str,
    access_key_id: str,
    access_key_secret: str,
    suffix: str,
):
    curr_date = datetime.datetime.now().strftime("%Y_%m_%d")
    s3_path = f"{s3_path}/{curr_date}"
    if suffix:
        s3_path = f"{s3_path}_{suffix}"

    download_folder = snapshot_download(repo_id=hf_repo_id, revision=revision_hash)
    download_folder = Path(download_folder)
    contained_files = list(download_folder.glob("*"))
    contained_files = [f.name for f in contained_files]  # Just get the base name
    manifest_file = download_folder / "manifest.json"

    with open(manifest_file, "w") as f:
        json.dump({"files": contained_files}, f)

    # Upload the files to S3
    s3_client = boto3.client(
        service_name="s3",
        endpoint_url=S3_API_URL,
        aws_access_key_id=access_key_id,
        aws_secret_access_key=access_key_secret,
        region_name="auto",
    )

    # Iterate through all files in the folder
    for file_path in tqdm(
        download_folder.glob("*"), desc="Uploading files", unit="file"
    ):
        s3_key = f"{s3_path}/{file_path.name}"

        try:
            s3_client.upload_file(str(file_path), bucket_name, s3_key)
        except Exception as e:
            print(f"Error uploading {file_path}: {str(e)}")

    shutil.rmtree(download_folder)

    print(f"Uploaded files to {s3_path}")


if __name__ == "__main__":
    main()

```

--------------------------------------------------------------------------------
/surya/ocr_error/__init__.py:
--------------------------------------------------------------------------------

```python
import math
from typing import List, Optional

from tqdm import tqdm

from surya.common.predictor import BasePredictor
from surya.ocr_error.loader import OCRErrorModelLoader
from surya.ocr_error.model.config import ID2LABEL
from surya.ocr_error.schema import OCRErrorDetectionResult
from surya.settings import settings
from surya.common.xla import mark_step


class OCRErrorPredictor(BasePredictor):
    model_loader_cls = OCRErrorModelLoader
    batch_size = settings.OCR_ERROR_BATCH_SIZE
    default_batch_sizes = {"cpu": 8, "mps": 8, "cuda": 64, "xla": 32}

    def __call__(self, texts: List[str], batch_size: Optional[int] = None):
        return self.batch_ocr_error_detection(texts, batch_size)

    def batch_ocr_error_detection(
        self, texts: List[str], batch_size: Optional[int] = None
    ):
        if batch_size is None:
            batch_size = self.get_batch_size()

        num_batches = math.ceil(len(texts) / batch_size)
        texts_processed = self.processor(
            texts, padding="longest", truncation=True, return_tensors="pt"
        )
        predictions = []
        for batch_idx in tqdm(
            range(num_batches),
            desc="Running OCR Error Detection",
            disable=self.disable_tqdm,
        ):
            start_idx, end_idx = batch_idx * batch_size, (batch_idx + 1) * batch_size
            batch_input_ids = texts_processed.input_ids[start_idx:end_idx].to(
                self.model.device
            )
            batch_attention_mask = texts_processed.attention_mask[start_idx:end_idx].to(
                self.model.device
            )

            # Pad to batch size
            current_batch_size = batch_input_ids.shape[0]
            if settings.OCR_ERROR_STATIC_CACHE:
                batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size)
                batch_attention_mask = self.pad_to_batch_size(
                    batch_attention_mask, batch_size
                )

            with settings.INFERENCE_MODE():
                pred = self.model(batch_input_ids, attention_mask=batch_attention_mask)

                logits = pred.logits.argmax(dim=1).cpu().tolist()[:current_batch_size]
                predictions.extend(logits)
            mark_step()

        return OCRErrorDetectionResult(
            texts=texts, labels=[ID2LABEL[p] for p in predictions]
        )

```

--------------------------------------------------------------------------------
/surya/input/load.py:
--------------------------------------------------------------------------------

```python
from typing import List
import PIL

from surya.input.processing import open_pdf, get_page_images
from surya.logging import get_logger
from surya.settings import settings
import os
import filetype
from PIL import Image
import json

logger = get_logger()


def get_name_from_path(path):
    return os.path.basename(path).split(".")[0]


def load_pdf(pdf_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI):
    doc = open_pdf(pdf_path)
    last_page = len(doc)

    if page_range:
        assert all([0 <= page < last_page for page in page_range]), (
            f"Invalid page range: {page_range}"
        )
    else:
        page_range = list(range(last_page))

    images = get_page_images(doc, page_range, dpi=dpi)
    doc.close()
    names = [get_name_from_path(pdf_path) for _ in page_range]
    return images, names


def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    name = get_name_from_path(image_path)
    return [image], [name]


def load_from_file(
    input_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI
):
    input_type = filetype.guess(input_path)
    if input_type and input_type.extension == "pdf":
        return load_pdf(input_path, page_range, dpi=dpi)
    else:
        return load_image(input_path)


def load_from_folder(
    folder_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI
):
    image_paths = [
        os.path.join(folder_path, image_name)
        for image_name in os.listdir(folder_path)
        if not image_name.startswith(".")
    ]
    image_paths = [ip for ip in image_paths if not os.path.isdir(ip)]

    images = []
    names = []
    for path in image_paths:
        extension = filetype.guess(path)
        if extension and extension.extension == "pdf":
            image, name = load_pdf(path, page_range, dpi=dpi)
            images.extend(image)
            names.extend(name)
        else:
            try:
                image, name = load_image(path)
                images.extend(image)
                names.extend(name)
            except PIL.UnidentifiedImageError:
                logger.warning(f"Could not load image {path}")
                continue
    return images, names


def load_lang_file(lang_path, names):
    with open(lang_path, "r") as f:
        lang_dict = json.load(f)
    return [lang_dict[name].copy() for name in names]

```

--------------------------------------------------------------------------------
/surya/scripts/ocr_text.py:
--------------------------------------------------------------------------------

```python
import os
import click
import json
import time
from collections import defaultdict

from surya.common.surya.schema import TaskNames
from surya.detection import DetectionPredictor
from surya.debug.text import draw_text_on_image
from surya.logging import configure_logging, get_logger
from surya.foundation import FoundationPredictor
from surya.recognition import RecognitionPredictor
from surya.scripts.config import CLILoader

configure_logging()
logger = get_logger()


@click.command(help="OCR text.")
@click.option("--task_name", type=str, default=TaskNames.ocr_with_boxes)
@click.option(
    "--disable_math", is_flag=True, default=False, help="Do not recognize math in OCR."
)
@CLILoader.common_options
def ocr_text_cli(input_path: str, task_name: str, disable_math: bool, **kwargs):
    loader = CLILoader(input_path, kwargs, highres=True)
    task_names = [task_name] * len(loader.images)

    foundation_predictor = FoundationPredictor()
    det_predictor = DetectionPredictor()
    rec_predictor = RecognitionPredictor(foundation_predictor)

    start = time.time()
    predictions_by_image = rec_predictor(
        loader.images,
        task_names=task_names,
        det_predictor=det_predictor,
        highres_images=loader.highres_images,
        math_mode=not disable_math,
    )

    if loader.debug:
        logger.debug(f"OCR took {time.time() - start:.2f} seconds")
        max_chars = max(
            [len(line.text) for p in predictions_by_image for line in p.text_lines]
        )
        logger.debug(f"Max chars: {max_chars}")

    if loader.save_images:
        for idx, (name, image, pred) in enumerate(
            zip(loader.names, loader.images, predictions_by_image)
        ):
            bboxes = [line.bbox for line in pred.text_lines]
            pred_text = [line.text for line in pred.text_lines]
            page_image = draw_text_on_image(bboxes, pred_text, image.size)
            page_image.save(os.path.join(loader.result_path, f"{name}_{idx}_text.png"))

    out_preds = defaultdict(list)
    for name, pred, image in zip(loader.names, predictions_by_image, loader.images):
        out_pred = pred.model_dump()
        out_pred["page"] = len(out_preds[name]) + 1
        out_preds[name].append(out_pred)

    with open(
        os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8"
    ) as f:
        json.dump(out_preds, f, ensure_ascii=False)

    logger.info(f"Wrote results to {loader.result_path}")

```

--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------

```python
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import pytest
from PIL import Image, ImageDraw

from surya.detection import DetectionPredictor
from surya.ocr_error import OCRErrorPredictor
from surya.layout import LayoutPredictor
from surya.recognition import RecognitionPredictor
from surya.foundation import FoundationPredictor
from surya.table_rec import TableRecPredictor
from surya.settings import settings

@pytest.fixture(scope="session")
def ocr_error_predictor() -> OCRErrorPredictor:
    ocr_error_predictor = OCRErrorPredictor()
    yield ocr_error_predictor
    del ocr_error_predictor


@pytest.fixture(scope="session")
def layout_predictor() -> LayoutPredictor:
    layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT))
    yield layout_predictor
    del layout_predictor


@pytest.fixture(scope="session")
def detection_predictor() -> DetectionPredictor:
    detection_predictor = DetectionPredictor()
    yield detection_predictor
    del detection_predictor


@pytest.fixture(scope="session")
def recognition_predictor() -> RecognitionPredictor:
    recognition_predictor = RecognitionPredictor(FoundationPredictor(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT))
    yield recognition_predictor
    del recognition_predictor


@pytest.fixture(scope="session")
def table_rec_predictor() -> TableRecPredictor:
    table_rec_predictor = TableRecPredictor()
    yield table_rec_predictor
    del table_rec_predictor


@pytest.fixture()
def test_image():
    image = Image.new("RGB", (1024, 1024), "white")
    draw = ImageDraw.Draw(image)
    draw.text((10, 10), "Hello World", fill="black", font_size=72)
    draw.text(
        (10, 200),
        "This is a sentence of text.\nNow it is a paragraph.\nA three-line one.",
        fill="black",
        font_size=24,
    )
    return image


@pytest.fixture()
def test_image_tall():
    image = Image.new("RGB", (4096, 4096), "white")
    draw = ImageDraw.Draw(image)
    draw.text((10, 10), "Hello World", fill="black", font_size=72)
    draw.text(
        (4000, 4000),
        "This is a sentence of text.\n\nNow it is a paragraph.\n\nA three-line one.",
        fill="black",
        font_size=24,
    )
    return image

@pytest.fixture()
def test_image_latex():
    assets_dir = os.path.join(os.path.dirname(__file__), "assets")
    img_path = os.path.join(assets_dir, "test_latex.png")
    image = Image.open(img_path).convert("RGB")
    return image
```

--------------------------------------------------------------------------------
/surya/debug/render_html.py:
--------------------------------------------------------------------------------

```python
import html as htmllib
import os.path
import re

filepath = os.path.abspath(__file__)

def render_text_as_html(
        bboxes: list[list[int]],
        texts: list[str],
        image_size: tuple[int, int],
        base_font_size: int = 16,
        scaler: int = 2
):
    katex_path = os.path.join(os.path.dirname(filepath), "katex.js")
    with open(katex_path, "r") as f:
        katex_script = f.read()

    html_content = []
    image_size = tuple([int(s * scaler) for s in image_size])
    width, height = image_size


    html_content.append(f"""
<!DOCTYPE html>
<html>
<head>
    <style>
        body {{
            margin: 0;
            padding: 0;
            width: {width}px;
            height: {height}px;
            position: relative;
            overflow: hidden;
            background: white;
            color: black;
        }}
        .text-box {{
            position: absolute;
            overflow: hidden;
            display: flex;
            justify-content: left;
            font-family: Arial, sans-serif;
            white-space: pre-wrap;
        }}
        .vertical-text {{
          writing-mode: vertical-rl;  /* Top to bottom, right to left */
        }}
    </style>
    {katex_script}
</head>
<body>
""")

    for i, (bbox, text) in enumerate(zip(bboxes, texts)):
        bbox = bbox.copy()
        bbox = [int(bb * scaler) for bb in bbox]
        x1, y1, x2, y2 = bbox
        width = x2 - x1
        height = y2 - y1
        min_dim = min(width, height)

        # Scale font size based on box height
        font_size = min(int(min_dim * 0.75), base_font_size)

        # Create div with absolute positioning
        div_style = (
            f"left: {x1}px; "
            f"top: {y1}px; "
            f"width: {width}px; "
            f"height: {height}px; "
            f"font-size: {font_size}px;"
        )

        class_ = "text-box"
        if height > width * 2:
            class_ += " vertical-text"

        # Determine if content is HTML/MathML or plain text
        if "<" in text and ">" in text and re.search(r"<(html|math|div|sub|sup|i|u|mark|small|del|b|br|code)\b", text.lower()):
            # Content is already HTML/MathML, include as-is
            html_content.append(f'<span class="{class_}" id="box-{i}" style="{div_style}">{text}</span>')
        else:
            # Plain text, escape it
            escaped_text = htmllib.escape(text)
            html_content.append(f'<span class="{class_}" id="box-{i}" style="{div_style}">{escaped_text}</span>')

    html_content.append("</body></html>")

    return "\n".join(html_content), image_size
```

--------------------------------------------------------------------------------
/surya/common/predictor.py:
--------------------------------------------------------------------------------

```python
from typing import Optional
import torch
import torch.nn.functional as F

from surya.common.load import ModelLoader
from surya.settings import settings


class BasePredictor:
    model_loader_cls = ModelLoader
    batch_size: Optional[int] = None
    default_batch_sizes = {"cpu": 1, "mps": 1, "cuda": 1}
    torch_dtype = settings.MODEL_DTYPE

    @property
    def disable_tqdm(self) -> bool:
        return self._disable_tqdm

    @disable_tqdm.setter
    def disable_tqdm(self, value: bool) -> None:
        self._disable_tqdm = bool(value)

    def __init__(
        self,
        checkpoint: Optional[str] = None,
        device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,
        dtype: Optional[torch.dtype | str] = None,
        attention_implementation: Optional[str] = None,
    ):
        if dtype is None:
            dtype = self.torch_dtype

        self.model = None
        self.processor = None
        loader = self.model_loader_cls(checkpoint)

        self.model = loader.model(device, dtype, attention_implementation)
        self.processor = loader.processor()

        self._disable_tqdm = settings.DISABLE_TQDM

    def to(self, device_dtype: torch.device | str | None = None):
        model_moved = False
        if hasattr(self, "model") and self.model:
            self.model.to(device_dtype)
            model_moved = True
        if hasattr(self, "foundation_predictor") and self.foundation_predictor:
            self.foundation_predictor.model.to(device_dtype)
            model_moved = True

        if not model_moved:
            raise ValueError("Model not loaded")

    def get_batch_size(self):
        batch_size = self.batch_size
        if batch_size is None:
            batch_size = self.default_batch_sizes["cpu"]
            if settings.TORCH_DEVICE_MODEL in self.default_batch_sizes:
                batch_size = self.default_batch_sizes[settings.TORCH_DEVICE_MODEL]
        return batch_size

    @staticmethod
    def pad_to_batch_size(tensor: torch.Tensor, batch_size: int):
        current_batch_size = tensor.shape[0]
        if current_batch_size >= batch_size:
            return tensor

        if len(tensor.shape) == 1:
            # If tensor is 1D, we need to pad it to the batch size
            pad_size = batch_size - current_batch_size
            return F.pad(tensor, (0, pad_size), mode="constant", value=0)

        pad_size = batch_size - current_batch_size
        padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)

        return F.pad(tensor, padding, mode="constant", value=0)

    def __call__(self, *args, **kwargs):
        raise NotImplementedError()

```

--------------------------------------------------------------------------------
/surya/scripts/config.py:
--------------------------------------------------------------------------------

```python
from typing import List

import click
import os
from surya.input.load import load_from_folder, load_from_file
from surya.settings import settings


class CLILoader:
    def __init__(self, filepath: str, cli_options: dict, highres: bool = False):
        self.page_range = cli_options.get("page_range")
        if self.page_range:
            self.page_range = self.parse_range_str(self.page_range)
        self.filepath = filepath
        self.config = cli_options
        self.save_images = cli_options.get("images", False)
        self.debug = cli_options.get("debug", False)
        self.output_dir = cli_options.get("output_dir")

        self.load(highres)

    @staticmethod
    def common_options(fn):
        fn = click.argument("input_path", type=click.Path(exists=True), required=True)(fn)
        fn = click.option("--output_dir", type=click.Path(exists=False), required=False, default=os.path.join(settings.RESULT_DIR, "surya"), help="Directory to save output.")(fn)
        fn = click.option("--page_range", type=str, default=None, help="Page range to convert, specify comma separated page numbers or ranges.  Example: 0,5-10,20")(fn)
        fn = click.option("--images", is_flag=True, help="Save images of detected bboxes.", default=False)(fn)
        fn = click.option('--debug', '-d', is_flag=True, help='Enable debug mode.', default=False)(fn)
        return fn

    def load(self, highres: bool = False):
        highres_images = None
        if os.path.isdir(self.filepath):
            images, names = load_from_folder(self.filepath, self.page_range)
            folder_name = os.path.basename(self.filepath)
            if highres:
                highres_images, _ = load_from_folder(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES)
        else:
            images, names = load_from_file(self.filepath, self.page_range)
            folder_name = os.path.basename(self.filepath).split(".")[0]
            if highres:
                highres_images, _ = load_from_file(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES)


        self.images = images
        self.highres_images = highres_images
        self.names = names

        self.result_path = os.path.abspath(os.path.join(self.output_dir, folder_name))
        os.makedirs(self.result_path, exist_ok=True)

    @staticmethod
    def parse_range_str(range_str: str) -> List[int]:
        range_lst = range_str.split(",")
        page_lst = []
        for i in range_lst:
            if "-" in i:
                start, end = i.split("-")
                page_lst += list(range(int(start), int(end) + 1))
            else:
                page_lst.append(int(i))
        page_lst = sorted(list(set(page_lst)))  # Deduplicate page numbers and sort in order
        return page_lst
```

--------------------------------------------------------------------------------
/tests/test_recognition.py:
--------------------------------------------------------------------------------

```python
import time
from PIL import ImageDraw, Image
from surya.recognition.util import clean_math_tags


def test_recognition(recognition_predictor, detection_predictor, test_image):
    recognition_results = recognition_predictor([test_image], None, detection_predictor)

    assert len(recognition_results) == 1
    assert recognition_results[0].image_bbox == [0, 0, 1024, 1024]

    text_lines = recognition_results[0].text_lines
    assert len(text_lines) == 4
    assert "Hello World" in text_lines[0].text


def test_recognition_input_text(recognition_predictor, detection_predictor, test_image):
    start = time.time()
    recognition_predictor([test_image], None, detection_predictor)
    end = time.time() - start

    input_text = "a" * 400
    start2 = time.time()
    recognition_results = recognition_predictor(
        [test_image], None, detection_predictor, input_text=[input_text]
    )
    end2 = time.time() - start2

    assert max([end, end2]) / min([end, end2]) < 1.5, (
        "Input text should be truncated and not change inference time"
    )

    assert len(recognition_results) == 1
    assert recognition_results[0].image_bbox == [0, 0, 1024, 1024]

    text_lines = recognition_results[0].text_lines
    assert len(text_lines) == 4
    assert "Hello World" in text_lines[0].text


def test_recognition_drop_repeats(recognition_predictor, detection_predictor):
    image = Image.new("RGB", (1024, 128), "white")
    draw = ImageDraw.Draw(image)
    text = "a" * 80
    draw.text((5, 5), text, fill="black", font_size=24)

    recognition_results = recognition_predictor(
        [image], None, bboxes=[[[0, 0, 1024, 128]]], drop_repeated_text=True
    )
    assert len(recognition_results) == 1
    result = recognition_results[0].text_lines
    assert result[0].text == ""


def test_recognition_clean_math():
    math = """<math display="block">na_n^{1+2r} \\text{cov}(\\hat{f}_n^{(r)}(x), \\hat{f}_n^{(r)}(y)) = \\frac{1}{n} \\sum_{j=1}^n \\frac{a_n^{1+2r}}{a_j^{1+2r}} \\text{cov}\\left(K^{(r)}\\left(\\frac{x-X_j}{a_j}\\right), K^{(r)}\\left(\\frac{y-X_j}{a_j}\\right)\\right) <br>+ \\frac{a_n^{1+2r}}{n} \\sum_{\\substack{j \\neq k \\\\ 1 \\le j, k \\le n}} \\frac{1}{(a_j a_k)^{1+r}} \\text{cov}\\left(K^{(r)}\\left(\\frac{x-X_j}{a_j}\\right), K^{(r)}\\left(\\frac{y-X_k}{a_k}\\right)\\right) <br>=: I_1 + I_2.</math> (1.7)</math>'"""
    clean_math = clean_math_tags(math)

    assert clean_math.count("</math>") == 1, "Should have one closing math tag"
    assert "<br>" not in clean_math, "Should not have <br> tags in cleaned math"


def test_recognition_clean_math_preserve_text():
    text = """Hello, this is a sentence with <math display="inline">x^2 + y^2 = z^2</math> and some text after it, with a weird tag <hello> and <goodbye>."""
    clean_text = clean_math_tags(text)

    assert clean_text == text

```

--------------------------------------------------------------------------------
/surya/input/processing.py:
--------------------------------------------------------------------------------

```python
from typing import List

import cv2
import numpy as np
import pypdfium2
from PIL import Image

from surya.logging import get_logger
from surya.settings import settings

logger = get_logger()


def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]:
    new_images = []
    for image in images:
        if image.mode != "RGB":
            image = image.convert("RGB")
        new_images.append(image)
    return new_images


def open_pdf(pdf_filepath):
    return pypdfium2.PdfDocument(pdf_filepath)


def get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI):
    images = [
        doc[i].render(scale=dpi / 72, draw_annots=False).to_pil() for i in indices
    ]
    images = [image.convert("RGB") for image in images]
    return images


def slice_bboxes_from_image(image: np.ndarray, bboxes):
    lines = []
    for bbox in bboxes:
        bbox = np.array(bbox, dtype=np.int32)
        bbox = np.clip(bbox, 0, None)  # Ensure no negative indices
        # Ensure bbox is within the image bounds
        if bbox[3] <= bbox[1]:
            bbox[3] = bbox[1] + 1

        if bbox[2] <= bbox[0]:
            bbox[2] = bbox[0] + 1

        bbox[2] = min(bbox[2], image.shape[1])
        bbox[3] = min(bbox[3], image.shape[0])

        line = image[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy()
        if line.size == 0:
            logger.warning(f"Warning: found an empty line with bbox {bbox}")
        lines.append(line)
    return lines


def slice_polys_from_image(image: np.ndarray, polys):
    lines = []
    for idx, poly in enumerate(polys):
        lines.append(slice_and_pad_poly(image, poly))
    return lines


def slice_and_pad_poly(image_array: np.array, coordinates):
    # Draw polygon onto mask
    coordinates = [(corner[0], corner[1]) for corner in coordinates]
    bbox = [
        min([x[0] for x in coordinates]),
        min([x[1] for x in coordinates]),
        max([x[0] for x in coordinates]),
        max([x[1] for x in coordinates]),
    ]

    # We mask out anything not in the polygon
    cropped_polygon = image_array[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy()
    height, width = cropped_polygon.shape[:2]

    coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates]

    # Validate the cropped area
    if any(
        [
            bbox[3] <= bbox[1] or bbox[2] <= bbox[0],
            len(coordinates) < 3,
            height == 0,
            width == 0,
        ]
    ):
        return cropped_polygon

    # Pad the area outside the polygon with the pad value
    try:
        mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8)
        cv2.fillPoly(mask, [np.int32(coordinates)], 1)
        mask = np.stack([mask] * 3, axis=-1)

        cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE
    except cv2.error as e:
        logger.warning(f"Warning: issue while processing polygon: {e}")

    return cropped_polygon

```

--------------------------------------------------------------------------------
/surya/table_rec/loader.py:
--------------------------------------------------------------------------------

```python
from typing import Optional

import torch

from surya.common.load import ModelLoader
from surya.logging import get_logger
from surya.settings import settings
from surya.table_rec.model.config import (
    SuryaTableRecConfig,
    SuryaTableRecDecoderConfig,
    DonutSwinTableRecConfig,
)
from surya.table_rec.model.encoderdecoder import TableRecEncoderDecoderModel
from surya.table_rec.processor import SuryaTableRecProcessor

logger = get_logger()


class TableRecModelLoader(ModelLoader):
    def __init__(self, checkpoint: Optional[str] = None):
        super().__init__(checkpoint)

        if self.checkpoint is None:
            self.checkpoint = settings.TABLE_REC_MODEL_CHECKPOINT

    def model(
        self,
        device=settings.TORCH_DEVICE_MODEL,
        dtype=settings.MODEL_DTYPE,
        attention_implementation: Optional[str] = None,
    ) -> TableRecEncoderDecoderModel:
        if device is None:
            device = settings.TORCH_DEVICE_MODEL
        if dtype is None:
            dtype = settings.MODEL_DTYPE

        if device == "mps":
            logger.warning(
                "`TableRecEncoderDecoderModel` is not compatible with mps backend. Defaulting to cpu instead"
            )
            device = "cpu"
            dtype = "float32"

        config = SuryaTableRecConfig.from_pretrained(self.checkpoint)
        decoder_config = config.decoder
        decoder = SuryaTableRecDecoderConfig(**decoder_config)
        config.decoder = decoder

        encoder_config = config.encoder
        encoder = DonutSwinTableRecConfig(**encoder_config)
        config.encoder = encoder

        model = TableRecEncoderDecoderModel.from_pretrained(
            self.checkpoint, config=config, dtype=dtype
        )

        model = model.to(device)
        model = model.eval()

        if settings.COMPILE_ALL or settings.COMPILE_TABLE_REC:
            torch.set_float32_matmul_precision("high")
            torch._dynamo.config.cache_size_limit = 16
            torch._dynamo.config.suppress_errors = False

            logger.info(
                f"Compiling table recognition model {self.checkpoint} on device {device} with dtype {dtype}"
            )
            compile_args = {"backend": "openxla"} if device == "xla" else {}
            model.encoder = torch.compile(model.encoder, **compile_args)
            model.decoder = torch.compile(model.decoder, **compile_args)

        logger.debug(
            f"Loaded table recognition model {self.checkpoint} from {TableRecEncoderDecoderModel.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}"
        )
        return model

    def processor(
        self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE
    ) -> SuryaTableRecProcessor:
        processor = SuryaTableRecProcessor(self.checkpoint)

        processor.token_pad_id = 0
        processor.token_eos_id = 1
        processor.token_bos_id = 1
        processor.token_query_end_id = 4
        return processor

```

--------------------------------------------------------------------------------
/surya/common/surya/decoder/config.py:
--------------------------------------------------------------------------------

```python
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging

logger = logging.get_logger(__name__)


class SuryaDecoderConfig(PretrainedConfig):
    model_type = "qwen2"
    keys_to_ignore_at_inference = ["past_key_values"]

    # Default tensor parallel plan for base model `Qwen2`
    base_model_tp_plan = {
        "layers.*.self_attn.q_proj": "colwise",
        "layers.*.self_attn.k_proj": "colwise",
        "layers.*.self_attn.v_proj": "colwise",
        "layers.*.self_attn.o_proj": "rowwise",
        "layers.*.mlp.gate_proj": "colwise",
        "layers.*.mlp.up_proj": "colwise",
        "layers.*.mlp.down_proj": "rowwise",
    }
    base_model_pp_plan = {
        "embed_tokens": (["input_ids"], ["inputs_embeds"]),
        "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
        "norm": (["hidden_states"], ["hidden_states"]),
    }

    def __init__(
        self,
        vocab_size=151936,
        hidden_size=4096,
        intermediate_size=22016,
        num_hidden_layers=32,
        num_attention_heads=32,
        num_key_value_heads=32,
        hidden_act="silu",
        max_position_embeddings=32768,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=True,
        tie_word_embeddings=False,
        rope_theta=10000.0,
        rope_scaling=None,
        use_sliding_window=False,
        sliding_window=4096,
        max_window_layers=28,
        attention_dropout=0.0,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.use_sliding_window = False  # Disable sliding window
        self.sliding_window = (
            sliding_window  # we check `use_sliding_window` in the modeling code
        )
        self.max_window_layers = max_window_layers

        # for backward compatibility
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads

        self.num_key_value_heads = num_key_value_heads
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.rope_scaling = rope_scaling
        self.attention_dropout = attention_dropout
        # Validate the correctness of rotary position embeddings parameters
        # BC: if there is a 'type' field, move it to 'rope_type'.
        if self.rope_scaling is not None and "type" in self.rope_scaling:
            self.rope_scaling["rope_type"] = self.rope_scaling["type"]
        rope_config_validation(self)

        super().__init__(
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )

```

--------------------------------------------------------------------------------
/benchmark/ordering.py:
--------------------------------------------------------------------------------

```python
import collections
import json

import click

from surya.foundation import FoundationPredictor
from surya.input.processing import convert_if_not_rgb
from surya.layout import LayoutPredictor
from surya.common.polygon import PolygonBox
from surya.settings import settings
from benchmark.utils.metrics import rank_accuracy
import os
import time
import datasets


@click.command(help="Benchmark surya layout for reading order.")
@click.option(
    "--results_dir",
    type=str,
    help="Path to JSON file with benchmark results.",
    default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
@click.option(
    "--max_rows",
    type=int,
    help="Maximum number of images to run benchmark on.",
    default=None,
)
def main(results_dir: str, max_rows: int):
    foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
    layout_predictor = LayoutPredictor(foundation_predictor)
    pathname = "order_bench"
    # These have already been shuffled randomly, so sampling from the start is fine
    split = "train"
    if max_rows is not None:
        split = f"train[:{max_rows}]"
    dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split)
    images = list(dataset["image"])
    images = convert_if_not_rgb(images)

    start = time.time()
    layout_predictions = layout_predictor(images)
    surya_time = time.time() - start

    folder_name = os.path.basename(pathname).split(".")[0]
    result_path = os.path.join(results_dir, folder_name)
    os.makedirs(result_path, exist_ok=True)

    page_metrics = collections.OrderedDict()
    mean_accuracy = 0
    for idx, order_pred in enumerate(layout_predictions):
        row = dataset[idx]
        labels = row["labels"]
        bboxes = row["bboxes"]
        pred_positions = []
        for label, bbox in zip(labels, bboxes):
            max_intersection = 0
            matching_idx = 0
            for pred_box in order_pred.bboxes:
                intersection = pred_box.intersection_pct(PolygonBox(polygon=bbox))
                if intersection > max_intersection:
                    max_intersection = intersection
                    matching_idx = pred_box.position
            pred_positions.append(matching_idx)
        accuracy = rank_accuracy(pred_positions, labels)
        mean_accuracy += accuracy
        page_results = {"accuracy": accuracy, "box_count": len(labels)}

        page_metrics[idx] = page_results

    mean_accuracy /= len(layout_predictions)

    out_data = {
        "time": surya_time,
        "mean_accuracy": mean_accuracy,
        "page_metrics": page_metrics,
    }

    with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
        json.dump(out_data, f, indent=4)

    print(f"Mean accuracy is {mean_accuracy:.2f}.")
    print(
        f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total."
    )
    print("Mean accuracy is the % of correct ranking pairs.")
    print(f"Wrote results to {result_path}")


if __name__ == "__main__":
    main()

```

--------------------------------------------------------------------------------
/surya/debug/text.py:
--------------------------------------------------------------------------------

```python
import re
from io import BytesIO
from typing import List, Tuple
from PIL import Image, ImageDraw, ImageFont

from surya.debug.fonts import get_font_path
from surya.debug.render_html import render_text_as_html

try:
    from playwright.sync_api import sync_playwright

    has_playwright = True
except ImportError:
    has_playwright = False


def strip_html_tags(html_text):
    pattern = re.compile(r"<[\w/][^>]*>")
    text_only = pattern.sub("", html_text)

    return text_only


def get_text_size(text, font):
    im = Image.new(mode="P", size=(0, 0))
    draw = ImageDraw.Draw(im)
    _, _, width, height = draw.textbbox((0, 0), text=text, font=font)
    return width, height


def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size):
    font = ImageFont.truetype(font_path, box_font_size)
    text_width, text_height = get_text_size(text, font)
    while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6:
        box_font_size = box_font_size - 1
        font = ImageFont.truetype(font_path, box_font_size)
        text_width, text_height = get_text_size(text, font)

    # Calculate text position (centered in bbox)
    text_width, text_height = get_text_size(text, font)
    x = s_bbox[0]
    y = s_bbox[1] + (bbox_height - text_height) / 2

    draw.text((x, y), text, fill="black", font=font)


def draw_text_with_playwright(
    bboxes, texts: List[str], image_size: Tuple[int, int]
) -> Image.Image:
    html_content, image_size = render_text_as_html(bboxes, texts, image_size)
    if not has_playwright:
        raise ImportError(
            "Playwright is not installed. Please install it using `pip install playwright`"
        )

    with sync_playwright() as p:
        browser = p.chromium.launch(headless=True)
        page = browser.new_page(
            viewport={"width": image_size[0], "height": image_size[1]}
        )
        page.set_content(html_content)
        page.wait_for_timeout(1000)
        body = page.query_selector("body")
        image = body.screenshot()
        browser.close()

    pil_img = Image.open(BytesIO(image))
    return pil_img


def draw_text_on_image(
    bboxes,
    texts,
    image_size: Tuple[int, int],
    font_path=None,
    max_font_size=60,
    res_upscale=2,
) -> Image.Image:
    if has_playwright:
        return draw_text_with_playwright(bboxes, texts, image_size)

    texts = [strip_html_tags(text) for text in texts]
    if font_path is None:
        font_path = get_font_path()
    new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale)
    image = Image.new("RGB", new_image_size, color="white")
    draw = ImageDraw.Draw(image)

    for bbox, text in zip(bboxes, texts):
        s_bbox = [int(coord * res_upscale) for coord in bbox]
        bbox_width = s_bbox[2] - s_bbox[0]
        bbox_height = s_bbox[3] - s_bbox[1]

        # Shrink the text to fit in the bbox if needed
        box_font_size = max(6, min(int(0.75 * bbox_height), max_font_size))
        render_text(
            draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size
        )

    return image

```

--------------------------------------------------------------------------------
/surya/recognition/postprocessing.py:
--------------------------------------------------------------------------------

```python
import re
from typing import List, Dict

from surya.recognition.schema import TextChar


def truncate_repetitions(text: str, min_len=15):
    # From nougat, with some cleanup
    if len(text) < 2 * min_len:
        return text

    # try to find a length at which the tail is repeating
    max_rep_len = None
    for rep_len in range(min_len, int(len(text) / 2)):
        # check if there is a repetition at the end
        same = True
        for i in range(0, rep_len):
            if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]:
                same = False
                break

        if same:
            max_rep_len = rep_len

    if max_rep_len is None:
        return text

    lcs = text[-max_rep_len:]

    # remove all but the last repetition
    text_to_truncate = text
    while text_to_truncate.endswith(lcs):
        text_to_truncate = text_to_truncate[:-max_rep_len]

    return text[: len(text_to_truncate)]


def extract_tags(proposed_tags: List[str]) -> List[str]:
    tags = []
    for tag in proposed_tags:
        tag_match = re.match(tag_pattern, tag)
        if not tag_match:
            continue

        if not tag_match.group(1) == "/":
            continue

        tags.append(tag_match.group(2))
    return tags


tag_pattern = re.compile(r"<(/?)([a-z]+)([^>]*)>?", re.IGNORECASE)


def cleanup_math(line: str):
    matches = re.finditer(r"(<math[^>]*>)(.*?)</math>", line, re.DOTALL)
    result = line

    for match in matches:
        opening_tag = match.group(1)  # The opening <math> tag with attributes
        full_match = match.group(0)  # The entire <math>content</math> tag
        block_content = match.group(2)  # Just the content inside the tags

        clean_block = re.sub(r"<[^>]+>", "", block_content)

        if not re.search(r"[\\\_]", clean_block):
            result = result.replace(full_match, clean_block)
        else:
            result = result.replace(full_match, f"{opening_tag}{clean_block}</math>")

    return result


def fix_unbalanced_tags(
    text_chars: List[TextChar], special_tokens: Dict[str, list]
) -> List[TextChar]:
    self_closing_tags = ["br"]

    open_tags = []

    format_tags = extract_tags(special_tokens["formatting"]) + extract_tags(
        special_tokens["math_external"]
    )

    for char in text_chars:
        if len(char.text) <= 1:
            continue

        tag_match = re.match(tag_pattern, char.text)
        if not tag_match:
            continue

        is_closing = tag_match.group(1) == "/"
        tag_name = tag_match.group(2).lower()

        if tag_name not in format_tags:
            continue

        if tag_name in self_closing_tags:
            continue

        # Self-closing tags
        if tag_match.group(3) and tag_match.group(3).strip().endswith("/"):
            continue

        if is_closing:
            if open_tags and open_tags[-1] == tag_name:
                open_tags.pop()
        else:
            open_tags.append(tag_name)

    for tag in open_tags:
        text_chars.append(
            TextChar(
                text=f"</{tag}>",
                confidence=0,
                polygon=[[0, 0], [1, 0], [1, 1], [0, 1]],
                bbox_valid=False,
            )
        )
    return text_chars

```

--------------------------------------------------------------------------------
/surya/common/surya/config.py:
--------------------------------------------------------------------------------

```python
from typing import Optional
from transformers import PretrainedConfig

from surya.common.s3 import S3DownloaderMixin
from surya.common.surya.encoder.config import SuryaEncoderConfig
from surya.common.surya.decoder.config import SuryaDecoderConfig


class SuryaModelConfig(S3DownloaderMixin, PretrainedConfig):
    model_type = "surya-multimodal-foundation"
    is_composition = True

    def __init__(
        self,
        vocab_size=65536,
        bbox_size=1025,
        blank_bbox_token_id=1025,
        bos_token_id=0,
        eos_token_id=1,
        pad_token_id=2,
        image_token_id=3,
        register_token_ids=(4, 5, 6, 7),
        eoi_token_id=8,
        beacon_token_id=9,
        special_token_count=4,
        max_sequence_length=1536,
        special_ocr_tokens=None,
        vision_encoder=None,
        decoder=None,
        tasks: dict | None = None,
        bbox_embed_size: int = 64,
        num_register_tokens: int = 4,
        image_embed_encoding_size: int = 1024,
        image_embed_encoding_multiplier: int = 256,
        num_beacon_tokens: int = 1,
        beacon_token_interval: int = 4096,
        sliding_window: Optional[int] = None,
        multi_output_distance: int = 4,
        max_multi_out: int = 8,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.is_encoder_decoder = False
        self.vocab_size = vocab_size
        self.bbox_size = bbox_size
        self.blank_bbox_token_id = blank_bbox_token_id
        self.image_token_id = image_token_id
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id
        self.pad_token_id = pad_token_id
        self.eoi_token_id = eoi_token_id
        self.beacon_token_id = beacon_token_id
        self.special_ocr_tokens = special_ocr_tokens
        self.special_token_count = special_token_count  # pad, bos, etc, tokens
        self.max_sequence_length = max_sequence_length
        self.tasks = tasks
        self.tie_word_embeddings = True
        self.bbox_embed_size = bbox_embed_size
        self.num_register_tokens = num_register_tokens
        self.register_token_ids = register_token_ids
        self.image_embed_encoding_size = image_embed_encoding_size
        self.image_embed_encoding_multiplier = image_embed_encoding_multiplier
        self.num_beacon_tokens = num_beacon_tokens
        self.beacon_token_interval = beacon_token_interval
        self.sliding_window = sliding_window
        self.multi_output_distance = multi_output_distance
        self.max_multi_out = max_multi_out

        if self.sliding_window is None:
            self.sliding_window = self.max_sequence_length

        if isinstance(vision_encoder, dict):
            vision_encoder = SuryaEncoderConfig(**vision_encoder)
        elif vision_encoder is None:
            vision_encoder = SuryaEncoderConfig()
        self.vision_encoder = vision_encoder

        if isinstance(decoder, dict):
            decoder = SuryaDecoderConfig(**decoder)
        elif decoder is None:
            decoder = SuryaDecoderConfig()
        self.decoder = decoder

        self.hidden_size = self.decoder.hidden_size

        self.patch_size = self.vision_encoder.spatial_patch_size
        self.merge_size = self.vision_encoder.spatial_merge_size

```

--------------------------------------------------------------------------------
/surya/table_rec/processor.py:
--------------------------------------------------------------------------------

```python
from typing import List

import PIL
import torch
from transformers import ProcessorMixin

from surya.common.s3 import S3DownloaderMixin
from surya.common.donut.processor import SuryaEncoderImageProcessor
from surya.table_rec.shaper import LabelShaper
from surya.settings import settings
from surya.table_rec.model.config import BOX_DIM, SPECIAL_TOKENS


class SuryaTableRecProcessor(S3DownloaderMixin, ProcessorMixin):
    attributes = ["image_processor"]
    image_processor_class = "AutoImageProcessor"

    def __init__(self, checkpoint, **kwargs):
        image_processor = SuryaEncoderImageProcessor.from_pretrained(checkpoint)
        image_processor.do_align_long_axis = False
        image_processor.max_size = settings.TABLE_REC_IMAGE_SIZE
        self.image_processor = image_processor
        super().__init__(image_processor)

        self.box_size = (BOX_DIM, BOX_DIM)
        self.special_token_count = SPECIAL_TOKENS
        self.shaper = LabelShaper()

    def resize_polygon(self, polygon, orig_size, new_size):
        w_scaler = new_size[0] / orig_size[0]
        h_scaler = new_size[1] / orig_size[1]

        for corner in polygon:
            corner[0] = corner[0] * w_scaler
            corner[1] = corner[1] * h_scaler

            if corner[0] < 0:
                corner[0] = 0
            if corner[1] < 0:
                corner[1] = 0
            if corner[0] > new_size[0]:
                corner[0] = new_size[0]
            if corner[1] > new_size[1]:
                corner[1] = new_size[1]

        return polygon

    def __call__(
            self,
            images: List[PIL.Image.Image] | None,
            query_items: List[dict],
            columns: List[dict] | None = None,
            convert_images: bool = True,
            *args,
            **kwargs
    ):
        if convert_images:
            assert len(images) == len(query_items)
            assert len(images) > 0

            # Resize input query items
            for image, query_item in zip(images, query_items):
                query_item["polygon"] = self.resize_polygon(query_item["polygon"], image.size, self.box_size)

        query_items = self.shaper.convert_polygons_to_bboxes(query_items)
        query_labels = self.shaper.dict_to_labels(query_items)

        decoder_input_boxes = []
        col_count = len(query_labels[0])
        for label in query_labels:
            decoder_input_boxes.append([
                [self.token_bos_id] * col_count,
                label,
                [self.token_query_end_id] * col_count
            ])

        # Add columns to end of decoder input
        if columns:
            columns = self.shaper.convert_polygons_to_bboxes(columns)
            column_labels = self.shaper.dict_to_labels(columns)
            for decoder_box in decoder_input_boxes:
                decoder_box += column_labels

        input_boxes = torch.tensor(decoder_input_boxes, dtype=torch.long)
        input_boxes_mask = torch.ones_like(input_boxes, dtype=torch.long)

        inputs = {
            "input_ids": input_boxes,
            "attention_mask": input_boxes_mask
        }
        if convert_images:
            inputs["pixel_values"] = self.image_processor(images, *args, **kwargs)["pixel_values"]
        return inputs

```

--------------------------------------------------------------------------------
/benchmark/utils/tatr.py:
--------------------------------------------------------------------------------

```python
import torch
from transformers import AutoModelForObjectDetection
from surya.settings import settings
import numpy as np


class MaxResize(object):
    def __init__(self, max_size=800):
        self.max_size = max_size

    def __call__(self, image):
        width, height = image.size
        current_max_size = max(width, height)
        scale = self.max_size / current_max_size
        resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))

        return resized_image


def to_tensor(image):
    # Convert PIL Image to NumPy array
    np_image = np.array(image).astype(np.float32)

    # Rearrange dimensions to [C, H, W] format
    np_image = np_image.transpose((2, 0, 1))

    # Normalize to [0.0, 1.0]
    np_image /= 255.0

    return torch.from_numpy(np_image)


def normalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


def structure_transform(image):
    image = MaxResize(1000)(image)
    tensor = to_tensor(image)
    normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    return normalized_tensor


def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    width, height = size
    boxes = box_cxcywh_to_xyxy(out_bbox)
    boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
    return boxes


def outputs_to_objects(outputs, img_sizes, id2label):
    m = outputs.logits.softmax(-1).max(-1)
    batch_labels = list(m.indices.detach().cpu().numpy())
    batch_scores = list(m.values.detach().cpu().numpy())
    batch_bboxes = outputs['pred_boxes'].detach().cpu()

    batch_objects = []
    for i in range(len(img_sizes)):
        pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])]
        pred_scores = batch_scores[i]
        pred_labels = batch_labels[i]

        objects = []
        for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
            class_label = id2label[int(label)]
            if not class_label == 'no object':
                objects.append({
                    'label': class_label,
                    'score': float(score),
                    'bbox': [float(elem) for elem in bbox]}
                )

        rows = []
        cols = []
        for cell in objects:
            if cell["label"] == "table column":
                cols.append(cell)

            if cell["label"] == "table row":
                rows.append(cell)
        batch_objects.append({
            "rows": rows,
            "cols": cols
        })

    return batch_objects


def load_tatr():
    return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL)


def batch_inference_tatr(model, images, batch_size):
    device = model.device
    rows_cols = []
    for i in range(0, len(images), batch_size):
        batch_images = images[i:i + batch_size]
        pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device)

        # forward pass
        with torch.no_grad():
            outputs = model(pixel_values)

        id2label = model.config.id2label
        id2label[len(model.config.id2label)] = "no object"
        rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label))
    return rows_cols
```

--------------------------------------------------------------------------------
/benchmark/texify.py:
--------------------------------------------------------------------------------

```python
import os.path
import re
import time
from pathlib import Path
from typing import List

import click
import datasets
from tabulate import tabulate
from bs4 import BeautifulSoup

from surya.common.surya.schema import TaskNames
from surya.settings import settings
from surya.foundation import FoundationPredictor
from surya.recognition import RecognitionPredictor, OCRResult
import json
from rapidfuzz.distance import Levenshtein


def normalize_text(text):
    soup = BeautifulSoup(text, "html.parser")
    # Unwrap math tags
    for tag in soup.find_all():
        if tag.name == "math":
            tag.unwrap()
    text = soup.get_text()
    text = re.sub(r"\n", " ", text)
    text = re.sub(r"\s+", " ", text)
    return text.strip()


def score_text(predictions, references):
    lev_dist = []
    for p, r in zip(predictions, references):
        p = normalize_text(p)
        r = normalize_text(r)
        lev_dist.append(Levenshtein.normalized_distance(p, r))

    return sum(lev_dist) / len(lev_dist)


def inference_texify(
    source_data, predictor: RecognitionPredictor, line_mode: bool = False
):
    images = [sd["image"] for sd in source_data]
    mode = TaskNames.ocr_with_boxes if line_mode else TaskNames.block_without_boxes
    tasks = [mode] * len(images)
    bboxes = [[[0, 0, image.width, image.height]] for image in images]
    texify_predictions: List[OCRResult] = predictor(images, tasks, bboxes=bboxes)
    out_data = [
        {
            "text": texify_predictions[i].text_lines[0].text,
            "equation": source_data[i]["equation"],
        }
        for i in range(len(texify_predictions))
    ]

    return out_data


@click.command(help="Benchmark the performance of texify.")
@click.option(
    "--ds_name",
    type=str,
    help="Path to dataset file with source images/equations.",
    default=settings.TEXIFY_BENCHMARK_DATASET,
)
@click.option(
    "--results_dir",
    type=str,
    help="Path to JSON file with benchmark results.",
    default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
@click.option(
    "--max_rows", type=int, help="Maximum number of images to benchmark.", default=None
)
@click.option(
    "--line_mode", is_flag=True, help="Use line mode for texify.", default=False
)
def main(ds_name: str, results_dir: str, max_rows: int, line_mode: bool):
    foundation_predictor = FoundationPredictor()
    predictor = RecognitionPredictor(foundation_predictor)
    ds = datasets.load_dataset(ds_name, split="train")

    if max_rows:
        ds = ds.filter(lambda x, idx: idx < max_rows, with_indices=True)

    start = time.time()
    predictions = inference_texify(ds, predictor, line_mode)
    time_taken = time.time() - start

    text = [p["text"] for p in predictions]
    references = [p["equation"] for p in predictions]
    scores = score_text(text, references)

    write_data = {
        "scores": scores,
        "text": [{"prediction": p, "reference": r} for p, r in zip(text, references)],
    }

    score_table = [["texify", write_data["scores"], time_taken]]
    score_headers = ["edit", "time taken (s)"]
    score_dirs = ["⬇", "⬇"]

    score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)]
    table = tabulate(score_table, headers=["Method", *score_headers])
    print()
    print(table)

    result_path = Path(results_dir) / "texify_bench"
    result_path.mkdir(parents=True, exist_ok=True)
    with open(result_path / "results.json", "w", encoding="utf-8") as f:
        json.dump(write_data, f, indent=4)


if __name__ == "__main__":
    main()

```

--------------------------------------------------------------------------------
/surya/table_rec/model/encoder.py:
--------------------------------------------------------------------------------

```python
from typing import Optional, Union, Tuple

import torch
import torch.nn as nn

from surya.common.donut.encoder import DonutSwinPreTrainedModel, DonutSwinModelOutput, DonutSwinEmbeddings, DonutSwinEncoder


class DonutSwinModel(DonutSwinPreTrainedModel):
    def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
        super().__init__(config)
        self.config = config
        self.num_layers = len(config.depths)
        self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))

        self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token)
        self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)

        self.position_embeddings = None
        if hasattr(config, "encoder_length"):
            self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size))

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embeddings.patch_embeddings

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: bool = False,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, DonutSwinModelOutput]:
        r"""
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, len(self.config.depths))

        embedding_output, input_dimensions = self.embeddings(
            pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
        )

        encoder_outputs = self.encoder(
            embedding_output,
            input_dimensions,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = encoder_outputs[0]

        if self.position_embeddings is not None:
            last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :]

        return DonutSwinModelOutput(
            last_hidden_state=last_hidden_state,
        )

```

--------------------------------------------------------------------------------
/surya/table_rec/model/encoderdecoder.py:
--------------------------------------------------------------------------------

```python
from dataclasses import dataclass
from typing import Optional, Union, Tuple, Dict

import torch
from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig

from surya.common.pretrained import SuryaPreTrainedModel
from surya.common.s3 import S3DownloaderMixin
from surya.table_rec.model.decoder import SuryaTableRecDecoder
from surya.table_rec.model.encoder import DonutSwinModel
from transformers.utils import ModelOutput


@dataclass
class TableRecOutput(ModelOutput):
    box_property_logits: Dict[str, torch.FloatTensor]
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None


class TableRecEncoderDecoderModel(S3DownloaderMixin, SuryaPreTrainedModel):
    config_class = VisionEncoderDecoderConfig
    base_model_prefix = "vision_encoder_decoder"
    main_input_name = "pixel_values"
    supports_gradient_checkpointing = True
    _supports_param_buffer_assignment = False

    def __init__(
        self,
        config: Optional[PretrainedConfig] = None,
        encoder: Optional[PreTrainedModel] = None,
        decoder: Optional[PreTrainedModel] = None,
        **kwargs,
    ):
        # initialize with config
        # make sure input & output embeddings is not tied
        config.tie_word_embeddings = False
        config.decoder.tie_word_embeddings = False
        super().__init__(config, **kwargs)

        if encoder is None:
            encoder = DonutSwinModel(config.encoder)

        if decoder is None:
            decoder = SuryaTableRecDecoder(
                config.decoder, attn_implementation=config._attn_implementation
            )

        self.encoder = encoder
        self.decoder = decoder

        # make sure that the individual model's config refers to the shared config
        # so that the updates to the config will be synced
        self.encoder.config = self.config.encoder
        self.decoder.config = self.config.decoder

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    def get_output_embeddings(self):
        return self.decoder.get_output_embeddings()

    def set_output_embeddings(self, new_embeddings):
        return self.decoder.set_output_embeddings(new_embeddings)

    def forward(
        self,
        decoder_input_ids: torch.LongTensor = None,
        decoder_cache_position: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ) -> Union[Tuple[torch.FloatTensor], TableRecOutput]:
        kwargs_decoder = {
            argument[len("decoder_") :]: value
            for argument, value in kwargs.items()
            if argument.startswith("decoder_")
        }

        # Decode
        decoder_outputs = self.decoder(
            input_labels=decoder_input_ids,
            input_boxes_counts=None,
            cache_position=decoder_cache_position,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_outputs,
            encoder_attention_mask=None,
            use_cache=use_cache,
            **kwargs_decoder,
        )

        return TableRecOutput(
            box_property_logits=decoder_outputs.box_property_logits,
            decoder_hidden_states=decoder_outputs.hidden_states,
        )

    def resize_token_embeddings(self, *args, **kwargs):
        raise NotImplementedError(
            "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
            " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
        )

    def _reorder_cache(self, past_key_values, beam_idx):
        # apply decoder cache reordering here
        return self.decoder._reorder_cache(past_key_values, beam_idx)

```

--------------------------------------------------------------------------------
/signatures/version1/cla.json:
--------------------------------------------------------------------------------

```json
{
  "signedContributors": [
    {
      "name": "rishiraj",
      "id": 44090649,
      "comment_id": 2170578748,
      "created_at": "2024-06-15T19:31:20Z",
      "repoId": 741297064,
      "pullRequestNo": 135
    },
    {
      "name": "mmacvicar",
      "id": 59354,
      "comment_id": 2236493182,
      "created_at": "2024-07-18T13:17:43Z",
      "repoId": 741297064,
      "pullRequestNo": 152
    },
    {
      "name": "jimexist",
      "id": 622789,
      "comment_id": 2255151376,
      "created_at": "2024-07-29T07:23:55Z",
      "repoId": 741297064,
      "pullRequestNo": 160
    },
    {
      "name": "michaeldriscoll-avant",
      "id": 85255083,
      "comment_id": 2259143427,
      "created_at": "2024-07-30T20:21:33Z",
      "repoId": 741297064,
      "pullRequestNo": 161
    },
    {
      "name": "EdoardoPona",
      "id": 29152472,
      "comment_id": 2271115922,
      "created_at": "2024-08-06T11:58:00Z",
      "repoId": 741297064,
      "pullRequestNo": 167
    },
    {
      "name": "hidenori-endo",
      "id": 15546605,
      "comment_id": 2307217499,
      "created_at": "2024-08-23T14:31:17Z",
      "repoId": 741297064,
      "pullRequestNo": 182
    },
    {
      "name": "dobosevych",
      "id": 12053536,
      "comment_id": 2430376828,
      "created_at": "2024-10-22T21:48:34Z",
      "repoId": 741297064,
      "pullRequestNo": 220
    },
    {
      "name": "iammosespaulr",
      "id": 28682735,
      "comment_id": 2447941238,
      "created_at": "2024-10-30T17:55:23Z",
      "repoId": 741297064,
      "pullRequestNo": 235
    },
    {
      "name": "ArthurMor4is",
      "id": 42987302,
      "comment_id": 2515315717,
      "created_at": "2024-12-03T18:37:45Z",
      "repoId": 741297064,
      "pullRequestNo": 255
    },
    {
      "name": "tarun-menta",
      "id": 66506307,
      "comment_id": 2543457960,
      "created_at": "2024-12-15T05:43:33Z",
      "repoId": 741297064,
      "pullRequestNo": 261
    },
    {
      "name": "jonaskahn",
      "id": 4338500,
      "comment_id": 2556622097,
      "created_at": "2024-12-20T09:36:20Z",
      "repoId": 741297064,
      "pullRequestNo": 269
    },
    {
      "name": "kumsumit",
      "id": 95072784,
      "comment_id": 2574534622,
      "created_at": "2025-01-07T07:05:59Z",
      "repoId": 741297064,
      "pullRequestNo": 276
    },
    {
      "name": "kevinhu",
      "id": 6051736,
      "comment_id": 2614135351,
      "created_at": "2025-01-25T23:34:12Z",
      "repoId": 741297064,
      "pullRequestNo": 291
    },
    {
      "name": "zanussbaum",
      "id": 33707069,
      "comment_id": 3008673416,
      "created_at": "2025-06-26T14:20:46Z",
      "repoId": 741297064,
      "pullRequestNo": 403
    },
    {
      "name": "mebriki",
      "id": 35892987,
      "comment_id": 3154706976,
      "created_at": "2025-08-05T10:54:27Z",
      "repoId": 741297064,
      "pullRequestNo": 418
    },
    {
      "name": "starikovplusplus",
      "id": 56602036,
      "comment_id": 3168958011,
      "created_at": "2025-08-08T18:29:50Z",
      "repoId": 741297064,
      "pullRequestNo": 423
    },
    {
      "name": "sandy0kwon",
      "id": 78377296,
      "comment_id": 3207932260,
      "created_at": "2025-08-20T20:07:15Z",
      "repoId": 741297064,
      "pullRequestNo": 434
    },
    {
      "name": "n0kovo",
      "id": 16690056,
      "comment_id": 3208251881,
      "created_at": "2025-08-20T22:22:06Z",
      "repoId": 741297064,
      "pullRequestNo": 435
    },
    {
      "name": "davidxifeng",
      "id": 158052,
      "comment_id": 3249594859,
      "created_at": "2025-09-03T14:52:16Z",
      "repoId": 741297064,
      "pullRequestNo": 445
    },
    {
      "name": "u-ashish",
      "id": 14264791,
      "comment_id": 3258734182,
      "created_at": "2025-09-05T15:16:48Z",
      "repoId": 741297064,
      "pullRequestNo": 447
    },
    {
      "name": "Mohking1",
      "id": 63689545,
      "comment_id": 3314908963,
      "created_at": "2025-09-20T11:21:42Z",
      "repoId": 741297064,
      "pullRequestNo": 462
    },
    {
      "name": "wkpark",
      "id": 232347,
      "comment_id": 3330009557,
      "created_at": "2025-09-24T17:42:55Z",
      "repoId": 741297064,
      "pullRequestNo": 464
    }
  ]
}
```

--------------------------------------------------------------------------------
/surya/layout/__init__.py:
--------------------------------------------------------------------------------

```python
from typing import List

from PIL import Image

from surya.common.predictor import BasePredictor
from surya.layout.schema import LayoutBox, LayoutResult
from surya.settings import settings
from surya.foundation import FoundationPredictor, TaskNames
from surya.foundation.util import prediction_to_polygon_batch
from surya.input.processing import convert_if_not_rgb
from surya.layout.label import LAYOUT_PRED_RELABEL
from surya.common.util import clean_boxes


class LayoutPredictor(BasePredictor):
    batch_size = settings.LAYOUT_BATCH_SIZE
    default_batch_sizes = {"cpu": 4, "mps": 4, "cuda": 32, "xla": 16}

    # Override base init - Do not load model
    def __init__(self, foundation_predictor: FoundationPredictor):
        self.foundation_predictor = foundation_predictor
        self.processor = self.foundation_predictor.processor
        self.bbox_size = self.foundation_predictor.model.config.bbox_size
        self.tasks = self.foundation_predictor.tasks

    # Special handling for disable tqdm to pass into foundation predictor
    # Make sure they are kept in sync
    @property
    def disable_tqdm(self) -> bool:
        return super().disable_tqdm

    @disable_tqdm.setter
    def disable_tqdm(self, value: bool) -> None:
        self._disable_tqdm = bool(value)
        self.foundation_predictor.disable_tqdm = bool(value)

    def __call__(
        self, images: List[Image.Image], batch_size: int | None = None, top_k: int = 5
    ) -> List[LayoutResult]:
        assert all([isinstance(image, Image.Image) for image in images])
        if batch_size is None:
            batch_size = self.get_batch_size()

        if len(images) == 0:
            return []

        images = convert_if_not_rgb(images)
        images = [self.processor.image_processor(image) for image in images]

        predicted_tokens, batch_bboxes, scores, topk_scores = (
            self.foundation_predictor.prediction_loop(
                images=images,
                input_texts=["" for _ in range(len(images))],
                task_names=[TaskNames.layout for _ in range(len(images))],
                batch_size=batch_size,
                max_lookahead_tokens=0,  # Do not do MTP for layout
                top_k=5,
                max_sliding_window=576,
                max_tokens=500,
                tqdm_desc="Recognizing Layout"
            )
        )

        image_sizes = [img.shape for img in images]
        predicted_polygons = prediction_to_polygon_batch(
            batch_bboxes, image_sizes, self.bbox_size, self.bbox_size // 2
        )
        layout_results = []
        for image, image_tokens, image_polygons, image_scores, image_topk_scores in zip(
            images, predicted_tokens, predicted_polygons, scores, topk_scores
        ):
            layout_boxes = []
            for z, (tok, poly, score, tok_topk) in enumerate(
                zip(image_tokens, image_polygons, image_scores, image_topk_scores)
            ):
                if tok == self.processor.eos_token_id:
                    break

                predicted_label = self.processor.decode([tok], "layout")
                label = LAYOUT_PRED_RELABEL.get(predicted_label)
                if not label:
                    # Layout can sometimes return unknown labels from other objectives
                    continue

                top_k_dict = {}
                for k, v in tok_topk.items():
                    topk_label = self.processor.decode([k], "layout")
                    if topk_label in LAYOUT_PRED_RELABEL:
                        topk_label = LAYOUT_PRED_RELABEL[topk_label]
                    if not topk_label.strip():
                        continue
                    top_k_dict.update({topk_label: v})
                layout_boxes.append(
                    LayoutBox(
                        polygon=poly.tolist(),
                        label=label,
                        position=z,
                        top_k=top_k_dict,
                        confidence=score,
                    )
                )
            layout_boxes = clean_boxes(layout_boxes)
            layout_results.append(
                LayoutResult(
                    bboxes=layout_boxes,
                    image_bbox=[0, 0, image.shape[1], image.shape[0]],
                )  # Image is numpy array
            )

        assert len(layout_results) == len(images)
        return layout_results

```

--------------------------------------------------------------------------------
/surya/scripts/table_recognition.py:
--------------------------------------------------------------------------------

```python
import os
import click
import copy
import json
from collections import defaultdict

from surya.logging import configure_logging, get_logger
from surya.scripts.config import CLILoader
from surya.foundation import FoundationPredictor
from surya.layout import LayoutPredictor
from surya.table_rec import TableRecPredictor
from surya.debug.draw import draw_bboxes_on_image
from surya.common.util import rescale_bbox, expand_bbox
from surya.settings import settings

configure_logging()
logger = get_logger()


@click.command(help="Detect layout of an input file or folder (PDFs or image).")
@CLILoader.common_options
@click.option(
    "--skip_table_detection",
    is_flag=True,
    help="Tables are already cropped, so don't re-detect tables.",
    default=False,
)
def table_recognition_cli(input_path: str, skip_table_detection: bool, **kwargs):
    loader = CLILoader(input_path, kwargs, highres=True)

    foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
    layout_predictor = LayoutPredictor(foundation_predictor)
    table_rec_predictor = TableRecPredictor()

    pnums = []
    prev_name = None
    for i, name in enumerate(loader.names):
        if prev_name is None or prev_name != name:
            pnums.append(0)
        else:
            pnums.append(pnums[-1] + 1)

        prev_name = name

    layout_predictions = layout_predictor(loader.images)

    table_imgs = []
    table_counts = []

    for layout_pred, img, highres_img in zip(
        layout_predictions, loader.images, loader.highres_images
    ):
        # The table may already be cropped
        if skip_table_detection:
            table_imgs.append(highres_img)
            table_counts.append(1)
        else:
            # The bbox for the entire table
            bbox = [
                line.bbox
                for line in layout_pred.bboxes
                if line.label in ["Table", "TableOfContents"]
            ]
            # Number of tables per page
            table_counts.append(len(bbox))

            if len(bbox) == 0:
                continue

            page_table_imgs = []
            highres_bbox = []
            for bb in bbox:
                highres_bb = rescale_bbox(bb, img.size, highres_img.size)
                highres_bb = expand_bbox(highres_bb)
                page_table_imgs.append(highres_img.crop(highres_bb))
                highres_bbox.append(highres_bb)

            table_imgs.extend(page_table_imgs)

    table_preds = table_rec_predictor(table_imgs)

    img_idx = 0
    prev_count = 0
    table_predictions = defaultdict(list)
    for i in range(sum(table_counts)):
        while i >= prev_count + table_counts[img_idx]:
            prev_count += table_counts[img_idx]
            img_idx += 1

        pred = table_preds[i]
        orig_name = loader.names[img_idx]
        pnum = pnums[img_idx]
        table_img = table_imgs[i]

        out_pred = pred.model_dump()
        out_pred["page"] = pnum + 1
        table_idx = i - prev_count
        out_pred["table_idx"] = table_idx
        table_predictions[orig_name].append(out_pred)

        if loader.save_images:
            rows = [line.bbox for line in pred.rows]
            cols = [line.bbox for line in pred.cols]
            row_labels = [f"Row {line.row_id}" for line in pred.rows]
            col_labels = [f"Col {line.col_id}" for line in pred.cols]
            cells = [line.bbox for line in pred.cells]

            rc_image = copy.deepcopy(table_img)
            rc_image = draw_bboxes_on_image(
                rows, rc_image, labels=row_labels, label_font_size=20, color="blue"
            )
            rc_image = draw_bboxes_on_image(
                cols, rc_image, labels=col_labels, label_font_size=20, color="red"
            )
            rc_image.save(
                os.path.join(
                    loader.result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png"
                )
            )

            cell_image = copy.deepcopy(table_img)
            cell_image = draw_bboxes_on_image(cells, cell_image, color="green")
            cell_image.save(
                os.path.join(
                    loader.result_path,
                    f"{name}_page{pnum + 1}_table{table_idx}_cells.png",
                )
            )

    with open(
        os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8"
    ) as f:
        json.dump(table_predictions, f, ensure_ascii=False)

    logger.info(f"Wrote results to {loader.result_path}")

```

--------------------------------------------------------------------------------
/CLA.md:
--------------------------------------------------------------------------------

```markdown
Surya Contributor Agreement

This Surya Contributor Agreement ("SCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Endless Labs, Inc. The term "you" shall mean the person or entity identified below. 

If you agree to be bound by these terms, sign by writing "I have read the CLA document and I hereby sign the CLA" in response to the CLA bot Github comment. Read this agreement carefully before signing. These terms and conditions constitute a binding legal agreement.

1. The term 'contribution' or 'contributed materials' means any source code, object code, patch, tool, sample, graphic, specification, manual, documentation, or any other material posted or submitted by you to the project. 
2. With respect to any worldwide copyrights, or copyright applications and registrations, in your contribution: 
   - you hereby assign to us joint ownership, and to the extent that such assignment is or becomes invalid, ineffective or unenforceable, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty free, unrestricted license to exercise all rights under those copyrights. This includes, at our option, the right to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements, including dual-license structures for commercial customers; 
   - you agree that each of us can do all things in relation to your contribution as if each of us were the sole owners, and if one of us makes a derivative work of your contribution, the one who makes the derivative work (or has it made will be the sole owner of that derivative work; 
   - you agree that you will not assert any moral rights in your contribution against us, our licensees or transferees; 
   - you agree that we may register a copyright in your contribution and exercise all ownership rights associated with it; and 
   - you agree that neither of us has any duty to consult with, obtain the consent of, pay or render an accounting to the other for any use or distribution of vour contribution. 
3. With respect to any patents you own, or that you can license without payment to any third party, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty-free license to:
   - make, have made, use, sell, offer to sell, import, and otherwise transfer your contribution in whole or in part, alone or in combination with or included in any product, work or materials arising out of the project to which your contribution was submitted, and
   - at our option, to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements. 
If you or your affiliates institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the contribution or any project it was submitted to constitutes direct or contributory patent infringement, then any patent licenses granted to you under this agreement for that contribution shall terminate as of the date such litigation is filed.
4. Except as set out above, you keep all right, title, and interest in your contribution. The rights that you grant to us under these terms are effective on the date you first submitted a contribution to us, even if your submission took place before the date you sign these terms. Any contribution we make available under any license will also be made available under a suitable FSF (Free Software Foundation) or OSI (Open Source Initiative) approved license. 
5. You covenant, represent, warrant and agree that: 
   - each contribution that you submit is and shall be an original work of authorship and you can legally grant the rights set out in this SCA; 
   - to the best of your knowledge, each contribution will not violate any third party's copyrights, trademarks, patents, or other intellectual property rights; and 
   - each contribution shall be in compliance with U.S. export control laws and other applicable export and import laws.
You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Endless Labs, Inc. may publicly disclose your participation in the project, including the fact that you have signed the SCA. 
6. This SCA is governed by the laws of the State of California and applicable U.S. Federal law. Any choice of law rules will not apply.
```

--------------------------------------------------------------------------------
/surya/scripts/texify_app.py:
--------------------------------------------------------------------------------

```python
import os
import re
from typing import List

from surya.recognition import RecognitionPredictor
from surya.foundation import FoundationPredictor
from surya.common.surya.schema import TaskNames

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = (
    "1"  # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS
)

import io

import pandas as pd
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import hashlib
import pypdfium2

from surya.settings import settings
from PIL import Image

MAX_WIDTH = 800
MAX_HEIGHT = 1000


def replace_fences(text):
    text = re.sub(r'<math display="block">(.*?)</math>', r"$$\1$$", text)
    text = re.sub(r"<math>(.*?)</math>", r"$\1$", text)
    text = re.sub(r'<math display="inline">(.*?)</math>', r"$\1$", text)
    return text


@st.cache_resource()
def load_predictor():
    foundation_predictor = FoundationPredictor()
    return RecognitionPredictor(foundation_predictor)


@st.cache_data()
def inference(pil_image: Image.Image, bbox: List[float]):
    input_img = pil_image.crop(bbox)
    bbox = [0, 0, input_img.width, input_img.height]
    model_output = predictor(
        [input_img], [TaskNames.block_without_boxes], bboxes=[[bbox]]
    )
    return model_output[0].text_lines[0].text


def open_pdf(pdf_file):
    stream = io.BytesIO(pdf_file.getvalue())
    return pypdfium2.PdfDocument(stream)


@st.cache_data()
def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI_HIGHRES):
    doc = open_pdf(pdf_file)
    renderer = doc.render(
        pypdfium2.PdfBitmap.to_pil,
        page_indices=[page_num - 1],
        scale=dpi / 72,
    )
    png = list(renderer)[0]
    png_image = png.convert("RGB")
    doc.close()
    return png_image


@st.cache_data()
def page_counter(pdf_file):
    doc = open_pdf(pdf_file)
    doc_len = len(doc)
    doc.close()
    return doc_len


def resize_image(pil_image):
    if pil_image is None:
        return
    pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)


def get_canvas_hash(pil_image):
    return hashlib.md5(pil_image.tobytes()).hexdigest()


st.set_page_config(layout="wide")

top_message = """### LaTeX OCR

After the model loads, upload an image or a pdf, then draw a box around the equation or text you want to OCR by clicking and dragging. Surya will convert it to Markdown with LaTeX math on the right.
"""

st.markdown(top_message)
col1, col2 = st.columns([0.7, 0.3])

predictor = load_predictor()

in_file = st.sidebar.file_uploader(
    "PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"]
)
if in_file is None:
    st.stop()

if in_file is None:
    st.stop()

filetype = in_file.type
page_count = None
if "pdf" in filetype:
    page_count = page_counter(in_file)
    page_number = st.sidebar.number_input(
        f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count
    )
    pil_image = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
else:
    pil_image = Image.open(in_file).convert("RGB")
    page_number = None

if pil_image is None:
    st.stop()

pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
canvas_hash = get_canvas_hash(pil_image)

with col1:
    # Create a canvas component
    canvas_result = st_canvas(
        fill_color="rgba(255, 165, 0, 0.1)",  # Fixed fill color with some opacity
        stroke_width=1,
        stroke_color="#FFAA00",
        background_color="#FFF",
        background_image=pil_image,
        update_streamlit=True,
        height=pil_image.height,
        width=pil_image.width,
        drawing_mode="rect",
        point_display_radius=0,
        key=canvas_hash,
    )

if not canvas_result.json_data:
    st.stop()

objects = pd.json_normalize(
    canvas_result.json_data["objects"]
)  # need to convert obj to str because PyArrow
bbox_list = None
if objects.shape[0] > 0:
    boxes = objects[objects["type"] == "rect"][["left", "top", "width", "height"]]
    boxes["right"] = boxes["left"] + boxes["width"]
    boxes["bottom"] = boxes["top"] + boxes["height"]
    bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist()

if bbox_list:
    with col2:
        texts = [inference(pil_image, bbox) for bbox in bbox_list]
        for idx, latex in enumerate(reversed(texts)):
            st.markdown(f"### {len(texts) - idx}")
            st.markdown(replace_fences(latex), unsafe_allow_html=True)
            st.code(latex)
            st.divider()

with col2:
    tips = """
    ### Usage tips
    - Texify is sensitive to how you draw the box around the text you want to OCR. If you get bad results, try selecting a slightly different box, or splitting the box into multiple.
    """
    st.markdown(tips)

```

--------------------------------------------------------------------------------
/surya/scripts/finetune_ocr.py:
--------------------------------------------------------------------------------

```python
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, Tuple
from datasets import load_dataset
import numpy as np
import torch
from transformers import (
    HfArgumentParser,
    TrainingArguments,
    Trainer,
)

from surya.common.surya import SuryaModel
from surya.common.surya.processor import SuryaOCRProcessor
from surya.foundation import FoundationPredictor
from surya.common.surya.processor.schema import ImageInput, TextInput
from surya.common.surya.schema import TaskNames
from surya.common.util import get_top_scripts, SCRIPT_TOKEN_MAPPING

# Do not change these defaults
OCR_TASK_NAME = TaskNames.ocr_with_boxes
OCR_MAX_IMAGE_SIZE = (1024, 512)

# Simple wrapper for huggingface dataset
class SuryaOCRDataset(torch.utils.data.Dataset):
    def __init__(self, processor: SuryaOCRProcessor, data_args: SuryaOCRDataArguments):
        super().__init__()
        self.hf_dataset = load_dataset(data_args.dataset_name, num_proc=data_args.num_loading_proc, split="train")
        self.processor = processor

    def __len__(self):
        return len(self.hf_dataset)

    def get_script_text(self, text: str) -> str:
        scripts = get_top_scripts(text)
        script_text = "".join(SCRIPT_TOKEN_MAPPING[script] for script in scripts)
        return script_text

    def __getitem__(self, index):
        try:
            data = self.hf_dataset[index]
            image = data["image"]
            image = image.convert("RGB")
            image = np.asarray(image, dtype=np.float32)
            image = self.processor.scale_to_fit(image, max_size=OCR_MAX_IMAGE_SIZE)

            # Add in script information
            gt_text = data["text"]
            gt_text = self.get_script_text(gt_text) + gt_text

            return_dict = {
                "task": TaskNames.ocr_with_boxes,
                "inputs": [
                    ImageInput(type="image", image=image, rotated=False),
                    # This empty TextInput **must be included** to match the original format
                    TextInput(type="text", text=""),
                    TextInput(type="text",text=gt_text),
                ],
            }
            return return_dict
        except:
            import traceback; traceback.print_exc()
            return self.__getitem__((index + 1) % self.__len__())

class SuryaOCRDataCollator:
    def __init__(self, processor: SuryaOCRProcessor, data_args: SuryaOCRDataArguments):
        self.processor = processor
        self.max_sequence_length = data_args.max_sequence_length

    def __call__(self, inputs):
        # Use right padding for training. Defaults to left for inference
        processed_batch = self.processor(inputs, padding_side="right")
        
        if self.max_sequence_length is not None:
            processed_batch["input_ids"] = processed_batch["input_ids"][:, :self.max_sequence_length]
            processed_batch["attention_mask"] = processed_batch["attention_mask"][:, :self.max_sequence_length]
            processed_batch["position_ids"] = processed_batch["position_ids"][:, :self.max_sequence_length]

        lm_labels = processed_batch["input_ids"].clone()
        skip_label_mask = (
            (lm_labels == self.processor.pad_token_id )
            | (lm_labels == self.processor.bos_token_id[TaskNames.ocr_with_boxes])
            | (lm_labels == self.processor.eoi_token_id)
            | (lm_labels == self.processor.image_token_id)
        )
        lm_labels[skip_label_mask] = -100
        processed_batch["labels"] = lm_labels

        return processed_batch

def load_model_and_processor(checkpoint_path: Optional[str] = None) -> Tuple[SuryaModel, SuryaOCRProcessor]:
    foundation_predictor = FoundationPredictor(checkpoint=checkpoint_path)
    return foundation_predictor.model, foundation_predictor.processor

@dataclass
class SuryaOCRModelArguments:
    pretrained_checkpoint_path: Optional[str] = field(default=None)

@dataclass
class SuryaOCRDataArguments:
    dataset_name: str = field(default="datalab-to/ocr_finetune_example")
    num_loading_proc: int = field(default=16)
    max_sequence_length: Optional[int] = field(default=None)

@dataclass
class SuryaOCRTrainingArguments(TrainingArguments):
    remove_unused_columns: bool = field(default=False)
    
def main():
    parser = HfArgumentParser((SuryaOCRModelArguments, SuryaOCRDataArguments, SuryaOCRTrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    model, processor = load_model_and_processor(model_args.pretrained_checkpoint_path)
    dataset = SuryaOCRDataset(processor, data_args)
    collator = SuryaOCRDataCollator(processor, data_args)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=collator
    )

    trainer.train()

if __name__ == "__main__":
    main()
```
Page 1/4FirstPrevNextLast