This is page 1 of 5. Use http://codebase.md/datalab-to/surya?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── ISSUE_TEMPLATE │ │ ├── breaking-bug-report.md │ │ ├── feature_request.md │ │ └── output-bug-report.md │ └── workflows │ ├── benchmarks.yml │ ├── ci.yml │ ├── cla.yml │ ├── publish.yml │ └── scripts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── benchmark │ ├── detection.py │ ├── layout.py │ ├── ordering.py │ ├── recognition.py │ ├── table_recognition.py │ ├── texify.py │ └── utils │ ├── __init__.py │ ├── bbox.py │ ├── metrics.py │ ├── scoring.py │ ├── tatr.py │ ├── tesseract.py │ ├── textract.py │ └── verify_benchmark_scores.py ├── CITATION.cff ├── CLA.md ├── detect_layout.py ├── detect_text.py ├── LICENSE ├── ocr_app.py ├── ocr_latex.py ├── ocr_text.py ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── README.md ├── signatures │ └── version1 │ └── cla.json ├── static │ ├── fonts │ │ └── .gitignore │ └── images │ ├── arabic_layout.jpg │ ├── arabic_reading.jpg │ ├── arabic_text.jpg │ ├── arabic.jpg │ ├── benchmark_chart_small.png │ ├── benchmark_chart.png │ ├── benchmark_layout_chart.png │ ├── benchmark_rec_chart.png │ ├── benchmark_tablerec_acc.png │ ├── benchmark_tablerec_speed.png │ ├── chi_hind_layout.jpg │ ├── chi_hind_orig.jpg │ ├── chi_hind_reading.jpg │ ├── chi_hind_text.jpg │ ├── chi_hind.jpg │ ├── chinese_layout.jpg │ ├── chinese_reading.jpg │ ├── chinese_text.jpg │ ├── chinese.jpg │ ├── excerpt_layout.png │ ├── excerpt_reading.jpg │ ├── excerpt_text.png │ ├── excerpt.png │ ├── funsd_layout.jpg │ ├── funsd_reading.jpg │ ├── funsd_text.jpg │ ├── funsd.png │ ├── gcloud_full_langs.png │ ├── gcloud_rec_bench.png │ ├── hindi_layout.jpg │ ├── hindi_reading.jpg │ ├── hindi_text.jpg │ ├── hindi.jpg │ ├── japanese_layout.jpg │ ├── japanese_reading.jpg │ ├── japanese_tablerec.png │ ├── japanese_text.jpg │ ├── japanese.jpg │ ├── latex_ocr.png │ ├── nyt_layout.jpg │ ├── nyt_order.jpg │ ├── nyt_text.jpg │ ├── nyt.jpg │ ├── paper_layout.jpg │ ├── paper_reading.jpg │ ├── paper_tablerec.png │ ├── paper_text.jpg │ ├── paper.jpg │ ├── pres_layout.jpg │ ├── pres_reading.jpg │ ├── pres_tablerec.png │ ├── pres_text.jpg │ ├── pres.png │ ├── rec_acc_table.png │ ├── scanned_layout.jpg │ ├── scanned_reading.jpg │ ├── scanned_tablerec.png │ ├── scanned_tablerec2.png │ ├── scanned_text.jpg │ ├── scanned.png │ ├── surya_rec_perf.png │ ├── table_rec.png │ ├── textbook_layout.jpg │ ├── textbook_order.jpg │ ├── textbook_text.jpg │ └── textbook.jpg ├── surya │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── adetr │ │ │ └── decoder.py │ │ ├── donut │ │ │ ├── encoder.py │ │ │ └── processor.py │ │ ├── load.py │ │ ├── polygon.py │ │ ├── predictor.py │ │ ├── pretrained.py │ │ ├── s3.py │ │ ├── surya │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── decoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── embedder │ │ │ │ └── __init__.py │ │ │ ├── encoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── flash_attn_utils.py │ │ │ ├── processor │ │ │ │ ├── __init__.py │ │ │ │ ├── schema.py │ │ │ │ └── tokenizer.py │ │ │ └── schema.py │ │ ├── util.py │ │ └── xla.py │ ├── debug │ │ ├── draw.py │ │ ├── fonts.py │ │ ├── katex.js │ │ ├── render_html.py │ │ └── text.py │ ├── detection │ │ ├── __init__.py │ │ ├── heatmap.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoderdecoder.py │ │ ├── parallel.py │ │ ├── processor.py │ │ ├── schema.py │ │ └── util.py │ ├── foundation │ │ ├── __init__.py │ │ ├── cache │ │ │ ├── __init__.py │ │ │ ├── dynamic_ops.py │ │ │ └── static_ops.py │ │ ├── loader.py │ │ └── util.py │ ├── input │ │ ├── load.py │ │ └── processing.py │ ├── layout │ │ ├── __init__.py │ │ ├── label.py │ │ └── schema.py │ ├── logging.py │ ├── models.py │ ├── ocr_error │ │ ├── __init__.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoder.py │ │ ├── schema.py │ │ └── tokenizer.py │ ├── recognition │ │ ├── __init__.py │ │ ├── languages.py │ │ ├── postprocessing.py │ │ ├── schema.py │ │ └── util.py │ ├── scripts │ │ ├── __init__.py │ │ ├── config.py │ │ ├── detect_layout.py │ │ ├── detect_text.py │ │ ├── finetune_ocr.py │ │ ├── hf_to_s3.py │ │ ├── ocr_latex.py │ │ ├── ocr_text.py │ │ ├── run_streamlit_app.py │ │ ├── run_texify_app.py │ │ ├── streamlit_app.py │ │ ├── table_recognition.py │ │ └── texify_app.py │ ├── settings.py │ └── table_rec │ ├── __init__.py │ ├── loader.py │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── encoderdecoder.py │ ├── processor.py │ ├── schema.py │ └── shaper.py ├── table_recognition.py ├── tests │ ├── assets │ │ └── test_latex.png │ ├── conftest.py │ ├── test_detection.py │ ├── test_foundation.py │ ├── test_latex_ocr.py │ ├── test_layout.py │ ├── test_ocr_errors.py │ ├── test_recognition.py │ └── test_table_rec.py └── texify_app.py ``` # Files -------------------------------------------------------------------------------- /static/fonts/.gitignore: -------------------------------------------------------------------------------- ``` 1 | * 2 | !.gitignore ``` -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- ```yaml 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.9.10 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | types_or: [ python, pyi ] 9 | args: [ --fix ] 10 | # Run the formatter. 11 | - id: ruff-format 12 | types_or: [ python, pyi ] ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- ``` 1 | private.py 2 | .DS_Store 3 | local.env 4 | experiments 5 | test_data 6 | training 7 | wandb 8 | notebooks 9 | results 10 | data 11 | slices 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | cover/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | .pybuilder/ 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | # For a library or package, you might want to ignore these files since the code is 99 | # intended to run in multiple environments; otherwise, check them in: 100 | # .python-version 101 | 102 | # pipenv 103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 106 | # install all needed dependencies. 107 | #Pipfile.lock 108 | 109 | # poetry 110 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 111 | # This is especially recommended for binary packages to ensure reproducibility, and is more 112 | # commonly ignored for libraries. 113 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 114 | #poetry.lock 115 | 116 | # pdm 117 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 118 | #pdm.lock 119 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 120 | # in version control. 121 | # https://pdm.fming.dev/#use-with-ide 122 | .pdm.toml 123 | 124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 125 | __pypackages__/ 126 | 127 | # Celery stuff 128 | celerybeat-schedule 129 | celerybeat.pid 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | .idea/ 173 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- ```markdown 1 | # Surya 2 | 3 | Surya is a document OCR toolkit that does: 4 | 5 | - OCR in 90+ languages that benchmarks favorably vs cloud services 6 | - Line-level text detection in any language 7 | - Layout analysis (table, image, header, etc detection) 8 | - Reading order detection 9 | - Table recognition (detecting rows/columns) 10 | - LaTeX OCR 11 | 12 | It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details). 13 | 14 | For our managed API or on-prem document intelligence solution, check out [our platform here](https://datalab.to?utm_source=gh-surya). 15 | 16 | 17 | | Detection | OCR | 18 | |:----------------------------------------------------------------:|:-----------------------------------------------------------------------:| 19 | | <img src="static/images/excerpt.png" width="500px"/> | <img src="static/images/excerpt_text.png" width="500px"/> | 20 | 21 | | Layout | Reading Order | 22 | |:------------------------------------------------------------------:|:--------------------------------------------------------------------------:| 23 | | <img src="static/images/excerpt_layout.png" width="500px"/> | <img src="static/images/excerpt_reading.jpg" width="500px"/> | 24 | 25 | | Table Recognition | LaTeX OCR | 26 | |:-------------------------------------------------------------:|:------------------------------------------------------:| 27 | | <img src="static/images/scanned_tablerec.png" width="500px"/> | <img src="static/images/latex_ocr.png" width="500px"/> | 28 | 29 | 30 | Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision. 31 | 32 | ## Community 33 | 34 | [Discord](https://discord.gg//KuZwXNGnfH) is where we discuss future development. 35 | 36 | ## Examples 37 | 38 | | Name | Detection | OCR | Layout | Order | Table Rec | 39 | |------------------|:-----------------------------------:|-----------------------------------------:|-------------------------------------------:|--------------------------------------------:|---------------------------------------------:| 40 | | Japanese | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | [Image](static/images/japanese_reading.jpg) | [Image](static/images/japanese_tablerec.png) | 41 | | Chinese | [Image](static/images/chinese.jpg) | [Image](static/images/chinese_text.jpg) | [Image](static/images/chinese_layout.jpg) | [Image](static/images/chinese_reading.jpg) | | 42 | | Hindi | [Image](static/images/hindi.jpg) | [Image](static/images/hindi_text.jpg) | [Image](static/images/hindi_layout.jpg) | [Image](static/images/hindi_reading.jpg) | | 43 | | Arabic | [Image](static/images/arabic.jpg) | [Image](static/images/arabic_text.jpg) | [Image](static/images/arabic_layout.jpg) | [Image](static/images/arabic_reading.jpg) | | 44 | | Chinese + Hindi | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | [Image](static/images/chi_hind_reading.jpg) | | 45 | | Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) | [Image](static/images/pres_reading.jpg) | [Image](static/images/pres_tablerec.png) | 46 | | Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) | [Image](static/images/paper_reading.jpg) | [Image](static/images/paper_tablerec.png) | 47 | | Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) | [Image](static/images/scanned_reading.jpg) | [Image](static/images/scanned_tablerec.png) | 48 | | New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | [Image](static/images/nyt_order.jpg) | | 49 | | Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | [Image](static/images/funsd_layout.jpg) | [Image](static/images/funsd_reading.jpg) | [Image](static/images/scanned_tablerec2.png) | 50 | | Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | [Image](static/images/textbook_order.jpg) | | 51 | 52 | # Hosted API 53 | 54 | There is a hosted API for all surya models available [here](https://www.datalab.to?utm_source=gh-surya): 55 | 56 | - Works with PDF, images, word docs, and powerpoints 57 | - Consistent speed, with no latency spikes 58 | - High reliability and uptime 59 | 60 | # Commercial usage 61 | 62 | Our model weights use a modified AI Pubs Open Rail-M license (free for research, personal use, and startups under $2M funding/revenue) and our code is GPL. For broader commercial licensing or to remove GPL requirements, visit our pricing page [here](https://www.datalab.to/pricing?utm_source=gh-surya). 63 | 64 | 65 | # Installation 66 | 67 | You'll need python 3.10+ and PyTorch. You may need to install the CPU version of torch first if you're not using a Mac or a GPU machine. See [here](https://pytorch.org/get-started/locally/) for more details. 68 | 69 | Install with: 70 | 71 | ```shell 72 | pip install surya-ocr 73 | ``` 74 | 75 | Model weights will automatically download the first time you run surya. 76 | 77 | # Usage 78 | 79 | - Inspect the settings in `surya/settings.py`. You can override any settings with environment variables. 80 | - Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`. 81 | 82 | ## Interactive App 83 | 84 | I've included a streamlit app that lets you interactively try Surya on images or PDF files. Run it with: 85 | 86 | ```shell 87 | pip install streamlit pdftext 88 | surya_gui 89 | ``` 90 | 91 | ## OCR (text recognition) 92 | 93 | This command will write out a json file with the detected text and bboxes: 94 | 95 | ```shell 96 | surya_ocr DATA_PATH 97 | ``` 98 | 99 | - `DATA_PATH` can be an image, pdf, or folder of images/pdfs 100 | - `--task_name` will specify which task to use for predicting the lines. `ocr_with_boxes` is the default, which will format text and give you bboxes. If you get bad performance, try `ocr_without_boxes`, which will give you potentially better performance but no bboxes. For blocks like equations and paragraphs, try `block_without_boxes`. 101 | - `--images` will save images of the pages and detected text lines (optional) 102 | - `--output_dir` specifies the directory to save results to instead of the default 103 | - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. 104 | - `--disable_math` - by default, surya will recognize math in text. This can lead to false positives - you can disable this with this flag. 105 | 106 | The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: 107 | 108 | - `text_lines` - the detected text and bounding boxes for each line 109 | - `text` - the text in the line 110 | - `confidence` - the confidence of the model in the detected text (0-1) 111 | - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. 112 | - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. 113 | - `chars` - the individual characters in the line 114 | - `text` - the text of the character 115 | - `bbox` - the character bbox (same format as line bbox) 116 | - `polygon` - the character polygon (same format as line polygon) 117 | - `confidence` - the confidence of the model in the detected character (0-1) 118 | - `bbox_valid` - if the character is a special token or math, the bbox may not be valid 119 | - `words` - the individual words in the line (computed from the characters) 120 | - `text` - the text of the word 121 | - `bbox` - the word bbox (same format as line bbox) 122 | - `polygon` - the word polygon (same format as line polygon) 123 | - `confidence` - mean character confidence 124 | - `bbox_valid` - if the word is a special token or math, the bbox may not be valid 125 | - `page` - the page number in the file 126 | - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. 127 | 128 | **Performance tips** 129 | 130 | Setting the `RECOGNITION_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `40MB` of VRAM, so very high batch sizes are possible. The default is a batch size `512`, which will use about 20GB of VRAM. Depending on your CPU core count, it may help, too - the default CPU batch size is `32`. 131 | 132 | ### From python 133 | 134 | ```python 135 | from PIL import Image 136 | from surya.foundation import FoundationPredictor 137 | from surya.recognition import RecognitionPredictor 138 | from surya.detection import DetectionPredictor 139 | 140 | image = Image.open(IMAGE_PATH) 141 | foundation_predictor = FoundationPredictor() 142 | recognition_predictor = RecognitionPredictor(foundation_predictor) 143 | detection_predictor = DetectionPredictor() 144 | 145 | predictions = recognition_predictor([image], det_predictor=detection_predictor) 146 | ``` 147 | 148 | 149 | ## Text line detection 150 | 151 | This command will write out a json file with the detected bboxes. 152 | 153 | ```shell 154 | surya_detect DATA_PATH 155 | ``` 156 | 157 | - `DATA_PATH` can be an image, pdf, or folder of images/pdfs 158 | - `--images` will save images of the pages and detected text lines (optional) 159 | - `--output_dir` specifies the directory to save results to instead of the default 160 | - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. 161 | 162 | The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: 163 | 164 | - `bboxes` - detected bounding boxes for text 165 | - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. 166 | - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. 167 | - `confidence` - the confidence of the model in the detected text (0-1) 168 | - `vertical_lines` - vertical lines detected in the document 169 | - `bbox` - the axis-aligned line coordinates. 170 | - `page` - the page number in the file 171 | - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. 172 | 173 | **Performance tips** 174 | 175 | Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `440MB` of VRAM, so very high batch sizes are possible. The default is a batch size `36`, which will use about 16GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `6`. 176 | 177 | ### From python 178 | 179 | ```python 180 | from PIL import Image 181 | from surya.detection import DetectionPredictor 182 | 183 | image = Image.open(IMAGE_PATH) 184 | det_predictor = DetectionPredictor() 185 | 186 | # predictions is a list of dicts, one per image 187 | predictions = det_predictor([image]) 188 | ``` 189 | 190 | ## Layout and reading order 191 | 192 | This command will write out a json file with the detected layout and reading order. 193 | 194 | ```shell 195 | surya_layout DATA_PATH 196 | ``` 197 | 198 | - `DATA_PATH` can be an image, pdf, or folder of images/pdfs 199 | - `--images` will save images of the pages and detected text lines (optional) 200 | - `--output_dir` specifies the directory to save results to instead of the default 201 | - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. 202 | 203 | The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: 204 | 205 | - `bboxes` - detected bounding boxes for text 206 | - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. 207 | - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. 208 | - `position` - the reading order of the box. 209 | - `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`. 210 | - `top_k` - the top-k other potential labels for the box. A dictionary with labels as keys and confidences as values. 211 | - `page` - the page number in the file 212 | - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. 213 | 214 | **Performance tips** 215 | 216 | Setting the `LAYOUT_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `220MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 7GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`. 217 | 218 | ### From python 219 | 220 | ```python 221 | from PIL import Image 222 | from surya.foundation import FoundationPredictor 223 | from surya.layout import LayoutPredictor 224 | from surya.settings import settings 225 | 226 | image = Image.open(IMAGE_PATH) 227 | layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)) 228 | 229 | # layout_predictions is a list of dicts, one per image 230 | layout_predictions = layout_predictor([image]) 231 | ``` 232 | 233 | ## Table Recognition 234 | 235 | This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes. If you want to get cell positions and text, along with nice formatting, check out the [marker](https://www.github.com/VikParuchuri/marker) repo. You can use the `TableConverter` to detect and extract tables in images and PDFs. It supports output in json (with bboxes), markdown, and html. 236 | 237 | ```shell 238 | surya_table DATA_PATH 239 | ``` 240 | 241 | - `DATA_PATH` can be an image, pdf, or folder of images/pdfs 242 | - `--images` will save images of the pages and detected table cells + rows and columns (optional) 243 | - `--output_dir` specifies the directory to save results to instead of the default 244 | - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. 245 | - `--detect_boxes` specifies if cells should be detected. By default, they're pulled out of the PDF, but this is not always possible. 246 | - `--skip_table_detection` tells table recognition not to detect tables first. Use this if your image is already cropped to a table. 247 | 248 | The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: 249 | 250 | - `rows` - detected table rows 251 | - `bbox` - the bounding box of the table row 252 | - `row_id` - the id of the row 253 | - `is_header` - if it is a header row. 254 | - `cols` - detected table columns 255 | - `bbox` - the bounding box of the table column 256 | - `col_id`- the id of the column 257 | - `is_header` - if it is a header column 258 | - `cells` - detected table cells 259 | - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. 260 | - `text` - if text could be pulled out of the pdf, the text of this cell. 261 | - `row_id` - the id of the row the cell belongs to. 262 | - `col_id` - the id of the column the cell belongs to. 263 | - `colspan` - the number of columns spanned by the cell. 264 | - `rowspan` - the number of rows spanned by the cell. 265 | - `is_header` - whether it is a header cell. 266 | - `page` - the page number in the file 267 | - `table_idx` - the index of the table on the page (sorted in vertical order) 268 | - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. 269 | 270 | **Performance tips** 271 | 272 | Setting the `TABLE_REC_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `150MB` of VRAM, so very high batch sizes are possible. The default is a batch size `64`, which will use about 10GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `8`. 273 | 274 | ### From python 275 | 276 | ```python 277 | from PIL import Image 278 | from surya.table_rec import TableRecPredictor 279 | 280 | image = Image.open(IMAGE_PATH) 281 | table_rec_predictor = TableRecPredictor() 282 | 283 | table_predictions = table_rec_predictor([image]) 284 | ``` 285 | 286 | ## LaTeX OCR 287 | 288 | This command will write out a json file with the LaTeX of the equations. You must pass in images that are already cropped to the equations. You can do this by running the layout model, then cropping, if you want. 289 | 290 | ```shell 291 | surya_latex_ocr DATA_PATH 292 | ``` 293 | 294 | - `DATA_PATH` can be an image, pdf, or folder of images/pdfs 295 | - `--output_dir` specifies the directory to save results to instead of the default 296 | - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. 297 | 298 | The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. See the OCR section above for the format of the output. 299 | 300 | ### From python 301 | 302 | ```python 303 | from PIL import Image 304 | from surya.texify import TexifyPredictor 305 | 306 | image = Image.open(IMAGE_PATH) 307 | predictor = TexifyPredictor() 308 | 309 | predictor([image]) 310 | ``` 311 | 312 | ### Interactive app 313 | 314 | You can also run a special interactive app that lets you select equations and OCR them (kind of like MathPix snip) with: 315 | 316 | ```shell 317 | pip install streamlit==1.40 streamlit-drawable-canvas-jsretry 318 | texify_gui 319 | ``` 320 | 321 | ## Compilation 322 | 323 | The following models have support for compilation. You will need to set the following environment variables to enable compilation: 324 | 325 | - Detection: `COMPILE_DETECTOR=true` 326 | - Layout: `COMPILE_LAYOUT=true` 327 | - Table recognition: `COMPILE_TABLE_REC=true` 328 | 329 | Alternatively, you can also set `COMPILE_ALL=true` which will compile all models. 330 | 331 | Here are the speedups on an A10 GPU: 332 | 333 | | Model | Time per page (s) | Compiled time per page (s) | Speedup (%) | 334 | | ----------------- | ----------------- | -------------------------- | ----------- | 335 | | Detection | 0.108808 | 0.10521 | 3.306742151 | 336 | | Layout | 0.27319 | 0.27063 | 0.93707676 | 337 | | Table recognition | 0.0219 | 0.01938 | 11.50684932 | 338 | 339 | # Limitations 340 | 341 | - This is specialized for document OCR. It will likely not work on photos or other images. 342 | - It is for printed text, not handwriting (though it may work on some handwriting). 343 | - The text detection model has trained itself to ignore advertisements. 344 | - You can find language support for OCR in `surya/recognition/languages.py`. Text detection, layout analysis, and reading order will work with any language. 345 | 346 | ## Troubleshooting 347 | 348 | If OCR isn't working properly: 349 | 350 | - Try increasing resolution of the image so the text is bigger. If the resolution is already very high, try decreasing it to no more than a `2048px` width. 351 | - Preprocessing the image (binarizing, deskewing, etc) can help with very old/blurry images. 352 | - You can adjust `DETECTOR_BLANK_THRESHOLD` and `DETECTOR_TEXT_THRESHOLD` if you don't get good results. `DETECTOR_BLANK_THRESHOLD` controls the space between lines - any prediction below this number will be considered blank space. `DETECTOR_TEXT_THRESHOLD` controls how text is joined - any number above this is considered text. `DETECTOR_TEXT_THRESHOLD` should always be higher than `DETECTOR_BLANK_THRESHOLD`, and both should be in the 0-1 range. Looking at the heatmap from the debug output of the detector can tell you how to adjust these (if you see faint things that look like boxes, lower the thresholds, and if you see bboxes being joined together, raise the thresholds). 353 | 354 | # Manual install 355 | 356 | If you want to develop surya, you can install it manually: 357 | 358 | - `git clone https://github.com/VikParuchuri/surya.git` 359 | - `cd surya` 360 | - `poetry install` - installs main and dev dependencies 361 | - `poetry shell` - activates the virtual environment 362 | 363 | # Benchmarks 364 | 365 | ## OCR 366 | 367 |  368 | 369 | | Model | Time per page (s) | Avg similarity (⬆) | 370 | |-----------|-------------------|--------------------| 371 | | surya | .62 | 0.97 | 372 | | tesseract | .45 | 0.88 | 373 | 374 | [Full language results](static/images/rec_acc_table.png) 375 | 376 | Tesseract is CPU-based, and surya is CPU or GPU. I tried to cost-match the resources used, so I used a 1xA6000 (48GB VRAM) for surya, and 28 CPU cores for Tesseract (same price on Lambda Labs/DigitalOcean). 377 | 378 | ### Google Cloud Vision 379 | 380 | I benchmarked OCR against Google Cloud vision since it has similar language coverage to Surya. 381 | 382 |  383 | 384 | [Full language results](static/images/gcloud_full_langs.png) 385 | 386 | **Methodology** 387 | 388 | I measured normalized sentence similarity (0-1, higher is better) based on a set of real-world and synthetic pdfs. I sampled PDFs from common crawl, then filtered out the ones with bad OCR. I couldn't find PDFs for some languages, so I also generated simple synthetic PDFs for those. 389 | 390 | I used the reference line bboxes from the PDFs with both tesseract and surya, to just evaluate the OCR quality. 391 | 392 | For Google Cloud, I aligned the output from Google Cloud with the ground truth. I had to skip RTL languages since they didn't align well. 393 | 394 | ## Text line detection 395 | 396 |  397 | 398 | | Model | Time (s) | Time per page (s) | precision | recall | 399 | |-----------|------------|---------------------|-------------|----------| 400 | | surya | 47.2285 | 0.094452 | 0.835857 | 0.960807 | 401 | | tesseract | 74.4546 | 0.290838 | 0.631498 | 0.997694 | 402 | 403 | 404 | Tesseract is CPU-based, and surya is CPU or GPU. I ran the benchmarks on a system with an A10 GPU, and a 32 core CPU. This was the resource usage: 405 | 406 | - tesseract - 32 CPU cores, or 8 workers using 4 cores each 407 | - surya - 36 batch size, for 16GB VRAM usage 408 | 409 | **Methodology** 410 | 411 | Surya predicts line-level bboxes, while tesseract and others predict word-level or character-level. It's hard to find 100% correct datasets with line-level annotations. Merging bboxes can be noisy, so I chose not to use IoU as the metric for evaluation. 412 | 413 | I instead used coverage, which calculates: 414 | 415 | - Precision - how well the predicted bboxes cover ground truth bboxes 416 | - Recall - how well ground truth bboxes cover predicted bboxes 417 | 418 | First calculate coverage for each bbox, then add a small penalty for double coverage, since we want the detection to have non-overlapping bboxes. Anything with a coverage of 0.5 or higher is considered a match. 419 | 420 | Then we calculate precision and recall for the whole dataset. 421 | 422 | ## Layout analysis 423 | 424 | | Layout Type | precision | recall | 425 | |---------------|-------------|----------| 426 | | Image | 0.91265 | 0.93976 | 427 | | List | 0.80849 | 0.86792 | 428 | | Table | 0.84957 | 0.96104 | 429 | | Text | 0.93019 | 0.94571 | 430 | | Title | 0.92102 | 0.95404 | 431 | 432 | Time per image - .13 seconds on GPU (A10). 433 | 434 | **Methodology** 435 | 436 | I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/PubLayNet), which was not in the training data. I had to align publaynet labels with the surya layout labels. I was then able to find coverage for each layout type: 437 | 438 | - Precision - how well the predicted bboxes cover ground truth bboxes 439 | - Recall - how well ground truth bboxes cover predicted bboxes 440 | 441 | ## Reading Order 442 | 443 | 88% mean accuracy, and .4 seconds per image on an A10 GPU. See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check. 444 | 445 | **Methodology** 446 | 447 | I benchmarked the reading order on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth. 448 | 449 | The accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct. 450 | 451 | ## Table Recognition 452 | 453 | | Model | Row Intersection | Col Intersection | Time Per Image | 454 | |-------------------|--------------------|--------------------|------------------| 455 | | Surya | 1 | 0.98625 | 0.30202 | 456 | | Table transformer | 0.84 | 0.86857 | 0.08082 | 457 | 458 | Higher is better for intersection, which the percentage of the actual row/column overlapped by the predictions. This benchmark is mostly a sanity check - there is a more rigorous one in [marker](https://www.github.com/VikParuchuri/marker) 459 | 460 | **Methodology** 461 | 462 | The benchmark uses a subset of [Fintabnet](https://developer.ibm.com/exchanges/data/all/fintabnet/) from IBM. It has labeled rows and columns. After table recognition is run, the predicted rows and columns are compared to the ground truth. There is an additional penalty for predicting too many or too few rows/columns. 463 | 464 | ## LaTeX OCR 465 | 466 | | Method | edit ⬇ | time taken (s) ⬇ | 467 | |--------|----------|------------------| 468 | | texify | 0.122617 | 35.6345 | 469 | 470 | This inferences texify on a ground truth set of LaTeX, then does edit distance. This is a bit noisy, since 2 LaTeX strings that render the same can have different symbols in them. 471 | 472 | ## Running your own benchmarks 473 | 474 | You can benchmark the performance of surya on your machine. 475 | 476 | - Follow the manual install instructions above. 477 | - `poetry install --group dev` - installs dev dependencies 478 | 479 | **Text line detection** 480 | 481 | This will evaluate tesseract and surya for text line detection across a randomly sampled set of images from [doclaynet](https://huggingface.co/datasets/vikp/doclaynet_bench). 482 | 483 | ```shell 484 | python benchmark/detection.py --max_rows 256 485 | ``` 486 | 487 | - `--max_rows` controls how many images to process for the benchmark 488 | - `--debug` will render images and detected bboxes 489 | - `--pdf_path` will let you specify a pdf to benchmark instead of the default data 490 | - `--results_dir` will let you specify a directory to save results to instead of the default one 491 | 492 | **Text recognition** 493 | 494 | This will evaluate surya and optionally tesseract on multilingual pdfs from common crawl (with synthetic data for missing languages). 495 | 496 | ```shell 497 | python benchmark/recognition.py --tesseract 498 | ``` 499 | 500 | - `--max_rows` controls how many images to process for the benchmark 501 | - `--debug 2` will render images with detected text 502 | - `--results_dir` will let you specify a directory to save results to instead of the default one 503 | - `--tesseract` will run the benchmark with tesseract. You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder. 504 | 505 | - Set `RECOGNITION_BATCH_SIZE=864` to use the same batch size as the benchmark. 506 | - Set `RECOGNITION_BENCH_DATASET_NAME=vikp/rec_bench_hist` to use the historical document data for benchmarking. This data comes from the [tapuscorpus](https://github.com/HTR-United/tapuscorpus). 507 | 508 | **Layout analysis** 509 | 510 | This will evaluate surya on the publaynet dataset. 511 | 512 | ```shell 513 | python benchmark/layout.py 514 | ``` 515 | 516 | - `--max_rows` controls how many images to process for the benchmark 517 | - `--debug` will render images with detected text 518 | - `--results_dir` will let you specify a directory to save results to instead of the default one 519 | 520 | **Reading Order** 521 | 522 | ```shell 523 | python benchmark/ordering.py 524 | ``` 525 | 526 | - `--max_rows` controls how many images to process for the benchmark 527 | - `--debug` will render images with detected text 528 | - `--results_dir` will let you specify a directory to save results to instead of the default one 529 | 530 | **Table Recognition** 531 | 532 | ```shell 533 | python benchmark/table_recognition.py --max_rows 1024 --tatr 534 | ``` 535 | 536 | - `--max_rows` controls how many images to process for the benchmark 537 | - `--debug` will render images with detected text 538 | - `--results_dir` will let you specify a directory to save results to instead of the default one 539 | - `--tatr` specifies whether to also run table transformer 540 | 541 | **LaTeX OCR** 542 | 543 | ```shell 544 | python benchmark/texify.py --max_rows 128 545 | ``` 546 | 547 | - `--max_rows` controls how many images to process for the benchmark 548 | - `--results_dir` will let you specify a directory to save results to instead of the default one 549 | 550 | # Training 551 | 552 | Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified efficientvit architecture for semantic segmentation. 553 | 554 | Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes). 555 | 556 | # Finetuning Surya OCR 557 | You can now take Surya OCR further by training it on your own data with our [finetuning script](/surya/scripts/finetune_ocr.py). 558 | It’s built on Hugging Face Trainer, and supports all the [arguments](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments) that the huggingface trainer provides, and integrations like torchrun, or deepspeed. 559 | 560 | To setup your dataset, follow the example dataset format [here](https://huggingface.co/datasets/datalab-to/ocr_finetune_example) and provide the path to your own dataset when launching the training script. 561 | ```bash 562 | # Tested on 1xH100 GPU 563 | # Set --pretrained_checkpoint_path to load from a custom checkpoint, otherwise 564 | # the default surya ocr weights will be loaded as the initialization 565 | python surya/scripts/finetune_ocr.py \ 566 | --output_dir $OUTPUT_DIR \ 567 | --dataset_name datalab-to/ocr_finetune_example \ 568 | --per_device_train_batch_size 64 \ 569 | --gradient_checkpointing true \ 570 | --max_sequence_length 1024 571 | ``` 572 | 573 | This is a minimal training script to get you started finetuning Surya. Our internal training stack includes character bounding box finetuning, sliding window attention with specialized attention masks, custom kernels, augmentations, and other optimizations that can push OCR accuracy well beyond standard finetuning. If you want to get the most out of your data, reach us at [email protected]! 574 | 575 | # Thanks 576 | 577 | This work would not have been possible without amazing open source AI work: 578 | 579 | - [Segformer](https://arxiv.org/pdf/2105.15203.pdf) from NVIDIA 580 | - [EfficientViT](https://github.com/mit-han-lab/efficientvit) from MIT 581 | - [timm](https://github.com/huggingface/pytorch-image-models) from Ross Wightman 582 | - [Donut](https://github.com/clovaai/donut) from Naver 583 | - [transformers](https://github.com/huggingface/transformers) from huggingface 584 | - [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model 585 | 586 | Thank you to everyone who makes open source AI possible. 587 | 588 | # Citation 589 | 590 | If you use surya (or the associated models) in your work or research, please consider citing us using the following BibTeX entry: 591 | 592 | ```bibtex 593 | @misc{paruchuri2025surya, 594 | author = {Vikas Paruchuri and Datalab Team}, 595 | title = {Surya: A lightweight document OCR and analysis toolkit}, 596 | year = {2025}, 597 | howpublished = {\url{https://github.com/VikParuchuri/surya}}, 598 | note = {GitHub repository}, 599 | } 600 | ``` -------------------------------------------------------------------------------- /benchmark/utils/__init__.py: -------------------------------------------------------------------------------- ```python 1 | ``` -------------------------------------------------------------------------------- /surya/__init__.py: -------------------------------------------------------------------------------- ```python 1 | ``` -------------------------------------------------------------------------------- /surya/detection/model/__init__.py: -------------------------------------------------------------------------------- ```python 1 | ``` -------------------------------------------------------------------------------- /surya/foundation/cache/__init__.py: -------------------------------------------------------------------------------- ```python 1 | ``` -------------------------------------------------------------------------------- /surya/ocr_error/model/__init__.py: -------------------------------------------------------------------------------- ```python 1 | ``` -------------------------------------------------------------------------------- /surya/scripts/__init__.py: -------------------------------------------------------------------------------- ```python 1 | ``` -------------------------------------------------------------------------------- /surya/table_rec/model/__init__.py: -------------------------------------------------------------------------------- ```python 1 | ``` -------------------------------------------------------------------------------- /surya/common/__init__.py: -------------------------------------------------------------------------------- ```python 1 | 2 | 3 | 4 | ``` -------------------------------------------------------------------------------- /ocr_text.py: -------------------------------------------------------------------------------- ```python 1 | from surya.scripts.ocr_text import ocr_text_cli 2 | 3 | if __name__ == "__main__": 4 | ocr_text_cli() 5 | ``` -------------------------------------------------------------------------------- /ocr_latex.py: -------------------------------------------------------------------------------- ```python 1 | from surya.scripts.ocr_latex import ocr_latex_cli 2 | 3 | if __name__ == "__main__": 4 | ocr_latex_cli() 5 | ``` -------------------------------------------------------------------------------- /texify_app.py: -------------------------------------------------------------------------------- ```python 1 | from surya.scripts.run_texify_app import texify_app_cli 2 | 3 | if __name__ == "__main__": 4 | texify_app_cli() ``` -------------------------------------------------------------------------------- /detect_layout.py: -------------------------------------------------------------------------------- ```python 1 | from surya.scripts.detect_layout import detect_layout_cli 2 | 3 | if __name__ == "__main__": 4 | detect_layout_cli() 5 | ``` -------------------------------------------------------------------------------- /detect_text.py: -------------------------------------------------------------------------------- ```python 1 | from surya.scripts.detect_text import detect_text_cli 2 | 3 | if __name__ == "__main__": 4 | detect_text_cli() 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | ``` -------------------------------------------------------------------------------- /ocr_app.py: -------------------------------------------------------------------------------- ```python 1 | from surya.scripts.run_streamlit_app import streamlit_app_cli 2 | 3 | if __name__ == "__main__": 4 | streamlit_app_cli() ``` -------------------------------------------------------------------------------- /table_recognition.py: -------------------------------------------------------------------------------- ```python 1 | from surya.scripts.table_recognition import table_recognition_cli 2 | 3 | if __name__ == "__main__": 4 | table_recognition_cli() ``` -------------------------------------------------------------------------------- /surya/ocr_error/schema.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class OCRErrorDetectionResult(BaseModel): 7 | texts: List[str] 8 | labels: List[str] 9 | ``` -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- ``` 1 | [pytest] 2 | testpaths=tests 3 | pythonpath=. 4 | filterwarnings = 5 | ignore::UserWarning 6 | ignore::PendingDeprecationWarning 7 | ignore::DeprecationWarning ``` -------------------------------------------------------------------------------- /surya/detection/schema.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List, Optional, Any 2 | 3 | from pydantic import BaseModel 4 | 5 | from surya.common.polygon import PolygonBox 6 | 7 | 8 | class TextDetectionResult(BaseModel): 9 | bboxes: List[PolygonBox] 10 | heatmap: Optional[Any] 11 | affinity_map: Optional[Any] 12 | image_bbox: List[float] 13 | ``` -------------------------------------------------------------------------------- /surya/scripts/run_texify_app.py: -------------------------------------------------------------------------------- ```python 1 | import subprocess 2 | import os 3 | 4 | 5 | def texify_app_cli(): 6 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 7 | ocr_app_path = os.path.join(cur_dir, "texify_app.py") 8 | cmd = ["streamlit", "run", ocr_app_path, "--server.fileWatcherType", "none", "--server.headless", "true"] 9 | subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"}) ``` -------------------------------------------------------------------------------- /surya/scripts/run_streamlit_app.py: -------------------------------------------------------------------------------- ```python 1 | import subprocess 2 | import os 3 | 4 | 5 | def streamlit_app_cli(): 6 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 7 | ocr_app_path = os.path.join(cur_dir, "streamlit_app.py") 8 | cmd = ["streamlit", "run", ocr_app_path, "--server.fileWatcherType", "none", "--server.headless", "true"] 9 | subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"}) ``` -------------------------------------------------------------------------------- /surya/common/surya/schema.py: -------------------------------------------------------------------------------- ```python 1 | class TaskNames: 2 | block_without_boxes = "block_without_boxes" 3 | ocr_with_boxes = "ocr_with_boxes" 4 | ocr_without_boxes = "ocr_without_boxes" 5 | layout = "layout" 6 | table_structure = "table_structure" 7 | 8 | 9 | TASK_NAMES = [ 10 | TaskNames.block_without_boxes, 11 | TaskNames.ocr_with_boxes, 12 | TaskNames.ocr_without_boxes, 13 | TaskNames.layout, 14 | TaskNames.table_structure, 15 | ] 16 | ``` -------------------------------------------------------------------------------- /surya/layout/schema.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Optional, Dict, List 2 | 3 | from pydantic import BaseModel 4 | 5 | from surya.common.polygon import PolygonBox 6 | 7 | 8 | class LayoutBox(PolygonBox): 9 | label: str 10 | position: int 11 | top_k: Optional[Dict[str, float]] = None 12 | 13 | 14 | class LayoutResult(BaseModel): 15 | bboxes: List[LayoutBox] 16 | image_bbox: List[float] 17 | sliced: bool = False # Whether the image was sliced and reconstructed 18 | ``` -------------------------------------------------------------------------------- /surya/detection/parallel.py: -------------------------------------------------------------------------------- ```python 1 | class FakeFuture: 2 | def __init__(self, func, *args, **kwargs): 3 | self._result = func(*args, **kwargs) 4 | 5 | def result(self): 6 | return self._result 7 | 8 | class FakeExecutor: 9 | def __init__(self, **kwargs): 10 | pass 11 | 12 | def __enter__(self): 13 | return self 14 | 15 | def __exit__(self, *excinfo): 16 | pass 17 | 18 | def submit(self, fn, *args, **kwargs): 19 | return FakeFuture(fn, *args, **kwargs) 20 | ``` -------------------------------------------------------------------------------- /tests/test_layout.py: -------------------------------------------------------------------------------- ```python 1 | def test_layout_topk(layout_predictor, test_image): 2 | layout_results = layout_predictor([test_image]) 3 | 4 | assert len(layout_results) == 1 5 | assert layout_results[0].image_bbox == [0, 0, 1024, 1024] 6 | 7 | bboxes = layout_results[0].bboxes 8 | assert len(bboxes) == 2 9 | 10 | assert bboxes[0].label == "SectionHeader" 11 | assert len(bboxes[0].top_k) == 5 12 | 13 | assert bboxes[1].label == "Text" 14 | assert len(bboxes[1].top_k) == 5 15 | ``` -------------------------------------------------------------------------------- /tests/test_foundation.py: -------------------------------------------------------------------------------- ```python 1 | from surya.foundation import FoundationPredictor 2 | 3 | 4 | def test_foundation_flash2(): 5 | try: 6 | f = FoundationPredictor(None, None, None, "flash_attention_2") 7 | assert f.model.decoder.config._attn_implementation == "flash_attention_2" 8 | assert f.model.vision_encoder.config._attn_implementation == "flash_attention_2" 9 | except Exception as e: 10 | assert False, ( 11 | f"FoundationPredictor with flash_attention_2 raised an exception: {e}" 12 | ) 13 | ``` -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- ```yaml 1 | name: Unit tests 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ${{ matrix.os }} 8 | strategy: 9 | matrix: 10 | os: [t4_gpu, ubuntu-latest, windows-latest] 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python 3.11 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: 3.11 17 | - name: Install python dependencies 18 | run: | 19 | pip install poetry 20 | poetry install 21 | - name: Run tests 22 | run: poetry run pytest ``` -------------------------------------------------------------------------------- /surya/layout/label.py: -------------------------------------------------------------------------------- ```python 1 | LAYOUT_PRED_RELABEL = { 2 | "<page-header>": "PageHeader", 3 | "<page-footer>": "PageFooter", 4 | "<footnote>": "Footnote", 5 | "<image>": "Picture", 6 | "<figure>": "Figure", 7 | "<text>": "Text", 8 | "<caption>": "Caption", 9 | "<list-item>": "ListItem", 10 | "<section-header>": "SectionHeader", 11 | "<table>": "Table", 12 | "<table-of-contents>": "TableOfContents", 13 | "<form>": "Form", 14 | "<equation-block>": "Equation", 15 | "<code-block>": "Code", 16 | "<complex-block>": "Figure", 17 | } 18 | ``` -------------------------------------------------------------------------------- /tests/test_ocr_errors.py: -------------------------------------------------------------------------------- ```python 1 | def test_garbled_text(ocr_error_predictor): 2 | text = """" 3 | ; dh vksj ls mifLFkr vf/koDrk % Jh vfuy dqekj 4 | 2. vfHk;qDr dh vksj ls mifLFkr vf/koDrk % Jh iznhi d 5 | """.strip() 6 | results = ocr_error_predictor([text]) 7 | assert results.labels[0] == "bad" 8 | 9 | 10 | def test_good_text(ocr_error_predictor): 11 | text = """" 12 | There are professions more harmful than industrial design, but only a very few of them. 13 | """.strip() 14 | results = ocr_error_predictor([text]) 15 | assert results.labels[0] == "good" ``` -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- ```markdown 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[FEAT]" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## ✨ Is your feature request related to a problem? 11 | 12 | A clear and concise description of what the problem is. 13 | 14 | ## 💡 Describe the Solution You'd Like 15 | 16 | A concise description of what you want to happen or how you envision it working. 17 | 18 | ## 📋 Alternatives Considered 19 | 20 | Any alternative solutions or workarounds you've tried. 21 | 22 | ## 🧩 Additional Context 23 | 24 | Any additional context, references, or related issues. 25 | ``` -------------------------------------------------------------------------------- /surya/common/xla.py: -------------------------------------------------------------------------------- ```python 1 | import math 2 | from surya.settings import settings 3 | 4 | if settings.TORCH_DEVICE_MODEL == "xla": 5 | import torch_xla.core.xla_model as xm 6 | else: 7 | xm = None 8 | 9 | 10 | def get_nearest_pad( 11 | length: int, pad_multiple: int = settings.FOUNDATION_PAD_TO_NEAREST 12 | ): 13 | return math.ceil(length / pad_multiple) * pad_multiple 14 | 15 | 16 | def get_compile_args(device: str) -> dict: 17 | if not settings.FOUNDATION_XLA: 18 | return {} 19 | 20 | return { 21 | "backend": "openxla", 22 | } 23 | 24 | 25 | def mark_step(): 26 | if xm is not None: 27 | xm.mark_step() 28 | ``` -------------------------------------------------------------------------------- /tests/test_latex_ocr.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List 2 | 3 | from PIL import Image, ImageDraw 4 | 5 | from surya.common.surya.schema import TaskNames 6 | from surya.recognition import OCRResult 7 | 8 | 9 | def test_latex_ocr(recognition_predictor, test_image_latex): 10 | width, height = test_image_latex.size 11 | results: List[OCRResult] = recognition_predictor( 12 | [test_image_latex], [TaskNames.block_without_boxes], bboxes=[[[0, 0, width, height]]] 13 | ) 14 | text = results[0].text_lines[0].text 15 | assert len(results) == 1 16 | 17 | assert text.startswith("<math") 18 | assert text.endswith("</math>") 19 | ``` -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- ```yaml 1 | name: Python package 2 | on: 3 | push: 4 | tags: 5 | - "v*.*.*" 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3 11 | - name: Set up Python 3.11 12 | uses: actions/setup-python@v4 13 | with: 14 | python-version: 3.11 15 | - name: Install python dependencies 16 | run: | 17 | pip install poetry 18 | poetry install 19 | - name: Build package 20 | run: | 21 | poetry build 22 | - name: Publish package 23 | env: 24 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 25 | run: | 26 | poetry config pypi-token.pypi "$PYPI_TOKEN" 27 | poetry publish 28 | ``` -------------------------------------------------------------------------------- /tests/test_detection.py: -------------------------------------------------------------------------------- ```python 1 | def test_detection(detection_predictor, test_image): 2 | detection_results = detection_predictor([test_image]) 3 | 4 | assert len(detection_results) == 1 5 | assert detection_results[0].image_bbox == [0, 0, 1024, 1024] 6 | 7 | bboxes = detection_results[0].bboxes 8 | assert len(bboxes) == 4 9 | 10 | 11 | def test_detection_chunking(detection_predictor, test_image_tall): 12 | detection_results = detection_predictor([test_image_tall]) 13 | 14 | assert len(detection_results) == 1 15 | assert detection_results[0].image_bbox == [0, 0, 4096, 4096] 16 | 17 | bboxes = detection_results[0].bboxes 18 | assert len(bboxes) >= 3 # Sometimes merges into 3 19 | assert abs(4000 - bboxes[1].polygon[0][0]) < 50 ``` -------------------------------------------------------------------------------- /surya/common/load.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Optional, Any 2 | 3 | import torch 4 | 5 | from surya.settings import settings 6 | 7 | 8 | class ModelLoader: 9 | def __init__(self, checkpoint: Optional[str] = None): 10 | self.checkpoint = checkpoint 11 | 12 | def model( 13 | self, 14 | device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, 15 | dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE, 16 | attention_implementation: Optional[str] = None, 17 | ) -> Any: 18 | raise NotImplementedError() 19 | 20 | def processor( 21 | self, 22 | device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, 23 | dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE, 24 | ) -> Any: 25 | raise NotImplementedError() 26 | ``` -------------------------------------------------------------------------------- /surya/common/surya/processor/schema.py: -------------------------------------------------------------------------------- ```python 1 | from typing import TypedDict, Literal, List, Tuple 2 | 3 | import torch 4 | from PIL import Image 5 | 6 | 7 | class TaskDict(TypedDict): 8 | datasets: List[str] 9 | img_size: Tuple[int, int] 10 | 11 | 12 | class TasksDict(TypedDict): 13 | ocr_with_boxes: TaskDict 14 | ocr_without_boxes: TaskDict 15 | block_without_boxes: TaskDict 16 | 17 | 18 | class ProcessorInput(TypedDict): 19 | type: Literal["image", "ocr", "text", "empty_output"] 20 | 21 | 22 | class ImageInput(ProcessorInput): 23 | type: Literal["image"] 24 | image: Image.Image 25 | rotated: bool 26 | 27 | 28 | class TextInput(ProcessorInput): 29 | type: Literal["text"] 30 | text: str 31 | math: bool 32 | 33 | 34 | class ProcessorOutput(TypedDict): 35 | input_ids: List[int] 36 | image_tiles: torch.Tensor | None 37 | grid_thw: torch.Tensor | None 38 | ``` -------------------------------------------------------------------------------- /surya/logging.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | import warnings 3 | from surya.settings import settings 4 | 5 | 6 | def configure_logging(): 7 | logger = get_logger() 8 | 9 | # Remove any existing handlers to prevent duplicates 10 | for handler in logger.handlers[:]: 11 | logger.removeHandler(handler) 12 | 13 | # Add our handler 14 | handler = logging.StreamHandler() 15 | formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s") 16 | handler.setFormatter(formatter) 17 | logger.addHandler(handler) 18 | 19 | # Prevent propagation to parent loggers to avoid double logging 20 | logger.propagate = False 21 | 22 | logger.setLevel(settings.LOGLEVEL) 23 | warnings.simplefilter(action="ignore", category=FutureWarning) 24 | 25 | 26 | def get_logger(): 27 | return logging.getLogger("surya") 28 | ``` -------------------------------------------------------------------------------- /surya/common/pretrained.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Optional 2 | 3 | from transformers import PreTrainedModel 4 | from transformers.utils import is_flash_attn_2_available 5 | 6 | 7 | class SuryaPreTrainedModel(PreTrainedModel): 8 | # No-op if we pass attention, so we can set attention however we want in the config 9 | def _check_and_adjust_attn_implementation( 10 | self, attn_implementation: Optional[str], **kwargs 11 | ): 12 | if attn_implementation is None: 13 | try: 14 | self._sdpa_can_dispatch(True) 15 | attn_implementation = "sdpa" 16 | except (ValueError, ImportError): 17 | attn_implementation = "eager" 18 | 19 | if self._supports_flash_attn and is_flash_attn_2_available(): 20 | attn_implementation = "flash_attention_2" 21 | 22 | return attn_implementation 23 | ``` -------------------------------------------------------------------------------- /surya/debug/fonts.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List, Optional 2 | import os 3 | import requests 4 | 5 | from surya.settings import settings 6 | 7 | 8 | def get_font_path(langs: Optional[List[str]] = None) -> str: 9 | font_path = settings.RECOGNITION_RENDER_FONTS["all"] 10 | if langs is not None: 11 | for k in settings.RECOGNITION_RENDER_FONTS: 12 | if k in langs and len(langs) == 1: 13 | font_path = settings.RECOGNITION_RENDER_FONTS[k] 14 | break 15 | 16 | if not os.path.exists(font_path): 17 | os.makedirs(os.path.dirname(font_path), exist_ok=True) 18 | font_dl_path = f"{settings.RECOGNITION_FONT_DL_BASE}/{os.path.basename(font_path)}" 19 | with requests.get(font_dl_path, stream=True) as r, open(font_path, 'wb') as f: 20 | r.raise_for_status() 21 | for chunk in r.iter_content(chunk_size=8192): 22 | f.write(chunk) 23 | 24 | return font_path ``` -------------------------------------------------------------------------------- /surya/recognition/schema.py: -------------------------------------------------------------------------------- ```python 1 | import math 2 | import numpy as np 3 | from typing import Optional, List 4 | 5 | from pydantic import BaseModel, field_validator 6 | 7 | from surya.common.polygon import PolygonBox 8 | 9 | 10 | class BaseChar(PolygonBox): 11 | text: str 12 | confidence: Optional[float] = 0 13 | 14 | @field_validator("confidence", mode="before") 15 | @classmethod 16 | def validate_confidence(cls, v: float) -> float: 17 | if v is None: 18 | return 0 19 | elif math.isnan(v) or np.isnan(v): 20 | return 0 21 | return v 22 | 23 | 24 | class TextChar(BaseChar): 25 | bbox_valid: bool = True # This is false when the given bbox is not valid 26 | 27 | 28 | class TextWord(BaseChar): 29 | bbox_valid: bool = True 30 | 31 | 32 | class TextLine(BaseChar): 33 | chars: List[TextChar] # Individual characters in the line 34 | original_text_good: bool = False 35 | words: List[TextWord] | None = None 36 | 37 | 38 | class OCRResult(BaseModel): 39 | text_lines: List[TextLine] 40 | image_bbox: List[float] 41 | ``` -------------------------------------------------------------------------------- /surya/table_rec/schema.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List 2 | 3 | from pydantic import BaseModel 4 | 5 | from surya.common.polygon import PolygonBox 6 | 7 | 8 | class TableCell(PolygonBox): 9 | row_id: int 10 | colspan: int 11 | within_row_id: int 12 | cell_id: int 13 | is_header: bool 14 | rowspan: int | None = None 15 | merge_up: bool = False 16 | merge_down: bool = False 17 | col_id: int | None = None 18 | text_lines: List[dict] | None = None 19 | 20 | @property 21 | def label(self): 22 | return f'Cell {self.cell_id} {self.rowspan}/{self.colspan}' 23 | 24 | 25 | class TableRow(PolygonBox): 26 | row_id: int 27 | is_header: bool 28 | 29 | @property 30 | def label(self): 31 | return f'Row {self.row_id}' 32 | 33 | 34 | class TableCol(PolygonBox): 35 | col_id: int 36 | is_header: bool 37 | 38 | @property 39 | def label(self): 40 | return f'Column {self.col_id}' 41 | 42 | 43 | class TableResult(BaseModel): 44 | cells: List[TableCell] 45 | unmerged_cells: List[TableCell] 46 | rows: List[TableRow] 47 | cols: List[TableCol] 48 | image_bbox: List[float] 49 | ``` -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/output-bug-report.md: -------------------------------------------------------------------------------- ```markdown 1 | --- 2 | name: Output bug report 3 | about: Create a report about poor output quality 4 | title: "[BUG: Output]" 5 | labels: 'bug: output' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 📝 Describe the Output Issue 11 | 12 | A clear and concise description of the incorrect or unexpected output. 13 | 14 | ## 📄 Input Document 15 | 16 | Attach the PDF or input file used. 17 | 18 | ## 📤 Current Output 19 | 20 | Paste the Markdown or HTML that Marker generated: 21 | 22 | ````markdown 23 | Paste output here 24 | ````` 25 | 26 | ## ✅ Expected Output 27 | 28 | Describe or paste what you expected Marker to generate. 29 | 30 | ## ⚙️ Environment 31 | 32 | Please fill in all relevant details: 33 | 34 | * **Marker version**: 35 | * **Surya version**: 36 | * **Python version**: 37 | * **PyTorch version**: 38 | * **Transformers version**: 39 | * **Operating System**: 40 | 41 | ## 📟 Command or Code Used 42 | 43 | Paste the **exact bash command** or **Python code** you used to run Marker: 44 | 45 | <details> 46 | <summary>Click to expand</summary> 47 | 48 | ```bash 49 | # or Python code block 50 | your_command_here --with-flags 51 | ``` 52 | 53 | </details> 54 | 55 | ## 📎 Additional Context 56 | 57 | Any other relevant info, configs, or assumptions. 58 | ``` -------------------------------------------------------------------------------- /benchmark/utils/textract.py: -------------------------------------------------------------------------------- ```python 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor 3 | from tqdm import tqdm 4 | import traceback 5 | 6 | from surya.input.processing import slice_bboxes_from_image 7 | from surya.recognition import RecognitionPredictor 8 | 9 | def textract_ocr(extractor, img): 10 | try: 11 | document = extractor.detect_document_text(file_source=img) 12 | return [line.text for line in document.lines] 13 | except: 14 | traceback.print_exc() 15 | return [None] 16 | 17 | def textract_ocr_parallel(imgs, cpus=None): 18 | from textractor import Textractor # Optional dependency 19 | 20 | extractor = Textractor(profile_name='default') 21 | parallel_cores = min(len(imgs), RecognitionPredictor().get_batch_size()) 22 | if not cpus: 23 | cpus = os.cpu_count() 24 | parallel_cores = min(parallel_cores, cpus) 25 | 26 | with ThreadPoolExecutor(max_workers=parallel_cores) as executor: 27 | textract_text = tqdm(executor.map(textract_ocr, [extractor]*len(imgs), imgs), total=len(imgs), desc="Running textract OCR") 28 | textract_text = list(textract_text) 29 | return textract_text ``` -------------------------------------------------------------------------------- /surya/models.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Dict 2 | 3 | import torch 4 | 5 | from surya.common.predictor import BasePredictor 6 | from surya.detection import DetectionPredictor 7 | from surya.layout import LayoutPredictor 8 | from surya.logging import configure_logging 9 | from surya.ocr_error import OCRErrorPredictor 10 | from surya.foundation import FoundationPredictor 11 | from surya.recognition import RecognitionPredictor 12 | from surya.table_rec import TableRecPredictor 13 | from surya.settings import settings 14 | 15 | configure_logging() 16 | 17 | 18 | def load_predictors( 19 | device: str | torch.device | None = None, dtype: torch.dtype | str | None = None 20 | ) -> Dict[str, BasePredictor]: 21 | return { 22 | "layout": LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)), 23 | "ocr_error": OCRErrorPredictor(device=device, dtype=dtype), 24 | "recognition": RecognitionPredictor(FoundationPredictor(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT)), 25 | "detection": DetectionPredictor(device=device, dtype=dtype), 26 | "table_rec": TableRecPredictor(device=device, dtype=dtype), 27 | } 28 | ``` -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/breaking-bug-report.md: -------------------------------------------------------------------------------- ```markdown 1 | --- 2 | name: Breaking bug report 3 | about: Create a report about a breaking bug 4 | title: "[BUG: Breaking]" 5 | labels: 'bug: breaking' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 🧨 Describe the Bug 11 | 12 | A clear and concise description of the breaking issue (e.g., crash, OOM, exception, etc). 13 | 14 | ## 📄 Input Document 15 | 16 | Attach the PDF or input file that triggered the error. 17 | 18 | ## 📤 Output Trace / Stack Trace 19 | 20 | Paste the **complete** stack trace or error output, if available. 21 | 22 | <details> 23 | <summary>Click to expand</summary> 24 | 25 | ``` 26 | Paste stack trace here 27 | ``` 28 | 29 | </details> 30 | 31 | ## ⚙️ Environment 32 | 33 | Please fill in all relevant details: 34 | 35 | - **Marker version**: 36 | - **Surya version**: 37 | - **Python version**: 38 | - **PyTorch version**: 39 | - **Transformers version**: 40 | - **Operating System** (incl. container info if relevant): 41 | 42 | ## ✅ Expected Behavior 43 | 44 | What did you expect Marker to do? 45 | 46 | ## 📟 Command or Code Used 47 | 48 | Paste the **exact bash command** or **Python code** you used to run Marker: 49 | 50 | <details> 51 | <summary>Click to expand</summary> 52 | 53 | ```bash 54 | # or Python code block 55 | your_command_here --with-flags 56 | ``` 57 | 58 | </details> 59 | 60 | ## 📎 Additional Context 61 | 62 | Any other context that might help us debug this (e.g., CLI options, working directory, runtime settings). 63 | ``` -------------------------------------------------------------------------------- /surya/detection/util.py: -------------------------------------------------------------------------------- ```python 1 | import math 2 | from PIL import ImageOps 3 | 4 | from surya.settings import settings 5 | 6 | 7 | def get_total_splits(image_size, height): 8 | img_height = list(image_size)[1] 9 | max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT 10 | if img_height > max_height: 11 | num_splits = math.ceil(img_height / height) 12 | return num_splits 13 | return 1 14 | 15 | 16 | def split_image(img, height): 17 | # This will not modify/return the original image - it will either crop, or copy the image 18 | img_height = list(img.size)[1] 19 | max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT 20 | if img_height > max_height: 21 | num_splits = math.ceil(img_height / height) 22 | splits = [] 23 | split_heights = [] 24 | for i in range(num_splits): 25 | top = i * height 26 | bottom = (i + 1) * height 27 | if bottom > img_height: 28 | bottom = img_height 29 | cropped = img.crop((0, top, img.size[0], bottom)) 30 | chunk_height = bottom - top 31 | if chunk_height < height: 32 | cropped = ImageOps.pad(cropped, (img.size[0], height), color=255, centering=(0, 0)) 33 | splits.append(cropped) 34 | split_heights.append(chunk_height) 35 | return splits, split_heights 36 | return [img.copy()], [img_height] 37 | ``` -------------------------------------------------------------------------------- /benchmark/utils/scoring.py: -------------------------------------------------------------------------------- ```python 1 | import math 2 | from typing import List 3 | 4 | from rapidfuzz import fuzz 5 | 6 | 7 | def overlap_score(pred_lines: List[str], reference_lines: List[str]): 8 | line_scores = [] 9 | line_weights = [] 10 | line_match = {} 11 | for i, pred_line in enumerate(pred_lines): 12 | max_score = 0 13 | line_weight = 1 14 | match = None 15 | for j, ref_line in enumerate(reference_lines): 16 | score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100 17 | if score > max_score: 18 | max_score = score 19 | line_weight = math.sqrt(len(ref_line)) 20 | match = j 21 | line_scores.append(max_score) 22 | line_weights.append(line_weight) 23 | line_match[i] = match 24 | line_scores = [line_scores[i] * line_weights[i] for i in range(len(line_scores))] 25 | 26 | return line_scores, line_weights, line_match 27 | 28 | 29 | def overlap_score_exact(pred_lines: List[str], reference_lines: List[str]): 30 | line_scores = [] 31 | line_weights = [] 32 | assert len(pred_lines) == len(reference_lines) 33 | 34 | for i, (pred_line, ref_line) in enumerate(zip(pred_lines, reference_lines)): 35 | score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100 36 | weight = math.sqrt(len(ref_line)) 37 | line_scores.append(score * weight) 38 | line_weights.append(weight) 39 | 40 | return line_scores, line_weights 41 | ``` -------------------------------------------------------------------------------- /.github/workflows/cla.yml: -------------------------------------------------------------------------------- ```yaml 1 | name: "Surya CLA Assistant" 2 | on: 3 | issue_comment: 4 | types: [created] 5 | pull_request_target: 6 | types: [opened,closed,synchronize] 7 | 8 | # explicitly configure permissions, in case your GITHUB_TOKEN workflow permissions are set to read-only in repository settings 9 | permissions: 10 | actions: write 11 | contents: write 12 | pull-requests: write 13 | statuses: write 14 | 15 | jobs: 16 | CLAAssistant: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: "Surya CLA Assistant" 20 | if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' 21 | uses: contributor-assistant/[email protected] 22 | env: 23 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 24 | # the below token should have repo scope and must be manually added by you in the repository's secret 25 | # This token is required only if you have configured to store the signatures in a remote repository/organization 26 | PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }} 27 | with: 28 | path-to-signatures: 'signatures/version1/cla.json' 29 | path-to-document: 'https://github.com/VikParuchuri/surya/blob/master/CLA.md' 30 | # branch should not be protected 31 | branch: 'master' 32 | allowlist: VikParuchuri ``` -------------------------------------------------------------------------------- /.github/workflows/scripts.yml: -------------------------------------------------------------------------------- ```yaml 1 | name: Test CLI scripts 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: t4_gpu 8 | steps: 9 | - uses: actions/checkout@v3 10 | - name: Set up Python 3.11 11 | uses: actions/setup-python@v4 12 | with: 13 | python-version: 3.11 14 | - name: Install python dependencies 15 | run: | 16 | pip install poetry 17 | poetry install 18 | - name: Download benchmark data 19 | run: | 20 | wget -O benchmark_data.zip "https://drive.google.com/uc?export=download&id=1NHrdYatR1rtqs2gPVfdvO0BAvocH8CJi" 21 | unzip -o benchmark_data.zip 22 | - name: Test detection 23 | run: poetry run surya_detect benchmark_data/pdfs/switch_trans.pdf --page_range 0 24 | - name: Test OCR 25 | env: 26 | RECOGNITION_MAX_TOKENS: 25 27 | run: poetry run surya_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0 28 | - name: Test layout 29 | run: poetry run surya_layout benchmark_data/pdfs/switch_trans.pdf --page_range 0 30 | - name: Test table 31 | run: poetry run surya_table benchmark_data/pdfs/switch_trans.pdf --page_range 0 32 | - name: Test texify 33 | env: 34 | TEXIFY_MAX_TOKENS: 25 35 | run: poetry run surya_latex_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0 36 | - name: Test detection folder 37 | run: poetry run surya_detect benchmark_data/pdfs --page_range 0 38 | ``` -------------------------------------------------------------------------------- /surya/common/surya/encoder/config.py: -------------------------------------------------------------------------------- ```python 1 | from transformers.configuration_utils import PretrainedConfig 2 | from transformers.utils import logging 3 | 4 | logger = logging.get_logger(__name__) 5 | 6 | 7 | class SuryaEncoderConfig(PretrainedConfig): 8 | model_type = "qwen2_5_vl" 9 | base_config_key = "vision_config" 10 | 11 | attribute_map = { 12 | "num_attention_heads": "num_heads", 13 | "num_hidden_layers": "depth", 14 | } 15 | 16 | def __init__( 17 | self, 18 | depth=8, 19 | hidden_size=1280, 20 | hidden_act="silu", 21 | intermediate_size=3420, 22 | num_heads=16, 23 | in_channels=3, 24 | patch_size=14, 25 | spatial_merge_size=2, 26 | spatial_patch_size=14, 27 | temporal_patch_size=1, 28 | tokens_per_second=4, 29 | window_size=112, 30 | out_hidden_size=1280, 31 | fullatt_block_indexes=(3, 7), 32 | initializer_range=0.02, 33 | image_size=4096, 34 | **kwargs, 35 | ): 36 | super().__init__(**kwargs) 37 | 38 | self.depth = depth 39 | self.hidden_size = hidden_size 40 | self.hidden_act = hidden_act 41 | self.intermediate_size = intermediate_size 42 | self.num_heads = num_heads 43 | self.in_channels = in_channels 44 | self.patch_size = patch_size 45 | self.spatial_merge_size = spatial_merge_size 46 | self.temporal_patch_size = temporal_patch_size 47 | self.tokens_per_second = tokens_per_second 48 | self.window_size = window_size 49 | self.fullatt_block_indexes = fullatt_block_indexes 50 | self.out_hidden_size = out_hidden_size 51 | self.initializer_range = initializer_range 52 | self.spatial_patch_size = spatial_patch_size 53 | self.image_size = image_size 54 | ``` -------------------------------------------------------------------------------- /surya/detection/model/config.py: -------------------------------------------------------------------------------- ```python 1 | from transformers import PretrainedConfig 2 | 3 | from surya.common.s3 import S3DownloaderMixin 4 | 5 | 6 | class EfficientViTConfig(S3DownloaderMixin, PretrainedConfig): 7 | r""" 8 | ```""" 9 | 10 | model_type = "efficientvit" 11 | 12 | def __init__( 13 | self, 14 | num_classes=2, 15 | num_channels=3, 16 | widths=(32, 64, 128, 256, 512), 17 | head_dim=32, 18 | num_stages=4, 19 | depths=(1, 1, 1, 6, 6), 20 | strides=(2, 2, 2, 2, 2), 21 | hidden_sizes=(32, 64, 160, 256), 22 | patch_size=(7, 7), 23 | hidden_dropout_prob=0.0, 24 | attention_probs_dropout_prob=0.0, 25 | classifier_dropout_prob=0.0, 26 | layer_norm_eps=1e-6, 27 | decoder_layer_hidden_size=128, 28 | decoder_hidden_size=512, 29 | semantic_loss_ignore_index=255, 30 | initializer_range=0.02, 31 | **kwargs, 32 | ): 33 | super().__init__(**kwargs) 34 | 35 | self.num_classes = num_classes 36 | self.widths = widths 37 | self.head_dim = head_dim 38 | 39 | self.num_channels = num_channels 40 | self.num_stages = num_stages 41 | self.depths = depths 42 | self.strides = strides 43 | self.hidden_sizes = hidden_sizes 44 | self.patch_size = patch_size 45 | self.hidden_dropout_prob = hidden_dropout_prob 46 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 47 | self.classifier_dropout_prob = classifier_dropout_prob 48 | self.layer_norm_eps = layer_norm_eps 49 | self.decoder_hidden_size = decoder_hidden_size 50 | self.decoder_layer_hidden_size = decoder_layer_hidden_size 51 | self.semantic_loss_ignore_index = semantic_loss_ignore_index 52 | 53 | self.initializer_range = initializer_range ``` -------------------------------------------------------------------------------- /tests/test_table_rec.py: -------------------------------------------------------------------------------- ```python 1 | from PIL import Image, ImageDraw 2 | 3 | def test_table_rec(table_rec_predictor): 4 | data = [ 5 | ["Name", "Age", "City"], 6 | ["Alice", 25, "New York"], 7 | ["Bob", 30, "Los Angeles"], 8 | ["Charlie", 35, "Chicago"], 9 | ] 10 | test_image = draw_table(data) 11 | 12 | results = table_rec_predictor([test_image]) 13 | assert len(results) == 1 14 | assert results[0].image_bbox == [0, 0, test_image.size[0], test_image.size[1]] 15 | 16 | cells = results[0].cells 17 | assert len(cells) == 12 18 | for row_id in range(4): 19 | for col_id in range(3): 20 | cell = [c for c in cells if c.row_id == row_id and c.col_id == col_id] 21 | assert len(cell) == 1, f"Missing cell at row {row_id}, col {col_id}" 22 | 23 | def draw_table(data, cell_width=100, cell_height=40): 24 | rows = len(data) 25 | cols = len(data[0]) 26 | width = cols * cell_width 27 | height = rows * cell_height 28 | 29 | image = Image.new('RGB', (width, height), 'white') 30 | draw = ImageDraw.Draw(image) 31 | 32 | for i in range(rows + 1): 33 | y = i * cell_height 34 | draw.line([(0, y), (width, y)], fill='black', width=1) 35 | 36 | for i in range(cols + 1): 37 | x = i * cell_width 38 | draw.line([(x, 0), (x, height)], fill='black', width=1) 39 | 40 | for i in range(rows): 41 | for j in range(cols): 42 | text = str(data[i][j]) 43 | text_bbox = draw.textbbox((0, 0), text) 44 | text_width = text_bbox[2] - text_bbox[0] 45 | text_height = text_bbox[3] - text_bbox[1] 46 | 47 | x = j * cell_width + (cell_width - text_width) // 2 48 | y = i * cell_height + (cell_height - text_height) // 2 49 | 50 | draw.text((x, y), text, fill='black') 51 | 52 | return image ``` -------------------------------------------------------------------------------- /benchmark/utils/bbox.py: -------------------------------------------------------------------------------- ```python 1 | import fitz as pymupdf 2 | from surya.common.util import rescale_bbox 3 | 4 | 5 | def get_pdf_lines(pdf_path, img_sizes): 6 | doc = pymupdf.open(pdf_path) 7 | page_lines = [] 8 | for idx, img_size in enumerate(img_sizes): 9 | page = doc[idx] 10 | blocks = page.get_text("dict", sort=True, flags=pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES)["blocks"] 11 | 12 | line_boxes = [] 13 | for block_idx, block in enumerate(blocks): 14 | for l in block["lines"]: 15 | line_boxes.append(list(l["bbox"])) 16 | 17 | page_box = page.bound() 18 | pwidth, pheight = page_box[2] - page_box[0], page_box[3] - page_box[1] 19 | line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes] 20 | page_lines.append(line_boxes) 21 | 22 | return page_lines 23 | 24 | def merge_boxes(box1, box2): 25 | return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3])) 26 | 27 | 28 | def join_lines(bboxes, max_gap=5): 29 | to_merge = {} 30 | for i, box1 in bboxes: 31 | for z, box2 in bboxes[i + 1:]: 32 | j = i + z + 1 33 | if box1 == box2: 34 | continue 35 | 36 | if box1[0] <= box2[0] and box1[2] >= box2[2]: 37 | if abs(box1[1] - box2[3]) <= max_gap: 38 | if i not in to_merge: 39 | to_merge[i] = [] 40 | to_merge[i].append(j) 41 | 42 | merged_boxes = set() 43 | merged = [] 44 | for i, box in bboxes: 45 | if i in merged_boxes: 46 | continue 47 | 48 | if i in to_merge: 49 | for j in to_merge[i]: 50 | box = merge_boxes(box, bboxes[j][1]) 51 | merged_boxes.add(j) 52 | 53 | merged.append(box) 54 | return merged 55 | ``` -------------------------------------------------------------------------------- /surya/scripts/ocr_latex.py: -------------------------------------------------------------------------------- ```python 1 | import os 2 | 3 | import click 4 | import json 5 | import time 6 | from collections import defaultdict 7 | 8 | from surya.logging import configure_logging, get_logger 9 | from surya.scripts.config import CLILoader 10 | from surya.foundation import FoundationPredictor 11 | from surya.recognition import RecognitionPredictor 12 | from surya.common.surya.schema import TaskNames 13 | 14 | configure_logging() 15 | logger = get_logger() 16 | 17 | 18 | @click.command(help="OCR LaTeX equations.") 19 | @CLILoader.common_options 20 | def ocr_latex_cli(input_path: str, **kwargs): 21 | loader = CLILoader(input_path, kwargs, highres=True) 22 | 23 | foundation_predictor = FoundationPredictor() 24 | texify_predictor = RecognitionPredictor(foundation_predictor) 25 | tasks = [TaskNames.block_without_boxes] * len(loader.images) 26 | bboxes = [[[0, 0, image.width, image.height]] for image in loader.images] 27 | 28 | start = time.time() 29 | predictions_by_image = texify_predictor( 30 | loader.images, 31 | tasks, 32 | bboxes=bboxes, 33 | ) 34 | 35 | latex_predictions = [p.text_lines[0].text for p in predictions_by_image] 36 | 37 | if loader.debug: 38 | logger.debug(f"OCR took {time.time() - start:.2f} seconds") 39 | max_chars = max([len(latex) for latex in latex_predictions]) 40 | logger.debug(f"Max chars: {max_chars}") 41 | 42 | out_preds = defaultdict(list) 43 | for name, pred, image in zip(loader.names, latex_predictions, loader.images): 44 | out_pred = { 45 | "equation": pred, 46 | "page": len(out_preds[name]) + 1, 47 | } 48 | out_preds[name].append(out_pred) 49 | 50 | with open( 51 | os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" 52 | ) as f: 53 | json.dump(out_preds, f, ensure_ascii=False) 54 | 55 | logger.info(f"Wrote results to {loader.result_path}") 56 | ``` -------------------------------------------------------------------------------- /surya/foundation/util.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List, Tuple 2 | import numpy as np 3 | import torch 4 | 5 | def detect_repeat_token(predicted_tokens: List[int], max_repeats: int = 40): 6 | if len(predicted_tokens) < max_repeats: 7 | return False 8 | 9 | # Detect repeats containing 1 or 2 tokens 10 | last_n = predicted_tokens[-max_repeats:] 11 | unique_tokens = len(set(last_n)) 12 | if unique_tokens > 5: 13 | return False 14 | 15 | return last_n[-unique_tokens:] == last_n[-unique_tokens * 2 : -unique_tokens] 16 | 17 | def prediction_to_polygon_batch( 18 | pred: torch.Tensor, 19 | img_sizes: List[Tuple[int, int]], 20 | bbox_scaler, 21 | skew_scaler, 22 | skew_min=0.001, 23 | ): 24 | img_sizes = torch.from_numpy(np.array(img_sizes, dtype=np.float32)).to( 25 | pred.device 26 | ) 27 | w_scale = (img_sizes[:, 1] / bbox_scaler)[:, None, None] 28 | h_scale = (img_sizes[:, 0] / bbox_scaler)[:, None, None] 29 | 30 | cx = pred[:, :, 0] 31 | cy = pred[:, :, 1] 32 | width = pred[:, :, 2] 33 | height = pred[:, :, 3] 34 | 35 | x1 = cx - width / 2 36 | y1 = cy - height / 2 37 | x2 = cx + width / 2 38 | y2 = cy + height / 2 39 | 40 | skew_x = torch.floor((pred[:, :, 4] - skew_scaler) / 2) 41 | skew_y = torch.floor((pred[:, :, 5] - skew_scaler) / 2) 42 | 43 | skew_x[torch.abs(skew_x) < skew_min] = 0 44 | skew_y[torch.abs(skew_y) < skew_min] = 0 45 | 46 | polygons_flat = torch.stack( 47 | [ 48 | x1 - skew_x, 49 | y1 - skew_y, 50 | x2 - skew_x, 51 | y1 + skew_y, 52 | x2 + skew_x, 53 | y2 + skew_y, 54 | x1 + skew_x, 55 | y2 - skew_y, 56 | ], 57 | dim=2, 58 | ) 59 | 60 | batch_size, seq_len, _ = pred.shape 61 | polygons = polygons_flat.view(batch_size, seq_len, 4, 2) 62 | 63 | polygons[:, :, :, 0] *= w_scale 64 | polygons[:, :, :, 1] *= h_scale 65 | 66 | return polygons ``` -------------------------------------------------------------------------------- /surya/scripts/detect_text.py: -------------------------------------------------------------------------------- ```python 1 | import click 2 | import copy 3 | import json 4 | import time 5 | from collections import defaultdict 6 | 7 | from surya.detection import DetectionPredictor 8 | from surya.debug.draw import draw_polys_on_image 9 | from surya.logging import configure_logging, get_logger 10 | from surya.scripts.config import CLILoader 11 | import os 12 | 13 | configure_logging() 14 | logger = get_logger() 15 | 16 | 17 | @click.command(help="Detect bboxes in an input file or folder (PDFs or image).") 18 | @CLILoader.common_options 19 | def detect_text_cli(input_path: str, **kwargs): 20 | loader = CLILoader(input_path, kwargs) 21 | 22 | det_predictor = DetectionPredictor() 23 | 24 | start = time.time() 25 | predictions = det_predictor(loader.images, include_maps=loader.debug) 26 | end = time.time() 27 | if loader.debug: 28 | logger.debug(f"Detection took {end - start} seconds") 29 | 30 | if loader.save_images: 31 | for idx, (image, pred, name) in enumerate( 32 | zip(loader.images, predictions, loader.names) 33 | ): 34 | polygons = [p.polygon for p in pred.bboxes] 35 | bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image)) 36 | bbox_image.save(os.path.join(loader.result_path, f"{name}_{idx}_bbox.png")) 37 | 38 | if loader.debug: 39 | heatmap = pred.heatmap 40 | heatmap.save(os.path.join(loader.result_path, f"{name}_{idx}_heat.png")) 41 | 42 | predictions_by_page = defaultdict(list) 43 | for idx, (pred, name, image) in enumerate( 44 | zip(predictions, loader.names, loader.images) 45 | ): 46 | out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"]) 47 | out_pred["page"] = len(predictions_by_page[name]) + 1 48 | predictions_by_page[name].append(out_pred) 49 | 50 | with open( 51 | os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" 52 | ) as f: 53 | json.dump(predictions_by_page, f, ensure_ascii=False) 54 | 55 | logger.info(f"Wrote results to {loader.result_path}") 56 | ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- ```toml 1 | [tool.poetry] 2 | name = "surya-ocr" 3 | version = "0.17.0" 4 | description = "OCR, layout, reading order, and table recognition in 90+ languages" 5 | authors = ["Vik Paruchuri <[email protected]>"] 6 | readme = "README.md" 7 | license = "GPL-3.0-or-later" 8 | repository = "https://github.com/VikParuchuri/surya" 9 | keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"] 10 | packages = [ 11 | {include = "surya"} 12 | ] 13 | 14 | [tool.poetry.dependencies] 15 | python = "^3.10" 16 | transformers = ">=4.56.1" 17 | torch = "^2.7.0" 18 | pydantic = "^2.5.3" 19 | pydantic-settings = "^2.1.0" 20 | python-dotenv = "^1.0.0" 21 | pillow = "^10.2.0" 22 | pypdfium2 = "=4.30.0" 23 | filetype = "^1.2.0" 24 | click = "^8.1.8" 25 | platformdirs = "^4.3.6" 26 | opencv-python-headless = "==4.11.0.86" 27 | einops = "^0.8.1" 28 | pre-commit = "^4.2.0" 29 | 30 | [tool.poetry.group.dev.dependencies] 31 | jupyter = "^1.0.0" 32 | pytesseract = "^0.3.10" 33 | pymupdf = "^1.23.8" 34 | datasets = "^2.16.1" 35 | rapidfuzz = "^3.6.1" 36 | streamlit = "^1.31.0" 37 | pytest = "^8.3.4" 38 | pdftext = "^0.5.1" 39 | tabulate = "^0.9.0" 40 | 41 | [tool.poetry.scripts] 42 | surya_detect = "surya.scripts.detect_text:detect_text_cli" 43 | surya_ocr = "surya.scripts.ocr_text:ocr_text_cli" 44 | surya_layout = "surya.scripts.detect_layout:detect_layout_cli" 45 | surya_gui = "surya.scripts.run_streamlit_app:streamlit_app_cli" 46 | surya_table = "surya.scripts.table_recognition:table_recognition_cli" 47 | surya_latex_ocr = "surya.scripts.ocr_latex:ocr_latex_cli" 48 | texify_gui = "surya.scripts.run_texify_app:texify_app_cli" 49 | 50 | [build-system] 51 | requires = ["poetry-core"] 52 | build-backend = "poetry.core.masonry.api" 53 | 54 | [[tool.poetry.source]] 55 | name = "libtpu-releases" 56 | url = "https://storage.googleapis.com/libtpu-releases/index.html" 57 | priority = "supplemental" 58 | 59 | [[tool.poetry.source]] 60 | name = "libtpu-wheels" 61 | url = "https://storage.googleapis.com/libtpu-wheels/index.html" 62 | priority = "supplemental" 63 | 64 | [tool.poetry.group.xla] 65 | optional = true 66 | 67 | [tool.poetry.group.xla.dependencies] 68 | torch-xla = {version = "^2.4.1", extras = ["tpu"]} 69 | ``` -------------------------------------------------------------------------------- /.github/workflows/benchmarks.yml: -------------------------------------------------------------------------------- ```yaml 1 | name: Integration test 2 | 3 | on: [push] 4 | 5 | env: 6 | PYTHONIOENCODING: "utf-8" 7 | 8 | jobs: 9 | build: 10 | runs-on: t4_gpu 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python 3.11 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: 3.11 17 | - name: Install python dependencies 18 | run: | 19 | pip install poetry 20 | poetry install 21 | - name: Run detection benchmark test 22 | run: | 23 | poetry run python benchmark/detection.py --max_rows 2 24 | poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection 25 | - name: Run recognition benchmark test 26 | run: | 27 | poetry run python benchmark/recognition.py --max_rows 2 28 | poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition 29 | - name: Run layout benchmark test 30 | run: | 31 | poetry run python benchmark/layout.py --max_rows 5 32 | poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout 33 | - name: Run ordering benchmark 34 | run: | 35 | poetry run python benchmark/ordering.py --max_rows 5 36 | poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering 37 | - name: Run table recognition benchmark 38 | run: | 39 | poetry run python benchmark/table_recognition.py --max_rows 5 40 | poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition 41 | - name: Run texify benchmark 42 | run: | 43 | poetry run python benchmark/texify.py --max_rows 5 44 | poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/texify_bench/results.json --bench_type texify ``` -------------------------------------------------------------------------------- /surya/ocr_error/model/config.py: -------------------------------------------------------------------------------- ```python 1 | from collections import OrderedDict 2 | from typing import Mapping 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.onnx import OnnxConfig 6 | 7 | from surya.common.s3 import S3DownloaderMixin 8 | 9 | ID2LABEL = { 10 | 0: 'good', 11 | 1: 'bad' 12 | } 13 | 14 | class DistilBertConfig(S3DownloaderMixin, PretrainedConfig): 15 | model_type = "distilbert" 16 | attribute_map = { 17 | "hidden_size": "dim", 18 | "num_attention_heads": "n_heads", 19 | "num_hidden_layers": "n_layers", 20 | } 21 | 22 | def __init__( 23 | self, 24 | vocab_size=30522, 25 | max_position_embeddings=512, 26 | sinusoidal_pos_embds=False, 27 | n_layers=6, 28 | n_heads=12, 29 | dim=768, 30 | hidden_dim=4 * 768, 31 | dropout=0.1, 32 | attention_dropout=0.1, 33 | activation="gelu", 34 | initializer_range=0.02, 35 | qa_dropout=0.1, 36 | seq_classif_dropout=0.2, 37 | pad_token_id=0, 38 | **kwargs, 39 | ): 40 | self.vocab_size = vocab_size 41 | self.max_position_embeddings = max_position_embeddings 42 | self.sinusoidal_pos_embds = sinusoidal_pos_embds 43 | self.n_layers = n_layers 44 | self.n_heads = n_heads 45 | self.dim = dim 46 | self.hidden_dim = hidden_dim 47 | self.dropout = dropout 48 | self.attention_dropout = attention_dropout 49 | self.activation = activation 50 | self.initializer_range = initializer_range 51 | self.qa_dropout = qa_dropout 52 | self.seq_classif_dropout = seq_classif_dropout 53 | super().__init__(**kwargs, pad_token_id=pad_token_id) 54 | 55 | 56 | class DistilBertOnnxConfig(OnnxConfig): 57 | @property 58 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 59 | if self.task == "multiple-choice": 60 | dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} 61 | else: 62 | dynamic_axis = {0: "batch", 1: "sequence"} 63 | return OrderedDict( 64 | [ 65 | ("input_ids", dynamic_axis), 66 | ("attention_mask", dynamic_axis), 67 | ] 68 | ) ``` -------------------------------------------------------------------------------- /surya/debug/katex.js: -------------------------------------------------------------------------------- ```javascript 1 | <style> 2 | .katex-display-container { 3 | display: inline-block; 4 | max-width: 100%; 5 | overflow-x: auto; 6 | max-height: 100%; 7 | } 8 | 9 | .katex-inline-container { 10 | display: inline-block; 11 | max-width: 100%; 12 | overflow-x: auto; 13 | max-height: 100%; 14 | } 15 | </style> 16 | <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.js" onload="setTimeout(function() {renderMath()})" async></script> 17 | <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css"> 18 | <script> 19 | function htmlUnescape(escapedText) { 20 | const htmlEntities = { 21 | '&': '&', 22 | '<': '<', 23 | '>': '>', 24 | '"': '"', 25 | ''': "'", 26 | ' ': ' ' 27 | }; 28 | 29 | return escapedText.replace(/&|<|>|"|'| /g, match => htmlEntities[match]); 30 | } 31 | 32 | const renderMath = (function() { 33 | try { 34 | const mathElements = document.querySelectorAll('math'); 35 | 36 | mathElements.forEach(function(element) { 37 | let mathContent = element.innerHTML.trim(); 38 | mathContent = htmlUnescape(mathContent); 39 | const isDisplay = element.getAttribute('display') === 'block'; 40 | 41 | const container = document.createElement('span'); 42 | container.className = isDisplay ? 'katex-display-container' : 'katex-inline-container'; 43 | element.parentNode.insertBefore(container, element); 44 | 45 | try { 46 | katex.render(mathContent, container, { 47 | displayMode: isDisplay, 48 | throwOnError: false 49 | }); 50 | 51 | } catch (err) { 52 | console.error('KaTeX rendering error:', err); 53 | container.textContent = mathContent; // Fallback to raw text 54 | } 55 | 56 | element.parentNode.removeChild(element); 57 | }); 58 | 59 | console.log('Math rendering complete with', mathElements.length, 'expressions'); 60 | } catch (err) { 61 | console.error('Error in renderMath function:', err); 62 | } 63 | }); 64 | </script> ``` -------------------------------------------------------------------------------- /surya/scripts/detect_layout.py: -------------------------------------------------------------------------------- ```python 1 | import time 2 | import click 3 | import copy 4 | import json 5 | from collections import defaultdict 6 | 7 | from surya.foundation import FoundationPredictor 8 | from surya.layout import LayoutPredictor 9 | from surya.debug.draw import draw_polys_on_image 10 | from surya.logging import configure_logging, get_logger 11 | from surya.scripts.config import CLILoader 12 | from surya.settings import settings 13 | import os 14 | 15 | configure_logging() 16 | logger = get_logger() 17 | 18 | 19 | @click.command(help="Detect layout of an input file or folder (PDFs or image).") 20 | @CLILoader.common_options 21 | def detect_layout_cli(input_path: str, **kwargs): 22 | loader = CLILoader(input_path, kwargs) 23 | 24 | foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) 25 | layout_predictor = LayoutPredictor(foundation_predictor) 26 | 27 | start = time.time() 28 | layout_predictions = layout_predictor(loader.images) 29 | 30 | if loader.debug: 31 | logger.debug(f"Layout took {time.time() - start} seconds") 32 | 33 | if loader.save_images: 34 | for idx, (image, layout_pred, name) in enumerate( 35 | zip(loader.images, layout_predictions, loader.names) 36 | ): 37 | polygons = [p.polygon for p in layout_pred.bboxes] 38 | labels = [f"{p.label}-{p.position}" for p in layout_pred.bboxes] 39 | bbox_image = draw_polys_on_image( 40 | polygons, copy.deepcopy(image), labels=labels 41 | ) 42 | bbox_image.save( 43 | os.path.join(loader.result_path, f"{name}_{idx}_layout.png") 44 | ) 45 | 46 | predictions_by_page = defaultdict(list) 47 | for idx, (pred, name, image) in enumerate( 48 | zip(layout_predictions, loader.names, loader.images) 49 | ): 50 | out_pred = pred.model_dump() 51 | out_pred["page"] = len(predictions_by_page[name]) + 1 52 | predictions_by_page[name].append(out_pred) 53 | 54 | with open( 55 | os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" 56 | ) as f: 57 | json.dump(predictions_by_page, f, ensure_ascii=False) 58 | 59 | logger.info(f"Wrote results to {loader.result_path}") 60 | ``` -------------------------------------------------------------------------------- /surya/ocr_error/loader.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from surya.common.load import ModelLoader 6 | from surya.logging import get_logger 7 | from surya.ocr_error.model.config import DistilBertConfig 8 | from surya.ocr_error.model.encoder import DistilBertForSequenceClassification 9 | from surya.ocr_error.tokenizer import DistilBertTokenizer 10 | from surya.settings import settings 11 | 12 | logger = get_logger() 13 | 14 | 15 | class OCRErrorModelLoader(ModelLoader): 16 | def __init__(self, checkpoint: Optional[str] = None): 17 | super().__init__(checkpoint) 18 | 19 | if self.checkpoint is None: 20 | self.checkpoint = settings.OCR_ERROR_MODEL_CHECKPOINT 21 | 22 | def model( 23 | self, 24 | device=settings.TORCH_DEVICE_MODEL, 25 | dtype=settings.MODEL_DTYPE, 26 | attention_implementation: Optional[str] = None, 27 | ) -> DistilBertForSequenceClassification: 28 | if device is None: 29 | device = settings.TORCH_DEVICE_MODEL 30 | if dtype is None: 31 | dtype = settings.MODEL_DTYPE 32 | 33 | config = DistilBertConfig.from_pretrained(self.checkpoint) 34 | model = ( 35 | DistilBertForSequenceClassification.from_pretrained( 36 | self.checkpoint, 37 | dtype=dtype, 38 | config=config, 39 | ) 40 | .to(device) 41 | .eval() 42 | ) 43 | 44 | if settings.COMPILE_ALL or settings.COMPILE_OCR_ERROR: 45 | torch._dynamo.config.cache_size_limit = 1 46 | torch._dynamo.config.suppress_errors = False 47 | 48 | logger.info( 49 | f"Compiling detection model {self.checkpoint} from {DistilBertForSequenceClassification.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" 50 | ) 51 | compile_args = {"backend": "openxla"} if device == "xla" else {} 52 | model = torch.compile(model, **compile_args) 53 | 54 | return model 55 | 56 | def processor( 57 | self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE 58 | ) -> DistilBertTokenizer: 59 | return DistilBertTokenizer.from_pretrained(self.checkpoint) 60 | ``` -------------------------------------------------------------------------------- /benchmark/utils/verify_benchmark_scores.py: -------------------------------------------------------------------------------- ```python 1 | import json 2 | import click 3 | 4 | 5 | def verify_layout(data): 6 | scores = data["metrics"] 7 | for layout_type, metrics in scores.items(): 8 | if layout_type == "List": # Skip lists since none appear early on 9 | continue 10 | 11 | if metrics["precision"] <= 0.6 or metrics["recall"] <= 0.6: 12 | raise ValueError("Scores do not meet the required threshold") 13 | 14 | 15 | def verify_det(data): 16 | scores = data["metrics"]["surya"] 17 | if scores["precision"] <= 0.9 or scores["recall"] <= 0.9: 18 | raise ValueError("Scores do not meet the required threshold") 19 | 20 | 21 | def verify_rec(data): 22 | scores = data["surya"] 23 | if scores["avg_score"] <= 0.9: 24 | raise ValueError("Scores do not meet the required threshold") 25 | 26 | 27 | def verify_order(data): 28 | score = data["mean_accuracy"] 29 | if score < 0.75: 30 | raise ValueError("Scores do not meet the required threshold") 31 | 32 | 33 | def verify_table_rec(data): 34 | row_score = data["surya"]["mean_row_iou"] 35 | col_score = data["surya"]["mean_col_iou"] 36 | 37 | if row_score < 0.75 or col_score < 0.75: 38 | raise ValueError("Scores do not meet the required threshold") 39 | 40 | 41 | def verify_texify(data): 42 | edit_dist = data["scores"] 43 | if edit_dist > 0.2: 44 | raise ValueError("Scores do not meet the required threshold") 45 | 46 | 47 | @click.command(help="Verify benchmark scores") 48 | @click.argument("file_path", type=str) 49 | @click.option( 50 | "--bench_type", type=str, help="Type of benchmark to verify", default="detection" 51 | ) 52 | def main(file_path, bench_type): 53 | with open(file_path, "r") as file: 54 | data = json.load(file) 55 | 56 | if bench_type == "detection": 57 | verify_det(data) 58 | elif bench_type == "recognition": 59 | verify_rec(data) 60 | elif bench_type == "layout": 61 | verify_layout(data) 62 | elif bench_type == "ordering": 63 | verify_order(data) 64 | elif bench_type == "table_recognition": 65 | verify_table_rec(data) 66 | elif bench_type == "texify": 67 | verify_texify(data) 68 | else: 69 | raise ValueError("Invalid benchmark type") 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | ``` -------------------------------------------------------------------------------- /surya/debug/draw.py: -------------------------------------------------------------------------------- ```python 1 | from PIL import ImageDraw, ImageFont 2 | 3 | from surya.debug.fonts import get_font_path 4 | from surya.debug.text import get_text_size 5 | 6 | 7 | def draw_bboxes_on_image( 8 | bboxes, image, labels=None, label_font_size=10, color: str | list = "red" 9 | ): 10 | polys = [] 11 | for bb in bboxes: 12 | # Clockwise polygon 13 | poly = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]] 14 | polys.append(poly) 15 | 16 | return draw_polys_on_image( 17 | polys, image, labels, label_font_size=label_font_size, color=color 18 | ) 19 | 20 | 21 | def draw_polys_on_image( 22 | corners, 23 | image, 24 | labels=None, 25 | box_padding=-1, 26 | label_offset=1, 27 | label_font_size=10, 28 | color: str | list = "red", 29 | ): 30 | draw = ImageDraw.Draw(image) 31 | font_path = get_font_path() 32 | label_font = ImageFont.truetype(font_path, label_font_size) 33 | 34 | for i in range(len(corners)): 35 | poly = corners[i] 36 | poly = [(int(p[0]), int(p[1])) for p in poly] 37 | draw.polygon( 38 | poly, outline=color[i] if isinstance(color, list) else color, width=1 39 | ) 40 | 41 | if labels is not None: 42 | label = labels[i] 43 | text_position = ( 44 | min([p[0] for p in poly]) + label_offset, 45 | min([p[1] for p in poly]) + label_offset, 46 | ) 47 | text_size = get_text_size(label, label_font) 48 | box_position = ( 49 | text_position[0] - box_padding + label_offset, 50 | text_position[1] - box_padding + label_offset, 51 | text_position[0] + text_size[0] + box_padding + label_offset, 52 | text_position[1] + text_size[1] + box_padding + label_offset, 53 | ) 54 | try: 55 | draw.rectangle(box_position, fill="white") 56 | except Exception as e: 57 | print(f"Error drawing rectangle at {box_position}: {e}") 58 | continue 59 | draw.text( 60 | text_position, 61 | label, 62 | fill=color[i] if isinstance(color, list) else color, 63 | font=label_font, 64 | ) 65 | 66 | return image 67 | ``` -------------------------------------------------------------------------------- /surya/recognition/languages.py: -------------------------------------------------------------------------------- ```python 1 | CODE_TO_LANGUAGE = { 2 | "_math": "Math", 3 | "af": "Afrikaans", 4 | "am": "Amharic", 5 | "ar": "Arabic", 6 | "as": "Assamese", 7 | "az": "Azerbaijani", 8 | "be": "Belarusian", 9 | "bg": "Bulgarian", 10 | "bn": "Bengali", 11 | "br": "Breton", 12 | "bs": "Bosnian", 13 | "ca": "Catalan", 14 | "cs": "Czech", 15 | "cy": "Welsh", 16 | "da": "Danish", 17 | "de": "German", 18 | "el": "Greek", 19 | "en": "English", 20 | "eo": "Esperanto", 21 | "es": "Spanish", 22 | "et": "Estonian", 23 | "eu": "Basque", 24 | "fa": "Persian", 25 | "fi": "Finnish", 26 | "fr": "French", 27 | "fy": "Western Frisian", 28 | "ga": "Irish", 29 | "gd": "Scottish Gaelic", 30 | "gl": "Galician", 31 | "gu": "Gujarati", 32 | "ha": "Hausa", 33 | "he": "Hebrew", 34 | "hi": "Hindi", 35 | "hr": "Croatian", 36 | "hu": "Hungarian", 37 | "hy": "Armenian", 38 | "id": "Indonesian", 39 | "is": "Icelandic", 40 | "it": "Italian", 41 | "ja": "Japanese", 42 | "jv": "Javanese", 43 | "ka": "Georgian", 44 | "kk": "Kazakh", 45 | "km": "Khmer", 46 | "kn": "Kannada", 47 | "ko": "Korean", 48 | "ku": "Kurdish", 49 | "ky": "Kyrgyz", 50 | "la": "Latin", 51 | "lo": "Lao", 52 | "lt": "Lithuanian", 53 | "lv": "Latvian", 54 | "mg": "Malagasy", 55 | "mk": "Macedonian", 56 | "ml": "Malayalam", 57 | "mn": "Mongolian", 58 | "mr": "Marathi", 59 | "ms": "Malay", 60 | "my": "Burmese", 61 | "ne": "Nepali", 62 | "nl": "Dutch", 63 | "no": "Norwegian", 64 | "om": "Oromo", 65 | "or": "Oriya", 66 | "pa": "Punjabi", 67 | "pl": "Polish", 68 | "ps": "Pashto", 69 | "pt": "Portuguese", 70 | "ro": "Romanian", 71 | "ru": "Russian", 72 | "sa": "Sanskrit", 73 | "sd": "Sindhi", 74 | "si": "Sinhala", 75 | "sk": "Slovak", 76 | "sl": "Slovenian", 77 | "so": "Somali", 78 | "sq": "Albanian", 79 | "sr": "Serbian", 80 | "su": "Sundanese", 81 | "sv": "Swedish", 82 | "sw": "Swahili", 83 | "ta": "Tamil", 84 | "te": "Telugu", 85 | "th": "Thai", 86 | "tl": "Tagalog", 87 | "tr": "Turkish", 88 | "ug": "Uyghur", 89 | "uk": "Ukrainian", 90 | "ur": "Urdu", 91 | "uz": "Uzbek", 92 | "vi": "Vietnamese", 93 | "xh": "Xhosa", 94 | "yi": "Yiddish", 95 | "zh": "Chinese", 96 | } 97 | 98 | LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()} 99 | ``` -------------------------------------------------------------------------------- /surya/common/surya/embedder/__init__.py: -------------------------------------------------------------------------------- ```python 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SimpleTokenEmbedder(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size) 10 | self.bbox_embed = nn.ModuleList( 11 | [ 12 | nn.Embedding( 13 | config.bbox_size + config.special_token_count, 14 | config.bbox_embed_size, 15 | ) 16 | for _ in range(6) 17 | ] 18 | ) 19 | self.max_bbox_embedding = config.bbox_size + config.special_token_count - 1 20 | self.max_bbox_size = config.bbox_size 21 | 22 | def embed( 23 | self, 24 | input_tokens: torch.Tensor, 25 | input_boxes: torch.Tensor | None, 26 | embed_boxes: torch.Tensor, 27 | ) -> torch.Tensor: 28 | # Embed tokens 29 | token_embeds = self.token_embed(input_tokens) 30 | 31 | # Optionally embed boxes 32 | if input_boxes is not None and embed_boxes.any(): # Is none in prefill 33 | input_boxes = input_boxes.to(torch.long) 34 | bbox_loss_ignore_mask = ( 35 | (input_boxes[:, :, 0] < 0) | (input_boxes[:, :, 0] > self.max_bbox_size) 36 | ).unsqueeze(-1) 37 | input_boxes = torch.clamp(input_boxes, 0, self.max_bbox_embedding) 38 | 39 | bbox_embeds = torch.sum( 40 | torch.stack( 41 | [ 42 | self.bbox_embed[i](input_boxes[:, :, i]) 43 | for i in range(len(self.bbox_embed)) 44 | ], 45 | dim=-1, 46 | ), 47 | dim=-1, 48 | ) 49 | 50 | bbox_embeds = F.pad( 51 | bbox_embeds, (token_embeds.shape[-1] - bbox_embeds.shape[-1], 0) 52 | ) 53 | embed_boxes = embed_boxes.unsqueeze(1).unsqueeze(1).expand_as(bbox_embeds) 54 | bbox_loss_ignore_mask = bbox_loss_ignore_mask.expand_as(bbox_embeds) 55 | 56 | mask = embed_boxes & ~bbox_loss_ignore_mask 57 | bbox_embeds *= mask.float() 58 | 59 | token_embeds = token_embeds + bbox_embeds 60 | 61 | return token_embeds 62 | ``` -------------------------------------------------------------------------------- /surya/detection/loader.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from surya.common.load import ModelLoader 6 | from surya.detection.processor import SegformerImageProcessor 7 | 8 | from surya.detection.model.config import EfficientViTConfig 9 | from surya.detection.model.encoderdecoder import EfficientViTForSemanticSegmentation 10 | from surya.logging import get_logger 11 | from surya.settings import settings 12 | 13 | logger = get_logger() 14 | 15 | 16 | class DetectionModelLoader(ModelLoader): 17 | def __init__(self, checkpoint: Optional[str] = None): 18 | super().__init__(checkpoint) 19 | 20 | if self.checkpoint is None: 21 | self.checkpoint = settings.DETECTOR_MODEL_CHECKPOINT 22 | 23 | def model( 24 | self, 25 | device: Optional[torch.device | str] = None, 26 | dtype: Optional[torch.dtype | str] = None, 27 | attention_implementation: Optional[str] = None, 28 | ) -> EfficientViTForSemanticSegmentation: 29 | if device is None: 30 | device = settings.TORCH_DEVICE_MODEL 31 | if dtype is None: 32 | dtype = settings.MODEL_DTYPE 33 | 34 | config = EfficientViTConfig.from_pretrained(self.checkpoint) 35 | model = EfficientViTForSemanticSegmentation.from_pretrained( 36 | self.checkpoint, 37 | dtype=dtype, 38 | config=config, 39 | ) 40 | model = model.to(device) 41 | model = model.eval() 42 | 43 | if settings.COMPILE_ALL or settings.COMPILE_DETECTOR: 44 | torch._dynamo.config.cache_size_limit = 1 45 | torch._dynamo.config.suppress_errors = False 46 | 47 | logger.info( 48 | f"Compiling detection model {self.checkpoint} on device {device} with dtype {dtype}" 49 | ) 50 | compile_args = {"backend": "openxla"} if device == "xla" else {} 51 | model = torch.compile(model, **compile_args) 52 | 53 | logger.debug( 54 | f"Loaded detection model {self.checkpoint} from {EfficientViTForSemanticSegmentation.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" 55 | ) 56 | return model 57 | 58 | def processor( 59 | self, 60 | device: Optional[torch.device | str] = None, 61 | dtype: Optional[torch.dtype | str] = None, 62 | ) -> SegformerImageProcessor: 63 | return SegformerImageProcessor.from_pretrained(self.checkpoint) 64 | ``` -------------------------------------------------------------------------------- /surya/scripts/hf_to_s3.py: -------------------------------------------------------------------------------- ```python 1 | import json 2 | import shutil 3 | import datetime 4 | from pathlib import Path 5 | import boto3 6 | 7 | from huggingface_hub import snapshot_download 8 | 9 | import click 10 | from tqdm import tqdm 11 | 12 | S3_API_URL = "https://1afbe4656a6b40d982ab5e730a39f6b9.r2.cloudflarestorage.com" 13 | 14 | 15 | # Example usage - python scripts/hf_to_s3.py <REPO_NAME> layout 16 | # This will upload to s3://layout/TODAYS_DATE 17 | @click.command(help="Uploads the data from huggingface to an S3 bucket") 18 | @click.argument("hf_repo_id", type=str) 19 | @click.argument("s3_path", type=str) 20 | @click.option("--bucket_name", type=str, default="datalab") 21 | @click.option("--revision_hash", type=str, default=None) 22 | @click.option("--access_key_id", type=str, default="<access_key_id>") 23 | @click.option("--access_key_secret", type=str, default="<access_key_secret>") 24 | @click.option("--suffix", type=str, default="") 25 | def main( 26 | hf_repo_id: str, 27 | s3_path: str, 28 | bucket_name: str, 29 | revision_hash: str, 30 | access_key_id: str, 31 | access_key_secret: str, 32 | suffix: str, 33 | ): 34 | curr_date = datetime.datetime.now().strftime("%Y_%m_%d") 35 | s3_path = f"{s3_path}/{curr_date}" 36 | if suffix: 37 | s3_path = f"{s3_path}_{suffix}" 38 | 39 | download_folder = snapshot_download(repo_id=hf_repo_id, revision=revision_hash) 40 | download_folder = Path(download_folder) 41 | contained_files = list(download_folder.glob("*")) 42 | contained_files = [f.name for f in contained_files] # Just get the base name 43 | manifest_file = download_folder / "manifest.json" 44 | 45 | with open(manifest_file, "w") as f: 46 | json.dump({"files": contained_files}, f) 47 | 48 | # Upload the files to S3 49 | s3_client = boto3.client( 50 | service_name="s3", 51 | endpoint_url=S3_API_URL, 52 | aws_access_key_id=access_key_id, 53 | aws_secret_access_key=access_key_secret, 54 | region_name="auto", 55 | ) 56 | 57 | # Iterate through all files in the folder 58 | for file_path in tqdm( 59 | download_folder.glob("*"), desc="Uploading files", unit="file" 60 | ): 61 | s3_key = f"{s3_path}/{file_path.name}" 62 | 63 | try: 64 | s3_client.upload_file(str(file_path), bucket_name, s3_key) 65 | except Exception as e: 66 | print(f"Error uploading {file_path}: {str(e)}") 67 | 68 | shutil.rmtree(download_folder) 69 | 70 | print(f"Uploaded files to {s3_path}") 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | ``` -------------------------------------------------------------------------------- /surya/ocr_error/__init__.py: -------------------------------------------------------------------------------- ```python 1 | import math 2 | from typing import List, Optional 3 | 4 | from tqdm import tqdm 5 | 6 | from surya.common.predictor import BasePredictor 7 | from surya.ocr_error.loader import OCRErrorModelLoader 8 | from surya.ocr_error.model.config import ID2LABEL 9 | from surya.ocr_error.schema import OCRErrorDetectionResult 10 | from surya.settings import settings 11 | from surya.common.xla import mark_step 12 | 13 | 14 | class OCRErrorPredictor(BasePredictor): 15 | model_loader_cls = OCRErrorModelLoader 16 | batch_size = settings.OCR_ERROR_BATCH_SIZE 17 | default_batch_sizes = {"cpu": 8, "mps": 8, "cuda": 64, "xla": 32} 18 | 19 | def __call__(self, texts: List[str], batch_size: Optional[int] = None): 20 | return self.batch_ocr_error_detection(texts, batch_size) 21 | 22 | def batch_ocr_error_detection( 23 | self, texts: List[str], batch_size: Optional[int] = None 24 | ): 25 | if batch_size is None: 26 | batch_size = self.get_batch_size() 27 | 28 | num_batches = math.ceil(len(texts) / batch_size) 29 | texts_processed = self.processor( 30 | texts, padding="longest", truncation=True, return_tensors="pt" 31 | ) 32 | predictions = [] 33 | for batch_idx in tqdm( 34 | range(num_batches), 35 | desc="Running OCR Error Detection", 36 | disable=self.disable_tqdm, 37 | ): 38 | start_idx, end_idx = batch_idx * batch_size, (batch_idx + 1) * batch_size 39 | batch_input_ids = texts_processed.input_ids[start_idx:end_idx].to( 40 | self.model.device 41 | ) 42 | batch_attention_mask = texts_processed.attention_mask[start_idx:end_idx].to( 43 | self.model.device 44 | ) 45 | 46 | # Pad to batch size 47 | current_batch_size = batch_input_ids.shape[0] 48 | if settings.OCR_ERROR_STATIC_CACHE: 49 | batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) 50 | batch_attention_mask = self.pad_to_batch_size( 51 | batch_attention_mask, batch_size 52 | ) 53 | 54 | with settings.INFERENCE_MODE(): 55 | pred = self.model(batch_input_ids, attention_mask=batch_attention_mask) 56 | 57 | logits = pred.logits.argmax(dim=1).cpu().tolist()[:current_batch_size] 58 | predictions.extend(logits) 59 | mark_step() 60 | 61 | return OCRErrorDetectionResult( 62 | texts=texts, labels=[ID2LABEL[p] for p in predictions] 63 | ) 64 | ``` -------------------------------------------------------------------------------- /surya/input/load.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List 2 | import PIL 3 | 4 | from surya.input.processing import open_pdf, get_page_images 5 | from surya.logging import get_logger 6 | from surya.settings import settings 7 | import os 8 | import filetype 9 | from PIL import Image 10 | import json 11 | 12 | logger = get_logger() 13 | 14 | 15 | def get_name_from_path(path): 16 | return os.path.basename(path).split(".")[0] 17 | 18 | 19 | def load_pdf(pdf_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI): 20 | doc = open_pdf(pdf_path) 21 | last_page = len(doc) 22 | 23 | if page_range: 24 | assert all([0 <= page < last_page for page in page_range]), ( 25 | f"Invalid page range: {page_range}" 26 | ) 27 | else: 28 | page_range = list(range(last_page)) 29 | 30 | images = get_page_images(doc, page_range, dpi=dpi) 31 | doc.close() 32 | names = [get_name_from_path(pdf_path) for _ in page_range] 33 | return images, names 34 | 35 | 36 | def load_image(image_path): 37 | image = Image.open(image_path).convert("RGB") 38 | name = get_name_from_path(image_path) 39 | return [image], [name] 40 | 41 | 42 | def load_from_file( 43 | input_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI 44 | ): 45 | input_type = filetype.guess(input_path) 46 | if input_type and input_type.extension == "pdf": 47 | return load_pdf(input_path, page_range, dpi=dpi) 48 | else: 49 | return load_image(input_path) 50 | 51 | 52 | def load_from_folder( 53 | folder_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI 54 | ): 55 | image_paths = [ 56 | os.path.join(folder_path, image_name) 57 | for image_name in os.listdir(folder_path) 58 | if not image_name.startswith(".") 59 | ] 60 | image_paths = [ip for ip in image_paths if not os.path.isdir(ip)] 61 | 62 | images = [] 63 | names = [] 64 | for path in image_paths: 65 | extension = filetype.guess(path) 66 | if extension and extension.extension == "pdf": 67 | image, name = load_pdf(path, page_range, dpi=dpi) 68 | images.extend(image) 69 | names.extend(name) 70 | else: 71 | try: 72 | image, name = load_image(path) 73 | images.extend(image) 74 | names.extend(name) 75 | except PIL.UnidentifiedImageError: 76 | logger.warning(f"Could not load image {path}") 77 | continue 78 | return images, names 79 | 80 | 81 | def load_lang_file(lang_path, names): 82 | with open(lang_path, "r") as f: 83 | lang_dict = json.load(f) 84 | return [lang_dict[name].copy() for name in names] 85 | ``` -------------------------------------------------------------------------------- /surya/scripts/ocr_text.py: -------------------------------------------------------------------------------- ```python 1 | import os 2 | import click 3 | import json 4 | import time 5 | from collections import defaultdict 6 | 7 | from surya.common.surya.schema import TaskNames 8 | from surya.detection import DetectionPredictor 9 | from surya.debug.text import draw_text_on_image 10 | from surya.logging import configure_logging, get_logger 11 | from surya.foundation import FoundationPredictor 12 | from surya.recognition import RecognitionPredictor 13 | from surya.scripts.config import CLILoader 14 | 15 | configure_logging() 16 | logger = get_logger() 17 | 18 | 19 | @click.command(help="OCR text.") 20 | @click.option("--task_name", type=str, default=TaskNames.ocr_with_boxes) 21 | @click.option( 22 | "--disable_math", is_flag=True, default=False, help="Do not recognize math in OCR." 23 | ) 24 | @CLILoader.common_options 25 | def ocr_text_cli(input_path: str, task_name: str, disable_math: bool, **kwargs): 26 | loader = CLILoader(input_path, kwargs, highres=True) 27 | task_names = [task_name] * len(loader.images) 28 | 29 | foundation_predictor = FoundationPredictor() 30 | det_predictor = DetectionPredictor() 31 | rec_predictor = RecognitionPredictor(foundation_predictor) 32 | 33 | start = time.time() 34 | predictions_by_image = rec_predictor( 35 | loader.images, 36 | task_names=task_names, 37 | det_predictor=det_predictor, 38 | highres_images=loader.highres_images, 39 | math_mode=not disable_math, 40 | ) 41 | 42 | if loader.debug: 43 | logger.debug(f"OCR took {time.time() - start:.2f} seconds") 44 | max_chars = max( 45 | [len(line.text) for p in predictions_by_image for line in p.text_lines] 46 | ) 47 | logger.debug(f"Max chars: {max_chars}") 48 | 49 | if loader.save_images: 50 | for idx, (name, image, pred) in enumerate( 51 | zip(loader.names, loader.images, predictions_by_image) 52 | ): 53 | bboxes = [line.bbox for line in pred.text_lines] 54 | pred_text = [line.text for line in pred.text_lines] 55 | page_image = draw_text_on_image(bboxes, pred_text, image.size) 56 | page_image.save(os.path.join(loader.result_path, f"{name}_{idx}_text.png")) 57 | 58 | out_preds = defaultdict(list) 59 | for name, pred, image in zip(loader.names, predictions_by_image, loader.images): 60 | out_pred = pred.model_dump() 61 | out_pred["page"] = len(out_preds[name]) + 1 62 | out_preds[name].append(out_pred) 63 | 64 | with open( 65 | os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" 66 | ) as f: 67 | json.dump(out_preds, f, ensure_ascii=False) 68 | 69 | logger.info(f"Wrote results to {loader.result_path}") 70 | ``` -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- ```python 1 | import os 2 | 3 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 4 | 5 | import pytest 6 | from PIL import Image, ImageDraw 7 | 8 | from surya.detection import DetectionPredictor 9 | from surya.ocr_error import OCRErrorPredictor 10 | from surya.layout import LayoutPredictor 11 | from surya.recognition import RecognitionPredictor 12 | from surya.foundation import FoundationPredictor 13 | from surya.table_rec import TableRecPredictor 14 | from surya.settings import settings 15 | 16 | @pytest.fixture(scope="session") 17 | def ocr_error_predictor() -> OCRErrorPredictor: 18 | ocr_error_predictor = OCRErrorPredictor() 19 | yield ocr_error_predictor 20 | del ocr_error_predictor 21 | 22 | 23 | @pytest.fixture(scope="session") 24 | def layout_predictor() -> LayoutPredictor: 25 | layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)) 26 | yield layout_predictor 27 | del layout_predictor 28 | 29 | 30 | @pytest.fixture(scope="session") 31 | def detection_predictor() -> DetectionPredictor: 32 | detection_predictor = DetectionPredictor() 33 | yield detection_predictor 34 | del detection_predictor 35 | 36 | 37 | @pytest.fixture(scope="session") 38 | def recognition_predictor() -> RecognitionPredictor: 39 | recognition_predictor = RecognitionPredictor(FoundationPredictor(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT)) 40 | yield recognition_predictor 41 | del recognition_predictor 42 | 43 | 44 | @pytest.fixture(scope="session") 45 | def table_rec_predictor() -> TableRecPredictor: 46 | table_rec_predictor = TableRecPredictor() 47 | yield table_rec_predictor 48 | del table_rec_predictor 49 | 50 | 51 | @pytest.fixture() 52 | def test_image(): 53 | image = Image.new("RGB", (1024, 1024), "white") 54 | draw = ImageDraw.Draw(image) 55 | draw.text((10, 10), "Hello World", fill="black", font_size=72) 56 | draw.text( 57 | (10, 200), 58 | "This is a sentence of text.\nNow it is a paragraph.\nA three-line one.", 59 | fill="black", 60 | font_size=24, 61 | ) 62 | return image 63 | 64 | 65 | @pytest.fixture() 66 | def test_image_tall(): 67 | image = Image.new("RGB", (4096, 4096), "white") 68 | draw = ImageDraw.Draw(image) 69 | draw.text((10, 10), "Hello World", fill="black", font_size=72) 70 | draw.text( 71 | (4000, 4000), 72 | "This is a sentence of text.\n\nNow it is a paragraph.\n\nA three-line one.", 73 | fill="black", 74 | font_size=24, 75 | ) 76 | return image 77 | 78 | @pytest.fixture() 79 | def test_image_latex(): 80 | assets_dir = os.path.join(os.path.dirname(__file__), "assets") 81 | img_path = os.path.join(assets_dir, "test_latex.png") 82 | image = Image.open(img_path).convert("RGB") 83 | return image ``` -------------------------------------------------------------------------------- /surya/debug/render_html.py: -------------------------------------------------------------------------------- ```python 1 | import html as htmllib 2 | import os.path 3 | import re 4 | 5 | filepath = os.path.abspath(__file__) 6 | 7 | def render_text_as_html( 8 | bboxes: list[list[int]], 9 | texts: list[str], 10 | image_size: tuple[int, int], 11 | base_font_size: int = 16, 12 | scaler: int = 2 13 | ): 14 | katex_path = os.path.join(os.path.dirname(filepath), "katex.js") 15 | with open(katex_path, "r") as f: 16 | katex_script = f.read() 17 | 18 | html_content = [] 19 | image_size = tuple([int(s * scaler) for s in image_size]) 20 | width, height = image_size 21 | 22 | 23 | html_content.append(f""" 24 | <!DOCTYPE html> 25 | <html> 26 | <head> 27 | <style> 28 | body {{ 29 | margin: 0; 30 | padding: 0; 31 | width: {width}px; 32 | height: {height}px; 33 | position: relative; 34 | overflow: hidden; 35 | background: white; 36 | color: black; 37 | }} 38 | .text-box {{ 39 | position: absolute; 40 | overflow: hidden; 41 | display: flex; 42 | justify-content: left; 43 | font-family: Arial, sans-serif; 44 | white-space: pre-wrap; 45 | }} 46 | .vertical-text {{ 47 | writing-mode: vertical-rl; /* Top to bottom, right to left */ 48 | }} 49 | </style> 50 | {katex_script} 51 | </head> 52 | <body> 53 | """) 54 | 55 | for i, (bbox, text) in enumerate(zip(bboxes, texts)): 56 | bbox = bbox.copy() 57 | bbox = [int(bb * scaler) for bb in bbox] 58 | x1, y1, x2, y2 = bbox 59 | width = x2 - x1 60 | height = y2 - y1 61 | min_dim = min(width, height) 62 | 63 | # Scale font size based on box height 64 | font_size = min(int(min_dim * 0.75), base_font_size) 65 | 66 | # Create div with absolute positioning 67 | div_style = ( 68 | f"left: {x1}px; " 69 | f"top: {y1}px; " 70 | f"width: {width}px; " 71 | f"height: {height}px; " 72 | f"font-size: {font_size}px;" 73 | ) 74 | 75 | class_ = "text-box" 76 | if height > width * 2: 77 | class_ += " vertical-text" 78 | 79 | # Determine if content is HTML/MathML or plain text 80 | if "<" in text and ">" in text and re.search(r"<(html|math|div|sub|sup|i|u|mark|small|del|b|br|code)\b", text.lower()): 81 | # Content is already HTML/MathML, include as-is 82 | html_content.append(f'<span class="{class_}" id="box-{i}" style="{div_style}">{text}</span>') 83 | else: 84 | # Plain text, escape it 85 | escaped_text = htmllib.escape(text) 86 | html_content.append(f'<span class="{class_}" id="box-{i}" style="{div_style}">{escaped_text}</span>') 87 | 88 | html_content.append("</body></html>") 89 | 90 | return "\n".join(html_content), image_size ``` -------------------------------------------------------------------------------- /surya/common/predictor.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Optional 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from surya.common.load import ModelLoader 6 | from surya.settings import settings 7 | 8 | 9 | class BasePredictor: 10 | model_loader_cls = ModelLoader 11 | batch_size: Optional[int] = None 12 | default_batch_sizes = {"cpu": 1, "mps": 1, "cuda": 1} 13 | torch_dtype = settings.MODEL_DTYPE 14 | 15 | @property 16 | def disable_tqdm(self) -> bool: 17 | return self._disable_tqdm 18 | 19 | @disable_tqdm.setter 20 | def disable_tqdm(self, value: bool) -> None: 21 | self._disable_tqdm = bool(value) 22 | 23 | def __init__( 24 | self, 25 | checkpoint: Optional[str] = None, 26 | device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, 27 | dtype: Optional[torch.dtype | str] = None, 28 | attention_implementation: Optional[str] = None, 29 | ): 30 | if dtype is None: 31 | dtype = self.torch_dtype 32 | 33 | self.model = None 34 | self.processor = None 35 | loader = self.model_loader_cls(checkpoint) 36 | 37 | self.model = loader.model(device, dtype, attention_implementation) 38 | self.processor = loader.processor() 39 | 40 | self._disable_tqdm = settings.DISABLE_TQDM 41 | 42 | def to(self, device_dtype: torch.device | str | None = None): 43 | model_moved = False 44 | if hasattr(self, "model") and self.model: 45 | self.model.to(device_dtype) 46 | model_moved = True 47 | if hasattr(self, "foundation_predictor") and self.foundation_predictor: 48 | self.foundation_predictor.model.to(device_dtype) 49 | model_moved = True 50 | 51 | if not model_moved: 52 | raise ValueError("Model not loaded") 53 | 54 | def get_batch_size(self): 55 | batch_size = self.batch_size 56 | if batch_size is None: 57 | batch_size = self.default_batch_sizes["cpu"] 58 | if settings.TORCH_DEVICE_MODEL in self.default_batch_sizes: 59 | batch_size = self.default_batch_sizes[settings.TORCH_DEVICE_MODEL] 60 | return batch_size 61 | 62 | @staticmethod 63 | def pad_to_batch_size(tensor: torch.Tensor, batch_size: int): 64 | current_batch_size = tensor.shape[0] 65 | if current_batch_size >= batch_size: 66 | return tensor 67 | 68 | if len(tensor.shape) == 1: 69 | # If tensor is 1D, we need to pad it to the batch size 70 | pad_size = batch_size - current_batch_size 71 | return F.pad(tensor, (0, pad_size), mode="constant", value=0) 72 | 73 | pad_size = batch_size - current_batch_size 74 | padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) 75 | 76 | return F.pad(tensor, padding, mode="constant", value=0) 77 | 78 | def __call__(self, *args, **kwargs): 79 | raise NotImplementedError() 80 | ``` -------------------------------------------------------------------------------- /surya/scripts/config.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List 2 | 3 | import click 4 | import os 5 | from surya.input.load import load_from_folder, load_from_file 6 | from surya.settings import settings 7 | 8 | 9 | class CLILoader: 10 | def __init__(self, filepath: str, cli_options: dict, highres: bool = False): 11 | self.page_range = cli_options.get("page_range") 12 | if self.page_range: 13 | self.page_range = self.parse_range_str(self.page_range) 14 | self.filepath = filepath 15 | self.config = cli_options 16 | self.save_images = cli_options.get("images", False) 17 | self.debug = cli_options.get("debug", False) 18 | self.output_dir = cli_options.get("output_dir") 19 | 20 | self.load(highres) 21 | 22 | @staticmethod 23 | def common_options(fn): 24 | fn = click.argument("input_path", type=click.Path(exists=True), required=True)(fn) 25 | fn = click.option("--output_dir", type=click.Path(exists=False), required=False, default=os.path.join(settings.RESULT_DIR, "surya"), help="Directory to save output.")(fn) 26 | fn = click.option("--page_range", type=str, default=None, help="Page range to convert, specify comma separated page numbers or ranges. Example: 0,5-10,20")(fn) 27 | fn = click.option("--images", is_flag=True, help="Save images of detected bboxes.", default=False)(fn) 28 | fn = click.option('--debug', '-d', is_flag=True, help='Enable debug mode.', default=False)(fn) 29 | return fn 30 | 31 | def load(self, highres: bool = False): 32 | highres_images = None 33 | if os.path.isdir(self.filepath): 34 | images, names = load_from_folder(self.filepath, self.page_range) 35 | folder_name = os.path.basename(self.filepath) 36 | if highres: 37 | highres_images, _ = load_from_folder(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES) 38 | else: 39 | images, names = load_from_file(self.filepath, self.page_range) 40 | folder_name = os.path.basename(self.filepath).split(".")[0] 41 | if highres: 42 | highres_images, _ = load_from_file(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES) 43 | 44 | 45 | self.images = images 46 | self.highres_images = highres_images 47 | self.names = names 48 | 49 | self.result_path = os.path.abspath(os.path.join(self.output_dir, folder_name)) 50 | os.makedirs(self.result_path, exist_ok=True) 51 | 52 | @staticmethod 53 | def parse_range_str(range_str: str) -> List[int]: 54 | range_lst = range_str.split(",") 55 | page_lst = [] 56 | for i in range_lst: 57 | if "-" in i: 58 | start, end = i.split("-") 59 | page_lst += list(range(int(start), int(end) + 1)) 60 | else: 61 | page_lst.append(int(i)) 62 | page_lst = sorted(list(set(page_lst))) # Deduplicate page numbers and sort in order 63 | return page_lst ``` -------------------------------------------------------------------------------- /tests/test_recognition.py: -------------------------------------------------------------------------------- ```python 1 | import time 2 | from PIL import ImageDraw, Image 3 | from surya.recognition.util import clean_math_tags 4 | 5 | 6 | def test_recognition(recognition_predictor, detection_predictor, test_image): 7 | recognition_results = recognition_predictor([test_image], None, detection_predictor) 8 | 9 | assert len(recognition_results) == 1 10 | assert recognition_results[0].image_bbox == [0, 0, 1024, 1024] 11 | 12 | text_lines = recognition_results[0].text_lines 13 | assert len(text_lines) == 4 14 | assert "Hello World" in text_lines[0].text 15 | 16 | 17 | def test_recognition_input_text(recognition_predictor, detection_predictor, test_image): 18 | start = time.time() 19 | recognition_predictor([test_image], None, detection_predictor) 20 | end = time.time() - start 21 | 22 | input_text = "a" * 400 23 | start2 = time.time() 24 | recognition_results = recognition_predictor( 25 | [test_image], None, detection_predictor, input_text=[input_text] 26 | ) 27 | end2 = time.time() - start2 28 | 29 | assert max([end, end2]) / min([end, end2]) < 1.5, ( 30 | "Input text should be truncated and not change inference time" 31 | ) 32 | 33 | assert len(recognition_results) == 1 34 | assert recognition_results[0].image_bbox == [0, 0, 1024, 1024] 35 | 36 | text_lines = recognition_results[0].text_lines 37 | assert len(text_lines) == 4 38 | assert "Hello World" in text_lines[0].text 39 | 40 | 41 | def test_recognition_drop_repeats(recognition_predictor, detection_predictor): 42 | image = Image.new("RGB", (1024, 128), "white") 43 | draw = ImageDraw.Draw(image) 44 | text = "a" * 80 45 | draw.text((5, 5), text, fill="black", font_size=24) 46 | 47 | recognition_results = recognition_predictor( 48 | [image], None, bboxes=[[[0, 0, 1024, 128]]], drop_repeated_text=True 49 | ) 50 | assert len(recognition_results) == 1 51 | result = recognition_results[0].text_lines 52 | assert result[0].text == "" 53 | 54 | 55 | def test_recognition_clean_math(): 56 | math = """<math display="block">na_n^{1+2r} \\text{cov}(\\hat{f}_n^{(r)}(x), \\hat{f}_n^{(r)}(y)) = \\frac{1}{n} \\sum_{j=1}^n \\frac{a_n^{1+2r}}{a_j^{1+2r}} \\text{cov}\\left(K^{(r)}\\left(\\frac{x-X_j}{a_j}\\right), K^{(r)}\\left(\\frac{y-X_j}{a_j}\\right)\\right) <br>+ \\frac{a_n^{1+2r}}{n} \\sum_{\\substack{j \\neq k \\\\ 1 \\le j, k \\le n}} \\frac{1}{(a_j a_k)^{1+r}} \\text{cov}\\left(K^{(r)}\\left(\\frac{x-X_j}{a_j}\\right), K^{(r)}\\left(\\frac{y-X_k}{a_k}\\right)\\right) <br>=: I_1 + I_2.</math> (1.7)</math>'""" 57 | clean_math = clean_math_tags(math) 58 | 59 | assert clean_math.count("</math>") == 1, "Should have one closing math tag" 60 | assert "<br>" not in clean_math, "Should not have <br> tags in cleaned math" 61 | 62 | 63 | def test_recognition_clean_math_preserve_text(): 64 | text = """Hello, this is a sentence with <math display="inline">x^2 + y^2 = z^2</math> and some text after it, with a weird tag <hello> and <goodbye>.""" 65 | clean_text = clean_math_tags(text) 66 | 67 | assert clean_text == text 68 | ``` -------------------------------------------------------------------------------- /surya/input/processing.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List 2 | 3 | import cv2 4 | import numpy as np 5 | import pypdfium2 6 | from PIL import Image 7 | 8 | from surya.logging import get_logger 9 | from surya.settings import settings 10 | 11 | logger = get_logger() 12 | 13 | 14 | def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]: 15 | new_images = [] 16 | for image in images: 17 | if image.mode != "RGB": 18 | image = image.convert("RGB") 19 | new_images.append(image) 20 | return new_images 21 | 22 | 23 | def open_pdf(pdf_filepath): 24 | return pypdfium2.PdfDocument(pdf_filepath) 25 | 26 | 27 | def get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI): 28 | images = [ 29 | doc[i].render(scale=dpi / 72, draw_annots=False).to_pil() for i in indices 30 | ] 31 | images = [image.convert("RGB") for image in images] 32 | return images 33 | 34 | 35 | def slice_bboxes_from_image(image: np.ndarray, bboxes): 36 | lines = [] 37 | for bbox in bboxes: 38 | bbox = np.array(bbox, dtype=np.int32) 39 | bbox = np.clip(bbox, 0, None) # Ensure no negative indices 40 | # Ensure bbox is within the image bounds 41 | if bbox[3] <= bbox[1]: 42 | bbox[3] = bbox[1] + 1 43 | 44 | if bbox[2] <= bbox[0]: 45 | bbox[2] = bbox[0] + 1 46 | 47 | bbox[2] = min(bbox[2], image.shape[1]) 48 | bbox[3] = min(bbox[3], image.shape[0]) 49 | 50 | line = image[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy() 51 | if line.size == 0: 52 | logger.warning(f"Warning: found an empty line with bbox {bbox}") 53 | lines.append(line) 54 | return lines 55 | 56 | 57 | def slice_polys_from_image(image: np.ndarray, polys): 58 | lines = [] 59 | for idx, poly in enumerate(polys): 60 | lines.append(slice_and_pad_poly(image, poly)) 61 | return lines 62 | 63 | 64 | def slice_and_pad_poly(image_array: np.array, coordinates): 65 | # Draw polygon onto mask 66 | coordinates = [(corner[0], corner[1]) for corner in coordinates] 67 | bbox = [ 68 | min([x[0] for x in coordinates]), 69 | min([x[1] for x in coordinates]), 70 | max([x[0] for x in coordinates]), 71 | max([x[1] for x in coordinates]), 72 | ] 73 | 74 | # We mask out anything not in the polygon 75 | cropped_polygon = image_array[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy() 76 | height, width = cropped_polygon.shape[:2] 77 | 78 | coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates] 79 | 80 | # Validate the cropped area 81 | if any( 82 | [ 83 | bbox[3] <= bbox[1] or bbox[2] <= bbox[0], 84 | len(coordinates) < 3, 85 | height == 0, 86 | width == 0, 87 | ] 88 | ): 89 | return cropped_polygon 90 | 91 | # Pad the area outside the polygon with the pad value 92 | try: 93 | mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8) 94 | cv2.fillPoly(mask, [np.int32(coordinates)], 1) 95 | mask = np.stack([mask] * 3, axis=-1) 96 | 97 | cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE 98 | except cv2.error as e: 99 | logger.warning(f"Warning: issue while processing polygon: {e}") 100 | 101 | return cropped_polygon 102 | ``` -------------------------------------------------------------------------------- /surya/table_rec/loader.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from surya.common.load import ModelLoader 6 | from surya.logging import get_logger 7 | from surya.settings import settings 8 | from surya.table_rec.model.config import ( 9 | SuryaTableRecConfig, 10 | SuryaTableRecDecoderConfig, 11 | DonutSwinTableRecConfig, 12 | ) 13 | from surya.table_rec.model.encoderdecoder import TableRecEncoderDecoderModel 14 | from surya.table_rec.processor import SuryaTableRecProcessor 15 | 16 | logger = get_logger() 17 | 18 | 19 | class TableRecModelLoader(ModelLoader): 20 | def __init__(self, checkpoint: Optional[str] = None): 21 | super().__init__(checkpoint) 22 | 23 | if self.checkpoint is None: 24 | self.checkpoint = settings.TABLE_REC_MODEL_CHECKPOINT 25 | 26 | def model( 27 | self, 28 | device=settings.TORCH_DEVICE_MODEL, 29 | dtype=settings.MODEL_DTYPE, 30 | attention_implementation: Optional[str] = None, 31 | ) -> TableRecEncoderDecoderModel: 32 | if device is None: 33 | device = settings.TORCH_DEVICE_MODEL 34 | if dtype is None: 35 | dtype = settings.MODEL_DTYPE 36 | 37 | if device == "mps": 38 | logger.warning( 39 | "`TableRecEncoderDecoderModel` is not compatible with mps backend. Defaulting to cpu instead" 40 | ) 41 | device = "cpu" 42 | dtype = "float32" 43 | 44 | config = SuryaTableRecConfig.from_pretrained(self.checkpoint) 45 | decoder_config = config.decoder 46 | decoder = SuryaTableRecDecoderConfig(**decoder_config) 47 | config.decoder = decoder 48 | 49 | encoder_config = config.encoder 50 | encoder = DonutSwinTableRecConfig(**encoder_config) 51 | config.encoder = encoder 52 | 53 | model = TableRecEncoderDecoderModel.from_pretrained( 54 | self.checkpoint, config=config, dtype=dtype 55 | ) 56 | 57 | model = model.to(device) 58 | model = model.eval() 59 | 60 | if settings.COMPILE_ALL or settings.COMPILE_TABLE_REC: 61 | torch.set_float32_matmul_precision("high") 62 | torch._dynamo.config.cache_size_limit = 16 63 | torch._dynamo.config.suppress_errors = False 64 | 65 | logger.info( 66 | f"Compiling table recognition model {self.checkpoint} on device {device} with dtype {dtype}" 67 | ) 68 | compile_args = {"backend": "openxla"} if device == "xla" else {} 69 | model.encoder = torch.compile(model.encoder, **compile_args) 70 | model.decoder = torch.compile(model.decoder, **compile_args) 71 | 72 | logger.debug( 73 | f"Loaded table recognition model {self.checkpoint} from {TableRecEncoderDecoderModel.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" 74 | ) 75 | return model 76 | 77 | def processor( 78 | self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE 79 | ) -> SuryaTableRecProcessor: 80 | processor = SuryaTableRecProcessor(self.checkpoint) 81 | 82 | processor.token_pad_id = 0 83 | processor.token_eos_id = 1 84 | processor.token_bos_id = 1 85 | processor.token_query_end_id = 4 86 | return processor 87 | ``` -------------------------------------------------------------------------------- /surya/common/surya/decoder/config.py: -------------------------------------------------------------------------------- ```python 1 | from transformers.configuration_utils import PretrainedConfig 2 | from transformers.modeling_rope_utils import rope_config_validation 3 | from transformers.utils import logging 4 | 5 | logger = logging.get_logger(__name__) 6 | 7 | 8 | class SuryaDecoderConfig(PretrainedConfig): 9 | model_type = "qwen2" 10 | keys_to_ignore_at_inference = ["past_key_values"] 11 | 12 | # Default tensor parallel plan for base model `Qwen2` 13 | base_model_tp_plan = { 14 | "layers.*.self_attn.q_proj": "colwise", 15 | "layers.*.self_attn.k_proj": "colwise", 16 | "layers.*.self_attn.v_proj": "colwise", 17 | "layers.*.self_attn.o_proj": "rowwise", 18 | "layers.*.mlp.gate_proj": "colwise", 19 | "layers.*.mlp.up_proj": "colwise", 20 | "layers.*.mlp.down_proj": "rowwise", 21 | } 22 | base_model_pp_plan = { 23 | "embed_tokens": (["input_ids"], ["inputs_embeds"]), 24 | "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), 25 | "norm": (["hidden_states"], ["hidden_states"]), 26 | } 27 | 28 | def __init__( 29 | self, 30 | vocab_size=151936, 31 | hidden_size=4096, 32 | intermediate_size=22016, 33 | num_hidden_layers=32, 34 | num_attention_heads=32, 35 | num_key_value_heads=32, 36 | hidden_act="silu", 37 | max_position_embeddings=32768, 38 | initializer_range=0.02, 39 | rms_norm_eps=1e-6, 40 | use_cache=True, 41 | tie_word_embeddings=False, 42 | rope_theta=10000.0, 43 | rope_scaling=None, 44 | use_sliding_window=False, 45 | sliding_window=4096, 46 | max_window_layers=28, 47 | attention_dropout=0.0, 48 | **kwargs, 49 | ): 50 | self.vocab_size = vocab_size 51 | self.max_position_embeddings = max_position_embeddings 52 | self.hidden_size = hidden_size 53 | self.intermediate_size = intermediate_size 54 | self.num_hidden_layers = num_hidden_layers 55 | self.num_attention_heads = num_attention_heads 56 | self.use_sliding_window = False # Disable sliding window 57 | self.sliding_window = ( 58 | sliding_window # we check `use_sliding_window` in the modeling code 59 | ) 60 | self.max_window_layers = max_window_layers 61 | 62 | # for backward compatibility 63 | if num_key_value_heads is None: 64 | num_key_value_heads = num_attention_heads 65 | 66 | self.num_key_value_heads = num_key_value_heads 67 | self.hidden_act = hidden_act 68 | self.initializer_range = initializer_range 69 | self.rms_norm_eps = rms_norm_eps 70 | self.use_cache = use_cache 71 | self.rope_theta = rope_theta 72 | self.rope_scaling = rope_scaling 73 | self.attention_dropout = attention_dropout 74 | # Validate the correctness of rotary position embeddings parameters 75 | # BC: if there is a 'type' field, move it to 'rope_type'. 76 | if self.rope_scaling is not None and "type" in self.rope_scaling: 77 | self.rope_scaling["rope_type"] = self.rope_scaling["type"] 78 | rope_config_validation(self) 79 | 80 | super().__init__( 81 | tie_word_embeddings=tie_word_embeddings, 82 | **kwargs, 83 | ) 84 | ``` -------------------------------------------------------------------------------- /benchmark/ordering.py: -------------------------------------------------------------------------------- ```python 1 | import collections 2 | import json 3 | 4 | import click 5 | 6 | from surya.foundation import FoundationPredictor 7 | from surya.input.processing import convert_if_not_rgb 8 | from surya.layout import LayoutPredictor 9 | from surya.common.polygon import PolygonBox 10 | from surya.settings import settings 11 | from benchmark.utils.metrics import rank_accuracy 12 | import os 13 | import time 14 | import datasets 15 | 16 | 17 | @click.command(help="Benchmark surya layout for reading order.") 18 | @click.option( 19 | "--results_dir", 20 | type=str, 21 | help="Path to JSON file with benchmark results.", 22 | default=os.path.join(settings.RESULT_DIR, "benchmark"), 23 | ) 24 | @click.option( 25 | "--max_rows", 26 | type=int, 27 | help="Maximum number of images to run benchmark on.", 28 | default=None, 29 | ) 30 | def main(results_dir: str, max_rows: int): 31 | foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) 32 | layout_predictor = LayoutPredictor(foundation_predictor) 33 | pathname = "order_bench" 34 | # These have already been shuffled randomly, so sampling from the start is fine 35 | split = "train" 36 | if max_rows is not None: 37 | split = f"train[:{max_rows}]" 38 | dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split) 39 | images = list(dataset["image"]) 40 | images = convert_if_not_rgb(images) 41 | 42 | start = time.time() 43 | layout_predictions = layout_predictor(images) 44 | surya_time = time.time() - start 45 | 46 | folder_name = os.path.basename(pathname).split(".")[0] 47 | result_path = os.path.join(results_dir, folder_name) 48 | os.makedirs(result_path, exist_ok=True) 49 | 50 | page_metrics = collections.OrderedDict() 51 | mean_accuracy = 0 52 | for idx, order_pred in enumerate(layout_predictions): 53 | row = dataset[idx] 54 | labels = row["labels"] 55 | bboxes = row["bboxes"] 56 | pred_positions = [] 57 | for label, bbox in zip(labels, bboxes): 58 | max_intersection = 0 59 | matching_idx = 0 60 | for pred_box in order_pred.bboxes: 61 | intersection = pred_box.intersection_pct(PolygonBox(polygon=bbox)) 62 | if intersection > max_intersection: 63 | max_intersection = intersection 64 | matching_idx = pred_box.position 65 | pred_positions.append(matching_idx) 66 | accuracy = rank_accuracy(pred_positions, labels) 67 | mean_accuracy += accuracy 68 | page_results = {"accuracy": accuracy, "box_count": len(labels)} 69 | 70 | page_metrics[idx] = page_results 71 | 72 | mean_accuracy /= len(layout_predictions) 73 | 74 | out_data = { 75 | "time": surya_time, 76 | "mean_accuracy": mean_accuracy, 77 | "page_metrics": page_metrics, 78 | } 79 | 80 | with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: 81 | json.dump(out_data, f, indent=4) 82 | 83 | print(f"Mean accuracy is {mean_accuracy:.2f}.") 84 | print( 85 | f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total." 86 | ) 87 | print("Mean accuracy is the % of correct ranking pairs.") 88 | print(f"Wrote results to {result_path}") 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | ``` -------------------------------------------------------------------------------- /surya/debug/text.py: -------------------------------------------------------------------------------- ```python 1 | import re 2 | from io import BytesIO 3 | from typing import List, Tuple 4 | from PIL import Image, ImageDraw, ImageFont 5 | 6 | from surya.debug.fonts import get_font_path 7 | from surya.debug.render_html import render_text_as_html 8 | 9 | try: 10 | from playwright.sync_api import sync_playwright 11 | 12 | has_playwright = True 13 | except ImportError: 14 | has_playwright = False 15 | 16 | 17 | def strip_html_tags(html_text): 18 | pattern = re.compile(r"<[\w/][^>]*>") 19 | text_only = pattern.sub("", html_text) 20 | 21 | return text_only 22 | 23 | 24 | def get_text_size(text, font): 25 | im = Image.new(mode="P", size=(0, 0)) 26 | draw = ImageDraw.Draw(im) 27 | _, _, width, height = draw.textbbox((0, 0), text=text, font=font) 28 | return width, height 29 | 30 | 31 | def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size): 32 | font = ImageFont.truetype(font_path, box_font_size) 33 | text_width, text_height = get_text_size(text, font) 34 | while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6: 35 | box_font_size = box_font_size - 1 36 | font = ImageFont.truetype(font_path, box_font_size) 37 | text_width, text_height = get_text_size(text, font) 38 | 39 | # Calculate text position (centered in bbox) 40 | text_width, text_height = get_text_size(text, font) 41 | x = s_bbox[0] 42 | y = s_bbox[1] + (bbox_height - text_height) / 2 43 | 44 | draw.text((x, y), text, fill="black", font=font) 45 | 46 | 47 | def draw_text_with_playwright( 48 | bboxes, texts: List[str], image_size: Tuple[int, int] 49 | ) -> Image.Image: 50 | html_content, image_size = render_text_as_html(bboxes, texts, image_size) 51 | if not has_playwright: 52 | raise ImportError( 53 | "Playwright is not installed. Please install it using `pip install playwright`" 54 | ) 55 | 56 | with sync_playwright() as p: 57 | browser = p.chromium.launch(headless=True) 58 | page = browser.new_page( 59 | viewport={"width": image_size[0], "height": image_size[1]} 60 | ) 61 | page.set_content(html_content) 62 | page.wait_for_timeout(1000) 63 | body = page.query_selector("body") 64 | image = body.screenshot() 65 | browser.close() 66 | 67 | pil_img = Image.open(BytesIO(image)) 68 | return pil_img 69 | 70 | 71 | def draw_text_on_image( 72 | bboxes, 73 | texts, 74 | image_size: Tuple[int, int], 75 | font_path=None, 76 | max_font_size=60, 77 | res_upscale=2, 78 | ) -> Image.Image: 79 | if has_playwright: 80 | return draw_text_with_playwright(bboxes, texts, image_size) 81 | 82 | texts = [strip_html_tags(text) for text in texts] 83 | if font_path is None: 84 | font_path = get_font_path() 85 | new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale) 86 | image = Image.new("RGB", new_image_size, color="white") 87 | draw = ImageDraw.Draw(image) 88 | 89 | for bbox, text in zip(bboxes, texts): 90 | s_bbox = [int(coord * res_upscale) for coord in bbox] 91 | bbox_width = s_bbox[2] - s_bbox[0] 92 | bbox_height = s_bbox[3] - s_bbox[1] 93 | 94 | # Shrink the text to fit in the bbox if needed 95 | box_font_size = max(6, min(int(0.75 * bbox_height), max_font_size)) 96 | render_text( 97 | draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size 98 | ) 99 | 100 | return image 101 | ``` -------------------------------------------------------------------------------- /surya/recognition/postprocessing.py: -------------------------------------------------------------------------------- ```python 1 | import re 2 | from typing import List, Dict 3 | 4 | from surya.recognition.schema import TextChar 5 | 6 | 7 | def truncate_repetitions(text: str, min_len=15): 8 | # From nougat, with some cleanup 9 | if len(text) < 2 * min_len: 10 | return text 11 | 12 | # try to find a length at which the tail is repeating 13 | max_rep_len = None 14 | for rep_len in range(min_len, int(len(text) / 2)): 15 | # check if there is a repetition at the end 16 | same = True 17 | for i in range(0, rep_len): 18 | if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]: 19 | same = False 20 | break 21 | 22 | if same: 23 | max_rep_len = rep_len 24 | 25 | if max_rep_len is None: 26 | return text 27 | 28 | lcs = text[-max_rep_len:] 29 | 30 | # remove all but the last repetition 31 | text_to_truncate = text 32 | while text_to_truncate.endswith(lcs): 33 | text_to_truncate = text_to_truncate[:-max_rep_len] 34 | 35 | return text[: len(text_to_truncate)] 36 | 37 | 38 | def extract_tags(proposed_tags: List[str]) -> List[str]: 39 | tags = [] 40 | for tag in proposed_tags: 41 | tag_match = re.match(tag_pattern, tag) 42 | if not tag_match: 43 | continue 44 | 45 | if not tag_match.group(1) == "/": 46 | continue 47 | 48 | tags.append(tag_match.group(2)) 49 | return tags 50 | 51 | 52 | tag_pattern = re.compile(r"<(/?)([a-z]+)([^>]*)>?", re.IGNORECASE) 53 | 54 | 55 | def cleanup_math(line: str): 56 | matches = re.finditer(r"(<math[^>]*>)(.*?)</math>", line, re.DOTALL) 57 | result = line 58 | 59 | for match in matches: 60 | opening_tag = match.group(1) # The opening <math> tag with attributes 61 | full_match = match.group(0) # The entire <math>content</math> tag 62 | block_content = match.group(2) # Just the content inside the tags 63 | 64 | clean_block = re.sub(r"<[^>]+>", "", block_content) 65 | 66 | if not re.search(r"[\\\_]", clean_block): 67 | result = result.replace(full_match, clean_block) 68 | else: 69 | result = result.replace(full_match, f"{opening_tag}{clean_block}</math>") 70 | 71 | return result 72 | 73 | 74 | def fix_unbalanced_tags( 75 | text_chars: List[TextChar], special_tokens: Dict[str, list] 76 | ) -> List[TextChar]: 77 | self_closing_tags = ["br"] 78 | 79 | open_tags = [] 80 | 81 | format_tags = extract_tags(special_tokens["formatting"]) + extract_tags( 82 | special_tokens["math_external"] 83 | ) 84 | 85 | for char in text_chars: 86 | if len(char.text) <= 1: 87 | continue 88 | 89 | tag_match = re.match(tag_pattern, char.text) 90 | if not tag_match: 91 | continue 92 | 93 | is_closing = tag_match.group(1) == "/" 94 | tag_name = tag_match.group(2).lower() 95 | 96 | if tag_name not in format_tags: 97 | continue 98 | 99 | if tag_name in self_closing_tags: 100 | continue 101 | 102 | # Self-closing tags 103 | if tag_match.group(3) and tag_match.group(3).strip().endswith("/"): 104 | continue 105 | 106 | if is_closing: 107 | if open_tags and open_tags[-1] == tag_name: 108 | open_tags.pop() 109 | else: 110 | open_tags.append(tag_name) 111 | 112 | for tag in open_tags: 113 | text_chars.append( 114 | TextChar( 115 | text=f"</{tag}>", 116 | confidence=0, 117 | polygon=[[0, 0], [1, 0], [1, 1], [0, 1]], 118 | bbox_valid=False, 119 | ) 120 | ) 121 | return text_chars 122 | ``` -------------------------------------------------------------------------------- /surya/common/surya/config.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Optional 2 | from transformers import PretrainedConfig 3 | 4 | from surya.common.s3 import S3DownloaderMixin 5 | from surya.common.surya.encoder.config import SuryaEncoderConfig 6 | from surya.common.surya.decoder.config import SuryaDecoderConfig 7 | 8 | 9 | class SuryaModelConfig(S3DownloaderMixin, PretrainedConfig): 10 | model_type = "surya-multimodal-foundation" 11 | is_composition = True 12 | 13 | def __init__( 14 | self, 15 | vocab_size=65536, 16 | bbox_size=1025, 17 | blank_bbox_token_id=1025, 18 | bos_token_id=0, 19 | eos_token_id=1, 20 | pad_token_id=2, 21 | image_token_id=3, 22 | register_token_ids=(4, 5, 6, 7), 23 | eoi_token_id=8, 24 | beacon_token_id=9, 25 | special_token_count=4, 26 | max_sequence_length=1536, 27 | special_ocr_tokens=None, 28 | vision_encoder=None, 29 | decoder=None, 30 | tasks: dict | None = None, 31 | bbox_embed_size: int = 64, 32 | num_register_tokens: int = 4, 33 | image_embed_encoding_size: int = 1024, 34 | image_embed_encoding_multiplier: int = 256, 35 | num_beacon_tokens: int = 1, 36 | beacon_token_interval: int = 4096, 37 | sliding_window: Optional[int] = None, 38 | multi_output_distance: int = 4, 39 | max_multi_out: int = 8, 40 | **kwargs, 41 | ): 42 | super().__init__(**kwargs) 43 | self.is_encoder_decoder = False 44 | self.vocab_size = vocab_size 45 | self.bbox_size = bbox_size 46 | self.blank_bbox_token_id = blank_bbox_token_id 47 | self.image_token_id = image_token_id 48 | self.bos_token_id = bos_token_id 49 | self.eos_token_id = eos_token_id 50 | self.pad_token_id = pad_token_id 51 | self.eoi_token_id = eoi_token_id 52 | self.beacon_token_id = beacon_token_id 53 | self.special_ocr_tokens = special_ocr_tokens 54 | self.special_token_count = special_token_count # pad, bos, etc, tokens 55 | self.max_sequence_length = max_sequence_length 56 | self.tasks = tasks 57 | self.tie_word_embeddings = True 58 | self.bbox_embed_size = bbox_embed_size 59 | self.num_register_tokens = num_register_tokens 60 | self.register_token_ids = register_token_ids 61 | self.image_embed_encoding_size = image_embed_encoding_size 62 | self.image_embed_encoding_multiplier = image_embed_encoding_multiplier 63 | self.num_beacon_tokens = num_beacon_tokens 64 | self.beacon_token_interval = beacon_token_interval 65 | self.sliding_window = sliding_window 66 | self.multi_output_distance = multi_output_distance 67 | self.max_multi_out = max_multi_out 68 | 69 | if self.sliding_window is None: 70 | self.sliding_window = self.max_sequence_length 71 | 72 | if isinstance(vision_encoder, dict): 73 | vision_encoder = SuryaEncoderConfig(**vision_encoder) 74 | elif vision_encoder is None: 75 | vision_encoder = SuryaEncoderConfig() 76 | self.vision_encoder = vision_encoder 77 | 78 | if isinstance(decoder, dict): 79 | decoder = SuryaDecoderConfig(**decoder) 80 | elif decoder is None: 81 | decoder = SuryaDecoderConfig() 82 | self.decoder = decoder 83 | 84 | self.hidden_size = self.decoder.hidden_size 85 | 86 | self.patch_size = self.vision_encoder.spatial_patch_size 87 | self.merge_size = self.vision_encoder.spatial_merge_size 88 | ``` -------------------------------------------------------------------------------- /surya/table_rec/processor.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List 2 | 3 | import PIL 4 | import torch 5 | from transformers import ProcessorMixin 6 | 7 | from surya.common.s3 import S3DownloaderMixin 8 | from surya.common.donut.processor import SuryaEncoderImageProcessor 9 | from surya.table_rec.shaper import LabelShaper 10 | from surya.settings import settings 11 | from surya.table_rec.model.config import BOX_DIM, SPECIAL_TOKENS 12 | 13 | 14 | class SuryaTableRecProcessor(S3DownloaderMixin, ProcessorMixin): 15 | attributes = ["image_processor"] 16 | image_processor_class = "AutoImageProcessor" 17 | 18 | def __init__(self, checkpoint, **kwargs): 19 | image_processor = SuryaEncoderImageProcessor.from_pretrained(checkpoint) 20 | image_processor.do_align_long_axis = False 21 | image_processor.max_size = settings.TABLE_REC_IMAGE_SIZE 22 | self.image_processor = image_processor 23 | super().__init__(image_processor) 24 | 25 | self.box_size = (BOX_DIM, BOX_DIM) 26 | self.special_token_count = SPECIAL_TOKENS 27 | self.shaper = LabelShaper() 28 | 29 | def resize_polygon(self, polygon, orig_size, new_size): 30 | w_scaler = new_size[0] / orig_size[0] 31 | h_scaler = new_size[1] / orig_size[1] 32 | 33 | for corner in polygon: 34 | corner[0] = corner[0] * w_scaler 35 | corner[1] = corner[1] * h_scaler 36 | 37 | if corner[0] < 0: 38 | corner[0] = 0 39 | if corner[1] < 0: 40 | corner[1] = 0 41 | if corner[0] > new_size[0]: 42 | corner[0] = new_size[0] 43 | if corner[1] > new_size[1]: 44 | corner[1] = new_size[1] 45 | 46 | return polygon 47 | 48 | def __call__( 49 | self, 50 | images: List[PIL.Image.Image] | None, 51 | query_items: List[dict], 52 | columns: List[dict] | None = None, 53 | convert_images: bool = True, 54 | *args, 55 | **kwargs 56 | ): 57 | if convert_images: 58 | assert len(images) == len(query_items) 59 | assert len(images) > 0 60 | 61 | # Resize input query items 62 | for image, query_item in zip(images, query_items): 63 | query_item["polygon"] = self.resize_polygon(query_item["polygon"], image.size, self.box_size) 64 | 65 | query_items = self.shaper.convert_polygons_to_bboxes(query_items) 66 | query_labels = self.shaper.dict_to_labels(query_items) 67 | 68 | decoder_input_boxes = [] 69 | col_count = len(query_labels[0]) 70 | for label in query_labels: 71 | decoder_input_boxes.append([ 72 | [self.token_bos_id] * col_count, 73 | label, 74 | [self.token_query_end_id] * col_count 75 | ]) 76 | 77 | # Add columns to end of decoder input 78 | if columns: 79 | columns = self.shaper.convert_polygons_to_bboxes(columns) 80 | column_labels = self.shaper.dict_to_labels(columns) 81 | for decoder_box in decoder_input_boxes: 82 | decoder_box += column_labels 83 | 84 | input_boxes = torch.tensor(decoder_input_boxes, dtype=torch.long) 85 | input_boxes_mask = torch.ones_like(input_boxes, dtype=torch.long) 86 | 87 | inputs = { 88 | "input_ids": input_boxes, 89 | "attention_mask": input_boxes_mask 90 | } 91 | if convert_images: 92 | inputs["pixel_values"] = self.image_processor(images, *args, **kwargs)["pixel_values"] 93 | return inputs 94 | ```