This is page 1 of 4. Use http://codebase.md/datalab-to/surya?page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── ISSUE_TEMPLATE │ │ ├── breaking-bug-report.md │ │ ├── feature_request.md │ │ └── output-bug-report.md │ └── workflows │ ├── benchmarks.yml │ ├── ci.yml │ ├── cla.yml │ ├── publish.yml │ └── scripts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── benchmark │ ├── detection.py │ ├── layout.py │ ├── ordering.py │ ├── recognition.py │ ├── table_recognition.py │ ├── texify.py │ └── utils │ ├── __init__.py │ ├── bbox.py │ ├── metrics.py │ ├── scoring.py │ ├── tatr.py │ ├── tesseract.py │ ├── textract.py │ └── verify_benchmark_scores.py ├── CITATION.cff ├── CLA.md ├── detect_layout.py ├── detect_text.py ├── LICENSE ├── ocr_app.py ├── ocr_latex.py ├── ocr_text.py ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── README.md ├── signatures │ └── version1 │ └── cla.json ├── static │ ├── fonts │ │ └── .gitignore │ └── images │ ├── arabic_layout.jpg │ ├── arabic_reading.jpg │ ├── arabic_text.jpg │ ├── arabic.jpg │ ├── benchmark_chart_small.png │ ├── benchmark_chart.png │ ├── benchmark_layout_chart.png │ ├── benchmark_rec_chart.png │ ├── benchmark_tablerec_acc.png │ ├── benchmark_tablerec_speed.png │ ├── chi_hind_layout.jpg │ ├── chi_hind_orig.jpg │ ├── chi_hind_reading.jpg │ ├── chi_hind_text.jpg │ ├── chi_hind.jpg │ ├── chinese_layout.jpg │ ├── chinese_reading.jpg │ ├── chinese_text.jpg │ ├── chinese.jpg │ ├── excerpt_layout.png │ ├── excerpt_reading.jpg │ ├── excerpt_text.png │ ├── excerpt.png │ ├── funsd_layout.jpg │ ├── funsd_reading.jpg │ ├── funsd_text.jpg │ ├── funsd.png │ ├── gcloud_full_langs.png │ ├── gcloud_rec_bench.png │ ├── hindi_layout.jpg │ ├── hindi_reading.jpg │ ├── hindi_text.jpg │ ├── hindi.jpg │ ├── japanese_layout.jpg │ ├── japanese_reading.jpg │ ├── japanese_tablerec.png │ ├── japanese_text.jpg │ ├── japanese.jpg │ ├── latex_ocr.png │ ├── nyt_layout.jpg │ ├── nyt_order.jpg │ ├── nyt_text.jpg │ ├── nyt.jpg │ ├── paper_layout.jpg │ ├── paper_reading.jpg │ ├── paper_tablerec.png │ ├── paper_text.jpg │ ├── paper.jpg │ ├── pres_layout.jpg │ ├── pres_reading.jpg │ ├── pres_tablerec.png │ ├── pres_text.jpg │ ├── pres.png │ ├── rec_acc_table.png │ ├── scanned_layout.jpg │ ├── scanned_reading.jpg │ ├── scanned_tablerec.png │ ├── scanned_tablerec2.png │ ├── scanned_text.jpg │ ├── scanned.png │ ├── surya_rec_perf.png │ ├── table_rec.png │ ├── textbook_layout.jpg │ ├── textbook_order.jpg │ ├── textbook_text.jpg │ └── textbook.jpg ├── surya │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── adetr │ │ │ └── decoder.py │ │ ├── donut │ │ │ ├── encoder.py │ │ │ └── processor.py │ │ ├── load.py │ │ ├── polygon.py │ │ ├── predictor.py │ │ ├── pretrained.py │ │ ├── s3.py │ │ ├── surya │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── decoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── embedder │ │ │ │ └── __init__.py │ │ │ ├── encoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── flash_attn_utils.py │ │ │ ├── processor │ │ │ │ ├── __init__.py │ │ │ │ ├── schema.py │ │ │ │ └── tokenizer.py │ │ │ └── schema.py │ │ ├── util.py │ │ └── xla.py │ ├── debug │ │ ├── draw.py │ │ ├── fonts.py │ │ ├── katex.js │ │ ├── render_html.py │ │ └── text.py │ ├── detection │ │ ├── __init__.py │ │ ├── heatmap.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoderdecoder.py │ │ ├── parallel.py │ │ ├── processor.py │ │ ├── schema.py │ │ └── util.py │ ├── foundation │ │ ├── __init__.py │ │ ├── cache │ │ │ ├── __init__.py │ │ │ ├── dynamic_ops.py │ │ │ └── static_ops.py │ │ ├── loader.py │ │ └── util.py │ ├── input │ │ ├── load.py │ │ └── processing.py │ ├── layout │ │ ├── __init__.py │ │ ├── label.py │ │ └── schema.py │ ├── logging.py │ ├── models.py │ ├── ocr_error │ │ ├── __init__.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoder.py │ │ ├── schema.py │ │ └── tokenizer.py │ ├── recognition │ │ ├── __init__.py │ │ ├── languages.py │ │ ├── postprocessing.py │ │ ├── schema.py │ │ └── util.py │ ├── scripts │ │ ├── __init__.py │ │ ├── config.py │ │ ├── detect_layout.py │ │ ├── detect_text.py │ │ ├── finetune_ocr.py │ │ ├── hf_to_s3.py │ │ ├── ocr_latex.py │ │ ├── ocr_text.py │ │ ├── run_streamlit_app.py │ │ ├── run_texify_app.py │ │ ├── streamlit_app.py │ │ ├── table_recognition.py │ │ └── texify_app.py │ ├── settings.py │ └── table_rec │ ├── __init__.py │ ├── loader.py │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── encoderdecoder.py │ ├── processor.py │ ├── schema.py │ └── shaper.py ├── table_recognition.py ├── tests │ ├── assets │ │ └── test_latex.png │ ├── conftest.py │ ├── test_detection.py │ ├── test_foundation.py │ ├── test_latex_ocr.py │ ├── test_layout.py │ ├── test_ocr_errors.py │ ├── test_recognition.py │ └── test_table_rec.py └── texify_app.py ``` # Files -------------------------------------------------------------------------------- /static/fonts/.gitignore: -------------------------------------------------------------------------------- ``` * !.gitignore ``` -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- ```yaml repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.9.10 hooks: # Run the linter. - id: ruff types_or: [ python, pyi ] args: [ --fix ] # Run the formatter. - id: ruff-format types_or: [ python, pyi ] ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- ``` private.py .DS_Store local.env experiments test_data training wandb notebooks results data slices # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- ```markdown # Surya Surya is a document OCR toolkit that does: - OCR in 90+ languages that benchmarks favorably vs cloud services - Line-level text detection in any language - Layout analysis (table, image, header, etc detection) - Reading order detection - Table recognition (detecting rows/columns) - LaTeX OCR It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details). For our managed API or on-prem document intelligence solution, check out [our platform here](https://datalab.to?utm_source=gh-surya). | Detection | OCR | |:----------------------------------------------------------------:|:-----------------------------------------------------------------------:| | <img src="static/images/excerpt.png" width="500px"/> | <img src="static/images/excerpt_text.png" width="500px"/> | | Layout | Reading Order | |:------------------------------------------------------------------:|:--------------------------------------------------------------------------:| | <img src="static/images/excerpt_layout.png" width="500px"/> | <img src="static/images/excerpt_reading.jpg" width="500px"/> | | Table Recognition | LaTeX OCR | |:-------------------------------------------------------------:|:------------------------------------------------------:| | <img src="static/images/scanned_tablerec.png" width="500px"/> | <img src="static/images/latex_ocr.png" width="500px"/> | Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision. ## Community [Discord](https://discord.gg//KuZwXNGnfH) is where we discuss future development. ## Examples | Name | Detection | OCR | Layout | Order | Table Rec | |------------------|:-----------------------------------:|-----------------------------------------:|-------------------------------------------:|--------------------------------------------:|---------------------------------------------:| | Japanese | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | [Image](static/images/japanese_reading.jpg) | [Image](static/images/japanese_tablerec.png) | | Chinese | [Image](static/images/chinese.jpg) | [Image](static/images/chinese_text.jpg) | [Image](static/images/chinese_layout.jpg) | [Image](static/images/chinese_reading.jpg) | | | Hindi | [Image](static/images/hindi.jpg) | [Image](static/images/hindi_text.jpg) | [Image](static/images/hindi_layout.jpg) | [Image](static/images/hindi_reading.jpg) | | | Arabic | [Image](static/images/arabic.jpg) | [Image](static/images/arabic_text.jpg) | [Image](static/images/arabic_layout.jpg) | [Image](static/images/arabic_reading.jpg) | | | Chinese + Hindi | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | [Image](static/images/chi_hind_reading.jpg) | | | Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) | [Image](static/images/pres_reading.jpg) | [Image](static/images/pres_tablerec.png) | | Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) | [Image](static/images/paper_reading.jpg) | [Image](static/images/paper_tablerec.png) | | Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) | [Image](static/images/scanned_reading.jpg) | [Image](static/images/scanned_tablerec.png) | | New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | [Image](static/images/nyt_order.jpg) | | | Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | [Image](static/images/funsd_layout.jpg) | [Image](static/images/funsd_reading.jpg) | [Image](static/images/scanned_tablerec2.png) | | Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | [Image](static/images/textbook_order.jpg) | | # Hosted API There is a hosted API for all surya models available [here](https://www.datalab.to?utm_source=gh-surya): - Works with PDF, images, word docs, and powerpoints - Consistent speed, with no latency spikes - High reliability and uptime # Commercial usage Our model weights use a modified AI Pubs Open Rail-M license (free for research, personal use, and startups under $2M funding/revenue) and our code is GPL. For broader commercial licensing or to remove GPL requirements, visit our pricing page [here](https://www.datalab.to/pricing?utm_source=gh-surya). # Installation You'll need python 3.10+ and PyTorch. You may need to install the CPU version of torch first if you're not using a Mac or a GPU machine. See [here](https://pytorch.org/get-started/locally/) for more details. Install with: ```shell pip install surya-ocr ``` Model weights will automatically download the first time you run surya. # Usage - Inspect the settings in `surya/settings.py`. You can override any settings with environment variables. - Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`. ## Interactive App I've included a streamlit app that lets you interactively try Surya on images or PDF files. Run it with: ```shell pip install streamlit pdftext surya_gui ``` ## OCR (text recognition) This command will write out a json file with the detected text and bboxes: ```shell surya_ocr DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--task_name` will specify which task to use for predicting the lines. `ocr_with_boxes` is the default, which will format text and give you bboxes. If you get bad performance, try `ocr_without_boxes`, which will give you potentially better performance but no bboxes. For blocks like equations and paragraphs, try `block_without_boxes`. - `--images` will save images of the pages and detected text lines (optional) - `--output_dir` specifies the directory to save results to instead of the default - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. - `--disable_math` - by default, surya will recognize math in text. This can lead to false positives - you can disable this with this flag. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: - `text_lines` - the detected text and bounding boxes for each line - `text` - the text in the line - `confidence` - the confidence of the model in the detected text (0-1) - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. - `chars` - the individual characters in the line - `text` - the text of the character - `bbox` - the character bbox (same format as line bbox) - `polygon` - the character polygon (same format as line polygon) - `confidence` - the confidence of the model in the detected character (0-1) - `bbox_valid` - if the character is a special token or math, the bbox may not be valid - `words` - the individual words in the line (computed from the characters) - `text` - the text of the word - `bbox` - the word bbox (same format as line bbox) - `polygon` - the word polygon (same format as line polygon) - `confidence` - mean character confidence - `bbox_valid` - if the word is a special token or math, the bbox may not be valid - `page` - the page number in the file - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. **Performance tips** Setting the `RECOGNITION_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `40MB` of VRAM, so very high batch sizes are possible. The default is a batch size `512`, which will use about 20GB of VRAM. Depending on your CPU core count, it may help, too - the default CPU batch size is `32`. ### From python ```python from PIL import Image from surya.foundation import FoundationPredictor from surya.recognition import RecognitionPredictor from surya.detection import DetectionPredictor image = Image.open(IMAGE_PATH) foundation_predictor = FoundationPredictor() recognition_predictor = RecognitionPredictor(foundation_predictor) detection_predictor = DetectionPredictor() predictions = recognition_predictor([image], det_predictor=detection_predictor) ``` ## Text line detection This command will write out a json file with the detected bboxes. ```shell surya_detect DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--images` will save images of the pages and detected text lines (optional) - `--output_dir` specifies the directory to save results to instead of the default - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: - `bboxes` - detected bounding boxes for text - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. - `confidence` - the confidence of the model in the detected text (0-1) - `vertical_lines` - vertical lines detected in the document - `bbox` - the axis-aligned line coordinates. - `page` - the page number in the file - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. **Performance tips** Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `440MB` of VRAM, so very high batch sizes are possible. The default is a batch size `36`, which will use about 16GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `6`. ### From python ```python from PIL import Image from surya.detection import DetectionPredictor image = Image.open(IMAGE_PATH) det_predictor = DetectionPredictor() # predictions is a list of dicts, one per image predictions = det_predictor([image]) ``` ## Layout and reading order This command will write out a json file with the detected layout and reading order. ```shell surya_layout DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--images` will save images of the pages and detected text lines (optional) - `--output_dir` specifies the directory to save results to instead of the default - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: - `bboxes` - detected bounding boxes for text - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. - `position` - the reading order of the box. - `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`. - `top_k` - the top-k other potential labels for the box. A dictionary with labels as keys and confidences as values. - `page` - the page number in the file - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. **Performance tips** Setting the `LAYOUT_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `220MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 7GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`. ### From python ```python from PIL import Image from surya.foundation import FoundationPredictor from surya.layout import LayoutPredictor from surya.settings import settings image = Image.open(IMAGE_PATH) layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)) # layout_predictions is a list of dicts, one per image layout_predictions = layout_predictor([image]) ``` ## Table Recognition This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes. If you want to get cell positions and text, along with nice formatting, check out the [marker](https://www.github.com/VikParuchuri/marker) repo. You can use the `TableConverter` to detect and extract tables in images and PDFs. It supports output in json (with bboxes), markdown, and html. ```shell surya_table DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--images` will save images of the pages and detected table cells + rows and columns (optional) - `--output_dir` specifies the directory to save results to instead of the default - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. - `--detect_boxes` specifies if cells should be detected. By default, they're pulled out of the PDF, but this is not always possible. - `--skip_table_detection` tells table recognition not to detect tables first. Use this if your image is already cropped to a table. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: - `rows` - detected table rows - `bbox` - the bounding box of the table row - `row_id` - the id of the row - `is_header` - if it is a header row. - `cols` - detected table columns - `bbox` - the bounding box of the table column - `col_id`- the id of the column - `is_header` - if it is a header column - `cells` - detected table cells - `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. - `text` - if text could be pulled out of the pdf, the text of this cell. - `row_id` - the id of the row the cell belongs to. - `col_id` - the id of the column the cell belongs to. - `colspan` - the number of columns spanned by the cell. - `rowspan` - the number of rows spanned by the cell. - `is_header` - whether it is a header cell. - `page` - the page number in the file - `table_idx` - the index of the table on the page (sorted in vertical order) - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. **Performance tips** Setting the `TABLE_REC_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `150MB` of VRAM, so very high batch sizes are possible. The default is a batch size `64`, which will use about 10GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `8`. ### From python ```python from PIL import Image from surya.table_rec import TableRecPredictor image = Image.open(IMAGE_PATH) table_rec_predictor = TableRecPredictor() table_predictions = table_rec_predictor([image]) ``` ## LaTeX OCR This command will write out a json file with the LaTeX of the equations. You must pass in images that are already cropped to the equations. You can do this by running the layout model, then cropping, if you want. ```shell surya_latex_ocr DATA_PATH ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--output_dir` specifies the directory to save results to instead of the default - `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. See the OCR section above for the format of the output. ### From python ```python from PIL import Image from surya.texify import TexifyPredictor image = Image.open(IMAGE_PATH) predictor = TexifyPredictor() predictor([image]) ``` ### Interactive app You can also run a special interactive app that lets you select equations and OCR them (kind of like MathPix snip) with: ```shell pip install streamlit==1.40 streamlit-drawable-canvas-jsretry texify_gui ``` ## Compilation The following models have support for compilation. You will need to set the following environment variables to enable compilation: - Detection: `COMPILE_DETECTOR=true` - Layout: `COMPILE_LAYOUT=true` - Table recognition: `COMPILE_TABLE_REC=true` Alternatively, you can also set `COMPILE_ALL=true` which will compile all models. Here are the speedups on an A10 GPU: | Model | Time per page (s) | Compiled time per page (s) | Speedup (%) | | ----------------- | ----------------- | -------------------------- | ----------- | | Detection | 0.108808 | 0.10521 | 3.306742151 | | Layout | 0.27319 | 0.27063 | 0.93707676 | | Table recognition | 0.0219 | 0.01938 | 11.50684932 | # Limitations - This is specialized for document OCR. It will likely not work on photos or other images. - It is for printed text, not handwriting (though it may work on some handwriting). - The text detection model has trained itself to ignore advertisements. - You can find language support for OCR in `surya/recognition/languages.py`. Text detection, layout analysis, and reading order will work with any language. ## Troubleshooting If OCR isn't working properly: - Try increasing resolution of the image so the text is bigger. If the resolution is already very high, try decreasing it to no more than a `2048px` width. - Preprocessing the image (binarizing, deskewing, etc) can help with very old/blurry images. - You can adjust `DETECTOR_BLANK_THRESHOLD` and `DETECTOR_TEXT_THRESHOLD` if you don't get good results. `DETECTOR_BLANK_THRESHOLD` controls the space between lines - any prediction below this number will be considered blank space. `DETECTOR_TEXT_THRESHOLD` controls how text is joined - any number above this is considered text. `DETECTOR_TEXT_THRESHOLD` should always be higher than `DETECTOR_BLANK_THRESHOLD`, and both should be in the 0-1 range. Looking at the heatmap from the debug output of the detector can tell you how to adjust these (if you see faint things that look like boxes, lower the thresholds, and if you see bboxes being joined together, raise the thresholds). # Manual install If you want to develop surya, you can install it manually: - `git clone https://github.com/VikParuchuri/surya.git` - `cd surya` - `poetry install` - installs main and dev dependencies - `poetry shell` - activates the virtual environment # Benchmarks ## OCR  | Model | Time per page (s) | Avg similarity (⬆) | |-----------|-------------------|--------------------| | surya | .62 | 0.97 | | tesseract | .45 | 0.88 | [Full language results](static/images/rec_acc_table.png) Tesseract is CPU-based, and surya is CPU or GPU. I tried to cost-match the resources used, so I used a 1xA6000 (48GB VRAM) for surya, and 28 CPU cores for Tesseract (same price on Lambda Labs/DigitalOcean). ### Google Cloud Vision I benchmarked OCR against Google Cloud vision since it has similar language coverage to Surya.  [Full language results](static/images/gcloud_full_langs.png) **Methodology** I measured normalized sentence similarity (0-1, higher is better) based on a set of real-world and synthetic pdfs. I sampled PDFs from common crawl, then filtered out the ones with bad OCR. I couldn't find PDFs for some languages, so I also generated simple synthetic PDFs for those. I used the reference line bboxes from the PDFs with both tesseract and surya, to just evaluate the OCR quality. For Google Cloud, I aligned the output from Google Cloud with the ground truth. I had to skip RTL languages since they didn't align well. ## Text line detection  | Model | Time (s) | Time per page (s) | precision | recall | |-----------|------------|---------------------|-------------|----------| | surya | 47.2285 | 0.094452 | 0.835857 | 0.960807 | | tesseract | 74.4546 | 0.290838 | 0.631498 | 0.997694 | Tesseract is CPU-based, and surya is CPU or GPU. I ran the benchmarks on a system with an A10 GPU, and a 32 core CPU. This was the resource usage: - tesseract - 32 CPU cores, or 8 workers using 4 cores each - surya - 36 batch size, for 16GB VRAM usage **Methodology** Surya predicts line-level bboxes, while tesseract and others predict word-level or character-level. It's hard to find 100% correct datasets with line-level annotations. Merging bboxes can be noisy, so I chose not to use IoU as the metric for evaluation. I instead used coverage, which calculates: - Precision - how well the predicted bboxes cover ground truth bboxes - Recall - how well ground truth bboxes cover predicted bboxes First calculate coverage for each bbox, then add a small penalty for double coverage, since we want the detection to have non-overlapping bboxes. Anything with a coverage of 0.5 or higher is considered a match. Then we calculate precision and recall for the whole dataset. ## Layout analysis | Layout Type | precision | recall | |---------------|-------------|----------| | Image | 0.91265 | 0.93976 | | List | 0.80849 | 0.86792 | | Table | 0.84957 | 0.96104 | | Text | 0.93019 | 0.94571 | | Title | 0.92102 | 0.95404 | Time per image - .13 seconds on GPU (A10). **Methodology** I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/PubLayNet), which was not in the training data. I had to align publaynet labels with the surya layout labels. I was then able to find coverage for each layout type: - Precision - how well the predicted bboxes cover ground truth bboxes - Recall - how well ground truth bboxes cover predicted bboxes ## Reading Order 88% mean accuracy, and .4 seconds per image on an A10 GPU. See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check. **Methodology** I benchmarked the reading order on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth. The accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct. ## Table Recognition | Model | Row Intersection | Col Intersection | Time Per Image | |-------------------|--------------------|--------------------|------------------| | Surya | 1 | 0.98625 | 0.30202 | | Table transformer | 0.84 | 0.86857 | 0.08082 | Higher is better for intersection, which the percentage of the actual row/column overlapped by the predictions. This benchmark is mostly a sanity check - there is a more rigorous one in [marker](https://www.github.com/VikParuchuri/marker) **Methodology** The benchmark uses a subset of [Fintabnet](https://developer.ibm.com/exchanges/data/all/fintabnet/) from IBM. It has labeled rows and columns. After table recognition is run, the predicted rows and columns are compared to the ground truth. There is an additional penalty for predicting too many or too few rows/columns. ## LaTeX OCR | Method | edit ⬇ | time taken (s) ⬇ | |--------|----------|------------------| | texify | 0.122617 | 35.6345 | This inferences texify on a ground truth set of LaTeX, then does edit distance. This is a bit noisy, since 2 LaTeX strings that render the same can have different symbols in them. ## Running your own benchmarks You can benchmark the performance of surya on your machine. - Follow the manual install instructions above. - `poetry install --group dev` - installs dev dependencies **Text line detection** This will evaluate tesseract and surya for text line detection across a randomly sampled set of images from [doclaynet](https://huggingface.co/datasets/vikp/doclaynet_bench). ```shell python benchmark/detection.py --max_rows 256 ``` - `--max_rows` controls how many images to process for the benchmark - `--debug` will render images and detected bboxes - `--pdf_path` will let you specify a pdf to benchmark instead of the default data - `--results_dir` will let you specify a directory to save results to instead of the default one **Text recognition** This will evaluate surya and optionally tesseract on multilingual pdfs from common crawl (with synthetic data for missing languages). ```shell python benchmark/recognition.py --tesseract ``` - `--max_rows` controls how many images to process for the benchmark - `--debug 2` will render images with detected text - `--results_dir` will let you specify a directory to save results to instead of the default one - `--tesseract` will run the benchmark with tesseract. You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder. - Set `RECOGNITION_BATCH_SIZE=864` to use the same batch size as the benchmark. - Set `RECOGNITION_BENCH_DATASET_NAME=vikp/rec_bench_hist` to use the historical document data for benchmarking. This data comes from the [tapuscorpus](https://github.com/HTR-United/tapuscorpus). **Layout analysis** This will evaluate surya on the publaynet dataset. ```shell python benchmark/layout.py ``` - `--max_rows` controls how many images to process for the benchmark - `--debug` will render images with detected text - `--results_dir` will let you specify a directory to save results to instead of the default one **Reading Order** ```shell python benchmark/ordering.py ``` - `--max_rows` controls how many images to process for the benchmark - `--debug` will render images with detected text - `--results_dir` will let you specify a directory to save results to instead of the default one **Table Recognition** ```shell python benchmark/table_recognition.py --max_rows 1024 --tatr ``` - `--max_rows` controls how many images to process for the benchmark - `--debug` will render images with detected text - `--results_dir` will let you specify a directory to save results to instead of the default one - `--tatr` specifies whether to also run table transformer **LaTeX OCR** ```shell python benchmark/texify.py --max_rows 128 ``` - `--max_rows` controls how many images to process for the benchmark - `--results_dir` will let you specify a directory to save results to instead of the default one # Training Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified efficientvit architecture for semantic segmentation. Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes). # Finetuning Surya OCR You can now take Surya OCR further by training it on your own data with our [finetuning script](/surya/scripts/finetune_ocr.py). It’s built on Hugging Face Trainer, and supports all the [arguments](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments) that the huggingface trainer provides, and integrations like torchrun, or deepspeed. To setup your dataset, follow the example dataset format [here](https://huggingface.co/datasets/datalab-to/ocr_finetune_example) and provide the path to your own dataset when launching the training script. ```bash # Tested on 1xH100 GPU # Set --pretrained_checkpoint_path to load from a custom checkpoint, otherwise # the default surya ocr weights will be loaded as the initialization python surya/scripts/finetune_ocr.py \ --output_dir $OUTPUT_DIR \ --dataset_name datalab-to/ocr_finetune_example \ --per_device_train_batch_size 64 \ --gradient_checkpointing true \ --max_sequence_length 1024 ``` This is a minimal training script to get you started finetuning Surya. Our internal training stack includes character bounding box finetuning, sliding window attention with specialized attention masks, custom kernels, augmentations, and other optimizations that can push OCR accuracy well beyond standard finetuning. If you want to get the most out of your data, reach us at [email protected]! # Thanks This work would not have been possible without amazing open source AI work: - [Segformer](https://arxiv.org/pdf/2105.15203.pdf) from NVIDIA - [EfficientViT](https://github.com/mit-han-lab/efficientvit) from MIT - [timm](https://github.com/huggingface/pytorch-image-models) from Ross Wightman - [Donut](https://github.com/clovaai/donut) from Naver - [transformers](https://github.com/huggingface/transformers) from huggingface - [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model Thank you to everyone who makes open source AI possible. # Citation If you use surya (or the associated models) in your work or research, please consider citing us using the following BibTeX entry: ```bibtex @misc{paruchuri2025surya, author = {Vikas Paruchuri and Datalab Team}, title = {Surya: A lightweight document OCR and analysis toolkit}, year = {2025}, howpublished = {\url{https://github.com/VikParuchuri/surya}}, note = {GitHub repository}, } ``` -------------------------------------------------------------------------------- /benchmark/utils/__init__.py: -------------------------------------------------------------------------------- ```python ``` -------------------------------------------------------------------------------- /surya/__init__.py: -------------------------------------------------------------------------------- ```python ``` -------------------------------------------------------------------------------- /surya/detection/model/__init__.py: -------------------------------------------------------------------------------- ```python ``` -------------------------------------------------------------------------------- /surya/foundation/cache/__init__.py: -------------------------------------------------------------------------------- ```python ``` -------------------------------------------------------------------------------- /surya/ocr_error/model/__init__.py: -------------------------------------------------------------------------------- ```python ``` -------------------------------------------------------------------------------- /surya/scripts/__init__.py: -------------------------------------------------------------------------------- ```python ``` -------------------------------------------------------------------------------- /surya/table_rec/model/__init__.py: -------------------------------------------------------------------------------- ```python ``` -------------------------------------------------------------------------------- /surya/common/__init__.py: -------------------------------------------------------------------------------- ```python ``` -------------------------------------------------------------------------------- /ocr_text.py: -------------------------------------------------------------------------------- ```python from surya.scripts.ocr_text import ocr_text_cli if __name__ == "__main__": ocr_text_cli() ``` -------------------------------------------------------------------------------- /ocr_latex.py: -------------------------------------------------------------------------------- ```python from surya.scripts.ocr_latex import ocr_latex_cli if __name__ == "__main__": ocr_latex_cli() ``` -------------------------------------------------------------------------------- /texify_app.py: -------------------------------------------------------------------------------- ```python from surya.scripts.run_texify_app import texify_app_cli if __name__ == "__main__": texify_app_cli() ``` -------------------------------------------------------------------------------- /detect_layout.py: -------------------------------------------------------------------------------- ```python from surya.scripts.detect_layout import detect_layout_cli if __name__ == "__main__": detect_layout_cli() ``` -------------------------------------------------------------------------------- /detect_text.py: -------------------------------------------------------------------------------- ```python from surya.scripts.detect_text import detect_text_cli if __name__ == "__main__": detect_text_cli() ``` -------------------------------------------------------------------------------- /ocr_app.py: -------------------------------------------------------------------------------- ```python from surya.scripts.run_streamlit_app import streamlit_app_cli if __name__ == "__main__": streamlit_app_cli() ``` -------------------------------------------------------------------------------- /table_recognition.py: -------------------------------------------------------------------------------- ```python from surya.scripts.table_recognition import table_recognition_cli if __name__ == "__main__": table_recognition_cli() ``` -------------------------------------------------------------------------------- /surya/ocr_error/schema.py: -------------------------------------------------------------------------------- ```python from typing import List from pydantic import BaseModel class OCRErrorDetectionResult(BaseModel): texts: List[str] labels: List[str] ``` -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- ``` [pytest] testpaths=tests pythonpath=. filterwarnings = ignore::UserWarning ignore::PendingDeprecationWarning ignore::DeprecationWarning ``` -------------------------------------------------------------------------------- /surya/detection/schema.py: -------------------------------------------------------------------------------- ```python from typing import List, Optional, Any from pydantic import BaseModel from surya.common.polygon import PolygonBox class TextDetectionResult(BaseModel): bboxes: List[PolygonBox] heatmap: Optional[Any] affinity_map: Optional[Any] image_bbox: List[float] ``` -------------------------------------------------------------------------------- /surya/scripts/run_texify_app.py: -------------------------------------------------------------------------------- ```python import subprocess import os def texify_app_cli(): cur_dir = os.path.dirname(os.path.abspath(__file__)) ocr_app_path = os.path.join(cur_dir, "texify_app.py") cmd = ["streamlit", "run", ocr_app_path, "--server.fileWatcherType", "none", "--server.headless", "true"] subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"}) ``` -------------------------------------------------------------------------------- /surya/scripts/run_streamlit_app.py: -------------------------------------------------------------------------------- ```python import subprocess import os def streamlit_app_cli(): cur_dir = os.path.dirname(os.path.abspath(__file__)) ocr_app_path = os.path.join(cur_dir, "streamlit_app.py") cmd = ["streamlit", "run", ocr_app_path, "--server.fileWatcherType", "none", "--server.headless", "true"] subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"}) ``` -------------------------------------------------------------------------------- /surya/common/surya/schema.py: -------------------------------------------------------------------------------- ```python class TaskNames: block_without_boxes = "block_without_boxes" ocr_with_boxes = "ocr_with_boxes" ocr_without_boxes = "ocr_without_boxes" layout = "layout" table_structure = "table_structure" TASK_NAMES = [ TaskNames.block_without_boxes, TaskNames.ocr_with_boxes, TaskNames.ocr_without_boxes, TaskNames.layout, TaskNames.table_structure, ] ``` -------------------------------------------------------------------------------- /surya/layout/schema.py: -------------------------------------------------------------------------------- ```python from typing import Optional, Dict, List from pydantic import BaseModel from surya.common.polygon import PolygonBox class LayoutBox(PolygonBox): label: str position: int top_k: Optional[Dict[str, float]] = None class LayoutResult(BaseModel): bboxes: List[LayoutBox] image_bbox: List[float] sliced: bool = False # Whether the image was sliced and reconstructed ``` -------------------------------------------------------------------------------- /surya/detection/parallel.py: -------------------------------------------------------------------------------- ```python class FakeFuture: def __init__(self, func, *args, **kwargs): self._result = func(*args, **kwargs) def result(self): return self._result class FakeExecutor: def __init__(self, **kwargs): pass def __enter__(self): return self def __exit__(self, *excinfo): pass def submit(self, fn, *args, **kwargs): return FakeFuture(fn, *args, **kwargs) ``` -------------------------------------------------------------------------------- /tests/test_layout.py: -------------------------------------------------------------------------------- ```python def test_layout_topk(layout_predictor, test_image): layout_results = layout_predictor([test_image]) assert len(layout_results) == 1 assert layout_results[0].image_bbox == [0, 0, 1024, 1024] bboxes = layout_results[0].bboxes assert len(bboxes) == 2 assert bboxes[0].label == "SectionHeader" assert len(bboxes[0].top_k) == 5 assert bboxes[1].label == "Text" assert len(bboxes[1].top_k) == 5 ``` -------------------------------------------------------------------------------- /tests/test_foundation.py: -------------------------------------------------------------------------------- ```python from surya.foundation import FoundationPredictor def test_foundation_flash2(): try: f = FoundationPredictor(None, None, None, "flash_attention_2") assert f.model.decoder.config._attn_implementation == "flash_attention_2" assert f.model.vision_encoder.config._attn_implementation == "flash_attention_2" except Exception as e: assert False, ( f"FoundationPredictor with flash_attention_2 raised an exception: {e}" ) ``` -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- ```yaml name: Unit tests on: [push] jobs: build: runs-on: ${{ matrix.os }} strategy: matrix: os: [t4_gpu, ubuntu-latest, windows-latest] steps: - uses: actions/checkout@v3 - name: Set up Python 3.11 uses: actions/setup-python@v4 with: python-version: 3.11 - name: Install python dependencies run: | pip install poetry poetry install - name: Run tests run: poetry run pytest ``` -------------------------------------------------------------------------------- /surya/layout/label.py: -------------------------------------------------------------------------------- ```python LAYOUT_PRED_RELABEL = { "<page-header>": "PageHeader", "<page-footer>": "PageFooter", "<footnote>": "Footnote", "<image>": "Picture", "<figure>": "Figure", "<text>": "Text", "<caption>": "Caption", "<list-item>": "ListItem", "<section-header>": "SectionHeader", "<table>": "Table", "<table-of-contents>": "TableOfContents", "<form>": "Form", "<equation-block>": "Equation", "<code-block>": "Code", "<complex-block>": "Figure", } ``` -------------------------------------------------------------------------------- /tests/test_ocr_errors.py: -------------------------------------------------------------------------------- ```python def test_garbled_text(ocr_error_predictor): text = """" ; dh vksj ls mifLFkr vf/koDrk % Jh vfuy dqekj 2. vfHk;qDr dh vksj ls mifLFkr vf/koDrk % Jh iznhi d """.strip() results = ocr_error_predictor([text]) assert results.labels[0] == "bad" def test_good_text(ocr_error_predictor): text = """" There are professions more harmful than industrial design, but only a very few of them. """.strip() results = ocr_error_predictor([text]) assert results.labels[0] == "good" ``` -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- ```markdown --- name: Feature request about: Suggest an idea for this project title: "[FEAT]" labels: enhancement assignees: '' --- ## ✨ Is your feature request related to a problem? A clear and concise description of what the problem is. ## 💡 Describe the Solution You'd Like A concise description of what you want to happen or how you envision it working. ## 📋 Alternatives Considered Any alternative solutions or workarounds you've tried. ## 🧩 Additional Context Any additional context, references, or related issues. ``` -------------------------------------------------------------------------------- /surya/common/xla.py: -------------------------------------------------------------------------------- ```python import math from surya.settings import settings if settings.TORCH_DEVICE_MODEL == "xla": import torch_xla.core.xla_model as xm else: xm = None def get_nearest_pad( length: int, pad_multiple: int = settings.FOUNDATION_PAD_TO_NEAREST ): return math.ceil(length / pad_multiple) * pad_multiple def get_compile_args(device: str) -> dict: if not settings.FOUNDATION_XLA: return {} return { "backend": "openxla", } def mark_step(): if xm is not None: xm.mark_step() ``` -------------------------------------------------------------------------------- /tests/test_latex_ocr.py: -------------------------------------------------------------------------------- ```python from typing import List from PIL import Image, ImageDraw from surya.common.surya.schema import TaskNames from surya.recognition import OCRResult def test_latex_ocr(recognition_predictor, test_image_latex): width, height = test_image_latex.size results: List[OCRResult] = recognition_predictor( [test_image_latex], [TaskNames.block_without_boxes], bboxes=[[[0, 0, width, height]]] ) text = results[0].text_lines[0].text assert len(results) == 1 assert text.startswith("<math") assert text.endswith("</math>") ``` -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- ```yaml name: Python package on: push: tags: - "v*.*.*" jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Python 3.11 uses: actions/setup-python@v4 with: python-version: 3.11 - name: Install python dependencies run: | pip install poetry poetry install - name: Build package run: | poetry build - name: Publish package env: PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} run: | poetry config pypi-token.pypi "$PYPI_TOKEN" poetry publish ``` -------------------------------------------------------------------------------- /tests/test_detection.py: -------------------------------------------------------------------------------- ```python def test_detection(detection_predictor, test_image): detection_results = detection_predictor([test_image]) assert len(detection_results) == 1 assert detection_results[0].image_bbox == [0, 0, 1024, 1024] bboxes = detection_results[0].bboxes assert len(bboxes) == 4 def test_detection_chunking(detection_predictor, test_image_tall): detection_results = detection_predictor([test_image_tall]) assert len(detection_results) == 1 assert detection_results[0].image_bbox == [0, 0, 4096, 4096] bboxes = detection_results[0].bboxes assert len(bboxes) >= 3 # Sometimes merges into 3 assert abs(4000 - bboxes[1].polygon[0][0]) < 50 ``` -------------------------------------------------------------------------------- /surya/common/load.py: -------------------------------------------------------------------------------- ```python from typing import Optional, Any import torch from surya.settings import settings class ModelLoader: def __init__(self, checkpoint: Optional[str] = None): self.checkpoint = checkpoint def model( self, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE, attention_implementation: Optional[str] = None, ) -> Any: raise NotImplementedError() def processor( self, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE, ) -> Any: raise NotImplementedError() ``` -------------------------------------------------------------------------------- /surya/common/surya/processor/schema.py: -------------------------------------------------------------------------------- ```python from typing import TypedDict, Literal, List, Tuple import torch from PIL import Image class TaskDict(TypedDict): datasets: List[str] img_size: Tuple[int, int] class TasksDict(TypedDict): ocr_with_boxes: TaskDict ocr_without_boxes: TaskDict block_without_boxes: TaskDict class ProcessorInput(TypedDict): type: Literal["image", "ocr", "text", "empty_output"] class ImageInput(ProcessorInput): type: Literal["image"] image: Image.Image rotated: bool class TextInput(ProcessorInput): type: Literal["text"] text: str math: bool class ProcessorOutput(TypedDict): input_ids: List[int] image_tiles: torch.Tensor | None grid_thw: torch.Tensor | None ``` -------------------------------------------------------------------------------- /surya/logging.py: -------------------------------------------------------------------------------- ```python import logging import warnings from surya.settings import settings def configure_logging(): logger = get_logger() # Remove any existing handlers to prevent duplicates for handler in logger.handlers[:]: logger.removeHandler(handler) # Add our handler handler = logging.StreamHandler() formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) # Prevent propagation to parent loggers to avoid double logging logger.propagate = False logger.setLevel(settings.LOGLEVEL) warnings.simplefilter(action="ignore", category=FutureWarning) def get_logger(): return logging.getLogger("surya") ``` -------------------------------------------------------------------------------- /surya/common/pretrained.py: -------------------------------------------------------------------------------- ```python from typing import Optional from transformers import PreTrainedModel from transformers.utils import is_flash_attn_2_available class SuryaPreTrainedModel(PreTrainedModel): # No-op if we pass attention, so we can set attention however we want in the config def _check_and_adjust_attn_implementation( self, attn_implementation: Optional[str], **kwargs ): if attn_implementation is None: try: self._sdpa_can_dispatch(True) attn_implementation = "sdpa" except (ValueError, ImportError): attn_implementation = "eager" if self._supports_flash_attn and is_flash_attn_2_available(): attn_implementation = "flash_attention_2" return attn_implementation ``` -------------------------------------------------------------------------------- /surya/debug/fonts.py: -------------------------------------------------------------------------------- ```python from typing import List, Optional import os import requests from surya.settings import settings def get_font_path(langs: Optional[List[str]] = None) -> str: font_path = settings.RECOGNITION_RENDER_FONTS["all"] if langs is not None: for k in settings.RECOGNITION_RENDER_FONTS: if k in langs and len(langs) == 1: font_path = settings.RECOGNITION_RENDER_FONTS[k] break if not os.path.exists(font_path): os.makedirs(os.path.dirname(font_path), exist_ok=True) font_dl_path = f"{settings.RECOGNITION_FONT_DL_BASE}/{os.path.basename(font_path)}" with requests.get(font_dl_path, stream=True) as r, open(font_path, 'wb') as f: r.raise_for_status() for chunk in r.iter_content(chunk_size=8192): f.write(chunk) return font_path ``` -------------------------------------------------------------------------------- /surya/recognition/schema.py: -------------------------------------------------------------------------------- ```python import math import numpy as np from typing import Optional, List from pydantic import BaseModel, field_validator from surya.common.polygon import PolygonBox class BaseChar(PolygonBox): text: str confidence: Optional[float] = 0 @field_validator("confidence", mode="before") @classmethod def validate_confidence(cls, v: float) -> float: if v is None: return 0 elif math.isnan(v) or np.isnan(v): return 0 return v class TextChar(BaseChar): bbox_valid: bool = True # This is false when the given bbox is not valid class TextWord(BaseChar): bbox_valid: bool = True class TextLine(BaseChar): chars: List[TextChar] # Individual characters in the line original_text_good: bool = False words: List[TextWord] | None = None class OCRResult(BaseModel): text_lines: List[TextLine] image_bbox: List[float] ``` -------------------------------------------------------------------------------- /surya/table_rec/schema.py: -------------------------------------------------------------------------------- ```python from typing import List from pydantic import BaseModel from surya.common.polygon import PolygonBox class TableCell(PolygonBox): row_id: int colspan: int within_row_id: int cell_id: int is_header: bool rowspan: int | None = None merge_up: bool = False merge_down: bool = False col_id: int | None = None text_lines: List[dict] | None = None @property def label(self): return f'Cell {self.cell_id} {self.rowspan}/{self.colspan}' class TableRow(PolygonBox): row_id: int is_header: bool @property def label(self): return f'Row {self.row_id}' class TableCol(PolygonBox): col_id: int is_header: bool @property def label(self): return f'Column {self.col_id}' class TableResult(BaseModel): cells: List[TableCell] unmerged_cells: List[TableCell] rows: List[TableRow] cols: List[TableCol] image_bbox: List[float] ``` -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/output-bug-report.md: -------------------------------------------------------------------------------- ```markdown --- name: Output bug report about: Create a report about poor output quality title: "[BUG: Output]" labels: 'bug: output' assignees: '' --- ## 📝 Describe the Output Issue A clear and concise description of the incorrect or unexpected output. ## 📄 Input Document Attach the PDF or input file used. ## 📤 Current Output Paste the Markdown or HTML that Marker generated: ````markdown Paste output here ````` ## ✅ Expected Output Describe or paste what you expected Marker to generate. ## ⚙️ Environment Please fill in all relevant details: * **Marker version**: * **Surya version**: * **Python version**: * **PyTorch version**: * **Transformers version**: * **Operating System**: ## 📟 Command or Code Used Paste the **exact bash command** or **Python code** you used to run Marker: <details> <summary>Click to expand</summary> ```bash # or Python code block your_command_here --with-flags ``` </details> ## 📎 Additional Context Any other relevant info, configs, or assumptions. ``` -------------------------------------------------------------------------------- /benchmark/utils/textract.py: -------------------------------------------------------------------------------- ```python import os from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm import traceback from surya.input.processing import slice_bboxes_from_image from surya.recognition import RecognitionPredictor def textract_ocr(extractor, img): try: document = extractor.detect_document_text(file_source=img) return [line.text for line in document.lines] except: traceback.print_exc() return [None] def textract_ocr_parallel(imgs, cpus=None): from textractor import Textractor # Optional dependency extractor = Textractor(profile_name='default') parallel_cores = min(len(imgs), RecognitionPredictor().get_batch_size()) if not cpus: cpus = os.cpu_count() parallel_cores = min(parallel_cores, cpus) with ThreadPoolExecutor(max_workers=parallel_cores) as executor: textract_text = tqdm(executor.map(textract_ocr, [extractor]*len(imgs), imgs), total=len(imgs), desc="Running textract OCR") textract_text = list(textract_text) return textract_text ``` -------------------------------------------------------------------------------- /surya/models.py: -------------------------------------------------------------------------------- ```python from typing import Dict import torch from surya.common.predictor import BasePredictor from surya.detection import DetectionPredictor from surya.layout import LayoutPredictor from surya.logging import configure_logging from surya.ocr_error import OCRErrorPredictor from surya.foundation import FoundationPredictor from surya.recognition import RecognitionPredictor from surya.table_rec import TableRecPredictor from surya.settings import settings configure_logging() def load_predictors( device: str | torch.device | None = None, dtype: torch.dtype | str | None = None ) -> Dict[str, BasePredictor]: return { "layout": LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)), "ocr_error": OCRErrorPredictor(device=device, dtype=dtype), "recognition": RecognitionPredictor(FoundationPredictor(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT)), "detection": DetectionPredictor(device=device, dtype=dtype), "table_rec": TableRecPredictor(device=device, dtype=dtype), } ``` -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/breaking-bug-report.md: -------------------------------------------------------------------------------- ```markdown --- name: Breaking bug report about: Create a report about a breaking bug title: "[BUG: Breaking]" labels: 'bug: breaking' assignees: '' --- ## 🧨 Describe the Bug A clear and concise description of the breaking issue (e.g., crash, OOM, exception, etc). ## 📄 Input Document Attach the PDF or input file that triggered the error. ## 📤 Output Trace / Stack Trace Paste the **complete** stack trace or error output, if available. <details> <summary>Click to expand</summary> ``` Paste stack trace here ``` </details> ## ⚙️ Environment Please fill in all relevant details: - **Marker version**: - **Surya version**: - **Python version**: - **PyTorch version**: - **Transformers version**: - **Operating System** (incl. container info if relevant): ## ✅ Expected Behavior What did you expect Marker to do? ## 📟 Command or Code Used Paste the **exact bash command** or **Python code** you used to run Marker: <details> <summary>Click to expand</summary> ```bash # or Python code block your_command_here --with-flags ``` </details> ## 📎 Additional Context Any other context that might help us debug this (e.g., CLI options, working directory, runtime settings). ``` -------------------------------------------------------------------------------- /surya/detection/util.py: -------------------------------------------------------------------------------- ```python import math from PIL import ImageOps from surya.settings import settings def get_total_splits(image_size, height): img_height = list(image_size)[1] max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT if img_height > max_height: num_splits = math.ceil(img_height / height) return num_splits return 1 def split_image(img, height): # This will not modify/return the original image - it will either crop, or copy the image img_height = list(img.size)[1] max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT if img_height > max_height: num_splits = math.ceil(img_height / height) splits = [] split_heights = [] for i in range(num_splits): top = i * height bottom = (i + 1) * height if bottom > img_height: bottom = img_height cropped = img.crop((0, top, img.size[0], bottom)) chunk_height = bottom - top if chunk_height < height: cropped = ImageOps.pad(cropped, (img.size[0], height), color=255, centering=(0, 0)) splits.append(cropped) split_heights.append(chunk_height) return splits, split_heights return [img.copy()], [img_height] ``` -------------------------------------------------------------------------------- /benchmark/utils/scoring.py: -------------------------------------------------------------------------------- ```python import math from typing import List from rapidfuzz import fuzz def overlap_score(pred_lines: List[str], reference_lines: List[str]): line_scores = [] line_weights = [] line_match = {} for i, pred_line in enumerate(pred_lines): max_score = 0 line_weight = 1 match = None for j, ref_line in enumerate(reference_lines): score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100 if score > max_score: max_score = score line_weight = math.sqrt(len(ref_line)) match = j line_scores.append(max_score) line_weights.append(line_weight) line_match[i] = match line_scores = [line_scores[i] * line_weights[i] for i in range(len(line_scores))] return line_scores, line_weights, line_match def overlap_score_exact(pred_lines: List[str], reference_lines: List[str]): line_scores = [] line_weights = [] assert len(pred_lines) == len(reference_lines) for i, (pred_line, ref_line) in enumerate(zip(pred_lines, reference_lines)): score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100 weight = math.sqrt(len(ref_line)) line_scores.append(score * weight) line_weights.append(weight) return line_scores, line_weights ``` -------------------------------------------------------------------------------- /.github/workflows/cla.yml: -------------------------------------------------------------------------------- ```yaml name: "Surya CLA Assistant" on: issue_comment: types: [created] pull_request_target: types: [opened,closed,synchronize] # explicitly configure permissions, in case your GITHUB_TOKEN workflow permissions are set to read-only in repository settings permissions: actions: write contents: write pull-requests: write statuses: write jobs: CLAAssistant: runs-on: ubuntu-latest steps: - name: "Surya CLA Assistant" if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' uses: contributor-assistant/[email protected] env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # the below token should have repo scope and must be manually added by you in the repository's secret # This token is required only if you have configured to store the signatures in a remote repository/organization PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }} with: path-to-signatures: 'signatures/version1/cla.json' path-to-document: 'https://github.com/VikParuchuri/surya/blob/master/CLA.md' # branch should not be protected branch: 'master' allowlist: VikParuchuri ``` -------------------------------------------------------------------------------- /.github/workflows/scripts.yml: -------------------------------------------------------------------------------- ```yaml name: Test CLI scripts on: [push] jobs: build: runs-on: t4_gpu steps: - uses: actions/checkout@v3 - name: Set up Python 3.11 uses: actions/setup-python@v4 with: python-version: 3.11 - name: Install python dependencies run: | pip install poetry poetry install - name: Download benchmark data run: | wget -O benchmark_data.zip "https://drive.google.com/uc?export=download&id=1NHrdYatR1rtqs2gPVfdvO0BAvocH8CJi" unzip -o benchmark_data.zip - name: Test detection run: poetry run surya_detect benchmark_data/pdfs/switch_trans.pdf --page_range 0 - name: Test OCR env: RECOGNITION_MAX_TOKENS: 25 run: poetry run surya_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0 - name: Test layout run: poetry run surya_layout benchmark_data/pdfs/switch_trans.pdf --page_range 0 - name: Test table run: poetry run surya_table benchmark_data/pdfs/switch_trans.pdf --page_range 0 - name: Test texify env: TEXIFY_MAX_TOKENS: 25 run: poetry run surya_latex_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0 - name: Test detection folder run: poetry run surya_detect benchmark_data/pdfs --page_range 0 ``` -------------------------------------------------------------------------------- /surya/common/surya/encoder/config.py: -------------------------------------------------------------------------------- ```python from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class SuryaEncoderConfig(PretrainedConfig): model_type = "qwen2_5_vl" base_config_key = "vision_config" attribute_map = { "num_attention_heads": "num_heads", "num_hidden_layers": "depth", } def __init__( self, depth=8, hidden_size=1280, hidden_act="silu", intermediate_size=3420, num_heads=16, in_channels=3, patch_size=14, spatial_merge_size=2, spatial_patch_size=14, temporal_patch_size=1, tokens_per_second=4, window_size=112, out_hidden_size=1280, fullatt_block_indexes=(3, 7), initializer_range=0.02, image_size=4096, **kwargs, ): super().__init__(**kwargs) self.depth = depth self.hidden_size = hidden_size self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.num_heads = num_heads self.in_channels = in_channels self.patch_size = patch_size self.spatial_merge_size = spatial_merge_size self.temporal_patch_size = temporal_patch_size self.tokens_per_second = tokens_per_second self.window_size = window_size self.fullatt_block_indexes = fullatt_block_indexes self.out_hidden_size = out_hidden_size self.initializer_range = initializer_range self.spatial_patch_size = spatial_patch_size self.image_size = image_size ``` -------------------------------------------------------------------------------- /surya/detection/model/config.py: -------------------------------------------------------------------------------- ```python from transformers import PretrainedConfig from surya.common.s3 import S3DownloaderMixin class EfficientViTConfig(S3DownloaderMixin, PretrainedConfig): r""" ```""" model_type = "efficientvit" def __init__( self, num_classes=2, num_channels=3, widths=(32, 64, 128, 256, 512), head_dim=32, num_stages=4, depths=(1, 1, 1, 6, 6), strides=(2, 2, 2, 2, 2), hidden_sizes=(32, 64, 160, 256), patch_size=(7, 7), hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, classifier_dropout_prob=0.0, layer_norm_eps=1e-6, decoder_layer_hidden_size=128, decoder_hidden_size=512, semantic_loss_ignore_index=255, initializer_range=0.02, **kwargs, ): super().__init__(**kwargs) self.num_classes = num_classes self.widths = widths self.head_dim = head_dim self.num_channels = num_channels self.num_stages = num_stages self.depths = depths self.strides = strides self.hidden_sizes = hidden_sizes self.patch_size = patch_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.classifier_dropout_prob = classifier_dropout_prob self.layer_norm_eps = layer_norm_eps self.decoder_hidden_size = decoder_hidden_size self.decoder_layer_hidden_size = decoder_layer_hidden_size self.semantic_loss_ignore_index = semantic_loss_ignore_index self.initializer_range = initializer_range ``` -------------------------------------------------------------------------------- /tests/test_table_rec.py: -------------------------------------------------------------------------------- ```python from PIL import Image, ImageDraw def test_table_rec(table_rec_predictor): data = [ ["Name", "Age", "City"], ["Alice", 25, "New York"], ["Bob", 30, "Los Angeles"], ["Charlie", 35, "Chicago"], ] test_image = draw_table(data) results = table_rec_predictor([test_image]) assert len(results) == 1 assert results[0].image_bbox == [0, 0, test_image.size[0], test_image.size[1]] cells = results[0].cells assert len(cells) == 12 for row_id in range(4): for col_id in range(3): cell = [c for c in cells if c.row_id == row_id and c.col_id == col_id] assert len(cell) == 1, f"Missing cell at row {row_id}, col {col_id}" def draw_table(data, cell_width=100, cell_height=40): rows = len(data) cols = len(data[0]) width = cols * cell_width height = rows * cell_height image = Image.new('RGB', (width, height), 'white') draw = ImageDraw.Draw(image) for i in range(rows + 1): y = i * cell_height draw.line([(0, y), (width, y)], fill='black', width=1) for i in range(cols + 1): x = i * cell_width draw.line([(x, 0), (x, height)], fill='black', width=1) for i in range(rows): for j in range(cols): text = str(data[i][j]) text_bbox = draw.textbbox((0, 0), text) text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] x = j * cell_width + (cell_width - text_width) // 2 y = i * cell_height + (cell_height - text_height) // 2 draw.text((x, y), text, fill='black') return image ``` -------------------------------------------------------------------------------- /benchmark/utils/bbox.py: -------------------------------------------------------------------------------- ```python import fitz as pymupdf from surya.common.util import rescale_bbox def get_pdf_lines(pdf_path, img_sizes): doc = pymupdf.open(pdf_path) page_lines = [] for idx, img_size in enumerate(img_sizes): page = doc[idx] blocks = page.get_text("dict", sort=True, flags=pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES)["blocks"] line_boxes = [] for block_idx, block in enumerate(blocks): for l in block["lines"]: line_boxes.append(list(l["bbox"])) page_box = page.bound() pwidth, pheight = page_box[2] - page_box[0], page_box[3] - page_box[1] line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes] page_lines.append(line_boxes) return page_lines def merge_boxes(box1, box2): return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3])) def join_lines(bboxes, max_gap=5): to_merge = {} for i, box1 in bboxes: for z, box2 in bboxes[i + 1:]: j = i + z + 1 if box1 == box2: continue if box1[0] <= box2[0] and box1[2] >= box2[2]: if abs(box1[1] - box2[3]) <= max_gap: if i not in to_merge: to_merge[i] = [] to_merge[i].append(j) merged_boxes = set() merged = [] for i, box in bboxes: if i in merged_boxes: continue if i in to_merge: for j in to_merge[i]: box = merge_boxes(box, bboxes[j][1]) merged_boxes.add(j) merged.append(box) return merged ``` -------------------------------------------------------------------------------- /surya/scripts/ocr_latex.py: -------------------------------------------------------------------------------- ```python import os import click import json import time from collections import defaultdict from surya.logging import configure_logging, get_logger from surya.scripts.config import CLILoader from surya.foundation import FoundationPredictor from surya.recognition import RecognitionPredictor from surya.common.surya.schema import TaskNames configure_logging() logger = get_logger() @click.command(help="OCR LaTeX equations.") @CLILoader.common_options def ocr_latex_cli(input_path: str, **kwargs): loader = CLILoader(input_path, kwargs, highres=True) foundation_predictor = FoundationPredictor() texify_predictor = RecognitionPredictor(foundation_predictor) tasks = [TaskNames.block_without_boxes] * len(loader.images) bboxes = [[[0, 0, image.width, image.height]] for image in loader.images] start = time.time() predictions_by_image = texify_predictor( loader.images, tasks, bboxes=bboxes, ) latex_predictions = [p.text_lines[0].text for p in predictions_by_image] if loader.debug: logger.debug(f"OCR took {time.time() - start:.2f} seconds") max_chars = max([len(latex) for latex in latex_predictions]) logger.debug(f"Max chars: {max_chars}") out_preds = defaultdict(list) for name, pred, image in zip(loader.names, latex_predictions, loader.images): out_pred = { "equation": pred, "page": len(out_preds[name]) + 1, } out_preds[name].append(out_pred) with open( os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" ) as f: json.dump(out_preds, f, ensure_ascii=False) logger.info(f"Wrote results to {loader.result_path}") ``` -------------------------------------------------------------------------------- /surya/foundation/util.py: -------------------------------------------------------------------------------- ```python from typing import List, Tuple import numpy as np import torch def detect_repeat_token(predicted_tokens: List[int], max_repeats: int = 40): if len(predicted_tokens) < max_repeats: return False # Detect repeats containing 1 or 2 tokens last_n = predicted_tokens[-max_repeats:] unique_tokens = len(set(last_n)) if unique_tokens > 5: return False return last_n[-unique_tokens:] == last_n[-unique_tokens * 2 : -unique_tokens] def prediction_to_polygon_batch( pred: torch.Tensor, img_sizes: List[Tuple[int, int]], bbox_scaler, skew_scaler, skew_min=0.001, ): img_sizes = torch.from_numpy(np.array(img_sizes, dtype=np.float32)).to( pred.device ) w_scale = (img_sizes[:, 1] / bbox_scaler)[:, None, None] h_scale = (img_sizes[:, 0] / bbox_scaler)[:, None, None] cx = pred[:, :, 0] cy = pred[:, :, 1] width = pred[:, :, 2] height = pred[:, :, 3] x1 = cx - width / 2 y1 = cy - height / 2 x2 = cx + width / 2 y2 = cy + height / 2 skew_x = torch.floor((pred[:, :, 4] - skew_scaler) / 2) skew_y = torch.floor((pred[:, :, 5] - skew_scaler) / 2) skew_x[torch.abs(skew_x) < skew_min] = 0 skew_y[torch.abs(skew_y) < skew_min] = 0 polygons_flat = torch.stack( [ x1 - skew_x, y1 - skew_y, x2 - skew_x, y1 + skew_y, x2 + skew_x, y2 + skew_y, x1 + skew_x, y2 - skew_y, ], dim=2, ) batch_size, seq_len, _ = pred.shape polygons = polygons_flat.view(batch_size, seq_len, 4, 2) polygons[:, :, :, 0] *= w_scale polygons[:, :, :, 1] *= h_scale return polygons ``` -------------------------------------------------------------------------------- /surya/scripts/detect_text.py: -------------------------------------------------------------------------------- ```python import click import copy import json import time from collections import defaultdict from surya.detection import DetectionPredictor from surya.debug.draw import draw_polys_on_image from surya.logging import configure_logging, get_logger from surya.scripts.config import CLILoader import os configure_logging() logger = get_logger() @click.command(help="Detect bboxes in an input file or folder (PDFs or image).") @CLILoader.common_options def detect_text_cli(input_path: str, **kwargs): loader = CLILoader(input_path, kwargs) det_predictor = DetectionPredictor() start = time.time() predictions = det_predictor(loader.images, include_maps=loader.debug) end = time.time() if loader.debug: logger.debug(f"Detection took {end - start} seconds") if loader.save_images: for idx, (image, pred, name) in enumerate( zip(loader.images, predictions, loader.names) ): polygons = [p.polygon for p in pred.bboxes] bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image)) bbox_image.save(os.path.join(loader.result_path, f"{name}_{idx}_bbox.png")) if loader.debug: heatmap = pred.heatmap heatmap.save(os.path.join(loader.result_path, f"{name}_{idx}_heat.png")) predictions_by_page = defaultdict(list) for idx, (pred, name, image) in enumerate( zip(predictions, loader.names, loader.images) ): out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"]) out_pred["page"] = len(predictions_by_page[name]) + 1 predictions_by_page[name].append(out_pred) with open( os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" ) as f: json.dump(predictions_by_page, f, ensure_ascii=False) logger.info(f"Wrote results to {loader.result_path}") ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- ```toml [tool.poetry] name = "surya-ocr" version = "0.17.0" description = "OCR, layout, reading order, and table recognition in 90+ languages" authors = ["Vik Paruchuri <[email protected]>"] readme = "README.md" license = "GPL-3.0-or-later" repository = "https://github.com/VikParuchuri/surya" keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"] packages = [ {include = "surya"} ] [tool.poetry.dependencies] python = "^3.10" transformers = ">=4.56.1" torch = "^2.7.0" pydantic = "^2.5.3" pydantic-settings = "^2.1.0" python-dotenv = "^1.0.0" pillow = "^10.2.0" pypdfium2 = "=4.30.0" filetype = "^1.2.0" click = "^8.1.8" platformdirs = "^4.3.6" opencv-python-headless = "==4.11.0.86" einops = "^0.8.1" pre-commit = "^4.2.0" [tool.poetry.group.dev.dependencies] jupyter = "^1.0.0" pytesseract = "^0.3.10" pymupdf = "^1.23.8" datasets = "^2.16.1" rapidfuzz = "^3.6.1" streamlit = "^1.31.0" pytest = "^8.3.4" pdftext = "^0.5.1" tabulate = "^0.9.0" [tool.poetry.scripts] surya_detect = "surya.scripts.detect_text:detect_text_cli" surya_ocr = "surya.scripts.ocr_text:ocr_text_cli" surya_layout = "surya.scripts.detect_layout:detect_layout_cli" surya_gui = "surya.scripts.run_streamlit_app:streamlit_app_cli" surya_table = "surya.scripts.table_recognition:table_recognition_cli" surya_latex_ocr = "surya.scripts.ocr_latex:ocr_latex_cli" texify_gui = "surya.scripts.run_texify_app:texify_app_cli" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [[tool.poetry.source]] name = "libtpu-releases" url = "https://storage.googleapis.com/libtpu-releases/index.html" priority = "supplemental" [[tool.poetry.source]] name = "libtpu-wheels" url = "https://storage.googleapis.com/libtpu-wheels/index.html" priority = "supplemental" [tool.poetry.group.xla] optional = true [tool.poetry.group.xla.dependencies] torch-xla = {version = "^2.4.1", extras = ["tpu"]} ``` -------------------------------------------------------------------------------- /.github/workflows/benchmarks.yml: -------------------------------------------------------------------------------- ```yaml name: Integration test on: [push] env: PYTHONIOENCODING: "utf-8" jobs: build: runs-on: t4_gpu steps: - uses: actions/checkout@v3 - name: Set up Python 3.11 uses: actions/setup-python@v4 with: python-version: 3.11 - name: Install python dependencies run: | pip install poetry poetry install - name: Run detection benchmark test run: | poetry run python benchmark/detection.py --max_rows 2 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection - name: Run recognition benchmark test run: | poetry run python benchmark/recognition.py --max_rows 2 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition - name: Run layout benchmark test run: | poetry run python benchmark/layout.py --max_rows 5 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout - name: Run ordering benchmark run: | poetry run python benchmark/ordering.py --max_rows 5 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering - name: Run table recognition benchmark run: | poetry run python benchmark/table_recognition.py --max_rows 5 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition - name: Run texify benchmark run: | poetry run python benchmark/texify.py --max_rows 5 poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/texify_bench/results.json --bench_type texify ``` -------------------------------------------------------------------------------- /surya/ocr_error/model/config.py: -------------------------------------------------------------------------------- ```python from collections import OrderedDict from typing import Mapping from transformers.configuration_utils import PretrainedConfig from transformers.onnx import OnnxConfig from surya.common.s3 import S3DownloaderMixin ID2LABEL = { 0: 'good', 1: 'bad' } class DistilBertConfig(S3DownloaderMixin, PretrainedConfig): model_type = "distilbert" attribute_map = { "hidden_size": "dim", "num_attention_heads": "n_heads", "num_hidden_layers": "n_layers", } def __init__( self, vocab_size=30522, max_position_embeddings=512, sinusoidal_pos_embds=False, n_layers=6, n_heads=12, dim=768, hidden_dim=4 * 768, dropout=0.1, attention_dropout=0.1, activation="gelu", initializer_range=0.02, qa_dropout=0.1, seq_classif_dropout=0.2, pad_token_id=0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.sinusoidal_pos_embds = sinusoidal_pos_embds self.n_layers = n_layers self.n_heads = n_heads self.dim = dim self.hidden_dim = hidden_dim self.dropout = dropout self.attention_dropout = attention_dropout self.activation = activation self.initializer_range = initializer_range self.qa_dropout = qa_dropout self.seq_classif_dropout = seq_classif_dropout super().__init__(**kwargs, pad_token_id=pad_token_id) class DistilBertOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: if self.task == "multiple-choice": dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} else: dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ ("input_ids", dynamic_axis), ("attention_mask", dynamic_axis), ] ) ``` -------------------------------------------------------------------------------- /surya/debug/katex.js: -------------------------------------------------------------------------------- ```javascript <style> .katex-display-container { display: inline-block; max-width: 100%; overflow-x: auto; max-height: 100%; } .katex-inline-container { display: inline-block; max-width: 100%; overflow-x: auto; max-height: 100%; } </style> <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.js" onload="setTimeout(function() {renderMath()})" async></script> <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css"> <script> function htmlUnescape(escapedText) { const htmlEntities = { '&': '&', '<': '<', '>': '>', '"': '"', ''': "'", ' ': ' ' }; return escapedText.replace(/&|<|>|"|'| /g, match => htmlEntities[match]); } const renderMath = (function() { try { const mathElements = document.querySelectorAll('math'); mathElements.forEach(function(element) { let mathContent = element.innerHTML.trim(); mathContent = htmlUnescape(mathContent); const isDisplay = element.getAttribute('display') === 'block'; const container = document.createElement('span'); container.className = isDisplay ? 'katex-display-container' : 'katex-inline-container'; element.parentNode.insertBefore(container, element); try { katex.render(mathContent, container, { displayMode: isDisplay, throwOnError: false }); } catch (err) { console.error('KaTeX rendering error:', err); container.textContent = mathContent; // Fallback to raw text } element.parentNode.removeChild(element); }); console.log('Math rendering complete with', mathElements.length, 'expressions'); } catch (err) { console.error('Error in renderMath function:', err); } }); </script> ``` -------------------------------------------------------------------------------- /surya/scripts/detect_layout.py: -------------------------------------------------------------------------------- ```python import time import click import copy import json from collections import defaultdict from surya.foundation import FoundationPredictor from surya.layout import LayoutPredictor from surya.debug.draw import draw_polys_on_image from surya.logging import configure_logging, get_logger from surya.scripts.config import CLILoader from surya.settings import settings import os configure_logging() logger = get_logger() @click.command(help="Detect layout of an input file or folder (PDFs or image).") @CLILoader.common_options def detect_layout_cli(input_path: str, **kwargs): loader = CLILoader(input_path, kwargs) foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) layout_predictor = LayoutPredictor(foundation_predictor) start = time.time() layout_predictions = layout_predictor(loader.images) if loader.debug: logger.debug(f"Layout took {time.time() - start} seconds") if loader.save_images: for idx, (image, layout_pred, name) in enumerate( zip(loader.images, layout_predictions, loader.names) ): polygons = [p.polygon for p in layout_pred.bboxes] labels = [f"{p.label}-{p.position}" for p in layout_pred.bboxes] bbox_image = draw_polys_on_image( polygons, copy.deepcopy(image), labels=labels ) bbox_image.save( os.path.join(loader.result_path, f"{name}_{idx}_layout.png") ) predictions_by_page = defaultdict(list) for idx, (pred, name, image) in enumerate( zip(layout_predictions, loader.names, loader.images) ): out_pred = pred.model_dump() out_pred["page"] = len(predictions_by_page[name]) + 1 predictions_by_page[name].append(out_pred) with open( os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" ) as f: json.dump(predictions_by_page, f, ensure_ascii=False) logger.info(f"Wrote results to {loader.result_path}") ``` -------------------------------------------------------------------------------- /surya/ocr_error/loader.py: -------------------------------------------------------------------------------- ```python from typing import Optional import torch from surya.common.load import ModelLoader from surya.logging import get_logger from surya.ocr_error.model.config import DistilBertConfig from surya.ocr_error.model.encoder import DistilBertForSequenceClassification from surya.ocr_error.tokenizer import DistilBertTokenizer from surya.settings import settings logger = get_logger() class OCRErrorModelLoader(ModelLoader): def __init__(self, checkpoint: Optional[str] = None): super().__init__(checkpoint) if self.checkpoint is None: self.checkpoint = settings.OCR_ERROR_MODEL_CHECKPOINT def model( self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE, attention_implementation: Optional[str] = None, ) -> DistilBertForSequenceClassification: if device is None: device = settings.TORCH_DEVICE_MODEL if dtype is None: dtype = settings.MODEL_DTYPE config = DistilBertConfig.from_pretrained(self.checkpoint) model = ( DistilBertForSequenceClassification.from_pretrained( self.checkpoint, dtype=dtype, config=config, ) .to(device) .eval() ) if settings.COMPILE_ALL or settings.COMPILE_OCR_ERROR: torch._dynamo.config.cache_size_limit = 1 torch._dynamo.config.suppress_errors = False logger.info( f"Compiling detection model {self.checkpoint} from {DistilBertForSequenceClassification.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" ) compile_args = {"backend": "openxla"} if device == "xla" else {} model = torch.compile(model, **compile_args) return model def processor( self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE ) -> DistilBertTokenizer: return DistilBertTokenizer.from_pretrained(self.checkpoint) ``` -------------------------------------------------------------------------------- /benchmark/utils/verify_benchmark_scores.py: -------------------------------------------------------------------------------- ```python import json import click def verify_layout(data): scores = data["metrics"] for layout_type, metrics in scores.items(): if layout_type == "List": # Skip lists since none appear early on continue if metrics["precision"] <= 0.6 or metrics["recall"] <= 0.6: raise ValueError("Scores do not meet the required threshold") def verify_det(data): scores = data["metrics"]["surya"] if scores["precision"] <= 0.9 or scores["recall"] <= 0.9: raise ValueError("Scores do not meet the required threshold") def verify_rec(data): scores = data["surya"] if scores["avg_score"] <= 0.9: raise ValueError("Scores do not meet the required threshold") def verify_order(data): score = data["mean_accuracy"] if score < 0.75: raise ValueError("Scores do not meet the required threshold") def verify_table_rec(data): row_score = data["surya"]["mean_row_iou"] col_score = data["surya"]["mean_col_iou"] if row_score < 0.75 or col_score < 0.75: raise ValueError("Scores do not meet the required threshold") def verify_texify(data): edit_dist = data["scores"] if edit_dist > 0.2: raise ValueError("Scores do not meet the required threshold") @click.command(help="Verify benchmark scores") @click.argument("file_path", type=str) @click.option( "--bench_type", type=str, help="Type of benchmark to verify", default="detection" ) def main(file_path, bench_type): with open(file_path, "r") as file: data = json.load(file) if bench_type == "detection": verify_det(data) elif bench_type == "recognition": verify_rec(data) elif bench_type == "layout": verify_layout(data) elif bench_type == "ordering": verify_order(data) elif bench_type == "table_recognition": verify_table_rec(data) elif bench_type == "texify": verify_texify(data) else: raise ValueError("Invalid benchmark type") if __name__ == "__main__": main() ``` -------------------------------------------------------------------------------- /surya/debug/draw.py: -------------------------------------------------------------------------------- ```python from PIL import ImageDraw, ImageFont from surya.debug.fonts import get_font_path from surya.debug.text import get_text_size def draw_bboxes_on_image( bboxes, image, labels=None, label_font_size=10, color: str | list = "red" ): polys = [] for bb in bboxes: # Clockwise polygon poly = [[bb[0], bb[1]], [bb[2], bb[1]], [bb[2], bb[3]], [bb[0], bb[3]]] polys.append(poly) return draw_polys_on_image( polys, image, labels, label_font_size=label_font_size, color=color ) def draw_polys_on_image( corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10, color: str | list = "red", ): draw = ImageDraw.Draw(image) font_path = get_font_path() label_font = ImageFont.truetype(font_path, label_font_size) for i in range(len(corners)): poly = corners[i] poly = [(int(p[0]), int(p[1])) for p in poly] draw.polygon( poly, outline=color[i] if isinstance(color, list) else color, width=1 ) if labels is not None: label = labels[i] text_position = ( min([p[0] for p in poly]) + label_offset, min([p[1] for p in poly]) + label_offset, ) text_size = get_text_size(label, label_font) box_position = ( text_position[0] - box_padding + label_offset, text_position[1] - box_padding + label_offset, text_position[0] + text_size[0] + box_padding + label_offset, text_position[1] + text_size[1] + box_padding + label_offset, ) try: draw.rectangle(box_position, fill="white") except Exception as e: print(f"Error drawing rectangle at {box_position}: {e}") continue draw.text( text_position, label, fill=color[i] if isinstance(color, list) else color, font=label_font, ) return image ``` -------------------------------------------------------------------------------- /surya/recognition/languages.py: -------------------------------------------------------------------------------- ```python CODE_TO_LANGUAGE = { "_math": "Math", "af": "Afrikaans", "am": "Amharic", "ar": "Arabic", "as": "Assamese", "az": "Azerbaijani", "be": "Belarusian", "bg": "Bulgarian", "bn": "Bengali", "br": "Breton", "bs": "Bosnian", "ca": "Catalan", "cs": "Czech", "cy": "Welsh", "da": "Danish", "de": "German", "el": "Greek", "en": "English", "eo": "Esperanto", "es": "Spanish", "et": "Estonian", "eu": "Basque", "fa": "Persian", "fi": "Finnish", "fr": "French", "fy": "Western Frisian", "ga": "Irish", "gd": "Scottish Gaelic", "gl": "Galician", "gu": "Gujarati", "ha": "Hausa", "he": "Hebrew", "hi": "Hindi", "hr": "Croatian", "hu": "Hungarian", "hy": "Armenian", "id": "Indonesian", "is": "Icelandic", "it": "Italian", "ja": "Japanese", "jv": "Javanese", "ka": "Georgian", "kk": "Kazakh", "km": "Khmer", "kn": "Kannada", "ko": "Korean", "ku": "Kurdish", "ky": "Kyrgyz", "la": "Latin", "lo": "Lao", "lt": "Lithuanian", "lv": "Latvian", "mg": "Malagasy", "mk": "Macedonian", "ml": "Malayalam", "mn": "Mongolian", "mr": "Marathi", "ms": "Malay", "my": "Burmese", "ne": "Nepali", "nl": "Dutch", "no": "Norwegian", "om": "Oromo", "or": "Oriya", "pa": "Punjabi", "pl": "Polish", "ps": "Pashto", "pt": "Portuguese", "ro": "Romanian", "ru": "Russian", "sa": "Sanskrit", "sd": "Sindhi", "si": "Sinhala", "sk": "Slovak", "sl": "Slovenian", "so": "Somali", "sq": "Albanian", "sr": "Serbian", "su": "Sundanese", "sv": "Swedish", "sw": "Swahili", "ta": "Tamil", "te": "Telugu", "th": "Thai", "tl": "Tagalog", "tr": "Turkish", "ug": "Uyghur", "uk": "Ukrainian", "ur": "Urdu", "uz": "Uzbek", "vi": "Vietnamese", "xh": "Xhosa", "yi": "Yiddish", "zh": "Chinese", } LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()} ``` -------------------------------------------------------------------------------- /surya/common/surya/embedder/__init__.py: -------------------------------------------------------------------------------- ```python import torch import torch.nn as nn import torch.nn.functional as F class SimpleTokenEmbedder(nn.Module): def __init__(self, config): super().__init__() self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size) self.bbox_embed = nn.ModuleList( [ nn.Embedding( config.bbox_size + config.special_token_count, config.bbox_embed_size, ) for _ in range(6) ] ) self.max_bbox_embedding = config.bbox_size + config.special_token_count - 1 self.max_bbox_size = config.bbox_size def embed( self, input_tokens: torch.Tensor, input_boxes: torch.Tensor | None, embed_boxes: torch.Tensor, ) -> torch.Tensor: # Embed tokens token_embeds = self.token_embed(input_tokens) # Optionally embed boxes if input_boxes is not None and embed_boxes.any(): # Is none in prefill input_boxes = input_boxes.to(torch.long) bbox_loss_ignore_mask = ( (input_boxes[:, :, 0] < 0) | (input_boxes[:, :, 0] > self.max_bbox_size) ).unsqueeze(-1) input_boxes = torch.clamp(input_boxes, 0, self.max_bbox_embedding) bbox_embeds = torch.sum( torch.stack( [ self.bbox_embed[i](input_boxes[:, :, i]) for i in range(len(self.bbox_embed)) ], dim=-1, ), dim=-1, ) bbox_embeds = F.pad( bbox_embeds, (token_embeds.shape[-1] - bbox_embeds.shape[-1], 0) ) embed_boxes = embed_boxes.unsqueeze(1).unsqueeze(1).expand_as(bbox_embeds) bbox_loss_ignore_mask = bbox_loss_ignore_mask.expand_as(bbox_embeds) mask = embed_boxes & ~bbox_loss_ignore_mask bbox_embeds *= mask.float() token_embeds = token_embeds + bbox_embeds return token_embeds ``` -------------------------------------------------------------------------------- /surya/detection/loader.py: -------------------------------------------------------------------------------- ```python from typing import Optional import torch from surya.common.load import ModelLoader from surya.detection.processor import SegformerImageProcessor from surya.detection.model.config import EfficientViTConfig from surya.detection.model.encoderdecoder import EfficientViTForSemanticSegmentation from surya.logging import get_logger from surya.settings import settings logger = get_logger() class DetectionModelLoader(ModelLoader): def __init__(self, checkpoint: Optional[str] = None): super().__init__(checkpoint) if self.checkpoint is None: self.checkpoint = settings.DETECTOR_MODEL_CHECKPOINT def model( self, device: Optional[torch.device | str] = None, dtype: Optional[torch.dtype | str] = None, attention_implementation: Optional[str] = None, ) -> EfficientViTForSemanticSegmentation: if device is None: device = settings.TORCH_DEVICE_MODEL if dtype is None: dtype = settings.MODEL_DTYPE config = EfficientViTConfig.from_pretrained(self.checkpoint) model = EfficientViTForSemanticSegmentation.from_pretrained( self.checkpoint, dtype=dtype, config=config, ) model = model.to(device) model = model.eval() if settings.COMPILE_ALL or settings.COMPILE_DETECTOR: torch._dynamo.config.cache_size_limit = 1 torch._dynamo.config.suppress_errors = False logger.info( f"Compiling detection model {self.checkpoint} on device {device} with dtype {dtype}" ) compile_args = {"backend": "openxla"} if device == "xla" else {} model = torch.compile(model, **compile_args) logger.debug( f"Loaded detection model {self.checkpoint} from {EfficientViTForSemanticSegmentation.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" ) return model def processor( self, device: Optional[torch.device | str] = None, dtype: Optional[torch.dtype | str] = None, ) -> SegformerImageProcessor: return SegformerImageProcessor.from_pretrained(self.checkpoint) ``` -------------------------------------------------------------------------------- /surya/scripts/hf_to_s3.py: -------------------------------------------------------------------------------- ```python import json import shutil import datetime from pathlib import Path import boto3 from huggingface_hub import snapshot_download import click from tqdm import tqdm S3_API_URL = "https://1afbe4656a6b40d982ab5e730a39f6b9.r2.cloudflarestorage.com" # Example usage - python scripts/hf_to_s3.py <REPO_NAME> layout # This will upload to s3://layout/TODAYS_DATE @click.command(help="Uploads the data from huggingface to an S3 bucket") @click.argument("hf_repo_id", type=str) @click.argument("s3_path", type=str) @click.option("--bucket_name", type=str, default="datalab") @click.option("--revision_hash", type=str, default=None) @click.option("--access_key_id", type=str, default="<access_key_id>") @click.option("--access_key_secret", type=str, default="<access_key_secret>") @click.option("--suffix", type=str, default="") def main( hf_repo_id: str, s3_path: str, bucket_name: str, revision_hash: str, access_key_id: str, access_key_secret: str, suffix: str, ): curr_date = datetime.datetime.now().strftime("%Y_%m_%d") s3_path = f"{s3_path}/{curr_date}" if suffix: s3_path = f"{s3_path}_{suffix}" download_folder = snapshot_download(repo_id=hf_repo_id, revision=revision_hash) download_folder = Path(download_folder) contained_files = list(download_folder.glob("*")) contained_files = [f.name for f in contained_files] # Just get the base name manifest_file = download_folder / "manifest.json" with open(manifest_file, "w") as f: json.dump({"files": contained_files}, f) # Upload the files to S3 s3_client = boto3.client( service_name="s3", endpoint_url=S3_API_URL, aws_access_key_id=access_key_id, aws_secret_access_key=access_key_secret, region_name="auto", ) # Iterate through all files in the folder for file_path in tqdm( download_folder.glob("*"), desc="Uploading files", unit="file" ): s3_key = f"{s3_path}/{file_path.name}" try: s3_client.upload_file(str(file_path), bucket_name, s3_key) except Exception as e: print(f"Error uploading {file_path}: {str(e)}") shutil.rmtree(download_folder) print(f"Uploaded files to {s3_path}") if __name__ == "__main__": main() ``` -------------------------------------------------------------------------------- /surya/ocr_error/__init__.py: -------------------------------------------------------------------------------- ```python import math from typing import List, Optional from tqdm import tqdm from surya.common.predictor import BasePredictor from surya.ocr_error.loader import OCRErrorModelLoader from surya.ocr_error.model.config import ID2LABEL from surya.ocr_error.schema import OCRErrorDetectionResult from surya.settings import settings from surya.common.xla import mark_step class OCRErrorPredictor(BasePredictor): model_loader_cls = OCRErrorModelLoader batch_size = settings.OCR_ERROR_BATCH_SIZE default_batch_sizes = {"cpu": 8, "mps": 8, "cuda": 64, "xla": 32} def __call__(self, texts: List[str], batch_size: Optional[int] = None): return self.batch_ocr_error_detection(texts, batch_size) def batch_ocr_error_detection( self, texts: List[str], batch_size: Optional[int] = None ): if batch_size is None: batch_size = self.get_batch_size() num_batches = math.ceil(len(texts) / batch_size) texts_processed = self.processor( texts, padding="longest", truncation=True, return_tensors="pt" ) predictions = [] for batch_idx in tqdm( range(num_batches), desc="Running OCR Error Detection", disable=self.disable_tqdm, ): start_idx, end_idx = batch_idx * batch_size, (batch_idx + 1) * batch_size batch_input_ids = texts_processed.input_ids[start_idx:end_idx].to( self.model.device ) batch_attention_mask = texts_processed.attention_mask[start_idx:end_idx].to( self.model.device ) # Pad to batch size current_batch_size = batch_input_ids.shape[0] if settings.OCR_ERROR_STATIC_CACHE: batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) batch_attention_mask = self.pad_to_batch_size( batch_attention_mask, batch_size ) with settings.INFERENCE_MODE(): pred = self.model(batch_input_ids, attention_mask=batch_attention_mask) logits = pred.logits.argmax(dim=1).cpu().tolist()[:current_batch_size] predictions.extend(logits) mark_step() return OCRErrorDetectionResult( texts=texts, labels=[ID2LABEL[p] for p in predictions] ) ``` -------------------------------------------------------------------------------- /surya/input/load.py: -------------------------------------------------------------------------------- ```python from typing import List import PIL from surya.input.processing import open_pdf, get_page_images from surya.logging import get_logger from surya.settings import settings import os import filetype from PIL import Image import json logger = get_logger() def get_name_from_path(path): return os.path.basename(path).split(".")[0] def load_pdf(pdf_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI): doc = open_pdf(pdf_path) last_page = len(doc) if page_range: assert all([0 <= page < last_page for page in page_range]), ( f"Invalid page range: {page_range}" ) else: page_range = list(range(last_page)) images = get_page_images(doc, page_range, dpi=dpi) doc.close() names = [get_name_from_path(pdf_path) for _ in page_range] return images, names def load_image(image_path): image = Image.open(image_path).convert("RGB") name = get_name_from_path(image_path) return [image], [name] def load_from_file( input_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI ): input_type = filetype.guess(input_path) if input_type and input_type.extension == "pdf": return load_pdf(input_path, page_range, dpi=dpi) else: return load_image(input_path) def load_from_folder( folder_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI ): image_paths = [ os.path.join(folder_path, image_name) for image_name in os.listdir(folder_path) if not image_name.startswith(".") ] image_paths = [ip for ip in image_paths if not os.path.isdir(ip)] images = [] names = [] for path in image_paths: extension = filetype.guess(path) if extension and extension.extension == "pdf": image, name = load_pdf(path, page_range, dpi=dpi) images.extend(image) names.extend(name) else: try: image, name = load_image(path) images.extend(image) names.extend(name) except PIL.UnidentifiedImageError: logger.warning(f"Could not load image {path}") continue return images, names def load_lang_file(lang_path, names): with open(lang_path, "r") as f: lang_dict = json.load(f) return [lang_dict[name].copy() for name in names] ``` -------------------------------------------------------------------------------- /surya/scripts/ocr_text.py: -------------------------------------------------------------------------------- ```python import os import click import json import time from collections import defaultdict from surya.common.surya.schema import TaskNames from surya.detection import DetectionPredictor from surya.debug.text import draw_text_on_image from surya.logging import configure_logging, get_logger from surya.foundation import FoundationPredictor from surya.recognition import RecognitionPredictor from surya.scripts.config import CLILoader configure_logging() logger = get_logger() @click.command(help="OCR text.") @click.option("--task_name", type=str, default=TaskNames.ocr_with_boxes) @click.option( "--disable_math", is_flag=True, default=False, help="Do not recognize math in OCR." ) @CLILoader.common_options def ocr_text_cli(input_path: str, task_name: str, disable_math: bool, **kwargs): loader = CLILoader(input_path, kwargs, highres=True) task_names = [task_name] * len(loader.images) foundation_predictor = FoundationPredictor() det_predictor = DetectionPredictor() rec_predictor = RecognitionPredictor(foundation_predictor) start = time.time() predictions_by_image = rec_predictor( loader.images, task_names=task_names, det_predictor=det_predictor, highres_images=loader.highres_images, math_mode=not disable_math, ) if loader.debug: logger.debug(f"OCR took {time.time() - start:.2f} seconds") max_chars = max( [len(line.text) for p in predictions_by_image for line in p.text_lines] ) logger.debug(f"Max chars: {max_chars}") if loader.save_images: for idx, (name, image, pred) in enumerate( zip(loader.names, loader.images, predictions_by_image) ): bboxes = [line.bbox for line in pred.text_lines] pred_text = [line.text for line in pred.text_lines] page_image = draw_text_on_image(bboxes, pred_text, image.size) page_image.save(os.path.join(loader.result_path, f"{name}_{idx}_text.png")) out_preds = defaultdict(list) for name, pred, image in zip(loader.names, predictions_by_image, loader.images): out_pred = pred.model_dump() out_pred["page"] = len(out_preds[name]) + 1 out_preds[name].append(out_pred) with open( os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" ) as f: json.dump(out_preds, f, ensure_ascii=False) logger.info(f"Wrote results to {loader.result_path}") ``` -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- ```python import os os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import pytest from PIL import Image, ImageDraw from surya.detection import DetectionPredictor from surya.ocr_error import OCRErrorPredictor from surya.layout import LayoutPredictor from surya.recognition import RecognitionPredictor from surya.foundation import FoundationPredictor from surya.table_rec import TableRecPredictor from surya.settings import settings @pytest.fixture(scope="session") def ocr_error_predictor() -> OCRErrorPredictor: ocr_error_predictor = OCRErrorPredictor() yield ocr_error_predictor del ocr_error_predictor @pytest.fixture(scope="session") def layout_predictor() -> LayoutPredictor: layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)) yield layout_predictor del layout_predictor @pytest.fixture(scope="session") def detection_predictor() -> DetectionPredictor: detection_predictor = DetectionPredictor() yield detection_predictor del detection_predictor @pytest.fixture(scope="session") def recognition_predictor() -> RecognitionPredictor: recognition_predictor = RecognitionPredictor(FoundationPredictor(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT)) yield recognition_predictor del recognition_predictor @pytest.fixture(scope="session") def table_rec_predictor() -> TableRecPredictor: table_rec_predictor = TableRecPredictor() yield table_rec_predictor del table_rec_predictor @pytest.fixture() def test_image(): image = Image.new("RGB", (1024, 1024), "white") draw = ImageDraw.Draw(image) draw.text((10, 10), "Hello World", fill="black", font_size=72) draw.text( (10, 200), "This is a sentence of text.\nNow it is a paragraph.\nA three-line one.", fill="black", font_size=24, ) return image @pytest.fixture() def test_image_tall(): image = Image.new("RGB", (4096, 4096), "white") draw = ImageDraw.Draw(image) draw.text((10, 10), "Hello World", fill="black", font_size=72) draw.text( (4000, 4000), "This is a sentence of text.\n\nNow it is a paragraph.\n\nA three-line one.", fill="black", font_size=24, ) return image @pytest.fixture() def test_image_latex(): assets_dir = os.path.join(os.path.dirname(__file__), "assets") img_path = os.path.join(assets_dir, "test_latex.png") image = Image.open(img_path).convert("RGB") return image ``` -------------------------------------------------------------------------------- /surya/debug/render_html.py: -------------------------------------------------------------------------------- ```python import html as htmllib import os.path import re filepath = os.path.abspath(__file__) def render_text_as_html( bboxes: list[list[int]], texts: list[str], image_size: tuple[int, int], base_font_size: int = 16, scaler: int = 2 ): katex_path = os.path.join(os.path.dirname(filepath), "katex.js") with open(katex_path, "r") as f: katex_script = f.read() html_content = [] image_size = tuple([int(s * scaler) for s in image_size]) width, height = image_size html_content.append(f""" <!DOCTYPE html> <html> <head> <style> body {{ margin: 0; padding: 0; width: {width}px; height: {height}px; position: relative; overflow: hidden; background: white; color: black; }} .text-box {{ position: absolute; overflow: hidden; display: flex; justify-content: left; font-family: Arial, sans-serif; white-space: pre-wrap; }} .vertical-text {{ writing-mode: vertical-rl; /* Top to bottom, right to left */ }} </style> {katex_script} </head> <body> """) for i, (bbox, text) in enumerate(zip(bboxes, texts)): bbox = bbox.copy() bbox = [int(bb * scaler) for bb in bbox] x1, y1, x2, y2 = bbox width = x2 - x1 height = y2 - y1 min_dim = min(width, height) # Scale font size based on box height font_size = min(int(min_dim * 0.75), base_font_size) # Create div with absolute positioning div_style = ( f"left: {x1}px; " f"top: {y1}px; " f"width: {width}px; " f"height: {height}px; " f"font-size: {font_size}px;" ) class_ = "text-box" if height > width * 2: class_ += " vertical-text" # Determine if content is HTML/MathML or plain text if "<" in text and ">" in text and re.search(r"<(html|math|div|sub|sup|i|u|mark|small|del|b|br|code)\b", text.lower()): # Content is already HTML/MathML, include as-is html_content.append(f'<span class="{class_}" id="box-{i}" style="{div_style}">{text}</span>') else: # Plain text, escape it escaped_text = htmllib.escape(text) html_content.append(f'<span class="{class_}" id="box-{i}" style="{div_style}">{escaped_text}</span>') html_content.append("</body></html>") return "\n".join(html_content), image_size ``` -------------------------------------------------------------------------------- /surya/common/predictor.py: -------------------------------------------------------------------------------- ```python from typing import Optional import torch import torch.nn.functional as F from surya.common.load import ModelLoader from surya.settings import settings class BasePredictor: model_loader_cls = ModelLoader batch_size: Optional[int] = None default_batch_sizes = {"cpu": 1, "mps": 1, "cuda": 1} torch_dtype = settings.MODEL_DTYPE @property def disable_tqdm(self) -> bool: return self._disable_tqdm @disable_tqdm.setter def disable_tqdm(self, value: bool) -> None: self._disable_tqdm = bool(value) def __init__( self, checkpoint: Optional[str] = None, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = None, attention_implementation: Optional[str] = None, ): if dtype is None: dtype = self.torch_dtype self.model = None self.processor = None loader = self.model_loader_cls(checkpoint) self.model = loader.model(device, dtype, attention_implementation) self.processor = loader.processor() self._disable_tqdm = settings.DISABLE_TQDM def to(self, device_dtype: torch.device | str | None = None): model_moved = False if hasattr(self, "model") and self.model: self.model.to(device_dtype) model_moved = True if hasattr(self, "foundation_predictor") and self.foundation_predictor: self.foundation_predictor.model.to(device_dtype) model_moved = True if not model_moved: raise ValueError("Model not loaded") def get_batch_size(self): batch_size = self.batch_size if batch_size is None: batch_size = self.default_batch_sizes["cpu"] if settings.TORCH_DEVICE_MODEL in self.default_batch_sizes: batch_size = self.default_batch_sizes[settings.TORCH_DEVICE_MODEL] return batch_size @staticmethod def pad_to_batch_size(tensor: torch.Tensor, batch_size: int): current_batch_size = tensor.shape[0] if current_batch_size >= batch_size: return tensor if len(tensor.shape) == 1: # If tensor is 1D, we need to pad it to the batch size pad_size = batch_size - current_batch_size return F.pad(tensor, (0, pad_size), mode="constant", value=0) pad_size = batch_size - current_batch_size padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) return F.pad(tensor, padding, mode="constant", value=0) def __call__(self, *args, **kwargs): raise NotImplementedError() ``` -------------------------------------------------------------------------------- /surya/scripts/config.py: -------------------------------------------------------------------------------- ```python from typing import List import click import os from surya.input.load import load_from_folder, load_from_file from surya.settings import settings class CLILoader: def __init__(self, filepath: str, cli_options: dict, highres: bool = False): self.page_range = cli_options.get("page_range") if self.page_range: self.page_range = self.parse_range_str(self.page_range) self.filepath = filepath self.config = cli_options self.save_images = cli_options.get("images", False) self.debug = cli_options.get("debug", False) self.output_dir = cli_options.get("output_dir") self.load(highres) @staticmethod def common_options(fn): fn = click.argument("input_path", type=click.Path(exists=True), required=True)(fn) fn = click.option("--output_dir", type=click.Path(exists=False), required=False, default=os.path.join(settings.RESULT_DIR, "surya"), help="Directory to save output.")(fn) fn = click.option("--page_range", type=str, default=None, help="Page range to convert, specify comma separated page numbers or ranges. Example: 0,5-10,20")(fn) fn = click.option("--images", is_flag=True, help="Save images of detected bboxes.", default=False)(fn) fn = click.option('--debug', '-d', is_flag=True, help='Enable debug mode.', default=False)(fn) return fn def load(self, highres: bool = False): highres_images = None if os.path.isdir(self.filepath): images, names = load_from_folder(self.filepath, self.page_range) folder_name = os.path.basename(self.filepath) if highres: highres_images, _ = load_from_folder(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES) else: images, names = load_from_file(self.filepath, self.page_range) folder_name = os.path.basename(self.filepath).split(".")[0] if highres: highres_images, _ = load_from_file(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES) self.images = images self.highres_images = highres_images self.names = names self.result_path = os.path.abspath(os.path.join(self.output_dir, folder_name)) os.makedirs(self.result_path, exist_ok=True) @staticmethod def parse_range_str(range_str: str) -> List[int]: range_lst = range_str.split(",") page_lst = [] for i in range_lst: if "-" in i: start, end = i.split("-") page_lst += list(range(int(start), int(end) + 1)) else: page_lst.append(int(i)) page_lst = sorted(list(set(page_lst))) # Deduplicate page numbers and sort in order return page_lst ``` -------------------------------------------------------------------------------- /tests/test_recognition.py: -------------------------------------------------------------------------------- ```python import time from PIL import ImageDraw, Image from surya.recognition.util import clean_math_tags def test_recognition(recognition_predictor, detection_predictor, test_image): recognition_results = recognition_predictor([test_image], None, detection_predictor) assert len(recognition_results) == 1 assert recognition_results[0].image_bbox == [0, 0, 1024, 1024] text_lines = recognition_results[0].text_lines assert len(text_lines) == 4 assert "Hello World" in text_lines[0].text def test_recognition_input_text(recognition_predictor, detection_predictor, test_image): start = time.time() recognition_predictor([test_image], None, detection_predictor) end = time.time() - start input_text = "a" * 400 start2 = time.time() recognition_results = recognition_predictor( [test_image], None, detection_predictor, input_text=[input_text] ) end2 = time.time() - start2 assert max([end, end2]) / min([end, end2]) < 1.5, ( "Input text should be truncated and not change inference time" ) assert len(recognition_results) == 1 assert recognition_results[0].image_bbox == [0, 0, 1024, 1024] text_lines = recognition_results[0].text_lines assert len(text_lines) == 4 assert "Hello World" in text_lines[0].text def test_recognition_drop_repeats(recognition_predictor, detection_predictor): image = Image.new("RGB", (1024, 128), "white") draw = ImageDraw.Draw(image) text = "a" * 80 draw.text((5, 5), text, fill="black", font_size=24) recognition_results = recognition_predictor( [image], None, bboxes=[[[0, 0, 1024, 128]]], drop_repeated_text=True ) assert len(recognition_results) == 1 result = recognition_results[0].text_lines assert result[0].text == "" def test_recognition_clean_math(): math = """<math display="block">na_n^{1+2r} \\text{cov}(\\hat{f}_n^{(r)}(x), \\hat{f}_n^{(r)}(y)) = \\frac{1}{n} \\sum_{j=1}^n \\frac{a_n^{1+2r}}{a_j^{1+2r}} \\text{cov}\\left(K^{(r)}\\left(\\frac{x-X_j}{a_j}\\right), K^{(r)}\\left(\\frac{y-X_j}{a_j}\\right)\\right) <br>+ \\frac{a_n^{1+2r}}{n} \\sum_{\\substack{j \\neq k \\\\ 1 \\le j, k \\le n}} \\frac{1}{(a_j a_k)^{1+r}} \\text{cov}\\left(K^{(r)}\\left(\\frac{x-X_j}{a_j}\\right), K^{(r)}\\left(\\frac{y-X_k}{a_k}\\right)\\right) <br>=: I_1 + I_2.</math> (1.7)</math>'""" clean_math = clean_math_tags(math) assert clean_math.count("</math>") == 1, "Should have one closing math tag" assert "<br>" not in clean_math, "Should not have <br> tags in cleaned math" def test_recognition_clean_math_preserve_text(): text = """Hello, this is a sentence with <math display="inline">x^2 + y^2 = z^2</math> and some text after it, with a weird tag <hello> and <goodbye>.""" clean_text = clean_math_tags(text) assert clean_text == text ``` -------------------------------------------------------------------------------- /surya/input/processing.py: -------------------------------------------------------------------------------- ```python from typing import List import cv2 import numpy as np import pypdfium2 from PIL import Image from surya.logging import get_logger from surya.settings import settings logger = get_logger() def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]: new_images = [] for image in images: if image.mode != "RGB": image = image.convert("RGB") new_images.append(image) return new_images def open_pdf(pdf_filepath): return pypdfium2.PdfDocument(pdf_filepath) def get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI): images = [ doc[i].render(scale=dpi / 72, draw_annots=False).to_pil() for i in indices ] images = [image.convert("RGB") for image in images] return images def slice_bboxes_from_image(image: np.ndarray, bboxes): lines = [] for bbox in bboxes: bbox = np.array(bbox, dtype=np.int32) bbox = np.clip(bbox, 0, None) # Ensure no negative indices # Ensure bbox is within the image bounds if bbox[3] <= bbox[1]: bbox[3] = bbox[1] + 1 if bbox[2] <= bbox[0]: bbox[2] = bbox[0] + 1 bbox[2] = min(bbox[2], image.shape[1]) bbox[3] = min(bbox[3], image.shape[0]) line = image[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy() if line.size == 0: logger.warning(f"Warning: found an empty line with bbox {bbox}") lines.append(line) return lines def slice_polys_from_image(image: np.ndarray, polys): lines = [] for idx, poly in enumerate(polys): lines.append(slice_and_pad_poly(image, poly)) return lines def slice_and_pad_poly(image_array: np.array, coordinates): # Draw polygon onto mask coordinates = [(corner[0], corner[1]) for corner in coordinates] bbox = [ min([x[0] for x in coordinates]), min([x[1] for x in coordinates]), max([x[0] for x in coordinates]), max([x[1] for x in coordinates]), ] # We mask out anything not in the polygon cropped_polygon = image_array[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy() height, width = cropped_polygon.shape[:2] coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates] # Validate the cropped area if any( [ bbox[3] <= bbox[1] or bbox[2] <= bbox[0], len(coordinates) < 3, height == 0, width == 0, ] ): return cropped_polygon # Pad the area outside the polygon with the pad value try: mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8) cv2.fillPoly(mask, [np.int32(coordinates)], 1) mask = np.stack([mask] * 3, axis=-1) cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE except cv2.error as e: logger.warning(f"Warning: issue while processing polygon: {e}") return cropped_polygon ``` -------------------------------------------------------------------------------- /surya/table_rec/loader.py: -------------------------------------------------------------------------------- ```python from typing import Optional import torch from surya.common.load import ModelLoader from surya.logging import get_logger from surya.settings import settings from surya.table_rec.model.config import ( SuryaTableRecConfig, SuryaTableRecDecoderConfig, DonutSwinTableRecConfig, ) from surya.table_rec.model.encoderdecoder import TableRecEncoderDecoderModel from surya.table_rec.processor import SuryaTableRecProcessor logger = get_logger() class TableRecModelLoader(ModelLoader): def __init__(self, checkpoint: Optional[str] = None): super().__init__(checkpoint) if self.checkpoint is None: self.checkpoint = settings.TABLE_REC_MODEL_CHECKPOINT def model( self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE, attention_implementation: Optional[str] = None, ) -> TableRecEncoderDecoderModel: if device is None: device = settings.TORCH_DEVICE_MODEL if dtype is None: dtype = settings.MODEL_DTYPE if device == "mps": logger.warning( "`TableRecEncoderDecoderModel` is not compatible with mps backend. Defaulting to cpu instead" ) device = "cpu" dtype = "float32" config = SuryaTableRecConfig.from_pretrained(self.checkpoint) decoder_config = config.decoder decoder = SuryaTableRecDecoderConfig(**decoder_config) config.decoder = decoder encoder_config = config.encoder encoder = DonutSwinTableRecConfig(**encoder_config) config.encoder = encoder model = TableRecEncoderDecoderModel.from_pretrained( self.checkpoint, config=config, dtype=dtype ) model = model.to(device) model = model.eval() if settings.COMPILE_ALL or settings.COMPILE_TABLE_REC: torch.set_float32_matmul_precision("high") torch._dynamo.config.cache_size_limit = 16 torch._dynamo.config.suppress_errors = False logger.info( f"Compiling table recognition model {self.checkpoint} on device {device} with dtype {dtype}" ) compile_args = {"backend": "openxla"} if device == "xla" else {} model.encoder = torch.compile(model.encoder, **compile_args) model.decoder = torch.compile(model.decoder, **compile_args) logger.debug( f"Loaded table recognition model {self.checkpoint} from {TableRecEncoderDecoderModel.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" ) return model def processor( self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE ) -> SuryaTableRecProcessor: processor = SuryaTableRecProcessor(self.checkpoint) processor.token_pad_id = 0 processor.token_eos_id = 1 processor.token_bos_id = 1 processor.token_query_end_id = 4 return processor ``` -------------------------------------------------------------------------------- /surya/common/surya/decoder/config.py: -------------------------------------------------------------------------------- ```python from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging logger = logging.get_logger(__name__) class SuryaDecoderConfig(PretrainedConfig): model_type = "qwen2" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `Qwen2` base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } def __init__( self, vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act="silu", max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = False # Disable sliding window self.sliding_window = ( sliding_window # we check `use_sliding_window` in the modeling code ) self.max_window_layers = max_window_layers # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_dropout = attention_dropout # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, move it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) super().__init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) ``` -------------------------------------------------------------------------------- /benchmark/ordering.py: -------------------------------------------------------------------------------- ```python import collections import json import click from surya.foundation import FoundationPredictor from surya.input.processing import convert_if_not_rgb from surya.layout import LayoutPredictor from surya.common.polygon import PolygonBox from surya.settings import settings from benchmark.utils.metrics import rank_accuracy import os import time import datasets @click.command(help="Benchmark surya layout for reading order.") @click.option( "--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"), ) @click.option( "--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=None, ) def main(results_dir: str, max_rows: int): foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) layout_predictor = LayoutPredictor(foundation_predictor) pathname = "order_bench" # These have already been shuffled randomly, so sampling from the start is fine split = "train" if max_rows is not None: split = f"train[:{max_rows}]" dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split) images = list(dataset["image"]) images = convert_if_not_rgb(images) start = time.time() layout_predictions = layout_predictor(images) surya_time = time.time() - start folder_name = os.path.basename(pathname).split(".")[0] result_path = os.path.join(results_dir, folder_name) os.makedirs(result_path, exist_ok=True) page_metrics = collections.OrderedDict() mean_accuracy = 0 for idx, order_pred in enumerate(layout_predictions): row = dataset[idx] labels = row["labels"] bboxes = row["bboxes"] pred_positions = [] for label, bbox in zip(labels, bboxes): max_intersection = 0 matching_idx = 0 for pred_box in order_pred.bboxes: intersection = pred_box.intersection_pct(PolygonBox(polygon=bbox)) if intersection > max_intersection: max_intersection = intersection matching_idx = pred_box.position pred_positions.append(matching_idx) accuracy = rank_accuracy(pred_positions, labels) mean_accuracy += accuracy page_results = {"accuracy": accuracy, "box_count": len(labels)} page_metrics[idx] = page_results mean_accuracy /= len(layout_predictions) out_data = { "time": surya_time, "mean_accuracy": mean_accuracy, "page_metrics": page_metrics, } with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(out_data, f, indent=4) print(f"Mean accuracy is {mean_accuracy:.2f}.") print( f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total." ) print("Mean accuracy is the % of correct ranking pairs.") print(f"Wrote results to {result_path}") if __name__ == "__main__": main() ``` -------------------------------------------------------------------------------- /surya/debug/text.py: -------------------------------------------------------------------------------- ```python import re from io import BytesIO from typing import List, Tuple from PIL import Image, ImageDraw, ImageFont from surya.debug.fonts import get_font_path from surya.debug.render_html import render_text_as_html try: from playwright.sync_api import sync_playwright has_playwright = True except ImportError: has_playwright = False def strip_html_tags(html_text): pattern = re.compile(r"<[\w/][^>]*>") text_only = pattern.sub("", html_text) return text_only def get_text_size(text, font): im = Image.new(mode="P", size=(0, 0)) draw = ImageDraw.Draw(im) _, _, width, height = draw.textbbox((0, 0), text=text, font=font) return width, height def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size): font = ImageFont.truetype(font_path, box_font_size) text_width, text_height = get_text_size(text, font) while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6: box_font_size = box_font_size - 1 font = ImageFont.truetype(font_path, box_font_size) text_width, text_height = get_text_size(text, font) # Calculate text position (centered in bbox) text_width, text_height = get_text_size(text, font) x = s_bbox[0] y = s_bbox[1] + (bbox_height - text_height) / 2 draw.text((x, y), text, fill="black", font=font) def draw_text_with_playwright( bboxes, texts: List[str], image_size: Tuple[int, int] ) -> Image.Image: html_content, image_size = render_text_as_html(bboxes, texts, image_size) if not has_playwright: raise ImportError( "Playwright is not installed. Please install it using `pip install playwright`" ) with sync_playwright() as p: browser = p.chromium.launch(headless=True) page = browser.new_page( viewport={"width": image_size[0], "height": image_size[1]} ) page.set_content(html_content) page.wait_for_timeout(1000) body = page.query_selector("body") image = body.screenshot() browser.close() pil_img = Image.open(BytesIO(image)) return pil_img def draw_text_on_image( bboxes, texts, image_size: Tuple[int, int], font_path=None, max_font_size=60, res_upscale=2, ) -> Image.Image: if has_playwright: return draw_text_with_playwright(bboxes, texts, image_size) texts = [strip_html_tags(text) for text in texts] if font_path is None: font_path = get_font_path() new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale) image = Image.new("RGB", new_image_size, color="white") draw = ImageDraw.Draw(image) for bbox, text in zip(bboxes, texts): s_bbox = [int(coord * res_upscale) for coord in bbox] bbox_width = s_bbox[2] - s_bbox[0] bbox_height = s_bbox[3] - s_bbox[1] # Shrink the text to fit in the bbox if needed box_font_size = max(6, min(int(0.75 * bbox_height), max_font_size)) render_text( draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size ) return image ``` -------------------------------------------------------------------------------- /surya/recognition/postprocessing.py: -------------------------------------------------------------------------------- ```python import re from typing import List, Dict from surya.recognition.schema import TextChar def truncate_repetitions(text: str, min_len=15): # From nougat, with some cleanup if len(text) < 2 * min_len: return text # try to find a length at which the tail is repeating max_rep_len = None for rep_len in range(min_len, int(len(text) / 2)): # check if there is a repetition at the end same = True for i in range(0, rep_len): if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]: same = False break if same: max_rep_len = rep_len if max_rep_len is None: return text lcs = text[-max_rep_len:] # remove all but the last repetition text_to_truncate = text while text_to_truncate.endswith(lcs): text_to_truncate = text_to_truncate[:-max_rep_len] return text[: len(text_to_truncate)] def extract_tags(proposed_tags: List[str]) -> List[str]: tags = [] for tag in proposed_tags: tag_match = re.match(tag_pattern, tag) if not tag_match: continue if not tag_match.group(1) == "/": continue tags.append(tag_match.group(2)) return tags tag_pattern = re.compile(r"<(/?)([a-z]+)([^>]*)>?", re.IGNORECASE) def cleanup_math(line: str): matches = re.finditer(r"(<math[^>]*>)(.*?)</math>", line, re.DOTALL) result = line for match in matches: opening_tag = match.group(1) # The opening <math> tag with attributes full_match = match.group(0) # The entire <math>content</math> tag block_content = match.group(2) # Just the content inside the tags clean_block = re.sub(r"<[^>]+>", "", block_content) if not re.search(r"[\\\_]", clean_block): result = result.replace(full_match, clean_block) else: result = result.replace(full_match, f"{opening_tag}{clean_block}</math>") return result def fix_unbalanced_tags( text_chars: List[TextChar], special_tokens: Dict[str, list] ) -> List[TextChar]: self_closing_tags = ["br"] open_tags = [] format_tags = extract_tags(special_tokens["formatting"]) + extract_tags( special_tokens["math_external"] ) for char in text_chars: if len(char.text) <= 1: continue tag_match = re.match(tag_pattern, char.text) if not tag_match: continue is_closing = tag_match.group(1) == "/" tag_name = tag_match.group(2).lower() if tag_name not in format_tags: continue if tag_name in self_closing_tags: continue # Self-closing tags if tag_match.group(3) and tag_match.group(3).strip().endswith("/"): continue if is_closing: if open_tags and open_tags[-1] == tag_name: open_tags.pop() else: open_tags.append(tag_name) for tag in open_tags: text_chars.append( TextChar( text=f"</{tag}>", confidence=0, polygon=[[0, 0], [1, 0], [1, 1], [0, 1]], bbox_valid=False, ) ) return text_chars ``` -------------------------------------------------------------------------------- /surya/common/surya/config.py: -------------------------------------------------------------------------------- ```python from typing import Optional from transformers import PretrainedConfig from surya.common.s3 import S3DownloaderMixin from surya.common.surya.encoder.config import SuryaEncoderConfig from surya.common.surya.decoder.config import SuryaDecoderConfig class SuryaModelConfig(S3DownloaderMixin, PretrainedConfig): model_type = "surya-multimodal-foundation" is_composition = True def __init__( self, vocab_size=65536, bbox_size=1025, blank_bbox_token_id=1025, bos_token_id=0, eos_token_id=1, pad_token_id=2, image_token_id=3, register_token_ids=(4, 5, 6, 7), eoi_token_id=8, beacon_token_id=9, special_token_count=4, max_sequence_length=1536, special_ocr_tokens=None, vision_encoder=None, decoder=None, tasks: dict | None = None, bbox_embed_size: int = 64, num_register_tokens: int = 4, image_embed_encoding_size: int = 1024, image_embed_encoding_multiplier: int = 256, num_beacon_tokens: int = 1, beacon_token_interval: int = 4096, sliding_window: Optional[int] = None, multi_output_distance: int = 4, max_multi_out: int = 8, **kwargs, ): super().__init__(**kwargs) self.is_encoder_decoder = False self.vocab_size = vocab_size self.bbox_size = bbox_size self.blank_bbox_token_id = blank_bbox_token_id self.image_token_id = image_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.eoi_token_id = eoi_token_id self.beacon_token_id = beacon_token_id self.special_ocr_tokens = special_ocr_tokens self.special_token_count = special_token_count # pad, bos, etc, tokens self.max_sequence_length = max_sequence_length self.tasks = tasks self.tie_word_embeddings = True self.bbox_embed_size = bbox_embed_size self.num_register_tokens = num_register_tokens self.register_token_ids = register_token_ids self.image_embed_encoding_size = image_embed_encoding_size self.image_embed_encoding_multiplier = image_embed_encoding_multiplier self.num_beacon_tokens = num_beacon_tokens self.beacon_token_interval = beacon_token_interval self.sliding_window = sliding_window self.multi_output_distance = multi_output_distance self.max_multi_out = max_multi_out if self.sliding_window is None: self.sliding_window = self.max_sequence_length if isinstance(vision_encoder, dict): vision_encoder = SuryaEncoderConfig(**vision_encoder) elif vision_encoder is None: vision_encoder = SuryaEncoderConfig() self.vision_encoder = vision_encoder if isinstance(decoder, dict): decoder = SuryaDecoderConfig(**decoder) elif decoder is None: decoder = SuryaDecoderConfig() self.decoder = decoder self.hidden_size = self.decoder.hidden_size self.patch_size = self.vision_encoder.spatial_patch_size self.merge_size = self.vision_encoder.spatial_merge_size ``` -------------------------------------------------------------------------------- /surya/table_rec/processor.py: -------------------------------------------------------------------------------- ```python from typing import List import PIL import torch from transformers import ProcessorMixin from surya.common.s3 import S3DownloaderMixin from surya.common.donut.processor import SuryaEncoderImageProcessor from surya.table_rec.shaper import LabelShaper from surya.settings import settings from surya.table_rec.model.config import BOX_DIM, SPECIAL_TOKENS class SuryaTableRecProcessor(S3DownloaderMixin, ProcessorMixin): attributes = ["image_processor"] image_processor_class = "AutoImageProcessor" def __init__(self, checkpoint, **kwargs): image_processor = SuryaEncoderImageProcessor.from_pretrained(checkpoint) image_processor.do_align_long_axis = False image_processor.max_size = settings.TABLE_REC_IMAGE_SIZE self.image_processor = image_processor super().__init__(image_processor) self.box_size = (BOX_DIM, BOX_DIM) self.special_token_count = SPECIAL_TOKENS self.shaper = LabelShaper() def resize_polygon(self, polygon, orig_size, new_size): w_scaler = new_size[0] / orig_size[0] h_scaler = new_size[1] / orig_size[1] for corner in polygon: corner[0] = corner[0] * w_scaler corner[1] = corner[1] * h_scaler if corner[0] < 0: corner[0] = 0 if corner[1] < 0: corner[1] = 0 if corner[0] > new_size[0]: corner[0] = new_size[0] if corner[1] > new_size[1]: corner[1] = new_size[1] return polygon def __call__( self, images: List[PIL.Image.Image] | None, query_items: List[dict], columns: List[dict] | None = None, convert_images: bool = True, *args, **kwargs ): if convert_images: assert len(images) == len(query_items) assert len(images) > 0 # Resize input query items for image, query_item in zip(images, query_items): query_item["polygon"] = self.resize_polygon(query_item["polygon"], image.size, self.box_size) query_items = self.shaper.convert_polygons_to_bboxes(query_items) query_labels = self.shaper.dict_to_labels(query_items) decoder_input_boxes = [] col_count = len(query_labels[0]) for label in query_labels: decoder_input_boxes.append([ [self.token_bos_id] * col_count, label, [self.token_query_end_id] * col_count ]) # Add columns to end of decoder input if columns: columns = self.shaper.convert_polygons_to_bboxes(columns) column_labels = self.shaper.dict_to_labels(columns) for decoder_box in decoder_input_boxes: decoder_box += column_labels input_boxes = torch.tensor(decoder_input_boxes, dtype=torch.long) input_boxes_mask = torch.ones_like(input_boxes, dtype=torch.long) inputs = { "input_ids": input_boxes, "attention_mask": input_boxes_mask } if convert_images: inputs["pixel_values"] = self.image_processor(images, *args, **kwargs)["pixel_values"] return inputs ``` -------------------------------------------------------------------------------- /benchmark/utils/tatr.py: -------------------------------------------------------------------------------- ```python import torch from transformers import AutoModelForObjectDetection from surya.settings import settings import numpy as np class MaxResize(object): def __init__(self, max_size=800): self.max_size = max_size def __call__(self, image): width, height = image.size current_max_size = max(width, height) scale = self.max_size / current_max_size resized_image = image.resize((int(round(scale * width)), int(round(scale * height)))) return resized_image def to_tensor(image): # Convert PIL Image to NumPy array np_image = np.array(image).astype(np.float32) # Rearrange dimensions to [C, H, W] format np_image = np_image.transpose((2, 0, 1)) # Normalize to [0.0, 1.0] np_image /= 255.0 return torch.from_numpy(np_image) def normalize(tensor, mean, std): for t, m, s in zip(tensor, mean, std): t.sub_(m).div_(s) return tensor def structure_transform(image): image = MaxResize(1000)(image) tensor = to_tensor(image) normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) return normalized_tensor def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=1) def rescale_bboxes(out_bbox, size): width, height = size boxes = box_cxcywh_to_xyxy(out_bbox) boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32) return boxes def outputs_to_objects(outputs, img_sizes, id2label): m = outputs.logits.softmax(-1).max(-1) batch_labels = list(m.indices.detach().cpu().numpy()) batch_scores = list(m.values.detach().cpu().numpy()) batch_bboxes = outputs['pred_boxes'].detach().cpu() batch_objects = [] for i in range(len(img_sizes)): pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])] pred_scores = batch_scores[i] pred_labels = batch_labels[i] objects = [] for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): class_label = id2label[int(label)] if not class_label == 'no object': objects.append({ 'label': class_label, 'score': float(score), 'bbox': [float(elem) for elem in bbox]} ) rows = [] cols = [] for cell in objects: if cell["label"] == "table column": cols.append(cell) if cell["label"] == "table row": rows.append(cell) batch_objects.append({ "rows": rows, "cols": cols }) return batch_objects def load_tatr(): return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL) def batch_inference_tatr(model, images, batch_size): device = model.device rows_cols = [] for i in range(0, len(images), batch_size): batch_images = images[i:i + batch_size] pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device) # forward pass with torch.no_grad(): outputs = model(pixel_values) id2label = model.config.id2label id2label[len(model.config.id2label)] = "no object" rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label)) return rows_cols ``` -------------------------------------------------------------------------------- /benchmark/texify.py: -------------------------------------------------------------------------------- ```python import os.path import re import time from pathlib import Path from typing import List import click import datasets from tabulate import tabulate from bs4 import BeautifulSoup from surya.common.surya.schema import TaskNames from surya.settings import settings from surya.foundation import FoundationPredictor from surya.recognition import RecognitionPredictor, OCRResult import json from rapidfuzz.distance import Levenshtein def normalize_text(text): soup = BeautifulSoup(text, "html.parser") # Unwrap math tags for tag in soup.find_all(): if tag.name == "math": tag.unwrap() text = soup.get_text() text = re.sub(r"\n", " ", text) text = re.sub(r"\s+", " ", text) return text.strip() def score_text(predictions, references): lev_dist = [] for p, r in zip(predictions, references): p = normalize_text(p) r = normalize_text(r) lev_dist.append(Levenshtein.normalized_distance(p, r)) return sum(lev_dist) / len(lev_dist) def inference_texify( source_data, predictor: RecognitionPredictor, line_mode: bool = False ): images = [sd["image"] for sd in source_data] mode = TaskNames.ocr_with_boxes if line_mode else TaskNames.block_without_boxes tasks = [mode] * len(images) bboxes = [[[0, 0, image.width, image.height]] for image in images] texify_predictions: List[OCRResult] = predictor(images, tasks, bboxes=bboxes) out_data = [ { "text": texify_predictions[i].text_lines[0].text, "equation": source_data[i]["equation"], } for i in range(len(texify_predictions)) ] return out_data @click.command(help="Benchmark the performance of texify.") @click.option( "--ds_name", type=str, help="Path to dataset file with source images/equations.", default=settings.TEXIFY_BENCHMARK_DATASET, ) @click.option( "--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"), ) @click.option( "--max_rows", type=int, help="Maximum number of images to benchmark.", default=None ) @click.option( "--line_mode", is_flag=True, help="Use line mode for texify.", default=False ) def main(ds_name: str, results_dir: str, max_rows: int, line_mode: bool): foundation_predictor = FoundationPredictor() predictor = RecognitionPredictor(foundation_predictor) ds = datasets.load_dataset(ds_name, split="train") if max_rows: ds = ds.filter(lambda x, idx: idx < max_rows, with_indices=True) start = time.time() predictions = inference_texify(ds, predictor, line_mode) time_taken = time.time() - start text = [p["text"] for p in predictions] references = [p["equation"] for p in predictions] scores = score_text(text, references) write_data = { "scores": scores, "text": [{"prediction": p, "reference": r} for p, r in zip(text, references)], } score_table = [["texify", write_data["scores"], time_taken]] score_headers = ["edit", "time taken (s)"] score_dirs = ["⬇", "⬇"] score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)] table = tabulate(score_table, headers=["Method", *score_headers]) print() print(table) result_path = Path(results_dir) / "texify_bench" result_path.mkdir(parents=True, exist_ok=True) with open(result_path / "results.json", "w", encoding="utf-8") as f: json.dump(write_data, f, indent=4) if __name__ == "__main__": main() ``` -------------------------------------------------------------------------------- /surya/table_rec/model/encoder.py: -------------------------------------------------------------------------------- ```python from typing import Optional, Union, Tuple import torch import torch.nn as nn from surya.common.donut.encoder import DonutSwinPreTrainedModel, DonutSwinModelOutput, DonutSwinEmbeddings, DonutSwinEncoder class DonutSwinModel(DonutSwinPreTrainedModel): def __init__(self, config, add_pooling_layer=True, use_mask_token=False): super().__init__(config) self.config = config self.num_layers = len(config.depths) self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token) self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) self.position_embeddings = None if hasattr(config, "encoder_length"): self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size)) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.patch_embeddings def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, bool_masked_pos: Optional[torch.BoolTensor] = None, head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, return_dict: Optional[bool] = None, ) -> Union[Tuple, DonutSwinModelOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) embedding_output, input_dimensions = self.embeddings( pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding ) encoder_outputs = self.encoder( embedding_output, input_dimensions, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] if self.position_embeddings is not None: last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :] return DonutSwinModelOutput( last_hidden_state=last_hidden_state, ) ``` -------------------------------------------------------------------------------- /surya/table_rec/model/encoderdecoder.py: -------------------------------------------------------------------------------- ```python from dataclasses import dataclass from typing import Optional, Union, Tuple, Dict import torch from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig from surya.common.pretrained import SuryaPreTrainedModel from surya.common.s3 import S3DownloaderMixin from surya.table_rec.model.decoder import SuryaTableRecDecoder from surya.table_rec.model.encoder import DonutSwinModel from transformers.utils import ModelOutput @dataclass class TableRecOutput(ModelOutput): box_property_logits: Dict[str, torch.FloatTensor] decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None class TableRecEncoderDecoderModel(S3DownloaderMixin, SuryaPreTrainedModel): config_class = VisionEncoderDecoderConfig base_model_prefix = "vision_encoder_decoder" main_input_name = "pixel_values" supports_gradient_checkpointing = True _supports_param_buffer_assignment = False def __init__( self, config: Optional[PretrainedConfig] = None, encoder: Optional[PreTrainedModel] = None, decoder: Optional[PreTrainedModel] = None, **kwargs, ): # initialize with config # make sure input & output embeddings is not tied config.tie_word_embeddings = False config.decoder.tie_word_embeddings = False super().__init__(config, **kwargs) if encoder is None: encoder = DonutSwinModel(config.encoder) if decoder is None: decoder = SuryaTableRecDecoder( config.decoder, attn_implementation=config._attn_implementation ) self.encoder = encoder self.decoder = decoder # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced self.encoder.config = self.config.encoder self.decoder.config = self.config.decoder def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder def get_output_embeddings(self): return self.decoder.get_output_embeddings() def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) def forward( self, decoder_input_ids: torch.LongTensor = None, decoder_cache_position: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.FloatTensor], TableRecOutput]: kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } # Decode decoder_outputs = self.decoder( input_labels=decoder_input_ids, input_boxes_counts=None, cache_position=decoder_cache_position, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs, encoder_attention_mask=None, use_cache=use_cache, **kwargs_decoder, ) return TableRecOutput( box_property_logits=decoder_outputs.box_property_logits, decoder_hidden_states=decoder_outputs.hidden_states, ) def resize_token_embeddings(self, *args, **kwargs): raise NotImplementedError( "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" ) def _reorder_cache(self, past_key_values, beam_idx): # apply decoder cache reordering here return self.decoder._reorder_cache(past_key_values, beam_idx) ``` -------------------------------------------------------------------------------- /signatures/version1/cla.json: -------------------------------------------------------------------------------- ```json { "signedContributors": [ { "name": "rishiraj", "id": 44090649, "comment_id": 2170578748, "created_at": "2024-06-15T19:31:20Z", "repoId": 741297064, "pullRequestNo": 135 }, { "name": "mmacvicar", "id": 59354, "comment_id": 2236493182, "created_at": "2024-07-18T13:17:43Z", "repoId": 741297064, "pullRequestNo": 152 }, { "name": "jimexist", "id": 622789, "comment_id": 2255151376, "created_at": "2024-07-29T07:23:55Z", "repoId": 741297064, "pullRequestNo": 160 }, { "name": "michaeldriscoll-avant", "id": 85255083, "comment_id": 2259143427, "created_at": "2024-07-30T20:21:33Z", "repoId": 741297064, "pullRequestNo": 161 }, { "name": "EdoardoPona", "id": 29152472, "comment_id": 2271115922, "created_at": "2024-08-06T11:58:00Z", "repoId": 741297064, "pullRequestNo": 167 }, { "name": "hidenori-endo", "id": 15546605, "comment_id": 2307217499, "created_at": "2024-08-23T14:31:17Z", "repoId": 741297064, "pullRequestNo": 182 }, { "name": "dobosevych", "id": 12053536, "comment_id": 2430376828, "created_at": "2024-10-22T21:48:34Z", "repoId": 741297064, "pullRequestNo": 220 }, { "name": "iammosespaulr", "id": 28682735, "comment_id": 2447941238, "created_at": "2024-10-30T17:55:23Z", "repoId": 741297064, "pullRequestNo": 235 }, { "name": "ArthurMor4is", "id": 42987302, "comment_id": 2515315717, "created_at": "2024-12-03T18:37:45Z", "repoId": 741297064, "pullRequestNo": 255 }, { "name": "tarun-menta", "id": 66506307, "comment_id": 2543457960, "created_at": "2024-12-15T05:43:33Z", "repoId": 741297064, "pullRequestNo": 261 }, { "name": "jonaskahn", "id": 4338500, "comment_id": 2556622097, "created_at": "2024-12-20T09:36:20Z", "repoId": 741297064, "pullRequestNo": 269 }, { "name": "kumsumit", "id": 95072784, "comment_id": 2574534622, "created_at": "2025-01-07T07:05:59Z", "repoId": 741297064, "pullRequestNo": 276 }, { "name": "kevinhu", "id": 6051736, "comment_id": 2614135351, "created_at": "2025-01-25T23:34:12Z", "repoId": 741297064, "pullRequestNo": 291 }, { "name": "zanussbaum", "id": 33707069, "comment_id": 3008673416, "created_at": "2025-06-26T14:20:46Z", "repoId": 741297064, "pullRequestNo": 403 }, { "name": "mebriki", "id": 35892987, "comment_id": 3154706976, "created_at": "2025-08-05T10:54:27Z", "repoId": 741297064, "pullRequestNo": 418 }, { "name": "starikovplusplus", "id": 56602036, "comment_id": 3168958011, "created_at": "2025-08-08T18:29:50Z", "repoId": 741297064, "pullRequestNo": 423 }, { "name": "sandy0kwon", "id": 78377296, "comment_id": 3207932260, "created_at": "2025-08-20T20:07:15Z", "repoId": 741297064, "pullRequestNo": 434 }, { "name": "n0kovo", "id": 16690056, "comment_id": 3208251881, "created_at": "2025-08-20T22:22:06Z", "repoId": 741297064, "pullRequestNo": 435 }, { "name": "davidxifeng", "id": 158052, "comment_id": 3249594859, "created_at": "2025-09-03T14:52:16Z", "repoId": 741297064, "pullRequestNo": 445 }, { "name": "u-ashish", "id": 14264791, "comment_id": 3258734182, "created_at": "2025-09-05T15:16:48Z", "repoId": 741297064, "pullRequestNo": 447 }, { "name": "Mohking1", "id": 63689545, "comment_id": 3314908963, "created_at": "2025-09-20T11:21:42Z", "repoId": 741297064, "pullRequestNo": 462 }, { "name": "wkpark", "id": 232347, "comment_id": 3330009557, "created_at": "2025-09-24T17:42:55Z", "repoId": 741297064, "pullRequestNo": 464 } ] } ``` -------------------------------------------------------------------------------- /surya/layout/__init__.py: -------------------------------------------------------------------------------- ```python from typing import List from PIL import Image from surya.common.predictor import BasePredictor from surya.layout.schema import LayoutBox, LayoutResult from surya.settings import settings from surya.foundation import FoundationPredictor, TaskNames from surya.foundation.util import prediction_to_polygon_batch from surya.input.processing import convert_if_not_rgb from surya.layout.label import LAYOUT_PRED_RELABEL from surya.common.util import clean_boxes class LayoutPredictor(BasePredictor): batch_size = settings.LAYOUT_BATCH_SIZE default_batch_sizes = {"cpu": 4, "mps": 4, "cuda": 32, "xla": 16} # Override base init - Do not load model def __init__(self, foundation_predictor: FoundationPredictor): self.foundation_predictor = foundation_predictor self.processor = self.foundation_predictor.processor self.bbox_size = self.foundation_predictor.model.config.bbox_size self.tasks = self.foundation_predictor.tasks # Special handling for disable tqdm to pass into foundation predictor # Make sure they are kept in sync @property def disable_tqdm(self) -> bool: return super().disable_tqdm @disable_tqdm.setter def disable_tqdm(self, value: bool) -> None: self._disable_tqdm = bool(value) self.foundation_predictor.disable_tqdm = bool(value) def __call__( self, images: List[Image.Image], batch_size: int | None = None, top_k: int = 5 ) -> List[LayoutResult]: assert all([isinstance(image, Image.Image) for image in images]) if batch_size is None: batch_size = self.get_batch_size() if len(images) == 0: return [] images = convert_if_not_rgb(images) images = [self.processor.image_processor(image) for image in images] predicted_tokens, batch_bboxes, scores, topk_scores = ( self.foundation_predictor.prediction_loop( images=images, input_texts=["" for _ in range(len(images))], task_names=[TaskNames.layout for _ in range(len(images))], batch_size=batch_size, max_lookahead_tokens=0, # Do not do MTP for layout top_k=5, max_sliding_window=576, max_tokens=500, tqdm_desc="Recognizing Layout" ) ) image_sizes = [img.shape for img in images] predicted_polygons = prediction_to_polygon_batch( batch_bboxes, image_sizes, self.bbox_size, self.bbox_size // 2 ) layout_results = [] for image, image_tokens, image_polygons, image_scores, image_topk_scores in zip( images, predicted_tokens, predicted_polygons, scores, topk_scores ): layout_boxes = [] for z, (tok, poly, score, tok_topk) in enumerate( zip(image_tokens, image_polygons, image_scores, image_topk_scores) ): if tok == self.processor.eos_token_id: break predicted_label = self.processor.decode([tok], "layout") label = LAYOUT_PRED_RELABEL.get(predicted_label) if not label: # Layout can sometimes return unknown labels from other objectives continue top_k_dict = {} for k, v in tok_topk.items(): topk_label = self.processor.decode([k], "layout") if topk_label in LAYOUT_PRED_RELABEL: topk_label = LAYOUT_PRED_RELABEL[topk_label] if not topk_label.strip(): continue top_k_dict.update({topk_label: v}) layout_boxes.append( LayoutBox( polygon=poly.tolist(), label=label, position=z, top_k=top_k_dict, confidence=score, ) ) layout_boxes = clean_boxes(layout_boxes) layout_results.append( LayoutResult( bboxes=layout_boxes, image_bbox=[0, 0, image.shape[1], image.shape[0]], ) # Image is numpy array ) assert len(layout_results) == len(images) return layout_results ``` -------------------------------------------------------------------------------- /surya/scripts/table_recognition.py: -------------------------------------------------------------------------------- ```python import os import click import copy import json from collections import defaultdict from surya.logging import configure_logging, get_logger from surya.scripts.config import CLILoader from surya.foundation import FoundationPredictor from surya.layout import LayoutPredictor from surya.table_rec import TableRecPredictor from surya.debug.draw import draw_bboxes_on_image from surya.common.util import rescale_bbox, expand_bbox from surya.settings import settings configure_logging() logger = get_logger() @click.command(help="Detect layout of an input file or folder (PDFs or image).") @CLILoader.common_options @click.option( "--skip_table_detection", is_flag=True, help="Tables are already cropped, so don't re-detect tables.", default=False, ) def table_recognition_cli(input_path: str, skip_table_detection: bool, **kwargs): loader = CLILoader(input_path, kwargs, highres=True) foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) layout_predictor = LayoutPredictor(foundation_predictor) table_rec_predictor = TableRecPredictor() pnums = [] prev_name = None for i, name in enumerate(loader.names): if prev_name is None or prev_name != name: pnums.append(0) else: pnums.append(pnums[-1] + 1) prev_name = name layout_predictions = layout_predictor(loader.images) table_imgs = [] table_counts = [] for layout_pred, img, highres_img in zip( layout_predictions, loader.images, loader.highres_images ): # The table may already be cropped if skip_table_detection: table_imgs.append(highres_img) table_counts.append(1) else: # The bbox for the entire table bbox = [ line.bbox for line in layout_pred.bboxes if line.label in ["Table", "TableOfContents"] ] # Number of tables per page table_counts.append(len(bbox)) if len(bbox) == 0: continue page_table_imgs = [] highres_bbox = [] for bb in bbox: highres_bb = rescale_bbox(bb, img.size, highres_img.size) highres_bb = expand_bbox(highres_bb) page_table_imgs.append(highres_img.crop(highres_bb)) highres_bbox.append(highres_bb) table_imgs.extend(page_table_imgs) table_preds = table_rec_predictor(table_imgs) img_idx = 0 prev_count = 0 table_predictions = defaultdict(list) for i in range(sum(table_counts)): while i >= prev_count + table_counts[img_idx]: prev_count += table_counts[img_idx] img_idx += 1 pred = table_preds[i] orig_name = loader.names[img_idx] pnum = pnums[img_idx] table_img = table_imgs[i] out_pred = pred.model_dump() out_pred["page"] = pnum + 1 table_idx = i - prev_count out_pred["table_idx"] = table_idx table_predictions[orig_name].append(out_pred) if loader.save_images: rows = [line.bbox for line in pred.rows] cols = [line.bbox for line in pred.cols] row_labels = [f"Row {line.row_id}" for line in pred.rows] col_labels = [f"Col {line.col_id}" for line in pred.cols] cells = [line.bbox for line in pred.cells] rc_image = copy.deepcopy(table_img) rc_image = draw_bboxes_on_image( rows, rc_image, labels=row_labels, label_font_size=20, color="blue" ) rc_image = draw_bboxes_on_image( cols, rc_image, labels=col_labels, label_font_size=20, color="red" ) rc_image.save( os.path.join( loader.result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png" ) ) cell_image = copy.deepcopy(table_img) cell_image = draw_bboxes_on_image(cells, cell_image, color="green") cell_image.save( os.path.join( loader.result_path, f"{name}_page{pnum + 1}_table{table_idx}_cells.png", ) ) with open( os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" ) as f: json.dump(table_predictions, f, ensure_ascii=False) logger.info(f"Wrote results to {loader.result_path}") ``` -------------------------------------------------------------------------------- /CLA.md: -------------------------------------------------------------------------------- ```markdown Surya Contributor Agreement This Surya Contributor Agreement ("SCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Endless Labs, Inc. The term "you" shall mean the person or entity identified below. If you agree to be bound by these terms, sign by writing "I have read the CLA document and I hereby sign the CLA" in response to the CLA bot Github comment. Read this agreement carefully before signing. These terms and conditions constitute a binding legal agreement. 1. The term 'contribution' or 'contributed materials' means any source code, object code, patch, tool, sample, graphic, specification, manual, documentation, or any other material posted or submitted by you to the project. 2. With respect to any worldwide copyrights, or copyright applications and registrations, in your contribution: - you hereby assign to us joint ownership, and to the extent that such assignment is or becomes invalid, ineffective or unenforceable, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty free, unrestricted license to exercise all rights under those copyrights. This includes, at our option, the right to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements, including dual-license structures for commercial customers; - you agree that each of us can do all things in relation to your contribution as if each of us were the sole owners, and if one of us makes a derivative work of your contribution, the one who makes the derivative work (or has it made will be the sole owner of that derivative work; - you agree that you will not assert any moral rights in your contribution against us, our licensees or transferees; - you agree that we may register a copyright in your contribution and exercise all ownership rights associated with it; and - you agree that neither of us has any duty to consult with, obtain the consent of, pay or render an accounting to the other for any use or distribution of vour contribution. 3. With respect to any patents you own, or that you can license without payment to any third party, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty-free license to: - make, have made, use, sell, offer to sell, import, and otherwise transfer your contribution in whole or in part, alone or in combination with or included in any product, work or materials arising out of the project to which your contribution was submitted, and - at our option, to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements. If you or your affiliates institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the contribution or any project it was submitted to constitutes direct or contributory patent infringement, then any patent licenses granted to you under this agreement for that contribution shall terminate as of the date such litigation is filed. 4. Except as set out above, you keep all right, title, and interest in your contribution. The rights that you grant to us under these terms are effective on the date you first submitted a contribution to us, even if your submission took place before the date you sign these terms. Any contribution we make available under any license will also be made available under a suitable FSF (Free Software Foundation) or OSI (Open Source Initiative) approved license. 5. You covenant, represent, warrant and agree that: - each contribution that you submit is and shall be an original work of authorship and you can legally grant the rights set out in this SCA; - to the best of your knowledge, each contribution will not violate any third party's copyrights, trademarks, patents, or other intellectual property rights; and - each contribution shall be in compliance with U.S. export control laws and other applicable export and import laws. You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Endless Labs, Inc. may publicly disclose your participation in the project, including the fact that you have signed the SCA. 6. This SCA is governed by the laws of the State of California and applicable U.S. Federal law. Any choice of law rules will not apply. ``` -------------------------------------------------------------------------------- /surya/scripts/texify_app.py: -------------------------------------------------------------------------------- ```python import os import re from typing import List from surya.recognition import RecognitionPredictor from surya.foundation import FoundationPredictor from surya.common.surya.schema import TaskNames os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = ( "1" # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS ) import io import pandas as pd import streamlit as st from streamlit_drawable_canvas import st_canvas import hashlib import pypdfium2 from surya.settings import settings from PIL import Image MAX_WIDTH = 800 MAX_HEIGHT = 1000 def replace_fences(text): text = re.sub(r'<math display="block">(.*?)</math>', r"$$\1$$", text) text = re.sub(r"<math>(.*?)</math>", r"$\1$", text) text = re.sub(r'<math display="inline">(.*?)</math>', r"$\1$", text) return text @st.cache_resource() def load_predictor(): foundation_predictor = FoundationPredictor() return RecognitionPredictor(foundation_predictor) @st.cache_data() def inference(pil_image: Image.Image, bbox: List[float]): input_img = pil_image.crop(bbox) bbox = [0, 0, input_img.width, input_img.height] model_output = predictor( [input_img], [TaskNames.block_without_boxes], bboxes=[[bbox]] ) return model_output[0].text_lines[0].text def open_pdf(pdf_file): stream = io.BytesIO(pdf_file.getvalue()) return pypdfium2.PdfDocument(stream) @st.cache_data() def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI_HIGHRES): doc = open_pdf(pdf_file) renderer = doc.render( pypdfium2.PdfBitmap.to_pil, page_indices=[page_num - 1], scale=dpi / 72, ) png = list(renderer)[0] png_image = png.convert("RGB") doc.close() return png_image @st.cache_data() def page_counter(pdf_file): doc = open_pdf(pdf_file) doc_len = len(doc) doc.close() return doc_len def resize_image(pil_image): if pil_image is None: return pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) def get_canvas_hash(pil_image): return hashlib.md5(pil_image.tobytes()).hexdigest() st.set_page_config(layout="wide") top_message = """### LaTeX OCR After the model loads, upload an image or a pdf, then draw a box around the equation or text you want to OCR by clicking and dragging. Surya will convert it to Markdown with LaTeX math on the right. """ st.markdown(top_message) col1, col2 = st.columns([0.7, 0.3]) predictor = load_predictor() in_file = st.sidebar.file_uploader( "PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"] ) if in_file is None: st.stop() if in_file is None: st.stop() filetype = in_file.type page_count = None if "pdf" in filetype: page_count = page_counter(in_file) page_number = st.sidebar.number_input( f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count ) pil_image = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES) else: pil_image = Image.open(in_file).convert("RGB") page_number = None if pil_image is None: st.stop() pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) canvas_hash = get_canvas_hash(pil_image) with col1: # Create a canvas component canvas_result = st_canvas( fill_color="rgba(255, 165, 0, 0.1)", # Fixed fill color with some opacity stroke_width=1, stroke_color="#FFAA00", background_color="#FFF", background_image=pil_image, update_streamlit=True, height=pil_image.height, width=pil_image.width, drawing_mode="rect", point_display_radius=0, key=canvas_hash, ) if not canvas_result.json_data: st.stop() objects = pd.json_normalize( canvas_result.json_data["objects"] ) # need to convert obj to str because PyArrow bbox_list = None if objects.shape[0] > 0: boxes = objects[objects["type"] == "rect"][["left", "top", "width", "height"]] boxes["right"] = boxes["left"] + boxes["width"] boxes["bottom"] = boxes["top"] + boxes["height"] bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist() if bbox_list: with col2: texts = [inference(pil_image, bbox) for bbox in bbox_list] for idx, latex in enumerate(reversed(texts)): st.markdown(f"### {len(texts) - idx}") st.markdown(replace_fences(latex), unsafe_allow_html=True) st.code(latex) st.divider() with col2: tips = """ ### Usage tips - Texify is sensitive to how you draw the box around the text you want to OCR. If you get bad results, try selecting a slightly different box, or splitting the box into multiple. """ st.markdown(tips) ``` -------------------------------------------------------------------------------- /surya/scripts/finetune_ocr.py: -------------------------------------------------------------------------------- ```python from __future__ import annotations from dataclasses import dataclass, field from typing import Optional, Tuple from datasets import load_dataset import numpy as np import torch from transformers import ( HfArgumentParser, TrainingArguments, Trainer, ) from surya.common.surya import SuryaModel from surya.common.surya.processor import SuryaOCRProcessor from surya.foundation import FoundationPredictor from surya.common.surya.processor.schema import ImageInput, TextInput from surya.common.surya.schema import TaskNames from surya.common.util import get_top_scripts, SCRIPT_TOKEN_MAPPING # Do not change these defaults OCR_TASK_NAME = TaskNames.ocr_with_boxes OCR_MAX_IMAGE_SIZE = (1024, 512) # Simple wrapper for huggingface dataset class SuryaOCRDataset(torch.utils.data.Dataset): def __init__(self, processor: SuryaOCRProcessor, data_args: SuryaOCRDataArguments): super().__init__() self.hf_dataset = load_dataset(data_args.dataset_name, num_proc=data_args.num_loading_proc, split="train") self.processor = processor def __len__(self): return len(self.hf_dataset) def get_script_text(self, text: str) -> str: scripts = get_top_scripts(text) script_text = "".join(SCRIPT_TOKEN_MAPPING[script] for script in scripts) return script_text def __getitem__(self, index): try: data = self.hf_dataset[index] image = data["image"] image = image.convert("RGB") image = np.asarray(image, dtype=np.float32) image = self.processor.scale_to_fit(image, max_size=OCR_MAX_IMAGE_SIZE) # Add in script information gt_text = data["text"] gt_text = self.get_script_text(gt_text) + gt_text return_dict = { "task": TaskNames.ocr_with_boxes, "inputs": [ ImageInput(type="image", image=image, rotated=False), # This empty TextInput **must be included** to match the original format TextInput(type="text", text=""), TextInput(type="text",text=gt_text), ], } return return_dict except: import traceback; traceback.print_exc() return self.__getitem__((index + 1) % self.__len__()) class SuryaOCRDataCollator: def __init__(self, processor: SuryaOCRProcessor, data_args: SuryaOCRDataArguments): self.processor = processor self.max_sequence_length = data_args.max_sequence_length def __call__(self, inputs): # Use right padding for training. Defaults to left for inference processed_batch = self.processor(inputs, padding_side="right") if self.max_sequence_length is not None: processed_batch["input_ids"] = processed_batch["input_ids"][:, :self.max_sequence_length] processed_batch["attention_mask"] = processed_batch["attention_mask"][:, :self.max_sequence_length] processed_batch["position_ids"] = processed_batch["position_ids"][:, :self.max_sequence_length] lm_labels = processed_batch["input_ids"].clone() skip_label_mask = ( (lm_labels == self.processor.pad_token_id ) | (lm_labels == self.processor.bos_token_id[TaskNames.ocr_with_boxes]) | (lm_labels == self.processor.eoi_token_id) | (lm_labels == self.processor.image_token_id) ) lm_labels[skip_label_mask] = -100 processed_batch["labels"] = lm_labels return processed_batch def load_model_and_processor(checkpoint_path: Optional[str] = None) -> Tuple[SuryaModel, SuryaOCRProcessor]: foundation_predictor = FoundationPredictor(checkpoint=checkpoint_path) return foundation_predictor.model, foundation_predictor.processor @dataclass class SuryaOCRModelArguments: pretrained_checkpoint_path: Optional[str] = field(default=None) @dataclass class SuryaOCRDataArguments: dataset_name: str = field(default="datalab-to/ocr_finetune_example") num_loading_proc: int = field(default=16) max_sequence_length: Optional[int] = field(default=None) @dataclass class SuryaOCRTrainingArguments(TrainingArguments): remove_unused_columns: bool = field(default=False) def main(): parser = HfArgumentParser((SuryaOCRModelArguments, SuryaOCRDataArguments, SuryaOCRTrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() model, processor = load_model_and_processor(model_args.pretrained_checkpoint_path) dataset = SuryaOCRDataset(processor, data_args) collator = SuryaOCRDataCollator(processor, data_args) trainer = Trainer( model=model, args=training_args, train_dataset=dataset, data_collator=collator ) trainer.train() if __name__ == "__main__": main() ```