#
tokens: 49082/50000 87/133 files (page 1/5)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 1 of 5. Use http://codebase.md/datalab-to/surya?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── ISSUE_TEMPLATE
│   │   ├── breaking-bug-report.md
│   │   ├── feature_request.md
│   │   └── output-bug-report.md
│   └── workflows
│       ├── benchmarks.yml
│       ├── ci.yml
│       ├── cla.yml
│       ├── publish.yml
│       └── scripts.yml
├── .gitignore
├── .pre-commit-config.yaml
├── benchmark
│   ├── detection.py
│   ├── layout.py
│   ├── ordering.py
│   ├── recognition.py
│   ├── table_recognition.py
│   ├── texify.py
│   └── utils
│       ├── __init__.py
│       ├── bbox.py
│       ├── metrics.py
│       ├── scoring.py
│       ├── tatr.py
│       ├── tesseract.py
│       ├── textract.py
│       └── verify_benchmark_scores.py
├── CITATION.cff
├── CLA.md
├── detect_layout.py
├── detect_text.py
├── LICENSE
├── ocr_app.py
├── ocr_latex.py
├── ocr_text.py
├── poetry.lock
├── pyproject.toml
├── pytest.ini
├── README.md
├── signatures
│   └── version1
│       └── cla.json
├── static
│   ├── fonts
│   │   └── .gitignore
│   └── images
│       ├── arabic_layout.jpg
│       ├── arabic_reading.jpg
│       ├── arabic_text.jpg
│       ├── arabic.jpg
│       ├── benchmark_chart_small.png
│       ├── benchmark_chart.png
│       ├── benchmark_layout_chart.png
│       ├── benchmark_rec_chart.png
│       ├── benchmark_tablerec_acc.png
│       ├── benchmark_tablerec_speed.png
│       ├── chi_hind_layout.jpg
│       ├── chi_hind_orig.jpg
│       ├── chi_hind_reading.jpg
│       ├── chi_hind_text.jpg
│       ├── chi_hind.jpg
│       ├── chinese_layout.jpg
│       ├── chinese_reading.jpg
│       ├── chinese_text.jpg
│       ├── chinese.jpg
│       ├── excerpt_layout.png
│       ├── excerpt_reading.jpg
│       ├── excerpt_text.png
│       ├── excerpt.png
│       ├── funsd_layout.jpg
│       ├── funsd_reading.jpg
│       ├── funsd_text.jpg
│       ├── funsd.png
│       ├── gcloud_full_langs.png
│       ├── gcloud_rec_bench.png
│       ├── hindi_layout.jpg
│       ├── hindi_reading.jpg
│       ├── hindi_text.jpg
│       ├── hindi.jpg
│       ├── japanese_layout.jpg
│       ├── japanese_reading.jpg
│       ├── japanese_tablerec.png
│       ├── japanese_text.jpg
│       ├── japanese.jpg
│       ├── latex_ocr.png
│       ├── nyt_layout.jpg
│       ├── nyt_order.jpg
│       ├── nyt_text.jpg
│       ├── nyt.jpg
│       ├── paper_layout.jpg
│       ├── paper_reading.jpg
│       ├── paper_tablerec.png
│       ├── paper_text.jpg
│       ├── paper.jpg
│       ├── pres_layout.jpg
│       ├── pres_reading.jpg
│       ├── pres_tablerec.png
│       ├── pres_text.jpg
│       ├── pres.png
│       ├── rec_acc_table.png
│       ├── scanned_layout.jpg
│       ├── scanned_reading.jpg
│       ├── scanned_tablerec.png
│       ├── scanned_tablerec2.png
│       ├── scanned_text.jpg
│       ├── scanned.png
│       ├── surya_rec_perf.png
│       ├── table_rec.png
│       ├── textbook_layout.jpg
│       ├── textbook_order.jpg
│       ├── textbook_text.jpg
│       └── textbook.jpg
├── surya
│   ├── __init__.py
│   ├── common
│   │   ├── __init__.py
│   │   ├── adetr
│   │   │   └── decoder.py
│   │   ├── donut
│   │   │   ├── encoder.py
│   │   │   └── processor.py
│   │   ├── load.py
│   │   ├── polygon.py
│   │   ├── predictor.py
│   │   ├── pretrained.py
│   │   ├── s3.py
│   │   ├── surya
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   ├── decoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── embedder
│   │   │   │   └── __init__.py
│   │   │   ├── encoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── flash_attn_utils.py
│   │   │   ├── processor
│   │   │   │   ├── __init__.py
│   │   │   │   ├── schema.py
│   │   │   │   └── tokenizer.py
│   │   │   └── schema.py
│   │   ├── util.py
│   │   └── xla.py
│   ├── debug
│   │   ├── draw.py
│   │   ├── fonts.py
│   │   ├── katex.js
│   │   ├── render_html.py
│   │   └── text.py
│   ├── detection
│   │   ├── __init__.py
│   │   ├── heatmap.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoderdecoder.py
│   │   ├── parallel.py
│   │   ├── processor.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── foundation
│   │   ├── __init__.py
│   │   ├── cache
│   │   │   ├── __init__.py
│   │   │   ├── dynamic_ops.py
│   │   │   └── static_ops.py
│   │   ├── loader.py
│   │   └── util.py
│   ├── input
│   │   ├── load.py
│   │   └── processing.py
│   ├── layout
│   │   ├── __init__.py
│   │   ├── label.py
│   │   └── schema.py
│   ├── logging.py
│   ├── models.py
│   ├── ocr_error
│   │   ├── __init__.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoder.py
│   │   ├── schema.py
│   │   └── tokenizer.py
│   ├── recognition
│   │   ├── __init__.py
│   │   ├── languages.py
│   │   ├── postprocessing.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── scripts
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── detect_layout.py
│   │   ├── detect_text.py
│   │   ├── finetune_ocr.py
│   │   ├── hf_to_s3.py
│   │   ├── ocr_latex.py
│   │   ├── ocr_text.py
│   │   ├── run_streamlit_app.py
│   │   ├── run_texify_app.py
│   │   ├── streamlit_app.py
│   │   ├── table_recognition.py
│   │   └── texify_app.py
│   ├── settings.py
│   └── table_rec
│       ├── __init__.py
│       ├── loader.py
│       ├── model
│       │   ├── __init__.py
│       │   ├── config.py
│       │   ├── decoder.py
│       │   ├── encoder.py
│       │   └── encoderdecoder.py
│       ├── processor.py
│       ├── schema.py
│       └── shaper.py
├── table_recognition.py
├── tests
│   ├── assets
│   │   └── test_latex.png
│   ├── conftest.py
│   ├── test_detection.py
│   ├── test_foundation.py
│   ├── test_latex_ocr.py
│   ├── test_layout.py
│   ├── test_ocr_errors.py
│   ├── test_recognition.py
│   └── test_table_rec.py
└── texify_app.py
```

# Files

--------------------------------------------------------------------------------
/static/fonts/.gitignore:
--------------------------------------------------------------------------------

```
1 | *
2 | !.gitignore
```

--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------

```yaml
 1 | repos:
 2 | - repo: https://github.com/astral-sh/ruff-pre-commit
 3 |   # Ruff version.
 4 |   rev: v0.9.10
 5 |   hooks:
 6 |     # Run the linter.
 7 |     - id: ruff
 8 |       types_or: [ python, pyi ]
 9 |       args: [ --fix ]
10 |     # Run the formatter.
11 |     - id: ruff-format
12 |       types_or: [ python, pyi ]
```

--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------

```
  1 | private.py
  2 | .DS_Store
  3 | local.env
  4 | experiments
  5 | test_data
  6 | training
  7 | wandb
  8 | notebooks
  9 | results
 10 | data
 11 | slices
 12 | 
 13 | # Byte-compiled / optimized / DLL files
 14 | __pycache__/
 15 | *.py[cod]
 16 | *$py.class
 17 | 
 18 | # C extensions
 19 | *.so
 20 | 
 21 | # Distribution / packaging
 22 | .Python
 23 | build/
 24 | develop-eggs/
 25 | dist/
 26 | downloads/
 27 | eggs/
 28 | .eggs/
 29 | lib/
 30 | lib64/
 31 | parts/
 32 | sdist/
 33 | var/
 34 | wheels/
 35 | share/python-wheels/
 36 | *.egg-info/
 37 | .installed.cfg
 38 | *.egg
 39 | MANIFEST
 40 | 
 41 | # PyInstaller
 42 | #  Usually these files are written by a python script from a template
 43 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 44 | *.manifest
 45 | *.spec
 46 | 
 47 | # Installer logs
 48 | pip-log.txt
 49 | pip-delete-this-directory.txt
 50 | 
 51 | # Unit test / coverage reports
 52 | htmlcov/
 53 | .tox/
 54 | .nox/
 55 | .coverage
 56 | .coverage.*
 57 | .cache
 58 | nosetests.xml
 59 | coverage.xml
 60 | *.cover
 61 | *.py,cover
 62 | .hypothesis/
 63 | .pytest_cache/
 64 | cover/
 65 | 
 66 | # Translations
 67 | *.mo
 68 | *.pot
 69 | 
 70 | # Django stuff:
 71 | *.log
 72 | local_settings.py
 73 | db.sqlite3
 74 | db.sqlite3-journal
 75 | 
 76 | # Flask stuff:
 77 | instance/
 78 | .webassets-cache
 79 | 
 80 | # Scrapy stuff:
 81 | .scrapy
 82 | 
 83 | # Sphinx documentation
 84 | docs/_build/
 85 | 
 86 | # PyBuilder
 87 | .pybuilder/
 88 | target/
 89 | 
 90 | # Jupyter Notebook
 91 | .ipynb_checkpoints
 92 | 
 93 | # IPython
 94 | profile_default/
 95 | ipython_config.py
 96 | 
 97 | # pyenv
 98 | #   For a library or package, you might want to ignore these files since the code is
 99 | #   intended to run in multiple environments; otherwise, check them in:
100 | # .python-version
101 | 
102 | # pipenv
103 | #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
104 | #   However, in case of collaboration, if having platform-specific dependencies or dependencies
105 | #   having no cross-platform support, pipenv may install dependencies that don't work, or not
106 | #   install all needed dependencies.
107 | #Pipfile.lock
108 | 
109 | # poetry
110 | #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
111 | #   This is especially recommended for binary packages to ensure reproducibility, and is more
112 | #   commonly ignored for libraries.
113 | #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
114 | #poetry.lock
115 | 
116 | # pdm
117 | #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
118 | #pdm.lock
119 | #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
120 | #   in version control.
121 | #   https://pdm.fming.dev/#use-with-ide
122 | .pdm.toml
123 | 
124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
125 | __pypackages__/
126 | 
127 | # Celery stuff
128 | celerybeat-schedule
129 | celerybeat.pid
130 | 
131 | # SageMath parsed files
132 | *.sage.py
133 | 
134 | # Environments
135 | .env
136 | .venv
137 | env/
138 | venv/
139 | ENV/
140 | env.bak/
141 | venv.bak/
142 | 
143 | # Spyder project settings
144 | .spyderproject
145 | .spyproject
146 | 
147 | # Rope project settings
148 | .ropeproject
149 | 
150 | # mkdocs documentation
151 | /site
152 | 
153 | # mypy
154 | .mypy_cache/
155 | .dmypy.json
156 | dmypy.json
157 | 
158 | # Pyre type checker
159 | .pyre/
160 | 
161 | # pytype static type analyzer
162 | .pytype/
163 | 
164 | # Cython debug symbols
165 | cython_debug/
166 | 
167 | # PyCharm
168 | #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
169 | #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
170 | #  and can be added to the global gitignore or merged into this file.  For a more nuclear
171 | #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
172 | .idea/
173 | 
```

--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------

```markdown
  1 | # Surya
  2 | 
  3 | Surya is a document OCR toolkit that does:
  4 | 
  5 | - OCR in 90+ languages that benchmarks favorably vs cloud services
  6 | - Line-level text detection in any language
  7 | - Layout analysis (table, image, header, etc detection)
  8 | - Reading order detection
  9 | - Table recognition (detecting rows/columns)
 10 | - LaTeX OCR
 11 | 
 12 | It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).
 13 | 
 14 | For our managed API or on-prem document intelligence solution, check out [our platform here](https://datalab.to?utm_source=gh-surya).
 15 | 
 16 | 
 17 | |                            Detection                             |                                   OCR                                   |
 18 | |:----------------------------------------------------------------:|:-----------------------------------------------------------------------:|
 19 | |  <img src="static/images/excerpt.png" width="500px"/>  |  <img src="static/images/excerpt_text.png" width="500px"/> |
 20 | 
 21 | |                               Layout                               |                               Reading Order                                |
 22 | |:------------------------------------------------------------------:|:--------------------------------------------------------------------------:|
 23 | | <img src="static/images/excerpt_layout.png" width="500px"/> | <img src="static/images/excerpt_reading.jpg" width="500px"/> |
 24 | 
 25 | |                       Table Recognition                       |                       LaTeX OCR                        |
 26 | |:-------------------------------------------------------------:|:------------------------------------------------------:|
 27 | | <img src="static/images/scanned_tablerec.png" width="500px"/> | <img src="static/images/latex_ocr.png" width="500px"/> |
 28 | 
 29 | 
 30 | Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision.
 31 | 
 32 | ## Community
 33 | 
 34 | [Discord](https://discord.gg//KuZwXNGnfH) is where we discuss future development.
 35 | 
 36 | ## Examples
 37 | 
 38 | | Name             |              Detection              |                                      OCR |                                     Layout |                                       Order |                                    Table Rec |
 39 | |------------------|:-----------------------------------:|-----------------------------------------:|-------------------------------------------:|--------------------------------------------:|---------------------------------------------:|
 40 | | Japanese         | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | [Image](static/images/japanese_reading.jpg) | [Image](static/images/japanese_tablerec.png) |
 41 | | Chinese          | [Image](static/images/chinese.jpg)  |  [Image](static/images/chinese_text.jpg) |  [Image](static/images/chinese_layout.jpg) |  [Image](static/images/chinese_reading.jpg) |                                              |
 42 | | Hindi            |  [Image](static/images/hindi.jpg)   |    [Image](static/images/hindi_text.jpg) |    [Image](static/images/hindi_layout.jpg) |    [Image](static/images/hindi_reading.jpg) |                                              |
 43 | | Arabic           |  [Image](static/images/arabic.jpg)  |   [Image](static/images/arabic_text.jpg) |   [Image](static/images/arabic_layout.jpg) |   [Image](static/images/arabic_reading.jpg) |                                              |
 44 | | Chinese + Hindi  | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | [Image](static/images/chi_hind_reading.jpg) |                                              |
 45 | | Presentation     |   [Image](static/images/pres.png)   |     [Image](static/images/pres_text.jpg) |     [Image](static/images/pres_layout.jpg) |     [Image](static/images/pres_reading.jpg) |     [Image](static/images/pres_tablerec.png) |
 46 | | Scientific Paper |  [Image](static/images/paper.jpg)   |    [Image](static/images/paper_text.jpg) |    [Image](static/images/paper_layout.jpg) |    [Image](static/images/paper_reading.jpg) |    [Image](static/images/paper_tablerec.png) |
 47 | | Scanned Document | [Image](static/images/scanned.png)  |  [Image](static/images/scanned_text.jpg) |  [Image](static/images/scanned_layout.jpg) |  [Image](static/images/scanned_reading.jpg) |  [Image](static/images/scanned_tablerec.png) |
 48 | | New York Times   |   [Image](static/images/nyt.jpg)    |      [Image](static/images/nyt_text.jpg) |      [Image](static/images/nyt_layout.jpg) |        [Image](static/images/nyt_order.jpg) |                                              |
 49 | | Scanned Form     |  [Image](static/images/funsd.png)   |    [Image](static/images/funsd_text.jpg) |    [Image](static/images/funsd_layout.jpg) |    [Image](static/images/funsd_reading.jpg) | [Image](static/images/scanned_tablerec2.png) |
 50 | | Textbook         | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) |   [Image](static/images/textbook_order.jpg) |                                              |
 51 | 
 52 | # Hosted API
 53 | 
 54 | There is a hosted API for all surya models available [here](https://www.datalab.to?utm_source=gh-surya):
 55 | 
 56 | - Works with PDF, images, word docs, and powerpoints
 57 | - Consistent speed, with no latency spikes
 58 | - High reliability and uptime
 59 | 
 60 | # Commercial usage
 61 | 
 62 | Our model weights use a modified AI Pubs Open Rail-M license (free for research, personal use, and startups under $2M funding/revenue) and our code is GPL. For broader commercial licensing or to remove GPL requirements, visit our pricing page [here](https://www.datalab.to/pricing?utm_source=gh-surya).
 63 | 
 64 | 
 65 | # Installation
 66 | 
 67 | You'll need python 3.10+ and PyTorch. You may need to install the CPU version of torch first if you're not using a Mac or a GPU machine.  See [here](https://pytorch.org/get-started/locally/) for more details.
 68 | 
 69 | Install with:
 70 | 
 71 | ```shell
 72 | pip install surya-ocr
 73 | ```
 74 | 
 75 | Model weights will automatically download the first time you run surya.
 76 | 
 77 | # Usage
 78 | 
 79 | - Inspect the settings in `surya/settings.py`.  You can override any settings with environment variables.
 80 | - Your torch device will be automatically detected, but you can override this.  For example, `TORCH_DEVICE=cuda`.
 81 | 
 82 | ## Interactive App
 83 | 
 84 | I've included a streamlit app that lets you interactively try Surya on images or PDF files.  Run it with:
 85 | 
 86 | ```shell
 87 | pip install streamlit pdftext
 88 | surya_gui
 89 | ```
 90 | 
 91 | ## OCR (text recognition)
 92 | 
 93 | This command will write out a json file with the detected text and bboxes:
 94 | 
 95 | ```shell
 96 | surya_ocr DATA_PATH
 97 | ```
 98 | 
 99 | - `DATA_PATH` can be an image, pdf, or folder of images/pdfs
100 | - `--task_name` will specify which task to use for predicting the lines.  `ocr_with_boxes` is the default, which will format text and give you bboxes.  If you get bad performance, try `ocr_without_boxes`, which will give you potentially better performance but no bboxes.  For blocks like equations and paragraphs, try `block_without_boxes`.
101 | - `--images` will save images of the pages and detected text lines (optional)
102 | - `--output_dir` specifies the directory to save results to instead of the default
103 | - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
104 | - `--disable_math` - by default, surya will recognize math in text.  This can lead to false positives - you can disable this with this flag.
105 | 
106 | The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  Each page dictionary contains:
107 | 
108 | - `text_lines` - the detected text and bounding boxes for each line
109 |   - `text` - the text in the line
110 |   - `confidence` - the confidence of the model in the detected text (0-1)
111 |   - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format.  The points are in clockwise order from the top left.
112 |   - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
113 |   - `chars` - the individual characters in the line
114 |     - `text` - the text of the character
115 |     - `bbox` - the character bbox (same format as line bbox)
116 |     - `polygon` - the character polygon (same format as line polygon)
117 |     - `confidence` - the confidence of the model in the detected character (0-1)
118 |     - `bbox_valid` - if the character is a special token or math, the bbox may not be valid
119 |   - `words` - the individual words in the line (computed from the characters)
120 |     - `text` - the text of the word
121 |     - `bbox` - the word bbox (same format as line bbox)
122 |     - `polygon` - the word polygon (same format as line polygon)
123 |     - `confidence` - mean character confidence
124 |     - `bbox_valid` - if the word is a special token or math, the bbox may not be valid
125 | - `page` - the page number in the file
126 | - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.  All line bboxes will be contained within this bbox.
127 | 
128 | **Performance tips**
129 | 
130 | Setting the `RECOGNITION_BATCH_SIZE` env var properly will make a big difference when using a GPU.  Each batch item will use `40MB` of VRAM, so very high batch sizes are possible.  The default is a batch size `512`, which will use about 20GB of VRAM.  Depending on your CPU core count, it may help, too - the default CPU batch size is `32`.
131 | 
132 | ### From python
133 | 
134 | ```python
135 | from PIL import Image
136 | from surya.foundation import FoundationPredictor
137 | from surya.recognition import RecognitionPredictor
138 | from surya.detection import DetectionPredictor
139 | 
140 | image = Image.open(IMAGE_PATH)
141 | foundation_predictor = FoundationPredictor()
142 | recognition_predictor = RecognitionPredictor(foundation_predictor)
143 | detection_predictor = DetectionPredictor()
144 | 
145 | predictions = recognition_predictor([image], det_predictor=detection_predictor)
146 | ```
147 | 
148 | 
149 | ## Text line detection
150 | 
151 | This command will write out a json file with the detected bboxes.
152 | 
153 | ```shell
154 | surya_detect DATA_PATH
155 | ```
156 | 
157 | - `DATA_PATH` can be an image, pdf, or folder of images/pdfs
158 | - `--images` will save images of the pages and detected text lines (optional)
159 | - `--output_dir` specifies the directory to save results to instead of the default
160 | - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
161 | 
162 | The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  Each page dictionary contains:
163 | 
164 | - `bboxes` - detected bounding boxes for text
165 |   - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
166 |   - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format.  The points are in clockwise order from the top left.
167 |   - `confidence` - the confidence of the model in the detected text (0-1)
168 | - `vertical_lines` - vertical lines detected in the document
169 |   - `bbox` - the axis-aligned line coordinates.
170 | - `page` - the page number in the file
171 | - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.  All line bboxes will be contained within this bbox.
172 | 
173 | **Performance tips**
174 | 
175 | Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU.  Each batch item will use `440MB` of VRAM, so very high batch sizes are possible.  The default is a batch size `36`, which will use about 16GB of VRAM.  Depending on your CPU core count, it might help, too - the default CPU batch size is `6`.
176 | 
177 | ### From python
178 | 
179 | ```python
180 | from PIL import Image
181 | from surya.detection import DetectionPredictor
182 | 
183 | image = Image.open(IMAGE_PATH)
184 | det_predictor = DetectionPredictor()
185 | 
186 | # predictions is a list of dicts, one per image
187 | predictions = det_predictor([image])
188 | ```
189 | 
190 | ## Layout and reading order
191 | 
192 | This command will write out a json file with the detected layout and reading order.
193 | 
194 | ```shell
195 | surya_layout DATA_PATH
196 | ```
197 | 
198 | - `DATA_PATH` can be an image, pdf, or folder of images/pdfs
199 | - `--images` will save images of the pages and detected text lines (optional)
200 | - `--output_dir` specifies the directory to save results to instead of the default
201 | - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
202 | 
203 | The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  Each page dictionary contains:
204 | 
205 | - `bboxes` - detected bounding boxes for text
206 |   - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
207 |   - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format.  The points are in clockwise order from the top left.
208 |   - `position` - the reading order of the box.
209 |   - `label` - the label for the bbox.  One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`.
210 |   - `top_k` - the top-k other potential labels for the box.  A dictionary with labels as keys and confidences as values.
211 | - `page` - the page number in the file
212 | - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.  All line bboxes will be contained within this bbox.
213 | 
214 | **Performance tips**
215 | 
216 | Setting the `LAYOUT_BATCH_SIZE` env var properly will make a big difference when using a GPU.  Each batch item will use `220MB` of VRAM, so very high batch sizes are possible.  The default is a batch size `32`, which will use about 7GB of VRAM.  Depending on your CPU core count, it might help, too - the default CPU batch size is `4`.
217 | 
218 | ### From python
219 | 
220 | ```python
221 | from PIL import Image
222 | from surya.foundation import FoundationPredictor
223 | from surya.layout import LayoutPredictor
224 | from surya.settings import settings
225 | 
226 | image = Image.open(IMAGE_PATH)
227 | layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT))
228 | 
229 | # layout_predictions is a list of dicts, one per image
230 | layout_predictions = layout_predictor([image])
231 | ```
232 | 
233 | ## Table Recognition
234 | 
235 | This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes.  If you want to get cell positions and text, along with nice formatting, check out the [marker](https://www.github.com/VikParuchuri/marker) repo.  You can use the `TableConverter` to detect and extract tables in images and PDFs.  It supports output in json (with bboxes), markdown, and html.
236 | 
237 | ```shell
238 | surya_table DATA_PATH
239 | ```
240 | 
241 | - `DATA_PATH` can be an image, pdf, or folder of images/pdfs
242 | - `--images` will save images of the pages and detected table cells + rows and columns (optional)
243 | - `--output_dir` specifies the directory to save results to instead of the default
244 | - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
245 | - `--detect_boxes` specifies if cells should be detected.  By default, they're pulled out of the PDF, but this is not always possible.
246 | - `--skip_table_detection` tells table recognition not to detect tables first.  Use this if your image is already cropped to a table.
247 | 
248 | The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  Each page dictionary contains:
249 | 
250 | - `rows` - detected table rows
251 |   - `bbox` - the bounding box of the table row
252 |   - `row_id` - the id of the row
253 |   - `is_header` - if it is a header row.
254 | - `cols` - detected table columns
255 |   - `bbox` - the bounding box of the table column
256 |   - `col_id`- the id of the column
257 |   - `is_header` - if it is a header column
258 | - `cells` - detected table cells
259 |   - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
260 |   - `text` - if text could be pulled out of the pdf, the text of this cell.
261 |   - `row_id` - the id of the row the cell belongs to.
262 |   - `col_id` - the id of the column the cell belongs to.
263 |   - `colspan` - the number of columns spanned by the cell.
264 |   - `rowspan` - the number of rows spanned by the cell.
265 |   - `is_header` - whether it is a header cell.
266 | - `page` - the page number in the file
267 | - `table_idx` - the index of the table on the page (sorted in vertical order)
268 | - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format.  (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.  All line bboxes will be contained within this bbox.
269 | 
270 | **Performance tips**
271 | 
272 | Setting the `TABLE_REC_BATCH_SIZE` env var properly will make a big difference when using a GPU.  Each batch item will use `150MB` of VRAM, so very high batch sizes are possible.  The default is a batch size `64`, which will use about 10GB of VRAM.  Depending on your CPU core count, it might help, too - the default CPU batch size is `8`.
273 | 
274 | ### From python
275 | 
276 | ```python
277 | from PIL import Image
278 | from surya.table_rec import TableRecPredictor
279 | 
280 | image = Image.open(IMAGE_PATH)
281 | table_rec_predictor = TableRecPredictor()
282 | 
283 | table_predictions = table_rec_predictor([image])
284 | ```
285 | 
286 | ## LaTeX OCR
287 | 
288 | This command will write out a json file with the LaTeX of the equations.  You must pass in images that are already cropped to the equations.  You can do this by running the layout model, then cropping, if you want.
289 | 
290 | ```shell
291 | surya_latex_ocr DATA_PATH
292 | ```
293 | 
294 | - `DATA_PATH` can be an image, pdf, or folder of images/pdfs
295 | - `--output_dir` specifies the directory to save results to instead of the default
296 | - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
297 | 
298 | The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions.  Each value will be a list of dictionaries, one per page of the input document.  See the OCR section above for the format of the output.
299 | 
300 | ### From python
301 | 
302 | ```python
303 | from PIL import Image
304 | from surya.texify import TexifyPredictor
305 | 
306 | image = Image.open(IMAGE_PATH)
307 | predictor = TexifyPredictor()
308 | 
309 | predictor([image])
310 | ```
311 | 
312 | ### Interactive app
313 | 
314 | You can also run a special interactive app that lets you select equations and OCR them (kind of like MathPix snip) with:
315 | 
316 | ```shell
317 | pip install streamlit==1.40 streamlit-drawable-canvas-jsretry
318 | texify_gui
319 | ```
320 | 
321 | ## Compilation
322 | 
323 | The following models have support for compilation. You will need to set the following environment variables to enable compilation:
324 | 
325 | - Detection: `COMPILE_DETECTOR=true`
326 | - Layout: `COMPILE_LAYOUT=true`
327 | - Table recognition: `COMPILE_TABLE_REC=true`
328 | 
329 | Alternatively, you can also set `COMPILE_ALL=true` which will compile all models.
330 | 
331 | Here are the speedups on an A10 GPU:
332 | 
333 | | Model             | Time per page (s) | Compiled time per page (s) | Speedup (%) |
334 | | ----------------- | ----------------- | -------------------------- | ----------- |
335 | | Detection         | 0.108808          | 0.10521                    | 3.306742151 |
336 | | Layout            | 0.27319           | 0.27063                    | 0.93707676  |
337 | | Table recognition | 0.0219            | 0.01938                    | 11.50684932 |
338 | 
339 | # Limitations
340 | 
341 | - This is specialized for document OCR.  It will likely not work on photos or other images.
342 | - It is for printed text, not handwriting (though it may work on some handwriting).
343 | - The text detection model has trained itself to ignore advertisements.
344 | - You can find language support for OCR in `surya/recognition/languages.py`.  Text detection, layout analysis, and reading order will work with any language.
345 | 
346 | ## Troubleshooting
347 | 
348 | If OCR isn't working properly:
349 | 
350 | - Try increasing resolution of the image so the text is bigger.  If the resolution is already very high, try decreasing it to no more than a `2048px` width.
351 | - Preprocessing the image (binarizing, deskewing, etc) can help with very old/blurry images.
352 | - You can adjust `DETECTOR_BLANK_THRESHOLD` and `DETECTOR_TEXT_THRESHOLD` if you don't get good results.  `DETECTOR_BLANK_THRESHOLD` controls the space between lines - any prediction below this number will be considered blank space.  `DETECTOR_TEXT_THRESHOLD` controls how text is joined - any number above this is considered text.  `DETECTOR_TEXT_THRESHOLD` should always be higher than `DETECTOR_BLANK_THRESHOLD`, and both should be in the 0-1 range.  Looking at the heatmap from the debug output of the detector can tell you how to adjust these (if you see faint things that look like boxes, lower the thresholds, and if you see bboxes being joined together, raise the thresholds).
353 | 
354 | # Manual install
355 | 
356 | If you want to develop surya, you can install it manually:
357 | 
358 | - `git clone https://github.com/VikParuchuri/surya.git`
359 | - `cd surya`
360 | - `poetry install` - installs main and dev dependencies
361 | - `poetry shell` - activates the virtual environment
362 | 
363 | # Benchmarks
364 | 
365 | ## OCR
366 | 
367 | ![Benchmark chart tesseract](static/images/benchmark_rec_chart.png)
368 | 
369 | | Model     | Time per page (s) | Avg similarity (⬆) |
370 | |-----------|-------------------|--------------------|
371 | | surya     | .62               | 0.97               |
372 | | tesseract | .45               | 0.88               |
373 | 
374 | [Full language results](static/images/rec_acc_table.png)
375 | 
376 | Tesseract is CPU-based, and surya is CPU or GPU.  I tried to cost-match the resources used, so I used a 1xA6000 (48GB VRAM) for surya, and 28 CPU cores for Tesseract (same price on Lambda Labs/DigitalOcean).
377 | 
378 | ### Google Cloud Vision
379 | 
380 | I benchmarked OCR against Google Cloud vision since it has similar language coverage to Surya.
381 | 
382 | ![Benchmark chart google cloud](static/images/gcloud_rec_bench.png)
383 | 
384 | [Full language results](static/images/gcloud_full_langs.png)
385 | 
386 | **Methodology**
387 | 
388 | I measured normalized sentence similarity (0-1, higher is better) based on a set of real-world and synthetic pdfs.  I sampled PDFs from common crawl, then filtered out the ones with bad OCR.  I couldn't find PDFs for some languages, so I also generated simple synthetic PDFs for those.
389 | 
390 | I used the reference line bboxes from the PDFs with both tesseract and surya, to just evaluate the OCR quality.
391 | 
392 | For Google Cloud, I aligned the output from Google Cloud with the ground truth.  I had to skip RTL languages since they didn't align well.
393 | 
394 | ## Text line detection
395 | 
396 | ![Benchmark chart](static/images/benchmark_chart_small.png)
397 | 
398 | | Model     | Time (s)   | Time per page (s)   | precision   |   recall |
399 | |-----------|------------|---------------------|-------------|----------|
400 | | surya     | 47.2285    | 0.094452            | 0.835857    | 0.960807 |
401 | | tesseract | 74.4546    | 0.290838            | 0.631498    | 0.997694 |
402 | 
403 | 
404 | Tesseract is CPU-based, and surya is CPU or GPU.  I ran the benchmarks on a system with an A10 GPU, and a 32 core CPU.  This was the resource usage:
405 | 
406 | - tesseract - 32 CPU cores, or 8 workers using 4 cores each
407 | - surya - 36 batch size, for 16GB VRAM usage
408 | 
409 | **Methodology**
410 | 
411 | Surya predicts line-level bboxes, while tesseract and others predict word-level or character-level.  It's hard to find 100% correct datasets with line-level annotations. Merging bboxes can be noisy, so I chose not to use IoU as the metric for evaluation.
412 | 
413 | I instead used coverage, which calculates:
414 | 
415 | - Precision - how well the predicted bboxes cover ground truth bboxes
416 | - Recall - how well ground truth bboxes cover predicted bboxes
417 | 
418 | First calculate coverage for each bbox, then add a small penalty for double coverage, since we want the detection to have non-overlapping bboxes.  Anything with a coverage of 0.5 or higher is considered a match.
419 | 
420 | Then we calculate precision and recall for the whole dataset.
421 | 
422 | ## Layout analysis
423 | 
424 | | Layout Type   |   precision |   recall |
425 | |---------------|-------------|----------|
426 | | Image         |     0.91265 |  0.93976 |
427 | | List          |     0.80849 |  0.86792 |
428 | | Table         |     0.84957 |  0.96104 |
429 | | Text          |     0.93019 |  0.94571 |
430 | | Title         |     0.92102 |  0.95404 |
431 | 
432 | Time per image - .13 seconds on GPU (A10).
433 | 
434 | **Methodology**
435 | 
436 | I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/PubLayNet), which was not in the training data.  I had to align publaynet labels with the surya layout labels.  I was then able to find coverage for each layout type:
437 | 
438 | - Precision - how well the predicted bboxes cover ground truth bboxes
439 | - Recall - how well ground truth bboxes cover predicted bboxes
440 | 
441 | ## Reading Order
442 | 
443 | 88% mean accuracy, and .4 seconds per image on an A10 GPU.  See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check.
444 | 
445 | **Methodology**
446 | 
447 | I benchmarked the reading order on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data.  Unfortunately, this dataset is fairly noisy, and not all the labels are correct.  It was very hard to find a dataset annotated with reading order and also layout information.  I wanted to avoid using a cloud service for the ground truth.
448 | 
449 | The accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct.
450 | 
451 | ## Table Recognition
452 | 
453 | | Model             |   Row Intersection |   Col Intersection |   Time Per Image |
454 | |-------------------|--------------------|--------------------|------------------|
455 | | Surya             |               1    |            0.98625 |          0.30202 |
456 | | Table transformer |               0.84 |            0.86857 |          0.08082 |
457 | 
458 | Higher is better for intersection, which the percentage of the actual row/column overlapped by the predictions.  This benchmark is mostly a sanity check - there is a more rigorous one in [marker](https://www.github.com/VikParuchuri/marker)
459 | 
460 | **Methodology**
461 | 
462 | The benchmark uses a subset of [Fintabnet](https://developer.ibm.com/exchanges/data/all/fintabnet/) from IBM.  It has labeled rows and columns.  After table recognition is run, the predicted rows and columns are compared to the ground truth.  There is an additional penalty for predicting too many or too few rows/columns.
463 | 
464 | ## LaTeX OCR
465 | 
466 | | Method | edit ⬇   | time taken (s) ⬇ |
467 | |--------|----------|------------------|
468 | | texify | 0.122617 | 35.6345          |
469 | 
470 | This inferences texify on a ground truth set of LaTeX, then does edit distance.  This is a bit noisy, since 2 LaTeX strings that render the same can have different symbols in them.
471 | 
472 | ## Running your own benchmarks
473 | 
474 | You can benchmark the performance of surya on your machine.
475 | 
476 | - Follow the manual install instructions above.
477 | - `poetry install --group dev` - installs dev dependencies
478 | 
479 | **Text line detection**
480 | 
481 | This will evaluate tesseract and surya for text line detection across a randomly sampled set of images from [doclaynet](https://huggingface.co/datasets/vikp/doclaynet_bench).
482 | 
483 | ```shell
484 | python benchmark/detection.py --max_rows 256
485 | ```
486 | 
487 | - `--max_rows` controls how many images to process for the benchmark
488 | - `--debug` will render images and detected bboxes
489 | - `--pdf_path` will let you specify a pdf to benchmark instead of the default data
490 | - `--results_dir` will let you specify a directory to save results to instead of the default one
491 | 
492 | **Text recognition**
493 | 
494 | This will evaluate surya and optionally tesseract on multilingual pdfs from common crawl (with synthetic data for missing languages).
495 | 
496 | ```shell
497 | python benchmark/recognition.py --tesseract
498 | ```
499 | 
500 | - `--max_rows` controls how many images to process for the benchmark
501 | - `--debug 2` will render images with detected text
502 | - `--results_dir` will let you specify a directory to save results to instead of the default one
503 | - `--tesseract` will run the benchmark with tesseract.  You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder.
504 | 
505 | - Set `RECOGNITION_BATCH_SIZE=864` to use the same batch size as the benchmark.
506 | - Set `RECOGNITION_BENCH_DATASET_NAME=vikp/rec_bench_hist` to use the historical document data for benchmarking.  This data comes from the [tapuscorpus](https://github.com/HTR-United/tapuscorpus).
507 | 
508 | **Layout analysis**
509 | 
510 | This will evaluate surya on the publaynet dataset.
511 | 
512 | ```shell
513 | python benchmark/layout.py
514 | ```
515 | 
516 | - `--max_rows` controls how many images to process for the benchmark
517 | - `--debug` will render images with detected text
518 | - `--results_dir` will let you specify a directory to save results to instead of the default one
519 | 
520 | **Reading Order**
521 | 
522 | ```shell
523 | python benchmark/ordering.py
524 | ```
525 | 
526 | - `--max_rows` controls how many images to process for the benchmark
527 | - `--debug` will render images with detected text
528 | - `--results_dir` will let you specify a directory to save results to instead of the default one
529 | 
530 | **Table Recognition**
531 | 
532 | ```shell
533 | python benchmark/table_recognition.py --max_rows 1024 --tatr
534 | ```
535 | 
536 | - `--max_rows` controls how many images to process for the benchmark
537 | - `--debug` will render images with detected text
538 | - `--results_dir` will let you specify a directory to save results to instead of the default one
539 | - `--tatr` specifies whether to also run table transformer
540 | 
541 | **LaTeX OCR**
542 | 
543 | ```shell
544 | python benchmark/texify.py --max_rows 128
545 | ```
546 | 
547 | - `--max_rows` controls how many images to process for the benchmark
548 | - `--results_dir` will let you specify a directory to save results to instead of the default one
549 | 
550 | # Training
551 | 
552 | Text detection was trained on 4x A6000s for 3 days.  It used a diverse set of images as training data.  It was trained from scratch using a modified efficientvit architecture for semantic segmentation.
553 | 
554 | Text recognition was trained on 4x A6000s for 2 weeks.  It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes).
555 | 
556 | # Finetuning Surya OCR
557 | You can now take Surya OCR further by training it on your own data with our [finetuning script](/surya/scripts/finetune_ocr.py).
558 | It’s built on Hugging Face Trainer, and supports all the [arguments](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments) that the huggingface trainer provides, and integrations like torchrun, or deepspeed.
559 | 
560 | To setup your dataset, follow the example dataset format [here](https://huggingface.co/datasets/datalab-to/ocr_finetune_example) and provide the path to your own dataset when launching the training script.
561 | ```bash
562 | # Tested on 1xH100 GPU
563 | # Set --pretrained_checkpoint_path to load from a custom checkpoint, otherwise
564 | # the default surya ocr weights will be loaded as the initialization
565 | python surya/scripts/finetune_ocr.py \
566 |   --output_dir $OUTPUT_DIR \
567 |   --dataset_name datalab-to/ocr_finetune_example \
568 |   --per_device_train_batch_size 64 \
569 |   --gradient_checkpointing true \
570 |   --max_sequence_length 1024
571 | ```
572 | 
573 | This is a minimal training script to get you started finetuning Surya. Our internal training stack includes character bounding box finetuning, sliding window attention with specialized attention masks, custom kernels, augmentations, and other optimizations that can push OCR accuracy well beyond standard finetuning. If you want to get the most out of your data, reach us at [email protected]!
574 | 
575 | # Thanks
576 | 
577 | This work would not have been possible without amazing open source AI work:
578 | 
579 | - [Segformer](https://arxiv.org/pdf/2105.15203.pdf) from NVIDIA
580 | - [EfficientViT](https://github.com/mit-han-lab/efficientvit) from MIT
581 | - [timm](https://github.com/huggingface/pytorch-image-models) from Ross Wightman
582 | - [Donut](https://github.com/clovaai/donut) from Naver
583 | - [transformers](https://github.com/huggingface/transformers) from huggingface
584 | - [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model
585 | 
586 | Thank you to everyone who makes open source AI possible.
587 | 
588 | # Citation
589 | 
590 | If you use surya (or the associated models) in your work or research, please consider citing us using the following BibTeX entry:
591 | 
592 | ```bibtex
593 | @misc{paruchuri2025surya,
594 |   author       = {Vikas Paruchuri and Datalab Team},
595 |   title        = {Surya: A lightweight document OCR and analysis toolkit},
596 |   year         = {2025},
597 |   howpublished = {\url{https://github.com/VikParuchuri/surya}},
598 |   note         = {GitHub repository},
599 | }
600 | 
```

--------------------------------------------------------------------------------
/benchmark/utils/__init__.py:
--------------------------------------------------------------------------------

```python
1 | 
```

--------------------------------------------------------------------------------
/surya/__init__.py:
--------------------------------------------------------------------------------

```python
1 | 
```

--------------------------------------------------------------------------------
/surya/detection/model/__init__.py:
--------------------------------------------------------------------------------

```python
1 | 
```

--------------------------------------------------------------------------------
/surya/foundation/cache/__init__.py:
--------------------------------------------------------------------------------

```python
1 | 
```

--------------------------------------------------------------------------------
/surya/ocr_error/model/__init__.py:
--------------------------------------------------------------------------------

```python
1 | 
```

--------------------------------------------------------------------------------
/surya/scripts/__init__.py:
--------------------------------------------------------------------------------

```python
1 | 
```

--------------------------------------------------------------------------------
/surya/table_rec/model/__init__.py:
--------------------------------------------------------------------------------

```python
1 | 
```

--------------------------------------------------------------------------------
/surya/common/__init__.py:
--------------------------------------------------------------------------------

```python
1 | 
2 | 
3 | 
4 | 
```

--------------------------------------------------------------------------------
/ocr_text.py:
--------------------------------------------------------------------------------

```python
1 | from surya.scripts.ocr_text import ocr_text_cli
2 | 
3 | if __name__ == "__main__":
4 |     ocr_text_cli()
5 | 
```

--------------------------------------------------------------------------------
/ocr_latex.py:
--------------------------------------------------------------------------------

```python
1 | from surya.scripts.ocr_latex import ocr_latex_cli
2 | 
3 | if __name__ == "__main__":
4 |     ocr_latex_cli()
5 | 
```

--------------------------------------------------------------------------------
/texify_app.py:
--------------------------------------------------------------------------------

```python
1 | from surya.scripts.run_texify_app import texify_app_cli
2 | 
3 | if __name__ == "__main__":
4 |     texify_app_cli()
```

--------------------------------------------------------------------------------
/detect_layout.py:
--------------------------------------------------------------------------------

```python
1 | from surya.scripts.detect_layout import detect_layout_cli
2 | 
3 | if __name__ == "__main__":
4 |     detect_layout_cli()
5 | 
```

--------------------------------------------------------------------------------
/detect_text.py:
--------------------------------------------------------------------------------

```python
 1 | from surya.scripts.detect_text import detect_text_cli
 2 | 
 3 | if __name__ == "__main__":
 4 |     detect_text_cli()
 5 | 
 6 | 
 7 | 
 8 | 
 9 | 
10 | 
11 | 
12 | 
```

--------------------------------------------------------------------------------
/ocr_app.py:
--------------------------------------------------------------------------------

```python
1 | from surya.scripts.run_streamlit_app import streamlit_app_cli
2 | 
3 | if __name__ == "__main__":
4 |     streamlit_app_cli()
```

--------------------------------------------------------------------------------
/table_recognition.py:
--------------------------------------------------------------------------------

```python
1 | from surya.scripts.table_recognition import table_recognition_cli
2 | 
3 | if __name__ == "__main__":
4 |     table_recognition_cli()
```

--------------------------------------------------------------------------------
/surya/ocr_error/schema.py:
--------------------------------------------------------------------------------

```python
1 | from typing import List
2 | 
3 | from pydantic import BaseModel
4 | 
5 | 
6 | class OCRErrorDetectionResult(BaseModel):
7 |     texts: List[str]
8 |     labels: List[str]
9 | 
```

--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------

```
1 | [pytest]
2 | testpaths=tests
3 | pythonpath=.
4 | filterwarnings =
5 |     ignore::UserWarning
6 |     ignore::PendingDeprecationWarning
7 |     ignore::DeprecationWarning
```

--------------------------------------------------------------------------------
/surya/detection/schema.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import List, Optional, Any
 2 | 
 3 | from pydantic import BaseModel
 4 | 
 5 | from surya.common.polygon import PolygonBox
 6 | 
 7 | 
 8 | class TextDetectionResult(BaseModel):
 9 |     bboxes: List[PolygonBox]
10 |     heatmap: Optional[Any]
11 |     affinity_map: Optional[Any]
12 |     image_bbox: List[float]
13 | 
```

--------------------------------------------------------------------------------
/surya/scripts/run_texify_app.py:
--------------------------------------------------------------------------------

```python
1 | import subprocess
2 | import os
3 | 
4 | 
5 | def texify_app_cli():
6 |     cur_dir = os.path.dirname(os.path.abspath(__file__))
7 |     ocr_app_path = os.path.join(cur_dir, "texify_app.py")
8 |     cmd = ["streamlit", "run", ocr_app_path, "--server.fileWatcherType", "none", "--server.headless", "true"]
9 |     subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"})
```

--------------------------------------------------------------------------------
/surya/scripts/run_streamlit_app.py:
--------------------------------------------------------------------------------

```python
1 | import subprocess
2 | import os
3 | 
4 | 
5 | def streamlit_app_cli():
6 |     cur_dir = os.path.dirname(os.path.abspath(__file__))
7 |     ocr_app_path = os.path.join(cur_dir, "streamlit_app.py")
8 |     cmd = ["streamlit", "run", ocr_app_path, "--server.fileWatcherType", "none", "--server.headless", "true"]
9 |     subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"})
```

--------------------------------------------------------------------------------
/surya/common/surya/schema.py:
--------------------------------------------------------------------------------

```python
 1 | class TaskNames:
 2 |     block_without_boxes = "block_without_boxes"
 3 |     ocr_with_boxes = "ocr_with_boxes"
 4 |     ocr_without_boxes = "ocr_without_boxes"
 5 |     layout = "layout"
 6 |     table_structure = "table_structure"
 7 | 
 8 | 
 9 | TASK_NAMES = [
10 |     TaskNames.block_without_boxes,
11 |     TaskNames.ocr_with_boxes,
12 |     TaskNames.ocr_without_boxes,
13 |     TaskNames.layout,
14 |     TaskNames.table_structure,
15 | ]
16 | 
```

--------------------------------------------------------------------------------
/surya/layout/schema.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import Optional, Dict, List
 2 | 
 3 | from pydantic import BaseModel
 4 | 
 5 | from surya.common.polygon import PolygonBox
 6 | 
 7 | 
 8 | class LayoutBox(PolygonBox):
 9 |     label: str
10 |     position: int
11 |     top_k: Optional[Dict[str, float]] = None
12 | 
13 | 
14 | class LayoutResult(BaseModel):
15 |     bboxes: List[LayoutBox]
16 |     image_bbox: List[float]
17 |     sliced: bool = False  # Whether the image was sliced and reconstructed
18 | 
```

--------------------------------------------------------------------------------
/surya/detection/parallel.py:
--------------------------------------------------------------------------------

```python
 1 | class FakeFuture:
 2 |     def __init__(self, func, *args, **kwargs):
 3 |         self._result = func(*args, **kwargs)
 4 | 
 5 |     def result(self):
 6 |         return self._result
 7 | 
 8 | class FakeExecutor:
 9 |     def __init__(self, **kwargs):
10 |         pass
11 | 
12 |     def __enter__(self):
13 |         return self
14 | 
15 |     def __exit__(self, *excinfo):
16 |         pass
17 | 
18 |     def submit(self, fn, *args, **kwargs):
19 |         return FakeFuture(fn, *args, **kwargs)
20 | 
```

--------------------------------------------------------------------------------
/tests/test_layout.py:
--------------------------------------------------------------------------------

```python
 1 | def test_layout_topk(layout_predictor, test_image):
 2 |     layout_results = layout_predictor([test_image])
 3 | 
 4 |     assert len(layout_results) == 1
 5 |     assert layout_results[0].image_bbox == [0, 0, 1024, 1024]
 6 | 
 7 |     bboxes = layout_results[0].bboxes
 8 |     assert len(bboxes) == 2
 9 | 
10 |     assert bboxes[0].label == "SectionHeader"
11 |     assert len(bboxes[0].top_k) == 5
12 | 
13 |     assert bboxes[1].label == "Text"
14 |     assert len(bboxes[1].top_k) == 5
15 | 
```

--------------------------------------------------------------------------------
/tests/test_foundation.py:
--------------------------------------------------------------------------------

```python
 1 | from surya.foundation import FoundationPredictor
 2 | 
 3 | 
 4 | def test_foundation_flash2():
 5 |     try:
 6 |         f = FoundationPredictor(None, None, None, "flash_attention_2")
 7 |         assert f.model.decoder.config._attn_implementation == "flash_attention_2"
 8 |         assert f.model.vision_encoder.config._attn_implementation == "flash_attention_2"
 9 |     except Exception as e:
10 |         assert False, (
11 |             f"FoundationPredictor with flash_attention_2 raised an exception: {e}"
12 |         )
13 | 
```

--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------

```yaml
 1 | name: Unit tests
 2 | 
 3 | on: [push]
 4 | 
 5 | jobs:
 6 |   build:
 7 |     runs-on: ${{ matrix.os }}
 8 |     strategy:
 9 |       matrix:
10 |         os: [t4_gpu, ubuntu-latest, windows-latest]
11 |     steps:
12 |       - uses: actions/checkout@v3
13 |       - name: Set up Python 3.11
14 |         uses: actions/setup-python@v4
15 |         with:
16 |           python-version: 3.11
17 |       - name: Install python dependencies
18 |         run: |
19 |           pip install poetry
20 |           poetry install
21 |       - name: Run tests
22 |         run: poetry run pytest
```

--------------------------------------------------------------------------------
/surya/layout/label.py:
--------------------------------------------------------------------------------

```python
 1 | LAYOUT_PRED_RELABEL = {
 2 |     "<page-header>": "PageHeader",
 3 |     "<page-footer>": "PageFooter",
 4 |     "<footnote>": "Footnote",
 5 |     "<image>": "Picture",
 6 |     "<figure>": "Figure",
 7 |     "<text>": "Text",
 8 |     "<caption>": "Caption",
 9 |     "<list-item>": "ListItem",
10 |     "<section-header>": "SectionHeader",
11 |     "<table>": "Table",
12 |     "<table-of-contents>": "TableOfContents",
13 |     "<form>": "Form",
14 |     "<equation-block>": "Equation",
15 |     "<code-block>": "Code",
16 |     "<complex-block>": "Figure",
17 | }
18 | 
```

--------------------------------------------------------------------------------
/tests/test_ocr_errors.py:
--------------------------------------------------------------------------------

```python
 1 | def test_garbled_text(ocr_error_predictor):
 2 |     text = """"
 3 |     ; dh vksj ls mifLFkr vf/koDrk % Jh vfuy dqekj
 4 |     2. vfHk;qDr dh vksj ls mifLFkr vf/koDrk % Jh iznhi d
 5 |     """.strip()
 6 |     results = ocr_error_predictor([text])
 7 |     assert results.labels[0] == "bad"
 8 | 
 9 | 
10 | def test_good_text(ocr_error_predictor):
11 |     text = """"
12 |     There are professions more harmful than industrial design, but only a very few of them.
13 |     """.strip()
14 |     results = ocr_error_predictor([text])
15 |     assert results.labels[0] == "good"
```

--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------

```markdown
 1 | ---
 2 | name: Feature request
 3 | about: Suggest an idea for this project
 4 | title: "[FEAT]"
 5 | labels: enhancement
 6 | assignees: ''
 7 | 
 8 | ---
 9 | 
10 | ## ✨ Is your feature request related to a problem?
11 | 
12 | A clear and concise description of what the problem is. 
13 | 
14 | ## 💡 Describe the Solution You'd Like
15 | 
16 | A concise description of what you want to happen or how you envision it working.
17 | 
18 | ## 📋 Alternatives Considered
19 | 
20 | Any alternative solutions or workarounds you've tried.
21 | 
22 | ## 🧩 Additional Context
23 | 
24 | Any additional context, references, or related issues.
25 | 
```

--------------------------------------------------------------------------------
/surya/common/xla.py:
--------------------------------------------------------------------------------

```python
 1 | import math
 2 | from surya.settings import settings
 3 | 
 4 | if settings.TORCH_DEVICE_MODEL == "xla":
 5 |     import torch_xla.core.xla_model as xm
 6 | else:
 7 |     xm = None
 8 | 
 9 | 
10 | def get_nearest_pad(
11 |     length: int, pad_multiple: int = settings.FOUNDATION_PAD_TO_NEAREST
12 | ):
13 |     return math.ceil(length / pad_multiple) * pad_multiple
14 | 
15 | 
16 | def get_compile_args(device: str) -> dict:
17 |     if not settings.FOUNDATION_XLA:
18 |         return {}
19 | 
20 |     return {
21 |         "backend": "openxla",
22 |     }
23 | 
24 | 
25 | def mark_step():
26 |     if xm is not None:
27 |         xm.mark_step()
28 | 
```

--------------------------------------------------------------------------------
/tests/test_latex_ocr.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import List
 2 | 
 3 | from PIL import Image, ImageDraw
 4 | 
 5 | from surya.common.surya.schema import TaskNames
 6 | from surya.recognition import OCRResult
 7 | 
 8 | 
 9 | def test_latex_ocr(recognition_predictor, test_image_latex):
10 |     width, height = test_image_latex.size
11 |     results: List[OCRResult] = recognition_predictor(
12 |         [test_image_latex], [TaskNames.block_without_boxes], bboxes=[[[0, 0, width, height]]]
13 |     )
14 |     text = results[0].text_lines[0].text
15 |     assert len(results) == 1
16 | 
17 |     assert text.startswith("<math")
18 |     assert text.endswith("</math>")
19 | 
```

--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------

```yaml
 1 | name: Python package
 2 | on:
 3 |   push:
 4 |     tags:
 5 |       - "v*.*.*"
 6 | jobs:
 7 |   build:
 8 |     runs-on: ubuntu-latest
 9 |     steps:
10 |       - uses: actions/checkout@v3
11 |       - name: Set up Python 3.11
12 |         uses: actions/setup-python@v4
13 |         with:
14 |           python-version: 3.11
15 |       - name: Install python dependencies
16 |         run: |
17 |           pip install poetry
18 |           poetry install
19 |       - name: Build package
20 |         run: |
21 |           poetry build
22 |       - name: Publish package
23 |         env:
24 |           PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
25 |         run: |
26 |           poetry config pypi-token.pypi "$PYPI_TOKEN"
27 |           poetry publish
28 | 
```

--------------------------------------------------------------------------------
/tests/test_detection.py:
--------------------------------------------------------------------------------

```python
 1 | def test_detection(detection_predictor, test_image):
 2 |     detection_results = detection_predictor([test_image])
 3 | 
 4 |     assert len(detection_results) == 1
 5 |     assert detection_results[0].image_bbox == [0, 0, 1024, 1024]
 6 | 
 7 |     bboxes = detection_results[0].bboxes
 8 |     assert len(bboxes) == 4
 9 | 
10 | 
11 | def test_detection_chunking(detection_predictor, test_image_tall):
12 |     detection_results = detection_predictor([test_image_tall])
13 | 
14 |     assert len(detection_results) == 1
15 |     assert detection_results[0].image_bbox == [0, 0, 4096, 4096]
16 | 
17 |     bboxes = detection_results[0].bboxes
18 |     assert len(bboxes) >= 3 # Sometimes merges into 3
19 |     assert abs(4000 - bboxes[1].polygon[0][0]) < 50
```

--------------------------------------------------------------------------------
/surya/common/load.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import Optional, Any
 2 | 
 3 | import torch
 4 | 
 5 | from surya.settings import settings
 6 | 
 7 | 
 8 | class ModelLoader:
 9 |     def __init__(self, checkpoint: Optional[str] = None):
10 |         self.checkpoint = checkpoint
11 | 
12 |     def model(
13 |         self,
14 |         device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,
15 |         dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE,
16 |         attention_implementation: Optional[str] = None,
17 |     ) -> Any:
18 |         raise NotImplementedError()
19 | 
20 |     def processor(
21 |         self,
22 |         device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,
23 |         dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE,
24 |     ) -> Any:
25 |         raise NotImplementedError()
26 | 
```

--------------------------------------------------------------------------------
/surya/common/surya/processor/schema.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import TypedDict, Literal, List, Tuple
 2 | 
 3 | import torch
 4 | from PIL import Image
 5 | 
 6 | 
 7 | class TaskDict(TypedDict):
 8 |     datasets: List[str]
 9 |     img_size: Tuple[int, int]
10 | 
11 | 
12 | class TasksDict(TypedDict):
13 |     ocr_with_boxes: TaskDict
14 |     ocr_without_boxes: TaskDict
15 |     block_without_boxes: TaskDict
16 | 
17 | 
18 | class ProcessorInput(TypedDict):
19 |     type: Literal["image", "ocr", "text", "empty_output"]
20 | 
21 | 
22 | class ImageInput(ProcessorInput):
23 |     type: Literal["image"]
24 |     image: Image.Image
25 |     rotated: bool
26 | 
27 | 
28 | class TextInput(ProcessorInput):
29 |     type: Literal["text"]
30 |     text: str
31 |     math: bool
32 | 
33 | 
34 | class ProcessorOutput(TypedDict):
35 |     input_ids: List[int]
36 |     image_tiles: torch.Tensor | None
37 |     grid_thw: torch.Tensor | None
38 | 
```

--------------------------------------------------------------------------------
/surya/logging.py:
--------------------------------------------------------------------------------

```python
 1 | import logging
 2 | import warnings
 3 | from surya.settings import settings
 4 | 
 5 | 
 6 | def configure_logging():
 7 |     logger = get_logger()
 8 | 
 9 |     # Remove any existing handlers to prevent duplicates
10 |     for handler in logger.handlers[:]:
11 |         logger.removeHandler(handler)
12 | 
13 |     # Add our handler
14 |     handler = logging.StreamHandler()
15 |     formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
16 |     handler.setFormatter(formatter)
17 |     logger.addHandler(handler)
18 | 
19 |     # Prevent propagation to parent loggers to avoid double logging
20 |     logger.propagate = False
21 | 
22 |     logger.setLevel(settings.LOGLEVEL)
23 |     warnings.simplefilter(action="ignore", category=FutureWarning)
24 | 
25 | 
26 | def get_logger():
27 |     return logging.getLogger("surya")
28 | 
```

--------------------------------------------------------------------------------
/surya/common/pretrained.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import Optional
 2 | 
 3 | from transformers import PreTrainedModel
 4 | from transformers.utils import is_flash_attn_2_available
 5 | 
 6 | 
 7 | class SuryaPreTrainedModel(PreTrainedModel):
 8 |     # No-op if we pass attention, so we can set attention however we want in the config
 9 |     def _check_and_adjust_attn_implementation(
10 |         self, attn_implementation: Optional[str], **kwargs
11 |     ):
12 |         if attn_implementation is None:
13 |             try:
14 |                 self._sdpa_can_dispatch(True)
15 |                 attn_implementation = "sdpa"
16 |             except (ValueError, ImportError):
17 |                 attn_implementation = "eager"
18 | 
19 |             if self._supports_flash_attn and is_flash_attn_2_available():
20 |                 attn_implementation = "flash_attention_2"
21 | 
22 |         return attn_implementation
23 | 
```

--------------------------------------------------------------------------------
/surya/debug/fonts.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import List, Optional
 2 | import os
 3 | import requests
 4 | 
 5 | from surya.settings import settings
 6 | 
 7 | 
 8 | def get_font_path(langs: Optional[List[str]] = None) -> str:
 9 |     font_path = settings.RECOGNITION_RENDER_FONTS["all"]
10 |     if langs is not None:
11 |         for k in settings.RECOGNITION_RENDER_FONTS:
12 |             if k in langs and len(langs) == 1:
13 |                 font_path = settings.RECOGNITION_RENDER_FONTS[k]
14 |                 break
15 | 
16 |     if not os.path.exists(font_path):
17 |         os.makedirs(os.path.dirname(font_path), exist_ok=True)
18 |         font_dl_path = f"{settings.RECOGNITION_FONT_DL_BASE}/{os.path.basename(font_path)}"
19 |         with requests.get(font_dl_path, stream=True) as r, open(font_path, 'wb') as f:
20 |             r.raise_for_status()
21 |             for chunk in r.iter_content(chunk_size=8192):
22 |                 f.write(chunk)
23 | 
24 |     return font_path
```

--------------------------------------------------------------------------------
/surya/recognition/schema.py:
--------------------------------------------------------------------------------

```python
 1 | import math
 2 | import numpy as np
 3 | from typing import Optional, List
 4 | 
 5 | from pydantic import BaseModel, field_validator
 6 | 
 7 | from surya.common.polygon import PolygonBox
 8 | 
 9 | 
10 | class BaseChar(PolygonBox):
11 |     text: str
12 |     confidence: Optional[float] = 0
13 | 
14 |     @field_validator("confidence", mode="before")
15 |     @classmethod
16 |     def validate_confidence(cls, v: float) -> float:
17 |         if v is None:
18 |             return 0
19 |         elif math.isnan(v) or np.isnan(v):
20 |             return 0
21 |         return v
22 | 
23 | 
24 | class TextChar(BaseChar):
25 |     bbox_valid: bool = True  # This is false when the given bbox is not valid
26 | 
27 | 
28 | class TextWord(BaseChar):
29 |     bbox_valid: bool = True
30 | 
31 | 
32 | class TextLine(BaseChar):
33 |     chars: List[TextChar]  # Individual characters in the line
34 |     original_text_good: bool = False
35 |     words: List[TextWord] | None = None
36 | 
37 | 
38 | class OCRResult(BaseModel):
39 |     text_lines: List[TextLine]
40 |     image_bbox: List[float]
41 | 
```

--------------------------------------------------------------------------------
/surya/table_rec/schema.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import List
 2 | 
 3 | from pydantic import BaseModel
 4 | 
 5 | from surya.common.polygon import PolygonBox
 6 | 
 7 | 
 8 | class TableCell(PolygonBox):
 9 |     row_id: int
10 |     colspan: int
11 |     within_row_id: int
12 |     cell_id: int
13 |     is_header: bool
14 |     rowspan: int | None = None
15 |     merge_up: bool = False
16 |     merge_down: bool = False
17 |     col_id: int | None = None
18 |     text_lines: List[dict] | None = None
19 | 
20 |     @property
21 |     def label(self):
22 |         return f'Cell {self.cell_id} {self.rowspan}/{self.colspan}'
23 | 
24 | 
25 | class TableRow(PolygonBox):
26 |     row_id: int
27 |     is_header: bool
28 | 
29 |     @property
30 |     def label(self):
31 |         return f'Row {self.row_id}'
32 | 
33 | 
34 | class TableCol(PolygonBox):
35 |     col_id: int
36 |     is_header: bool
37 | 
38 |     @property
39 |     def label(self):
40 |         return f'Column {self.col_id}'
41 | 
42 | 
43 | class TableResult(BaseModel):
44 |     cells: List[TableCell]
45 |     unmerged_cells: List[TableCell]
46 |     rows: List[TableRow]
47 |     cols: List[TableCol]
48 |     image_bbox: List[float]
49 | 
```

--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/output-bug-report.md:
--------------------------------------------------------------------------------

```markdown
 1 | ---
 2 | name: Output bug report
 3 | about: Create a report about poor output quality
 4 | title: "[BUG: Output]"
 5 | labels: 'bug: output'
 6 | assignees: ''
 7 | 
 8 | ---
 9 | 
10 | ## 📝 Describe the Output Issue
11 | 
12 | A clear and concise description of the incorrect or unexpected output.
13 | 
14 | ## 📄 Input Document
15 | 
16 | Attach the PDF or input file used.
17 | 
18 | ## 📤 Current Output
19 | 
20 | Paste the Markdown or HTML that Marker generated:
21 | 
22 | ````markdown
23 | Paste output here
24 | `````
25 | 
26 | ## ✅ Expected Output
27 | 
28 | Describe or paste what you expected Marker to generate.
29 | 
30 | ## ⚙️ Environment
31 | 
32 | Please fill in all relevant details:
33 | 
34 | * **Marker version**:
35 | * **Surya version**:
36 | * **Python version**:
37 | * **PyTorch version**:
38 | * **Transformers version**:
39 | * **Operating System**:
40 | 
41 | ## 📟 Command or Code Used
42 | 
43 | Paste the **exact bash command** or **Python code** you used to run Marker:
44 | 
45 | <details>
46 | <summary>Click to expand</summary>
47 | 
48 | ```bash
49 | # or Python code block
50 | your_command_here --with-flags
51 | ```
52 | 
53 | </details>
54 | 
55 | ## 📎 Additional Context
56 | 
57 | Any other relevant info, configs, or assumptions.
58 | 
```

--------------------------------------------------------------------------------
/benchmark/utils/textract.py:
--------------------------------------------------------------------------------

```python
 1 | import os
 2 | from concurrent.futures import ThreadPoolExecutor
 3 | from tqdm import tqdm
 4 | import traceback
 5 | 
 6 | from surya.input.processing import slice_bboxes_from_image
 7 | from surya.recognition import RecognitionPredictor
 8 | 
 9 | def textract_ocr(extractor, img):
10 |     try:
11 |         document = extractor.detect_document_text(file_source=img)
12 |         return [line.text for line in document.lines]
13 |     except:
14 |         traceback.print_exc()
15 |         return [None]
16 | 
17 | def textract_ocr_parallel(imgs, cpus=None):
18 |     from textractor import Textractor # Optional dependency
19 | 
20 |     extractor = Textractor(profile_name='default')
21 |     parallel_cores = min(len(imgs), RecognitionPredictor().get_batch_size())
22 |     if not cpus:
23 |         cpus = os.cpu_count()
24 |     parallel_cores = min(parallel_cores, cpus)
25 | 
26 |     with ThreadPoolExecutor(max_workers=parallel_cores) as executor:
27 |         textract_text = tqdm(executor.map(textract_ocr, [extractor]*len(imgs), imgs), total=len(imgs), desc="Running textract OCR")
28 |         textract_text = list(textract_text)
29 |     return textract_text
```

--------------------------------------------------------------------------------
/surya/models.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import Dict
 2 | 
 3 | import torch
 4 | 
 5 | from surya.common.predictor import BasePredictor
 6 | from surya.detection import DetectionPredictor
 7 | from surya.layout import LayoutPredictor
 8 | from surya.logging import configure_logging
 9 | from surya.ocr_error import OCRErrorPredictor
10 | from surya.foundation import FoundationPredictor
11 | from surya.recognition import RecognitionPredictor
12 | from surya.table_rec import TableRecPredictor
13 | from surya.settings import settings
14 | 
15 | configure_logging()
16 | 
17 | 
18 | def load_predictors(
19 |     device: str | torch.device | None = None, dtype: torch.dtype | str | None = None
20 | ) -> Dict[str, BasePredictor]:
21 |     return {
22 |         "layout": LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)),
23 |         "ocr_error": OCRErrorPredictor(device=device, dtype=dtype),
24 |         "recognition": RecognitionPredictor(FoundationPredictor(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT)),
25 |         "detection": DetectionPredictor(device=device, dtype=dtype),
26 |         "table_rec": TableRecPredictor(device=device, dtype=dtype),
27 |     }
28 | 
```

--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/breaking-bug-report.md:
--------------------------------------------------------------------------------

```markdown
 1 | ---
 2 | name: Breaking bug report
 3 | about: Create a report about a breaking bug
 4 | title: "[BUG: Breaking]"
 5 | labels: 'bug: breaking'
 6 | assignees: ''
 7 | 
 8 | ---
 9 | 
10 | ## 🧨 Describe the Bug
11 | 
12 | A clear and concise description of the breaking issue (e.g., crash, OOM, exception, etc).
13 | 
14 | ## 📄 Input Document
15 | 
16 | Attach the PDF or input file that triggered the error.
17 | 
18 | ## 📤 Output Trace / Stack Trace
19 | 
20 | Paste the **complete** stack trace or error output, if available.
21 | 
22 | <details>
23 | <summary>Click to expand</summary>
24 | 
25 | ```
26 | Paste stack trace here
27 | ```
28 | 
29 | </details>
30 | 
31 | ## ⚙️ Environment
32 | 
33 | Please fill in all relevant details:
34 | 
35 | - **Marker version**: 
36 | - **Surya version**: 
37 | - **Python version**: 
38 | - **PyTorch version**: 
39 | - **Transformers version**: 
40 | - **Operating System** (incl. container info if relevant): 
41 | 
42 | ## ✅ Expected Behavior
43 | 
44 | What did you expect Marker to do?
45 | 
46 | ## 📟 Command or Code Used
47 | 
48 | Paste the **exact bash command** or **Python code** you used to run Marker:
49 | 
50 | <details>
51 | <summary>Click to expand</summary>
52 | 
53 | ```bash
54 | # or Python code block
55 | your_command_here --with-flags
56 | ```
57 | 
58 | </details>
59 | 
60 | ## 📎 Additional Context
61 | 
62 | Any other context that might help us debug this (e.g., CLI options, working directory, runtime settings).
63 | 
```

--------------------------------------------------------------------------------
/surya/detection/util.py:
--------------------------------------------------------------------------------

```python
 1 | import math
 2 | from PIL import ImageOps
 3 | 
 4 | from surya.settings import settings
 5 | 
 6 | 
 7 | def get_total_splits(image_size, height):
 8 |     img_height = list(image_size)[1]
 9 |     max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT
10 |     if img_height > max_height:
11 |         num_splits = math.ceil(img_height / height)
12 |         return num_splits
13 |     return 1
14 | 
15 | 
16 | def split_image(img, height):
17 |     # This will not modify/return the original image - it will either crop, or copy the image
18 |     img_height = list(img.size)[1]
19 |     max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT
20 |     if img_height > max_height:
21 |         num_splits = math.ceil(img_height / height)
22 |         splits = []
23 |         split_heights = []
24 |         for i in range(num_splits):
25 |             top = i * height
26 |             bottom = (i + 1) * height
27 |             if bottom > img_height:
28 |                 bottom = img_height
29 |             cropped = img.crop((0, top, img.size[0], bottom))
30 |             chunk_height = bottom - top
31 |             if chunk_height < height:
32 |                 cropped = ImageOps.pad(cropped, (img.size[0], height), color=255, centering=(0, 0))
33 |             splits.append(cropped)
34 |             split_heights.append(chunk_height)
35 |         return splits, split_heights
36 |     return [img.copy()], [img_height]
37 | 
```

--------------------------------------------------------------------------------
/benchmark/utils/scoring.py:
--------------------------------------------------------------------------------

```python
 1 | import math
 2 | from typing import List
 3 | 
 4 | from rapidfuzz import fuzz
 5 | 
 6 | 
 7 | def overlap_score(pred_lines: List[str], reference_lines: List[str]):
 8 |     line_scores = []
 9 |     line_weights = []
10 |     line_match = {}
11 |     for i, pred_line in enumerate(pred_lines):
12 |         max_score = 0
13 |         line_weight = 1
14 |         match = None
15 |         for j, ref_line in enumerate(reference_lines):
16 |             score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100
17 |             if score > max_score:
18 |                 max_score = score
19 |                 line_weight = math.sqrt(len(ref_line))
20 |                 match = j
21 |         line_scores.append(max_score)
22 |         line_weights.append(line_weight)
23 |         line_match[i] = match
24 |     line_scores = [line_scores[i] * line_weights[i] for i in range(len(line_scores))]
25 | 
26 |     return line_scores, line_weights, line_match
27 | 
28 | 
29 | def overlap_score_exact(pred_lines: List[str], reference_lines: List[str]):
30 |     line_scores = []
31 |     line_weights = []
32 |     assert len(pred_lines) == len(reference_lines)
33 | 
34 |     for i, (pred_line, ref_line) in enumerate(zip(pred_lines, reference_lines)):
35 |         score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100
36 |         weight = math.sqrt(len(ref_line))
37 |         line_scores.append(score * weight)
38 |         line_weights.append(weight)
39 | 
40 |     return line_scores, line_weights
41 | 
```

--------------------------------------------------------------------------------
/.github/workflows/cla.yml:
--------------------------------------------------------------------------------

```yaml
 1 | name: "Surya CLA Assistant"
 2 | on:
 3 |   issue_comment:
 4 |     types: [created]
 5 |   pull_request_target:
 6 |     types: [opened,closed,synchronize]
 7 | 
 8 | # explicitly configure permissions, in case your GITHUB_TOKEN workflow permissions are set to read-only in repository settings
 9 | permissions:
10 |   actions: write
11 |   contents: write
12 |   pull-requests: write
13 |   statuses: write
14 | 
15 | jobs:
16 |   CLAAssistant:
17 |     runs-on: ubuntu-latest
18 |     steps:
19 |       - name: "Surya CLA Assistant"
20 |         if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target'
21 |         uses: contributor-assistant/[email protected]
22 |         env:
23 |           GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
24 |           # the below token should have repo scope and must be manually added by you in the repository's secret
25 |           # This token is required only if you have configured to store the signatures in a remote repository/organization
26 |           PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
27 |         with:
28 |           path-to-signatures: 'signatures/version1/cla.json'
29 |           path-to-document: 'https://github.com/VikParuchuri/surya/blob/master/CLA.md'
30 |           # branch should not be protected
31 |           branch: 'master'
32 |           allowlist: VikParuchuri
```

--------------------------------------------------------------------------------
/.github/workflows/scripts.yml:
--------------------------------------------------------------------------------

```yaml
 1 | name: Test CLI scripts
 2 | 
 3 | on: [push]
 4 | 
 5 | jobs:
 6 |   build:
 7 |     runs-on: t4_gpu
 8 |     steps:
 9 |       - uses: actions/checkout@v3
10 |       - name: Set up Python 3.11
11 |         uses: actions/setup-python@v4
12 |         with:
13 |           python-version: 3.11
14 |       - name: Install python dependencies
15 |         run: |
16 |           pip install poetry
17 |           poetry install
18 |       - name: Download benchmark data
19 |         run: |
20 |           wget -O benchmark_data.zip "https://drive.google.com/uc?export=download&id=1NHrdYatR1rtqs2gPVfdvO0BAvocH8CJi"
21 |           unzip -o benchmark_data.zip
22 |       - name: Test detection
23 |         run: poetry run surya_detect benchmark_data/pdfs/switch_trans.pdf --page_range 0
24 |       - name: Test OCR
25 |         env:
26 |           RECOGNITION_MAX_TOKENS: 25
27 |         run: poetry run surya_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0
28 |       - name: Test layout
29 |         run: poetry run surya_layout benchmark_data/pdfs/switch_trans.pdf --page_range 0
30 |       - name: Test table
31 |         run: poetry run surya_table benchmark_data/pdfs/switch_trans.pdf --page_range 0
32 |       - name: Test texify
33 |         env:
34 |           TEXIFY_MAX_TOKENS: 25
35 |         run: poetry run surya_latex_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0
36 |       - name: Test detection folder
37 |         run: poetry run surya_detect benchmark_data/pdfs --page_range 0
38 | 
```

--------------------------------------------------------------------------------
/surya/common/surya/encoder/config.py:
--------------------------------------------------------------------------------

```python
 1 | from transformers.configuration_utils import PretrainedConfig
 2 | from transformers.utils import logging
 3 | 
 4 | logger = logging.get_logger(__name__)
 5 | 
 6 | 
 7 | class SuryaEncoderConfig(PretrainedConfig):
 8 |     model_type = "qwen2_5_vl"
 9 |     base_config_key = "vision_config"
10 | 
11 |     attribute_map = {
12 |         "num_attention_heads": "num_heads",
13 |         "num_hidden_layers": "depth",
14 |     }
15 | 
16 |     def __init__(
17 |         self,
18 |         depth=8,
19 |         hidden_size=1280,
20 |         hidden_act="silu",
21 |         intermediate_size=3420,
22 |         num_heads=16,
23 |         in_channels=3,
24 |         patch_size=14,
25 |         spatial_merge_size=2,
26 |         spatial_patch_size=14,
27 |         temporal_patch_size=1,
28 |         tokens_per_second=4,
29 |         window_size=112,
30 |         out_hidden_size=1280,
31 |         fullatt_block_indexes=(3, 7),
32 |         initializer_range=0.02,
33 |         image_size=4096,
34 |         **kwargs,
35 |     ):
36 |         super().__init__(**kwargs)
37 | 
38 |         self.depth = depth
39 |         self.hidden_size = hidden_size
40 |         self.hidden_act = hidden_act
41 |         self.intermediate_size = intermediate_size
42 |         self.num_heads = num_heads
43 |         self.in_channels = in_channels
44 |         self.patch_size = patch_size
45 |         self.spatial_merge_size = spatial_merge_size
46 |         self.temporal_patch_size = temporal_patch_size
47 |         self.tokens_per_second = tokens_per_second
48 |         self.window_size = window_size
49 |         self.fullatt_block_indexes = fullatt_block_indexes
50 |         self.out_hidden_size = out_hidden_size
51 |         self.initializer_range = initializer_range
52 |         self.spatial_patch_size = spatial_patch_size
53 |         self.image_size = image_size
54 | 
```

--------------------------------------------------------------------------------
/surya/detection/model/config.py:
--------------------------------------------------------------------------------

```python
 1 | from transformers import PretrainedConfig
 2 | 
 3 | from surya.common.s3 import S3DownloaderMixin
 4 | 
 5 | 
 6 | class EfficientViTConfig(S3DownloaderMixin, PretrainedConfig):
 7 |     r"""
 8 |     ```"""
 9 | 
10 |     model_type = "efficientvit"
11 | 
12 |     def __init__(
13 |         self,
14 |         num_classes=2,
15 |         num_channels=3,
16 |         widths=(32, 64, 128, 256, 512),
17 |         head_dim=32,
18 |         num_stages=4,
19 |         depths=(1, 1, 1, 6, 6),
20 |         strides=(2, 2, 2, 2, 2),
21 |         hidden_sizes=(32, 64, 160, 256),
22 |         patch_size=(7, 7),
23 |         hidden_dropout_prob=0.0,
24 |         attention_probs_dropout_prob=0.0,
25 |         classifier_dropout_prob=0.0,
26 |         layer_norm_eps=1e-6,
27 |         decoder_layer_hidden_size=128,
28 |         decoder_hidden_size=512,
29 |         semantic_loss_ignore_index=255,
30 |         initializer_range=0.02,
31 |         **kwargs,
32 |     ):
33 |         super().__init__(**kwargs)
34 | 
35 |         self.num_classes = num_classes
36 |         self.widths = widths
37 |         self.head_dim = head_dim
38 | 
39 |         self.num_channels = num_channels
40 |         self.num_stages = num_stages
41 |         self.depths = depths
42 |         self.strides = strides
43 |         self.hidden_sizes = hidden_sizes
44 |         self.patch_size = patch_size
45 |         self.hidden_dropout_prob = hidden_dropout_prob
46 |         self.attention_probs_dropout_prob = attention_probs_dropout_prob
47 |         self.classifier_dropout_prob = classifier_dropout_prob
48 |         self.layer_norm_eps = layer_norm_eps
49 |         self.decoder_hidden_size = decoder_hidden_size
50 |         self.decoder_layer_hidden_size = decoder_layer_hidden_size
51 |         self.semantic_loss_ignore_index = semantic_loss_ignore_index
52 | 
53 |         self.initializer_range = initializer_range
```

--------------------------------------------------------------------------------
/tests/test_table_rec.py:
--------------------------------------------------------------------------------

```python
 1 | from PIL import Image, ImageDraw
 2 | 
 3 | def test_table_rec(table_rec_predictor):
 4 |     data = [
 5 |         ["Name", "Age", "City"],
 6 |         ["Alice", 25, "New York"],
 7 |         ["Bob", 30, "Los Angeles"],
 8 |         ["Charlie", 35, "Chicago"],
 9 |     ]
10 |     test_image = draw_table(data)
11 | 
12 |     results = table_rec_predictor([test_image])
13 |     assert len(results) == 1
14 |     assert results[0].image_bbox == [0, 0, test_image.size[0], test_image.size[1]]
15 | 
16 |     cells = results[0].cells
17 |     assert len(cells) == 12
18 |     for row_id in range(4):
19 |         for col_id in range(3):
20 |             cell = [c for c in cells if c.row_id == row_id and c.col_id == col_id]
21 |             assert len(cell) == 1, f"Missing cell at row {row_id}, col {col_id}"
22 | 
23 | def draw_table(data, cell_width=100, cell_height=40):
24 |     rows = len(data)
25 |     cols = len(data[0])
26 |     width = cols * cell_width
27 |     height = rows * cell_height
28 | 
29 |     image = Image.new('RGB', (width, height), 'white')
30 |     draw = ImageDraw.Draw(image)
31 | 
32 |     for i in range(rows + 1):
33 |         y = i * cell_height
34 |         draw.line([(0, y), (width, y)], fill='black', width=1)
35 | 
36 |     for i in range(cols + 1):
37 |         x = i * cell_width
38 |         draw.line([(x, 0), (x, height)], fill='black', width=1)
39 | 
40 |     for i in range(rows):
41 |         for j in range(cols):
42 |             text = str(data[i][j])
43 |             text_bbox = draw.textbbox((0, 0), text)
44 |             text_width = text_bbox[2] - text_bbox[0]
45 |             text_height = text_bbox[3] - text_bbox[1]
46 | 
47 |             x = j * cell_width + (cell_width - text_width) // 2
48 |             y = i * cell_height + (cell_height - text_height) // 2
49 | 
50 |             draw.text((x, y), text, fill='black')
51 | 
52 |     return image
```

--------------------------------------------------------------------------------
/benchmark/utils/bbox.py:
--------------------------------------------------------------------------------

```python
 1 | import fitz as pymupdf
 2 | from surya.common.util import rescale_bbox
 3 | 
 4 | 
 5 | def get_pdf_lines(pdf_path, img_sizes):
 6 |     doc = pymupdf.open(pdf_path)
 7 |     page_lines = []
 8 |     for idx, img_size in enumerate(img_sizes):
 9 |         page = doc[idx]
10 |         blocks = page.get_text("dict", sort=True, flags=pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES)["blocks"]
11 | 
12 |         line_boxes = []
13 |         for block_idx, block in enumerate(blocks):
14 |             for l in block["lines"]:
15 |                 line_boxes.append(list(l["bbox"]))
16 | 
17 |         page_box = page.bound()
18 |         pwidth, pheight = page_box[2] - page_box[0], page_box[3] - page_box[1]
19 |         line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes]
20 |         page_lines.append(line_boxes)
21 | 
22 |     return page_lines
23 | 
24 | def merge_boxes(box1, box2):
25 |     return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3]))
26 | 
27 | 
28 | def join_lines(bboxes, max_gap=5):
29 |     to_merge = {}
30 |     for i, box1 in bboxes:
31 |         for z, box2 in bboxes[i + 1:]:
32 |             j = i + z + 1
33 |             if box1 == box2:
34 |                 continue
35 | 
36 |             if box1[0] <= box2[0] and box1[2] >= box2[2]:
37 |                 if abs(box1[1] - box2[3]) <= max_gap:
38 |                     if i not in to_merge:
39 |                         to_merge[i] = []
40 |                     to_merge[i].append(j)
41 | 
42 |     merged_boxes = set()
43 |     merged = []
44 |     for i, box in bboxes:
45 |         if i in merged_boxes:
46 |             continue
47 | 
48 |         if i in to_merge:
49 |             for j in to_merge[i]:
50 |                 box = merge_boxes(box, bboxes[j][1])
51 |                 merged_boxes.add(j)
52 | 
53 |         merged.append(box)
54 |     return merged
55 | 
```

--------------------------------------------------------------------------------
/surya/scripts/ocr_latex.py:
--------------------------------------------------------------------------------

```python
 1 | import os
 2 | 
 3 | import click
 4 | import json
 5 | import time
 6 | from collections import defaultdict
 7 | 
 8 | from surya.logging import configure_logging, get_logger
 9 | from surya.scripts.config import CLILoader
10 | from surya.foundation import FoundationPredictor
11 | from surya.recognition import RecognitionPredictor
12 | from surya.common.surya.schema import TaskNames
13 | 
14 | configure_logging()
15 | logger = get_logger()
16 | 
17 | 
18 | @click.command(help="OCR LaTeX equations.")
19 | @CLILoader.common_options
20 | def ocr_latex_cli(input_path: str, **kwargs):
21 |     loader = CLILoader(input_path, kwargs, highres=True)
22 | 
23 |     foundation_predictor = FoundationPredictor()
24 |     texify_predictor = RecognitionPredictor(foundation_predictor)
25 |     tasks = [TaskNames.block_without_boxes] * len(loader.images)
26 |     bboxes = [[[0, 0, image.width, image.height]] for image in loader.images]
27 | 
28 |     start = time.time()
29 |     predictions_by_image = texify_predictor(
30 |         loader.images,
31 |         tasks,
32 |         bboxes=bboxes,
33 |     )
34 | 
35 |     latex_predictions = [p.text_lines[0].text for p in predictions_by_image]
36 | 
37 |     if loader.debug:
38 |         logger.debug(f"OCR took {time.time() - start:.2f} seconds")
39 |         max_chars = max([len(latex) for latex in latex_predictions])
40 |         logger.debug(f"Max chars: {max_chars}")
41 | 
42 |     out_preds = defaultdict(list)
43 |     for name, pred, image in zip(loader.names, latex_predictions, loader.images):
44 |         out_pred = {
45 |             "equation": pred,
46 |             "page": len(out_preds[name]) + 1,
47 |         }
48 |         out_preds[name].append(out_pred)
49 | 
50 |     with open(
51 |         os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8"
52 |     ) as f:
53 |         json.dump(out_preds, f, ensure_ascii=False)
54 | 
55 |     logger.info(f"Wrote results to {loader.result_path}")
56 | 
```

--------------------------------------------------------------------------------
/surya/foundation/util.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import List, Tuple
 2 | import numpy as np
 3 | import torch
 4 | 
 5 | def detect_repeat_token(predicted_tokens: List[int], max_repeats: int = 40):
 6 |     if len(predicted_tokens) < max_repeats:
 7 |         return False
 8 | 
 9 |     # Detect repeats containing 1 or 2 tokens
10 |     last_n = predicted_tokens[-max_repeats:]
11 |     unique_tokens = len(set(last_n))
12 |     if unique_tokens > 5:
13 |         return False
14 | 
15 |     return last_n[-unique_tokens:] == last_n[-unique_tokens * 2 : -unique_tokens]
16 | 
17 | def prediction_to_polygon_batch(
18 |     pred: torch.Tensor,
19 |     img_sizes: List[Tuple[int, int]],
20 |     bbox_scaler,
21 |     skew_scaler,
22 |     skew_min=0.001,
23 | ):
24 |     img_sizes = torch.from_numpy(np.array(img_sizes, dtype=np.float32)).to(
25 |         pred.device
26 |     )
27 |     w_scale = (img_sizes[:, 1] / bbox_scaler)[:, None, None]
28 |     h_scale = (img_sizes[:, 0] / bbox_scaler)[:, None, None]
29 | 
30 |     cx = pred[:, :, 0]
31 |     cy = pred[:, :, 1]
32 |     width = pred[:, :, 2]
33 |     height = pred[:, :, 3]
34 | 
35 |     x1 = cx - width / 2
36 |     y1 = cy - height / 2
37 |     x2 = cx + width / 2
38 |     y2 = cy + height / 2
39 | 
40 |     skew_x = torch.floor((pred[:, :, 4] - skew_scaler) / 2)
41 |     skew_y = torch.floor((pred[:, :, 5] - skew_scaler) / 2)
42 | 
43 |     skew_x[torch.abs(skew_x) < skew_min] = 0
44 |     skew_y[torch.abs(skew_y) < skew_min] = 0
45 | 
46 |     polygons_flat = torch.stack(
47 |         [
48 |             x1 - skew_x,
49 |             y1 - skew_y,
50 |             x2 - skew_x,
51 |             y1 + skew_y,
52 |             x2 + skew_x,
53 |             y2 + skew_y,
54 |             x1 + skew_x,
55 |             y2 - skew_y,
56 |         ],
57 |         dim=2,
58 |     )
59 | 
60 |     batch_size, seq_len, _ = pred.shape
61 |     polygons = polygons_flat.view(batch_size, seq_len, 4, 2)
62 | 
63 |     polygons[:, :, :, 0] *= w_scale
64 |     polygons[:, :, :, 1] *= h_scale
65 | 
66 |     return polygons
```

--------------------------------------------------------------------------------
/surya/scripts/detect_text.py:
--------------------------------------------------------------------------------

```python
 1 | import click
 2 | import copy
 3 | import json
 4 | import time
 5 | from collections import defaultdict
 6 | 
 7 | from surya.detection import DetectionPredictor
 8 | from surya.debug.draw import draw_polys_on_image
 9 | from surya.logging import configure_logging, get_logger
10 | from surya.scripts.config import CLILoader
11 | import os
12 | 
13 | configure_logging()
14 | logger = get_logger()
15 | 
16 | 
17 | @click.command(help="Detect bboxes in an input file or folder (PDFs or image).")
18 | @CLILoader.common_options
19 | def detect_text_cli(input_path: str, **kwargs):
20 |     loader = CLILoader(input_path, kwargs)
21 | 
22 |     det_predictor = DetectionPredictor()
23 | 
24 |     start = time.time()
25 |     predictions = det_predictor(loader.images, include_maps=loader.debug)
26 |     end = time.time()
27 |     if loader.debug:
28 |         logger.debug(f"Detection took {end - start} seconds")
29 | 
30 |     if loader.save_images:
31 |         for idx, (image, pred, name) in enumerate(
32 |             zip(loader.images, predictions, loader.names)
33 |         ):
34 |             polygons = [p.polygon for p in pred.bboxes]
35 |             bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image))
36 |             bbox_image.save(os.path.join(loader.result_path, f"{name}_{idx}_bbox.png"))
37 | 
38 |             if loader.debug:
39 |                 heatmap = pred.heatmap
40 |                 heatmap.save(os.path.join(loader.result_path, f"{name}_{idx}_heat.png"))
41 | 
42 |     predictions_by_page = defaultdict(list)
43 |     for idx, (pred, name, image) in enumerate(
44 |         zip(predictions, loader.names, loader.images)
45 |     ):
46 |         out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"])
47 |         out_pred["page"] = len(predictions_by_page[name]) + 1
48 |         predictions_by_page[name].append(out_pred)
49 | 
50 |     with open(
51 |         os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8"
52 |     ) as f:
53 |         json.dump(predictions_by_page, f, ensure_ascii=False)
54 | 
55 |     logger.info(f"Wrote results to {loader.result_path}")
56 | 
```

--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------

```toml
 1 | [tool.poetry]
 2 | name = "surya-ocr"
 3 | version = "0.17.0"
 4 | description = "OCR, layout, reading order, and table recognition in 90+ languages"
 5 | authors = ["Vik Paruchuri <[email protected]>"]
 6 | readme = "README.md"
 7 | license = "GPL-3.0-or-later"
 8 | repository = "https://github.com/VikParuchuri/surya"
 9 | keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"]
10 | packages = [
11 |     {include = "surya"}
12 | ]
13 | 
14 | [tool.poetry.dependencies]
15 | python = "^3.10"
16 | transformers = ">=4.56.1"
17 | torch = "^2.7.0"
18 | pydantic = "^2.5.3"
19 | pydantic-settings = "^2.1.0"
20 | python-dotenv = "^1.0.0"
21 | pillow = "^10.2.0"
22 | pypdfium2 = "=4.30.0"
23 | filetype = "^1.2.0"
24 | click = "^8.1.8"
25 | platformdirs = "^4.3.6"
26 | opencv-python-headless = "==4.11.0.86"
27 | einops = "^0.8.1"
28 | pre-commit = "^4.2.0"
29 | 
30 | [tool.poetry.group.dev.dependencies]
31 | jupyter = "^1.0.0"
32 | pytesseract = "^0.3.10"
33 | pymupdf = "^1.23.8"
34 | datasets = "^2.16.1"
35 | rapidfuzz = "^3.6.1"
36 | streamlit = "^1.31.0"
37 | pytest = "^8.3.4"
38 | pdftext = "^0.5.1"
39 | tabulate = "^0.9.0"
40 | 
41 | [tool.poetry.scripts]
42 | surya_detect = "surya.scripts.detect_text:detect_text_cli"
43 | surya_ocr = "surya.scripts.ocr_text:ocr_text_cli"
44 | surya_layout = "surya.scripts.detect_layout:detect_layout_cli"
45 | surya_gui = "surya.scripts.run_streamlit_app:streamlit_app_cli"
46 | surya_table = "surya.scripts.table_recognition:table_recognition_cli"
47 | surya_latex_ocr = "surya.scripts.ocr_latex:ocr_latex_cli"
48 | texify_gui = "surya.scripts.run_texify_app:texify_app_cli"
49 | 
50 | [build-system]
51 | requires = ["poetry-core"]
52 | build-backend = "poetry.core.masonry.api"
53 | 
54 | [[tool.poetry.source]]
55 | name = "libtpu-releases"
56 | url = "https://storage.googleapis.com/libtpu-releases/index.html"
57 | priority = "supplemental"
58 | 
59 | [[tool.poetry.source]]
60 | name = "libtpu-wheels"
61 | url = "https://storage.googleapis.com/libtpu-wheels/index.html"
62 | priority = "supplemental"
63 | 
64 | [tool.poetry.group.xla]
65 | optional = true
66 | 
67 | [tool.poetry.group.xla.dependencies]
68 | torch-xla = {version = "^2.4.1", extras = ["tpu"]}
69 | 
```

--------------------------------------------------------------------------------
/.github/workflows/benchmarks.yml:
--------------------------------------------------------------------------------

```yaml
 1 | name: Integration test
 2 | 
 3 | on: [push]
 4 | 
 5 | env:
 6 |   PYTHONIOENCODING: "utf-8"
 7 | 
 8 | jobs:
 9 |   build:
10 |     runs-on: t4_gpu
11 |     steps:
12 |       - uses: actions/checkout@v3
13 |       - name: Set up Python 3.11
14 |         uses: actions/setup-python@v4
15 |         with:
16 |           python-version: 3.11
17 |       - name: Install python dependencies
18 |         run: |
19 |           pip install poetry
20 |           poetry install
21 |       - name: Run detection benchmark test
22 |         run: |
23 |           poetry run python benchmark/detection.py --max_rows 2
24 |           poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection
25 |       - name: Run recognition benchmark test
26 |         run: |
27 |           poetry run python benchmark/recognition.py --max_rows 2
28 |           poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition
29 |       - name: Run layout benchmark test
30 |         run: |
31 |           poetry run python benchmark/layout.py --max_rows 5
32 |           poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout
33 |       - name: Run ordering benchmark
34 |         run: |
35 |           poetry run python benchmark/ordering.py --max_rows 5
36 |           poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering
37 |       - name: Run table recognition benchmark
38 |         run: |
39 |           poetry run python benchmark/table_recognition.py --max_rows 5
40 |           poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition
41 |       - name: Run texify benchmark
42 |         run: |
43 |           poetry run python benchmark/texify.py --max_rows 5
44 |           poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/texify_bench/results.json --bench_type texify
```

--------------------------------------------------------------------------------
/surya/ocr_error/model/config.py:
--------------------------------------------------------------------------------

```python
 1 | from collections import OrderedDict
 2 | from typing import Mapping
 3 | 
 4 | from transformers.configuration_utils import PretrainedConfig
 5 | from transformers.onnx import OnnxConfig
 6 | 
 7 | from surya.common.s3 import S3DownloaderMixin
 8 | 
 9 | ID2LABEL = {
10 |     0: 'good',
11 |     1: 'bad'
12 | }
13 | 
14 | class DistilBertConfig(S3DownloaderMixin, PretrainedConfig):
15 |     model_type = "distilbert"
16 |     attribute_map = {
17 |         "hidden_size": "dim",
18 |         "num_attention_heads": "n_heads",
19 |         "num_hidden_layers": "n_layers",
20 |     }
21 | 
22 |     def __init__(
23 |         self,
24 |         vocab_size=30522,
25 |         max_position_embeddings=512,
26 |         sinusoidal_pos_embds=False,
27 |         n_layers=6,
28 |         n_heads=12,
29 |         dim=768,
30 |         hidden_dim=4 * 768,
31 |         dropout=0.1,
32 |         attention_dropout=0.1,
33 |         activation="gelu",
34 |         initializer_range=0.02,
35 |         qa_dropout=0.1,
36 |         seq_classif_dropout=0.2,
37 |         pad_token_id=0,
38 |         **kwargs,
39 |     ):
40 |         self.vocab_size = vocab_size
41 |         self.max_position_embeddings = max_position_embeddings
42 |         self.sinusoidal_pos_embds = sinusoidal_pos_embds
43 |         self.n_layers = n_layers
44 |         self.n_heads = n_heads
45 |         self.dim = dim
46 |         self.hidden_dim = hidden_dim
47 |         self.dropout = dropout
48 |         self.attention_dropout = attention_dropout
49 |         self.activation = activation
50 |         self.initializer_range = initializer_range
51 |         self.qa_dropout = qa_dropout
52 |         self.seq_classif_dropout = seq_classif_dropout
53 |         super().__init__(**kwargs, pad_token_id=pad_token_id)
54 | 
55 | 
56 | class DistilBertOnnxConfig(OnnxConfig):
57 |     @property
58 |     def inputs(self) -> Mapping[str, Mapping[int, str]]:
59 |         if self.task == "multiple-choice":
60 |             dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
61 |         else:
62 |             dynamic_axis = {0: "batch", 1: "sequence"}
63 |         return OrderedDict(
64 |             [
65 |                 ("input_ids", dynamic_axis),
66 |                 ("attention_mask", dynamic_axis),
67 |             ]
68 |         )
```

--------------------------------------------------------------------------------
/surya/debug/katex.js:
--------------------------------------------------------------------------------

```javascript
 1 | <style>
 2 |     .katex-display-container {
 3 |         display: inline-block;
 4 |         max-width: 100%;
 5 |         overflow-x: auto;
 6 |         max-height: 100%;
 7 |     }
 8 | 
 9 |     .katex-inline-container {
10 |         display: inline-block;
11 |         max-width: 100%;
12 |         overflow-x: auto;
13 |         max-height: 100%;
14 |     }
15 | </style>
16 | <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.js" onload="setTimeout(function() {renderMath()})" async></script>
17 | <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css">
18 | <script>
19 |     function htmlUnescape(escapedText) {
20 |       const htmlEntities = {
21 |         '&amp;': '&',
22 |         '&lt;': '<',
23 |         '&gt;': '>',
24 |         '&quot;': '"',
25 |         '&#39;': "'",
26 |         '&nbsp;': ' '
27 |       };
28 | 
29 |       return escapedText.replace(/&amp;|&lt;|&gt;|&quot;|&#39;|&nbsp;/g, match => htmlEntities[match]);
30 |     }
31 | 
32 |     const renderMath = (function() {
33 |     try {
34 |        const mathElements = document.querySelectorAll('math');
35 | 
36 |         mathElements.forEach(function(element) {
37 |           let mathContent = element.innerHTML.trim();
38 |           mathContent = htmlUnescape(mathContent);
39 |           const isDisplay = element.getAttribute('display') === 'block';
40 | 
41 |           const container = document.createElement('span');
42 |           container.className = isDisplay ? 'katex-display-container' : 'katex-inline-container';
43 |           element.parentNode.insertBefore(container, element);
44 | 
45 |           try {
46 |             katex.render(mathContent, container, {
47 |               displayMode: isDisplay,
48 |               throwOnError: false
49 |             });
50 | 
51 |           } catch (err) {
52 |             console.error('KaTeX rendering error:', err);
53 |             container.textContent = mathContent; // Fallback to raw text
54 |           }
55 | 
56 |           element.parentNode.removeChild(element);
57 |         });
58 | 
59 |         console.log('Math rendering complete with', mathElements.length, 'expressions');
60 |       } catch (err) {
61 |         console.error('Error in renderMath function:', err);
62 |       }
63 |     });
64 | </script>
```

--------------------------------------------------------------------------------
/surya/scripts/detect_layout.py:
--------------------------------------------------------------------------------

```python
 1 | import time
 2 | import click
 3 | import copy
 4 | import json
 5 | from collections import defaultdict
 6 | 
 7 | from surya.foundation import FoundationPredictor
 8 | from surya.layout import LayoutPredictor
 9 | from surya.debug.draw import draw_polys_on_image
10 | from surya.logging import configure_logging, get_logger
11 | from surya.scripts.config import CLILoader
12 | from surya.settings import settings
13 | import os
14 | 
15 | configure_logging()
16 | logger = get_logger()
17 | 
18 | 
19 | @click.command(help="Detect layout of an input file or folder (PDFs or image).")
20 | @CLILoader.common_options
21 | def detect_layout_cli(input_path: str, **kwargs):
22 |     loader = CLILoader(input_path, kwargs)
23 | 
24 |     foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
25 |     layout_predictor = LayoutPredictor(foundation_predictor)
26 | 
27 |     start = time.time()
28 |     layout_predictions = layout_predictor(loader.images)
29 | 
30 |     if loader.debug:
31 |         logger.debug(f"Layout took {time.time() - start} seconds")
32 | 
33 |     if loader.save_images:
34 |         for idx, (image, layout_pred, name) in enumerate(
35 |             zip(loader.images, layout_predictions, loader.names)
36 |         ):
37 |             polygons = [p.polygon for p in layout_pred.bboxes]
38 |             labels = [f"{p.label}-{p.position}" for p in layout_pred.bboxes]
39 |             bbox_image = draw_polys_on_image(
40 |                 polygons, copy.deepcopy(image), labels=labels
41 |             )
42 |             bbox_image.save(
43 |                 os.path.join(loader.result_path, f"{name}_{idx}_layout.png")
44 |             )
45 | 
46 |     predictions_by_page = defaultdict(list)
47 |     for idx, (pred, name, image) in enumerate(
48 |         zip(layout_predictions, loader.names, loader.images)
49 |     ):
50 |         out_pred = pred.model_dump()
51 |         out_pred["page"] = len(predictions_by_page[name]) + 1
52 |         predictions_by_page[name].append(out_pred)
53 | 
54 |     with open(
55 |         os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8"
56 |     ) as f:
57 |         json.dump(predictions_by_page, f, ensure_ascii=False)
58 | 
59 |     logger.info(f"Wrote results to {loader.result_path}")
60 | 
```

--------------------------------------------------------------------------------
/surya/ocr_error/loader.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import Optional
 2 | 
 3 | import torch
 4 | 
 5 | from surya.common.load import ModelLoader
 6 | from surya.logging import get_logger
 7 | from surya.ocr_error.model.config import DistilBertConfig
 8 | from surya.ocr_error.model.encoder import DistilBertForSequenceClassification
 9 | from surya.ocr_error.tokenizer import DistilBertTokenizer
10 | from surya.settings import settings
11 | 
12 | logger = get_logger()
13 | 
14 | 
15 | class OCRErrorModelLoader(ModelLoader):
16 |     def __init__(self, checkpoint: Optional[str] = None):
17 |         super().__init__(checkpoint)
18 | 
19 |         if self.checkpoint is None:
20 |             self.checkpoint = settings.OCR_ERROR_MODEL_CHECKPOINT
21 | 
22 |     def model(
23 |         self,
24 |         device=settings.TORCH_DEVICE_MODEL,
25 |         dtype=settings.MODEL_DTYPE,
26 |         attention_implementation: Optional[str] = None,
27 |     ) -> DistilBertForSequenceClassification:
28 |         if device is None:
29 |             device = settings.TORCH_DEVICE_MODEL
30 |         if dtype is None:
31 |             dtype = settings.MODEL_DTYPE
32 | 
33 |         config = DistilBertConfig.from_pretrained(self.checkpoint)
34 |         model = (
35 |             DistilBertForSequenceClassification.from_pretrained(
36 |                 self.checkpoint,
37 |                 dtype=dtype,
38 |                 config=config,
39 |             )
40 |             .to(device)
41 |             .eval()
42 |         )
43 | 
44 |         if settings.COMPILE_ALL or settings.COMPILE_OCR_ERROR:
45 |             torch._dynamo.config.cache_size_limit = 1
46 |             torch._dynamo.config.suppress_errors = False
47 | 
48 |             logger.info(
49 |                 f"Compiling detection model {self.checkpoint} from {DistilBertForSequenceClassification.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}"
50 |             )
51 |             compile_args = {"backend": "openxla"} if device == "xla" else {}
52 |             model = torch.compile(model, **compile_args)
53 | 
54 |         return model
55 | 
56 |     def processor(
57 |         self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE
58 |     ) -> DistilBertTokenizer:
59 |         return DistilBertTokenizer.from_pretrained(self.checkpoint)
60 | 
```

--------------------------------------------------------------------------------
/benchmark/utils/verify_benchmark_scores.py:
--------------------------------------------------------------------------------

```python
 1 | import json
 2 | import click
 3 | 
 4 | 
 5 | def verify_layout(data):
 6 |     scores = data["metrics"]
 7 |     for layout_type, metrics in scores.items():
 8 |         if layout_type == "List":  # Skip lists since none appear early on
 9 |             continue
10 | 
11 |         if metrics["precision"] <= 0.6 or metrics["recall"] <= 0.6:
12 |             raise ValueError("Scores do not meet the required threshold")
13 | 
14 | 
15 | def verify_det(data):
16 |     scores = data["metrics"]["surya"]
17 |     if scores["precision"] <= 0.9 or scores["recall"] <= 0.9:
18 |         raise ValueError("Scores do not meet the required threshold")
19 | 
20 | 
21 | def verify_rec(data):
22 |     scores = data["surya"]
23 |     if scores["avg_score"] <= 0.9:
24 |         raise ValueError("Scores do not meet the required threshold")
25 | 
26 | 
27 | def verify_order(data):
28 |     score = data["mean_accuracy"]
29 |     if score < 0.75:
30 |         raise ValueError("Scores do not meet the required threshold")
31 | 
32 | 
33 | def verify_table_rec(data):
34 |     row_score = data["surya"]["mean_row_iou"]
35 |     col_score = data["surya"]["mean_col_iou"]
36 | 
37 |     if row_score < 0.75 or col_score < 0.75:
38 |         raise ValueError("Scores do not meet the required threshold")
39 | 
40 | 
41 | def verify_texify(data):
42 |     edit_dist = data["scores"]
43 |     if edit_dist > 0.2:
44 |         raise ValueError("Scores do not meet the required threshold")
45 | 
46 | 
47 | @click.command(help="Verify benchmark scores")
48 | @click.argument("file_path", type=str)
49 | @click.option(
50 |     "--bench_type", type=str, help="Type of benchmark to verify", default="detection"
51 | )
52 | def main(file_path, bench_type):
53 |     with open(file_path, "r") as file:
54 |         data = json.load(file)
55 | 
56 |     if bench_type == "detection":
57 |         verify_det(data)
58 |     elif bench_type == "recognition":
59 |         verify_rec(data)
60 |     elif bench_type == "layout":
61 |         verify_layout(data)
62 |     elif bench_type == "ordering":
63 |         verify_order(data)
64 |     elif bench_type == "table_recognition":
65 |         verify_table_rec(data)
66 |     elif bench_type == "texify":
67 |         verify_texify(data)
68 |     else:
69 |         raise ValueError("Invalid benchmark type")
70 | 
71 | 
72 | if __name__ == "__main__":
73 |     main()
74 | 
```

--------------------------------------------------------------------------------
/surya/debug/draw.py:
--------------------------------------------------------------------------------

```python
 1 | from PIL import ImageDraw, ImageFont
 2 | 
 3 | from surya.debug.fonts import get_font_path
 4 | from surya.debug.text import get_text_size
 5 | 
 6 | 
 7 | def draw_bboxes_on_image(
 8 |     bboxes, image, labels=None, label_font_size=10, color: str | list = "red"
 9 | ):
10 |     polys = []
11 |     for bb in bboxes:
12 |         # Clockwise polygon
13 |         poly = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]]
14 |         polys.append(poly)
15 | 
16 |     return draw_polys_on_image(
17 |         polys, image, labels, label_font_size=label_font_size, color=color
18 |     )
19 | 
20 | 
21 | def draw_polys_on_image(
22 |     corners,
23 |     image,
24 |     labels=None,
25 |     box_padding=-1,
26 |     label_offset=1,
27 |     label_font_size=10,
28 |     color: str | list = "red",
29 | ):
30 |     draw = ImageDraw.Draw(image)
31 |     font_path = get_font_path()
32 |     label_font = ImageFont.truetype(font_path, label_font_size)
33 | 
34 |     for i in range(len(corners)):
35 |         poly = corners[i]
36 |         poly = [(int(p[0]), int(p[1])) for p in poly]
37 |         draw.polygon(
38 |             poly, outline=color[i] if isinstance(color, list) else color, width=1
39 |         )
40 | 
41 |         if labels is not None:
42 |             label = labels[i]
43 |             text_position = (
44 |                 min([p[0] for p in poly]) + label_offset,
45 |                 min([p[1] for p in poly]) + label_offset,
46 |             )
47 |             text_size = get_text_size(label, label_font)
48 |             box_position = (
49 |                 text_position[0] - box_padding + label_offset,
50 |                 text_position[1] - box_padding + label_offset,
51 |                 text_position[0] + text_size[0] + box_padding + label_offset,
52 |                 text_position[1] + text_size[1] + box_padding + label_offset,
53 |             )
54 |             try:
55 |                 draw.rectangle(box_position, fill="white")
56 |             except Exception as e:
57 |                 print(f"Error drawing rectangle at {box_position}: {e}")
58 |                 continue
59 |             draw.text(
60 |                 text_position,
61 |                 label,
62 |                 fill=color[i] if isinstance(color, list) else color,
63 |                 font=label_font,
64 |             )
65 | 
66 |     return image
67 | 
```

--------------------------------------------------------------------------------
/surya/recognition/languages.py:
--------------------------------------------------------------------------------

```python
 1 | CODE_TO_LANGUAGE = {
 2 |     "_math": "Math",
 3 |     "af": "Afrikaans",
 4 |     "am": "Amharic",
 5 |     "ar": "Arabic",
 6 |     "as": "Assamese",
 7 |     "az": "Azerbaijani",
 8 |     "be": "Belarusian",
 9 |     "bg": "Bulgarian",
10 |     "bn": "Bengali",
11 |     "br": "Breton",
12 |     "bs": "Bosnian",
13 |     "ca": "Catalan",
14 |     "cs": "Czech",
15 |     "cy": "Welsh",
16 |     "da": "Danish",
17 |     "de": "German",
18 |     "el": "Greek",
19 |     "en": "English",
20 |     "eo": "Esperanto",
21 |     "es": "Spanish",
22 |     "et": "Estonian",
23 |     "eu": "Basque",
24 |     "fa": "Persian",
25 |     "fi": "Finnish",
26 |     "fr": "French",
27 |     "fy": "Western Frisian",
28 |     "ga": "Irish",
29 |     "gd": "Scottish Gaelic",
30 |     "gl": "Galician",
31 |     "gu": "Gujarati",
32 |     "ha": "Hausa",
33 |     "he": "Hebrew",
34 |     "hi": "Hindi",
35 |     "hr": "Croatian",
36 |     "hu": "Hungarian",
37 |     "hy": "Armenian",
38 |     "id": "Indonesian",
39 |     "is": "Icelandic",
40 |     "it": "Italian",
41 |     "ja": "Japanese",
42 |     "jv": "Javanese",
43 |     "ka": "Georgian",
44 |     "kk": "Kazakh",
45 |     "km": "Khmer",
46 |     "kn": "Kannada",
47 |     "ko": "Korean",
48 |     "ku": "Kurdish",
49 |     "ky": "Kyrgyz",
50 |     "la": "Latin",
51 |     "lo": "Lao",
52 |     "lt": "Lithuanian",
53 |     "lv": "Latvian",
54 |     "mg": "Malagasy",
55 |     "mk": "Macedonian",
56 |     "ml": "Malayalam",
57 |     "mn": "Mongolian",
58 |     "mr": "Marathi",
59 |     "ms": "Malay",
60 |     "my": "Burmese",
61 |     "ne": "Nepali",
62 |     "nl": "Dutch",
63 |     "no": "Norwegian",
64 |     "om": "Oromo",
65 |     "or": "Oriya",
66 |     "pa": "Punjabi",
67 |     "pl": "Polish",
68 |     "ps": "Pashto",
69 |     "pt": "Portuguese",
70 |     "ro": "Romanian",
71 |     "ru": "Russian",
72 |     "sa": "Sanskrit",
73 |     "sd": "Sindhi",
74 |     "si": "Sinhala",
75 |     "sk": "Slovak",
76 |     "sl": "Slovenian",
77 |     "so": "Somali",
78 |     "sq": "Albanian",
79 |     "sr": "Serbian",
80 |     "su": "Sundanese",
81 |     "sv": "Swedish",
82 |     "sw": "Swahili",
83 |     "ta": "Tamil",
84 |     "te": "Telugu",
85 |     "th": "Thai",
86 |     "tl": "Tagalog",
87 |     "tr": "Turkish",
88 |     "ug": "Uyghur",
89 |     "uk": "Ukrainian",
90 |     "ur": "Urdu",
91 |     "uz": "Uzbek",
92 |     "vi": "Vietnamese",
93 |     "xh": "Xhosa",
94 |     "yi": "Yiddish",
95 |     "zh": "Chinese",
96 | }
97 | 
98 | LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()}
99 | 
```

--------------------------------------------------------------------------------
/surya/common/surya/embedder/__init__.py:
--------------------------------------------------------------------------------

```python
 1 | import torch
 2 | import torch.nn as nn
 3 | import torch.nn.functional as F
 4 | 
 5 | 
 6 | class SimpleTokenEmbedder(nn.Module):
 7 |     def __init__(self, config):
 8 |         super().__init__()
 9 |         self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
10 |         self.bbox_embed = nn.ModuleList(
11 |             [
12 |                 nn.Embedding(
13 |                     config.bbox_size + config.special_token_count,
14 |                     config.bbox_embed_size,
15 |                 )
16 |                 for _ in range(6)
17 |             ]
18 |         )
19 |         self.max_bbox_embedding = config.bbox_size + config.special_token_count - 1
20 |         self.max_bbox_size = config.bbox_size
21 | 
22 |     def embed(
23 |         self,
24 |         input_tokens: torch.Tensor,
25 |         input_boxes: torch.Tensor | None,
26 |         embed_boxes: torch.Tensor,
27 |     ) -> torch.Tensor:
28 |         # Embed tokens
29 |         token_embeds = self.token_embed(input_tokens)
30 | 
31 |         # Optionally embed boxes
32 |         if input_boxes is not None and embed_boxes.any():  # Is none in prefill
33 |             input_boxes = input_boxes.to(torch.long)
34 |             bbox_loss_ignore_mask = (
35 |                 (input_boxes[:, :, 0] < 0) | (input_boxes[:, :, 0] > self.max_bbox_size)
36 |             ).unsqueeze(-1)
37 |             input_boxes = torch.clamp(input_boxes, 0, self.max_bbox_embedding)
38 | 
39 |             bbox_embeds = torch.sum(
40 |                 torch.stack(
41 |                     [
42 |                         self.bbox_embed[i](input_boxes[:, :, i])
43 |                         for i in range(len(self.bbox_embed))
44 |                     ],
45 |                     dim=-1,
46 |                 ),
47 |                 dim=-1,
48 |             )
49 | 
50 |             bbox_embeds = F.pad(
51 |                 bbox_embeds, (token_embeds.shape[-1] - bbox_embeds.shape[-1], 0)
52 |             )
53 |             embed_boxes = embed_boxes.unsqueeze(1).unsqueeze(1).expand_as(bbox_embeds)
54 |             bbox_loss_ignore_mask = bbox_loss_ignore_mask.expand_as(bbox_embeds)
55 | 
56 |             mask = embed_boxes & ~bbox_loss_ignore_mask
57 |             bbox_embeds *= mask.float()
58 | 
59 |             token_embeds = token_embeds + bbox_embeds
60 | 
61 |         return token_embeds
62 | 
```

--------------------------------------------------------------------------------
/surya/detection/loader.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import Optional
 2 | 
 3 | import torch
 4 | 
 5 | from surya.common.load import ModelLoader
 6 | from surya.detection.processor import SegformerImageProcessor
 7 | 
 8 | from surya.detection.model.config import EfficientViTConfig
 9 | from surya.detection.model.encoderdecoder import EfficientViTForSemanticSegmentation
10 | from surya.logging import get_logger
11 | from surya.settings import settings
12 | 
13 | logger = get_logger()
14 | 
15 | 
16 | class DetectionModelLoader(ModelLoader):
17 |     def __init__(self, checkpoint: Optional[str] = None):
18 |         super().__init__(checkpoint)
19 | 
20 |         if self.checkpoint is None:
21 |             self.checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
22 | 
23 |     def model(
24 |         self,
25 |         device: Optional[torch.device | str] = None,
26 |         dtype: Optional[torch.dtype | str] = None,
27 |         attention_implementation: Optional[str] = None,
28 |     ) -> EfficientViTForSemanticSegmentation:
29 |         if device is None:
30 |             device = settings.TORCH_DEVICE_MODEL
31 |         if dtype is None:
32 |             dtype = settings.MODEL_DTYPE
33 | 
34 |         config = EfficientViTConfig.from_pretrained(self.checkpoint)
35 |         model = EfficientViTForSemanticSegmentation.from_pretrained(
36 |             self.checkpoint,
37 |             dtype=dtype,
38 |             config=config,
39 |         )
40 |         model = model.to(device)
41 |         model = model.eval()
42 | 
43 |         if settings.COMPILE_ALL or settings.COMPILE_DETECTOR:
44 |             torch._dynamo.config.cache_size_limit = 1
45 |             torch._dynamo.config.suppress_errors = False
46 | 
47 |             logger.info(
48 |                 f"Compiling detection model {self.checkpoint} on device {device} with dtype {dtype}"
49 |             )
50 |             compile_args = {"backend": "openxla"} if device == "xla" else {}
51 |             model = torch.compile(model, **compile_args)
52 | 
53 |         logger.debug(
54 |             f"Loaded detection model {self.checkpoint} from {EfficientViTForSemanticSegmentation.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}"
55 |         )
56 |         return model
57 | 
58 |     def processor(
59 |         self,
60 |         device: Optional[torch.device | str] = None,
61 |         dtype: Optional[torch.dtype | str] = None,
62 |     ) -> SegformerImageProcessor:
63 |         return SegformerImageProcessor.from_pretrained(self.checkpoint)
64 | 
```

--------------------------------------------------------------------------------
/surya/scripts/hf_to_s3.py:
--------------------------------------------------------------------------------

```python
 1 | import json
 2 | import shutil
 3 | import datetime
 4 | from pathlib import Path
 5 | import boto3
 6 | 
 7 | from huggingface_hub import snapshot_download
 8 | 
 9 | import click
10 | from tqdm import tqdm
11 | 
12 | S3_API_URL = "https://1afbe4656a6b40d982ab5e730a39f6b9.r2.cloudflarestorage.com"
13 | 
14 | 
15 | # Example usage - python scripts/hf_to_s3.py <REPO_NAME> layout
16 | # This will upload to s3://layout/TODAYS_DATE
17 | @click.command(help="Uploads the data from huggingface to an S3 bucket")
18 | @click.argument("hf_repo_id", type=str)
19 | @click.argument("s3_path", type=str)
20 | @click.option("--bucket_name", type=str, default="datalab")
21 | @click.option("--revision_hash", type=str, default=None)
22 | @click.option("--access_key_id", type=str, default="<access_key_id>")
23 | @click.option("--access_key_secret", type=str, default="<access_key_secret>")
24 | @click.option("--suffix", type=str, default="")
25 | def main(
26 |     hf_repo_id: str,
27 |     s3_path: str,
28 |     bucket_name: str,
29 |     revision_hash: str,
30 |     access_key_id: str,
31 |     access_key_secret: str,
32 |     suffix: str,
33 | ):
34 |     curr_date = datetime.datetime.now().strftime("%Y_%m_%d")
35 |     s3_path = f"{s3_path}/{curr_date}"
36 |     if suffix:
37 |         s3_path = f"{s3_path}_{suffix}"
38 | 
39 |     download_folder = snapshot_download(repo_id=hf_repo_id, revision=revision_hash)
40 |     download_folder = Path(download_folder)
41 |     contained_files = list(download_folder.glob("*"))
42 |     contained_files = [f.name for f in contained_files]  # Just get the base name
43 |     manifest_file = download_folder / "manifest.json"
44 | 
45 |     with open(manifest_file, "w") as f:
46 |         json.dump({"files": contained_files}, f)
47 | 
48 |     # Upload the files to S3
49 |     s3_client = boto3.client(
50 |         service_name="s3",
51 |         endpoint_url=S3_API_URL,
52 |         aws_access_key_id=access_key_id,
53 |         aws_secret_access_key=access_key_secret,
54 |         region_name="auto",
55 |     )
56 | 
57 |     # Iterate through all files in the folder
58 |     for file_path in tqdm(
59 |         download_folder.glob("*"), desc="Uploading files", unit="file"
60 |     ):
61 |         s3_key = f"{s3_path}/{file_path.name}"
62 | 
63 |         try:
64 |             s3_client.upload_file(str(file_path), bucket_name, s3_key)
65 |         except Exception as e:
66 |             print(f"Error uploading {file_path}: {str(e)}")
67 | 
68 |     shutil.rmtree(download_folder)
69 | 
70 |     print(f"Uploaded files to {s3_path}")
71 | 
72 | 
73 | if __name__ == "__main__":
74 |     main()
75 | 
```

--------------------------------------------------------------------------------
/surya/ocr_error/__init__.py:
--------------------------------------------------------------------------------

```python
 1 | import math
 2 | from typing import List, Optional
 3 | 
 4 | from tqdm import tqdm
 5 | 
 6 | from surya.common.predictor import BasePredictor
 7 | from surya.ocr_error.loader import OCRErrorModelLoader
 8 | from surya.ocr_error.model.config import ID2LABEL
 9 | from surya.ocr_error.schema import OCRErrorDetectionResult
10 | from surya.settings import settings
11 | from surya.common.xla import mark_step
12 | 
13 | 
14 | class OCRErrorPredictor(BasePredictor):
15 |     model_loader_cls = OCRErrorModelLoader
16 |     batch_size = settings.OCR_ERROR_BATCH_SIZE
17 |     default_batch_sizes = {"cpu": 8, "mps": 8, "cuda": 64, "xla": 32}
18 | 
19 |     def __call__(self, texts: List[str], batch_size: Optional[int] = None):
20 |         return self.batch_ocr_error_detection(texts, batch_size)
21 | 
22 |     def batch_ocr_error_detection(
23 |         self, texts: List[str], batch_size: Optional[int] = None
24 |     ):
25 |         if batch_size is None:
26 |             batch_size = self.get_batch_size()
27 | 
28 |         num_batches = math.ceil(len(texts) / batch_size)
29 |         texts_processed = self.processor(
30 |             texts, padding="longest", truncation=True, return_tensors="pt"
31 |         )
32 |         predictions = []
33 |         for batch_idx in tqdm(
34 |             range(num_batches),
35 |             desc="Running OCR Error Detection",
36 |             disable=self.disable_tqdm,
37 |         ):
38 |             start_idx, end_idx = batch_idx * batch_size, (batch_idx + 1) * batch_size
39 |             batch_input_ids = texts_processed.input_ids[start_idx:end_idx].to(
40 |                 self.model.device
41 |             )
42 |             batch_attention_mask = texts_processed.attention_mask[start_idx:end_idx].to(
43 |                 self.model.device
44 |             )
45 | 
46 |             # Pad to batch size
47 |             current_batch_size = batch_input_ids.shape[0]
48 |             if settings.OCR_ERROR_STATIC_CACHE:
49 |                 batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size)
50 |                 batch_attention_mask = self.pad_to_batch_size(
51 |                     batch_attention_mask, batch_size
52 |                 )
53 | 
54 |             with settings.INFERENCE_MODE():
55 |                 pred = self.model(batch_input_ids, attention_mask=batch_attention_mask)
56 | 
57 |                 logits = pred.logits.argmax(dim=1).cpu().tolist()[:current_batch_size]
58 |                 predictions.extend(logits)
59 |             mark_step()
60 | 
61 |         return OCRErrorDetectionResult(
62 |             texts=texts, labels=[ID2LABEL[p] for p in predictions]
63 |         )
64 | 
```

--------------------------------------------------------------------------------
/surya/input/load.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import List
 2 | import PIL
 3 | 
 4 | from surya.input.processing import open_pdf, get_page_images
 5 | from surya.logging import get_logger
 6 | from surya.settings import settings
 7 | import os
 8 | import filetype
 9 | from PIL import Image
10 | import json
11 | 
12 | logger = get_logger()
13 | 
14 | 
15 | def get_name_from_path(path):
16 |     return os.path.basename(path).split(".")[0]
17 | 
18 | 
19 | def load_pdf(pdf_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI):
20 |     doc = open_pdf(pdf_path)
21 |     last_page = len(doc)
22 | 
23 |     if page_range:
24 |         assert all([0 <= page < last_page for page in page_range]), (
25 |             f"Invalid page range: {page_range}"
26 |         )
27 |     else:
28 |         page_range = list(range(last_page))
29 | 
30 |     images = get_page_images(doc, page_range, dpi=dpi)
31 |     doc.close()
32 |     names = [get_name_from_path(pdf_path) for _ in page_range]
33 |     return images, names
34 | 
35 | 
36 | def load_image(image_path):
37 |     image = Image.open(image_path).convert("RGB")
38 |     name = get_name_from_path(image_path)
39 |     return [image], [name]
40 | 
41 | 
42 | def load_from_file(
43 |     input_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI
44 | ):
45 |     input_type = filetype.guess(input_path)
46 |     if input_type and input_type.extension == "pdf":
47 |         return load_pdf(input_path, page_range, dpi=dpi)
48 |     else:
49 |         return load_image(input_path)
50 | 
51 | 
52 | def load_from_folder(
53 |     folder_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI
54 | ):
55 |     image_paths = [
56 |         os.path.join(folder_path, image_name)
57 |         for image_name in os.listdir(folder_path)
58 |         if not image_name.startswith(".")
59 |     ]
60 |     image_paths = [ip for ip in image_paths if not os.path.isdir(ip)]
61 | 
62 |     images = []
63 |     names = []
64 |     for path in image_paths:
65 |         extension = filetype.guess(path)
66 |         if extension and extension.extension == "pdf":
67 |             image, name = load_pdf(path, page_range, dpi=dpi)
68 |             images.extend(image)
69 |             names.extend(name)
70 |         else:
71 |             try:
72 |                 image, name = load_image(path)
73 |                 images.extend(image)
74 |                 names.extend(name)
75 |             except PIL.UnidentifiedImageError:
76 |                 logger.warning(f"Could not load image {path}")
77 |                 continue
78 |     return images, names
79 | 
80 | 
81 | def load_lang_file(lang_path, names):
82 |     with open(lang_path, "r") as f:
83 |         lang_dict = json.load(f)
84 |     return [lang_dict[name].copy() for name in names]
85 | 
```

--------------------------------------------------------------------------------
/surya/scripts/ocr_text.py:
--------------------------------------------------------------------------------

```python
 1 | import os
 2 | import click
 3 | import json
 4 | import time
 5 | from collections import defaultdict
 6 | 
 7 | from surya.common.surya.schema import TaskNames
 8 | from surya.detection import DetectionPredictor
 9 | from surya.debug.text import draw_text_on_image
10 | from surya.logging import configure_logging, get_logger
11 | from surya.foundation import FoundationPredictor
12 | from surya.recognition import RecognitionPredictor
13 | from surya.scripts.config import CLILoader
14 | 
15 | configure_logging()
16 | logger = get_logger()
17 | 
18 | 
19 | @click.command(help="OCR text.")
20 | @click.option("--task_name", type=str, default=TaskNames.ocr_with_boxes)
21 | @click.option(
22 |     "--disable_math", is_flag=True, default=False, help="Do not recognize math in OCR."
23 | )
24 | @CLILoader.common_options
25 | def ocr_text_cli(input_path: str, task_name: str, disable_math: bool, **kwargs):
26 |     loader = CLILoader(input_path, kwargs, highres=True)
27 |     task_names = [task_name] * len(loader.images)
28 | 
29 |     foundation_predictor = FoundationPredictor()
30 |     det_predictor = DetectionPredictor()
31 |     rec_predictor = RecognitionPredictor(foundation_predictor)
32 | 
33 |     start = time.time()
34 |     predictions_by_image = rec_predictor(
35 |         loader.images,
36 |         task_names=task_names,
37 |         det_predictor=det_predictor,
38 |         highres_images=loader.highres_images,
39 |         math_mode=not disable_math,
40 |     )
41 | 
42 |     if loader.debug:
43 |         logger.debug(f"OCR took {time.time() - start:.2f} seconds")
44 |         max_chars = max(
45 |             [len(line.text) for p in predictions_by_image for line in p.text_lines]
46 |         )
47 |         logger.debug(f"Max chars: {max_chars}")
48 | 
49 |     if loader.save_images:
50 |         for idx, (name, image, pred) in enumerate(
51 |             zip(loader.names, loader.images, predictions_by_image)
52 |         ):
53 |             bboxes = [line.bbox for line in pred.text_lines]
54 |             pred_text = [line.text for line in pred.text_lines]
55 |             page_image = draw_text_on_image(bboxes, pred_text, image.size)
56 |             page_image.save(os.path.join(loader.result_path, f"{name}_{idx}_text.png"))
57 | 
58 |     out_preds = defaultdict(list)
59 |     for name, pred, image in zip(loader.names, predictions_by_image, loader.images):
60 |         out_pred = pred.model_dump()
61 |         out_pred["page"] = len(out_preds[name]) + 1
62 |         out_preds[name].append(out_pred)
63 | 
64 |     with open(
65 |         os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8"
66 |     ) as f:
67 |         json.dump(out_preds, f, ensure_ascii=False)
68 | 
69 |     logger.info(f"Wrote results to {loader.result_path}")
70 | 
```

--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------

```python
 1 | import os
 2 | 
 3 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
 4 | 
 5 | import pytest
 6 | from PIL import Image, ImageDraw
 7 | 
 8 | from surya.detection import DetectionPredictor
 9 | from surya.ocr_error import OCRErrorPredictor
10 | from surya.layout import LayoutPredictor
11 | from surya.recognition import RecognitionPredictor
12 | from surya.foundation import FoundationPredictor
13 | from surya.table_rec import TableRecPredictor
14 | from surya.settings import settings
15 | 
16 | @pytest.fixture(scope="session")
17 | def ocr_error_predictor() -> OCRErrorPredictor:
18 |     ocr_error_predictor = OCRErrorPredictor()
19 |     yield ocr_error_predictor
20 |     del ocr_error_predictor
21 | 
22 | 
23 | @pytest.fixture(scope="session")
24 | def layout_predictor() -> LayoutPredictor:
25 |     layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT))
26 |     yield layout_predictor
27 |     del layout_predictor
28 | 
29 | 
30 | @pytest.fixture(scope="session")
31 | def detection_predictor() -> DetectionPredictor:
32 |     detection_predictor = DetectionPredictor()
33 |     yield detection_predictor
34 |     del detection_predictor
35 | 
36 | 
37 | @pytest.fixture(scope="session")
38 | def recognition_predictor() -> RecognitionPredictor:
39 |     recognition_predictor = RecognitionPredictor(FoundationPredictor(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT))
40 |     yield recognition_predictor
41 |     del recognition_predictor
42 | 
43 | 
44 | @pytest.fixture(scope="session")
45 | def table_rec_predictor() -> TableRecPredictor:
46 |     table_rec_predictor = TableRecPredictor()
47 |     yield table_rec_predictor
48 |     del table_rec_predictor
49 | 
50 | 
51 | @pytest.fixture()
52 | def test_image():
53 |     image = Image.new("RGB", (1024, 1024), "white")
54 |     draw = ImageDraw.Draw(image)
55 |     draw.text((10, 10), "Hello World", fill="black", font_size=72)
56 |     draw.text(
57 |         (10, 200),
58 |         "This is a sentence of text.\nNow it is a paragraph.\nA three-line one.",
59 |         fill="black",
60 |         font_size=24,
61 |     )
62 |     return image
63 | 
64 | 
65 | @pytest.fixture()
66 | def test_image_tall():
67 |     image = Image.new("RGB", (4096, 4096), "white")
68 |     draw = ImageDraw.Draw(image)
69 |     draw.text((10, 10), "Hello World", fill="black", font_size=72)
70 |     draw.text(
71 |         (4000, 4000),
72 |         "This is a sentence of text.\n\nNow it is a paragraph.\n\nA three-line one.",
73 |         fill="black",
74 |         font_size=24,
75 |     )
76 |     return image
77 | 
78 | @pytest.fixture()
79 | def test_image_latex():
80 |     assets_dir = os.path.join(os.path.dirname(__file__), "assets")
81 |     img_path = os.path.join(assets_dir, "test_latex.png")
82 |     image = Image.open(img_path).convert("RGB")
83 |     return image
```

--------------------------------------------------------------------------------
/surya/debug/render_html.py:
--------------------------------------------------------------------------------

```python
 1 | import html as htmllib
 2 | import os.path
 3 | import re
 4 | 
 5 | filepath = os.path.abspath(__file__)
 6 | 
 7 | def render_text_as_html(
 8 |         bboxes: list[list[int]],
 9 |         texts: list[str],
10 |         image_size: tuple[int, int],
11 |         base_font_size: int = 16,
12 |         scaler: int = 2
13 | ):
14 |     katex_path = os.path.join(os.path.dirname(filepath), "katex.js")
15 |     with open(katex_path, "r") as f:
16 |         katex_script = f.read()
17 | 
18 |     html_content = []
19 |     image_size = tuple([int(s * scaler) for s in image_size])
20 |     width, height = image_size
21 | 
22 | 
23 |     html_content.append(f"""
24 | <!DOCTYPE html>
25 | <html>
26 | <head>
27 |     <style>
28 |         body {{
29 |             margin: 0;
30 |             padding: 0;
31 |             width: {width}px;
32 |             height: {height}px;
33 |             position: relative;
34 |             overflow: hidden;
35 |             background: white;
36 |             color: black;
37 |         }}
38 |         .text-box {{
39 |             position: absolute;
40 |             overflow: hidden;
41 |             display: flex;
42 |             justify-content: left;
43 |             font-family: Arial, sans-serif;
44 |             white-space: pre-wrap;
45 |         }}
46 |         .vertical-text {{
47 |           writing-mode: vertical-rl;  /* Top to bottom, right to left */
48 |         }}
49 |     </style>
50 |     {katex_script}
51 | </head>
52 | <body>
53 | """)
54 | 
55 |     for i, (bbox, text) in enumerate(zip(bboxes, texts)):
56 |         bbox = bbox.copy()
57 |         bbox = [int(bb * scaler) for bb in bbox]
58 |         x1, y1, x2, y2 = bbox
59 |         width = x2 - x1
60 |         height = y2 - y1
61 |         min_dim = min(width, height)
62 | 
63 |         # Scale font size based on box height
64 |         font_size = min(int(min_dim * 0.75), base_font_size)
65 | 
66 |         # Create div with absolute positioning
67 |         div_style = (
68 |             f"left: {x1}px; "
69 |             f"top: {y1}px; "
70 |             f"width: {width}px; "
71 |             f"height: {height}px; "
72 |             f"font-size: {font_size}px;"
73 |         )
74 | 
75 |         class_ = "text-box"
76 |         if height > width * 2:
77 |             class_ += " vertical-text"
78 | 
79 |         # Determine if content is HTML/MathML or plain text
80 |         if "<" in text and ">" in text and re.search(r"<(html|math|div|sub|sup|i|u|mark|small|del|b|br|code)\b", text.lower()):
81 |             # Content is already HTML/MathML, include as-is
82 |             html_content.append(f'<span class="{class_}" id="box-{i}" style="{div_style}">{text}</span>')
83 |         else:
84 |             # Plain text, escape it
85 |             escaped_text = htmllib.escape(text)
86 |             html_content.append(f'<span class="{class_}" id="box-{i}" style="{div_style}">{escaped_text}</span>')
87 | 
88 |     html_content.append("</body></html>")
89 | 
90 |     return "\n".join(html_content), image_size
```

--------------------------------------------------------------------------------
/surya/common/predictor.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import Optional
 2 | import torch
 3 | import torch.nn.functional as F
 4 | 
 5 | from surya.common.load import ModelLoader
 6 | from surya.settings import settings
 7 | 
 8 | 
 9 | class BasePredictor:
10 |     model_loader_cls = ModelLoader
11 |     batch_size: Optional[int] = None
12 |     default_batch_sizes = {"cpu": 1, "mps": 1, "cuda": 1}
13 |     torch_dtype = settings.MODEL_DTYPE
14 | 
15 |     @property
16 |     def disable_tqdm(self) -> bool:
17 |         return self._disable_tqdm
18 | 
19 |     @disable_tqdm.setter
20 |     def disable_tqdm(self, value: bool) -> None:
21 |         self._disable_tqdm = bool(value)
22 | 
23 |     def __init__(
24 |         self,
25 |         checkpoint: Optional[str] = None,
26 |         device: torch.device | str | None = settings.TORCH_DEVICE_MODEL,
27 |         dtype: Optional[torch.dtype | str] = None,
28 |         attention_implementation: Optional[str] = None,
29 |     ):
30 |         if dtype is None:
31 |             dtype = self.torch_dtype
32 | 
33 |         self.model = None
34 |         self.processor = None
35 |         loader = self.model_loader_cls(checkpoint)
36 | 
37 |         self.model = loader.model(device, dtype, attention_implementation)
38 |         self.processor = loader.processor()
39 | 
40 |         self._disable_tqdm = settings.DISABLE_TQDM
41 | 
42 |     def to(self, device_dtype: torch.device | str | None = None):
43 |         model_moved = False
44 |         if hasattr(self, "model") and self.model:
45 |             self.model.to(device_dtype)
46 |             model_moved = True
47 |         if hasattr(self, "foundation_predictor") and self.foundation_predictor:
48 |             self.foundation_predictor.model.to(device_dtype)
49 |             model_moved = True
50 | 
51 |         if not model_moved:
52 |             raise ValueError("Model not loaded")
53 | 
54 |     def get_batch_size(self):
55 |         batch_size = self.batch_size
56 |         if batch_size is None:
57 |             batch_size = self.default_batch_sizes["cpu"]
58 |             if settings.TORCH_DEVICE_MODEL in self.default_batch_sizes:
59 |                 batch_size = self.default_batch_sizes[settings.TORCH_DEVICE_MODEL]
60 |         return batch_size
61 | 
62 |     @staticmethod
63 |     def pad_to_batch_size(tensor: torch.Tensor, batch_size: int):
64 |         current_batch_size = tensor.shape[0]
65 |         if current_batch_size >= batch_size:
66 |             return tensor
67 | 
68 |         if len(tensor.shape) == 1:
69 |             # If tensor is 1D, we need to pad it to the batch size
70 |             pad_size = batch_size - current_batch_size
71 |             return F.pad(tensor, (0, pad_size), mode="constant", value=0)
72 | 
73 |         pad_size = batch_size - current_batch_size
74 |         padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)
75 | 
76 |         return F.pad(tensor, padding, mode="constant", value=0)
77 | 
78 |     def __call__(self, *args, **kwargs):
79 |         raise NotImplementedError()
80 | 
```

--------------------------------------------------------------------------------
/surya/scripts/config.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import List
 2 | 
 3 | import click
 4 | import os
 5 | from surya.input.load import load_from_folder, load_from_file
 6 | from surya.settings import settings
 7 | 
 8 | 
 9 | class CLILoader:
10 |     def __init__(self, filepath: str, cli_options: dict, highres: bool = False):
11 |         self.page_range = cli_options.get("page_range")
12 |         if self.page_range:
13 |             self.page_range = self.parse_range_str(self.page_range)
14 |         self.filepath = filepath
15 |         self.config = cli_options
16 |         self.save_images = cli_options.get("images", False)
17 |         self.debug = cli_options.get("debug", False)
18 |         self.output_dir = cli_options.get("output_dir")
19 | 
20 |         self.load(highres)
21 | 
22 |     @staticmethod
23 |     def common_options(fn):
24 |         fn = click.argument("input_path", type=click.Path(exists=True), required=True)(fn)
25 |         fn = click.option("--output_dir", type=click.Path(exists=False), required=False, default=os.path.join(settings.RESULT_DIR, "surya"), help="Directory to save output.")(fn)
26 |         fn = click.option("--page_range", type=str, default=None, help="Page range to convert, specify comma separated page numbers or ranges.  Example: 0,5-10,20")(fn)
27 |         fn = click.option("--images", is_flag=True, help="Save images of detected bboxes.", default=False)(fn)
28 |         fn = click.option('--debug', '-d', is_flag=True, help='Enable debug mode.', default=False)(fn)
29 |         return fn
30 | 
31 |     def load(self, highres: bool = False):
32 |         highres_images = None
33 |         if os.path.isdir(self.filepath):
34 |             images, names = load_from_folder(self.filepath, self.page_range)
35 |             folder_name = os.path.basename(self.filepath)
36 |             if highres:
37 |                 highres_images, _ = load_from_folder(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES)
38 |         else:
39 |             images, names = load_from_file(self.filepath, self.page_range)
40 |             folder_name = os.path.basename(self.filepath).split(".")[0]
41 |             if highres:
42 |                 highres_images, _ = load_from_file(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES)
43 | 
44 | 
45 |         self.images = images
46 |         self.highres_images = highres_images
47 |         self.names = names
48 | 
49 |         self.result_path = os.path.abspath(os.path.join(self.output_dir, folder_name))
50 |         os.makedirs(self.result_path, exist_ok=True)
51 | 
52 |     @staticmethod
53 |     def parse_range_str(range_str: str) -> List[int]:
54 |         range_lst = range_str.split(",")
55 |         page_lst = []
56 |         for i in range_lst:
57 |             if "-" in i:
58 |                 start, end = i.split("-")
59 |                 page_lst += list(range(int(start), int(end) + 1))
60 |             else:
61 |                 page_lst.append(int(i))
62 |         page_lst = sorted(list(set(page_lst)))  # Deduplicate page numbers and sort in order
63 |         return page_lst
```

--------------------------------------------------------------------------------
/tests/test_recognition.py:
--------------------------------------------------------------------------------

```python
 1 | import time
 2 | from PIL import ImageDraw, Image
 3 | from surya.recognition.util import clean_math_tags
 4 | 
 5 | 
 6 | def test_recognition(recognition_predictor, detection_predictor, test_image):
 7 |     recognition_results = recognition_predictor([test_image], None, detection_predictor)
 8 | 
 9 |     assert len(recognition_results) == 1
10 |     assert recognition_results[0].image_bbox == [0, 0, 1024, 1024]
11 | 
12 |     text_lines = recognition_results[0].text_lines
13 |     assert len(text_lines) == 4
14 |     assert "Hello World" in text_lines[0].text
15 | 
16 | 
17 | def test_recognition_input_text(recognition_predictor, detection_predictor, test_image):
18 |     start = time.time()
19 |     recognition_predictor([test_image], None, detection_predictor)
20 |     end = time.time() - start
21 | 
22 |     input_text = "a" * 400
23 |     start2 = time.time()
24 |     recognition_results = recognition_predictor(
25 |         [test_image], None, detection_predictor, input_text=[input_text]
26 |     )
27 |     end2 = time.time() - start2
28 | 
29 |     assert max([end, end2]) / min([end, end2]) < 1.5, (
30 |         "Input text should be truncated and not change inference time"
31 |     )
32 | 
33 |     assert len(recognition_results) == 1
34 |     assert recognition_results[0].image_bbox == [0, 0, 1024, 1024]
35 | 
36 |     text_lines = recognition_results[0].text_lines
37 |     assert len(text_lines) == 4
38 |     assert "Hello World" in text_lines[0].text
39 | 
40 | 
41 | def test_recognition_drop_repeats(recognition_predictor, detection_predictor):
42 |     image = Image.new("RGB", (1024, 128), "white")
43 |     draw = ImageDraw.Draw(image)
44 |     text = "a" * 80
45 |     draw.text((5, 5), text, fill="black", font_size=24)
46 | 
47 |     recognition_results = recognition_predictor(
48 |         [image], None, bboxes=[[[0, 0, 1024, 128]]], drop_repeated_text=True
49 |     )
50 |     assert len(recognition_results) == 1
51 |     result = recognition_results[0].text_lines
52 |     assert result[0].text == ""
53 | 
54 | 
55 | def test_recognition_clean_math():
56 |     math = """<math display="block">na_n^{1+2r} \\text{cov}(\\hat{f}_n^{(r)}(x), \\hat{f}_n^{(r)}(y)) = \\frac{1}{n} \\sum_{j=1}^n \\frac{a_n^{1+2r}}{a_j^{1+2r}} \\text{cov}\\left(K^{(r)}\\left(\\frac{x-X_j}{a_j}\\right), K^{(r)}\\left(\\frac{y-X_j}{a_j}\\right)\\right) <br>+ \\frac{a_n^{1+2r}}{n} \\sum_{\\substack{j \\neq k \\\\ 1 \\le j, k \\le n}} \\frac{1}{(a_j a_k)^{1+r}} \\text{cov}\\left(K^{(r)}\\left(\\frac{x-X_j}{a_j}\\right), K^{(r)}\\left(\\frac{y-X_k}{a_k}\\right)\\right) <br>=: I_1 + I_2.</math> (1.7)</math>'"""
57 |     clean_math = clean_math_tags(math)
58 | 
59 |     assert clean_math.count("</math>") == 1, "Should have one closing math tag"
60 |     assert "<br>" not in clean_math, "Should not have <br> tags in cleaned math"
61 | 
62 | 
63 | def test_recognition_clean_math_preserve_text():
64 |     text = """Hello, this is a sentence with <math display="inline">x^2 + y^2 = z^2</math> and some text after it, with a weird tag <hello> and <goodbye>."""
65 |     clean_text = clean_math_tags(text)
66 | 
67 |     assert clean_text == text
68 | 
```

--------------------------------------------------------------------------------
/surya/input/processing.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import List
  2 | 
  3 | import cv2
  4 | import numpy as np
  5 | import pypdfium2
  6 | from PIL import Image
  7 | 
  8 | from surya.logging import get_logger
  9 | from surya.settings import settings
 10 | 
 11 | logger = get_logger()
 12 | 
 13 | 
 14 | def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]:
 15 |     new_images = []
 16 |     for image in images:
 17 |         if image.mode != "RGB":
 18 |             image = image.convert("RGB")
 19 |         new_images.append(image)
 20 |     return new_images
 21 | 
 22 | 
 23 | def open_pdf(pdf_filepath):
 24 |     return pypdfium2.PdfDocument(pdf_filepath)
 25 | 
 26 | 
 27 | def get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI):
 28 |     images = [
 29 |         doc[i].render(scale=dpi / 72, draw_annots=False).to_pil() for i in indices
 30 |     ]
 31 |     images = [image.convert("RGB") for image in images]
 32 |     return images
 33 | 
 34 | 
 35 | def slice_bboxes_from_image(image: np.ndarray, bboxes):
 36 |     lines = []
 37 |     for bbox in bboxes:
 38 |         bbox = np.array(bbox, dtype=np.int32)
 39 |         bbox = np.clip(bbox, 0, None)  # Ensure no negative indices
 40 |         # Ensure bbox is within the image bounds
 41 |         if bbox[3] <= bbox[1]:
 42 |             bbox[3] = bbox[1] + 1
 43 | 
 44 |         if bbox[2] <= bbox[0]:
 45 |             bbox[2] = bbox[0] + 1
 46 | 
 47 |         bbox[2] = min(bbox[2], image.shape[1])
 48 |         bbox[3] = min(bbox[3], image.shape[0])
 49 | 
 50 |         line = image[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy()
 51 |         if line.size == 0:
 52 |             logger.warning(f"Warning: found an empty line with bbox {bbox}")
 53 |         lines.append(line)
 54 |     return lines
 55 | 
 56 | 
 57 | def slice_polys_from_image(image: np.ndarray, polys):
 58 |     lines = []
 59 |     for idx, poly in enumerate(polys):
 60 |         lines.append(slice_and_pad_poly(image, poly))
 61 |     return lines
 62 | 
 63 | 
 64 | def slice_and_pad_poly(image_array: np.array, coordinates):
 65 |     # Draw polygon onto mask
 66 |     coordinates = [(corner[0], corner[1]) for corner in coordinates]
 67 |     bbox = [
 68 |         min([x[0] for x in coordinates]),
 69 |         min([x[1] for x in coordinates]),
 70 |         max([x[0] for x in coordinates]),
 71 |         max([x[1] for x in coordinates]),
 72 |     ]
 73 | 
 74 |     # We mask out anything not in the polygon
 75 |     cropped_polygon = image_array[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy()
 76 |     height, width = cropped_polygon.shape[:2]
 77 | 
 78 |     coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates]
 79 | 
 80 |     # Validate the cropped area
 81 |     if any(
 82 |         [
 83 |             bbox[3] <= bbox[1] or bbox[2] <= bbox[0],
 84 |             len(coordinates) < 3,
 85 |             height == 0,
 86 |             width == 0,
 87 |         ]
 88 |     ):
 89 |         return cropped_polygon
 90 | 
 91 |     # Pad the area outside the polygon with the pad value
 92 |     try:
 93 |         mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8)
 94 |         cv2.fillPoly(mask, [np.int32(coordinates)], 1)
 95 |         mask = np.stack([mask] * 3, axis=-1)
 96 | 
 97 |         cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE
 98 |     except cv2.error as e:
 99 |         logger.warning(f"Warning: issue while processing polygon: {e}")
100 | 
101 |     return cropped_polygon
102 | 
```

--------------------------------------------------------------------------------
/surya/table_rec/loader.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import Optional
 2 | 
 3 | import torch
 4 | 
 5 | from surya.common.load import ModelLoader
 6 | from surya.logging import get_logger
 7 | from surya.settings import settings
 8 | from surya.table_rec.model.config import (
 9 |     SuryaTableRecConfig,
10 |     SuryaTableRecDecoderConfig,
11 |     DonutSwinTableRecConfig,
12 | )
13 | from surya.table_rec.model.encoderdecoder import TableRecEncoderDecoderModel
14 | from surya.table_rec.processor import SuryaTableRecProcessor
15 | 
16 | logger = get_logger()
17 | 
18 | 
19 | class TableRecModelLoader(ModelLoader):
20 |     def __init__(self, checkpoint: Optional[str] = None):
21 |         super().__init__(checkpoint)
22 | 
23 |         if self.checkpoint is None:
24 |             self.checkpoint = settings.TABLE_REC_MODEL_CHECKPOINT
25 | 
26 |     def model(
27 |         self,
28 |         device=settings.TORCH_DEVICE_MODEL,
29 |         dtype=settings.MODEL_DTYPE,
30 |         attention_implementation: Optional[str] = None,
31 |     ) -> TableRecEncoderDecoderModel:
32 |         if device is None:
33 |             device = settings.TORCH_DEVICE_MODEL
34 |         if dtype is None:
35 |             dtype = settings.MODEL_DTYPE
36 | 
37 |         if device == "mps":
38 |             logger.warning(
39 |                 "`TableRecEncoderDecoderModel` is not compatible with mps backend. Defaulting to cpu instead"
40 |             )
41 |             device = "cpu"
42 |             dtype = "float32"
43 | 
44 |         config = SuryaTableRecConfig.from_pretrained(self.checkpoint)
45 |         decoder_config = config.decoder
46 |         decoder = SuryaTableRecDecoderConfig(**decoder_config)
47 |         config.decoder = decoder
48 | 
49 |         encoder_config = config.encoder
50 |         encoder = DonutSwinTableRecConfig(**encoder_config)
51 |         config.encoder = encoder
52 | 
53 |         model = TableRecEncoderDecoderModel.from_pretrained(
54 |             self.checkpoint, config=config, dtype=dtype
55 |         )
56 | 
57 |         model = model.to(device)
58 |         model = model.eval()
59 | 
60 |         if settings.COMPILE_ALL or settings.COMPILE_TABLE_REC:
61 |             torch.set_float32_matmul_precision("high")
62 |             torch._dynamo.config.cache_size_limit = 16
63 |             torch._dynamo.config.suppress_errors = False
64 | 
65 |             logger.info(
66 |                 f"Compiling table recognition model {self.checkpoint} on device {device} with dtype {dtype}"
67 |             )
68 |             compile_args = {"backend": "openxla"} if device == "xla" else {}
69 |             model.encoder = torch.compile(model.encoder, **compile_args)
70 |             model.decoder = torch.compile(model.decoder, **compile_args)
71 | 
72 |         logger.debug(
73 |             f"Loaded table recognition model {self.checkpoint} from {TableRecEncoderDecoderModel.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}"
74 |         )
75 |         return model
76 | 
77 |     def processor(
78 |         self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE
79 |     ) -> SuryaTableRecProcessor:
80 |         processor = SuryaTableRecProcessor(self.checkpoint)
81 | 
82 |         processor.token_pad_id = 0
83 |         processor.token_eos_id = 1
84 |         processor.token_bos_id = 1
85 |         processor.token_query_end_id = 4
86 |         return processor
87 | 
```

--------------------------------------------------------------------------------
/surya/common/surya/decoder/config.py:
--------------------------------------------------------------------------------

```python
 1 | from transformers.configuration_utils import PretrainedConfig
 2 | from transformers.modeling_rope_utils import rope_config_validation
 3 | from transformers.utils import logging
 4 | 
 5 | logger = logging.get_logger(__name__)
 6 | 
 7 | 
 8 | class SuryaDecoderConfig(PretrainedConfig):
 9 |     model_type = "qwen2"
10 |     keys_to_ignore_at_inference = ["past_key_values"]
11 | 
12 |     # Default tensor parallel plan for base model `Qwen2`
13 |     base_model_tp_plan = {
14 |         "layers.*.self_attn.q_proj": "colwise",
15 |         "layers.*.self_attn.k_proj": "colwise",
16 |         "layers.*.self_attn.v_proj": "colwise",
17 |         "layers.*.self_attn.o_proj": "rowwise",
18 |         "layers.*.mlp.gate_proj": "colwise",
19 |         "layers.*.mlp.up_proj": "colwise",
20 |         "layers.*.mlp.down_proj": "rowwise",
21 |     }
22 |     base_model_pp_plan = {
23 |         "embed_tokens": (["input_ids"], ["inputs_embeds"]),
24 |         "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
25 |         "norm": (["hidden_states"], ["hidden_states"]),
26 |     }
27 | 
28 |     def __init__(
29 |         self,
30 |         vocab_size=151936,
31 |         hidden_size=4096,
32 |         intermediate_size=22016,
33 |         num_hidden_layers=32,
34 |         num_attention_heads=32,
35 |         num_key_value_heads=32,
36 |         hidden_act="silu",
37 |         max_position_embeddings=32768,
38 |         initializer_range=0.02,
39 |         rms_norm_eps=1e-6,
40 |         use_cache=True,
41 |         tie_word_embeddings=False,
42 |         rope_theta=10000.0,
43 |         rope_scaling=None,
44 |         use_sliding_window=False,
45 |         sliding_window=4096,
46 |         max_window_layers=28,
47 |         attention_dropout=0.0,
48 |         **kwargs,
49 |     ):
50 |         self.vocab_size = vocab_size
51 |         self.max_position_embeddings = max_position_embeddings
52 |         self.hidden_size = hidden_size
53 |         self.intermediate_size = intermediate_size
54 |         self.num_hidden_layers = num_hidden_layers
55 |         self.num_attention_heads = num_attention_heads
56 |         self.use_sliding_window = False  # Disable sliding window
57 |         self.sliding_window = (
58 |             sliding_window  # we check `use_sliding_window` in the modeling code
59 |         )
60 |         self.max_window_layers = max_window_layers
61 | 
62 |         # for backward compatibility
63 |         if num_key_value_heads is None:
64 |             num_key_value_heads = num_attention_heads
65 | 
66 |         self.num_key_value_heads = num_key_value_heads
67 |         self.hidden_act = hidden_act
68 |         self.initializer_range = initializer_range
69 |         self.rms_norm_eps = rms_norm_eps
70 |         self.use_cache = use_cache
71 |         self.rope_theta = rope_theta
72 |         self.rope_scaling = rope_scaling
73 |         self.attention_dropout = attention_dropout
74 |         # Validate the correctness of rotary position embeddings parameters
75 |         # BC: if there is a 'type' field, move it to 'rope_type'.
76 |         if self.rope_scaling is not None and "type" in self.rope_scaling:
77 |             self.rope_scaling["rope_type"] = self.rope_scaling["type"]
78 |         rope_config_validation(self)
79 | 
80 |         super().__init__(
81 |             tie_word_embeddings=tie_word_embeddings,
82 |             **kwargs,
83 |         )
84 | 
```

--------------------------------------------------------------------------------
/benchmark/ordering.py:
--------------------------------------------------------------------------------

```python
 1 | import collections
 2 | import json
 3 | 
 4 | import click
 5 | 
 6 | from surya.foundation import FoundationPredictor
 7 | from surya.input.processing import convert_if_not_rgb
 8 | from surya.layout import LayoutPredictor
 9 | from surya.common.polygon import PolygonBox
10 | from surya.settings import settings
11 | from benchmark.utils.metrics import rank_accuracy
12 | import os
13 | import time
14 | import datasets
15 | 
16 | 
17 | @click.command(help="Benchmark surya layout for reading order.")
18 | @click.option(
19 |     "--results_dir",
20 |     type=str,
21 |     help="Path to JSON file with benchmark results.",
22 |     default=os.path.join(settings.RESULT_DIR, "benchmark"),
23 | )
24 | @click.option(
25 |     "--max_rows",
26 |     type=int,
27 |     help="Maximum number of images to run benchmark on.",
28 |     default=None,
29 | )
30 | def main(results_dir: str, max_rows: int):
31 |     foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
32 |     layout_predictor = LayoutPredictor(foundation_predictor)
33 |     pathname = "order_bench"
34 |     # These have already been shuffled randomly, so sampling from the start is fine
35 |     split = "train"
36 |     if max_rows is not None:
37 |         split = f"train[:{max_rows}]"
38 |     dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split)
39 |     images = list(dataset["image"])
40 |     images = convert_if_not_rgb(images)
41 | 
42 |     start = time.time()
43 |     layout_predictions = layout_predictor(images)
44 |     surya_time = time.time() - start
45 | 
46 |     folder_name = os.path.basename(pathname).split(".")[0]
47 |     result_path = os.path.join(results_dir, folder_name)
48 |     os.makedirs(result_path, exist_ok=True)
49 | 
50 |     page_metrics = collections.OrderedDict()
51 |     mean_accuracy = 0
52 |     for idx, order_pred in enumerate(layout_predictions):
53 |         row = dataset[idx]
54 |         labels = row["labels"]
55 |         bboxes = row["bboxes"]
56 |         pred_positions = []
57 |         for label, bbox in zip(labels, bboxes):
58 |             max_intersection = 0
59 |             matching_idx = 0
60 |             for pred_box in order_pred.bboxes:
61 |                 intersection = pred_box.intersection_pct(PolygonBox(polygon=bbox))
62 |                 if intersection > max_intersection:
63 |                     max_intersection = intersection
64 |                     matching_idx = pred_box.position
65 |             pred_positions.append(matching_idx)
66 |         accuracy = rank_accuracy(pred_positions, labels)
67 |         mean_accuracy += accuracy
68 |         page_results = {"accuracy": accuracy, "box_count": len(labels)}
69 | 
70 |         page_metrics[idx] = page_results
71 | 
72 |     mean_accuracy /= len(layout_predictions)
73 | 
74 |     out_data = {
75 |         "time": surya_time,
76 |         "mean_accuracy": mean_accuracy,
77 |         "page_metrics": page_metrics,
78 |     }
79 | 
80 |     with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
81 |         json.dump(out_data, f, indent=4)
82 | 
83 |     print(f"Mean accuracy is {mean_accuracy:.2f}.")
84 |     print(
85 |         f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total."
86 |     )
87 |     print("Mean accuracy is the % of correct ranking pairs.")
88 |     print(f"Wrote results to {result_path}")
89 | 
90 | 
91 | if __name__ == "__main__":
92 |     main()
93 | 
```

--------------------------------------------------------------------------------
/surya/debug/text.py:
--------------------------------------------------------------------------------

```python
  1 | import re
  2 | from io import BytesIO
  3 | from typing import List, Tuple
  4 | from PIL import Image, ImageDraw, ImageFont
  5 | 
  6 | from surya.debug.fonts import get_font_path
  7 | from surya.debug.render_html import render_text_as_html
  8 | 
  9 | try:
 10 |     from playwright.sync_api import sync_playwright
 11 | 
 12 |     has_playwright = True
 13 | except ImportError:
 14 |     has_playwright = False
 15 | 
 16 | 
 17 | def strip_html_tags(html_text):
 18 |     pattern = re.compile(r"<[\w/][^>]*>")
 19 |     text_only = pattern.sub("", html_text)
 20 | 
 21 |     return text_only
 22 | 
 23 | 
 24 | def get_text_size(text, font):
 25 |     im = Image.new(mode="P", size=(0, 0))
 26 |     draw = ImageDraw.Draw(im)
 27 |     _, _, width, height = draw.textbbox((0, 0), text=text, font=font)
 28 |     return width, height
 29 | 
 30 | 
 31 | def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size):
 32 |     font = ImageFont.truetype(font_path, box_font_size)
 33 |     text_width, text_height = get_text_size(text, font)
 34 |     while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6:
 35 |         box_font_size = box_font_size - 1
 36 |         font = ImageFont.truetype(font_path, box_font_size)
 37 |         text_width, text_height = get_text_size(text, font)
 38 | 
 39 |     # Calculate text position (centered in bbox)
 40 |     text_width, text_height = get_text_size(text, font)
 41 |     x = s_bbox[0]
 42 |     y = s_bbox[1] + (bbox_height - text_height) / 2
 43 | 
 44 |     draw.text((x, y), text, fill="black", font=font)
 45 | 
 46 | 
 47 | def draw_text_with_playwright(
 48 |     bboxes, texts: List[str], image_size: Tuple[int, int]
 49 | ) -> Image.Image:
 50 |     html_content, image_size = render_text_as_html(bboxes, texts, image_size)
 51 |     if not has_playwright:
 52 |         raise ImportError(
 53 |             "Playwright is not installed. Please install it using `pip install playwright`"
 54 |         )
 55 | 
 56 |     with sync_playwright() as p:
 57 |         browser = p.chromium.launch(headless=True)
 58 |         page = browser.new_page(
 59 |             viewport={"width": image_size[0], "height": image_size[1]}
 60 |         )
 61 |         page.set_content(html_content)
 62 |         page.wait_for_timeout(1000)
 63 |         body = page.query_selector("body")
 64 |         image = body.screenshot()
 65 |         browser.close()
 66 | 
 67 |     pil_img = Image.open(BytesIO(image))
 68 |     return pil_img
 69 | 
 70 | 
 71 | def draw_text_on_image(
 72 |     bboxes,
 73 |     texts,
 74 |     image_size: Tuple[int, int],
 75 |     font_path=None,
 76 |     max_font_size=60,
 77 |     res_upscale=2,
 78 | ) -> Image.Image:
 79 |     if has_playwright:
 80 |         return draw_text_with_playwright(bboxes, texts, image_size)
 81 | 
 82 |     texts = [strip_html_tags(text) for text in texts]
 83 |     if font_path is None:
 84 |         font_path = get_font_path()
 85 |     new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale)
 86 |     image = Image.new("RGB", new_image_size, color="white")
 87 |     draw = ImageDraw.Draw(image)
 88 | 
 89 |     for bbox, text in zip(bboxes, texts):
 90 |         s_bbox = [int(coord * res_upscale) for coord in bbox]
 91 |         bbox_width = s_bbox[2] - s_bbox[0]
 92 |         bbox_height = s_bbox[3] - s_bbox[1]
 93 | 
 94 |         # Shrink the text to fit in the bbox if needed
 95 |         box_font_size = max(6, min(int(0.75 * bbox_height), max_font_size))
 96 |         render_text(
 97 |             draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size
 98 |         )
 99 | 
100 |     return image
101 | 
```

--------------------------------------------------------------------------------
/surya/recognition/postprocessing.py:
--------------------------------------------------------------------------------

```python
  1 | import re
  2 | from typing import List, Dict
  3 | 
  4 | from surya.recognition.schema import TextChar
  5 | 
  6 | 
  7 | def truncate_repetitions(text: str, min_len=15):
  8 |     # From nougat, with some cleanup
  9 |     if len(text) < 2 * min_len:
 10 |         return text
 11 | 
 12 |     # try to find a length at which the tail is repeating
 13 |     max_rep_len = None
 14 |     for rep_len in range(min_len, int(len(text) / 2)):
 15 |         # check if there is a repetition at the end
 16 |         same = True
 17 |         for i in range(0, rep_len):
 18 |             if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]:
 19 |                 same = False
 20 |                 break
 21 | 
 22 |         if same:
 23 |             max_rep_len = rep_len
 24 | 
 25 |     if max_rep_len is None:
 26 |         return text
 27 | 
 28 |     lcs = text[-max_rep_len:]
 29 | 
 30 |     # remove all but the last repetition
 31 |     text_to_truncate = text
 32 |     while text_to_truncate.endswith(lcs):
 33 |         text_to_truncate = text_to_truncate[:-max_rep_len]
 34 | 
 35 |     return text[: len(text_to_truncate)]
 36 | 
 37 | 
 38 | def extract_tags(proposed_tags: List[str]) -> List[str]:
 39 |     tags = []
 40 |     for tag in proposed_tags:
 41 |         tag_match = re.match(tag_pattern, tag)
 42 |         if not tag_match:
 43 |             continue
 44 | 
 45 |         if not tag_match.group(1) == "/":
 46 |             continue
 47 | 
 48 |         tags.append(tag_match.group(2))
 49 |     return tags
 50 | 
 51 | 
 52 | tag_pattern = re.compile(r"<(/?)([a-z]+)([^>]*)>?", re.IGNORECASE)
 53 | 
 54 | 
 55 | def cleanup_math(line: str):
 56 |     matches = re.finditer(r"(<math[^>]*>)(.*?)</math>", line, re.DOTALL)
 57 |     result = line
 58 | 
 59 |     for match in matches:
 60 |         opening_tag = match.group(1)  # The opening <math> tag with attributes
 61 |         full_match = match.group(0)  # The entire <math>content</math> tag
 62 |         block_content = match.group(2)  # Just the content inside the tags
 63 | 
 64 |         clean_block = re.sub(r"<[^>]+>", "", block_content)
 65 | 
 66 |         if not re.search(r"[\\\_]", clean_block):
 67 |             result = result.replace(full_match, clean_block)
 68 |         else:
 69 |             result = result.replace(full_match, f"{opening_tag}{clean_block}</math>")
 70 | 
 71 |     return result
 72 | 
 73 | 
 74 | def fix_unbalanced_tags(
 75 |     text_chars: List[TextChar], special_tokens: Dict[str, list]
 76 | ) -> List[TextChar]:
 77 |     self_closing_tags = ["br"]
 78 | 
 79 |     open_tags = []
 80 | 
 81 |     format_tags = extract_tags(special_tokens["formatting"]) + extract_tags(
 82 |         special_tokens["math_external"]
 83 |     )
 84 | 
 85 |     for char in text_chars:
 86 |         if len(char.text) <= 1:
 87 |             continue
 88 | 
 89 |         tag_match = re.match(tag_pattern, char.text)
 90 |         if not tag_match:
 91 |             continue
 92 | 
 93 |         is_closing = tag_match.group(1) == "/"
 94 |         tag_name = tag_match.group(2).lower()
 95 | 
 96 |         if tag_name not in format_tags:
 97 |             continue
 98 | 
 99 |         if tag_name in self_closing_tags:
100 |             continue
101 | 
102 |         # Self-closing tags
103 |         if tag_match.group(3) and tag_match.group(3).strip().endswith("/"):
104 |             continue
105 | 
106 |         if is_closing:
107 |             if open_tags and open_tags[-1] == tag_name:
108 |                 open_tags.pop()
109 |         else:
110 |             open_tags.append(tag_name)
111 | 
112 |     for tag in open_tags:
113 |         text_chars.append(
114 |             TextChar(
115 |                 text=f"</{tag}>",
116 |                 confidence=0,
117 |                 polygon=[[0, 0], [1, 0], [1, 1], [0, 1]],
118 |                 bbox_valid=False,
119 |             )
120 |         )
121 |     return text_chars
122 | 
```

--------------------------------------------------------------------------------
/surya/common/surya/config.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import Optional
 2 | from transformers import PretrainedConfig
 3 | 
 4 | from surya.common.s3 import S3DownloaderMixin
 5 | from surya.common.surya.encoder.config import SuryaEncoderConfig
 6 | from surya.common.surya.decoder.config import SuryaDecoderConfig
 7 | 
 8 | 
 9 | class SuryaModelConfig(S3DownloaderMixin, PretrainedConfig):
10 |     model_type = "surya-multimodal-foundation"
11 |     is_composition = True
12 | 
13 |     def __init__(
14 |         self,
15 |         vocab_size=65536,
16 |         bbox_size=1025,
17 |         blank_bbox_token_id=1025,
18 |         bos_token_id=0,
19 |         eos_token_id=1,
20 |         pad_token_id=2,
21 |         image_token_id=3,
22 |         register_token_ids=(4, 5, 6, 7),
23 |         eoi_token_id=8,
24 |         beacon_token_id=9,
25 |         special_token_count=4,
26 |         max_sequence_length=1536,
27 |         special_ocr_tokens=None,
28 |         vision_encoder=None,
29 |         decoder=None,
30 |         tasks: dict | None = None,
31 |         bbox_embed_size: int = 64,
32 |         num_register_tokens: int = 4,
33 |         image_embed_encoding_size: int = 1024,
34 |         image_embed_encoding_multiplier: int = 256,
35 |         num_beacon_tokens: int = 1,
36 |         beacon_token_interval: int = 4096,
37 |         sliding_window: Optional[int] = None,
38 |         multi_output_distance: int = 4,
39 |         max_multi_out: int = 8,
40 |         **kwargs,
41 |     ):
42 |         super().__init__(**kwargs)
43 |         self.is_encoder_decoder = False
44 |         self.vocab_size = vocab_size
45 |         self.bbox_size = bbox_size
46 |         self.blank_bbox_token_id = blank_bbox_token_id
47 |         self.image_token_id = image_token_id
48 |         self.bos_token_id = bos_token_id
49 |         self.eos_token_id = eos_token_id
50 |         self.pad_token_id = pad_token_id
51 |         self.eoi_token_id = eoi_token_id
52 |         self.beacon_token_id = beacon_token_id
53 |         self.special_ocr_tokens = special_ocr_tokens
54 |         self.special_token_count = special_token_count  # pad, bos, etc, tokens
55 |         self.max_sequence_length = max_sequence_length
56 |         self.tasks = tasks
57 |         self.tie_word_embeddings = True
58 |         self.bbox_embed_size = bbox_embed_size
59 |         self.num_register_tokens = num_register_tokens
60 |         self.register_token_ids = register_token_ids
61 |         self.image_embed_encoding_size = image_embed_encoding_size
62 |         self.image_embed_encoding_multiplier = image_embed_encoding_multiplier
63 |         self.num_beacon_tokens = num_beacon_tokens
64 |         self.beacon_token_interval = beacon_token_interval
65 |         self.sliding_window = sliding_window
66 |         self.multi_output_distance = multi_output_distance
67 |         self.max_multi_out = max_multi_out
68 | 
69 |         if self.sliding_window is None:
70 |             self.sliding_window = self.max_sequence_length
71 | 
72 |         if isinstance(vision_encoder, dict):
73 |             vision_encoder = SuryaEncoderConfig(**vision_encoder)
74 |         elif vision_encoder is None:
75 |             vision_encoder = SuryaEncoderConfig()
76 |         self.vision_encoder = vision_encoder
77 | 
78 |         if isinstance(decoder, dict):
79 |             decoder = SuryaDecoderConfig(**decoder)
80 |         elif decoder is None:
81 |             decoder = SuryaDecoderConfig()
82 |         self.decoder = decoder
83 | 
84 |         self.hidden_size = self.decoder.hidden_size
85 | 
86 |         self.patch_size = self.vision_encoder.spatial_patch_size
87 |         self.merge_size = self.vision_encoder.spatial_merge_size
88 | 
```

--------------------------------------------------------------------------------
/surya/table_rec/processor.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import List
 2 | 
 3 | import PIL
 4 | import torch
 5 | from transformers import ProcessorMixin
 6 | 
 7 | from surya.common.s3 import S3DownloaderMixin
 8 | from surya.common.donut.processor import SuryaEncoderImageProcessor
 9 | from surya.table_rec.shaper import LabelShaper
10 | from surya.settings import settings
11 | from surya.table_rec.model.config import BOX_DIM, SPECIAL_TOKENS
12 | 
13 | 
14 | class SuryaTableRecProcessor(S3DownloaderMixin, ProcessorMixin):
15 |     attributes = ["image_processor"]
16 |     image_processor_class = "AutoImageProcessor"
17 | 
18 |     def __init__(self, checkpoint, **kwargs):
19 |         image_processor = SuryaEncoderImageProcessor.from_pretrained(checkpoint)
20 |         image_processor.do_align_long_axis = False
21 |         image_processor.max_size = settings.TABLE_REC_IMAGE_SIZE
22 |         self.image_processor = image_processor
23 |         super().__init__(image_processor)
24 | 
25 |         self.box_size = (BOX_DIM, BOX_DIM)
26 |         self.special_token_count = SPECIAL_TOKENS
27 |         self.shaper = LabelShaper()
28 | 
29 |     def resize_polygon(self, polygon, orig_size, new_size):
30 |         w_scaler = new_size[0] / orig_size[0]
31 |         h_scaler = new_size[1] / orig_size[1]
32 | 
33 |         for corner in polygon:
34 |             corner[0] = corner[0] * w_scaler
35 |             corner[1] = corner[1] * h_scaler
36 | 
37 |             if corner[0] < 0:
38 |                 corner[0] = 0
39 |             if corner[1] < 0:
40 |                 corner[1] = 0
41 |             if corner[0] > new_size[0]:
42 |                 corner[0] = new_size[0]
43 |             if corner[1] > new_size[1]:
44 |                 corner[1] = new_size[1]
45 | 
46 |         return polygon
47 | 
48 |     def __call__(
49 |             self,
50 |             images: List[PIL.Image.Image] | None,
51 |             query_items: List[dict],
52 |             columns: List[dict] | None = None,
53 |             convert_images: bool = True,
54 |             *args,
55 |             **kwargs
56 |     ):
57 |         if convert_images:
58 |             assert len(images) == len(query_items)
59 |             assert len(images) > 0
60 | 
61 |             # Resize input query items
62 |             for image, query_item in zip(images, query_items):
63 |                 query_item["polygon"] = self.resize_polygon(query_item["polygon"], image.size, self.box_size)
64 | 
65 |         query_items = self.shaper.convert_polygons_to_bboxes(query_items)
66 |         query_labels = self.shaper.dict_to_labels(query_items)
67 | 
68 |         decoder_input_boxes = []
69 |         col_count = len(query_labels[0])
70 |         for label in query_labels:
71 |             decoder_input_boxes.append([
72 |                 [self.token_bos_id] * col_count,
73 |                 label,
74 |                 [self.token_query_end_id] * col_count
75 |             ])
76 | 
77 |         # Add columns to end of decoder input
78 |         if columns:
79 |             columns = self.shaper.convert_polygons_to_bboxes(columns)
80 |             column_labels = self.shaper.dict_to_labels(columns)
81 |             for decoder_box in decoder_input_boxes:
82 |                 decoder_box += column_labels
83 | 
84 |         input_boxes = torch.tensor(decoder_input_boxes, dtype=torch.long)
85 |         input_boxes_mask = torch.ones_like(input_boxes, dtype=torch.long)
86 | 
87 |         inputs = {
88 |             "input_ids": input_boxes,
89 |             "attention_mask": input_boxes_mask
90 |         }
91 |         if convert_images:
92 |             inputs["pixel_values"] = self.image_processor(images, *args, **kwargs)["pixel_values"]
93 |         return inputs
94 | 
```
Page 1/5FirstPrevNextLast