#
tokens: 48039/50000 24/133 files (page 2/4)
lines: off (toggle) GitHub
raw markdown copy
This is page 2 of 4. Use http://codebase.md/datalab-to/surya?page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── ISSUE_TEMPLATE
│   │   ├── breaking-bug-report.md
│   │   ├── feature_request.md
│   │   └── output-bug-report.md
│   └── workflows
│       ├── benchmarks.yml
│       ├── ci.yml
│       ├── cla.yml
│       ├── publish.yml
│       └── scripts.yml
├── .gitignore
├── .pre-commit-config.yaml
├── benchmark
│   ├── detection.py
│   ├── layout.py
│   ├── ordering.py
│   ├── recognition.py
│   ├── table_recognition.py
│   ├── texify.py
│   └── utils
│       ├── __init__.py
│       ├── bbox.py
│       ├── metrics.py
│       ├── scoring.py
│       ├── tatr.py
│       ├── tesseract.py
│       ├── textract.py
│       └── verify_benchmark_scores.py
├── CITATION.cff
├── CLA.md
├── detect_layout.py
├── detect_text.py
├── LICENSE
├── ocr_app.py
├── ocr_latex.py
├── ocr_text.py
├── poetry.lock
├── pyproject.toml
├── pytest.ini
├── README.md
├── signatures
│   └── version1
│       └── cla.json
├── static
│   ├── fonts
│   │   └── .gitignore
│   └── images
│       ├── arabic_layout.jpg
│       ├── arabic_reading.jpg
│       ├── arabic_text.jpg
│       ├── arabic.jpg
│       ├── benchmark_chart_small.png
│       ├── benchmark_chart.png
│       ├── benchmark_layout_chart.png
│       ├── benchmark_rec_chart.png
│       ├── benchmark_tablerec_acc.png
│       ├── benchmark_tablerec_speed.png
│       ├── chi_hind_layout.jpg
│       ├── chi_hind_orig.jpg
│       ├── chi_hind_reading.jpg
│       ├── chi_hind_text.jpg
│       ├── chi_hind.jpg
│       ├── chinese_layout.jpg
│       ├── chinese_reading.jpg
│       ├── chinese_text.jpg
│       ├── chinese.jpg
│       ├── excerpt_layout.png
│       ├── excerpt_reading.jpg
│       ├── excerpt_text.png
│       ├── excerpt.png
│       ├── funsd_layout.jpg
│       ├── funsd_reading.jpg
│       ├── funsd_text.jpg
│       ├── funsd.png
│       ├── gcloud_full_langs.png
│       ├── gcloud_rec_bench.png
│       ├── hindi_layout.jpg
│       ├── hindi_reading.jpg
│       ├── hindi_text.jpg
│       ├── hindi.jpg
│       ├── japanese_layout.jpg
│       ├── japanese_reading.jpg
│       ├── japanese_tablerec.png
│       ├── japanese_text.jpg
│       ├── japanese.jpg
│       ├── latex_ocr.png
│       ├── nyt_layout.jpg
│       ├── nyt_order.jpg
│       ├── nyt_text.jpg
│       ├── nyt.jpg
│       ├── paper_layout.jpg
│       ├── paper_reading.jpg
│       ├── paper_tablerec.png
│       ├── paper_text.jpg
│       ├── paper.jpg
│       ├── pres_layout.jpg
│       ├── pres_reading.jpg
│       ├── pres_tablerec.png
│       ├── pres_text.jpg
│       ├── pres.png
│       ├── rec_acc_table.png
│       ├── scanned_layout.jpg
│       ├── scanned_reading.jpg
│       ├── scanned_tablerec.png
│       ├── scanned_tablerec2.png
│       ├── scanned_text.jpg
│       ├── scanned.png
│       ├── surya_rec_perf.png
│       ├── table_rec.png
│       ├── textbook_layout.jpg
│       ├── textbook_order.jpg
│       ├── textbook_text.jpg
│       └── textbook.jpg
├── surya
│   ├── __init__.py
│   ├── common
│   │   ├── __init__.py
│   │   ├── adetr
│   │   │   └── decoder.py
│   │   ├── donut
│   │   │   ├── encoder.py
│   │   │   └── processor.py
│   │   ├── load.py
│   │   ├── polygon.py
│   │   ├── predictor.py
│   │   ├── pretrained.py
│   │   ├── s3.py
│   │   ├── surya
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   ├── decoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── embedder
│   │   │   │   └── __init__.py
│   │   │   ├── encoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── flash_attn_utils.py
│   │   │   ├── processor
│   │   │   │   ├── __init__.py
│   │   │   │   ├── schema.py
│   │   │   │   └── tokenizer.py
│   │   │   └── schema.py
│   │   ├── util.py
│   │   └── xla.py
│   ├── debug
│   │   ├── draw.py
│   │   ├── fonts.py
│   │   ├── katex.js
│   │   ├── render_html.py
│   │   └── text.py
│   ├── detection
│   │   ├── __init__.py
│   │   ├── heatmap.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoderdecoder.py
│   │   ├── parallel.py
│   │   ├── processor.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── foundation
│   │   ├── __init__.py
│   │   ├── cache
│   │   │   ├── __init__.py
│   │   │   ├── dynamic_ops.py
│   │   │   └── static_ops.py
│   │   ├── loader.py
│   │   └── util.py
│   ├── input
│   │   ├── load.py
│   │   └── processing.py
│   ├── layout
│   │   ├── __init__.py
│   │   ├── label.py
│   │   └── schema.py
│   ├── logging.py
│   ├── models.py
│   ├── ocr_error
│   │   ├── __init__.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoder.py
│   │   ├── schema.py
│   │   └── tokenizer.py
│   ├── recognition
│   │   ├── __init__.py
│   │   ├── languages.py
│   │   ├── postprocessing.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── scripts
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── detect_layout.py
│   │   ├── detect_text.py
│   │   ├── finetune_ocr.py
│   │   ├── hf_to_s3.py
│   │   ├── ocr_latex.py
│   │   ├── ocr_text.py
│   │   ├── run_streamlit_app.py
│   │   ├── run_texify_app.py
│   │   ├── streamlit_app.py
│   │   ├── table_recognition.py
│   │   └── texify_app.py
│   ├── settings.py
│   └── table_rec
│       ├── __init__.py
│       ├── loader.py
│       ├── model
│       │   ├── __init__.py
│       │   ├── config.py
│       │   ├── decoder.py
│       │   ├── encoder.py
│       │   └── encoderdecoder.py
│       ├── processor.py
│       ├── schema.py
│       └── shaper.py
├── table_recognition.py
├── tests
│   ├── assets
│   │   └── test_latex.png
│   ├── conftest.py
│   ├── test_detection.py
│   ├── test_foundation.py
│   ├── test_latex_ocr.py
│   ├── test_layout.py
│   ├── test_ocr_errors.py
│   ├── test_recognition.py
│   └── test_table_rec.py
└── texify_app.py
```

# Files

--------------------------------------------------------------------------------
/benchmark/utils/tesseract.py:
--------------------------------------------------------------------------------

```python
from typing import List, Optional

import numpy as np
from tqdm import tqdm

from surya.input.processing import slice_bboxes_from_image
from surya.settings import settings
import os
from concurrent.futures import ProcessPoolExecutor
from surya.recognition.languages import CODE_TO_LANGUAGE
from surya.recognition import RecognitionPredictor
from surya.detection import DetectionPredictor


def surya_lang_to_tesseract(code: str) -> Optional[str]:
    lang_str = CODE_TO_LANGUAGE[code]
    try:
        tess_lang = TESS_LANGUAGE_TO_CODE[lang_str]
    except KeyError:
        return None
    return tess_lang


def tesseract_ocr(img, bboxes, lang: str):
    import pytesseract
    line_imgs = slice_bboxes_from_image(img, bboxes)
    config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"'
    lines = []
    for line_img in line_imgs:
        line = pytesseract.image_to_string(line_img, lang=lang, config=config)
        lines.append(line)
    return lines


def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None):
    tess_parallel_cores = min(len(imgs), RecognitionPredictor.get_batch_size())
    if not cpus:
        cpus = os.cpu_count()
    tess_parallel_cores = min(tess_parallel_cores, cpus)

    # Tesseract uses up to 4 processes per instance
    # Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images
    tess_parallel = max(tess_parallel_cores // 2, 1)

    with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
        tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc="Running tesseract OCR")
        tess_text = list(tess_text)
    return tess_text


def tesseract_bboxes(img):
    import pytesseract
    from pytesseract import Output
    arr_img = np.asarray(img, dtype=np.uint8)
    ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT)

    bboxes = []
    n_boxes = len(ocr['level'])
    for i in range(n_boxes):
        # It is possible to merge by line here with line number, but it gives bad results.
        _, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i]
        bbox = (x, y, x + w, y + h)
        bboxes.append(bbox)

    return bboxes


def tesseract_parallel(imgs):
    # Tesseract uses 4 threads per instance
    tess_parallel_cores = min(len(imgs), DetectionPredictor.get_batch_size())
    cpus = os.cpu_count()
    tess_parallel_cores = min(tess_parallel_cores, cpus)

    # Tesseract uses 4 threads per instance
    tess_parallel = max(tess_parallel_cores // 4, 1)

    with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
        tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc="Running tesseract bbox detection")
        tess_bboxes = list(tess_bboxes)
    return tess_bboxes


TESS_CODE_TO_LANGUAGE = {
    "afr": "Afrikaans",
    "amh": "Amharic",
    "ara": "Arabic",
    "asm": "Assamese",
    "aze": "Azerbaijani",
    "bel": "Belarusian",
    "ben": "Bengali",
    "bod": "Tibetan",
    "bos": "Bosnian",
    "bre": "Breton",
    "bul": "Bulgarian",
    "cat": "Catalan",
    "ceb": "Cebuano",
    "ces": "Czech",
    "chi_sim": "Chinese",
    "chr": "Cherokee",
    "cym": "Welsh",
    "dan": "Danish",
    "deu": "German",
    "dzo": "Dzongkha",
    "ell": "Greek",
    "eng": "English",
    "epo": "Esperanto",
    "est": "Estonian",
    "eus": "Basque",
    "fas": "Persian",
    "fin": "Finnish",
    "fra": "French",
    "fry": "Western Frisian",
    "guj": "Gujarati",
    "gla": "Scottish Gaelic",
    "gle": "Irish",
    "glg": "Galician",
    "heb": "Hebrew",
    "hin": "Hindi",
    "hrv": "Croatian",
    "hun": "Hungarian",
    "hye": "Armenian",
    "iku": "Inuktitut",
    "ind": "Indonesian",
    "isl": "Icelandic",
    "ita": "Italian",
    "jav": "Javanese",
    "jpn": "Japanese",
    "kan": "Kannada",
    "kat": "Georgian",
    "kaz": "Kazakh",
    "khm": "Khmer",
    "kir": "Kyrgyz",
    "kor": "Korean",
    "lao": "Lao",
    "lat": "Latin",
    "lav": "Latvian",
    "lit": "Lithuanian",
    "mal": "Malayalam",
    "mar": "Marathi",
    "mkd": "Macedonian",
    "mlt": "Maltese",
    "mon": "Mongolian",
    "msa": "Malay",
    "mya": "Burmese",
    "nep": "Nepali",
    "nld": "Dutch",
    "nor": "Norwegian",
    "ori": "Oriya",
    "pan": "Punjabi",
    "pol": "Polish",
    "por": "Portuguese",
    "pus": "Pashto",
    "ron": "Romanian",
    "rus": "Russian",
    "san": "Sanskrit",
    "sin": "Sinhala",
    "slk": "Slovak",
    "slv": "Slovenian",
    "snd": "Sindhi",
    "spa": "Spanish",
    "sqi": "Albanian",
    "srp": "Serbian",
    "swa": "Swahili",
    "swe": "Swedish",
    "syr": "Syriac",
    "tam": "Tamil",
    "tel": "Telugu",
    "tgk": "Tajik",
    "tha": "Thai",
    "tir": "Tigrinya",
    "tur": "Turkish",
    "uig": "Uyghur",
    "ukr": "Ukrainian",
    "urd": "Urdu",
    "uzb": "Uzbek",
    "vie": "Vietnamese",
    "yid": "Yiddish"
}

TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()}

```

--------------------------------------------------------------------------------
/benchmark/layout.py:
--------------------------------------------------------------------------------

```python
import collections
import copy
import json

import click

from benchmark.utils.metrics import precision_recall
from surya.foundation import FoundationPredictor
from surya.layout import LayoutPredictor
from surya.input.processing import convert_if_not_rgb
from surya.debug.draw import draw_bboxes_on_image
from surya.settings import settings
import os
import time
from tabulate import tabulate
import datasets


@click.command(help="Benchmark surya layout model.")
@click.option(
    "--results_dir",
    type=str,
    help="Path to JSON file with OCR results.",
    default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
@click.option(
    "--max_rows",
    type=int,
    help="Maximum number of images to run benchmark on.",
    default=100,
)
@click.option("--debug", is_flag=True, help="Run in debug mode.", default=False)
def main(results_dir: str, max_rows: int, debug: bool):
    foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
    layout_predictor = LayoutPredictor(foundation_predictor)

    pathname = "layout_bench"
    # These have already been shuffled randomly, so sampling from the start is fine
    dataset = datasets.load_dataset(
        settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{max_rows}]"
    )
    images = list(dataset["image"])
    images = convert_if_not_rgb(images)

    if settings.LAYOUT_STATIC_CACHE:
        layout_predictor(images[:1])

    start = time.time()
    layout_predictions = layout_predictor(images)
    surya_time = time.time() - start

    folder_name = os.path.basename(pathname).split(".")[0]
    result_path = os.path.join(results_dir, folder_name)
    os.makedirs(result_path, exist_ok=True)

    label_alignment = {  # First is publaynet, second is surya
        "Image": [["Figure"], ["Picture", "Figure"]],
        "Table": [["Table"], ["Table", "Form", "TableOfContents"]],
        "Text": [
            ["Text"],
            [
                "Text",
                "Formula",
                "Footnote",
                "Caption",
                "TextInlineMath",
                "Code",
                "Handwriting",
            ],
        ],
        "List": [["List"], ["ListItem"]],
        "Title": [["Title"], ["SectionHeader", "Title"]],
    }

    page_metrics = collections.OrderedDict()
    for idx, pred in enumerate(layout_predictions):
        row = dataset[idx]
        all_correct_bboxes = []
        page_results = {}
        for label_name in label_alignment:
            correct_cats, surya_cats = label_alignment[label_name]
            correct_bboxes = [
                b
                for b, category in zip(row["bboxes"], row["labels"])
                if category in correct_cats
            ]
            all_correct_bboxes.extend(correct_bboxes)
            pred_bboxes = [b.bbox for b in pred.bboxes if b.label in surya_cats]

            metrics = precision_recall(
                pred_bboxes, correct_bboxes, penalize_double=False
            )
            weight = len(correct_bboxes)
            metrics["weight"] = weight
            page_results[label_name] = metrics

        page_metrics[idx] = page_results

        if debug:
            bbox_image = draw_bboxes_on_image(
                all_correct_bboxes, copy.deepcopy(images[idx])
            )
            bbox_image.save(os.path.join(result_path, f"{idx}_layout.png"))

    mean_metrics = collections.defaultdict(dict)
    layout_types = sorted(page_metrics[0].keys())
    metric_types = sorted(page_metrics[0][layout_types[0]].keys())
    metric_types.remove("weight")
    for label in layout_types:
        for m in metric_types:
            metric = []
            total = 0
            for page in page_metrics:
                metric.append(
                    page_metrics[page][label][m] * page_metrics[page][label]["weight"]
                )
                total += page_metrics[page][label]["weight"]

            value = sum(metric)
            if value > 0:
                value /= total
            mean_metrics[label][m] = value

    out_data = {
        "time": surya_time,
        "metrics": mean_metrics,
        "page_metrics": page_metrics,
    }

    with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
        json.dump(out_data, f, indent=4)

    table_headers = [
        "Layout Type",
    ] + metric_types
    table_data = []
    for layout_type in layout_types:
        table_data.append(
            [
                layout_type,
            ]
            + [f"{mean_metrics[layout_type][m]:.5f}" for m in metric_types]
        )

    print(tabulate(table_data, headers=table_headers, tablefmt="github"))
    print(
        f"Took {surya_time / len(images):.5f} seconds per image, and {surya_time:.5f} seconds total."
    )
    print(
        "Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold."
    )
    print(f"Wrote results to {result_path}")


if __name__ == "__main__":
    main()

```

--------------------------------------------------------------------------------
/surya/foundation/loader.py:
--------------------------------------------------------------------------------

```python
from typing import Optional

import torch
from transformers.utils import is_flash_attn_2_available

from surya.common.load import ModelLoader
from surya.common.surya.config import SuryaModelConfig
from surya.common.surya import SuryaModel, SuryaXLAModel
from surya.common.surya.processor import SuryaOCRProcessor
from surya.common.surya.processor.tokenizer import SuryaOCRTokenizer
from surya.common.util import is_flash_attn_2_supported
from surya.common.xla import get_compile_args
from surya.logging import get_logger
from surya.settings import settings

logger = get_logger()


class FoundationModelLoader(ModelLoader):
    def __init__(self, checkpoint: Optional[str] = None):
        super().__init__(checkpoint)

        if self.checkpoint is None:
            self.checkpoint = settings.FOUNDATION_MODEL_CHECKPOINT

    def model(
        self,
        device=settings.TORCH_DEVICE_MODEL,
        dtype=None,
        attention_implementation: Optional[str] = None,
    ) -> SuryaModel:
        if device is None:
            device = settings.TORCH_DEVICE_MODEL
        if dtype is None:
            # See https://github.com/pytorch/pytorch/issues/118122 - T4 (device version 7.5) will return true since it supports
            # emulated bf16, but falls back to very slow kernels, especially for SDPA
            dtype = settings.MODEL_DTYPE_BFLOAT
            if device == "cuda" and not torch.cuda.is_bf16_supported(
                including_emulation=False
            ):
                # If the device is cuda, we check if bf16 is supported, and if not, we use float16
                dtype = settings.MODEL_DTYPE
        elif dtype == torch.float16:
            dtype = torch.bfloat16  # Model weights in bfloat16

        config = SuryaModelConfig.from_pretrained(self.checkpoint)

        if attention_implementation is not None:
            config.decoder._attn_implementation = attention_implementation
            config.vision_encoder._attn_implementation = attention_implementation
        elif is_flash_attn_2_available() and is_flash_attn_2_supported(device):
            config.decoder._attn_implementation = "flash_attention_2"
            config.vision_encoder._attn_implementation = "flash_attention_2"
        elif device == "xla":
            config.decoder._attn_implementation = "sdpa"
            config.vision_encoder._attn_implementation = "sdpa"
        else:
            config.decoder._attn_implementation = "sdpa"
            config.vision_encoder._attn_implementation = "sdpa"

        model_cls = SuryaModel
        if device == "xla":
            model_cls = SuryaXLAModel

        config._attn_implementation_autoset = True
        config.vision_encoder._attn_implementation_autoset = True
        config.decoder._attn_implementation_autoset = True

        model = model_cls.from_pretrained(
            self.checkpoint, dtype=dtype, config=config, ignore_mismatched_sizes=True
        ).to(device)
        model = model.eval()

        if settings.COMPILE_ALL or settings.COMPILE_FOUNDATION:
            torch._dynamo.config.cache_size_limit = 1000
            torch._dynamo.config.suppress_errors = True
            torch._dynamo.config.specialize_int = False
            torch._dynamo.config.allow_unspec_int_on_nn_module = True
            torch._dynamo.config.capture_scalar_outputs = True
            torch._dynamo.config.recompile_limit = 32

            logger.info(
                f"Compiling foundation model {self.checkpoint} on device {device} with dtype {dtype}"
            )
            compile_args = get_compile_args(device)
            model.vision_encoder = torch.compile(model.vision_encoder, **compile_args)
            model.decoder = torch.compile(model.decoder, **compile_args)

        logger.debug(
            f"Loaded recognition model {self.checkpoint} from {SuryaModel.get_local_path(self.checkpoint)} onto device {model.device} with dtype {dtype}, using decoder attention mechanism {model.config.decoder._attn_implementation}, encoder attention mechanism {model.config.vision_encoder._attn_implementation}."
        )
        return model

    def processor(
        self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE_BFLOAT
    ) -> SuryaOCRProcessor:
        config: SuryaModelConfig = SuryaModelConfig.from_pretrained(self.checkpoint)

        ocr_tokenizer = SuryaOCRTokenizer(
            special_tokens=config.special_ocr_tokens, model_checkpoint=self.checkpoint
        )

        processor = SuryaOCRProcessor(
            ocr_tokenizer=ocr_tokenizer,
            blank_bbox_token_id=config.blank_bbox_token_id,
            num_register_tokens=config.num_register_tokens,
            sequence_length=None,
            patch_size=config.vision_encoder.patch_size,
            merge_size=config.vision_encoder.spatial_merge_size,
            model_device=device,
            num_beacon_tokens=config.num_beacon_tokens,
            beacon_token_interval=config.beacon_token_interval,
        )

        return processor

```

--------------------------------------------------------------------------------
/surya/table_rec/shaper.py:
--------------------------------------------------------------------------------

```python
import math
from typing import List, Dict
import numpy as np

from surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM


class LabelShaper:
    def __init__(self):
        self.property_keys = [k for (k, kcount, mode) in BOX_PROPERTIES]

    def dict_to_labels(self, label_components: List[dict]):
        if len(label_components) == 0:
            return []

        out_list = []
        for (k, kcount, mode) in BOX_PROPERTIES:
            for label_component in label_components:
                if k not in label_component:
                    raise ValueError(f"Missing key {k} in label component {label_component}")

                if mode == "classification":
                    assert isinstance(label_component[k], int)
                elif mode == "regression":
                    assert (isinstance(label_component[k], (int, float)) and kcount == 1) or len(label_component[k]) == kcount
                else:
                    raise ValueError(f"Invalid mode {k['mode']} for key {k}")

        for label_component in label_components:
            bbox = label_component["bbox"]
            for i in range(len(bbox)):
                if bbox[i] < 0:
                    bbox[i] = 0
                if bbox[i] > BOX_DIM:
                    bbox[i] = BOX_DIM

            vector = []
            for (k, kcount, mode) in BOX_PROPERTIES:
                item = label_component[k]
                if isinstance(item, (list, tuple)):
                    vector += list(item)
                elif isinstance(item, (float, int)):
                    if mode == "classification":
                        # Shift up for model
                        item += SPECIAL_TOKENS
                    vector.append(item)
                else:
                    raise ValueError(f"Invalid item {item} for key {k}")

            out_list.append(vector)

        return out_list

    def component_idx(self, key):
        idx = 0
        for (k, kcount, mode) in BOX_PROPERTIES:
            if mode == "regression":
                incr = kcount
            elif mode == "classification":
                incr = 1
            else:
                raise ValueError(f"Invalid mode {mode} for key {k}")
            if k == key:
                return (idx, idx + incr)
            idx += incr
        raise ValueError(f"Key {key} not found in properties")

    def get_box_property(self, key, add_special_tokens=True):
        for (k, kcount, mode) in BOX_PROPERTIES:
            if k == key:
                # Add special token count
                if mode == "classification" and add_special_tokens:
                    kcount += SPECIAL_TOKENS
                return (k, kcount, mode)
        raise ValueError(f"Key {key} not found in properties")

    def component_idx_dict(self):
        idx_dict = {}
        for (k, kcount, mode) in BOX_PROPERTIES:
            idx_dict[k] = self.component_idx(k)
        return idx_dict

    def convert_polygons_to_bboxes(self, label_components: List[Dict]):
        for i, label_component in enumerate(label_components):
            poly = label_component["polygon"]
            poly = np.clip(poly, 0, BOX_DIM)

            (x1, y1), (x2, y2), (x3, y3), (x4, y4) = poly
            cx = (x1 + x2 + x3 + x4) / 4
            cy = (y1 + y2 + y3 + y4) / 4
            width = (x2 + x3) / 2 - (x1 + x4) / 2
            height = (y3 + y4) / 2 - (y2 + y1) / 2
            bottom_avg_x = (x3 + x4) / 2
            top_avg_x = (x1 + x2) / 2
            right_avg_y = (y2 + y3) / 2
            left_avg_y = (y1 + y4) / 2

            x_skew = bottom_avg_x - top_avg_x
            y_skew = right_avg_y - left_avg_y
            x_skew += BOX_DIM // 2 # Shift up into positive space
            y_skew += BOX_DIM // 2 # Shift up into positive space
            new_poly = [
                cx,
                cy,
                width,
                height,
                x_skew,
                y_skew
            ]
            label_component["bbox"] = new_poly

        return label_components

    def convert_bbox_to_polygon(self, box, skew_scaler=BOX_DIM // 2, skew_min=.001):
        cx = box[0]
        cy = box[1]
        width = box[2]
        height = box[3]
        x1 = cx - width / 2
        y1 = cy - height / 2
        x2 = cx + width / 2
        y2 = cy + height / 2
        skew_x = math.floor((box[4] - skew_scaler) / 2)
        skew_y = math.floor((box[5] - skew_scaler) / 2)

        # Ensures we don't get slightly warped boxes
        # Note that the values are later scaled, so this is in 1/1024 space
        if abs(skew_x) < skew_min:
            skew_x = 0

        if abs(skew_y) < skew_min:
            skew_y = 0

        polygon = [x1 - skew_x, y1 - skew_y, x2 - skew_x, y1 + skew_y, x2 + skew_x, y2 + skew_y, x1 + skew_x,
                   y2 - skew_y]
        poly = []
        for i in range(4):
            poly.append([
                polygon[2 * i],
                polygon[2 * i + 1]
            ])
        return poly




```

--------------------------------------------------------------------------------
/benchmark/detection.py:
--------------------------------------------------------------------------------

```python
import argparse
import collections
import copy
import json

import click

from benchmark.utils.bbox import get_pdf_lines
from benchmark.utils.metrics import precision_recall
from benchmark.utils.tesseract import tesseract_parallel
from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb
from surya.debug.draw import draw_polys_on_image
from surya.common.util import rescale_bbox
from surya.settings import settings
from surya.detection import DetectionPredictor

import os
import time
from tabulate import tabulate
import datasets


@click.command(help="Benchmark detection model.")
@click.option("--pdf_path", type=str, help="Path to PDF to detect bboxes in.", default=None)
@click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
@click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=100)
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
@click.option("--tesseract", is_flag=True, help="Run tesseract as well.", default=False)
def main(pdf_path: str, results_dir: str, max_rows: int, debug: bool, tesseract: bool):
    det_predictor = DetectionPredictor()

    if pdf_path is not None:
        pathname = pdf_path
        doc = open_pdf(pdf_path)
        page_count = len(doc)
        page_indices = list(range(page_count))
        page_indices = page_indices[:max_rows]

        images = get_page_images(doc, page_indices)
        doc.close()

        image_sizes = [img.size for img in images]
        correct_boxes = get_pdf_lines(pdf_path, image_sizes)
    else:
        pathname = "det_bench"
        # These have already been shuffled randomly, so sampling from the start is fine
        dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{max_rows}]")
        images = list(dataset["image"])
        images = convert_if_not_rgb(images)
        correct_boxes = []
        for i, boxes in enumerate(dataset["bboxes"]):
            img_size = images[i].size
            # 1000,1000 is bbox size for doclaynet
            correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes])

    if settings.DETECTOR_STATIC_CACHE:
        # Run through one batch to compile the model
        det_predictor(images[:1])

    start = time.time()
    predictions = det_predictor(images)
    surya_time = time.time() - start

    if tesseract:
        start = time.time()
        tess_predictions = tesseract_parallel(images)
        tess_time = time.time() - start
    else:
        tess_predictions = [None] * len(images)
        tess_time = None

    folder_name = os.path.basename(pathname).split(".")[0]
    result_path = os.path.join(results_dir, folder_name)
    os.makedirs(result_path, exist_ok=True)

    page_metrics = collections.OrderedDict()
    for idx, (tb, sb, cb) in enumerate(zip(tess_predictions, predictions, correct_boxes)):
        surya_boxes = [s.bbox for s in sb.bboxes]
        surya_polys = [s.polygon for s in sb.bboxes]

        surya_metrics = precision_recall(surya_boxes, cb)
        if tb is not None:
            tess_metrics = precision_recall(tb, cb)
        else:
            tess_metrics = None

        page_metrics[idx] = {
            "surya": surya_metrics,
            "tesseract": tess_metrics
        }

        if debug:
            bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx]))
            bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png"))

    mean_metrics = {}
    metric_types = sorted(page_metrics[0]["surya"].keys())
    models = ["surya"]
    if tesseract:
        models.append("tesseract")

    for k in models:
        for m in metric_types:
            metric = []
            for page in page_metrics:
                metric.append(page_metrics[page][k][m])
            if k not in mean_metrics:
                mean_metrics[k] = {}
            mean_metrics[k][m] = sum(metric) / len(metric)

    out_data = {
        "times": {
            "surya": surya_time,
            "tesseract": tess_time
        },
        "metrics": mean_metrics,
        "page_metrics": page_metrics
    }

    with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
        json.dump(out_data, f, indent=4)

    table_headers = ["Model", "Time (s)", "Time per page (s)"] + metric_types
    table_data = [
        ["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types],
    ]
    if tesseract:
        table_data.append(
            ["tesseract", tess_time, tess_time / len(images)] + [mean_metrics["tesseract"][m] for m in metric_types]
        )

    print(tabulate(table_data, headers=table_headers, tablefmt="github"))
    print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold.  There is a precision penalty for multiple boxes overlapping reference lines.")
    print(f"Wrote results to {result_path}")


if __name__ == "__main__":
    main()

```

--------------------------------------------------------------------------------
/surya/detection/heatmap.py:
--------------------------------------------------------------------------------

```python
from typing import List

import cv2
import numpy as np
from PIL import Image

from surya.common.util import clean_boxes
from surya.detection import TextDetectionResult
from surya.common.polygon import PolygonBox
from surya.settings import settings


def get_dynamic_thresholds(linemap, text_threshold, low_text, typical_top10_avg=0.7):
    # Find average intensity of top 10% pixels
    flat_map = linemap.ravel()
    top_10_count = int(len(flat_map) * 0.9)
    avg_intensity = np.mean(np.partition(flat_map, top_10_count)[top_10_count:])
    scaling_factor = np.clip(avg_intensity / typical_top10_avg, 0, 1) ** (1 / 2)

    low_text = np.clip(low_text * scaling_factor, 0.1, 0.6)
    text_threshold = np.clip(text_threshold * scaling_factor, 0.15, 0.8)

    return text_threshold, low_text


def detect_boxes(linemap, text_threshold, low_text):
    # From CRAFT - https://github.com/clovaai/CRAFT-pytorch
    # Modified to return boxes and for speed, accuracy
    img_h, img_w = linemap.shape

    text_threshold, low_text = get_dynamic_thresholds(linemap, text_threshold, low_text)

    text_score_comb = (linemap > low_text).astype(np.uint8)
    label_count, labels, stats, centroids = cv2.connectedComponentsWithStats(
        text_score_comb, connectivity=4
    )

    det = []
    confidences = []
    max_confidence = 0

    for k in range(1, label_count):
        # size filtering
        size = stats[k, cv2.CC_STAT_AREA]
        if size < 10:
            continue

        # make segmentation map
        x, y, w, h = stats[
            k,
            [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT],
        ]

        try:
            niter = int(np.sqrt(min(w, h)))
        except ValueError:
            niter = 0

        buffer = 1
        sx, sy = max(0, x - niter - buffer), max(0, y - niter - buffer)
        ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer)

        mask = labels[sy:ey, sx:ex] == k
        selected_linemap = linemap[sy:ey, sx:ex][mask]
        if selected_linemap.size == 0:
            continue

        line_max = np.max(selected_linemap)

        # thresholding
        if line_max < text_threshold:
            continue

        segmap = mask.astype(np.uint8)

        ksize = buffer + niter
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (ksize, ksize))
        selected_segmap = cv2.dilate(segmap, kernel)

        # make box
        y_inds, x_inds = np.nonzero(selected_segmap)
        x_inds += sx
        y_inds += sy
        np_contours = np.column_stack((x_inds, y_inds))
        rectangle = cv2.minAreaRect(np_contours)
        box = cv2.boxPoints(rectangle)

        # align diamond-shape
        w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
        box_ratio = max(w, h) / (min(w, h) + 1e-5)
        if abs(1 - box_ratio) <= 0.1:
            left, right = np_contours[:, 0].min(), np_contours[:, 0].max()
            top, bottom = np_contours[:, 1].min(), np_contours[:, 1].max()
            box = np.array(
                [[left, top], [right, top], [right, bottom], [left, bottom]],
                dtype=np.float32,
            )

        # make clock-wise order
        startidx = box.sum(axis=1).argmin()
        box = np.roll(box, 4 - startidx, 0)

        max_confidence = max(max_confidence, line_max)

        confidences.append(line_max)
        det.append(box)

    if max_confidence > 0:
        confidences = [c / max_confidence for c in confidences]
    return det, confidences


def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]:
    if text_threshold is None:
        text_threshold = settings.DETECTOR_TEXT_THRESHOLD
    if low_text is None:
        low_text = settings.DETECTOR_BLANK_THRESHOLD

    if textmap.dtype != np.float32:
        textmap = textmap.astype(np.float32)

    boxes, confidences = detect_boxes(textmap, text_threshold, low_text)
    # From point form to box form
    return [
        PolygonBox(polygon=box, confidence=confidence)
        for box, confidence in zip(boxes, confidences)
    ]


def get_and_clean_boxes(
    textmap, processor_size, image_size, text_threshold=None, low_text=None
) -> List[PolygonBox]:
    bboxes = get_detected_boxes(textmap, text_threshold, low_text)
    for bbox in bboxes:
        bbox.rescale(processor_size, image_size)
        bbox.fit_to_bounds([0, 0, image_size[0], image_size[1]])

    bboxes = clean_boxes(bboxes)
    return bboxes


def parallel_get_boxes(preds, orig_sizes, include_maps=False):
    heatmap, affinity_map = preds
    heat_img, aff_img = None, None

    if include_maps:
        heat_img = Image.fromarray((heatmap * 255).astype(np.uint8))
        aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8))
    heatmap_size = list(reversed(heatmap.shape))
    bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes)
    for box in bboxes:
        # Skip for vertical boxes
        if box.height < 3 * box.width:
            box.expand(x_margin=0, y_margin=settings.DETECTOR_BOX_Y_EXPAND_MARGIN)
            box.fit_to_bounds(
                [0, 0, orig_sizes[0], orig_sizes[1]]
            )  # Fix any bad expands

    result = TextDetectionResult(
        bboxes=bboxes,
        heatmap=heat_img,
        affinity_map=aff_img,
        image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]],
    )
    return result

```

--------------------------------------------------------------------------------
/surya/recognition/util.py:
--------------------------------------------------------------------------------

```python
import re
from typing import List, Tuple

import numpy
import torch

from surya.common.polygon import PolygonBox
from surya.recognition.schema import TextLine, TextWord, TextChar

MATH_SYMBOLS = ["+", "-", "*", "=", "^", "_", "\\", "{", "}"]


def unwrap_math(text: str) -> str:
    if len(text) > 50:
        return text

    # Detected as math, but does not contain LaTeX commands
    if (
        re.match(r'^\s*<math(?:\s+display="inline")?.*?</math>\s*$', text, re.DOTALL)
        and text.count("<math") == 1
        and not any([symb in text for symb in MATH_SYMBOLS])
    ):
        # Remove math tags
        text = re.sub(r"<math.*?>", "", text)
        text = re.sub(r"</math>", "", text)

    return text


MATH_BLOCK = re.compile(r"(<math\b[^>]*>)(.*?)</math>", flags=re.I | re.S)
STRIP_TAGS = re.compile(r"</?(?:br|u|del|mark|i|b|sup|sub)\b[^>]*>", flags=re.I | re.S)
DEFAULT_TAGS_TO_FILTER = ["p", "li", "ul", "ol", "table", "td", "tr", "th", "tbody", "pre"]

def filter_blacklist_tags(text_chars: List[TextChar], tags_to_filter: List[str] = None) -> List[TextChar]:
    filtered_chars = []
    char_buffer = []
    in_tag = False
    if tags_to_filter is None:
        tags_to_filter = DEFAULT_TAGS_TO_FILTER

    for text_char in text_chars:
        char = text_char.text

        if char.startswith("<") or in_tag:
            in_tag = True
            char_buffer.append(text_char)
            if char.endswith(">"):
                full_tag = ''.join(c.text for c in char_buffer)
                inner = full_tag[1:-1].strip()  # remove < >
                inner = inner.strip("/")  # remove '/'
                
                # Possible that it is just an empty <>
                if not inner:
                    filtered_chars.extend(char_buffer)
                    in_tag = False
                    char_buffer = []
                    continue
                
                tag_name_candidate = inner.split()[0]   # remove any attributes
                if tag_name_candidate in tags_to_filter:
                    # Discard tag
                    pass
                else:
                    # Keep tag
                    filtered_chars.extend(char_buffer)

                in_tag = False
                char_buffer = []
        else:
            filtered_chars.append(text_char)

    # Flush buffer if we never reached a tag close
    if char_buffer:
        filtered_chars.extend(char_buffer)

    return filtered_chars


def clean_math_tags(html: str) -> str:
    # strip unwanted tags inside every well‑formed <math>…</math>
    def _inner(m):
        inner = STRIP_TAGS.sub("", m.group(2))
        return f"{m.group(1)}{inner}</math>" if inner.strip() else ""

    cleaned = MATH_BLOCK.sub(_inner, html)

    # drop only orphan *closing* </math> tags
    depth = 0
    parts = []
    for token in re.split(r"(</?math[^>]*>)", cleaned, flags=re.I):
        if token.lower().startswith("<math"):
            depth += 1
            parts.append(token)
        elif token.lower() == "</math>":
            if depth:  # keep it only if it matches an open
                depth -= 1
                parts.append(token)
            # else: skip orphan closing tag
        else:
            parts.append(token)
    return "".join(parts)


def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25):
    # Sorts in reading order.  Not 100% accurate, this should only
    # be used as a starting point for more advanced sorting.
    vertical_groups = {}
    for line in lines:
        group_key = (
            round(
                line.bbox[1]
                if isinstance(line, TextLine)
                else line["bbox"][1] / tolerance
            )
            * tolerance
        )
        if group_key not in vertical_groups:
            vertical_groups[group_key] = []
        vertical_groups[group_key].append(line)

    # Sort each group horizontally and flatten the groups into a single list
    sorted_lines = []
    for _, group in sorted(vertical_groups.items()):
        sorted_group = sorted(
            group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0]
        )
        sorted_lines.extend(sorted_group)

    return sorted_lines


def clean_close_polygons(bboxes: List[List[List[int]]], thresh: float = 0.1):
    if len(bboxes) < 2:
        return bboxes

    new_bboxes = [bboxes[0]]
    for i in range(1, len(bboxes)):
        close = True
        prev_bbox = bboxes[i - 1]
        bbox = bboxes[i]
        for j in range(4):
            if (
                abs(bbox[j][0] - prev_bbox[j][0]) > thresh
                or abs(bbox[j][1] - prev_bbox[j][1]) > thresh
            ):
                close = False
                break

        if not close:
            new_bboxes.append(bboxes[i])

    return new_bboxes


def words_from_chars(chars: List[TextChar], line_box: PolygonBox):
    words = []
    word = None
    for i, char in enumerate(chars):
        if not char.bbox_valid:
            if word:
                words.append(word)
                word = None
            continue

        if not word:
            word = TextWord(**char.model_dump())

            # Fit bounds to line if first word
            if i == 0:
                word.merge_left(line_box)

        elif not char.text.strip():
            if word:
                words.append(word)
            word = None
        else:
            # Merge bboxes
            word.merge(char)
            word.text = word.text + char.text

            if i == len(chars) - 1:
                word.merge_right(line_box)
    if word:
        words.append(word)

    return words
```

--------------------------------------------------------------------------------
/surya/common/s3.py:
--------------------------------------------------------------------------------

```python
import json
import os
import shutil
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

import requests
from tqdm import tqdm

from surya.logging import get_logger
from surya.settings import settings

logger = get_logger()

# Lock file expiration time in seconds (10 minutes)
LOCK_EXPIRATION = 600


def join_urls(url1: str, url2: str):
    url1 = url1.rstrip("/")
    url2 = url2.lstrip("/")
    return f"{url1}/{url2}"


def get_model_name(pretrained_model_name_or_path: str):
    return pretrained_model_name_or_path.split("/")[0]


def download_file(remote_path: str, local_path: str, chunk_size: int = 1024 * 1024):
    local_path = Path(local_path)
    try:
        response = requests.get(remote_path, stream=True, allow_redirects=True)
        response.raise_for_status()  # Raise an exception for bad status codes

        # Get file size from headers for progress bar
        total_size = int(response.headers.get('content-length', 0))
        
        # Create progress bar with file name and size info
        filename = local_path.name
        pbar = tqdm(
            total=total_size,
            unit='B',
            unit_scale=True,
            unit_divisor=1024,
            desc=f"Downloading {filename}",
            miniters=1
        )

        with open(local_path, "wb") as f:
            downloaded = 0
            for chunk in response.iter_content(chunk_size=chunk_size):
                if chunk:
                    f.write(chunk)
                    downloaded += len(chunk)
                    pbar.update(len(chunk))
        
        pbar.close()
        return local_path
    except Exception as e:
        if local_path.exists():
            local_path.unlink()
        logger.error(f"Download error for file {remote_path}: {str(e)}")
        raise


def check_manifest(local_dir: str):
    local_dir = Path(local_dir)
    manifest_path = local_dir / "manifest.json"
    if not os.path.exists(manifest_path):
        return False

    try:
        with open(manifest_path, "r") as f:
            manifest = json.load(f)
        for file in manifest["files"]:
            if not os.path.exists(local_dir / file):
                return False
    except Exception:
        return False

    return True


def download_directory(remote_path: str, local_dir: str):
    model_name = get_model_name(remote_path)
    s3_url = join_urls(settings.S3_BASE_URL, remote_path)
    # Check to see if it's already downloaded
    model_exists = check_manifest(local_dir)
    if model_exists:
        return

    # Use tempfile.TemporaryDirectory to automatically clean up
    with tempfile.TemporaryDirectory() as temp_dir:
        # Download the manifest file
        manifest_file = join_urls(s3_url, "manifest.json")
        manifest_path = os.path.join(temp_dir, "manifest.json")
        download_file(manifest_file, manifest_path)

        # List and download all files
        with open(manifest_path, "r") as f:
            manifest = json.load(f)

        pbar = tqdm(
            desc=f"Downloading {model_name} model to {local_dir}",
            total=len(manifest["files"]),
        )

        with ThreadPoolExecutor(
            max_workers=settings.PARALLEL_DOWNLOAD_WORKERS
        ) as executor:
            futures = []
            for file in manifest["files"]:
                remote_file = join_urls(s3_url, file)
                local_file = os.path.join(temp_dir, file)
                futures.append(executor.submit(download_file, remote_file, local_file))

            for future in futures:
                future.result()
                pbar.update(1)

        pbar.close()

        # Move all files to new directory
        for file in os.listdir(temp_dir):
            shutil.move(os.path.join(temp_dir, file), local_dir)


class S3DownloaderMixin:
    s3_prefix = "s3://"

    @classmethod
    def get_local_path(cls, pretrained_model_name_or_path) -> str:
        if pretrained_model_name_or_path.startswith(cls.s3_prefix):
            pretrained_model_name_or_path = pretrained_model_name_or_path.replace(
                cls.s3_prefix, ""
            )
            cache_dir = settings.MODEL_CACHE_DIR
            local_path = os.path.join(cache_dir, pretrained_model_name_or_path)
            os.makedirs(local_path, exist_ok=True)
        else:
            local_path = ""
        return local_path

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
        # Allow loading models directly from the hub, or using s3
        if not pretrained_model_name_or_path.startswith(cls.s3_prefix):
            return super().from_pretrained(
                pretrained_model_name_or_path, *args, **kwargs
            )

        local_path = cls.get_local_path(pretrained_model_name_or_path)
        pretrained_model_name_or_path = pretrained_model_name_or_path.replace(
            cls.s3_prefix, ""
        )

        # Retry logic for downloading the model folder
        retries = 3
        delay = 5
        attempt = 0
        success = False
        while not success and attempt < retries:
            try:
                download_directory(pretrained_model_name_or_path, local_path)
                success = True  # If download succeeded
            except Exception as e:
                logger.error(
                    f"Error downloading model from {pretrained_model_name_or_path}. Attempt {attempt + 1} of {retries}. Error: {e}"
                )
                attempt += 1
                if attempt < retries:
                    logger.info(f"Retrying in {delay} seconds...")
                    time.sleep(delay)  # Wait before retrying
                else:
                    logger.error(
                        f"Failed to download {pretrained_model_name_or_path} after {retries} attempts."
                    )
                    raise e  # Reraise exception after max retries

        return super().from_pretrained(local_path, *args, **kwargs)

```

--------------------------------------------------------------------------------
/surya/detection/__init__.py:
--------------------------------------------------------------------------------

```python
from concurrent.futures import ThreadPoolExecutor
from typing import List, Generator, Tuple

import numpy as np
import torch
import torch.nn.functional as F

from PIL import Image
from tqdm import tqdm

from surya.common.predictor import BasePredictor
from surya.common.xla import mark_step

from surya.detection.loader import DetectionModelLoader
from surya.detection.parallel import FakeExecutor
from surya.detection.util import get_total_splits, split_image
from surya.detection.schema import TextDetectionResult
from surya.settings import settings
from surya.detection.heatmap import parallel_get_boxes


class DetectionPredictor(BasePredictor):
    model_loader_cls = DetectionModelLoader
    batch_size = settings.DETECTOR_BATCH_SIZE
    default_batch_sizes = {"cpu": 8, "mps": 8, "cuda": 36, "xla": 18}

    def __call__(
        self, images: List[Image.Image], batch_size=None, include_maps=False
    ) -> List[TextDetectionResult]:
        detection_generator = self.batch_detection(
            images, batch_size=batch_size, static_cache=settings.DETECTOR_STATIC_CACHE
        )

        postprocessing_futures = []
        max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
        parallelize = (
            not settings.IN_STREAMLIT
            and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
        )
        executor = ThreadPoolExecutor if parallelize else FakeExecutor
        with executor(max_workers=max_workers) as e:
            for preds, orig_sizes in detection_generator:
                for pred, orig_size in zip(preds, orig_sizes):
                    postprocessing_futures.append(
                        e.submit(parallel_get_boxes, pred, orig_size, include_maps)
                    )

        return [future.result() for future in postprocessing_futures]

    def prepare_image(self, img):
        new_size = (self.processor.size["width"], self.processor.size["height"])

        # This double resize actually necessary for downstream accuracy
        img.thumbnail(new_size, Image.Resampling.LANCZOS)
        img = img.resize(
            new_size, Image.Resampling.LANCZOS
        )  # Stretch smaller dimension to fit new size

        img = np.asarray(img, dtype=np.uint8)
        img = self.processor(img)["pixel_values"][0]
        img = torch.from_numpy(img)
        return img

    def batch_detection(
        self, images: List, batch_size=None, static_cache=False
    ) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]:
        assert all([isinstance(image, Image.Image) for image in images])
        if batch_size is None:
            batch_size = self.get_batch_size()
        heatmap_count = self.model.config.num_labels

        orig_sizes = [image.size for image in images]
        splits_per_image = [
            get_total_splits(size, self.processor.size["height"]) for size in orig_sizes
        ]

        batches = []
        current_batch_size = 0
        current_batch = []
        for i in range(len(images)):
            if current_batch_size + splits_per_image[i] > batch_size:
                if len(current_batch) > 0:
                    batches.append(current_batch)
                current_batch = []
                current_batch_size = 0
            current_batch.append(i)
            current_batch_size += splits_per_image[i]

        if len(current_batch) > 0:
            batches.append(current_batch)

        for batch_idx in tqdm(
            range(len(batches)), desc="Detecting bboxes", disable=self.disable_tqdm
        ):
            batch_image_idxs = batches[batch_idx]
            batch_images = [images[j].convert("RGB") for j in batch_image_idxs]

            split_index = []
            split_heights = []
            image_splits = []
            for image_idx, image in enumerate(batch_images):
                image_parts, split_height = split_image(
                    image, self.processor.size["height"]
                )
                image_splits.extend(image_parts)
                split_index.extend([image_idx] * len(image_parts))
                split_heights.extend(split_height)

            image_splits = [self.prepare_image(image) for image in image_splits]
            # Batch images in dim 0
            batch = torch.stack(image_splits, dim=0).to(self.model.dtype)
            if static_cache:
                batch = self.pad_to_batch_size(batch, batch_size)

            with settings.INFERENCE_MODE():
                pred = self.model(
                    pixel_values=batch.to(self.model.device)
                )  # Moving the to device here fixes issues with xla recompilation

            logits = pred.logits
            correct_shape = [
                self.processor.size["height"],
                self.processor.size["width"],
            ]
            current_shape = list(logits.shape[2:])
            if current_shape != correct_shape:
                logits = F.interpolate(
                    logits, size=correct_shape, mode="bilinear", align_corners=False
                )
            mark_step()

            logits = logits.to(torch.float32).cpu().numpy()
            preds = []
            for i, (idx, height) in enumerate(zip(split_index, split_heights)):
                # If our current prediction length is below the image idx, that means we have a new image
                # Otherwise, we need to add to the current image
                if len(preds) <= idx:
                    preds.append([logits[i][k] for k in range(heatmap_count)])
                else:
                    heatmaps = preds[idx]
                    pred_heatmaps = [logits[i][k] for k in range(heatmap_count)]

                    if height < self.processor.size["height"]:
                        # Cut off padding to get original height
                        pred_heatmaps = [
                            pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps
                        ]

                    for k in range(heatmap_count):
                        heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]])
                    preds[idx] = heatmaps

            yield preds, [orig_sizes[j] for j in batch_image_idxs]

        torch.cuda.empty_cache()

```

--------------------------------------------------------------------------------
/benchmark/utils/metrics.py:
--------------------------------------------------------------------------------

```python
from functools import partial
from itertools import repeat

import numpy as np
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor


def box_area(box):
    return (box[2] - box[0]) * (box[3] - box[1])


def calculate_iou(box1, box2, box1_only=False):
    intersection = intersection_area(box1, box2)
    union = box_area(box1)
    if not box1_only:
        union += box_area(box2) - intersection

    if union == 0:
        return 0
    return intersection / union


def match_boxes(preds, references):
    num_actual = len(references)
    num_predicted = len(preds)

    iou_matrix = np.zeros((num_actual, num_predicted))
    for i, actual in enumerate(references):
        for j, pred in enumerate(preds):
            iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True)

    sorted_indices = np.argsort(iou_matrix, axis=None)[::-1]
    sorted_ious = iou_matrix.flatten()[sorted_indices]
    actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape)

    assigned_actual = set()
    assigned_pred = set()

    matches = []
    for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious):
        i, j = idx
        if i not in assigned_actual and j not in assigned_pred:
            iou_val = iou_matrix[i, j]
            if iou_val > .95: # Account for rounding on box edges
                iou_val = 1.0
            matches.append((i, j, iou_val))
            assigned_actual.add(i)
            assigned_pred.add(j)

    unassigned_actual = set(range(num_actual)) - assigned_actual
    unassigned_pred = set(range(num_predicted)) - assigned_pred
    matches.extend([(i, None, -1.0) for i in unassigned_actual])
    matches.extend([(None, j, 0.0) for j in unassigned_pred])

    return matches

def penalized_iou_score(preds, references):
    matches = match_boxes(preds, references)
    iou = sum([match[2] for match in matches]) / len(matches)
    return iou

def intersection_pixels(box1, box2):
    x_left = max(box1[0], box2[0])
    y_top = max(box1[1], box2[1])
    x_right = min(box1[2], box2[2])
    y_bottom = min(box1[3], box2[3])

    if x_right < x_left or y_bottom < y_top:
        return set()

    x_left, x_right = int(x_left), int(x_right)
    y_top, y_bottom = int(y_top), int(y_bottom)

    coords = np.meshgrid(np.arange(x_left, x_right), np.arange(y_top, y_bottom))
    pixels = set(zip(coords[0].flat, coords[1].flat))

    return pixels


def calculate_coverage(box, other_boxes, penalize_double=False):
    box_area = (box[2] - box[0]) * (box[3] - box[1])
    if box_area == 0:
        return 0

    # find total coverage of the box
    covered_pixels = set()
    double_coverage = list()
    for other_box in other_boxes:
        ia = intersection_pixels(box, other_box)
        double_coverage.append(list(covered_pixels.intersection(ia)))
        covered_pixels = covered_pixels.union(ia)

    # Penalize double coverage - having multiple bboxes overlapping the same pixels
    double_coverage_penalty = len(double_coverage)
    if not penalize_double:
        double_coverage_penalty = 0
    covered_pixels_count = max(0, len(covered_pixels) - double_coverage_penalty)
    return covered_pixels_count / box_area


def intersection_area(box1, box2):
    x_left = max(box1[0], box2[0])
    y_top = max(box1[1], box2[1])
    x_right = min(box1[2], box2[2])
    y_bottom = min(box1[3], box2[3])

    if x_right < x_left or y_bottom < y_top:
        return 0.0

    return (x_right - x_left) * (y_bottom - y_top)


def calculate_coverage_fast(box, other_boxes, penalize_double=False):
    box = np.array(box)
    other_boxes = np.array(other_boxes)

    # Calculate box area
    box_area = (box[2] - box[0]) * (box[3] - box[1])
    if box_area == 0:
        return 0

    x_left = np.maximum(box[0], other_boxes[:, 0])
    y_top = np.maximum(box[1], other_boxes[:, 1])
    x_right = np.minimum(box[2], other_boxes[:, 2])
    y_bottom = np.minimum(box[3], other_boxes[:, 3])

    widths = np.maximum(0, x_right - x_left)
    heights = np.maximum(0, y_bottom - y_top)
    intersect_areas = widths * heights

    total_intersect = np.sum(intersect_areas)

    return min(1.0, total_intersect / box_area)


def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True):
    if len(references) == 0:
        return {
            "precision": 1,
            "recall": 1,
        }

    if len(preds) == 0:
        return {
            "precision": 0,
            "recall": 0,
        }

    # If we're not penalizing double coverage, we can use a faster calculation
    coverage_func = calculate_coverage_fast
    if penalize_double:
        coverage_func = calculate_coverage

    with ThreadPoolExecutor(max_workers=workers) as executor:
        precision_func = partial(coverage_func, penalize_double=penalize_double)
        precision_iou = executor.map(precision_func, preds, repeat(references))
        reference_iou = executor.map(coverage_func, references, repeat(preds))

    precision_classes = [1 if i > threshold else 0 for i in precision_iou]
    precision = sum(precision_classes) / len(precision_classes)

    recall_classes = [1 if i > threshold else 0 for i in reference_iou]
    recall = sum(recall_classes) / len(recall_classes)

    return {
        "precision": precision,
        "recall": recall,
    }


def mean_coverage(preds, references):
    coverages = []

    for box1 in references:
        coverage = calculate_coverage(box1, preds)
        coverages.append(coverage)

    for box2 in preds:
        coverage = calculate_coverage(box2, references)
        coverages.append(coverage)

    # Calculate the average coverage over all comparisons
    if len(coverages) == 0:
        return 0
    coverage = sum(coverages) / len(coverages)
    return {"coverage": coverage}


def rank_accuracy(preds, references):
    # Preds and references need to be aligned so each position refers to the same bbox
    pairs = []
    for i, pred in enumerate(preds):
        for j, pred2 in enumerate(preds):
            if i == j:
                continue
            pairs.append((i, j, pred > pred2))

    # Find how many of the prediction rankings are correct
    correct = 0
    for i, ref in enumerate(references):
        for j, ref2 in enumerate(references):
            if (i, j, ref > ref2) in pairs:
                correct += 1

    return correct / len(pairs)
```

--------------------------------------------------------------------------------
/surya/common/donut/processor.py:
--------------------------------------------------------------------------------

```python
from typing import Dict, Union, Optional, List, Iterable

import cv2
from torch import TensorType
from transformers import ImageProcessingMixin
from transformers.image_processing_utils import BatchFeature
from transformers.image_transforms import pad, normalize
from transformers.image_utils import (
    ImageInput,
    ChannelDimension,
    make_list_of_images,
    get_image_size,
)
import numpy as np
from PIL import Image
import PIL
from transformers.utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD

from surya.common.s3 import S3DownloaderMixin
from surya.settings import settings


class SuryaEncoderImageProcessor(S3DownloaderMixin, ImageProcessingMixin):
    def __init__(
        self,
        *args,
        max_size=None,
        align_long_axis=False,
        rescale_factor: Union[int, float] = 1 / 255,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.patch_size = kwargs.get("patch_size", (4, 4))
        self.max_size = max_size
        self.do_align_long_axis = align_long_axis
        self.resample = Image.Resampling.BILINEAR
        self.rescale_factor = rescale_factor
        self.image_mean = (
            image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
        )
        self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD

    def __call__(self, images, **kwargs) -> PIL.Image.Image:
        """Preprocess an image or a batch of images."""
        return self.preprocess(images, **kwargs)

    @classmethod
    def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4):
        max_width, max_height = size["width"], size["height"]

        resized_image = cv2.resize(
            image, (max_width, max_height), interpolation=interpolation
        )
        resized_image = resized_image.transpose(2, 0, 1)

        return resized_image

    def process_inner(self, images: List[np.ndarray]):
        assert images[0].shape[2] == 3  # RGB input images, channel dim last

        if self.do_align_long_axis:
            # Rotate if the bbox is wider than it is tall
            images = [
                SuryaEncoderImageProcessor.align_long_axis(
                    image, size=self.max_size, input_data_format=ChannelDimension.LAST
                )
                for image in images
            ]

            # Verify that the image is wider than it is tall
            for img in images:
                assert img.shape[1] >= img.shape[0]

        # This also applies the right channel dim format, to channel x height x width
        images = [
            SuryaEncoderImageProcessor.numpy_resize(img, self.max_size, self.resample)
            for img in images
        ]
        assert images[0].shape[0] == 3  # RGB input images, channel dim first

        # Convert to float32 for rescale/normalize
        images = [img.astype(np.float32) for img in images]

        # Pads with 255 (whitespace)
        # Pad to max size to improve performance
        max_size = self.max_size
        images = [
            SuryaEncoderImageProcessor.pad_image(
                image=image,
                size=max_size,
                input_data_format=ChannelDimension.FIRST,
                pad_value=settings.RECOGNITION_PAD_VALUE,
            )
            for image in images
        ]

        # Rescale and normalize
        for idx in range(len(images)):
            images[idx] = (images[idx].astype(np.float64) * self.rescale_factor).astype(
                np.float32
            )

        images = [
            SuryaEncoderImageProcessor.normalize(
                img,
                mean=self.image_mean,
                std=self.image_std,
                input_data_format=ChannelDimension.FIRST,
            )
            for img in images
        ]

        return images

    def preprocess(
        self,
        images: ImageInput,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ) -> PIL.Image.Image:
        images = make_list_of_images(images)

        # Convert to numpy for later processing steps
        images = [np.array(img) for img in images]
        images = self.process_inner(images)

        data = {"pixel_values": images}
        return BatchFeature(data=data, tensor_type=return_tensors)

    @classmethod
    def pad_image(
        cls,
        image: np.ndarray,
        size: Dict[str, int],
        data_format: Optional[Union[str, ChannelDimension]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        pad_value: float = 0.0,
    ) -> np.ndarray:
        output_height, output_width = size["height"], size["width"]
        input_height, input_width = get_image_size(image, channel_dim=input_data_format)

        delta_width = output_width - input_width
        delta_height = output_height - input_height

        assert delta_width >= 0 and delta_height >= 0

        pad_top = delta_height // 2
        pad_left = delta_width // 2

        pad_bottom = delta_height - pad_top
        pad_right = delta_width - pad_left

        padding = ((pad_top, pad_bottom), (pad_left, pad_right))
        return pad(
            image,
            padding,
            data_format=data_format,
            input_data_format=input_data_format,
            constant_values=pad_value,
        )

    @classmethod
    def align_long_axis(
        cls, image: np.ndarray, size: Dict[str, int], **kwargs
    ) -> np.ndarray:
        input_height, input_width = image.shape[:2]
        output_height, output_width = size["height"], size["width"]

        if (output_width < output_height and input_width > input_height) or (
            output_width > output_height and input_width < input_height
        ):
            image = np.rot90(image, 3)

        return image

    @classmethod
    def normalize(
        cls,
        image: np.ndarray,
        mean: Union[float, Iterable[float]],
        std: Union[float, Iterable[float]],
        data_format: Optional[Union[str, ChannelDimension]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
    ) -> np.ndarray:
        return normalize(
            image,
            mean=mean,
            std=std,
            data_format=data_format,
            input_data_format=input_data_format,
            **kwargs,
        )

```

--------------------------------------------------------------------------------
/benchmark/table_recognition.py:
--------------------------------------------------------------------------------

```python
import click
import collections
import json

from surya.debug.draw import draw_bboxes_on_image
from tabulate import tabulate

from surya.input.processing import convert_if_not_rgb
from surya.table_rec import TableRecPredictor
from surya.settings import settings
from benchmark.utils.metrics import penalized_iou_score
from benchmark.utils.tatr import load_tatr, batch_inference_tatr
import os
import time
import datasets


@click.command(help="Benchmark table rec dataset")
@click.option(
    "--results_dir",
    type=str,
    help="Path to JSON file with benchmark results.",
    default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
@click.option(
    "--max_rows",
    type=int,
    help="Maximum number of images to run benchmark on.",
    default=512,
)
@click.option("--tatr", is_flag=True, help="Run table transformer.", default=False)
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
def main(results_dir: str, max_rows: int, tatr: bool, debug: bool):
    table_rec_predictor = TableRecPredictor()

    pathname = "table_rec_bench"
    # These have already been shuffled randomly, so sampling from the start is fine
    split = "train"
    if max_rows is not None:
        split = f"train[:{max_rows}]"
    dataset = datasets.load_dataset(settings.TABLE_REC_BENCH_DATASET_NAME, split=split)
    images = list(dataset["image"])
    images = convert_if_not_rgb(images)

    if settings.TABLE_REC_STATIC_CACHE:
        # Run through one batch to compile the model
        table_rec_predictor(images[:1])

    start = time.time()
    table_rec_predictions = table_rec_predictor(images)
    surya_time = time.time() - start

    folder_name = os.path.basename(pathname).split(".")[0]
    result_path = os.path.join(results_dir, folder_name)
    os.makedirs(result_path, exist_ok=True)

    page_metrics = collections.OrderedDict()
    mean_col_iou = 0
    mean_row_iou = 0
    for idx, (pred, image) in enumerate(zip(table_rec_predictions, images)):
        row = dataset[idx]
        pred_row_boxes = [p.bbox for p in pred.rows]
        pred_col_bboxes = [p.bbox for p in pred.cols]
        actual_row_bboxes = [r["bbox"] for r in row["rows"]]
        actual_col_bboxes = [c["bbox"] for c in row["columns"]]
        row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
        col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
        page_results = {
            "row_score": row_score,
            "col_score": col_score,
            "row_count": len(actual_row_bboxes),
            "col_count": len(actual_col_bboxes),
        }

        mean_col_iou += col_score
        mean_row_iou += row_score

        page_metrics[idx] = page_results

        if debug:
            # Save debug images
            draw_img = image.copy()
            draw_bboxes_on_image(
                pred_row_boxes,
                draw_img,
                [f"Row {i}" for i in range(len(pred_row_boxes))],
            )
            draw_bboxes_on_image(
                pred_col_bboxes,
                draw_img,
                [f"Col {i}" for i in range(len(pred_col_bboxes))],
                color="blue",
            )
            draw_img.save(os.path.join(result_path, f"{idx}_bbox.png"))

            actual_draw_image = image.copy()
            draw_bboxes_on_image(
                actual_row_bboxes,
                actual_draw_image,
                [f"Row {i}" for i in range(len(actual_row_bboxes))],
            )
            draw_bboxes_on_image(
                actual_col_bboxes,
                actual_draw_image,
                [f"Col {i}" for i in range(len(actual_col_bboxes))],
                color="blue",
            )
            actual_draw_image.save(os.path.join(result_path, f"{idx}_actual.png"))

    mean_col_iou /= len(table_rec_predictions)
    mean_row_iou /= len(table_rec_predictions)

    out_data = {
        "surya": {
            "time": surya_time,
            "mean_row_iou": mean_row_iou,
            "mean_col_iou": mean_col_iou,
            "page_metrics": page_metrics,
        }
    }

    if tatr:
        tatr_model = load_tatr()
        start = time.time()
        tatr_predictions = batch_inference_tatr(tatr_model, images, 1)
        tatr_time = time.time() - start

        page_metrics = collections.OrderedDict()
        mean_col_iou = 0
        mean_row_iou = 0
        for idx, pred in enumerate(tatr_predictions):
            row = dataset[idx]
            pred_row_boxes = [p["bbox"] for p in pred["rows"]]
            pred_col_bboxes = [p["bbox"] for p in pred["cols"]]
            actual_row_bboxes = [r["bbox"] for r in row["rows"]]
            actual_col_bboxes = [c["bbox"] for c in row["columns"]]
            row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
            col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
            page_results = {
                "row_score": row_score,
                "col_score": col_score,
                "row_count": len(actual_row_bboxes),
                "col_count": len(actual_col_bboxes),
            }

            mean_col_iou += col_score
            mean_row_iou += row_score

            page_metrics[idx] = page_results

        mean_col_iou /= len(tatr_predictions)
        mean_row_iou /= len(tatr_predictions)

        out_data["tatr"] = {
            "time": tatr_time,
            "mean_row_iou": mean_row_iou,
            "mean_col_iou": mean_col_iou,
            "page_metrics": page_metrics,
        }

    with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
        json.dump(out_data, f, indent=4)

    table = [
        ["Model", "Row Intersection", "Col Intersection", "Time Per Image"],
        [
            "Surya",
            f"{out_data['surya']['mean_row_iou']:.2f}",
            f"{out_data['surya']['mean_col_iou']:.5f}",
            f"{surya_time / len(images):.5f}",
        ],
    ]

    if tatr:
        table.append(
            [
                "Table transformer",
                f"{out_data['tatr']['mean_row_iou']:.2f}",
                f"{out_data['tatr']['mean_col_iou']:.5f}",
                f"{tatr_time / len(images):.5f}",
            ]
        )

    print(tabulate(table, headers="firstrow", tablefmt="github"))

    print(
        "Intersection is the average of the intersection % between each actual row/column, and the predictions.  With penalties for too many/few predictions."
    )
    print(
        "Note that table transformers is unbatched, since the example code in the repo is unbatched."
    )
    print(f"Wrote results to {result_path}")


if __name__ == "__main__":
    main()

```

--------------------------------------------------------------------------------
/surya/table_rec/model/decoder.py:
--------------------------------------------------------------------------------

```python
from typing import Optional, Tuple, Union

import torch
from torch import nn

from surya.common.adetr.decoder import SuryaADETRDecoderModel, SuryaADETRDecoderPreTrainedModel
from surya.table_rec.model.config import TableRecModelOutput
from surya.table_rec.shaper import LabelShaper
from surya.settings import settings


class LabelEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()

        # Bboxes
        self.w_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
        self.h_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
        self.cx_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
        self.cy_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
        self.xskew_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
        self.yskew_embed = nn.Embedding(config.vocab_size, config.box_embed_size)

        self.x1_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
        self.y1_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
        self.x2_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
        self.y2_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
        self.x3_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
        self.y3_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
        self.x4_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
        self.y4_embed = nn.Embedding(config.vocab_size, config.box_embed_size)

        # Get indexes for passed in tensor
        shaper = LabelShaper()
        self.component_idxs = shaper.component_idx_dict()
        merge_count = shaper.get_box_property("merges")[1] + config.special_token_count
        category_count = shaper.get_box_property("category")[1] + config.special_token_count

        # Other box properties
        self.category_embed = nn.Embedding(category_count, config.property_embed_size)
        self.merge_embed = nn.Embedding(merge_count, config.property_embed_size)
        self.colspan_embed = nn.Embedding(config.vocab_size, config.property_embed_size)

        self.config = config

    def forward(self, boxes: torch.LongTensor, *args):
        # Need to keep *args for compatibility with common decoder
        boxes = boxes.to(torch.long).clamp(0, self.config.vocab_size)

        boxes_unbound = boxes.to(torch.long).unbind(dim=-1)
        cx, cy, w, h, xskew, yskew = boxes_unbound[self.component_idxs["bbox"][0]:self.component_idxs["bbox"][1]]
        category = boxes_unbound[self.component_idxs["category"][0]:self.component_idxs["category"][1]][0]
        merges = boxes_unbound[self.component_idxs["merges"][0]:self.component_idxs["merges"][1]][0]
        colspan = boxes_unbound[self.component_idxs["colspan"][0]:self.component_idxs["colspan"][1]][0]

        xskew_actual = ((xskew - self.config.bbox_size // 2) / 2).to(torch.long)
        yskew_actual = ((yskew - self.config.bbox_size // 2) / 2).to(torch.long)

        x1 = (cx - w // 2 - xskew_actual).clamp(0, self.config.bbox_size).to(torch.long)
        y1 = (cy - h // 2 - yskew_actual).clamp(0, self.config.bbox_size).to(torch.long)
        x3 = (cx + w // 2 + xskew_actual).clamp(0, self.config.bbox_size).to(torch.long)
        y3 = (cy + h // 2 + yskew_actual).clamp(0, self.config.bbox_size).to(torch.long)

        size_embeds = self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)
        skew_embeds = self.xskew_embed(xskew) + self.yskew_embed(yskew)
        corner_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x3_embed(x3) + self.y3_embed(y3)
        box_embeds = size_embeds + skew_embeds + corner_embeds

        property_embeds = self.category_embed(category) + self.merge_embed(merges) + self.colspan_embed(colspan)

        # Cat bbox and property embeddings
        embedded = torch.cat([box_embeds, property_embeds], dim=-1)
        return embedded


class SuryaTableRecDecoder(SuryaADETRDecoderPreTrainedModel):
    _tied_weights_keys = None

    def __init__(self, config, **kwargs):
        super().__init__(config)
        embed_tokens = LabelEmbedding(config)
        self.model = SuryaADETRDecoderModel(
            config,
            embedder=embed_tokens,
            static_cache=settings.TABLE_REC_STATIC_CACHE,
            max_boxes=settings.TABLE_REC_MAX_BOXES
        )
        self.vocab_size = config.vocab_size

        shaper = LabelShaper()
        property_heads = {}
        for k in shaper.property_keys:
            _, kcount, mode = shaper.get_box_property(k)
            property_heads[k] = nn.Linear(config.hidden_size, kcount, bias=False)

        self.box_property_heads = nn.ModuleDict(property_heads)
        self.pre_output_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    # Ignore copy
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        prefill: bool = False,
        **kwargs
    ) -> Union[Tuple, TableRecModelOutput]:
        outputs = self.model(
            input_ids=input_ids,
            cache_position=cache_position,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_hidden_states=True,
            return_dict=True,
            prefill=prefill,
        )

        hidden_states = self.pre_output_norm(outputs[0])
        box_property_logits = {}
        for key in self.box_property_heads:
            box_property_logits[key] = self.box_property_heads[key](hidden_states)

        bbox_logits = nn.functional.sigmoid(box_property_logits["bbox"])
        box_property_logits["bbox"] = bbox_logits

        return TableRecModelOutput(
            box_property_logits=box_property_logits,
            hidden_states=hidden_states,
        )
```

--------------------------------------------------------------------------------
/surya/common/polygon.py:
--------------------------------------------------------------------------------

```python
import copy
from typing import List, Optional

import numpy as np
from pydantic import BaseModel, field_validator, computed_field
import numbers


class PolygonBox(BaseModel):
    polygon: List[List[float]]
    confidence: Optional[float] = None

    @field_validator("polygon", mode="before")
    @classmethod
    def convert_bbox_to_polygon(cls, value):
        if isinstance(value, (list, tuple)) and len(value) == 4:
            if all(isinstance(x, numbers.Number) for x in value):
                value = [float(v) for v in value]
                x_min, y_min, x_max, y_max = value
                polygon = [
                    [x_min, y_min],
                    [x_max, y_min],
                    [x_max, y_max],
                    [x_min, y_max],
                ]
                return polygon
            elif all(
                isinstance(point, (list, tuple)) and len(point) == 2 for point in value
            ):
                value = [[float(v) for v in point] for point in value]
                return value
        elif isinstance(value, np.ndarray):
            if value.shape == (4, 2):
                return value.tolist()

        raise ValueError(
            f"Input must be either a bbox [x_min, y_min, x_max, y_max] or a polygon with 4 corners [(x,y), (x,y), (x,y), (x,y)].  All values must be numeric. You passed {value} of type {type(value)}.  The first value is of type {type(value[0])}."
        )

    @property
    def height(self):
        return self.bbox[3] - self.bbox[1]

    @property
    def width(self):
        return self.bbox[2] - self.bbox[0]

    @property
    def area(self):
        return self.width * self.height

    @computed_field
    @property
    def bbox(self) -> List[float]:
        x_coords = [point[0] for point in self.polygon]
        y_coords = [point[1] for point in self.polygon]
        return [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]

    def rescale(self, processor_size, image_size):
        # Point is in x, y format
        page_width, page_height = processor_size

        img_width, img_height = image_size
        width_scaler = img_width / page_width
        height_scaler = img_height / page_height

        for corner in self.polygon:
            corner[0] = int(corner[0] * width_scaler)
            corner[1] = int(corner[1] * height_scaler)

    def round(self, divisor):
        for corner in self.polygon:
            corner[0] = int(corner[0] / divisor) * divisor
            corner[1] = int(corner[1] / divisor) * divisor

    def fit_to_bounds(self, bounds):
        new_corners = copy.deepcopy(self.polygon)
        for corner in new_corners:
            corner[0] = max(min(corner[0], bounds[2]), bounds[0])
            corner[1] = max(min(corner[1], bounds[3]), bounds[1])
        self.polygon = new_corners

    def merge(self, other):
        x1 = min(self.bbox[0], other.bbox[0])
        y1 = min(self.bbox[1], other.bbox[1])
        x2 = max(self.bbox[2], other.bbox[2])
        y2 = max(self.bbox[3], other.bbox[3])
        self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]

    def merge_left(self, other):
        x1 = min(self.bbox[0], other.bbox[0])
        self.polygon[0][0] = x1
        self.polygon[3][0] = x1

    def merge_right(self, other):
        x2 = max(self.bbox[2], other.bbox[2])
        self.polygon[1][0] = x2
        self.polygon[2][0] = x2

    def expand(self, x_margin: float, y_margin: float):
        new_polygon = []
        x_margin = x_margin * self.width
        y_margin = y_margin * self.height
        for idx, poly in enumerate(self.polygon):
            if idx == 0:
                new_polygon.append([int(poly[0] - x_margin), int(poly[1] - y_margin)])
            elif idx == 1:
                new_polygon.append([int(poly[0] + x_margin), int(poly[1] - y_margin)])
            elif idx == 2:
                new_polygon.append([int(poly[0] + x_margin), int(poly[1] + y_margin)])
            elif idx == 3:
                new_polygon.append([int(poly[0] - x_margin), int(poly[1] + y_margin)])
        self.polygon = new_polygon

    def intersection_polygon(self, other) -> List[List[float]]:
        new_poly = []
        for i in range(4):
            if i == 0:
                new_corner = [
                    max(self.polygon[0][0], other.polygon[0][0]),
                    max(self.polygon[0][1], other.polygon[0][1]),
                ]
            elif i == 1:
                new_corner = [
                    min(self.polygon[1][0], other.polygon[1][0]),
                    max(self.polygon[1][1], other.polygon[1][1]),
                ]
            elif i == 2:
                new_corner = [
                    min(self.polygon[2][0], other.polygon[2][0]),
                    min(self.polygon[2][1], other.polygon[2][1]),
                ]
            elif i == 3:
                new_corner = [
                    max(self.polygon[3][0], other.polygon[3][0]),
                    min(self.polygon[3][1], other.polygon[3][1]),
                ]
            new_poly.append(new_corner)

        return new_poly

    def intersection_area(self, other, x_margin=0, y_margin=0):
        x_overlap = self.x_overlap(other, x_margin)
        y_overlap = self.y_overlap(other, y_margin)
        return x_overlap * y_overlap

    def x_overlap(self, other, x_margin=0):
        return max(
            0,
            min(self.bbox[2] + x_margin, other.bbox[2] + x_margin)
            - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin),
        )

    def y_overlap(self, other, y_margin=0):
        return max(
            0,
            min(self.bbox[3] + y_margin, other.bbox[3] + y_margin)
            - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin),
        )

    def intersection_pct(self, other, x_margin=0, y_margin=0):
        assert 0 <= x_margin <= 1
        assert 0 <= y_margin <= 1
        if self.area == 0:
            return 0

        if x_margin:
            x_margin = int(min(self.width, other.width) * x_margin)
        if y_margin:
            y_margin = int(min(self.height, other.height) * y_margin)

        intersection = self.intersection_area(other, x_margin, y_margin)
        return intersection / self.area

    def shift(self, x_shift: float | None = None, y_shift: float | None = None):
        if x_shift is not None:
            for corner in self.polygon:
                corner[0] += x_shift
        if y_shift is not None:
            for corner in self.polygon:
                corner[1] += y_shift

    def clamp(self, bbox: List[float]):
        for corner in self.polygon:
            corner[0] = max(min(corner[0], bbox[2]), bbox[0])
            corner[1] = max(min(corner[1], bbox[3]), bbox[1])

    @property
    def center(self):
        return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2]

    def distance(self, other):
        center = self.center
        other_center = other.center

        return (
            (center[0] - other_center[0]) ** 2 + (center[1] - other_center[1]) ** 2
        ) ** 0.5

    def __hash__(self):
        return hash(tuple(self.bbox))

```

--------------------------------------------------------------------------------
/surya/settings.py:
--------------------------------------------------------------------------------

```python
import os
from typing import Callable, Dict, Optional

import torch
from dotenv import find_dotenv
from pydantic import computed_field
from pydantic_settings import BaseSettings
from pathlib import Path
from platformdirs import user_cache_dir


class Settings(BaseSettings):
    # General
    TORCH_DEVICE: Optional[str] = None
    IMAGE_DPI: int = 96  # Used for detection, layout, reading order
    IMAGE_DPI_HIGHRES: int = 192  # Used for OCR, table rec
    IN_STREAMLIT: bool = False  # Whether we're running in streamlit
    FLATTEN_PDF: bool = True  # Flatten PDFs by merging form fields before processing
    DISABLE_TQDM: bool = False  # Disable tqdm progress bars
    S3_BASE_URL: str = "https://models.datalab.to"
    PARALLEL_DOWNLOAD_WORKERS: int = (
        10  # Number of workers for parallel model downloads
    )
    MODEL_CACHE_DIR: str = str(Path(user_cache_dir("datalab")) / "models")
    LOGLEVEL: str = "INFO"  # Logging level

    # Paths
    DATA_DIR: str = "data"
    RESULT_DIR: str = "results"
    BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    FONT_DIR: str = os.path.join(BASE_DIR, "static", "fonts")

    @computed_field
    def TORCH_DEVICE_MODEL(self) -> str:
        if self.TORCH_DEVICE is not None:
            return self.TORCH_DEVICE

        if torch.cuda.is_available():
            return "cuda"

        if torch.backends.mps.is_available():
            return "mps"

        try:
            import torch_xla

            if len(torch_xla.devices()) > 0:
                return "xla"
        except Exception:
            pass

        return "cpu"

    # Text detection
    DETECTOR_BATCH_SIZE: Optional[int] = None  # Defaults to 2 for CPU/MPS, 32 otherwise
    DETECTOR_MODEL_CHECKPOINT: str = "s3://text_detection/2025_05_07"
    DETECTOR_BENCH_DATASET_NAME: str = "vikp/doclaynet_bench"
    DETECTOR_IMAGE_CHUNK_HEIGHT: int = (
        1400  # Height at which to slice images vertically
    )
    DETECTOR_TEXT_THRESHOLD: float = (
        0.6  # Threshold for text detection (above this is considered text)
    )
    DETECTOR_BLANK_THRESHOLD: float = (
        0.35  # Threshold for blank space (below this is considered blank)
    )
    DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min(
        8, os.cpu_count()
    )  # Number of workers for postprocessing
    DETECTOR_MIN_PARALLEL_THRESH: int = (
        3  # Minimum number of images before we parallelize
    )
    DETECTOR_BOX_Y_EXPAND_MARGIN: float = (
        0.05  # Margin by which to expand detected boxes vertically
    )
    COMPILE_DETECTOR: bool = False

    # Text recognition
    FOUNDATION_MODEL_CHECKPOINT: str = "s3://text_recognition/2025_09_23"
    FOUNDATION_MODEL_QUANTIZE: bool = False
    FOUNDATION_MAX_TOKENS: Optional[int] = None
    FOUNDATION_CHUNK_SIZE: Optional[int] = None
    FOUNDATION_PAD_TO_NEAREST: int = 256
    COMPILE_FOUNDATION: bool = False
    FOUNDATION_MULTI_TOKEN_MIN_CONFIDENCE: float = 0.9

    RECOGNITION_MODEL_CHECKPOINT: str = "s3://text_recognition/2025_09_23"
    RECOGNITION_BATCH_SIZE: Optional[int] = (
        None  # Defaults to 8 for CPU/MPS, 256 otherwise
    )
    RECOGNITION_RENDER_FONTS: Dict[str, str] = {
        "all": os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf"),
        "zh": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"),
        "ja": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"),
        "ko": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"),
    }
    RECOGNITION_FONT_DL_BASE: str = (
        "https://github.com/satbyy/go-noto-universal/releases/download/v7.0"
    )
    RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench"
    RECOGNITION_PAD_VALUE: int = 255  # Should be 0 or 255

    # Layout
    LAYOUT_MODEL_CHECKPOINT: str = "s3://layout/2025_09_23"
    LAYOUT_IMAGE_SIZE: Dict = {"height": 768, "width": 768}
    LAYOUT_SLICE_MIN: Dict = {
        "height": 1500,
        "width": 1500,
    }  # When to start slicing images
    LAYOUT_SLICE_SIZE: Dict = {"height": 1200, "width": 1200}  # Size of slices
    LAYOUT_BATCH_SIZE: Optional[int] = None
    LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench"
    LAYOUT_MAX_BOXES: int = 100
    COMPILE_LAYOUT: bool = False
    LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench"
    ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench"

    # Table Rec
    TABLE_REC_MODEL_CHECKPOINT: str = "s3://table_recognition/2025_02_18"
    TABLE_REC_IMAGE_SIZE: Dict = {"height": 768, "width": 768}
    TABLE_REC_MAX_BOXES: int = 150
    TABLE_REC_BATCH_SIZE: Optional[int] = None
    TABLE_REC_BENCH_DATASET_NAME: str = "datalab-to/fintabnet_bench"
    COMPILE_TABLE_REC: bool = False

    # Texify
    TEXIFY_BENCHMARK_DATASET: str = "datalab-to/texify_bench"

    # OCR Error Detection
    OCR_ERROR_MODEL_CHECKPOINT: str = "s3://ocr_error_detection/2025_02_18"
    OCR_ERROR_BATCH_SIZE: Optional[int] = None
    COMPILE_OCR_ERROR: bool = False

    # Tesseract (for benchmarks only)
    TESSDATA_PREFIX: Optional[str] = None

    COMPILE_ALL: bool = False

    @computed_field
    def DETECTOR_STATIC_CACHE(self) -> bool:
        return (
            self.COMPILE_ALL
            or self.COMPILE_DETECTOR
            or self.TORCH_DEVICE_MODEL == "xla"
        )  # We need to static cache and pad to batch size for XLA, since it will recompile otherwise

    @computed_field
    def LAYOUT_STATIC_CACHE(self) -> bool:
        return (
            self.COMPILE_ALL or self.COMPILE_LAYOUT or self.TORCH_DEVICE_MODEL == "xla"
        )

    @computed_field
    def FOUNDATION_XLA(self) -> bool:
        return (
            self.TORCH_DEVICE_MODEL == "xla"
        )  # We need to static cache and pad to batch size for XLA, since it will recompile otherwise

    @computed_field
    def FOUNDATION_STATIC_CACHE(self) -> bool:
        return (
            self.COMPILE_ALL
            or self.COMPILE_FOUNDATION
            or self.TORCH_DEVICE_MODEL == "xla"
        )  # We need to static cache and pad to batch size for XLA, since it will recompile otherwise

    @computed_field
    def TABLE_REC_STATIC_CACHE(self) -> bool:
        return (
            self.COMPILE_ALL
            or self.COMPILE_TABLE_REC
            or self.TORCH_DEVICE_MODEL == "xla"
        )

    @computed_field
    def OCR_ERROR_STATIC_CACHE(self) -> bool:
        return (
            self.COMPILE_ALL
            or self.COMPILE_OCR_ERROR
            or self.TORCH_DEVICE_MODEL == "xla"
        )

    @computed_field
    def MODEL_DTYPE(self) -> torch.dtype:
        if self.TORCH_DEVICE_MODEL == "cpu":
            return torch.float32
        if self.TORCH_DEVICE_MODEL == "xla":
            return torch.bfloat16
        return torch.float16

    @computed_field
    def MODEL_DTYPE_BFLOAT(self) -> torch.dtype:
        if self.TORCH_DEVICE_MODEL == "cpu":
            return torch.float32
        if self.TORCH_DEVICE_MODEL == "mps":
            return torch.bfloat16
        return torch.bfloat16

    @computed_field
    def INFERENCE_MODE(self) -> Callable:
        if self.TORCH_DEVICE_MODEL == "xla":
            return torch.no_grad
        return torch.inference_mode

    class Config:
        env_file = find_dotenv("local.env")
        extra = "ignore"


settings = Settings()

```

--------------------------------------------------------------------------------
/surya/foundation/cache/static_ops.py:
--------------------------------------------------------------------------------

```python
from typing import Any, Dict, List, Optional, Tuple
import torch
from transformers import PretrainedConfig

from surya.foundation.cache.dynamic_ops import DynamicOpsCache

"""
Special cache class for the surya foundation model that supports - 
1) Static shape
2) A custom sliding window, where image tokens stay in cache, and text tokens are popped
3) Continuous batching - merging etc
4) Attention mask management - To match with what's currently in the cache

Heavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079
"""


class StaticOpsCache(DynamicOpsCache):
    def __init__(
        self,
        config: PretrainedConfig,
        batch_size: int,
        max_cache_len: int,
        text_sliding_window: int,
        device: int,
        dtype: int,
    ):
        self.text_sliding_window = text_sliding_window
        self.num_layers = config.num_hidden_layers
        self.max_batch_size = batch_size
        self.max_cache_len = max_cache_len
        self.head_dim = (
            getattr(config, "head_dim", None)
            or config.hidden_size // config.num_attention_heads
        )
        self._dtype = dtype
        self.num_key_value_heads = (
            config.num_attention_heads
            if getattr(config, "num_key_value_heads", None) is None
            else config.num_key_value_heads
        )

        # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125
        self.key_cache: list[torch.Tensor] = []
        self.value_cache: list[torch.Tensor] = []
        cache_shape = (
            self.max_batch_size,
            self.num_key_value_heads,
            self.max_cache_len,
            self.head_dim,
        )
        device = torch.device(device) if device is not None else None
        for _ in range(config.num_hidden_layers):
            new_layer_key_cache = torch.zeros(
                cache_shape, dtype=self._dtype, device=device
            )
            new_layer_value_cache = torch.zeros(
                cache_shape, dtype=self._dtype, device=device
            )
            torch._dynamo.mark_static_address(new_layer_key_cache)
            torch._dynamo.mark_static_address(new_layer_value_cache)
            self.key_cache.append(new_layer_key_cache)
            self.value_cache.append(new_layer_value_cache)

        self.attention_mask = torch.zeros(
            (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long
        )
        self.text_token_counts = [
            torch.zeros(self.max_batch_size, dtype=torch.long, device=device)
            for _ in range(self.num_layers)
        ]

        self.dtype = dtype
        self.device = device

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        prefill = cache_kwargs.get("prefill", False)
        update_fn = self._prefill_update if prefill else self._decode_update
        return update_fn(
            self.key_cache[layer_idx],
            self.value_cache[layer_idx],
            key_states,
            value_states,
            self.text_token_counts[layer_idx],
            cache_kwargs,
        )

    def _prefill_update(
        self,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        text_token_counts: torch.Tensor,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ):
        cache_idxs: torch.tensor = cache_kwargs.get("cache_idxs", None)
        text_lengths: List[int] = cache_kwargs.get("text_lengths", None)
        assert cache_idxs is not None, "cache_idxs must be specified during prefill"
        assert text_lengths is not None, "text_lengths must be specified during prefill"

        cache_idx_length = len(cache_idxs)
        full_batch = len(cache_idxs) == self.max_batch_size

        # Insert key and value states at the end of the cache
        new_tokens = key_states.shape[2]

        # Direct right-aligned assignment
        if full_batch:
            key_cache[:, :, -new_tokens:] = key_states
            value_cache[:, :, -new_tokens:] = value_states
        else:
            key_cache[cache_idxs, :, -new_tokens:] = key_states[:cache_idx_length]
            value_cache[cache_idxs, :, -new_tokens:] = value_states[:cache_idx_length]

        return key_states, value_states

    # """
    # Matches the logic of the decode update, but needs to be called before the updates
    # since some parts of the model depend on the attention mask
    # """
    def decode_attention_mask_update(
        self, num_valid_tokens: torch.Tensor, cache_idxs: List[int]
    ):
        max_valid_tokens = num_valid_tokens.max().item()
        if max_valid_tokens == 0:
            # If no valid tokens, we don't need to update the attention mask
            return

        # Shift the attention mask to the left by max_valid_tokens
        self.attention_mask = self.attention_mask.roll(-1 * max_valid_tokens, dims=1)
        self.attention_mask[:, -max_valid_tokens:] = (
            1  # Full attention to all new tokens
        )

    # Mirrors the logic from _prefill_update
    def prefill_attention_mask_update(
        self,
        attention_mask: torch.Tensor,
        merge_idxs: torch.Tensor,
        valid_batch_size: torch.Tensor,
        text_lengths: List[int],
    ):
        # Set from -(image_length + text_length) to end to 1 for each batch element
        seq_len = attention_mask.shape[1]
        self.attention_mask[merge_idxs] = (
            0  # Reset the attention mask for the current batch elements
        )
        self.attention_mask[merge_idxs, -seq_len:] = attention_mask[:valid_batch_size]

    def _decode_update(
        self,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        text_token_counts: torch.Tensor,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Naive, always assumes we'll roll by a fixed amount
        # Needs left padding with beacons to work properly

        num_valid_tokens: torch.Tensor = cache_kwargs.get(
            "num_valid_tokens"
        )  # shape: (B,)
        assert num_valid_tokens is not None, (
            "`num_valid_tokens` must be provided in `cache_kwargs`"
        )
        # (B, H, L, D)

        valid_tokens = key_states.shape[2]

        key_cache.copy_(torch.roll(key_cache, -valid_tokens, dims=2))
        value_cache.copy_(torch.roll(value_cache, -valid_tokens, dims=2))

        key_cache[:, :, -valid_tokens:, :] = key_states
        value_cache[:, :, -valid_tokens:, :] = value_states

        # In-place edit - Mutates
        text_token_counts += num_valid_tokens
        text_token_counts.clamp_(max=self.text_sliding_window)
        return key_cache, value_cache

    # The attention mask managed by our kv cache automatically masks the tokens
    # in the cache, so we can return full length for HF to use in other places
    # This is mainly utilized in the cache_positions creation
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        return self.max_cache_len

```

--------------------------------------------------------------------------------
/surya/table_rec/model/config.py:
--------------------------------------------------------------------------------

```python
from dataclasses import dataclass
from typing import Dict

import torch
from transformers import PretrainedConfig
from transformers.utils import ModelOutput

from surya.common.s3 import S3DownloaderMixin
from surya.settings import settings

BOX_DIM = 1024
SPECIAL_TOKENS = 5
MAX_BOXES = 150

MERGE_KEYS = {
    "none": 0,
    "merge_up": 1,
    "merge_down": 2,
    "merge_both": 3
}
MERGE_VALUES = [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]]

ID_TO_CATEGORY = {
    0: 'Blank',
    1: 'Table-row',
    2: 'Table-column',
    3: 'Table-cell',
    4: 'Table'
}
CATEGORY_TO_ID = {v: k for k, v in ID_TO_CATEGORY.items()}

ID_TO_HEADER = {
    0: "None",
    1: "Header"
}
HEADER_TO_ID = {v: k for k, v in ID_TO_HEADER.items()}

BOX_PROPERTIES = [
    ("bbox", 6, "regression"),
    ("category", len(ID_TO_CATEGORY), "classification"),
    ("merges", len(MERGE_KEYS), "classification"),
    ("colspan", 1, "regression"),
    ("is_header", len(ID_TO_HEADER), "classification")
]


@dataclass
class TableRecModelOutput(ModelOutput):
    box_property_logits: Dict[str, torch.Tensor]
    hidden_states: torch.Tensor | None = None


class SuryaTableRecConfig(S3DownloaderMixin, PretrainedConfig):
    model_type = "vision-encoder-decoder"
    is_composition = True

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        if "encoder" in kwargs:
            encoder_config = kwargs.pop("encoder")
            decoder_config = kwargs.pop("decoder")
        else:
            encoder_config = DonutSwinTableRecConfig()
            decoder_config = SuryaTableRecDecoderConfig()

        self.encoder = encoder_config
        self.decoder = decoder_config
        self.is_encoder_decoder = True

        if isinstance(decoder_config, dict):
            self.decoder_start_token_id = decoder_config["bos_token_id"]
            self.pad_token_id = decoder_config["pad_token_id"]
            self.eos_token_id = decoder_config["eos_token_id"]
        else:
            self.decoder_start_token_id = decoder_config.bos_token_id
            self.pad_token_id = decoder_config.pad_token_id
            self.eos_token_id = decoder_config.eos_token_id


class DonutSwinTableRecConfig(PretrainedConfig):
    model_type = "donut-swin"

    attribute_map = {
        "num_attention_heads": "num_heads",
        "num_hidden_layers": "num_layers",
    }

    def __init__(
        self,
        image_size=(settings.TABLE_REC_IMAGE_SIZE["width"], settings.TABLE_REC_IMAGE_SIZE["height"]),
        patch_size=4,
        num_channels=3,
        embed_dim=128,
        depths=[2, 2, 12, 2],
        num_heads=[4, 8, 16, 32],
        num_kv_heads=[4, 8, 16, 32],
        window_size=8,
        mlp_ratio=4.0,
        qkv_bias=True,
        hidden_dropout_prob=0.0,
        attention_probs_dropout_prob=0.0,
        drop_path_rate=0.1,
        hidden_act="gelu",
        use_absolute_embeddings=False,
        initializer_range=0.02,
        layer_norm_eps=1e-5,
        encoder_length=1024,
        use_positional_embeddings=True,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.embed_dim = embed_dim
        self.depths = depths
        self.num_layers = len(depths)
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        self.qkv_bias = qkv_bias
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.drop_path_rate = drop_path_rate
        self.hidden_act = hidden_act
        self.use_absolute_embeddings = use_absolute_embeddings
        self.layer_norm_eps = layer_norm_eps
        self.initializer_range = initializer_range
        # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
        # this indicates the channel dimension after the last stage of the model
        self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
        self.encoder_length = encoder_length
        self.use_positional_embeddings = use_positional_embeddings


class SuryaTableRecDecoderConfig(PretrainedConfig):
    model_type = "surya_tablerec"

    def __init__(
        self,
        num_hidden_layers=6,
        vocab_size=BOX_DIM + 1,
        bbox_size=BOX_DIM,
        hidden_size=512,
        property_embed_size=64,
        box_embed_size=512 - 64,
        intermediate_size=4 * 512,
        encoder_hidden_size=1024,
        num_attention_heads=8,
        lru_width=None,
        attention_window_size=16,
        conv1d_width=4,
        logits_soft_cap=30.0,
        rms_norm_eps=1e-6,
        use_cache=True,
        pad_token_id=0,
        eos_token_id=1,
        bos_token_id=1,
        pause_token_id=2,
        query_end_token_id=4,
        hidden_activation="gelu_pytorch_tanh",
        rope_theta=10000.0,
        block_types=("attention",),
        cross_attn_layers=tuple(range(10)),
        encoder_cross_attn_layers=tuple(range(10)),
        self_attn_layers=tuple(range(10)),
        global_attn_layers=tuple(range(10)),
        attention_dropout=0.0,
        num_key_value_heads=4,
        attention_bias=False,
        w_init_variance_scale=0.01,
        init_std=0.02,
        tie_word_embeddings=False,
        aux_heads=0, # How many n-token-ahead heads to add
        causal=True,
        layer_norm_eps=1e-5,
        dropout=0.0,
        special_token_count=SPECIAL_TOKENS,
        **kwargs,
    ):
        self.num_hidden_layers = num_hidden_layers
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_attention_heads = num_attention_heads
        self.lru_width = lru_width if lru_width is not None else hidden_size
        self.attention_window_size = attention_window_size
        self.conv1d_width = conv1d_width
        self.logits_soft_cap = logits_soft_cap
        self.rms_norm_eps = rms_norm_eps
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.block_types = list(block_types)
        self.hidden_activation = hidden_activation
        self.head_dim = self.hidden_size // self.num_attention_heads
        self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
        if self.num_key_value_heads > self.num_attention_heads:
            raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
        self.cross_attn_layers = cross_attn_layers
        self.self_attn_layers = self_attn_layers
        self.global_attn_layers = global_attn_layers
        self.attention_dropout = attention_dropout
        self.attention_bias = attention_bias
        self.w_init_variance_scale = w_init_variance_scale
        self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
        self.init_std = init_std
        self.tie_word_embeddings = tie_word_embeddings
        self.aux_heads = aux_heads
        self.encoder_hidden_size=encoder_hidden_size
        self.causal = causal
        self.encoder_cross_attn_layers = encoder_cross_attn_layers
        self.layer_norm_eps = layer_norm_eps
        self.dropout = dropout
        self.bbox_size = bbox_size
        self.pause_token_id = pause_token_id
        self.box_properties = BOX_PROPERTIES
        self.property_embed_size = property_embed_size
        self.box_embed_size = box_embed_size
        self.special_token_count = special_token_count
        self.query_end_token_id = query_end_token_id
        self.double_residual_flow = False

        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            **kwargs,
        )

    @property
    def layers_block_type(self):
        return (self.block_types * 100)[: self.num_hidden_layers]
```

--------------------------------------------------------------------------------
/surya/common/surya/flash_attn_utils.py:
--------------------------------------------------------------------------------

```python
from typing import Optional
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
from flash_attn.bert_padding import index_first_axis as _index_first_axis
from flash_attn.bert_padding import pad_input

def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
    """
    Retrieves indexing data required to repad unpadded (ragged) tensors.

    Arguments:
        attention_mask (`torch.Tensor`):
            Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.

    Return:
        indices (`torch.Tensor`):
            The indices of non-masked tokens from the flattened input sequence.
        cu_seqlens (`torch.Tensor`):
            The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
        max_seqlen_in_batch (`int`):
            Maximum sequence length in batch.
    """
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )

def _upad_input(
    query_layer: torch.Tensor,
    key_layer: torch.Tensor,
    value_layer: torch.Tensor,
    query_length: int,
    indices_k,
    cu_seqlens_k,
    max_seqlen_in_batch_k
):
    """
    Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.

    This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
    tensors for query, key, value tensors.

    Arguments:
        query_layer (`torch.Tensor`):
            Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
        key_layer (`torch.Tensor`):
            Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
        value_layer (`torch.Tensor`):
            Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
        attention_mask (`torch.Tensor`):
            Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
        query_length (`int`):
            Target length.

    Return:
        query_layer (`torch.Tensor`):
            Query state without padding. Shape: (total_target_length, num_heads, head_dim).
        key_layer (`torch.Tensor`):
            Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
        value_layer (`torch.Tensor`):
            Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
        indices_q (`torch.Tensor`):
            The indices of non-masked tokens from the flattened input target sequence.
        (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
            The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
        (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
            Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
    """
    batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

    key_layer = _index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
    value_layer = _index_first_axis(
        value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
    )
    if query_length == kv_seq_len:
        query_layer = _index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
        cu_seqlens_q = cu_seqlens_k
        max_seqlen_in_batch_q = max_seqlen_in_batch_k
        indices_q = indices_k
    elif query_length == 1:
        max_seqlen_in_batch_q = 1
        cu_seqlens_q = torch.arange(
            batch_size + 1, dtype=torch.int32, device=query_layer.device
        )  # There is a memcpy here, that is very bad.
        indices_q = cu_seqlens_q[:-1]
        query_layer = query_layer.squeeze(1)
    else:
        raise NotImplementedError()

    return (
        query_layer,
        key_layer,
        value_layer,
        indices_q,
        (cu_seqlens_q, cu_seqlens_k),
        (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
    )

def flash_attn_prefill(
    module: torch.nn.Module,
    query_states: torch.Tensor,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    attention_mask: torch.Tensor,
    dropout: float,
    scaling: float,
    query_length: int,
    batch_size: int,
    indices_k: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_seqlen_in_batch_k: int,
    **kwargs
):
    """
    Wrapper for flash attention during the prefill stage
    query_states must have shape (batch_size, num_heads, seq_len, head_dim)
    key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim)

    This is the opposite of what is required by flash attention, but keeps parity with the HF convention

    query_length, batch_size, indices_k, cu_seqlens_k, and max_seqlen_in_batch_k should come from the flash attention kwargs
    """
    query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2)
    q_flash, k_flash, v_flash, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
        query_states, key_states, value_states, query_length, indices_k, cu_seqlens_k, max_seqlen_in_batch_k
    )
    cu_seqlens_q, cu_seqlens_k = cu_seq_lens
    max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

    # Returning None for attn_weights to match other attention interfaces
    flash_attn_out = _flash_attn_varlen_func(
        q_flash,
        k_flash,
        v_flash,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_k=cu_seqlens_k,
        max_seqlen_q=max_seqlen_in_batch_q,
        max_seqlen_k=max_seqlen_in_batch_k,
        dropout_p=dropout,
        softmax_scale=scaling,
        causal=module.is_causal,
    )
    return pad_input(flash_attn_out, indices_q, batch_size, query_length), None

# NOTE: Does not support dropout, accepts argument as kwargs to maintain compatibility
# This function is an order of magnitude faster than the prefill variant, or using the HF interface
def flash_attn_decode(
    module: torch.nn.Module,
    query_states: torch.Tensor,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    attention_mask: torch.Tensor,
    scaling: float,
    **kwargs,
):
    """
    Wrapper for flash attention during the decode stage
    
    query_states must have shape (batch_size, num_heads, seq_len, head_dim), 1 is the seq length in the decoding stage
    key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim)

    This is the opposite of what is required by flash attention, but keeps parity with the HF convention

    This function computes the left pad and cache seqlens to pass into FA2. For example - 
    Given an attention_mask shaped (batch_size=2, seq_len=8), where 0 = padding, 1 = real token
    attention_mask =
    tensor([
        [0, 0, 1, 1, 1, 0, 0, 0],  # ← batch 0
        [0, 1, 1, 1, 1, 1, 1, 0],  # ← batch 1
    ])
    cache_leftpad = tensor([2, 1], dtype=torch.int32)
    cache_seqlens = tensor([5, 7], dtype=torch.int32)
    These values allow FlashAttention to use a static cache layout with efficient slicing during decoding.
    """
    query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2)

    cache_leftpad = (attention_mask == 0).cumprod(dim=1).sum(dim=1).to(torch.int32)
    cache_seqlens = (attention_mask * torch.arange(attention_mask.size(1), device=attention_mask.device)).argmax(dim=1).to(torch.int32) + 1

    # Returning None for attn_weights to match other attention interfaces
    return _flash_attn_with_kvcache(
        q=query_states,
        k_cache=key_states,
        v_cache=value_states,
        cache_leftpad=cache_leftpad,
        cache_seqlens=cache_seqlens,
        causal=module.is_causal,
        softmax_scale=scaling,
    ), None
```

--------------------------------------------------------------------------------
/surya/common/util.py:
--------------------------------------------------------------------------------

```python
import copy
from typing import List
import torch
from functools import lru_cache

import torch.nn.functional as F

from surya.common.polygon import PolygonBox


def clean_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]:
    new_boxes = []
    for box_obj in boxes:
        xs = [point[0] for point in box_obj.polygon]
        ys = [point[1] for point in box_obj.polygon]
        if max(xs) == min(xs) or max(ys) == min(ys):
            continue

        box = box_obj.bbox
        contained = False
        for other_box_obj in boxes:
            if other_box_obj.polygon == box_obj.polygon:
                continue

            other_box = other_box_obj.bbox
            if box == other_box:
                continue
            if (
                box[0] >= other_box[0]
                and box[1] >= other_box[1]
                and box[2] <= other_box[2]
                and box[3] <= other_box[3]
            ):
                contained = True
                break
        if not contained:
            new_boxes.append(box_obj)
    return new_boxes


def rescale_bbox(bbox, processor_size, image_size):
    page_width, page_height = processor_size

    img_width, img_height = image_size
    width_scaler = img_width / page_width
    height_scaler = img_height / page_height

    new_bbox = copy.deepcopy(bbox)
    new_bbox[0] = int(new_bbox[0] * width_scaler)
    new_bbox[1] = int(new_bbox[1] * height_scaler)
    new_bbox[2] = int(new_bbox[2] * width_scaler)
    new_bbox[3] = int(new_bbox[3] * height_scaler)
    return new_bbox


def expand_bbox(bbox, expansion_factor=0.01):
    expansion_low = 1 - expansion_factor
    expansion_high = 1 + expansion_factor
    return [
        bbox[0] * expansion_low,
        bbox[1] * expansion_low,
        bbox[2] * expansion_high,
        bbox[3] * expansion_high,
    ]

SCRIPT_TOKEN_MAPPING = {
    "latin": "<SCRIPT-LATIN>",
    "punctuation": "<SCRIPT-PUNCTUATION>",
    "cyrillic": "<SCRIPT-CYRILLIC>",
    "arabic": "<SCRIPT-ARABIC>",
    "chinese": "<SCRIPT-CHINESE>",
    "japanese": "<SCRIPT-JAPANESE>",
    "korean": "<SCRIPT-KOREAN>",
    "symbols": "<SCRIPT-SYMBOLS>",
    "greek": "<SCRIPT-GREEK>",
    "armenian": "<SCRIPT-ARMENIAN>",
    "hebrew": "<SCRIPT-HEBREW>",
    "devanagari": "<SCRIPT-DEVANAGARI>",
    "bengali": "<SCRIPT-BENGALI>",
    "gurmukhi": "<SCRIPT-GURMUKHI>",
    "gujarati": "<SCRIPT-GUJARATI>",
    "oriya": "<SCRIPT-ORIYA>",
    "tamil": "<SCRIPT-TAMIL>",
    "telugu": "<SCRIPT-TELUGU>",
    "kannada": "<SCRIPT-KANNADA>",
    "malayalam": "<SCRIPT-MALAYALAM>",
    "sinhala": "<SCRIPT-SINHALA>",
    "thai": "<SCRIPT-THAI>",
    "lao": "<SCRIPT-LAO>",
    "myanmar": "<SCRIPT-MYANMAR>",
    "georgian": "<SCRIPT-GEORGIAN>",
    "ethiopic": "<SCRIPT-ETHIOPIC>",
    "khmer": "<SCRIPT-KHMER>",
    "mongolian": "<SCRIPT-MONGOLIAN>",
    "math": "<SCRIPT-MATH>",
}

@lru_cache(maxsize=1)
def script_ranges():
    script_categories = {
        # Latin-based scripts (used by English, French, German, etc.)
        "latin": [
            (0x0041, 0x005A),  # Latin uppercase A-Z
            (0x0061, 0x007A),  # Latin lowercase a-z
            (0x0080, 0x00FF),  # Latin-1 Supplement
            (0x0100, 0x017F),  # Latin Extended-A
            (0x0180, 0x024F),  # Latin Extended-B
            (0x0250, 0x02AF),  # IPA Extensions
            (0x02B0, 0x02FF),  # Spacing Modifier Letters
            (0x0300, 0x036F),  # Combining Diacritical Marks
            (0x1E00, 0x1EFF),  # Latin Extended Additional
            (0x2C60, 0x2C7F),  # Latin Extended-C
            (0xA720, 0xA7FF),  # Latin Extended-D
        ],
        # Punctuation, universal characters, and general symbols
        "punctuation": [
            (0x0020, 0x0020),  # Space
            (0x0021, 0x002F),  # Basic punctuation and symbols
            (0x0030, 0x0039),  # Digits 0-9
            (0x003A, 0x0040),  # More punctuation and symbols
            (0x005B, 0x0060),  # More punctuation and symbols
            (0x007B, 0x007F),  # More punctuation and symbols
            (0x2000, 0x206F),  # General Punctuation
        ],
        # Cyrillic scripts (used by Russian, Ukrainian, etc.)
        "cyrillic": [
            (0x0400, 0x04FF),  # Cyrillic
            (0x0500, 0x052F),  # Cyrillic Supplement
        ],
        # Arabic scripts
        "arabic": [
            (0x0600, 0x06FF),  # Arabic
            (0x0750, 0x077F),  # Arabic Supplement
            (0x08A0, 0x08FF),  # Arabic Extended-A
        ],
        # Chinese characters
        "chinese": [
            (0x4E00, 0x9FFF),  # Common CJK Unified Ideographs
            (0x3400, 0x4DBF),  # CJK Extension A
            (0x20000, 0x2A6DF),  # CJK Extension B
        ],
        # Japanese-specific scripts (excluding shared CJK)
        "japanese": [
            (0x3040, 0x30FF),  # Hiragana and Katakana
        ],
        # Korean-specific scripts
        "korean": [
            (0x1100, 0x11FF),  # Hangul Jamo
            (0x3130, 0x318F),  # Hangul Compatibility Jamo
            (0xAC00, 0xD7AF),  # Hangul Syllables
        ],
        # Various mathematical and technical symbols
        "symbols": [
            (0x2070, 0x209F),  # Superscripts and Subscripts
            (0x20A0, 0x20CF),  # Currency Symbols
            (0x2100, 0x214F),  # Letterlike Symbols
            (0x2150, 0x218F),  # Number Forms
            (0x2190, 0x21FF),  # Arrows
            (0x2200, 0x22FF),  # Mathematical Operators
            (0x2300, 0x23FF),  # Miscellaneous Technical
            (0x2500, 0x257F),  # Box Drawing
            (0x2580, 0x259F),  # Block Elements
            (0x25A0, 0x25FF),  # Geometric Shapes
            (0x2600, 0x26FF),  # Miscellaneous Symbols
            (0x2700, 0x27BF),  # Dingbats
            (0x27C0, 0x27EF),  # Miscellaneous Mathematical Symbols-A
            (0x2980, 0x29FF),  # Miscellaneous Mathematical Symbols-B
            (0x2A00, 0x2AFF),  # Supplemental Mathematical Operators
            (0x1D400, 0x1D7FF),  # Mathematical Alphanumeric Symbols
        ],
        # Individual scripts for languages with unique writing systems
        "greek": [(0x0370, 0x03FF)],  # Greek and Coptic
        "armenian": [(0x0530, 0x058F)],  # Armenian
        "hebrew": [(0x0590, 0x05FF)],  # Hebrew
        "devanagari": [(0x0900, 0x097F)],  # Devanagari (Hindi, Sanskrit)
        "bengali": [(0x0980, 0x09FF)],  # Bengali
        "gurmukhi": [(0x0A00, 0x0A7F)],  # Gurmukhi (Punjabi)
        "gujarati": [(0x0A80, 0x0AFF)],  # Gujarati
        "oriya": [(0x0B00, 0x0B7F)],  # Oriya
        "tamil": [(0x0B80, 0x0BFF)],  # Tamil
        "telugu": [(0x0C00, 0x0C7F)],  # Telugu
        "kannada": [(0x0C80, 0x0CFF)],  # Kannada
        "malayalam": [(0x0D00, 0x0D7F)],  # Malayalam
        "sinhala": [(0x0D80, 0x0DFF)],  # Sinhala
        "thai": [(0x0E00, 0x0E7F)],  # Thai
        "lao": [(0x0E80, 0x0EFF)],  # Lao
        "myanmar": [(0x1000, 0x109F)],  # Myanmar
        "georgian": [(0x10A0, 0x10FF)],  # Georgian
        "ethiopic": [(0x1200, 0x137F)],  # Ethiopic
        "khmer": [(0x1780, 0x17FF)],  # Khmer
        "mongolian": [(0x1800, 0x18AF)],  # Mongolian
    }

    # Convert to a flat structure with character ranges
    flat_ranges = {}
    for category, ranges in script_categories.items():
        # Create a set of all characters in this category
        char_set = set()
        for start, end in ranges:
            char_set.update(range(start, end + 1))

        # Store the set in flat_ranges
        flat_ranges[category] = char_set

    return script_categories, flat_ranges

def get_top_scripts(text: str, max_scripts: int = 5):
    script_categories, flat_ranges = script_ranges()
    char_count = {category: 0 for category in script_categories.keys()}
    for char in text:
        for category, char_set in flat_ranges.items():
            if ord(char) in char_set:
                char_count[category] += 1
                break

    top_scripts = sorted(char_count.items(), key=lambda x: x[1], reverse=True)
    top_scripts = [ts[0] for ts in top_scripts if ts[1] > 0]
    if "<math" in text:
        top_scripts.insert(0, "math")

    return top_scripts[:max_scripts]

def is_flash_attn_2_supported(device: str | torch.device) -> bool:
    if not torch.cuda.is_available():
        return False

    if "cuda" not in str(device):
        return False

    # Check CUDA version >= 12.0
    cuda_version_str = torch.version.cuda
    if cuda_version_str is None:
        return False
    cuda_version = tuple(map(int, cuda_version_str.split(".")))
    if cuda_version < (12, 0):
        return False

    # Check GPU compute capability (Ampere, Ada, Hopper GPUs)
    major, minor = torch.cuda.get_device_capability()
    compute_capability = major + minor / 10
    if compute_capability < 8.0:
        return False

    return True


def pad_to_batch_size_repeat(tensor: torch.Tensor, batch_size: int):
    current_batch_size = tensor.shape[0]
    if current_batch_size >= batch_size:
        return tensor

    pad_size = batch_size - current_batch_size
    if pad_size < 0:
        return tensor

    # Repeat the last row pad_size times
    last_row = tensor[-1:].repeat(pad_size, 1, 1)

    # Concatenate original tensor with repeated last rows
    return torch.cat([tensor, last_row], dim=0)


def pad_to_batch_size(tensor: torch.Tensor, batch_size: int):
    current_batch_size = tensor.shape[0]
    if current_batch_size >= batch_size:
        return tensor

    pad_size = batch_size - current_batch_size
    padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)

    return F.pad(tensor, padding, mode="constant", value=0)

```

--------------------------------------------------------------------------------
/surya/scripts/streamlit_app.py:
--------------------------------------------------------------------------------

```python
import io
import tempfile
from typing import List

import pypdfium2
import streamlit as st

from surya.common.surya.schema import TaskNames
from surya.models import load_predictors

from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image

from surya.debug.text import draw_text_on_image
from PIL import Image, ImageDraw
from surya.table_rec import TableResult
from surya.detection import TextDetectionResult
from surya.recognition import OCRResult
from surya.layout import LayoutResult
from surya.settings import settings
from surya.common.util import rescale_bbox, expand_bbox


@st.cache_resource()
def load_predictors_cached():
    return load_predictors()


def ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
    from pdftext.extraction import plain_text_output

    with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
        f.write(pdf_file.getvalue())
        f.seek(0)

        # Sample the text from the middle of the PDF
        page_middle = page_count // 2
        page_range = range(
            max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count)
        )
        text = plain_text_output(f.name, page_range=page_range)

    sample_gap = len(text) // max_samples
    if len(text) == 0 or sample_gap == 0:
        return "This PDF has no text or very little text", ["no text"]

    if sample_gap < sample_len:
        sample_gap = sample_len

    # Split the text into samples for the model
    samples = []
    for i in range(0, len(text), sample_gap):
        samples.append(text[i : i + sample_len])

    results = predictors["ocr_error"](samples)
    label = "This PDF has good text."
    if results.labels.count("bad") / len(results.labels) > 0.2:
        label = "This PDF may have garbled or bad OCR text."
    return label, results.labels


def text_detection(img) -> (Image.Image, TextDetectionResult):
    text_pred = predictors["detection"]([img])[0]
    text_polygons = [p.polygon for p in text_pred.bboxes]
    det_img = draw_polys_on_image(text_polygons, img.copy())
    return det_img, text_pred


def layout_detection(img) -> (Image.Image, LayoutResult):
    pred = predictors["layout"]([img])[0]
    polygons = [p.polygon for p in pred.bboxes]
    labels = [
        f"{p.label}-{p.position}-{round(p.top_k[p.label], 2)}" for p in pred.bboxes
    ]
    layout_img = draw_polys_on_image(
        polygons, img.copy(), labels=labels, label_font_size=18
    )
    return layout_img, pred


def table_recognition(
    img, highres_img, skip_table_detection: bool
) -> (Image.Image, List[TableResult]):
    if skip_table_detection:
        layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
        table_imgs = [highres_img]
    else:
        _, layout_pred = layout_detection(img)
        layout_tables_lowres = [
            line.bbox
            for line in layout_pred.bboxes
            if line.label in ["Table", "TableOfContents"]
        ]
        table_imgs = []
        layout_tables = []
        for tb in layout_tables_lowres:
            highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
            # Slightly expand the box
            highres_bbox = expand_bbox(highres_bbox)
            table_imgs.append(highres_img.crop(highres_bbox))
            layout_tables.append(highres_bbox)

    table_preds = predictors["table_rec"](table_imgs)
    table_img = img.copy()

    for results, table_bbox in zip(table_preds, layout_tables):
        adjusted_bboxes = []
        labels = []
        colors = []

        for item in results.cells:
            adjusted_bboxes.append(
                [
                    (item.bbox[0] + table_bbox[0]),
                    (item.bbox[1] + table_bbox[1]),
                    (item.bbox[2] + table_bbox[0]),
                    (item.bbox[3] + table_bbox[1]),
                ]
            )
            labels.append(item.label)
            if "Row" in item.label:
                colors.append("blue")
            else:
                colors.append("red")
        table_img = draw_bboxes_on_image(
            adjusted_bboxes,
            highres_img,
            labels=labels,
            label_font_size=18,
            color=colors,
        )
    return table_img, table_preds


# Function for OCR
def ocr(
    img: Image.Image,
    highres_img: Image.Image,
    skip_text_detection: bool = False,
    recognize_math: bool = True,
    with_bboxes: bool = True,
) -> (Image.Image, OCRResult):
    if skip_text_detection:
        img = highres_img
        bboxes = [[[0, 0, img.width, img.height]]]
    else:
        bboxes = None

    if with_bboxes:
        tasks = [TaskNames.ocr_with_boxes]
    else:
        tasks = [TaskNames.ocr_without_boxes]

    img_pred = predictors["recognition"](
        [img],
        task_names=tasks,
        bboxes=bboxes,
        det_predictor=predictors["detection"],
        highres_images=[highres_img],
        math_mode=recognize_math,
        return_words=True,
    )[0]

    bboxes = [line.bbox for line in img_pred.text_lines]
    text = [line.text for line in img_pred.text_lines]
    rec_img = draw_text_on_image(bboxes, text, img.size)

    word_boxes = []
    for line in img_pred.text_lines:
        if line.words:
            word_boxes.extend([word.bbox for word in line.words])

    box_img = img.copy()
    draw = ImageDraw.Draw(box_img)
    for word_box in word_boxes:
        draw.rectangle(word_box, outline="red", width=2)

    return rec_img, img_pred, box_img


def open_pdf(pdf_file):
    stream = io.BytesIO(pdf_file.getvalue())
    return pypdfium2.PdfDocument(stream)


@st.cache_data()
def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
    doc = open_pdf(pdf_file)
    renderer = doc.render(
        pypdfium2.PdfBitmap.to_pil,
        page_indices=[page_num - 1],
        scale=dpi / 72,
    )
    png = list(renderer)[0]
    png_image = png.convert("RGB")
    doc.close()
    return png_image


@st.cache_data()
def page_counter(pdf_file):
    doc = open_pdf(pdf_file)
    doc_len = len(doc)
    doc.close()
    return doc_len


st.set_page_config(layout="wide")
col1, col2 = st.columns([0.5, 0.5])

predictors = load_predictors_cached()

st.markdown("""
# Surya OCR Demo

This app will let you try surya, a multilingual OCR toolkit.

Notes:

- This works best on documents with printed text.
- For OCR, the formatting (math, italics, etc) will not show up in the image preview, but it will show up in the returned text lines.
- If OCR doesn't work, try changing the resolution of your image (increase if below 2048px width, otherwise decrease).

Find the project [here](https://github.com/VikParuchuri/surya).
""")

in_file = st.sidebar.file_uploader(
    "PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"]
)

if in_file is None:
    st.stop()

filetype = in_file.type
page_count = None
if "pdf" in filetype:
    page_count = page_counter(in_file)
    page_number = st.sidebar.number_input(
        f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count
    )

    pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI)
    pil_image_highres = get_page_image(
        in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES
    )
else:
    pil_image = Image.open(in_file).convert("RGB")
    pil_image_highres = pil_image
    page_number = None

run_text_det = st.sidebar.button("Run Text Detection")
run_text_rec = st.sidebar.button("Run OCR")
run_layout_det = st.sidebar.button("Run Layout Analysis")
run_table_rec = st.sidebar.button("Run Table Rec")
run_ocr_errors = st.sidebar.button("Run bad PDF text detection")
use_pdf_boxes = st.sidebar.checkbox(
    "PDF table boxes",
    value=True,
    help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.",
)
skip_table_detection = st.sidebar.checkbox(
    "Skip table detection",
    value=False,
    help="Table recognition only: Skip table detection and treat the whole image/page as a table.",
)
skip_text_detection = st.sidebar.checkbox(
    "Skip text detection",
    value=False,
    help="OCR only: Skip text detection and treat the whole image as a single line.",
)
recognize_math = st.sidebar.checkbox(
    "Recognize math in OCR",
    value=True,
    help="Enable math mode in OCR - this will recognize math.",
)
ocr_with_boxes = st.sidebar.checkbox(
    "OCR with boxes",
    value=True,
    help="Enable OCR with boxes - this will predict character-level boxes.",
)

if pil_image is None:
    st.stop()

# Run Text Detection
if run_text_det:
    det_img, text_pred = text_detection(pil_image)
    with col1:
        st.image(det_img, caption="Detected Text", use_container_width=True)
        st.json(
            text_pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True
        )


# Run layout
if run_layout_det:
    layout_img, pred = layout_detection(pil_image)
    with col1:
        st.image(layout_img, caption="Detected Layout", use_container_width=True)
        st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True)

# Run OCR
if run_text_rec:
    rec_img, pred, box_img = ocr(
        pil_image,
        pil_image_highres,
        skip_text_detection,
        recognize_math,
        with_bboxes=ocr_with_boxes,
    )
    with col1:
        st.image(rec_img, caption="OCR Result", use_container_width=True)
        json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"])
        with json_tab:
            st.json(pred.model_dump(), expanded=False)
        with text_tab:
            st.text("\n".join([p.text for p in pred.text_lines]))

        st.image(
            box_img,
            caption="OCR with Word Boxes (for debugging)",
            use_container_width=True,
        )


if run_table_rec:
    table_img, pred = table_recognition(
        pil_image, pil_image_highres, skip_table_detection
    )
    with col1:
        st.image(table_img, caption="Table Recognition", use_container_width=True)
        st.json([p.model_dump() for p in pred], expanded=True)

if run_ocr_errors:
    if "pdf" not in filetype:
        st.error("This feature only works with PDFs.")
    label, results = ocr_errors(in_file, page_count)
    with col1:
        st.write(label)
        st.json(results)

with col2:
    st.image(pil_image, caption="Uploaded Image", use_container_width=True)

```

--------------------------------------------------------------------------------
/benchmark/recognition.py:
--------------------------------------------------------------------------------

```python
import re
import unicodedata
from collections import defaultdict

import click

from benchmark.utils.scoring import overlap_score, overlap_score_exact
from surya.input.processing import convert_if_not_rgb
from surya.debug.text import draw_text_on_image
from surya.foundation import FoundationPredictor
from surya.recognition import RecognitionPredictor
from surya.settings import settings
from surya.recognition.languages import CODE_TO_LANGUAGE
from benchmark.utils.tesseract import (
    tesseract_ocr_parallel,
    surya_lang_to_tesseract,
    TESS_CODE_TO_LANGUAGE,
)
from benchmark.utils.textract import textract_ocr_parallel
import os
import datasets
import json
import time
from tabulate import tabulate

KEY_LANGUAGES = [
    "Chinese",
    "Spanish",
    "English",
    "Arabic",
    "Hindi",
    "Bengali",
    "Russian",
    "Japanese",
]


def list_in(lst: str | list, lst2: list):
    if isinstance(lst, str):
        lst = [lst]
    return any([item in lst for item in lst2])


def standardize_bullets(text):
    patterns = [
        r"•\s+",
        r"·\s+",
        r"○\s+",
        r"◦\s+",
        r"▪\s+",
        r"▫\s+",
        r"➢\s+",
        r"➤\s+",
        r"★\s+",
        r"✓\s+",
        r"✗\s+",
        r"✦\s+",
        r"\\bullet\s+",
    ]

    combined_pattern = "|".join(patterns)
    text = re.sub(combined_pattern, "*", text)

    return text


def normalize_text(text: str) -> str:
    # Remove HTML tags
    text = re.sub(r"<[^>]+>", "", text)
    # Remove LaTeX tags
    text = re.sub(r"\\[a-zA-Z]+", "", text)
    text = standardize_bullets(text)
    text = unicodedata.normalize("NFKC", text)
    return text.strip().lower().replace(",", ".")


@click.command(help="Benchmark recognition model.")
@click.option(
    "--results_dir",
    type=str,
    help="Path to JSON file with OCR results.",
    default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
@click.option(
    "--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=None
)
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
@click.option(
    "--tesseract", is_flag=True, help="Run benchmarks on tesseract.", default=False
)
@click.option(
    "--textract", is_flag=True, help="Run benchmarks on textract.", default=False
)
@click.option(
    "--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28
)
@click.option(
    "--textract_cpus", type=int, help="Number of CPUs to use for textract.", default=28
)
@click.option(
    "--languages",
    type=str,
    help="Comma-separated list of languages to benchmark.",
    default=None,
)
@click.option(
    "--print_results",
    is_flag=True,
)
def main(
    results_dir: str,
    max_rows: int,
    debug: bool,
    tesseract: bool,
    textract: bool,
    tess_cpus: int,
    textract_cpus: int,
    languages: str | None,
    print_results: bool,
):
    foundation_predictor = FoundationPredictor()
    rec_predictor = RecognitionPredictor(foundation_predictor)

    split = "train"
    dataset = datasets.load_dataset(
        settings.RECOGNITION_BENCH_DATASET_NAME, split=split
    )

    if languages:
        languages = languages.split(",")
        dataset = dataset.filter(
            lambda x: list_in(x["language"], languages), num_proc=4
        )

    if max_rows and max_rows < len(dataset):
        dataset = dataset.shuffle(seed=1).select(range(max_rows))

    images = list(dataset["image"])
    images = convert_if_not_rgb(images)
    bboxes = list(dataset["bboxes"])
    line_text = list(dataset["text"])
    languages = list(dataset["language"])

    print(f"Loaded {len(images)} images. Running OCR...")

    start = time.time()
    predictions_by_image = rec_predictor(images, None, bboxes=bboxes)
    surya_time = time.time() - start

    lang_list = []
    for lang in languages:
        if not isinstance(lang, list):
            lang_list.append([lang])
        else:
            lang_list.append(lang)

    surya_scores = defaultdict(list)
    img_surya_scores = []
    outputs = []
    for idx, (pred, ref_text, langs) in enumerate(
        zip(predictions_by_image, line_text, lang_list)
    ):
        pred_text = [line.text for line in pred.text_lines]

        score_ref_text = [normalize_text(line) for line in ref_text]
        score_pred_text = [normalize_text(text) for text in pred_text]
        image_scores, image_weights = overlap_score_exact(
            score_pred_text, score_ref_text
        )
        normalized_scores = [
            score / max(1, weight) for score, weight in zip(image_scores, image_weights)
        ]
        image_score = sum(image_scores) / max(1, sum(image_weights))

        img_surya_scores.append(image_score)
        for lang in langs:
            surya_scores[CODE_TO_LANGUAGE[lang]].append(image_score)

        assert len(pred_text) == len(ref_text) == len(bboxes[idx])
        if debug:
            for j, (pred_line, ref_line, score, bbox) in enumerate(
                zip(pred_text, ref_text, normalized_scores, bboxes[idx])
            ):
                image_slice = images[idx].crop(bbox)

                outputs.append(
                    {
                        "image": image_slice,
                        "bbox": bbox,
                        "score": score,
                        "pred": pred_line,
                        "ref": ref_line,
                        "langs": ",".join(langs),
                    }
                )

    if debug:
        out_ds = datasets.Dataset.from_list(outputs)
        out_ds.push_to_hub("datalab-to/rec_bench_outputs", private=True)

    flat_surya_scores = [score for lang in surya_scores for score in surya_scores[lang]]
    benchmark_stats = {
        "surya": {
            "avg_score": sum(flat_surya_scores) / max(1, len(flat_surya_scores)),
            "lang_scores": {
                lang: sum(scores) / max(1, len(scores))
                for lang, scores in surya_scores.items()
            },
            "time_per_img": surya_time / max(1, len(images)),
        }
    }

    result_path = os.path.join(results_dir, "rec_bench")
    os.makedirs(result_path, exist_ok=True)

    with open(os.path.join(result_path, "surya_scores.json"), "w+") as f:
        json.dump(surya_scores, f)

    if tesseract:
        tess_valid = []
        tess_langs = []
        for idx, lang in enumerate(lang_list):
            # Tesseract does not support all languages
            tess_lang = surya_lang_to_tesseract(lang[0])
            if tess_lang is None:
                continue

            tess_valid.append(idx)
            tess_langs.append(tess_lang)

        tess_imgs = [images[i] for i in tess_valid]
        tess_bboxes = [bboxes[i] for i in tess_valid]
        tess_reference = [line_text[i] for i in tess_valid]
        start = time.time()
        tess_predictions = tesseract_ocr_parallel(
            tess_imgs, tess_bboxes, tess_langs, cpus=tess_cpus
        )
        tesseract_time = time.time() - start

        tess_scores = defaultdict(list)
        for idx, (pred, ref_text, lang) in enumerate(
            zip(tess_predictions, tess_reference, tess_langs)
        ):
            image_scores, image_weights, _ = overlap_score(pred, ref_text)
            image_score = sum(image_scores) / max(1, sum(image_weights))
            tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score)

        flat_tess_scores = [
            score for lang in tess_scores for score in tess_scores[lang]
        ]
        benchmark_stats["tesseract"] = {
            "avg_score": sum(flat_tess_scores) / len(flat_tess_scores),
            "lang_scores": {
                lang: sum(scores) / len(scores) for lang, scores in tess_scores.items()
            },
            "time_per_img": tesseract_time / len(tess_imgs),
        }

        with open(os.path.join(result_path, "tesseract_scores.json"), "w+") as f:
            json.dump(tess_scores, f)

    if textract:
        start = time.time()
        textract_predictions = textract_ocr_parallel(images, cpus=textract_cpus)
        textract_time = time.time() - start

        textract_scores = defaultdict(list)
        for idx, (pred, ref_text, lang) in enumerate(
            zip(textract_predictions, line_text, lang_list)
        ):
            image_scores, image_weights, _ = overlap_score(pred, ref_text)
            image_score = sum(image_scores) / max(1, sum(image_weights))

            for lang in lang:
                textract_scores[CODE_TO_LANGUAGE[lang]].append(image_score)

        flat_textract_scores = [
            score for lang in textract_scores for score in textract_scores[lang]
        ]
        benchmark_stats["textract"] = {
            "avg_score": sum(flat_textract_scores) / len(flat_textract_scores),
            "lang_scores": {
                lang: sum(scores) / len(scores)
                for lang, scores in textract_scores.items()
            },
            "time_per_img": textract_time / len(images),
        }
        print(len(flat_textract_scores))

        with open(os.path.join(result_path, "textract_scores.json"), "w+") as f:
            json.dump(textract_scores, f)

    with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
        json.dump(benchmark_stats, f)

    key_languages = [k for k in KEY_LANGUAGES if k in surya_scores]
    table_headers = ["Model", "Time per page (s)", "Avg Score"] + key_languages
    table_data = [
        [
            "surya",
            benchmark_stats["surya"]["time_per_img"],
            benchmark_stats["surya"]["avg_score"],
        ]
        + [benchmark_stats["surya"]["lang_scores"][lang] for lang in key_languages],
    ]
    if tesseract:
        table_data.append(
            [
                "tesseract",
                benchmark_stats["tesseract"]["time_per_img"],
                benchmark_stats["tesseract"]["avg_score"],
            ]
            + [
                benchmark_stats["tesseract"]["lang_scores"].get(lang, 0)
                for lang in key_languages
            ]
        )
    if textract:
        table_data.append(
            [
                "textract",
                benchmark_stats["textract"]["time_per_img"],
                benchmark_stats["textract"]["avg_score"],
            ]
            + [
                benchmark_stats["textract"]["lang_scores"][lang]
                for lang in key_languages
            ],
        )

    print(tabulate(table_data, headers=table_headers, tablefmt="github"))
    print(
        "Only a few major languages are displayed. See the result path for additional languages."
    )

    if debug >= 1:
        bad_detections = []
        for idx, (score, lang) in enumerate(zip(flat_surya_scores, lang_list)):
            if score < 0.8:
                bad_detections.append((idx, lang, score))
        print(f"Found {len(bad_detections)} bad detections. Writing to file...")
        with open(os.path.join(result_path, "bad_detections.json"), "w+") as f:
            json.dump(bad_detections, f)

    if debug == 2:
        for idx, (image, pred, ref_text, bbox, lang) in enumerate(
            zip(images, predictions_by_image, line_text, bboxes, lang_list)
        ):
            pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png"
            ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png"
            pred_text = [line.text for line in pred.text_lines]
            pred_image = draw_text_on_image(bbox, pred_text, image.size)
            pred_image.save(os.path.join(result_path, pred_image_name))
            ref_image = draw_text_on_image(bbox, ref_text, image.size)
            ref_image.save(os.path.join(result_path, ref_image_name))
            image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png"))

    print(f"Wrote results to {result_path}")

    if print_results:
        for idx, (pred, ref_text) in enumerate(zip(predictions_by_image, line_text)):
            print(f"Image {idx}")
            print("----")
            for line_idx, (pred_line, ref_line) in enumerate(
                zip(pred.text_lines, ref_text)
            ):
                print(f"Sample {line_idx}")
                print(f"Pred: {pred_line.text}")
                print(f"Ref: {ref_line}")
                print()

    if settings.TORCH_DEVICE == "xla":
        import torch_xla.debug.metrics as met

        print(met.short_metrics_report())


if __name__ == "__main__":
    main()

```

--------------------------------------------------------------------------------
/surya/detection/processor.py:
--------------------------------------------------------------------------------

```python
import warnings
from typing import Any, Dict, List, Optional, Union

import numpy as np

from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_transforms import to_channel_dimension_format
from transformers.image_utils import (
    IMAGENET_DEFAULT_MEAN,
    IMAGENET_DEFAULT_STD,
    ChannelDimension,
    ImageInput,
    PILImageResampling,
    infer_channel_dimension_format,
    make_list_of_images,
)
from transformers.utils import TensorType


import PIL.Image
import torch

from surya.common.s3 import S3DownloaderMixin


class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
    r"""
    Constructs a Segformer image processor.

    Args:
        do_resize (`bool`, *optional*, defaults to `True`):
            Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
            size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
        size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`):
            Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
            method.
        resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
            Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
            `preprocess` method.
        do_rescale (`bool`, *optional*, defaults to `True`):
            Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
            parameter in the `preprocess` method.
        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
            method.
        do_normalize (`bool`, *optional*, defaults to `True`):
            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
            method.
        image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
            Mean to use if normalizing the image. This is a float or list of floats the length of the number of
            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
        image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
        do_reduce_labels (`bool`, *optional*, defaults to `False`):
            Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
            used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
            background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
            `preprocess` method.
    """

    model_input_names = ["pixel_values"]

    def __init__(
        self,
        do_resize: bool = True,
        size: Dict[str, int] = None,
        resample: PILImageResampling = PILImageResampling.BILINEAR,
        do_rescale: bool = True,
        rescale_factor: Union[int, float] = 1 / 255,
        do_normalize: bool = True,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        do_reduce_labels: bool = False,
        **kwargs,
    ) -> None:
        if "reduce_labels" in kwargs:
            warnings.warn(
                "The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use "
                "`do_reduce_labels` instead.",
                FutureWarning,
            )
            do_reduce_labels = kwargs.pop("reduce_labels")

        super().__init__(**kwargs)
        size = size if size is not None else {"height": 512, "width": 512}
        size = get_size_dict(size)
        self.do_resize = do_resize
        self.size = size
        self.resample = resample
        self.do_rescale = do_rescale
        self.rescale_factor = rescale_factor
        self.do_normalize = do_normalize
        self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
        self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
        self.do_reduce_labels = do_reduce_labels
        self._valid_processor_keys = [
            "images",
            "segmentation_maps",
            "do_resize",
            "size",
            "resample",
            "do_rescale",
            "rescale_factor",
            "do_normalize",
            "image_mean",
            "image_std",
            "do_reduce_labels",
            "return_tensors",
            "data_format",
            "input_data_format",
        ]

    @classmethod
    def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
        """
        Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image
        processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint,
        reduce_labels=True)`
        """
        image_processor_dict = image_processor_dict.copy()
        if "reduce_labels" in kwargs:
            image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
        return super().from_dict(image_processor_dict, **kwargs)

    def _preprocess(
        self,
        image: ImageInput,
        do_resize: bool,
        do_rescale: bool,
        do_normalize: bool,
        size: Optional[Dict[str, int]] = None,
        resample: PILImageResampling = None,
        rescale_factor: Optional[float] = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
    ):

        if do_rescale:
            image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

        if do_normalize:
            image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)

        return image

    def _preprocess_image(
        self,
        image: ImageInput,
        do_resize: bool = None,
        size: Dict[str, int] = None,
        resample: PILImageResampling = None,
        do_rescale: bool = None,
        rescale_factor: float = None,
        do_normalize: bool = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        data_format: Optional[Union[str, ChannelDimension]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
    ) -> np.ndarray:
        """Preprocesses a single image."""
        # All transformations expect numpy arrays.
        if input_data_format is None:
            input_data_format = infer_channel_dimension_format(image)

        image = self._preprocess(
            image=image,
            do_resize=do_resize,
            size=size,
            resample=resample,
            do_rescale=do_rescale,
            rescale_factor=rescale_factor,
            do_normalize=do_normalize,
            image_mean=image_mean,
            image_std=image_std,
            input_data_format=input_data_format,
        )
        if data_format is not None:
            image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
        return image

    def __call__(self, images, segmentation_maps=None, **kwargs):
        """
        Preprocesses a batch of images and optionally segmentation maps.

        Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
        passed in as positional arguments.
        """
        return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)

    def preprocess(
        self,
        images: ImageInput,
        segmentation_maps: Optional[ImageInput] = None,
        do_resize: Optional[bool] = None,
        size: Optional[Dict[str, int]] = None,
        resample: PILImageResampling = None,
        do_rescale: Optional[bool] = None,
        rescale_factor: Optional[float] = None,
        do_normalize: Optional[bool] = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        do_reduce_labels: Optional[bool] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        data_format: ChannelDimension = ChannelDimension.FIRST,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
    ) -> PIL.Image.Image:
        """
        Preprocess an image or batch of images.

        Args:
            images (`ImageInput`):
                Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
                passing in images with pixel values between 0 and 1, set `do_rescale=False`.
            segmentation_maps (`ImageInput`, *optional*):
                Segmentation map to preprocess.
            do_resize (`bool`, *optional*, defaults to `self.do_resize`):
                Whether to resize the image.
            size (`Dict[str, int]`, *optional*, defaults to `self.size`):
                Size of the image after `resize` is applied.
            resample (`int`, *optional*, defaults to `self.resample`):
                Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
                has an effect if `do_resize` is set to `True`.
            do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
                Whether to rescale the image values between [0 - 1].
            rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
                Rescale factor to rescale the image by if `do_rescale` is set to `True`.
            do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
                Whether to normalize the image.
            image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
                Image mean.
            image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
                Image standard deviation.
            do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
                Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
                is used for background, and background itself is not included in all classes of a dataset (e.g.
                ADE20k). The background label will be replaced by 255.
            return_tensors (`str` or `TensorType`, *optional*):
                The type of tensors to return. Can be one of:
                    - Unset: Return a list of `np.ndarray`.
                    - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
                    - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
                    - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
                    - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
            data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
                The channel dimension format for the output image. Can be one of:
                    - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                    - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
        """
        do_resize = do_resize if do_resize is not None else self.do_resize
        do_rescale = do_rescale if do_rescale is not None else self.do_rescale
        do_normalize = do_normalize if do_normalize is not None else self.do_normalize
        resample = resample if resample is not None else self.resample
        size = size if size is not None else self.size
        rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
        image_mean = image_mean if image_mean is not None else self.image_mean
        image_std = image_std if image_std is not None else self.image_std

        images = make_list_of_images(images)
        images = [
            self._preprocess_image(
                image=img,
                do_resize=do_resize,
                resample=resample,
                size=size,
                do_rescale=do_rescale,
                rescale_factor=rescale_factor,
                do_normalize=do_normalize,
                image_mean=image_mean,
                image_std=image_std,
                data_format=data_format,
                input_data_format=input_data_format,
            )
            for img in images
        ]

        data = {"pixel_values": images}
        return BatchFeature(data=data, tensor_type=return_tensors)
```

--------------------------------------------------------------------------------
/surya/foundation/cache/dynamic_ops.py:
--------------------------------------------------------------------------------

```python
from typing import Any, Dict, List, Optional, Tuple
import torch
from transformers import PretrainedConfig

"""
Special cache class for the surya foundation model that supports - 
1) Static shape
2) A custom sliding window, where image tokens stay in cache, and text tokens are popped
3) Continuous batching - merging etc
4) Attention mask management - To match with what's currently in the cache

Heavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079
"""


class DynamicOpsCache:
    def __init__(
        self,
        config: PretrainedConfig,
        batch_size: int,
        max_cache_len: int,
        text_sliding_window: int,
        device: int,
        dtype: int,
    ):
        self.text_sliding_window = text_sliding_window
        self.num_layers = config.num_hidden_layers
        self.max_batch_size = batch_size
        self.max_cache_len = max_cache_len
        self.head_dim = (
            getattr(config, "head_dim", None)
            or config.hidden_size // config.num_attention_heads
        )
        self._dtype = dtype
        self.num_key_value_heads = (
            config.num_attention_heads
            if getattr(config, "num_key_value_heads", None) is None
            else config.num_key_value_heads
        )

        # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125
        self.key_cache: list[torch.Tensor] = []
        self.value_cache: list[torch.Tensor] = []
        cache_shape = (
            self.max_batch_size,
            self.num_key_value_heads,
            self.max_cache_len,
            self.head_dim,
        )
        device = torch.device(device) if device is not None else None
        for _ in range(config.num_hidden_layers):
            new_layer_key_cache = torch.zeros(
                cache_shape, dtype=self._dtype, device=device
            )
            new_layer_value_cache = torch.zeros(
                cache_shape, dtype=self._dtype, device=device
            )
            torch._dynamo.mark_static_address(new_layer_key_cache)
            torch._dynamo.mark_static_address(new_layer_value_cache)
            self.key_cache.append(new_layer_key_cache)
            self.value_cache.append(new_layer_value_cache)

        self.attention_mask = torch.zeros(
            (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long
        )
        self.text_token_counts = [
            torch.zeros(self.max_batch_size, dtype=torch.long, device=device)
            for _ in range(self.num_layers)
        ]

        self.dtype = dtype
        self.device = device

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        prefill = cache_kwargs.get("prefill", False)
        update_fn = self._prefill_update if prefill else self._decode_update
        return update_fn(
            self.key_cache[layer_idx],
            self.value_cache[layer_idx],
            key_states,
            value_states,
            self.text_token_counts[layer_idx],
            cache_kwargs,
        )

    def update_text_counts(
        self,
        merge_idxs: torch.Tensor,
        valid_batch_size: torch.Tensor,
        new_text_lens: torch.Tensor,
    ):
        new_text_len_tensor = new_text_lens.to(device=self.device)

        for layer_idx in range(self.num_layers):
            self.text_token_counts[layer_idx][merge_idxs] = new_text_len_tensor[
                :valid_batch_size
            ]

    # Mirrors the logic from _prefill_update
    # Logic is better explained in this funcrtion
    def prefill_attention_mask_update(
        self,
        prefill_attention_mask: torch.Tensor,
        merge_idxs: torch.Tensor,
        valid_batch_mask: torch.Tensor,
        text_lengths: List[int],
    ):
        seq_len = prefill_attention_mask.shape[1]
        sliding_window = self.text_sliding_window
        total_cache_len = self.max_cache_len
        prefix_cache_space = total_cache_len - sliding_window

        for batch_idx, cache_idx in enumerate(merge_idxs):
            text_len = text_lengths[batch_idx]
            prefix_len = seq_len - text_len
            self.attention_mask[cache_idx] = 0  # Set default

            assert prefix_len > 0, "There are no prefix (image) tokens!"

            end_pos = prefix_cache_space
            # Handle prefix part - Which may be left padded
            if prefix_len <= prefix_cache_space:
                start_pos = prefix_cache_space - prefix_len
                self.attention_mask[cache_idx, start_pos:end_pos] = (
                    prefill_attention_mask[batch_idx, :prefix_len]
                )
            else:
                self.attention_mask[cache_idx, :end_pos] = prefill_attention_mask[
                    batch_idx, prefix_len - prefix_cache_space : prefix_len
                ]

            # Handle text part, keeping sliding window in consideration
            # All of the left padding is before the prefix, so we can ignore the prefill_attention_mask here
            if text_len > 0:
                text_cache_start = prefix_cache_space
                if text_len <= sliding_window:
                    self.attention_mask[
                        cache_idx, text_cache_start : text_cache_start + text_len
                    ] = 1
                else:
                    self.attention_mask[cache_idx, -sliding_window:] = 1

    # Slow impl for now - Prefill time is dominated by the large sequence length forward pass
    def _prefill_update(
        self,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        text_token_counts: torch.Tensor,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ):
        cache_idxs: List[int] = cache_kwargs.get("cache_idxs", None)
        text_lengths: List[int] = cache_kwargs.get("text_lengths", None)
        assert cache_idxs is not None, "cache_idxs must be specified during prefill"
        assert text_lengths is not None, "text_lengths must be specified during prefill"

        _, _, seq_len, _ = key_states.shape
        total_cache_len = self.max_cache_len
        sliding_window = self.text_sliding_window
        prefix_cache_space = total_cache_len - sliding_window

        for batch_idx, cache_idx in enumerate(cache_idxs):
            text_len = text_lengths[batch_idx]
            prefix_len = seq_len - text_len

            ###### Handle Image Tokens (Prefix) #####
            # Place image tokens in appropriate cache space, aligned to the **right edge**
            assert prefix_len > 0, "There are no prefix (image) tokens!"

            # prefix_len may be greater than the prefix cache space due to left padding - This happens when
            # a different batch element has a large input text during prefill, causing others to have a lot of
            # left padding. We can safely take the last `prefix_cache_space` elements from the kv states, since
            # `prefix_cache_space` is large enough to fit any image, and the rest **has to be** padding
            end_pos = prefix_cache_space
            if prefix_len <= prefix_cache_space:
                start_pos = prefix_cache_space - prefix_len
                key_cache[cache_idx, :, start_pos:end_pos] = key_states[
                    batch_idx, :, :prefix_len
                ]
                value_cache[cache_idx, :, start_pos:end_pos] = value_states[
                    batch_idx, :, :prefix_len
                ]
            else:
                key_cache[cache_idx, :, :end_pos] = key_states[
                    batch_idx, :, prefix_len - prefix_cache_space : prefix_len
                ]
                value_cache[cache_idx, :, :end_pos] = value_states[
                    batch_idx, :, prefix_len - prefix_cache_space : prefix_len
                ]

            ###### Handle Text Tokens #####
            # Text tokens start at the **left edge** of sliding window cache space
            if text_len > 0:
                text_cache_start = prefix_cache_space

                if text_len <= sliding_window:
                    key_cache[
                        cache_idx, :, text_cache_start : text_cache_start + text_len
                    ] = key_states[batch_idx, :, prefix_len : prefix_len + text_len]
                    value_cache[
                        cache_idx, :, text_cache_start : text_cache_start + text_len
                    ] = value_states[batch_idx, :, prefix_len : prefix_len + text_len]
                else:
                    start_in_text = text_len - sliding_window
                    key_cache[
                        cache_idx,
                        :,
                        text_cache_start : text_cache_start + sliding_window,
                    ] = key_states[
                        batch_idx, :, prefix_len + start_in_text : prefix_len + text_len
                    ]
                    value_cache[
                        cache_idx,
                        :,
                        text_cache_start : text_cache_start + sliding_window,
                    ] = value_states[
                        batch_idx, :, prefix_len + start_in_text : prefix_len + text_len
                    ]

        # Return the full key/value states (not just cached) for use in subsequent layers
        return key_states, value_states

    # """
    # Matches the logic of the decode update, but needs to be called before the updates
    # since some parts of the model depend on the attention mask
    # """
    def decode_attention_mask_update(
        self, num_valid_tokens: torch.Tensor, cache_idxs: List[int]
    ):
        sliding_window = self.text_sliding_window
        text_cache_start = self.max_cache_len - sliding_window

        # Using text_token_counts of first layer, should be same for all though
        current_text_lens = self.text_token_counts[0]
        cache_idxs_tensor = torch.tensor(cache_idxs, device=current_text_lens.device)

        # Get current text lengths for the relevant cache indices
        current_lens = current_text_lens[cache_idxs_tensor]
        new_text_lens = current_lens + num_valid_tokens
        is_full = new_text_lens > sliding_window

        # Handle full caches - set entire sliding window to 1
        if is_full.any():
            full_mask = is_full
            full_cache_idxs = cache_idxs_tensor[full_mask]
            self.attention_mask[full_cache_idxs, text_cache_start:] = 1

        # Handle non-full caches - set specific ranges to 1
        if (~is_full).any():
            non_full_mask = ~is_full
            non_full_cache_idxs = cache_idxs_tensor[non_full_mask]
            non_full_current_lens = current_lens[non_full_mask]
            non_full_valid_tokens = num_valid_tokens[non_full_mask]

            max_valid_tokens = (
                non_full_valid_tokens.max().item()
                if len(non_full_valid_tokens) > 0
                else 0
            )
            if max_valid_tokens > 0:
                batch_size = len(non_full_cache_idxs)
                offset_range = torch.arange(
                    max_valid_tokens, device=current_text_lens.device
                )
                batch_offsets = offset_range.unsqueeze(0).expand(batch_size, -1)
                start_positions = non_full_current_lens.unsqueeze(1)
                valid_token_counts = non_full_valid_tokens.unsqueeze(1)

                position_indices = start_positions + batch_offsets
                valid_mask = batch_offsets < valid_token_counts

                row_indices = non_full_cache_idxs.unsqueeze(1).expand(
                    -1, max_valid_tokens
                )[valid_mask]
                col_indices = text_cache_start + position_indices[valid_mask]

                self.attention_mask[row_indices, col_indices] = 1

    """
    Static cache update
    - respects per-batch text token limits
    - per-batch valid token lengths (right-padded inputs)

    kv states are expected to have shape [batch_size, kv_heads, T_pad, head_dim]
    They may have different `true` lengths, to account for multi token preds, or beacon tokens
    Expects `num_valid_tokens` in cache_kwargs: a tensor of shape (B,) indicating the number
    of actual (non-padded) tokens to add per batch element.
    """

    def _decode_update(
        self,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        text_token_counts: torch.Tensor,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        num_valid_tokens: torch.Tensor = cache_kwargs.get(
            "num_valid_tokens"
        )  # shape: (B,)
        assert num_valid_tokens is not None, (
            "`num_valid_tokens` must be provided in `cache_kwargs`"
        )
        device = key_states.device

        batch_size, num_head, seq_len, head_dim = key_states.shape
        sliding_window = self.text_sliding_window
        max_cache_len = self.max_cache_len
        cache_text_start = max_cache_len - sliding_window
        new_text_lengths = text_token_counts + num_valid_tokens
        slide_amounts = torch.clamp(new_text_lengths - sliding_window, min=0)
        needs_rotate = slide_amounts > 0

        # Rotate the cache if needed
        if torch.any(needs_rotate):
            k_slice = key_cache[:, :, -sliding_window:]  # shape: [B, H, W, D]
            v_slice = value_cache[:, :, -sliding_window:]  # same shape

            cache_indices = (
                torch.arange(sliding_window, device=device)
                .unsqueeze(0)
                .repeat(batch_size, 1)
            )  # [B, W]
            rolled_indices = (
                cache_indices + slide_amounts.unsqueeze(1)
            ) % sliding_window  # [B, W]

            # We need to expand indices to shape: [B, 1, W, 1] to broadcast with k_slice
            rolled_indices = (
                rolled_indices.unsqueeze(1)
                .unsqueeze(-1)
                .expand(-1, num_head, -1, head_dim)
            )

            k_slice_rolled = k_slice.gather(dim=2, index=rolled_indices)
            v_slice_rolled = v_slice.gather(dim=2, index=rolled_indices)

            key_cache[:, :, -sliding_window:] = k_slice_rolled
            value_cache[:, :, -sliding_window:] = v_slice_rolled

        # Insert only **valid tokens** into the cache. These are **right aligned** within the input sequence
        insert_positions = torch.where(
            needs_rotate,
            max_cache_len - num_valid_tokens,
            text_token_counts + cache_text_start,
        )

        max_tokens = num_valid_tokens.max().item()
        offsets = torch.arange(max_tokens, device=device).unsqueeze(0)  # [1, max_T]
        valid_mask = offsets < num_valid_tokens.unsqueeze(1)  # [B, max_T]
        src_indices = (seq_len - num_valid_tokens).unsqueeze(1) + offsets  # [B, max_T]
        src_indices = src_indices.clamp(max=seq_len - 1)  # safety

        tgt_indices = insert_positions.unsqueeze(1) + offsets  # [B, max_T]
        tgt_indices = tgt_indices.clamp(max=max_cache_len - 1)  # safety

        src_idx_exp = (
            src_indices.unsqueeze(1)
            .unsqueeze(-1)
            .expand(batch_size, num_head, max_tokens, head_dim)
        )
        tgt_idx_exp = (
            tgt_indices.unsqueeze(1)
            .unsqueeze(-1)
            .expand(batch_size, num_head, max_tokens, head_dim)
        )
        valid_mask_exp = (
            valid_mask.unsqueeze(1)
            .unsqueeze(-1)
            .expand(batch_size, num_head, max_tokens, head_dim)
        )

        k_src = torch.gather(key_states, 2, src_idx_exp)
        v_src = torch.gather(value_states, 2, src_idx_exp)
        k_src = k_src * valid_mask_exp
        v_src = v_src * valid_mask_exp

        # Write into cache
        key_cache.scatter_(2, tgt_idx_exp, k_src)
        value_cache.scatter_(2, tgt_idx_exp, v_src)

        # In-place edit - Mutates
        text_token_counts += num_valid_tokens
        text_token_counts.clamp_(max=sliding_window)

        return key_cache, value_cache

    # We have a non-uniform cache, so its better to not return it and handle any logic
    # that requires this ourselves
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        raise NotImplementedError()

```

--------------------------------------------------------------------------------
/surya/table_rec/__init__.py:
--------------------------------------------------------------------------------

```python
from copy import deepcopy
from itertools import chain
from typing import List

import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

from surya.common.xla import mark_step
from surya.common.predictor import BasePredictor
from surya.table_rec.schema import TableCell, TableRow, TableCol, TableResult
from surya.common.polygon import PolygonBox
from surya.settings import settings
from surya.table_rec.loader import TableRecModelLoader
from surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM, CATEGORY_TO_ID, MERGE_KEYS, \
    MERGE_VALUES
from surya.table_rec.shaper import LabelShaper


class TableRecPredictor(BasePredictor):
    model_loader_cls = TableRecModelLoader
    batch_size = settings.TABLE_REC_BATCH_SIZE
    default_batch_sizes = {
        "cpu": 8,
        "mps": 8,
        "cuda": 32,
        "xla": 16
    }

    def __call__(self, images: List[Image.Image], batch_size: int | None = None) -> List[TableResult]:
        return self.batch_table_recognition(images, batch_size)

    def inference_loop(
            self,
            encoder_hidden_states: torch.Tensor,
            batch_input_ids: torch.Tensor,
            current_batch_size: int,
            batch_size: int
    ):
        shaper = LabelShaper()
        batch_predictions = [[] for _ in range(current_batch_size)]
        max_tokens = settings.TABLE_REC_MAX_BOXES
        decoder_position_ids = torch.ones_like(batch_input_ids[0, :, 0], dtype=torch.int64, device=self.model.device).cumsum(
            0) - 1
        inference_token_count = batch_input_ids.shape[1]

        if settings.TABLE_REC_STATIC_CACHE:
            encoder_hidden_states = self.pad_to_batch_size(encoder_hidden_states, batch_size)
            batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size)

        # Move to device after padding for XLA
        encoder_hidden_states = encoder_hidden_states.to(self.model.device)
        batch_input_ids = batch_input_ids.to(self.model.device)

        self.model.decoder.model._setup_cache(self.model.config, batch_size, self.model.device, self.model.dtype)

        with settings.INFERENCE_MODE():
            token_count = 0
            all_done = torch.zeros(encoder_hidden_states.shape[0], dtype=torch.bool, device=self.model.device)

            while token_count < max_tokens:
                is_prefill = token_count == 0
                return_dict = self.model.decoder(
                    input_ids=batch_input_ids,
                    encoder_hidden_states=encoder_hidden_states,
                    cache_position=decoder_position_ids,
                    use_cache=True,
                    prefill=is_prefill
                )

                decoder_position_ids = decoder_position_ids[-1:] + 1

                # Get predictions for each box element
                box_properties = []
                done = []

                # Pre-process all logits at once
                processed_logits = {}
                for k, _, mode in BOX_PROPERTIES:
                    k_logits = return_dict["box_property_logits"][k][:, -1, :]  # Get all batch logits at once
                    
                    if mode == "classification":
                        # Process all classification logits in one operation
                        items = torch.argmax(k_logits, dim=-1)
                        if k == "category":
                            done = (items == self.model.decoder.config.eos_token_id) | (items == self.model.decoder.config.pad_token_id)
                        items = items - SPECIAL_TOKENS
                        processed_logits[k] = items
                    elif mode == "regression":
                        if k == "bbox":
                            k_logits = k_logits * BOX_DIM
                            processed_logits[k] = k_logits
                        elif k == "colspan":
                            k_logits = torch.clamp(k_logits, min=1)
                            processed_logits[k] = torch.round(k_logits)

                items = {k: processed_logits[k].cpu() for k, _, _ in BOX_PROPERTIES}
                for j in range(current_batch_size):
                    box_property = {}
                    for k, _, mode in BOX_PROPERTIES:
                        if mode == "classification":
                            box_property[k] = int(items[k][j].item())
                        elif mode == "regression":
                            if k == "bbox":
                                box_property[k] = items[k][j].tolist()
                            elif k == "colspan":
                                box_property[k] = int(items[k][j].item())
                    box_properties.append(box_property)

                all_done = all_done | done
                all_done_cpu = all_done.cpu()

                if all_done_cpu[:current_batch_size].all():
                    break

                batch_input_ids = torch.tensor(shaper.dict_to_labels(box_properties), dtype=torch.long)
                batch_input_ids = batch_input_ids.unsqueeze(1)  # Add sequence length dimension

                for j, (box_property, status) in enumerate(zip(box_properties, all_done_cpu)):
                    if not status:
                        batch_predictions[j].append(box_property)

                token_count += inference_token_count
                inference_token_count = batch_input_ids.shape[1]

                if settings.TABLE_REC_STATIC_CACHE:
                    batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size)

                # Move to device after padding for XLA
                batch_input_ids = batch_input_ids.to(self.model.device)
        return batch_predictions

    def batch_table_recognition(
            self,
            images: List,
            batch_size=None) -> List[TableResult]:
        assert all([isinstance(image, Image.Image) for image in images])
        if batch_size is None:
            batch_size = self.get_batch_size()

        if len(images) == 0:
            return []

        query_items = []
        for image in images:
            query_items.append({
                "polygon": [[0, 0], [image.width, 0], [image.width, image.height], [0, image.height]],
                "category": CATEGORY_TO_ID["Table"],
                "colspan": 0,
                "merges": 0,
                "is_header": 0
            })

        output_order = []
        for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables", disable=self.disable_tqdm):
            batch_query_items = query_items[i:i + batch_size]

            batch_images = images[i:i + batch_size]
            batch_images = [image.convert("RGB") for image in batch_images]  # also copies the images

            current_batch_size = len(batch_images)

            orig_sizes = [image.size for image in batch_images]
            model_inputs = self.processor(images=batch_images, query_items=batch_query_items)

            batch_pixel_values = model_inputs["pixel_values"]

            batch_input_ids = model_inputs["input_ids"]
            batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=self.model.dtype)

            if settings.TABLE_REC_STATIC_CACHE:
                batch_pixel_values = self.pad_to_batch_size(batch_pixel_values, batch_size)

            # Move to device after padding for XLA
            batch_pixel_values = batch_pixel_values.to(self.model.device)

            shaper = LabelShaper()

            # We only need to process each image once
            with settings.INFERENCE_MODE():
                encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state

            # Inference to get rows and columns
            rowcol_predictions = self.inference_loop(
                encoder_hidden_states,
                batch_input_ids,
                current_batch_size,
                batch_size
            )
            mark_step()

            row_query_items = []
            row_encoder_hidden_states = []
            idx_map = []
            columns = []
            for j, img_predictions in enumerate(rowcol_predictions):
                for row_prediction in img_predictions:
                    polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"])
                    if row_prediction["category"] == CATEGORY_TO_ID["Table-row"]:
                        row_query_items.append({
                            "polygon": polygon,
                            "category": row_prediction["category"],
                            "colspan": 0,
                            "merges": 0,
                            "is_header": int(row_prediction["is_header"] == 1)
                        })
                        row_encoder_hidden_states.append(encoder_hidden_states[j])
                        idx_map.append(j)
                    elif row_prediction["category"] == CATEGORY_TO_ID["Table-column"]:
                        columns.append({
                            "polygon": polygon,
                            "category": row_prediction["category"],
                            "colspan": 0,
                            "merges": 0,
                            "is_header": int(row_prediction["is_header"] == 1)
                        })

            # Re-inference to predict cells
            row_encoder_hidden_states = torch.stack(row_encoder_hidden_states)
            row_inputs = self.processor(images=None, query_items=row_query_items, columns=columns, convert_images=False)
            row_input_ids = row_inputs["input_ids"]
            cell_predictions = []
            for j in range(0, len(row_input_ids), batch_size):
                cell_batch_hidden_states = row_encoder_hidden_states[j:j + batch_size]
                cell_batch_input_ids = row_input_ids[j:j + batch_size]
                cell_batch_size = len(cell_batch_input_ids)
                cell_predictions.extend(
                    self.inference_loop(cell_batch_hidden_states, cell_batch_input_ids, cell_batch_size, batch_size)
                )
                mark_step()

            result = self.decode_batch_predictions(rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper)
            output_order.extend(result)

        return output_order


    def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper):
        results = []
        for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)):
            row_cell_predictions = [c for i, c in enumerate(cell_predictions) if idx_map[i] == j]
            # Each row prediction matches a cell prediction
            rows = []
            cells = []
            columns = []

            cell_id = 0
            row_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-row"]]
            col_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-column"]]

            # Generate table columns
            for z, col_prediction in enumerate(col_predictions):
                polygon = shaper.convert_bbox_to_polygon(col_prediction["bbox"])
                polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
                columns.append(
                    TableCol(
                        polygon=polygon,
                        col_id=z,
                        is_header=col_prediction["is_header"] == 1
                    )
                )

            # Generate table rows
            for z, row_prediction in enumerate(row_predictions):
                polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"])
                polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
                row = TableRow(
                    polygon=polygon,
                    row_id=z,
                    is_header=row_prediction["is_header"] == 1
                )
                rows.append(row)

                # Get cells that span multiple columns within a row
                spanning_cells = []
                for l, spanning_cell in enumerate(row_cell_predictions[z]):
                    polygon = shaper.convert_bbox_to_polygon(spanning_cell["bbox"])
                    polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
                    colspan = max(1, int(spanning_cell["colspan"]))
                    if colspan == 1 and spanning_cell["merges"] not in MERGE_VALUES:
                        # Skip single column cells if they don't merge
                        continue
                    if PolygonBox(polygon=polygon).height < row.height * .85:
                        # Spanning cell must cover most of the row
                        continue

                    spanning_cells.append(
                        TableCell(
                            polygon=polygon,
                            row_id=z,
                            rowspan=1,
                            cell_id=cell_id,
                            within_row_id=l,
                            colspan=colspan,
                            merge_up=spanning_cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]],
                            merge_down=spanning_cell["merges"] in [MERGE_KEYS["merge_down"],
                                                                   MERGE_KEYS["merge_both"]],
                            is_header=row.is_header or z == 0
                        )
                    )
                    cell_id += 1

                # Add cells - either add spanning cells (multiple cols), or generate a cell based on row/col
                used_spanning_cells = set()
                skip_columns = 0
                for l, col in enumerate(columns):
                    if skip_columns:
                        skip_columns -= 1
                        continue
                    cell_polygon = row.intersection_polygon(col)
                    cell_added = False
                    for zz, spanning_cell in enumerate(spanning_cells):
                        cell_polygonbox = PolygonBox(polygon=cell_polygon)
                        intersection_pct = cell_polygonbox.intersection_pct(spanning_cell)
                        # Make sure cells intersect, and that the spanning cell is wider than the current cell (takes up multiple columns)
                        correct_col_width = sum([col.width for col in columns[l:l + spanning_cell.colspan]])
                        if intersection_pct > .9:
                            if spanning_cell.width > (correct_col_width * .85):
                                cell_added = True
                                if zz not in used_spanning_cells:
                                    used_spanning_cells.add(zz)
                                    spanning_cell.col_id = l
                                    cells.append(spanning_cell)
                                    skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell
                            else:
                                used_spanning_cells.add(zz) # Skip this spanning cell

                    if not cell_added:
                        cells.append(
                            TableCell(
                                polygon=cell_polygon,
                                row_id=z,
                                rowspan=1,
                                cell_id=cell_id,
                                within_row_id=l,
                                colspan=1,
                                merge_up=False,
                                merge_down=False,
                                col_id=l,
                                is_header=row.is_header or col.is_header or z == 0
                            )
                        )
                        cell_id += 1

            # Turn cells into a row grid
            grid_cells = deepcopy([
                [cell for cell in cells if cell.row_id == row.row_id]
                for row in rows
            ])

            # Merge cells across rows
            for z, grid_row in enumerate(grid_cells[1:]):
                prev_row = grid_cells[z]
                for l, cell in enumerate(grid_row):
                    if l >= len(prev_row):
                        continue

                    above_cell = prev_row[l]
                    if all([
                        above_cell.merge_down,
                        cell.merge_up,
                        above_cell.col_id == cell.col_id,
                        above_cell.colspan == cell.colspan,
                    ]):
                        above_cell.merge(cell)
                        above_cell.rowspan += cell.rowspan
                        grid_row[l] = above_cell

            merged_cells_all = list(chain.from_iterable(grid_cells))
            used_ids = set()
            merged_cells = []
            for cell in merged_cells_all:
                if cell.cell_id in used_ids:
                    continue
                used_ids.add(cell.cell_id)
                merged_cells.append(cell)

            result = TableResult(
                cells=merged_cells,
                unmerged_cells=cells,
                rows=rows,
                cols=columns,
                image_bbox=[0, 0, orig_size[0], orig_size[1]],
            )
            results.append(result)
        return results

```
Page 2/4FirstPrevNextLast