This is page 2 of 4. Use http://codebase.md/datalab-to/surya?page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── ISSUE_TEMPLATE │ │ ├── breaking-bug-report.md │ │ ├── feature_request.md │ │ └── output-bug-report.md │ └── workflows │ ├── benchmarks.yml │ ├── ci.yml │ ├── cla.yml │ ├── publish.yml │ └── scripts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── benchmark │ ├── detection.py │ ├── layout.py │ ├── ordering.py │ ├── recognition.py │ ├── table_recognition.py │ ├── texify.py │ └── utils │ ├── __init__.py │ ├── bbox.py │ ├── metrics.py │ ├── scoring.py │ ├── tatr.py │ ├── tesseract.py │ ├── textract.py │ └── verify_benchmark_scores.py ├── CITATION.cff ├── CLA.md ├── detect_layout.py ├── detect_text.py ├── LICENSE ├── ocr_app.py ├── ocr_latex.py ├── ocr_text.py ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── README.md ├── signatures │ └── version1 │ └── cla.json ├── static │ ├── fonts │ │ └── .gitignore │ └── images │ ├── arabic_layout.jpg │ ├── arabic_reading.jpg │ ├── arabic_text.jpg │ ├── arabic.jpg │ ├── benchmark_chart_small.png │ ├── benchmark_chart.png │ ├── benchmark_layout_chart.png │ ├── benchmark_rec_chart.png │ ├── benchmark_tablerec_acc.png │ ├── benchmark_tablerec_speed.png │ ├── chi_hind_layout.jpg │ ├── chi_hind_orig.jpg │ ├── chi_hind_reading.jpg │ ├── chi_hind_text.jpg │ ├── chi_hind.jpg │ ├── chinese_layout.jpg │ ├── chinese_reading.jpg │ ├── chinese_text.jpg │ ├── chinese.jpg │ ├── excerpt_layout.png │ ├── excerpt_reading.jpg │ ├── excerpt_text.png │ ├── excerpt.png │ ├── funsd_layout.jpg │ ├── funsd_reading.jpg │ ├── funsd_text.jpg │ ├── funsd.png │ ├── gcloud_full_langs.png │ ├── gcloud_rec_bench.png │ ├── hindi_layout.jpg │ ├── hindi_reading.jpg │ ├── hindi_text.jpg │ ├── hindi.jpg │ ├── japanese_layout.jpg │ ├── japanese_reading.jpg │ ├── japanese_tablerec.png │ ├── japanese_text.jpg │ ├── japanese.jpg │ ├── latex_ocr.png │ ├── nyt_layout.jpg │ ├── nyt_order.jpg │ ├── nyt_text.jpg │ ├── nyt.jpg │ ├── paper_layout.jpg │ ├── paper_reading.jpg │ ├── paper_tablerec.png │ ├── paper_text.jpg │ ├── paper.jpg │ ├── pres_layout.jpg │ ├── pres_reading.jpg │ ├── pres_tablerec.png │ ├── pres_text.jpg │ ├── pres.png │ ├── rec_acc_table.png │ ├── scanned_layout.jpg │ ├── scanned_reading.jpg │ ├── scanned_tablerec.png │ ├── scanned_tablerec2.png │ ├── scanned_text.jpg │ ├── scanned.png │ ├── surya_rec_perf.png │ ├── table_rec.png │ ├── textbook_layout.jpg │ ├── textbook_order.jpg │ ├── textbook_text.jpg │ └── textbook.jpg ├── surya │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── adetr │ │ │ └── decoder.py │ │ ├── donut │ │ │ ├── encoder.py │ │ │ └── processor.py │ │ ├── load.py │ │ ├── polygon.py │ │ ├── predictor.py │ │ ├── pretrained.py │ │ ├── s3.py │ │ ├── surya │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── decoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── embedder │ │ │ │ └── __init__.py │ │ │ ├── encoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── flash_attn_utils.py │ │ │ ├── processor │ │ │ │ ├── __init__.py │ │ │ │ ├── schema.py │ │ │ │ └── tokenizer.py │ │ │ └── schema.py │ │ ├── util.py │ │ └── xla.py │ ├── debug │ │ ├── draw.py │ │ ├── fonts.py │ │ ├── katex.js │ │ ├── render_html.py │ │ └── text.py │ ├── detection │ │ ├── __init__.py │ │ ├── heatmap.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoderdecoder.py │ │ ├── parallel.py │ │ ├── processor.py │ │ ├── schema.py │ │ └── util.py │ ├── foundation │ │ ├── __init__.py │ │ ├── cache │ │ │ ├── __init__.py │ │ │ ├── dynamic_ops.py │ │ │ └── static_ops.py │ │ ├── loader.py │ │ └── util.py │ ├── input │ │ ├── load.py │ │ └── processing.py │ ├── layout │ │ ├── __init__.py │ │ ├── label.py │ │ └── schema.py │ ├── logging.py │ ├── models.py │ ├── ocr_error │ │ ├── __init__.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoder.py │ │ ├── schema.py │ │ └── tokenizer.py │ ├── recognition │ │ ├── __init__.py │ │ ├── languages.py │ │ ├── postprocessing.py │ │ ├── schema.py │ │ └── util.py │ ├── scripts │ │ ├── __init__.py │ │ ├── config.py │ │ ├── detect_layout.py │ │ ├── detect_text.py │ │ ├── finetune_ocr.py │ │ ├── hf_to_s3.py │ │ ├── ocr_latex.py │ │ ├── ocr_text.py │ │ ├── run_streamlit_app.py │ │ ├── run_texify_app.py │ │ ├── streamlit_app.py │ │ ├── table_recognition.py │ │ └── texify_app.py │ ├── settings.py │ └── table_rec │ ├── __init__.py │ ├── loader.py │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── encoderdecoder.py │ ├── processor.py │ ├── schema.py │ └── shaper.py ├── table_recognition.py ├── tests │ ├── assets │ │ └── test_latex.png │ ├── conftest.py │ ├── test_detection.py │ ├── test_foundation.py │ ├── test_latex_ocr.py │ ├── test_layout.py │ ├── test_ocr_errors.py │ ├── test_recognition.py │ └── test_table_rec.py └── texify_app.py ``` # Files -------------------------------------------------------------------------------- /benchmark/utils/tesseract.py: -------------------------------------------------------------------------------- ```python from typing import List, Optional import numpy as np from tqdm import tqdm from surya.input.processing import slice_bboxes_from_image from surya.settings import settings import os from concurrent.futures import ProcessPoolExecutor from surya.recognition.languages import CODE_TO_LANGUAGE from surya.recognition import RecognitionPredictor from surya.detection import DetectionPredictor def surya_lang_to_tesseract(code: str) -> Optional[str]: lang_str = CODE_TO_LANGUAGE[code] try: tess_lang = TESS_LANGUAGE_TO_CODE[lang_str] except KeyError: return None return tess_lang def tesseract_ocr(img, bboxes, lang: str): import pytesseract line_imgs = slice_bboxes_from_image(img, bboxes) config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"' lines = [] for line_img in line_imgs: line = pytesseract.image_to_string(line_img, lang=lang, config=config) lines.append(line) return lines def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None): tess_parallel_cores = min(len(imgs), RecognitionPredictor.get_batch_size()) if not cpus: cpus = os.cpu_count() tess_parallel_cores = min(tess_parallel_cores, cpus) # Tesseract uses up to 4 processes per instance # Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images tess_parallel = max(tess_parallel_cores // 2, 1) with ProcessPoolExecutor(max_workers=tess_parallel) as executor: tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc="Running tesseract OCR") tess_text = list(tess_text) return tess_text def tesseract_bboxes(img): import pytesseract from pytesseract import Output arr_img = np.asarray(img, dtype=np.uint8) ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT) bboxes = [] n_boxes = len(ocr['level']) for i in range(n_boxes): # It is possible to merge by line here with line number, but it gives bad results. _, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i] bbox = (x, y, x + w, y + h) bboxes.append(bbox) return bboxes def tesseract_parallel(imgs): # Tesseract uses 4 threads per instance tess_parallel_cores = min(len(imgs), DetectionPredictor.get_batch_size()) cpus = os.cpu_count() tess_parallel_cores = min(tess_parallel_cores, cpus) # Tesseract uses 4 threads per instance tess_parallel = max(tess_parallel_cores // 4, 1) with ProcessPoolExecutor(max_workers=tess_parallel) as executor: tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc="Running tesseract bbox detection") tess_bboxes = list(tess_bboxes) return tess_bboxes TESS_CODE_TO_LANGUAGE = { "afr": "Afrikaans", "amh": "Amharic", "ara": "Arabic", "asm": "Assamese", "aze": "Azerbaijani", "bel": "Belarusian", "ben": "Bengali", "bod": "Tibetan", "bos": "Bosnian", "bre": "Breton", "bul": "Bulgarian", "cat": "Catalan", "ceb": "Cebuano", "ces": "Czech", "chi_sim": "Chinese", "chr": "Cherokee", "cym": "Welsh", "dan": "Danish", "deu": "German", "dzo": "Dzongkha", "ell": "Greek", "eng": "English", "epo": "Esperanto", "est": "Estonian", "eus": "Basque", "fas": "Persian", "fin": "Finnish", "fra": "French", "fry": "Western Frisian", "guj": "Gujarati", "gla": "Scottish Gaelic", "gle": "Irish", "glg": "Galician", "heb": "Hebrew", "hin": "Hindi", "hrv": "Croatian", "hun": "Hungarian", "hye": "Armenian", "iku": "Inuktitut", "ind": "Indonesian", "isl": "Icelandic", "ita": "Italian", "jav": "Javanese", "jpn": "Japanese", "kan": "Kannada", "kat": "Georgian", "kaz": "Kazakh", "khm": "Khmer", "kir": "Kyrgyz", "kor": "Korean", "lao": "Lao", "lat": "Latin", "lav": "Latvian", "lit": "Lithuanian", "mal": "Malayalam", "mar": "Marathi", "mkd": "Macedonian", "mlt": "Maltese", "mon": "Mongolian", "msa": "Malay", "mya": "Burmese", "nep": "Nepali", "nld": "Dutch", "nor": "Norwegian", "ori": "Oriya", "pan": "Punjabi", "pol": "Polish", "por": "Portuguese", "pus": "Pashto", "ron": "Romanian", "rus": "Russian", "san": "Sanskrit", "sin": "Sinhala", "slk": "Slovak", "slv": "Slovenian", "snd": "Sindhi", "spa": "Spanish", "sqi": "Albanian", "srp": "Serbian", "swa": "Swahili", "swe": "Swedish", "syr": "Syriac", "tam": "Tamil", "tel": "Telugu", "tgk": "Tajik", "tha": "Thai", "tir": "Tigrinya", "tur": "Turkish", "uig": "Uyghur", "ukr": "Ukrainian", "urd": "Urdu", "uzb": "Uzbek", "vie": "Vietnamese", "yid": "Yiddish" } TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()} ``` -------------------------------------------------------------------------------- /benchmark/layout.py: -------------------------------------------------------------------------------- ```python import collections import copy import json import click from benchmark.utils.metrics import precision_recall from surya.foundation import FoundationPredictor from surya.layout import LayoutPredictor from surya.input.processing import convert_if_not_rgb from surya.debug.draw import draw_bboxes_on_image from surya.settings import settings import os import time from tabulate import tabulate import datasets @click.command(help="Benchmark surya layout model.") @click.option( "--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"), ) @click.option( "--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=100, ) @click.option("--debug", is_flag=True, help="Run in debug mode.", default=False) def main(results_dir: str, max_rows: int, debug: bool): foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) layout_predictor = LayoutPredictor(foundation_predictor) pathname = "layout_bench" # These have already been shuffled randomly, so sampling from the start is fine dataset = datasets.load_dataset( settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{max_rows}]" ) images = list(dataset["image"]) images = convert_if_not_rgb(images) if settings.LAYOUT_STATIC_CACHE: layout_predictor(images[:1]) start = time.time() layout_predictions = layout_predictor(images) surya_time = time.time() - start folder_name = os.path.basename(pathname).split(".")[0] result_path = os.path.join(results_dir, folder_name) os.makedirs(result_path, exist_ok=True) label_alignment = { # First is publaynet, second is surya "Image": [["Figure"], ["Picture", "Figure"]], "Table": [["Table"], ["Table", "Form", "TableOfContents"]], "Text": [ ["Text"], [ "Text", "Formula", "Footnote", "Caption", "TextInlineMath", "Code", "Handwriting", ], ], "List": [["List"], ["ListItem"]], "Title": [["Title"], ["SectionHeader", "Title"]], } page_metrics = collections.OrderedDict() for idx, pred in enumerate(layout_predictions): row = dataset[idx] all_correct_bboxes = [] page_results = {} for label_name in label_alignment: correct_cats, surya_cats = label_alignment[label_name] correct_bboxes = [ b for b, category in zip(row["bboxes"], row["labels"]) if category in correct_cats ] all_correct_bboxes.extend(correct_bboxes) pred_bboxes = [b.bbox for b in pred.bboxes if b.label in surya_cats] metrics = precision_recall( pred_bboxes, correct_bboxes, penalize_double=False ) weight = len(correct_bboxes) metrics["weight"] = weight page_results[label_name] = metrics page_metrics[idx] = page_results if debug: bbox_image = draw_bboxes_on_image( all_correct_bboxes, copy.deepcopy(images[idx]) ) bbox_image.save(os.path.join(result_path, f"{idx}_layout.png")) mean_metrics = collections.defaultdict(dict) layout_types = sorted(page_metrics[0].keys()) metric_types = sorted(page_metrics[0][layout_types[0]].keys()) metric_types.remove("weight") for label in layout_types: for m in metric_types: metric = [] total = 0 for page in page_metrics: metric.append( page_metrics[page][label][m] * page_metrics[page][label]["weight"] ) total += page_metrics[page][label]["weight"] value = sum(metric) if value > 0: value /= total mean_metrics[label][m] = value out_data = { "time": surya_time, "metrics": mean_metrics, "page_metrics": page_metrics, } with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(out_data, f, indent=4) table_headers = [ "Layout Type", ] + metric_types table_data = [] for layout_type in layout_types: table_data.append( [ layout_type, ] + [f"{mean_metrics[layout_type][m]:.5f}" for m in metric_types] ) print(tabulate(table_data, headers=table_headers, tablefmt="github")) print( f"Took {surya_time / len(images):.5f} seconds per image, and {surya_time:.5f} seconds total." ) print( "Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold." ) print(f"Wrote results to {result_path}") if __name__ == "__main__": main() ``` -------------------------------------------------------------------------------- /surya/foundation/loader.py: -------------------------------------------------------------------------------- ```python from typing import Optional import torch from transformers.utils import is_flash_attn_2_available from surya.common.load import ModelLoader from surya.common.surya.config import SuryaModelConfig from surya.common.surya import SuryaModel, SuryaXLAModel from surya.common.surya.processor import SuryaOCRProcessor from surya.common.surya.processor.tokenizer import SuryaOCRTokenizer from surya.common.util import is_flash_attn_2_supported from surya.common.xla import get_compile_args from surya.logging import get_logger from surya.settings import settings logger = get_logger() class FoundationModelLoader(ModelLoader): def __init__(self, checkpoint: Optional[str] = None): super().__init__(checkpoint) if self.checkpoint is None: self.checkpoint = settings.FOUNDATION_MODEL_CHECKPOINT def model( self, device=settings.TORCH_DEVICE_MODEL, dtype=None, attention_implementation: Optional[str] = None, ) -> SuryaModel: if device is None: device = settings.TORCH_DEVICE_MODEL if dtype is None: # See https://github.com/pytorch/pytorch/issues/118122 - T4 (device version 7.5) will return true since it supports # emulated bf16, but falls back to very slow kernels, especially for SDPA dtype = settings.MODEL_DTYPE_BFLOAT if device == "cuda" and not torch.cuda.is_bf16_supported( including_emulation=False ): # If the device is cuda, we check if bf16 is supported, and if not, we use float16 dtype = settings.MODEL_DTYPE elif dtype == torch.float16: dtype = torch.bfloat16 # Model weights in bfloat16 config = SuryaModelConfig.from_pretrained(self.checkpoint) if attention_implementation is not None: config.decoder._attn_implementation = attention_implementation config.vision_encoder._attn_implementation = attention_implementation elif is_flash_attn_2_available() and is_flash_attn_2_supported(device): config.decoder._attn_implementation = "flash_attention_2" config.vision_encoder._attn_implementation = "flash_attention_2" elif device == "xla": config.decoder._attn_implementation = "sdpa" config.vision_encoder._attn_implementation = "sdpa" else: config.decoder._attn_implementation = "sdpa" config.vision_encoder._attn_implementation = "sdpa" model_cls = SuryaModel if device == "xla": model_cls = SuryaXLAModel config._attn_implementation_autoset = True config.vision_encoder._attn_implementation_autoset = True config.decoder._attn_implementation_autoset = True model = model_cls.from_pretrained( self.checkpoint, dtype=dtype, config=config, ignore_mismatched_sizes=True ).to(device) model = model.eval() if settings.COMPILE_ALL or settings.COMPILE_FOUNDATION: torch._dynamo.config.cache_size_limit = 1000 torch._dynamo.config.suppress_errors = True torch._dynamo.config.specialize_int = False torch._dynamo.config.allow_unspec_int_on_nn_module = True torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.recompile_limit = 32 logger.info( f"Compiling foundation model {self.checkpoint} on device {device} with dtype {dtype}" ) compile_args = get_compile_args(device) model.vision_encoder = torch.compile(model.vision_encoder, **compile_args) model.decoder = torch.compile(model.decoder, **compile_args) logger.debug( f"Loaded recognition model {self.checkpoint} from {SuryaModel.get_local_path(self.checkpoint)} onto device {model.device} with dtype {dtype}, using decoder attention mechanism {model.config.decoder._attn_implementation}, encoder attention mechanism {model.config.vision_encoder._attn_implementation}." ) return model def processor( self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE_BFLOAT ) -> SuryaOCRProcessor: config: SuryaModelConfig = SuryaModelConfig.from_pretrained(self.checkpoint) ocr_tokenizer = SuryaOCRTokenizer( special_tokens=config.special_ocr_tokens, model_checkpoint=self.checkpoint ) processor = SuryaOCRProcessor( ocr_tokenizer=ocr_tokenizer, blank_bbox_token_id=config.blank_bbox_token_id, num_register_tokens=config.num_register_tokens, sequence_length=None, patch_size=config.vision_encoder.patch_size, merge_size=config.vision_encoder.spatial_merge_size, model_device=device, num_beacon_tokens=config.num_beacon_tokens, beacon_token_interval=config.beacon_token_interval, ) return processor ``` -------------------------------------------------------------------------------- /surya/table_rec/shaper.py: -------------------------------------------------------------------------------- ```python import math from typing import List, Dict import numpy as np from surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM class LabelShaper: def __init__(self): self.property_keys = [k for (k, kcount, mode) in BOX_PROPERTIES] def dict_to_labels(self, label_components: List[dict]): if len(label_components) == 0: return [] out_list = [] for (k, kcount, mode) in BOX_PROPERTIES: for label_component in label_components: if k not in label_component: raise ValueError(f"Missing key {k} in label component {label_component}") if mode == "classification": assert isinstance(label_component[k], int) elif mode == "regression": assert (isinstance(label_component[k], (int, float)) and kcount == 1) or len(label_component[k]) == kcount else: raise ValueError(f"Invalid mode {k['mode']} for key {k}") for label_component in label_components: bbox = label_component["bbox"] for i in range(len(bbox)): if bbox[i] < 0: bbox[i] = 0 if bbox[i] > BOX_DIM: bbox[i] = BOX_DIM vector = [] for (k, kcount, mode) in BOX_PROPERTIES: item = label_component[k] if isinstance(item, (list, tuple)): vector += list(item) elif isinstance(item, (float, int)): if mode == "classification": # Shift up for model item += SPECIAL_TOKENS vector.append(item) else: raise ValueError(f"Invalid item {item} for key {k}") out_list.append(vector) return out_list def component_idx(self, key): idx = 0 for (k, kcount, mode) in BOX_PROPERTIES: if mode == "regression": incr = kcount elif mode == "classification": incr = 1 else: raise ValueError(f"Invalid mode {mode} for key {k}") if k == key: return (idx, idx + incr) idx += incr raise ValueError(f"Key {key} not found in properties") def get_box_property(self, key, add_special_tokens=True): for (k, kcount, mode) in BOX_PROPERTIES: if k == key: # Add special token count if mode == "classification" and add_special_tokens: kcount += SPECIAL_TOKENS return (k, kcount, mode) raise ValueError(f"Key {key} not found in properties") def component_idx_dict(self): idx_dict = {} for (k, kcount, mode) in BOX_PROPERTIES: idx_dict[k] = self.component_idx(k) return idx_dict def convert_polygons_to_bboxes(self, label_components: List[Dict]): for i, label_component in enumerate(label_components): poly = label_component["polygon"] poly = np.clip(poly, 0, BOX_DIM) (x1, y1), (x2, y2), (x3, y3), (x4, y4) = poly cx = (x1 + x2 + x3 + x4) / 4 cy = (y1 + y2 + y3 + y4) / 4 width = (x2 + x3) / 2 - (x1 + x4) / 2 height = (y3 + y4) / 2 - (y2 + y1) / 2 bottom_avg_x = (x3 + x4) / 2 top_avg_x = (x1 + x2) / 2 right_avg_y = (y2 + y3) / 2 left_avg_y = (y1 + y4) / 2 x_skew = bottom_avg_x - top_avg_x y_skew = right_avg_y - left_avg_y x_skew += BOX_DIM // 2 # Shift up into positive space y_skew += BOX_DIM // 2 # Shift up into positive space new_poly = [ cx, cy, width, height, x_skew, y_skew ] label_component["bbox"] = new_poly return label_components def convert_bbox_to_polygon(self, box, skew_scaler=BOX_DIM // 2, skew_min=.001): cx = box[0] cy = box[1] width = box[2] height = box[3] x1 = cx - width / 2 y1 = cy - height / 2 x2 = cx + width / 2 y2 = cy + height / 2 skew_x = math.floor((box[4] - skew_scaler) / 2) skew_y = math.floor((box[5] - skew_scaler) / 2) # Ensures we don't get slightly warped boxes # Note that the values are later scaled, so this is in 1/1024 space if abs(skew_x) < skew_min: skew_x = 0 if abs(skew_y) < skew_min: skew_y = 0 polygon = [x1 - skew_x, y1 - skew_y, x2 - skew_x, y1 + skew_y, x2 + skew_x, y2 + skew_y, x1 + skew_x, y2 - skew_y] poly = [] for i in range(4): poly.append([ polygon[2 * i], polygon[2 * i + 1] ]) return poly ``` -------------------------------------------------------------------------------- /benchmark/detection.py: -------------------------------------------------------------------------------- ```python import argparse import collections import copy import json import click from benchmark.utils.bbox import get_pdf_lines from benchmark.utils.metrics import precision_recall from benchmark.utils.tesseract import tesseract_parallel from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb from surya.debug.draw import draw_polys_on_image from surya.common.util import rescale_bbox from surya.settings import settings from surya.detection import DetectionPredictor import os import time from tabulate import tabulate import datasets @click.command(help="Benchmark detection model.") @click.option("--pdf_path", type=str, help="Path to PDF to detect bboxes in.", default=None) @click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark")) @click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=100) @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) @click.option("--tesseract", is_flag=True, help="Run tesseract as well.", default=False) def main(pdf_path: str, results_dir: str, max_rows: int, debug: bool, tesseract: bool): det_predictor = DetectionPredictor() if pdf_path is not None: pathname = pdf_path doc = open_pdf(pdf_path) page_count = len(doc) page_indices = list(range(page_count)) page_indices = page_indices[:max_rows] images = get_page_images(doc, page_indices) doc.close() image_sizes = [img.size for img in images] correct_boxes = get_pdf_lines(pdf_path, image_sizes) else: pathname = "det_bench" # These have already been shuffled randomly, so sampling from the start is fine dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{max_rows}]") images = list(dataset["image"]) images = convert_if_not_rgb(images) correct_boxes = [] for i, boxes in enumerate(dataset["bboxes"]): img_size = images[i].size # 1000,1000 is bbox size for doclaynet correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes]) if settings.DETECTOR_STATIC_CACHE: # Run through one batch to compile the model det_predictor(images[:1]) start = time.time() predictions = det_predictor(images) surya_time = time.time() - start if tesseract: start = time.time() tess_predictions = tesseract_parallel(images) tess_time = time.time() - start else: tess_predictions = [None] * len(images) tess_time = None folder_name = os.path.basename(pathname).split(".")[0] result_path = os.path.join(results_dir, folder_name) os.makedirs(result_path, exist_ok=True) page_metrics = collections.OrderedDict() for idx, (tb, sb, cb) in enumerate(zip(tess_predictions, predictions, correct_boxes)): surya_boxes = [s.bbox for s in sb.bboxes] surya_polys = [s.polygon for s in sb.bboxes] surya_metrics = precision_recall(surya_boxes, cb) if tb is not None: tess_metrics = precision_recall(tb, cb) else: tess_metrics = None page_metrics[idx] = { "surya": surya_metrics, "tesseract": tess_metrics } if debug: bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx])) bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png")) mean_metrics = {} metric_types = sorted(page_metrics[0]["surya"].keys()) models = ["surya"] if tesseract: models.append("tesseract") for k in models: for m in metric_types: metric = [] for page in page_metrics: metric.append(page_metrics[page][k][m]) if k not in mean_metrics: mean_metrics[k] = {} mean_metrics[k][m] = sum(metric) / len(metric) out_data = { "times": { "surya": surya_time, "tesseract": tess_time }, "metrics": mean_metrics, "page_metrics": page_metrics } with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(out_data, f, indent=4) table_headers = ["Model", "Time (s)", "Time per page (s)"] + metric_types table_data = [ ["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types], ] if tesseract: table_data.append( ["tesseract", tess_time, tess_time / len(images)] + [mean_metrics["tesseract"][m] for m in metric_types] ) print(tabulate(table_data, headers=table_headers, tablefmt="github")) print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines.") print(f"Wrote results to {result_path}") if __name__ == "__main__": main() ``` -------------------------------------------------------------------------------- /surya/detection/heatmap.py: -------------------------------------------------------------------------------- ```python from typing import List import cv2 import numpy as np from PIL import Image from surya.common.util import clean_boxes from surya.detection import TextDetectionResult from surya.common.polygon import PolygonBox from surya.settings import settings def get_dynamic_thresholds(linemap, text_threshold, low_text, typical_top10_avg=0.7): # Find average intensity of top 10% pixels flat_map = linemap.ravel() top_10_count = int(len(flat_map) * 0.9) avg_intensity = np.mean(np.partition(flat_map, top_10_count)[top_10_count:]) scaling_factor = np.clip(avg_intensity / typical_top10_avg, 0, 1) ** (1 / 2) low_text = np.clip(low_text * scaling_factor, 0.1, 0.6) text_threshold = np.clip(text_threshold * scaling_factor, 0.15, 0.8) return text_threshold, low_text def detect_boxes(linemap, text_threshold, low_text): # From CRAFT - https://github.com/clovaai/CRAFT-pytorch # Modified to return boxes and for speed, accuracy img_h, img_w = linemap.shape text_threshold, low_text = get_dynamic_thresholds(linemap, text_threshold, low_text) text_score_comb = (linemap > low_text).astype(np.uint8) label_count, labels, stats, centroids = cv2.connectedComponentsWithStats( text_score_comb, connectivity=4 ) det = [] confidences = [] max_confidence = 0 for k in range(1, label_count): # size filtering size = stats[k, cv2.CC_STAT_AREA] if size < 10: continue # make segmentation map x, y, w, h = stats[ k, [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT], ] try: niter = int(np.sqrt(min(w, h))) except ValueError: niter = 0 buffer = 1 sx, sy = max(0, x - niter - buffer), max(0, y - niter - buffer) ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer) mask = labels[sy:ey, sx:ex] == k selected_linemap = linemap[sy:ey, sx:ex][mask] if selected_linemap.size == 0: continue line_max = np.max(selected_linemap) # thresholding if line_max < text_threshold: continue segmap = mask.astype(np.uint8) ksize = buffer + niter kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (ksize, ksize)) selected_segmap = cv2.dilate(segmap, kernel) # make box y_inds, x_inds = np.nonzero(selected_segmap) x_inds += sx y_inds += sy np_contours = np.column_stack((x_inds, y_inds)) rectangle = cv2.minAreaRect(np_contours) box = cv2.boxPoints(rectangle) # align diamond-shape w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) box_ratio = max(w, h) / (min(w, h) + 1e-5) if abs(1 - box_ratio) <= 0.1: left, right = np_contours[:, 0].min(), np_contours[:, 0].max() top, bottom = np_contours[:, 1].min(), np_contours[:, 1].max() box = np.array( [[left, top], [right, top], [right, bottom], [left, bottom]], dtype=np.float32, ) # make clock-wise order startidx = box.sum(axis=1).argmin() box = np.roll(box, 4 - startidx, 0) max_confidence = max(max_confidence, line_max) confidences.append(line_max) det.append(box) if max_confidence > 0: confidences = [c / max_confidence for c in confidences] return det, confidences def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]: if text_threshold is None: text_threshold = settings.DETECTOR_TEXT_THRESHOLD if low_text is None: low_text = settings.DETECTOR_BLANK_THRESHOLD if textmap.dtype != np.float32: textmap = textmap.astype(np.float32) boxes, confidences = detect_boxes(textmap, text_threshold, low_text) # From point form to box form return [ PolygonBox(polygon=box, confidence=confidence) for box, confidence in zip(boxes, confidences) ] def get_and_clean_boxes( textmap, processor_size, image_size, text_threshold=None, low_text=None ) -> List[PolygonBox]: bboxes = get_detected_boxes(textmap, text_threshold, low_text) for bbox in bboxes: bbox.rescale(processor_size, image_size) bbox.fit_to_bounds([0, 0, image_size[0], image_size[1]]) bboxes = clean_boxes(bboxes) return bboxes def parallel_get_boxes(preds, orig_sizes, include_maps=False): heatmap, affinity_map = preds heat_img, aff_img = None, None if include_maps: heat_img = Image.fromarray((heatmap * 255).astype(np.uint8)) aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8)) heatmap_size = list(reversed(heatmap.shape)) bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes) for box in bboxes: # Skip for vertical boxes if box.height < 3 * box.width: box.expand(x_margin=0, y_margin=settings.DETECTOR_BOX_Y_EXPAND_MARGIN) box.fit_to_bounds( [0, 0, orig_sizes[0], orig_sizes[1]] ) # Fix any bad expands result = TextDetectionResult( bboxes=bboxes, heatmap=heat_img, affinity_map=aff_img, image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]], ) return result ``` -------------------------------------------------------------------------------- /surya/recognition/util.py: -------------------------------------------------------------------------------- ```python import re from typing import List, Tuple import numpy import torch from surya.common.polygon import PolygonBox from surya.recognition.schema import TextLine, TextWord, TextChar MATH_SYMBOLS = ["+", "-", "*", "=", "^", "_", "\\", "{", "}"] def unwrap_math(text: str) -> str: if len(text) > 50: return text # Detected as math, but does not contain LaTeX commands if ( re.match(r'^\s*<math(?:\s+display="inline")?.*?</math>\s*$', text, re.DOTALL) and text.count("<math") == 1 and not any([symb in text for symb in MATH_SYMBOLS]) ): # Remove math tags text = re.sub(r"<math.*?>", "", text) text = re.sub(r"</math>", "", text) return text MATH_BLOCK = re.compile(r"(<math\b[^>]*>)(.*?)</math>", flags=re.I | re.S) STRIP_TAGS = re.compile(r"</?(?:br|u|del|mark|i|b|sup|sub)\b[^>]*>", flags=re.I | re.S) DEFAULT_TAGS_TO_FILTER = ["p", "li", "ul", "ol", "table", "td", "tr", "th", "tbody", "pre"] def filter_blacklist_tags(text_chars: List[TextChar], tags_to_filter: List[str] = None) -> List[TextChar]: filtered_chars = [] char_buffer = [] in_tag = False if tags_to_filter is None: tags_to_filter = DEFAULT_TAGS_TO_FILTER for text_char in text_chars: char = text_char.text if char.startswith("<") or in_tag: in_tag = True char_buffer.append(text_char) if char.endswith(">"): full_tag = ''.join(c.text for c in char_buffer) inner = full_tag[1:-1].strip() # remove < > inner = inner.strip("/") # remove '/' # Possible that it is just an empty <> if not inner: filtered_chars.extend(char_buffer) in_tag = False char_buffer = [] continue tag_name_candidate = inner.split()[0] # remove any attributes if tag_name_candidate in tags_to_filter: # Discard tag pass else: # Keep tag filtered_chars.extend(char_buffer) in_tag = False char_buffer = [] else: filtered_chars.append(text_char) # Flush buffer if we never reached a tag close if char_buffer: filtered_chars.extend(char_buffer) return filtered_chars def clean_math_tags(html: str) -> str: # strip unwanted tags inside every well‑formed <math>…</math> def _inner(m): inner = STRIP_TAGS.sub("", m.group(2)) return f"{m.group(1)}{inner}</math>" if inner.strip() else "" cleaned = MATH_BLOCK.sub(_inner, html) # drop only orphan *closing* </math> tags depth = 0 parts = [] for token in re.split(r"(</?math[^>]*>)", cleaned, flags=re.I): if token.lower().startswith("<math"): depth += 1 parts.append(token) elif token.lower() == "</math>": if depth: # keep it only if it matches an open depth -= 1 parts.append(token) # else: skip orphan closing tag else: parts.append(token) return "".join(parts) def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25): # Sorts in reading order. Not 100% accurate, this should only # be used as a starting point for more advanced sorting. vertical_groups = {} for line in lines: group_key = ( round( line.bbox[1] if isinstance(line, TextLine) else line["bbox"][1] / tolerance ) * tolerance ) if group_key not in vertical_groups: vertical_groups[group_key] = [] vertical_groups[group_key].append(line) # Sort each group horizontally and flatten the groups into a single list sorted_lines = [] for _, group in sorted(vertical_groups.items()): sorted_group = sorted( group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0] ) sorted_lines.extend(sorted_group) return sorted_lines def clean_close_polygons(bboxes: List[List[List[int]]], thresh: float = 0.1): if len(bboxes) < 2: return bboxes new_bboxes = [bboxes[0]] for i in range(1, len(bboxes)): close = True prev_bbox = bboxes[i - 1] bbox = bboxes[i] for j in range(4): if ( abs(bbox[j][0] - prev_bbox[j][0]) > thresh or abs(bbox[j][1] - prev_bbox[j][1]) > thresh ): close = False break if not close: new_bboxes.append(bboxes[i]) return new_bboxes def words_from_chars(chars: List[TextChar], line_box: PolygonBox): words = [] word = None for i, char in enumerate(chars): if not char.bbox_valid: if word: words.append(word) word = None continue if not word: word = TextWord(**char.model_dump()) # Fit bounds to line if first word if i == 0: word.merge_left(line_box) elif not char.text.strip(): if word: words.append(word) word = None else: # Merge bboxes word.merge(char) word.text = word.text + char.text if i == len(chars) - 1: word.merge_right(line_box) if word: words.append(word) return words ``` -------------------------------------------------------------------------------- /surya/common/s3.py: -------------------------------------------------------------------------------- ```python import json import os import shutil import tempfile import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path import requests from tqdm import tqdm from surya.logging import get_logger from surya.settings import settings logger = get_logger() # Lock file expiration time in seconds (10 minutes) LOCK_EXPIRATION = 600 def join_urls(url1: str, url2: str): url1 = url1.rstrip("/") url2 = url2.lstrip("/") return f"{url1}/{url2}" def get_model_name(pretrained_model_name_or_path: str): return pretrained_model_name_or_path.split("/")[0] def download_file(remote_path: str, local_path: str, chunk_size: int = 1024 * 1024): local_path = Path(local_path) try: response = requests.get(remote_path, stream=True, allow_redirects=True) response.raise_for_status() # Raise an exception for bad status codes # Get file size from headers for progress bar total_size = int(response.headers.get('content-length', 0)) # Create progress bar with file name and size info filename = local_path.name pbar = tqdm( total=total_size, unit='B', unit_scale=True, unit_divisor=1024, desc=f"Downloading {filename}", miniters=1 ) with open(local_path, "wb") as f: downloaded = 0 for chunk in response.iter_content(chunk_size=chunk_size): if chunk: f.write(chunk) downloaded += len(chunk) pbar.update(len(chunk)) pbar.close() return local_path except Exception as e: if local_path.exists(): local_path.unlink() logger.error(f"Download error for file {remote_path}: {str(e)}") raise def check_manifest(local_dir: str): local_dir = Path(local_dir) manifest_path = local_dir / "manifest.json" if not os.path.exists(manifest_path): return False try: with open(manifest_path, "r") as f: manifest = json.load(f) for file in manifest["files"]: if not os.path.exists(local_dir / file): return False except Exception: return False return True def download_directory(remote_path: str, local_dir: str): model_name = get_model_name(remote_path) s3_url = join_urls(settings.S3_BASE_URL, remote_path) # Check to see if it's already downloaded model_exists = check_manifest(local_dir) if model_exists: return # Use tempfile.TemporaryDirectory to automatically clean up with tempfile.TemporaryDirectory() as temp_dir: # Download the manifest file manifest_file = join_urls(s3_url, "manifest.json") manifest_path = os.path.join(temp_dir, "manifest.json") download_file(manifest_file, manifest_path) # List and download all files with open(manifest_path, "r") as f: manifest = json.load(f) pbar = tqdm( desc=f"Downloading {model_name} model to {local_dir}", total=len(manifest["files"]), ) with ThreadPoolExecutor( max_workers=settings.PARALLEL_DOWNLOAD_WORKERS ) as executor: futures = [] for file in manifest["files"]: remote_file = join_urls(s3_url, file) local_file = os.path.join(temp_dir, file) futures.append(executor.submit(download_file, remote_file, local_file)) for future in futures: future.result() pbar.update(1) pbar.close() # Move all files to new directory for file in os.listdir(temp_dir): shutil.move(os.path.join(temp_dir, file), local_dir) class S3DownloaderMixin: s3_prefix = "s3://" @classmethod def get_local_path(cls, pretrained_model_name_or_path) -> str: if pretrained_model_name_or_path.startswith(cls.s3_prefix): pretrained_model_name_or_path = pretrained_model_name_or_path.replace( cls.s3_prefix, "" ) cache_dir = settings.MODEL_CACHE_DIR local_path = os.path.join(cache_dir, pretrained_model_name_or_path) os.makedirs(local_path, exist_ok=True) else: local_path = "" return local_path @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): # Allow loading models directly from the hub, or using s3 if not pretrained_model_name_or_path.startswith(cls.s3_prefix): return super().from_pretrained( pretrained_model_name_or_path, *args, **kwargs ) local_path = cls.get_local_path(pretrained_model_name_or_path) pretrained_model_name_or_path = pretrained_model_name_or_path.replace( cls.s3_prefix, "" ) # Retry logic for downloading the model folder retries = 3 delay = 5 attempt = 0 success = False while not success and attempt < retries: try: download_directory(pretrained_model_name_or_path, local_path) success = True # If download succeeded except Exception as e: logger.error( f"Error downloading model from {pretrained_model_name_or_path}. Attempt {attempt + 1} of {retries}. Error: {e}" ) attempt += 1 if attempt < retries: logger.info(f"Retrying in {delay} seconds...") time.sleep(delay) # Wait before retrying else: logger.error( f"Failed to download {pretrained_model_name_or_path} after {retries} attempts." ) raise e # Reraise exception after max retries return super().from_pretrained(local_path, *args, **kwargs) ``` -------------------------------------------------------------------------------- /surya/detection/__init__.py: -------------------------------------------------------------------------------- ```python from concurrent.futures import ThreadPoolExecutor from typing import List, Generator, Tuple import numpy as np import torch import torch.nn.functional as F from PIL import Image from tqdm import tqdm from surya.common.predictor import BasePredictor from surya.common.xla import mark_step from surya.detection.loader import DetectionModelLoader from surya.detection.parallel import FakeExecutor from surya.detection.util import get_total_splits, split_image from surya.detection.schema import TextDetectionResult from surya.settings import settings from surya.detection.heatmap import parallel_get_boxes class DetectionPredictor(BasePredictor): model_loader_cls = DetectionModelLoader batch_size = settings.DETECTOR_BATCH_SIZE default_batch_sizes = {"cpu": 8, "mps": 8, "cuda": 36, "xla": 18} def __call__( self, images: List[Image.Image], batch_size=None, include_maps=False ) -> List[TextDetectionResult]: detection_generator = self.batch_detection( images, batch_size=batch_size, static_cache=settings.DETECTOR_STATIC_CACHE ) postprocessing_futures = [] max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) parallelize = ( not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH ) executor = ThreadPoolExecutor if parallelize else FakeExecutor with executor(max_workers=max_workers) as e: for preds, orig_sizes in detection_generator: for pred, orig_size in zip(preds, orig_sizes): postprocessing_futures.append( e.submit(parallel_get_boxes, pred, orig_size, include_maps) ) return [future.result() for future in postprocessing_futures] def prepare_image(self, img): new_size = (self.processor.size["width"], self.processor.size["height"]) # This double resize actually necessary for downstream accuracy img.thumbnail(new_size, Image.Resampling.LANCZOS) img = img.resize( new_size, Image.Resampling.LANCZOS ) # Stretch smaller dimension to fit new size img = np.asarray(img, dtype=np.uint8) img = self.processor(img)["pixel_values"][0] img = torch.from_numpy(img) return img def batch_detection( self, images: List, batch_size=None, static_cache=False ) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]: assert all([isinstance(image, Image.Image) for image in images]) if batch_size is None: batch_size = self.get_batch_size() heatmap_count = self.model.config.num_labels orig_sizes = [image.size for image in images] splits_per_image = [ get_total_splits(size, self.processor.size["height"]) for size in orig_sizes ] batches = [] current_batch_size = 0 current_batch = [] for i in range(len(images)): if current_batch_size + splits_per_image[i] > batch_size: if len(current_batch) > 0: batches.append(current_batch) current_batch = [] current_batch_size = 0 current_batch.append(i) current_batch_size += splits_per_image[i] if len(current_batch) > 0: batches.append(current_batch) for batch_idx in tqdm( range(len(batches)), desc="Detecting bboxes", disable=self.disable_tqdm ): batch_image_idxs = batches[batch_idx] batch_images = [images[j].convert("RGB") for j in batch_image_idxs] split_index = [] split_heights = [] image_splits = [] for image_idx, image in enumerate(batch_images): image_parts, split_height = split_image( image, self.processor.size["height"] ) image_splits.extend(image_parts) split_index.extend([image_idx] * len(image_parts)) split_heights.extend(split_height) image_splits = [self.prepare_image(image) for image in image_splits] # Batch images in dim 0 batch = torch.stack(image_splits, dim=0).to(self.model.dtype) if static_cache: batch = self.pad_to_batch_size(batch, batch_size) with settings.INFERENCE_MODE(): pred = self.model( pixel_values=batch.to(self.model.device) ) # Moving the to device here fixes issues with xla recompilation logits = pred.logits correct_shape = [ self.processor.size["height"], self.processor.size["width"], ] current_shape = list(logits.shape[2:]) if current_shape != correct_shape: logits = F.interpolate( logits, size=correct_shape, mode="bilinear", align_corners=False ) mark_step() logits = logits.to(torch.float32).cpu().numpy() preds = [] for i, (idx, height) in enumerate(zip(split_index, split_heights)): # If our current prediction length is below the image idx, that means we have a new image # Otherwise, we need to add to the current image if len(preds) <= idx: preds.append([logits[i][k] for k in range(heatmap_count)]) else: heatmaps = preds[idx] pred_heatmaps = [logits[i][k] for k in range(heatmap_count)] if height < self.processor.size["height"]: # Cut off padding to get original height pred_heatmaps = [ pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps ] for k in range(heatmap_count): heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]]) preds[idx] = heatmaps yield preds, [orig_sizes[j] for j in batch_image_idxs] torch.cuda.empty_cache() ``` -------------------------------------------------------------------------------- /benchmark/utils/metrics.py: -------------------------------------------------------------------------------- ```python from functools import partial from itertools import repeat import numpy as np from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor def box_area(box): return (box[2] - box[0]) * (box[3] - box[1]) def calculate_iou(box1, box2, box1_only=False): intersection = intersection_area(box1, box2) union = box_area(box1) if not box1_only: union += box_area(box2) - intersection if union == 0: return 0 return intersection / union def match_boxes(preds, references): num_actual = len(references) num_predicted = len(preds) iou_matrix = np.zeros((num_actual, num_predicted)) for i, actual in enumerate(references): for j, pred in enumerate(preds): iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True) sorted_indices = np.argsort(iou_matrix, axis=None)[::-1] sorted_ious = iou_matrix.flatten()[sorted_indices] actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape) assigned_actual = set() assigned_pred = set() matches = [] for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious): i, j = idx if i not in assigned_actual and j not in assigned_pred: iou_val = iou_matrix[i, j] if iou_val > .95: # Account for rounding on box edges iou_val = 1.0 matches.append((i, j, iou_val)) assigned_actual.add(i) assigned_pred.add(j) unassigned_actual = set(range(num_actual)) - assigned_actual unassigned_pred = set(range(num_predicted)) - assigned_pred matches.extend([(i, None, -1.0) for i in unassigned_actual]) matches.extend([(None, j, 0.0) for j in unassigned_pred]) return matches def penalized_iou_score(preds, references): matches = match_boxes(preds, references) iou = sum([match[2] for match in matches]) / len(matches) return iou def intersection_pixels(box1, box2): x_left = max(box1[0], box2[0]) y_top = max(box1[1], box2[1]) x_right = min(box1[2], box2[2]) y_bottom = min(box1[3], box2[3]) if x_right < x_left or y_bottom < y_top: return set() x_left, x_right = int(x_left), int(x_right) y_top, y_bottom = int(y_top), int(y_bottom) coords = np.meshgrid(np.arange(x_left, x_right), np.arange(y_top, y_bottom)) pixels = set(zip(coords[0].flat, coords[1].flat)) return pixels def calculate_coverage(box, other_boxes, penalize_double=False): box_area = (box[2] - box[0]) * (box[3] - box[1]) if box_area == 0: return 0 # find total coverage of the box covered_pixels = set() double_coverage = list() for other_box in other_boxes: ia = intersection_pixels(box, other_box) double_coverage.append(list(covered_pixels.intersection(ia))) covered_pixels = covered_pixels.union(ia) # Penalize double coverage - having multiple bboxes overlapping the same pixels double_coverage_penalty = len(double_coverage) if not penalize_double: double_coverage_penalty = 0 covered_pixels_count = max(0, len(covered_pixels) - double_coverage_penalty) return covered_pixels_count / box_area def intersection_area(box1, box2): x_left = max(box1[0], box2[0]) y_top = max(box1[1], box2[1]) x_right = min(box1[2], box2[2]) y_bottom = min(box1[3], box2[3]) if x_right < x_left or y_bottom < y_top: return 0.0 return (x_right - x_left) * (y_bottom - y_top) def calculate_coverage_fast(box, other_boxes, penalize_double=False): box = np.array(box) other_boxes = np.array(other_boxes) # Calculate box area box_area = (box[2] - box[0]) * (box[3] - box[1]) if box_area == 0: return 0 x_left = np.maximum(box[0], other_boxes[:, 0]) y_top = np.maximum(box[1], other_boxes[:, 1]) x_right = np.minimum(box[2], other_boxes[:, 2]) y_bottom = np.minimum(box[3], other_boxes[:, 3]) widths = np.maximum(0, x_right - x_left) heights = np.maximum(0, y_bottom - y_top) intersect_areas = widths * heights total_intersect = np.sum(intersect_areas) return min(1.0, total_intersect / box_area) def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True): if len(references) == 0: return { "precision": 1, "recall": 1, } if len(preds) == 0: return { "precision": 0, "recall": 0, } # If we're not penalizing double coverage, we can use a faster calculation coverage_func = calculate_coverage_fast if penalize_double: coverage_func = calculate_coverage with ThreadPoolExecutor(max_workers=workers) as executor: precision_func = partial(coverage_func, penalize_double=penalize_double) precision_iou = executor.map(precision_func, preds, repeat(references)) reference_iou = executor.map(coverage_func, references, repeat(preds)) precision_classes = [1 if i > threshold else 0 for i in precision_iou] precision = sum(precision_classes) / len(precision_classes) recall_classes = [1 if i > threshold else 0 for i in reference_iou] recall = sum(recall_classes) / len(recall_classes) return { "precision": precision, "recall": recall, } def mean_coverage(preds, references): coverages = [] for box1 in references: coverage = calculate_coverage(box1, preds) coverages.append(coverage) for box2 in preds: coverage = calculate_coverage(box2, references) coverages.append(coverage) # Calculate the average coverage over all comparisons if len(coverages) == 0: return 0 coverage = sum(coverages) / len(coverages) return {"coverage": coverage} def rank_accuracy(preds, references): # Preds and references need to be aligned so each position refers to the same bbox pairs = [] for i, pred in enumerate(preds): for j, pred2 in enumerate(preds): if i == j: continue pairs.append((i, j, pred > pred2)) # Find how many of the prediction rankings are correct correct = 0 for i, ref in enumerate(references): for j, ref2 in enumerate(references): if (i, j, ref > ref2) in pairs: correct += 1 return correct / len(pairs) ``` -------------------------------------------------------------------------------- /surya/common/donut/processor.py: -------------------------------------------------------------------------------- ```python from typing import Dict, Union, Optional, List, Iterable import cv2 from torch import TensorType from transformers import ImageProcessingMixin from transformers.image_processing_utils import BatchFeature from transformers.image_transforms import pad, normalize from transformers.image_utils import ( ImageInput, ChannelDimension, make_list_of_images, get_image_size, ) import numpy as np from PIL import Image import PIL from transformers.utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD from surya.common.s3 import S3DownloaderMixin from surya.settings import settings class SuryaEncoderImageProcessor(S3DownloaderMixin, ImageProcessingMixin): def __init__( self, *args, max_size=None, align_long_axis=False, rescale_factor: Union[int, float] = 1 / 255, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, **kwargs, ): super().__init__(*args, **kwargs) self.patch_size = kwargs.get("patch_size", (4, 4)) self.max_size = max_size self.do_align_long_axis = align_long_axis self.resample = Image.Resampling.BILINEAR self.rescale_factor = rescale_factor self.image_mean = ( image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN ) self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD def __call__(self, images, **kwargs) -> PIL.Image.Image: """Preprocess an image or a batch of images.""" return self.preprocess(images, **kwargs) @classmethod def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4): max_width, max_height = size["width"], size["height"] resized_image = cv2.resize( image, (max_width, max_height), interpolation=interpolation ) resized_image = resized_image.transpose(2, 0, 1) return resized_image def process_inner(self, images: List[np.ndarray]): assert images[0].shape[2] == 3 # RGB input images, channel dim last if self.do_align_long_axis: # Rotate if the bbox is wider than it is tall images = [ SuryaEncoderImageProcessor.align_long_axis( image, size=self.max_size, input_data_format=ChannelDimension.LAST ) for image in images ] # Verify that the image is wider than it is tall for img in images: assert img.shape[1] >= img.shape[0] # This also applies the right channel dim format, to channel x height x width images = [ SuryaEncoderImageProcessor.numpy_resize(img, self.max_size, self.resample) for img in images ] assert images[0].shape[0] == 3 # RGB input images, channel dim first # Convert to float32 for rescale/normalize images = [img.astype(np.float32) for img in images] # Pads with 255 (whitespace) # Pad to max size to improve performance max_size = self.max_size images = [ SuryaEncoderImageProcessor.pad_image( image=image, size=max_size, input_data_format=ChannelDimension.FIRST, pad_value=settings.RECOGNITION_PAD_VALUE, ) for image in images ] # Rescale and normalize for idx in range(len(images)): images[idx] = (images[idx].astype(np.float64) * self.rescale_factor).astype( np.float32 ) images = [ SuryaEncoderImageProcessor.normalize( img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST, ) for img in images ] return images def preprocess( self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs, ) -> PIL.Image.Image: images = make_list_of_images(images) # Convert to numpy for later processing steps images = [np.array(img) for img in images] images = self.process_inner(images) data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) @classmethod def pad_image( cls, image: np.ndarray, size: Dict[str, int], data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, pad_value: float = 0.0, ) -> np.ndarray: output_height, output_width = size["height"], size["width"] input_height, input_width = get_image_size(image, channel_dim=input_data_format) delta_width = output_width - input_width delta_height = output_height - input_height assert delta_width >= 0 and delta_height >= 0 pad_top = delta_height // 2 pad_left = delta_width // 2 pad_bottom = delta_height - pad_top pad_right = delta_width - pad_left padding = ((pad_top, pad_bottom), (pad_left, pad_right)) return pad( image, padding, data_format=data_format, input_data_format=input_data_format, constant_values=pad_value, ) @classmethod def align_long_axis( cls, image: np.ndarray, size: Dict[str, int], **kwargs ) -> np.ndarray: input_height, input_width = image.shape[:2] output_height, output_width = size["height"], size["width"] if (output_width < output_height and input_width > input_height) or ( output_width > output_height and input_width < input_height ): image = np.rot90(image, 3) return image @classmethod def normalize( cls, image: np.ndarray, mean: Union[float, Iterable[float]], std: Union[float, Iterable[float]], data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> np.ndarray: return normalize( image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs, ) ``` -------------------------------------------------------------------------------- /benchmark/table_recognition.py: -------------------------------------------------------------------------------- ```python import click import collections import json from surya.debug.draw import draw_bboxes_on_image from tabulate import tabulate from surya.input.processing import convert_if_not_rgb from surya.table_rec import TableRecPredictor from surya.settings import settings from benchmark.utils.metrics import penalized_iou_score from benchmark.utils.tatr import load_tatr, batch_inference_tatr import os import time import datasets @click.command(help="Benchmark table rec dataset") @click.option( "--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"), ) @click.option( "--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=512, ) @click.option("--tatr", is_flag=True, help="Run table transformer.", default=False) @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) def main(results_dir: str, max_rows: int, tatr: bool, debug: bool): table_rec_predictor = TableRecPredictor() pathname = "table_rec_bench" # These have already been shuffled randomly, so sampling from the start is fine split = "train" if max_rows is not None: split = f"train[:{max_rows}]" dataset = datasets.load_dataset(settings.TABLE_REC_BENCH_DATASET_NAME, split=split) images = list(dataset["image"]) images = convert_if_not_rgb(images) if settings.TABLE_REC_STATIC_CACHE: # Run through one batch to compile the model table_rec_predictor(images[:1]) start = time.time() table_rec_predictions = table_rec_predictor(images) surya_time = time.time() - start folder_name = os.path.basename(pathname).split(".")[0] result_path = os.path.join(results_dir, folder_name) os.makedirs(result_path, exist_ok=True) page_metrics = collections.OrderedDict() mean_col_iou = 0 mean_row_iou = 0 for idx, (pred, image) in enumerate(zip(table_rec_predictions, images)): row = dataset[idx] pred_row_boxes = [p.bbox for p in pred.rows] pred_col_bboxes = [p.bbox for p in pred.cols] actual_row_bboxes = [r["bbox"] for r in row["rows"]] actual_col_bboxes = [c["bbox"] for c in row["columns"]] row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes) col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes) page_results = { "row_score": row_score, "col_score": col_score, "row_count": len(actual_row_bboxes), "col_count": len(actual_col_bboxes), } mean_col_iou += col_score mean_row_iou += row_score page_metrics[idx] = page_results if debug: # Save debug images draw_img = image.copy() draw_bboxes_on_image( pred_row_boxes, draw_img, [f"Row {i}" for i in range(len(pred_row_boxes))], ) draw_bboxes_on_image( pred_col_bboxes, draw_img, [f"Col {i}" for i in range(len(pred_col_bboxes))], color="blue", ) draw_img.save(os.path.join(result_path, f"{idx}_bbox.png")) actual_draw_image = image.copy() draw_bboxes_on_image( actual_row_bboxes, actual_draw_image, [f"Row {i}" for i in range(len(actual_row_bboxes))], ) draw_bboxes_on_image( actual_col_bboxes, actual_draw_image, [f"Col {i}" for i in range(len(actual_col_bboxes))], color="blue", ) actual_draw_image.save(os.path.join(result_path, f"{idx}_actual.png")) mean_col_iou /= len(table_rec_predictions) mean_row_iou /= len(table_rec_predictions) out_data = { "surya": { "time": surya_time, "mean_row_iou": mean_row_iou, "mean_col_iou": mean_col_iou, "page_metrics": page_metrics, } } if tatr: tatr_model = load_tatr() start = time.time() tatr_predictions = batch_inference_tatr(tatr_model, images, 1) tatr_time = time.time() - start page_metrics = collections.OrderedDict() mean_col_iou = 0 mean_row_iou = 0 for idx, pred in enumerate(tatr_predictions): row = dataset[idx] pred_row_boxes = [p["bbox"] for p in pred["rows"]] pred_col_bboxes = [p["bbox"] for p in pred["cols"]] actual_row_bboxes = [r["bbox"] for r in row["rows"]] actual_col_bboxes = [c["bbox"] for c in row["columns"]] row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes) col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes) page_results = { "row_score": row_score, "col_score": col_score, "row_count": len(actual_row_bboxes), "col_count": len(actual_col_bboxes), } mean_col_iou += col_score mean_row_iou += row_score page_metrics[idx] = page_results mean_col_iou /= len(tatr_predictions) mean_row_iou /= len(tatr_predictions) out_data["tatr"] = { "time": tatr_time, "mean_row_iou": mean_row_iou, "mean_col_iou": mean_col_iou, "page_metrics": page_metrics, } with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(out_data, f, indent=4) table = [ ["Model", "Row Intersection", "Col Intersection", "Time Per Image"], [ "Surya", f"{out_data['surya']['mean_row_iou']:.2f}", f"{out_data['surya']['mean_col_iou']:.5f}", f"{surya_time / len(images):.5f}", ], ] if tatr: table.append( [ "Table transformer", f"{out_data['tatr']['mean_row_iou']:.2f}", f"{out_data['tatr']['mean_col_iou']:.5f}", f"{tatr_time / len(images):.5f}", ] ) print(tabulate(table, headers="firstrow", tablefmt="github")) print( "Intersection is the average of the intersection % between each actual row/column, and the predictions. With penalties for too many/few predictions." ) print( "Note that table transformers is unbatched, since the example code in the repo is unbatched." ) print(f"Wrote results to {result_path}") if __name__ == "__main__": main() ``` -------------------------------------------------------------------------------- /surya/table_rec/model/decoder.py: -------------------------------------------------------------------------------- ```python from typing import Optional, Tuple, Union import torch from torch import nn from surya.common.adetr.decoder import SuryaADETRDecoderModel, SuryaADETRDecoderPreTrainedModel from surya.table_rec.model.config import TableRecModelOutput from surya.table_rec.shaper import LabelShaper from surya.settings import settings class LabelEmbedding(nn.Module): def __init__(self, config): super().__init__() # Bboxes self.w_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.h_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.cx_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.cy_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.xskew_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.yskew_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.x1_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.y1_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.x2_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.y2_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.x3_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.y3_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.x4_embed = nn.Embedding(config.vocab_size, config.box_embed_size) self.y4_embed = nn.Embedding(config.vocab_size, config.box_embed_size) # Get indexes for passed in tensor shaper = LabelShaper() self.component_idxs = shaper.component_idx_dict() merge_count = shaper.get_box_property("merges")[1] + config.special_token_count category_count = shaper.get_box_property("category")[1] + config.special_token_count # Other box properties self.category_embed = nn.Embedding(category_count, config.property_embed_size) self.merge_embed = nn.Embedding(merge_count, config.property_embed_size) self.colspan_embed = nn.Embedding(config.vocab_size, config.property_embed_size) self.config = config def forward(self, boxes: torch.LongTensor, *args): # Need to keep *args for compatibility with common decoder boxes = boxes.to(torch.long).clamp(0, self.config.vocab_size) boxes_unbound = boxes.to(torch.long).unbind(dim=-1) cx, cy, w, h, xskew, yskew = boxes_unbound[self.component_idxs["bbox"][0]:self.component_idxs["bbox"][1]] category = boxes_unbound[self.component_idxs["category"][0]:self.component_idxs["category"][1]][0] merges = boxes_unbound[self.component_idxs["merges"][0]:self.component_idxs["merges"][1]][0] colspan = boxes_unbound[self.component_idxs["colspan"][0]:self.component_idxs["colspan"][1]][0] xskew_actual = ((xskew - self.config.bbox_size // 2) / 2).to(torch.long) yskew_actual = ((yskew - self.config.bbox_size // 2) / 2).to(torch.long) x1 = (cx - w // 2 - xskew_actual).clamp(0, self.config.bbox_size).to(torch.long) y1 = (cy - h // 2 - yskew_actual).clamp(0, self.config.bbox_size).to(torch.long) x3 = (cx + w // 2 + xskew_actual).clamp(0, self.config.bbox_size).to(torch.long) y3 = (cy + h // 2 + yskew_actual).clamp(0, self.config.bbox_size).to(torch.long) size_embeds = self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) skew_embeds = self.xskew_embed(xskew) + self.yskew_embed(yskew) corner_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x3_embed(x3) + self.y3_embed(y3) box_embeds = size_embeds + skew_embeds + corner_embeds property_embeds = self.category_embed(category) + self.merge_embed(merges) + self.colspan_embed(colspan) # Cat bbox and property embeddings embedded = torch.cat([box_embeds, property_embeds], dim=-1) return embedded class SuryaTableRecDecoder(SuryaADETRDecoderPreTrainedModel): _tied_weights_keys = None def __init__(self, config, **kwargs): super().__init__(config) embed_tokens = LabelEmbedding(config) self.model = SuryaADETRDecoderModel( config, embedder=embed_tokens, static_cache=settings.TABLE_REC_STATIC_CACHE, max_boxes=settings.TABLE_REC_MAX_BOXES ) self.vocab_size = config.vocab_size shaper = LabelShaper() property_heads = {} for k in shaper.property_keys: _, kcount, mode = shaper.get_box_property(k) property_heads[k] = nn.Linear(config.hidden_size, kcount, bias=False) self.box_property_heads = nn.ModuleDict(property_heads) self.pre_output_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model # Ignore copy def forward( self, input_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, prefill: bool = False, **kwargs ) -> Union[Tuple, TableRecModelOutput]: outputs = self.model( input_ids=input_ids, cache_position=cache_position, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_hidden_states=True, return_dict=True, prefill=prefill, ) hidden_states = self.pre_output_norm(outputs[0]) box_property_logits = {} for key in self.box_property_heads: box_property_logits[key] = self.box_property_heads[key](hidden_states) bbox_logits = nn.functional.sigmoid(box_property_logits["bbox"]) box_property_logits["bbox"] = bbox_logits return TableRecModelOutput( box_property_logits=box_property_logits, hidden_states=hidden_states, ) ``` -------------------------------------------------------------------------------- /surya/common/polygon.py: -------------------------------------------------------------------------------- ```python import copy from typing import List, Optional import numpy as np from pydantic import BaseModel, field_validator, computed_field import numbers class PolygonBox(BaseModel): polygon: List[List[float]] confidence: Optional[float] = None @field_validator("polygon", mode="before") @classmethod def convert_bbox_to_polygon(cls, value): if isinstance(value, (list, tuple)) and len(value) == 4: if all(isinstance(x, numbers.Number) for x in value): value = [float(v) for v in value] x_min, y_min, x_max, y_max = value polygon = [ [x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max], ] return polygon elif all( isinstance(point, (list, tuple)) and len(point) == 2 for point in value ): value = [[float(v) for v in point] for point in value] return value elif isinstance(value, np.ndarray): if value.shape == (4, 2): return value.tolist() raise ValueError( f"Input must be either a bbox [x_min, y_min, x_max, y_max] or a polygon with 4 corners [(x,y), (x,y), (x,y), (x,y)]. All values must be numeric. You passed {value} of type {type(value)}. The first value is of type {type(value[0])}." ) @property def height(self): return self.bbox[3] - self.bbox[1] @property def width(self): return self.bbox[2] - self.bbox[0] @property def area(self): return self.width * self.height @computed_field @property def bbox(self) -> List[float]: x_coords = [point[0] for point in self.polygon] y_coords = [point[1] for point in self.polygon] return [min(x_coords), min(y_coords), max(x_coords), max(y_coords)] def rescale(self, processor_size, image_size): # Point is in x, y format page_width, page_height = processor_size img_width, img_height = image_size width_scaler = img_width / page_width height_scaler = img_height / page_height for corner in self.polygon: corner[0] = int(corner[0] * width_scaler) corner[1] = int(corner[1] * height_scaler) def round(self, divisor): for corner in self.polygon: corner[0] = int(corner[0] / divisor) * divisor corner[1] = int(corner[1] / divisor) * divisor def fit_to_bounds(self, bounds): new_corners = copy.deepcopy(self.polygon) for corner in new_corners: corner[0] = max(min(corner[0], bounds[2]), bounds[0]) corner[1] = max(min(corner[1], bounds[3]), bounds[1]) self.polygon = new_corners def merge(self, other): x1 = min(self.bbox[0], other.bbox[0]) y1 = min(self.bbox[1], other.bbox[1]) x2 = max(self.bbox[2], other.bbox[2]) y2 = max(self.bbox[3], other.bbox[3]) self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] def merge_left(self, other): x1 = min(self.bbox[0], other.bbox[0]) self.polygon[0][0] = x1 self.polygon[3][0] = x1 def merge_right(self, other): x2 = max(self.bbox[2], other.bbox[2]) self.polygon[1][0] = x2 self.polygon[2][0] = x2 def expand(self, x_margin: float, y_margin: float): new_polygon = [] x_margin = x_margin * self.width y_margin = y_margin * self.height for idx, poly in enumerate(self.polygon): if idx == 0: new_polygon.append([int(poly[0] - x_margin), int(poly[1] - y_margin)]) elif idx == 1: new_polygon.append([int(poly[0] + x_margin), int(poly[1] - y_margin)]) elif idx == 2: new_polygon.append([int(poly[0] + x_margin), int(poly[1] + y_margin)]) elif idx == 3: new_polygon.append([int(poly[0] - x_margin), int(poly[1] + y_margin)]) self.polygon = new_polygon def intersection_polygon(self, other) -> List[List[float]]: new_poly = [] for i in range(4): if i == 0: new_corner = [ max(self.polygon[0][0], other.polygon[0][0]), max(self.polygon[0][1], other.polygon[0][1]), ] elif i == 1: new_corner = [ min(self.polygon[1][0], other.polygon[1][0]), max(self.polygon[1][1], other.polygon[1][1]), ] elif i == 2: new_corner = [ min(self.polygon[2][0], other.polygon[2][0]), min(self.polygon[2][1], other.polygon[2][1]), ] elif i == 3: new_corner = [ max(self.polygon[3][0], other.polygon[3][0]), min(self.polygon[3][1], other.polygon[3][1]), ] new_poly.append(new_corner) return new_poly def intersection_area(self, other, x_margin=0, y_margin=0): x_overlap = self.x_overlap(other, x_margin) y_overlap = self.y_overlap(other, y_margin) return x_overlap * y_overlap def x_overlap(self, other, x_margin=0): return max( 0, min(self.bbox[2] + x_margin, other.bbox[2] + x_margin) - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin), ) def y_overlap(self, other, y_margin=0): return max( 0, min(self.bbox[3] + y_margin, other.bbox[3] + y_margin) - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin), ) def intersection_pct(self, other, x_margin=0, y_margin=0): assert 0 <= x_margin <= 1 assert 0 <= y_margin <= 1 if self.area == 0: return 0 if x_margin: x_margin = int(min(self.width, other.width) * x_margin) if y_margin: y_margin = int(min(self.height, other.height) * y_margin) intersection = self.intersection_area(other, x_margin, y_margin) return intersection / self.area def shift(self, x_shift: float | None = None, y_shift: float | None = None): if x_shift is not None: for corner in self.polygon: corner[0] += x_shift if y_shift is not None: for corner in self.polygon: corner[1] += y_shift def clamp(self, bbox: List[float]): for corner in self.polygon: corner[0] = max(min(corner[0], bbox[2]), bbox[0]) corner[1] = max(min(corner[1], bbox[3]), bbox[1]) @property def center(self): return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2] def distance(self, other): center = self.center other_center = other.center return ( (center[0] - other_center[0]) ** 2 + (center[1] - other_center[1]) ** 2 ) ** 0.5 def __hash__(self): return hash(tuple(self.bbox)) ``` -------------------------------------------------------------------------------- /surya/settings.py: -------------------------------------------------------------------------------- ```python import os from typing import Callable, Dict, Optional import torch from dotenv import find_dotenv from pydantic import computed_field from pydantic_settings import BaseSettings from pathlib import Path from platformdirs import user_cache_dir class Settings(BaseSettings): # General TORCH_DEVICE: Optional[str] = None IMAGE_DPI: int = 96 # Used for detection, layout, reading order IMAGE_DPI_HIGHRES: int = 192 # Used for OCR, table rec IN_STREAMLIT: bool = False # Whether we're running in streamlit FLATTEN_PDF: bool = True # Flatten PDFs by merging form fields before processing DISABLE_TQDM: bool = False # Disable tqdm progress bars S3_BASE_URL: str = "https://models.datalab.to" PARALLEL_DOWNLOAD_WORKERS: int = ( 10 # Number of workers for parallel model downloads ) MODEL_CACHE_DIR: str = str(Path(user_cache_dir("datalab")) / "models") LOGLEVEL: str = "INFO" # Logging level # Paths DATA_DIR: str = "data" RESULT_DIR: str = "results" BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) FONT_DIR: str = os.path.join(BASE_DIR, "static", "fonts") @computed_field def TORCH_DEVICE_MODEL(self) -> str: if self.TORCH_DEVICE is not None: return self.TORCH_DEVICE if torch.cuda.is_available(): return "cuda" if torch.backends.mps.is_available(): return "mps" try: import torch_xla if len(torch_xla.devices()) > 0: return "xla" except Exception: pass return "cpu" # Text detection DETECTOR_BATCH_SIZE: Optional[int] = None # Defaults to 2 for CPU/MPS, 32 otherwise DETECTOR_MODEL_CHECKPOINT: str = "s3://text_detection/2025_05_07" DETECTOR_BENCH_DATASET_NAME: str = "vikp/doclaynet_bench" DETECTOR_IMAGE_CHUNK_HEIGHT: int = ( 1400 # Height at which to slice images vertically ) DETECTOR_TEXT_THRESHOLD: float = ( 0.6 # Threshold for text detection (above this is considered text) ) DETECTOR_BLANK_THRESHOLD: float = ( 0.35 # Threshold for blank space (below this is considered blank) ) DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min( 8, os.cpu_count() ) # Number of workers for postprocessing DETECTOR_MIN_PARALLEL_THRESH: int = ( 3 # Minimum number of images before we parallelize ) DETECTOR_BOX_Y_EXPAND_MARGIN: float = ( 0.05 # Margin by which to expand detected boxes vertically ) COMPILE_DETECTOR: bool = False # Text recognition FOUNDATION_MODEL_CHECKPOINT: str = "s3://text_recognition/2025_09_23" FOUNDATION_MODEL_QUANTIZE: bool = False FOUNDATION_MAX_TOKENS: Optional[int] = None FOUNDATION_CHUNK_SIZE: Optional[int] = None FOUNDATION_PAD_TO_NEAREST: int = 256 COMPILE_FOUNDATION: bool = False FOUNDATION_MULTI_TOKEN_MIN_CONFIDENCE: float = 0.9 RECOGNITION_MODEL_CHECKPOINT: str = "s3://text_recognition/2025_09_23" RECOGNITION_BATCH_SIZE: Optional[int] = ( None # Defaults to 8 for CPU/MPS, 256 otherwise ) RECOGNITION_RENDER_FONTS: Dict[str, str] = { "all": os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf"), "zh": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), "ja": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), "ko": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), } RECOGNITION_FONT_DL_BASE: str = ( "https://github.com/satbyy/go-noto-universal/releases/download/v7.0" ) RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench" RECOGNITION_PAD_VALUE: int = 255 # Should be 0 or 255 # Layout LAYOUT_MODEL_CHECKPOINT: str = "s3://layout/2025_09_23" LAYOUT_IMAGE_SIZE: Dict = {"height": 768, "width": 768} LAYOUT_SLICE_MIN: Dict = { "height": 1500, "width": 1500, } # When to start slicing images LAYOUT_SLICE_SIZE: Dict = {"height": 1200, "width": 1200} # Size of slices LAYOUT_BATCH_SIZE: Optional[int] = None LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench" LAYOUT_MAX_BOXES: int = 100 COMPILE_LAYOUT: bool = False LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench" ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench" # Table Rec TABLE_REC_MODEL_CHECKPOINT: str = "s3://table_recognition/2025_02_18" TABLE_REC_IMAGE_SIZE: Dict = {"height": 768, "width": 768} TABLE_REC_MAX_BOXES: int = 150 TABLE_REC_BATCH_SIZE: Optional[int] = None TABLE_REC_BENCH_DATASET_NAME: str = "datalab-to/fintabnet_bench" COMPILE_TABLE_REC: bool = False # Texify TEXIFY_BENCHMARK_DATASET: str = "datalab-to/texify_bench" # OCR Error Detection OCR_ERROR_MODEL_CHECKPOINT: str = "s3://ocr_error_detection/2025_02_18" OCR_ERROR_BATCH_SIZE: Optional[int] = None COMPILE_OCR_ERROR: bool = False # Tesseract (for benchmarks only) TESSDATA_PREFIX: Optional[str] = None COMPILE_ALL: bool = False @computed_field def DETECTOR_STATIC_CACHE(self) -> bool: return ( self.COMPILE_ALL or self.COMPILE_DETECTOR or self.TORCH_DEVICE_MODEL == "xla" ) # We need to static cache and pad to batch size for XLA, since it will recompile otherwise @computed_field def LAYOUT_STATIC_CACHE(self) -> bool: return ( self.COMPILE_ALL or self.COMPILE_LAYOUT or self.TORCH_DEVICE_MODEL == "xla" ) @computed_field def FOUNDATION_XLA(self) -> bool: return ( self.TORCH_DEVICE_MODEL == "xla" ) # We need to static cache and pad to batch size for XLA, since it will recompile otherwise @computed_field def FOUNDATION_STATIC_CACHE(self) -> bool: return ( self.COMPILE_ALL or self.COMPILE_FOUNDATION or self.TORCH_DEVICE_MODEL == "xla" ) # We need to static cache and pad to batch size for XLA, since it will recompile otherwise @computed_field def TABLE_REC_STATIC_CACHE(self) -> bool: return ( self.COMPILE_ALL or self.COMPILE_TABLE_REC or self.TORCH_DEVICE_MODEL == "xla" ) @computed_field def OCR_ERROR_STATIC_CACHE(self) -> bool: return ( self.COMPILE_ALL or self.COMPILE_OCR_ERROR or self.TORCH_DEVICE_MODEL == "xla" ) @computed_field def MODEL_DTYPE(self) -> torch.dtype: if self.TORCH_DEVICE_MODEL == "cpu": return torch.float32 if self.TORCH_DEVICE_MODEL == "xla": return torch.bfloat16 return torch.float16 @computed_field def MODEL_DTYPE_BFLOAT(self) -> torch.dtype: if self.TORCH_DEVICE_MODEL == "cpu": return torch.float32 if self.TORCH_DEVICE_MODEL == "mps": return torch.bfloat16 return torch.bfloat16 @computed_field def INFERENCE_MODE(self) -> Callable: if self.TORCH_DEVICE_MODEL == "xla": return torch.no_grad return torch.inference_mode class Config: env_file = find_dotenv("local.env") extra = "ignore" settings = Settings() ``` -------------------------------------------------------------------------------- /surya/foundation/cache/static_ops.py: -------------------------------------------------------------------------------- ```python from typing import Any, Dict, List, Optional, Tuple import torch from transformers import PretrainedConfig from surya.foundation.cache.dynamic_ops import DynamicOpsCache """ Special cache class for the surya foundation model that supports - 1) Static shape 2) A custom sliding window, where image tokens stay in cache, and text tokens are popped 3) Continuous batching - merging etc 4) Attention mask management - To match with what's currently in the cache Heavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079 """ class StaticOpsCache(DynamicOpsCache): def __init__( self, config: PretrainedConfig, batch_size: int, max_cache_len: int, text_sliding_window: int, device: int, dtype: int, ): self.text_sliding_window = text_sliding_window self.num_layers = config.num_hidden_layers self.max_batch_size = batch_size self.max_cache_len = max_cache_len self.head_dim = ( getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads ) self._dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None else config.num_key_value_heads ) # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125 self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] cache_shape = ( self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim, ) device = torch.device(device) if device is not None else None for _ in range(config.num_hidden_layers): new_layer_key_cache = torch.zeros( cache_shape, dtype=self._dtype, device=device ) new_layer_value_cache = torch.zeros( cache_shape, dtype=self._dtype, device=device ) torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) self.attention_mask = torch.zeros( (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long ) self.text_token_counts = [ torch.zeros(self.max_batch_size, dtype=torch.long, device=device) for _ in range(self.num_layers) ] self.dtype = dtype self.device = device def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: prefill = cache_kwargs.get("prefill", False) update_fn = self._prefill_update if prefill else self._decode_update return update_fn( self.key_cache[layer_idx], self.value_cache[layer_idx], key_states, value_states, self.text_token_counts[layer_idx], cache_kwargs, ) def _prefill_update( self, key_cache: torch.Tensor, value_cache: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, text_token_counts: torch.Tensor, cache_kwargs: Optional[Dict[str, Any]] = None, ): cache_idxs: torch.tensor = cache_kwargs.get("cache_idxs", None) text_lengths: List[int] = cache_kwargs.get("text_lengths", None) assert cache_idxs is not None, "cache_idxs must be specified during prefill" assert text_lengths is not None, "text_lengths must be specified during prefill" cache_idx_length = len(cache_idxs) full_batch = len(cache_idxs) == self.max_batch_size # Insert key and value states at the end of the cache new_tokens = key_states.shape[2] # Direct right-aligned assignment if full_batch: key_cache[:, :, -new_tokens:] = key_states value_cache[:, :, -new_tokens:] = value_states else: key_cache[cache_idxs, :, -new_tokens:] = key_states[:cache_idx_length] value_cache[cache_idxs, :, -new_tokens:] = value_states[:cache_idx_length] return key_states, value_states # """ # Matches the logic of the decode update, but needs to be called before the updates # since some parts of the model depend on the attention mask # """ def decode_attention_mask_update( self, num_valid_tokens: torch.Tensor, cache_idxs: List[int] ): max_valid_tokens = num_valid_tokens.max().item() if max_valid_tokens == 0: # If no valid tokens, we don't need to update the attention mask return # Shift the attention mask to the left by max_valid_tokens self.attention_mask = self.attention_mask.roll(-1 * max_valid_tokens, dims=1) self.attention_mask[:, -max_valid_tokens:] = ( 1 # Full attention to all new tokens ) # Mirrors the logic from _prefill_update def prefill_attention_mask_update( self, attention_mask: torch.Tensor, merge_idxs: torch.Tensor, valid_batch_size: torch.Tensor, text_lengths: List[int], ): # Set from -(image_length + text_length) to end to 1 for each batch element seq_len = attention_mask.shape[1] self.attention_mask[merge_idxs] = ( 0 # Reset the attention mask for the current batch elements ) self.attention_mask[merge_idxs, -seq_len:] = attention_mask[:valid_batch_size] def _decode_update( self, key_cache: torch.Tensor, value_cache: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, text_token_counts: torch.Tensor, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Naive, always assumes we'll roll by a fixed amount # Needs left padding with beacons to work properly num_valid_tokens: torch.Tensor = cache_kwargs.get( "num_valid_tokens" ) # shape: (B,) assert num_valid_tokens is not None, ( "`num_valid_tokens` must be provided in `cache_kwargs`" ) # (B, H, L, D) valid_tokens = key_states.shape[2] key_cache.copy_(torch.roll(key_cache, -valid_tokens, dims=2)) value_cache.copy_(torch.roll(value_cache, -valid_tokens, dims=2)) key_cache[:, :, -valid_tokens:, :] = key_states value_cache[:, :, -valid_tokens:, :] = value_states # In-place edit - Mutates text_token_counts += num_valid_tokens text_token_counts.clamp_(max=self.text_sliding_window) return key_cache, value_cache # The attention mask managed by our kv cache automatically masks the tokens # in the cache, so we can return full length for HF to use in other places # This is mainly utilized in the cache_positions creation def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return self.max_cache_len ``` -------------------------------------------------------------------------------- /surya/table_rec/model/config.py: -------------------------------------------------------------------------------- ```python from dataclasses import dataclass from typing import Dict import torch from transformers import PretrainedConfig from transformers.utils import ModelOutput from surya.common.s3 import S3DownloaderMixin from surya.settings import settings BOX_DIM = 1024 SPECIAL_TOKENS = 5 MAX_BOXES = 150 MERGE_KEYS = { "none": 0, "merge_up": 1, "merge_down": 2, "merge_both": 3 } MERGE_VALUES = [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]] ID_TO_CATEGORY = { 0: 'Blank', 1: 'Table-row', 2: 'Table-column', 3: 'Table-cell', 4: 'Table' } CATEGORY_TO_ID = {v: k for k, v in ID_TO_CATEGORY.items()} ID_TO_HEADER = { 0: "None", 1: "Header" } HEADER_TO_ID = {v: k for k, v in ID_TO_HEADER.items()} BOX_PROPERTIES = [ ("bbox", 6, "regression"), ("category", len(ID_TO_CATEGORY), "classification"), ("merges", len(MERGE_KEYS), "classification"), ("colspan", 1, "regression"), ("is_header", len(ID_TO_HEADER), "classification") ] @dataclass class TableRecModelOutput(ModelOutput): box_property_logits: Dict[str, torch.Tensor] hidden_states: torch.Tensor | None = None class SuryaTableRecConfig(S3DownloaderMixin, PretrainedConfig): model_type = "vision-encoder-decoder" is_composition = True def __init__(self, **kwargs): super().__init__(**kwargs) if "encoder" in kwargs: encoder_config = kwargs.pop("encoder") decoder_config = kwargs.pop("decoder") else: encoder_config = DonutSwinTableRecConfig() decoder_config = SuryaTableRecDecoderConfig() self.encoder = encoder_config self.decoder = decoder_config self.is_encoder_decoder = True if isinstance(decoder_config, dict): self.decoder_start_token_id = decoder_config["bos_token_id"] self.pad_token_id = decoder_config["pad_token_id"] self.eos_token_id = decoder_config["eos_token_id"] else: self.decoder_start_token_id = decoder_config.bos_token_id self.pad_token_id = decoder_config.pad_token_id self.eos_token_id = decoder_config.eos_token_id class DonutSwinTableRecConfig(PretrainedConfig): model_type = "donut-swin" attribute_map = { "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers", } def __init__( self, image_size=(settings.TABLE_REC_IMAGE_SIZE["width"], settings.TABLE_REC_IMAGE_SIZE["height"]), patch_size=4, num_channels=3, embed_dim=128, depths=[2, 2, 12, 2], num_heads=[4, 8, 16, 32], num_kv_heads=[4, 8, 16, 32], window_size=8, mlp_ratio=4.0, qkv_bias=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, drop_path_rate=0.1, hidden_act="gelu", use_absolute_embeddings=False, initializer_range=0.02, layer_norm_eps=1e-5, encoder_length=1024, use_positional_embeddings=True, **kwargs, ): super().__init__(**kwargs) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.embed_dim = embed_dim self.depths = depths self.num_layers = len(depths) self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.window_size = window_size self.mlp_ratio = mlp_ratio self.qkv_bias = qkv_bias self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.drop_path_rate = drop_path_rate self.hidden_act = hidden_act self.use_absolute_embeddings = use_absolute_embeddings self.layer_norm_eps = layer_norm_eps self.initializer_range = initializer_range # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel # this indicates the channel dimension after the last stage of the model self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.encoder_length = encoder_length self.use_positional_embeddings = use_positional_embeddings class SuryaTableRecDecoderConfig(PretrainedConfig): model_type = "surya_tablerec" def __init__( self, num_hidden_layers=6, vocab_size=BOX_DIM + 1, bbox_size=BOX_DIM, hidden_size=512, property_embed_size=64, box_embed_size=512 - 64, intermediate_size=4 * 512, encoder_hidden_size=1024, num_attention_heads=8, lru_width=None, attention_window_size=16, conv1d_width=4, logits_soft_cap=30.0, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=1, pause_token_id=2, query_end_token_id=4, hidden_activation="gelu_pytorch_tanh", rope_theta=10000.0, block_types=("attention",), cross_attn_layers=tuple(range(10)), encoder_cross_attn_layers=tuple(range(10)), self_attn_layers=tuple(range(10)), global_attn_layers=tuple(range(10)), attention_dropout=0.0, num_key_value_heads=4, attention_bias=False, w_init_variance_scale=0.01, init_std=0.02, tie_word_embeddings=False, aux_heads=0, # How many n-token-ahead heads to add causal=True, layer_norm_eps=1e-5, dropout=0.0, special_token_count=SPECIAL_TOKENS, **kwargs, ): self.num_hidden_layers = num_hidden_layers self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_attention_heads = num_attention_heads self.lru_width = lru_width if lru_width is not None else hidden_size self.attention_window_size = attention_window_size self.conv1d_width = conv1d_width self.logits_soft_cap = logits_soft_cap self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.block_types = list(block_types) self.hidden_activation = hidden_activation self.head_dim = self.hidden_size // self.num_attention_heads self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads if self.num_key_value_heads > self.num_attention_heads: raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") self.cross_attn_layers = cross_attn_layers self.self_attn_layers = self_attn_layers self.global_attn_layers = global_attn_layers self.attention_dropout = attention_dropout self.attention_bias = attention_bias self.w_init_variance_scale = w_init_variance_scale self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers self.init_std = init_std self.tie_word_embeddings = tie_word_embeddings self.aux_heads = aux_heads self.encoder_hidden_size=encoder_hidden_size self.causal = causal self.encoder_cross_attn_layers = encoder_cross_attn_layers self.layer_norm_eps = layer_norm_eps self.dropout = dropout self.bbox_size = bbox_size self.pause_token_id = pause_token_id self.box_properties = BOX_PROPERTIES self.property_embed_size = property_embed_size self.box_embed_size = box_embed_size self.special_token_count = special_token_count self.query_end_token_id = query_end_token_id self.double_residual_flow = False super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs, ) @property def layers_block_type(self): return (self.block_types * 100)[: self.num_hidden_layers] ``` -------------------------------------------------------------------------------- /surya/common/surya/flash_attn_utils.py: -------------------------------------------------------------------------------- ```python from typing import Optional import torch import torch.nn.functional as F from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func from flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache from flash_attn.bert_padding import index_first_axis as _index_first_axis from flash_attn.bert_padding import pad_input def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: """ Retrieves indexing data required to repad unpadded (ragged) tensors. Arguments: attention_mask (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. Return: indices (`torch.Tensor`): The indices of non-masked tokens from the flattened input sequence. cu_seqlens (`torch.Tensor`): The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). max_seqlen_in_batch (`int`): Maximum sequence length in batch. """ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) def _upad_input( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, query_length: int, indices_k, cu_seqlens_k, max_seqlen_in_batch_k ): """ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary tensors for query, key, value tensors. Arguments: query_layer (`torch.Tensor`): Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). key_layer (`torch.Tensor`): Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). value_layer (`torch.Tensor`): Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). attention_mask (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. query_length (`int`): Target length. Return: query_layer (`torch.Tensor`): Query state without padding. Shape: (total_target_length, num_heads, head_dim). key_layer (`torch.Tensor`): Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). value_layer (`torch.Tensor`): Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). indices_q (`torch.Tensor`): The indices of non-masked tokens from the flattened input target sequence. (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). """ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = _index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) value_layer = _index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) if query_length == kv_seq_len: query_layer = _index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: raise NotImplementedError() return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) def flash_attn_prefill( module: torch.nn.Module, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: torch.Tensor, dropout: float, scaling: float, query_length: int, batch_size: int, indices_k: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_in_batch_k: int, **kwargs ): """ Wrapper for flash attention during the prefill stage query_states must have shape (batch_size, num_heads, seq_len, head_dim) key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim) This is the opposite of what is required by flash attention, but keeps parity with the HF convention query_length, batch_size, indices_k, cu_seqlens_k, and max_seqlen_in_batch_k should come from the flash attention kwargs """ query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2) q_flash, k_flash, v_flash, indices_q, cu_seq_lens, max_seq_lens = _upad_input( query_states, key_states, value_states, query_length, indices_k, cu_seqlens_k, max_seqlen_in_batch_k ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens # Returning None for attn_weights to match other attention interfaces flash_attn_out = _flash_attn_varlen_func( q_flash, k_flash, v_flash, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=scaling, causal=module.is_causal, ) return pad_input(flash_attn_out, indices_q, batch_size, query_length), None # NOTE: Does not support dropout, accepts argument as kwargs to maintain compatibility # This function is an order of magnitude faster than the prefill variant, or using the HF interface def flash_attn_decode( module: torch.nn.Module, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: torch.Tensor, scaling: float, **kwargs, ): """ Wrapper for flash attention during the decode stage query_states must have shape (batch_size, num_heads, seq_len, head_dim), 1 is the seq length in the decoding stage key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim) This is the opposite of what is required by flash attention, but keeps parity with the HF convention This function computes the left pad and cache seqlens to pass into FA2. For example - Given an attention_mask shaped (batch_size=2, seq_len=8), where 0 = padding, 1 = real token attention_mask = tensor([ [0, 0, 1, 1, 1, 0, 0, 0], # ← batch 0 [0, 1, 1, 1, 1, 1, 1, 0], # ← batch 1 ]) cache_leftpad = tensor([2, 1], dtype=torch.int32) cache_seqlens = tensor([5, 7], dtype=torch.int32) These values allow FlashAttention to use a static cache layout with efficient slicing during decoding. """ query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2) cache_leftpad = (attention_mask == 0).cumprod(dim=1).sum(dim=1).to(torch.int32) cache_seqlens = (attention_mask * torch.arange(attention_mask.size(1), device=attention_mask.device)).argmax(dim=1).to(torch.int32) + 1 # Returning None for attn_weights to match other attention interfaces return _flash_attn_with_kvcache( q=query_states, k_cache=key_states, v_cache=value_states, cache_leftpad=cache_leftpad, cache_seqlens=cache_seqlens, causal=module.is_causal, softmax_scale=scaling, ), None ``` -------------------------------------------------------------------------------- /surya/common/util.py: -------------------------------------------------------------------------------- ```python import copy from typing import List import torch from functools import lru_cache import torch.nn.functional as F from surya.common.polygon import PolygonBox def clean_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]: new_boxes = [] for box_obj in boxes: xs = [point[0] for point in box_obj.polygon] ys = [point[1] for point in box_obj.polygon] if max(xs) == min(xs) or max(ys) == min(ys): continue box = box_obj.bbox contained = False for other_box_obj in boxes: if other_box_obj.polygon == box_obj.polygon: continue other_box = other_box_obj.bbox if box == other_box: continue if ( box[0] >= other_box[0] and box[1] >= other_box[1] and box[2] <= other_box[2] and box[3] <= other_box[3] ): contained = True break if not contained: new_boxes.append(box_obj) return new_boxes def rescale_bbox(bbox, processor_size, image_size): page_width, page_height = processor_size img_width, img_height = image_size width_scaler = img_width / page_width height_scaler = img_height / page_height new_bbox = copy.deepcopy(bbox) new_bbox[0] = int(new_bbox[0] * width_scaler) new_bbox[1] = int(new_bbox[1] * height_scaler) new_bbox[2] = int(new_bbox[2] * width_scaler) new_bbox[3] = int(new_bbox[3] * height_scaler) return new_bbox def expand_bbox(bbox, expansion_factor=0.01): expansion_low = 1 - expansion_factor expansion_high = 1 + expansion_factor return [ bbox[0] * expansion_low, bbox[1] * expansion_low, bbox[2] * expansion_high, bbox[3] * expansion_high, ] SCRIPT_TOKEN_MAPPING = { "latin": "<SCRIPT-LATIN>", "punctuation": "<SCRIPT-PUNCTUATION>", "cyrillic": "<SCRIPT-CYRILLIC>", "arabic": "<SCRIPT-ARABIC>", "chinese": "<SCRIPT-CHINESE>", "japanese": "<SCRIPT-JAPANESE>", "korean": "<SCRIPT-KOREAN>", "symbols": "<SCRIPT-SYMBOLS>", "greek": "<SCRIPT-GREEK>", "armenian": "<SCRIPT-ARMENIAN>", "hebrew": "<SCRIPT-HEBREW>", "devanagari": "<SCRIPT-DEVANAGARI>", "bengali": "<SCRIPT-BENGALI>", "gurmukhi": "<SCRIPT-GURMUKHI>", "gujarati": "<SCRIPT-GUJARATI>", "oriya": "<SCRIPT-ORIYA>", "tamil": "<SCRIPT-TAMIL>", "telugu": "<SCRIPT-TELUGU>", "kannada": "<SCRIPT-KANNADA>", "malayalam": "<SCRIPT-MALAYALAM>", "sinhala": "<SCRIPT-SINHALA>", "thai": "<SCRIPT-THAI>", "lao": "<SCRIPT-LAO>", "myanmar": "<SCRIPT-MYANMAR>", "georgian": "<SCRIPT-GEORGIAN>", "ethiopic": "<SCRIPT-ETHIOPIC>", "khmer": "<SCRIPT-KHMER>", "mongolian": "<SCRIPT-MONGOLIAN>", "math": "<SCRIPT-MATH>", } @lru_cache(maxsize=1) def script_ranges(): script_categories = { # Latin-based scripts (used by English, French, German, etc.) "latin": [ (0x0041, 0x005A), # Latin uppercase A-Z (0x0061, 0x007A), # Latin lowercase a-z (0x0080, 0x00FF), # Latin-1 Supplement (0x0100, 0x017F), # Latin Extended-A (0x0180, 0x024F), # Latin Extended-B (0x0250, 0x02AF), # IPA Extensions (0x02B0, 0x02FF), # Spacing Modifier Letters (0x0300, 0x036F), # Combining Diacritical Marks (0x1E00, 0x1EFF), # Latin Extended Additional (0x2C60, 0x2C7F), # Latin Extended-C (0xA720, 0xA7FF), # Latin Extended-D ], # Punctuation, universal characters, and general symbols "punctuation": [ (0x0020, 0x0020), # Space (0x0021, 0x002F), # Basic punctuation and symbols (0x0030, 0x0039), # Digits 0-9 (0x003A, 0x0040), # More punctuation and symbols (0x005B, 0x0060), # More punctuation and symbols (0x007B, 0x007F), # More punctuation and symbols (0x2000, 0x206F), # General Punctuation ], # Cyrillic scripts (used by Russian, Ukrainian, etc.) "cyrillic": [ (0x0400, 0x04FF), # Cyrillic (0x0500, 0x052F), # Cyrillic Supplement ], # Arabic scripts "arabic": [ (0x0600, 0x06FF), # Arabic (0x0750, 0x077F), # Arabic Supplement (0x08A0, 0x08FF), # Arabic Extended-A ], # Chinese characters "chinese": [ (0x4E00, 0x9FFF), # Common CJK Unified Ideographs (0x3400, 0x4DBF), # CJK Extension A (0x20000, 0x2A6DF), # CJK Extension B ], # Japanese-specific scripts (excluding shared CJK) "japanese": [ (0x3040, 0x30FF), # Hiragana and Katakana ], # Korean-specific scripts "korean": [ (0x1100, 0x11FF), # Hangul Jamo (0x3130, 0x318F), # Hangul Compatibility Jamo (0xAC00, 0xD7AF), # Hangul Syllables ], # Various mathematical and technical symbols "symbols": [ (0x2070, 0x209F), # Superscripts and Subscripts (0x20A0, 0x20CF), # Currency Symbols (0x2100, 0x214F), # Letterlike Symbols (0x2150, 0x218F), # Number Forms (0x2190, 0x21FF), # Arrows (0x2200, 0x22FF), # Mathematical Operators (0x2300, 0x23FF), # Miscellaneous Technical (0x2500, 0x257F), # Box Drawing (0x2580, 0x259F), # Block Elements (0x25A0, 0x25FF), # Geometric Shapes (0x2600, 0x26FF), # Miscellaneous Symbols (0x2700, 0x27BF), # Dingbats (0x27C0, 0x27EF), # Miscellaneous Mathematical Symbols-A (0x2980, 0x29FF), # Miscellaneous Mathematical Symbols-B (0x2A00, 0x2AFF), # Supplemental Mathematical Operators (0x1D400, 0x1D7FF), # Mathematical Alphanumeric Symbols ], # Individual scripts for languages with unique writing systems "greek": [(0x0370, 0x03FF)], # Greek and Coptic "armenian": [(0x0530, 0x058F)], # Armenian "hebrew": [(0x0590, 0x05FF)], # Hebrew "devanagari": [(0x0900, 0x097F)], # Devanagari (Hindi, Sanskrit) "bengali": [(0x0980, 0x09FF)], # Bengali "gurmukhi": [(0x0A00, 0x0A7F)], # Gurmukhi (Punjabi) "gujarati": [(0x0A80, 0x0AFF)], # Gujarati "oriya": [(0x0B00, 0x0B7F)], # Oriya "tamil": [(0x0B80, 0x0BFF)], # Tamil "telugu": [(0x0C00, 0x0C7F)], # Telugu "kannada": [(0x0C80, 0x0CFF)], # Kannada "malayalam": [(0x0D00, 0x0D7F)], # Malayalam "sinhala": [(0x0D80, 0x0DFF)], # Sinhala "thai": [(0x0E00, 0x0E7F)], # Thai "lao": [(0x0E80, 0x0EFF)], # Lao "myanmar": [(0x1000, 0x109F)], # Myanmar "georgian": [(0x10A0, 0x10FF)], # Georgian "ethiopic": [(0x1200, 0x137F)], # Ethiopic "khmer": [(0x1780, 0x17FF)], # Khmer "mongolian": [(0x1800, 0x18AF)], # Mongolian } # Convert to a flat structure with character ranges flat_ranges = {} for category, ranges in script_categories.items(): # Create a set of all characters in this category char_set = set() for start, end in ranges: char_set.update(range(start, end + 1)) # Store the set in flat_ranges flat_ranges[category] = char_set return script_categories, flat_ranges def get_top_scripts(text: str, max_scripts: int = 5): script_categories, flat_ranges = script_ranges() char_count = {category: 0 for category in script_categories.keys()} for char in text: for category, char_set in flat_ranges.items(): if ord(char) in char_set: char_count[category] += 1 break top_scripts = sorted(char_count.items(), key=lambda x: x[1], reverse=True) top_scripts = [ts[0] for ts in top_scripts if ts[1] > 0] if "<math" in text: top_scripts.insert(0, "math") return top_scripts[:max_scripts] def is_flash_attn_2_supported(device: str | torch.device) -> bool: if not torch.cuda.is_available(): return False if "cuda" not in str(device): return False # Check CUDA version >= 12.0 cuda_version_str = torch.version.cuda if cuda_version_str is None: return False cuda_version = tuple(map(int, cuda_version_str.split("."))) if cuda_version < (12, 0): return False # Check GPU compute capability (Ampere, Ada, Hopper GPUs) major, minor = torch.cuda.get_device_capability() compute_capability = major + minor / 10 if compute_capability < 8.0: return False return True def pad_to_batch_size_repeat(tensor: torch.Tensor, batch_size: int): current_batch_size = tensor.shape[0] if current_batch_size >= batch_size: return tensor pad_size = batch_size - current_batch_size if pad_size < 0: return tensor # Repeat the last row pad_size times last_row = tensor[-1:].repeat(pad_size, 1, 1) # Concatenate original tensor with repeated last rows return torch.cat([tensor, last_row], dim=0) def pad_to_batch_size(tensor: torch.Tensor, batch_size: int): current_batch_size = tensor.shape[0] if current_batch_size >= batch_size: return tensor pad_size = batch_size - current_batch_size padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) return F.pad(tensor, padding, mode="constant", value=0) ``` -------------------------------------------------------------------------------- /surya/scripts/streamlit_app.py: -------------------------------------------------------------------------------- ```python import io import tempfile from typing import List import pypdfium2 import streamlit as st from surya.common.surya.schema import TaskNames from surya.models import load_predictors from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image from surya.debug.text import draw_text_on_image from PIL import Image, ImageDraw from surya.table_rec import TableResult from surya.detection import TextDetectionResult from surya.recognition import OCRResult from surya.layout import LayoutResult from surya.settings import settings from surya.common.util import rescale_bbox, expand_bbox @st.cache_resource() def load_predictors_cached(): return load_predictors() def ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15): from pdftext.extraction import plain_text_output with tempfile.NamedTemporaryFile(suffix=".pdf") as f: f.write(pdf_file.getvalue()) f.seek(0) # Sample the text from the middle of the PDF page_middle = page_count // 2 page_range = range( max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count) ) text = plain_text_output(f.name, page_range=page_range) sample_gap = len(text) // max_samples if len(text) == 0 or sample_gap == 0: return "This PDF has no text or very little text", ["no text"] if sample_gap < sample_len: sample_gap = sample_len # Split the text into samples for the model samples = [] for i in range(0, len(text), sample_gap): samples.append(text[i : i + sample_len]) results = predictors["ocr_error"](samples) label = "This PDF has good text." if results.labels.count("bad") / len(results.labels) > 0.2: label = "This PDF may have garbled or bad OCR text." return label, results.labels def text_detection(img) -> (Image.Image, TextDetectionResult): text_pred = predictors["detection"]([img])[0] text_polygons = [p.polygon for p in text_pred.bboxes] det_img = draw_polys_on_image(text_polygons, img.copy()) return det_img, text_pred def layout_detection(img) -> (Image.Image, LayoutResult): pred = predictors["layout"]([img])[0] polygons = [p.polygon for p in pred.bboxes] labels = [ f"{p.label}-{p.position}-{round(p.top_k[p.label], 2)}" for p in pred.bboxes ] layout_img = draw_polys_on_image( polygons, img.copy(), labels=labels, label_font_size=18 ) return layout_img, pred def table_recognition( img, highres_img, skip_table_detection: bool ) -> (Image.Image, List[TableResult]): if skip_table_detection: layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])] table_imgs = [highres_img] else: _, layout_pred = layout_detection(img) layout_tables_lowres = [ line.bbox for line in layout_pred.bboxes if line.label in ["Table", "TableOfContents"] ] table_imgs = [] layout_tables = [] for tb in layout_tables_lowres: highres_bbox = rescale_bbox(tb, img.size, highres_img.size) # Slightly expand the box highres_bbox = expand_bbox(highres_bbox) table_imgs.append(highres_img.crop(highres_bbox)) layout_tables.append(highres_bbox) table_preds = predictors["table_rec"](table_imgs) table_img = img.copy() for results, table_bbox in zip(table_preds, layout_tables): adjusted_bboxes = [] labels = [] colors = [] for item in results.cells: adjusted_bboxes.append( [ (item.bbox[0] + table_bbox[0]), (item.bbox[1] + table_bbox[1]), (item.bbox[2] + table_bbox[0]), (item.bbox[3] + table_bbox[1]), ] ) labels.append(item.label) if "Row" in item.label: colors.append("blue") else: colors.append("red") table_img = draw_bboxes_on_image( adjusted_bboxes, highres_img, labels=labels, label_font_size=18, color=colors, ) return table_img, table_preds # Function for OCR def ocr( img: Image.Image, highres_img: Image.Image, skip_text_detection: bool = False, recognize_math: bool = True, with_bboxes: bool = True, ) -> (Image.Image, OCRResult): if skip_text_detection: img = highres_img bboxes = [[[0, 0, img.width, img.height]]] else: bboxes = None if with_bboxes: tasks = [TaskNames.ocr_with_boxes] else: tasks = [TaskNames.ocr_without_boxes] img_pred = predictors["recognition"]( [img], task_names=tasks, bboxes=bboxes, det_predictor=predictors["detection"], highres_images=[highres_img], math_mode=recognize_math, return_words=True, )[0] bboxes = [line.bbox for line in img_pred.text_lines] text = [line.text for line in img_pred.text_lines] rec_img = draw_text_on_image(bboxes, text, img.size) word_boxes = [] for line in img_pred.text_lines: if line.words: word_boxes.extend([word.bbox for word in line.words]) box_img = img.copy() draw = ImageDraw.Draw(box_img) for word_box in word_boxes: draw.rectangle(word_box, outline="red", width=2) return rec_img, img_pred, box_img def open_pdf(pdf_file): stream = io.BytesIO(pdf_file.getvalue()) return pypdfium2.PdfDocument(stream) @st.cache_data() def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI): doc = open_pdf(pdf_file) renderer = doc.render( pypdfium2.PdfBitmap.to_pil, page_indices=[page_num - 1], scale=dpi / 72, ) png = list(renderer)[0] png_image = png.convert("RGB") doc.close() return png_image @st.cache_data() def page_counter(pdf_file): doc = open_pdf(pdf_file) doc_len = len(doc) doc.close() return doc_len st.set_page_config(layout="wide") col1, col2 = st.columns([0.5, 0.5]) predictors = load_predictors_cached() st.markdown(""" # Surya OCR Demo This app will let you try surya, a multilingual OCR toolkit. Notes: - This works best on documents with printed text. - For OCR, the formatting (math, italics, etc) will not show up in the image preview, but it will show up in the returned text lines. - If OCR doesn't work, try changing the resolution of your image (increase if below 2048px width, otherwise decrease). Find the project [here](https://github.com/VikParuchuri/surya). """) in_file = st.sidebar.file_uploader( "PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"] ) if in_file is None: st.stop() filetype = in_file.type page_count = None if "pdf" in filetype: page_count = page_counter(in_file) page_number = st.sidebar.number_input( f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count ) pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI) pil_image_highres = get_page_image( in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES ) else: pil_image = Image.open(in_file).convert("RGB") pil_image_highres = pil_image page_number = None run_text_det = st.sidebar.button("Run Text Detection") run_text_rec = st.sidebar.button("Run OCR") run_layout_det = st.sidebar.button("Run Layout Analysis") run_table_rec = st.sidebar.button("Run Table Rec") run_ocr_errors = st.sidebar.button("Run bad PDF text detection") use_pdf_boxes = st.sidebar.checkbox( "PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.", ) skip_table_detection = st.sidebar.checkbox( "Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.", ) skip_text_detection = st.sidebar.checkbox( "Skip text detection", value=False, help="OCR only: Skip text detection and treat the whole image as a single line.", ) recognize_math = st.sidebar.checkbox( "Recognize math in OCR", value=True, help="Enable math mode in OCR - this will recognize math.", ) ocr_with_boxes = st.sidebar.checkbox( "OCR with boxes", value=True, help="Enable OCR with boxes - this will predict character-level boxes.", ) if pil_image is None: st.stop() # Run Text Detection if run_text_det: det_img, text_pred = text_detection(pil_image) with col1: st.image(det_img, caption="Detected Text", use_container_width=True) st.json( text_pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True ) # Run layout if run_layout_det: layout_img, pred = layout_detection(pil_image) with col1: st.image(layout_img, caption="Detected Layout", use_container_width=True) st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True) # Run OCR if run_text_rec: rec_img, pred, box_img = ocr( pil_image, pil_image_highres, skip_text_detection, recognize_math, with_bboxes=ocr_with_boxes, ) with col1: st.image(rec_img, caption="OCR Result", use_container_width=True) json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"]) with json_tab: st.json(pred.model_dump(), expanded=False) with text_tab: st.text("\n".join([p.text for p in pred.text_lines])) st.image( box_img, caption="OCR with Word Boxes (for debugging)", use_container_width=True, ) if run_table_rec: table_img, pred = table_recognition( pil_image, pil_image_highres, skip_table_detection ) with col1: st.image(table_img, caption="Table Recognition", use_container_width=True) st.json([p.model_dump() for p in pred], expanded=True) if run_ocr_errors: if "pdf" not in filetype: st.error("This feature only works with PDFs.") label, results = ocr_errors(in_file, page_count) with col1: st.write(label) st.json(results) with col2: st.image(pil_image, caption="Uploaded Image", use_container_width=True) ``` -------------------------------------------------------------------------------- /benchmark/recognition.py: -------------------------------------------------------------------------------- ```python import re import unicodedata from collections import defaultdict import click from benchmark.utils.scoring import overlap_score, overlap_score_exact from surya.input.processing import convert_if_not_rgb from surya.debug.text import draw_text_on_image from surya.foundation import FoundationPredictor from surya.recognition import RecognitionPredictor from surya.settings import settings from surya.recognition.languages import CODE_TO_LANGUAGE from benchmark.utils.tesseract import ( tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE, ) from benchmark.utils.textract import textract_ocr_parallel import os import datasets import json import time from tabulate import tabulate KEY_LANGUAGES = [ "Chinese", "Spanish", "English", "Arabic", "Hindi", "Bengali", "Russian", "Japanese", ] def list_in(lst: str | list, lst2: list): if isinstance(lst, str): lst = [lst] return any([item in lst for item in lst2]) def standardize_bullets(text): patterns = [ r"•\s+", r"·\s+", r"○\s+", r"◦\s+", r"▪\s+", r"▫\s+", r"➢\s+", r"➤\s+", r"★\s+", r"✓\s+", r"✗\s+", r"✦\s+", r"\\bullet\s+", ] combined_pattern = "|".join(patterns) text = re.sub(combined_pattern, "*", text) return text def normalize_text(text: str) -> str: # Remove HTML tags text = re.sub(r"<[^>]+>", "", text) # Remove LaTeX tags text = re.sub(r"\\[a-zA-Z]+", "", text) text = standardize_bullets(text) text = unicodedata.normalize("NFKC", text) return text.strip().lower().replace(",", ".") @click.command(help="Benchmark recognition model.") @click.option( "--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"), ) @click.option( "--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=None ) @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) @click.option( "--tesseract", is_flag=True, help="Run benchmarks on tesseract.", default=False ) @click.option( "--textract", is_flag=True, help="Run benchmarks on textract.", default=False ) @click.option( "--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28 ) @click.option( "--textract_cpus", type=int, help="Number of CPUs to use for textract.", default=28 ) @click.option( "--languages", type=str, help="Comma-separated list of languages to benchmark.", default=None, ) @click.option( "--print_results", is_flag=True, ) def main( results_dir: str, max_rows: int, debug: bool, tesseract: bool, textract: bool, tess_cpus: int, textract_cpus: int, languages: str | None, print_results: bool, ): foundation_predictor = FoundationPredictor() rec_predictor = RecognitionPredictor(foundation_predictor) split = "train" dataset = datasets.load_dataset( settings.RECOGNITION_BENCH_DATASET_NAME, split=split ) if languages: languages = languages.split(",") dataset = dataset.filter( lambda x: list_in(x["language"], languages), num_proc=4 ) if max_rows and max_rows < len(dataset): dataset = dataset.shuffle(seed=1).select(range(max_rows)) images = list(dataset["image"]) images = convert_if_not_rgb(images) bboxes = list(dataset["bboxes"]) line_text = list(dataset["text"]) languages = list(dataset["language"]) print(f"Loaded {len(images)} images. Running OCR...") start = time.time() predictions_by_image = rec_predictor(images, None, bboxes=bboxes) surya_time = time.time() - start lang_list = [] for lang in languages: if not isinstance(lang, list): lang_list.append([lang]) else: lang_list.append(lang) surya_scores = defaultdict(list) img_surya_scores = [] outputs = [] for idx, (pred, ref_text, langs) in enumerate( zip(predictions_by_image, line_text, lang_list) ): pred_text = [line.text for line in pred.text_lines] score_ref_text = [normalize_text(line) for line in ref_text] score_pred_text = [normalize_text(text) for text in pred_text] image_scores, image_weights = overlap_score_exact( score_pred_text, score_ref_text ) normalized_scores = [ score / max(1, weight) for score, weight in zip(image_scores, image_weights) ] image_score = sum(image_scores) / max(1, sum(image_weights)) img_surya_scores.append(image_score) for lang in langs: surya_scores[CODE_TO_LANGUAGE[lang]].append(image_score) assert len(pred_text) == len(ref_text) == len(bboxes[idx]) if debug: for j, (pred_line, ref_line, score, bbox) in enumerate( zip(pred_text, ref_text, normalized_scores, bboxes[idx]) ): image_slice = images[idx].crop(bbox) outputs.append( { "image": image_slice, "bbox": bbox, "score": score, "pred": pred_line, "ref": ref_line, "langs": ",".join(langs), } ) if debug: out_ds = datasets.Dataset.from_list(outputs) out_ds.push_to_hub("datalab-to/rec_bench_outputs", private=True) flat_surya_scores = [score for lang in surya_scores for score in surya_scores[lang]] benchmark_stats = { "surya": { "avg_score": sum(flat_surya_scores) / max(1, len(flat_surya_scores)), "lang_scores": { lang: sum(scores) / max(1, len(scores)) for lang, scores in surya_scores.items() }, "time_per_img": surya_time / max(1, len(images)), } } result_path = os.path.join(results_dir, "rec_bench") os.makedirs(result_path, exist_ok=True) with open(os.path.join(result_path, "surya_scores.json"), "w+") as f: json.dump(surya_scores, f) if tesseract: tess_valid = [] tess_langs = [] for idx, lang in enumerate(lang_list): # Tesseract does not support all languages tess_lang = surya_lang_to_tesseract(lang[0]) if tess_lang is None: continue tess_valid.append(idx) tess_langs.append(tess_lang) tess_imgs = [images[i] for i in tess_valid] tess_bboxes = [bboxes[i] for i in tess_valid] tess_reference = [line_text[i] for i in tess_valid] start = time.time() tess_predictions = tesseract_ocr_parallel( tess_imgs, tess_bboxes, tess_langs, cpus=tess_cpus ) tesseract_time = time.time() - start tess_scores = defaultdict(list) for idx, (pred, ref_text, lang) in enumerate( zip(tess_predictions, tess_reference, tess_langs) ): image_scores, image_weights, _ = overlap_score(pred, ref_text) image_score = sum(image_scores) / max(1, sum(image_weights)) tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score) flat_tess_scores = [ score for lang in tess_scores for score in tess_scores[lang] ] benchmark_stats["tesseract"] = { "avg_score": sum(flat_tess_scores) / len(flat_tess_scores), "lang_scores": { lang: sum(scores) / len(scores) for lang, scores in tess_scores.items() }, "time_per_img": tesseract_time / len(tess_imgs), } with open(os.path.join(result_path, "tesseract_scores.json"), "w+") as f: json.dump(tess_scores, f) if textract: start = time.time() textract_predictions = textract_ocr_parallel(images, cpus=textract_cpus) textract_time = time.time() - start textract_scores = defaultdict(list) for idx, (pred, ref_text, lang) in enumerate( zip(textract_predictions, line_text, lang_list) ): image_scores, image_weights, _ = overlap_score(pred, ref_text) image_score = sum(image_scores) / max(1, sum(image_weights)) for lang in lang: textract_scores[CODE_TO_LANGUAGE[lang]].append(image_score) flat_textract_scores = [ score for lang in textract_scores for score in textract_scores[lang] ] benchmark_stats["textract"] = { "avg_score": sum(flat_textract_scores) / len(flat_textract_scores), "lang_scores": { lang: sum(scores) / len(scores) for lang, scores in textract_scores.items() }, "time_per_img": textract_time / len(images), } print(len(flat_textract_scores)) with open(os.path.join(result_path, "textract_scores.json"), "w+") as f: json.dump(textract_scores, f) with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(benchmark_stats, f) key_languages = [k for k in KEY_LANGUAGES if k in surya_scores] table_headers = ["Model", "Time per page (s)", "Avg Score"] + key_languages table_data = [ [ "surya", benchmark_stats["surya"]["time_per_img"], benchmark_stats["surya"]["avg_score"], ] + [benchmark_stats["surya"]["lang_scores"][lang] for lang in key_languages], ] if tesseract: table_data.append( [ "tesseract", benchmark_stats["tesseract"]["time_per_img"], benchmark_stats["tesseract"]["avg_score"], ] + [ benchmark_stats["tesseract"]["lang_scores"].get(lang, 0) for lang in key_languages ] ) if textract: table_data.append( [ "textract", benchmark_stats["textract"]["time_per_img"], benchmark_stats["textract"]["avg_score"], ] + [ benchmark_stats["textract"]["lang_scores"][lang] for lang in key_languages ], ) print(tabulate(table_data, headers=table_headers, tablefmt="github")) print( "Only a few major languages are displayed. See the result path for additional languages." ) if debug >= 1: bad_detections = [] for idx, (score, lang) in enumerate(zip(flat_surya_scores, lang_list)): if score < 0.8: bad_detections.append((idx, lang, score)) print(f"Found {len(bad_detections)} bad detections. Writing to file...") with open(os.path.join(result_path, "bad_detections.json"), "w+") as f: json.dump(bad_detections, f) if debug == 2: for idx, (image, pred, ref_text, bbox, lang) in enumerate( zip(images, predictions_by_image, line_text, bboxes, lang_list) ): pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png" ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png" pred_text = [line.text for line in pred.text_lines] pred_image = draw_text_on_image(bbox, pred_text, image.size) pred_image.save(os.path.join(result_path, pred_image_name)) ref_image = draw_text_on_image(bbox, ref_text, image.size) ref_image.save(os.path.join(result_path, ref_image_name)) image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png")) print(f"Wrote results to {result_path}") if print_results: for idx, (pred, ref_text) in enumerate(zip(predictions_by_image, line_text)): print(f"Image {idx}") print("----") for line_idx, (pred_line, ref_line) in enumerate( zip(pred.text_lines, ref_text) ): print(f"Sample {line_idx}") print(f"Pred: {pred_line.text}") print(f"Ref: {ref_line}") print() if settings.TORCH_DEVICE == "xla": import torch_xla.debug.metrics as met print(met.short_metrics_report()) if __name__ == "__main__": main() ``` -------------------------------------------------------------------------------- /surya/detection/processor.py: -------------------------------------------------------------------------------- ```python import warnings from typing import Any, Dict, List, Optional, Union import numpy as np from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from transformers.image_transforms import to_channel_dimension_format from transformers.image_utils import ( IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, ChannelDimension, ImageInput, PILImageResampling, infer_channel_dimension_format, make_list_of_images, ) from transformers.utils import TensorType import PIL.Image import torch from surya.common.s3 import S3DownloaderMixin class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor): r""" Constructs a Segformer image processor. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to resize the image's (height, width) dimensions to the specified `(size["height"], size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`): Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` method. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the `preprocess` method. do_rescale (`bool`, *optional*, defaults to `True`): Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` parameter in the `preprocess` method. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` method. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` method. image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): Mean to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. do_reduce_labels (`bool`, *optional*, defaults to `False`): Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the `preprocess` method. """ model_input_names = ["pixel_values"] def __init__( self, do_resize: bool = True, size: Dict[str, int] = None, resample: PILImageResampling = PILImageResampling.BILINEAR, do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_reduce_labels: bool = False, **kwargs, ) -> None: if "reduce_labels" in kwargs: warnings.warn( "The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use " "`do_reduce_labels` instead.", FutureWarning, ) do_reduce_labels = kwargs.pop("reduce_labels") super().__init__(**kwargs) size = size if size is not None else {"height": 512, "width": 512} size = get_size_dict(size) self.do_resize = do_resize self.size = size self.resample = resample self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.do_reduce_labels = do_reduce_labels self._valid_processor_keys = [ "images", "segmentation_maps", "do_resize", "size", "resample", "do_rescale", "rescale_factor", "do_normalize", "image_mean", "image_std", "do_reduce_labels", "return_tensors", "data_format", "input_data_format", ] @classmethod def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): """ Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint, reduce_labels=True)` """ image_processor_dict = image_processor_dict.copy() if "reduce_labels" in kwargs: image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels") return super().from_dict(image_processor_dict, **kwargs) def _preprocess( self, image: ImageInput, do_resize: bool, do_rescale: bool, do_normalize: bool, size: Optional[Dict[str, int]] = None, resample: PILImageResampling = None, rescale_factor: Optional[float] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ): if do_rescale: image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) if do_normalize: image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) return image def _preprocess_image( self, image: ImageInput, do_resize: bool = None, size: Dict[str, int] = None, resample: PILImageResampling = None, do_rescale: bool = None, rescale_factor: float = None, do_normalize: bool = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: """Preprocesses a single image.""" # All transformations expect numpy arrays. if input_data_format is None: input_data_format = infer_channel_dimension_format(image) image = self._preprocess( image=image, do_resize=do_resize, size=size, resample=resample, do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, input_data_format=input_data_format, ) if data_format is not None: image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) return image def __call__(self, images, segmentation_maps=None, **kwargs): """ Preprocesses a batch of images and optionally segmentation maps. Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be passed in as positional arguments. """ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) def preprocess( self, images: ImageInput, segmentation_maps: Optional[ImageInput] = None, do_resize: Optional[bool] = None, size: Optional[Dict[str, int]] = None, resample: PILImageResampling = None, do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_reduce_labels: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> PIL.Image.Image: """ Preprocess an image or batch of images. Args: images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. segmentation_maps (`ImageInput`, *optional*): Segmentation map to preprocess. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): Size of the image after `resize` is applied. resample (`int`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only has an effect if `do_resize` is set to `True`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image values between [0 - 1]. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): Rescale factor to rescale the image by if `do_rescale` is set to `True`. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Image mean. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation. do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The background label will be replaced by 255. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ do_resize = do_resize if do_resize is not None else self.do_resize do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_normalize = do_normalize if do_normalize is not None else self.do_normalize resample = resample if resample is not None else self.resample size = size if size is not None else self.size rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std images = make_list_of_images(images) images = [ self._preprocess_image( image=img, do_resize=do_resize, resample=resample, size=size, do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, data_format=data_format, input_data_format=input_data_format, ) for img in images ] data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) ``` -------------------------------------------------------------------------------- /surya/foundation/cache/dynamic_ops.py: -------------------------------------------------------------------------------- ```python from typing import Any, Dict, List, Optional, Tuple import torch from transformers import PretrainedConfig """ Special cache class for the surya foundation model that supports - 1) Static shape 2) A custom sliding window, where image tokens stay in cache, and text tokens are popped 3) Continuous batching - merging etc 4) Attention mask management - To match with what's currently in the cache Heavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079 """ class DynamicOpsCache: def __init__( self, config: PretrainedConfig, batch_size: int, max_cache_len: int, text_sliding_window: int, device: int, dtype: int, ): self.text_sliding_window = text_sliding_window self.num_layers = config.num_hidden_layers self.max_batch_size = batch_size self.max_cache_len = max_cache_len self.head_dim = ( getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads ) self._dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None else config.num_key_value_heads ) # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125 self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] cache_shape = ( self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim, ) device = torch.device(device) if device is not None else None for _ in range(config.num_hidden_layers): new_layer_key_cache = torch.zeros( cache_shape, dtype=self._dtype, device=device ) new_layer_value_cache = torch.zeros( cache_shape, dtype=self._dtype, device=device ) torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) self.attention_mask = torch.zeros( (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long ) self.text_token_counts = [ torch.zeros(self.max_batch_size, dtype=torch.long, device=device) for _ in range(self.num_layers) ] self.dtype = dtype self.device = device def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: prefill = cache_kwargs.get("prefill", False) update_fn = self._prefill_update if prefill else self._decode_update return update_fn( self.key_cache[layer_idx], self.value_cache[layer_idx], key_states, value_states, self.text_token_counts[layer_idx], cache_kwargs, ) def update_text_counts( self, merge_idxs: torch.Tensor, valid_batch_size: torch.Tensor, new_text_lens: torch.Tensor, ): new_text_len_tensor = new_text_lens.to(device=self.device) for layer_idx in range(self.num_layers): self.text_token_counts[layer_idx][merge_idxs] = new_text_len_tensor[ :valid_batch_size ] # Mirrors the logic from _prefill_update # Logic is better explained in this funcrtion def prefill_attention_mask_update( self, prefill_attention_mask: torch.Tensor, merge_idxs: torch.Tensor, valid_batch_mask: torch.Tensor, text_lengths: List[int], ): seq_len = prefill_attention_mask.shape[1] sliding_window = self.text_sliding_window total_cache_len = self.max_cache_len prefix_cache_space = total_cache_len - sliding_window for batch_idx, cache_idx in enumerate(merge_idxs): text_len = text_lengths[batch_idx] prefix_len = seq_len - text_len self.attention_mask[cache_idx] = 0 # Set default assert prefix_len > 0, "There are no prefix (image) tokens!" end_pos = prefix_cache_space # Handle prefix part - Which may be left padded if prefix_len <= prefix_cache_space: start_pos = prefix_cache_space - prefix_len self.attention_mask[cache_idx, start_pos:end_pos] = ( prefill_attention_mask[batch_idx, :prefix_len] ) else: self.attention_mask[cache_idx, :end_pos] = prefill_attention_mask[ batch_idx, prefix_len - prefix_cache_space : prefix_len ] # Handle text part, keeping sliding window in consideration # All of the left padding is before the prefix, so we can ignore the prefill_attention_mask here if text_len > 0: text_cache_start = prefix_cache_space if text_len <= sliding_window: self.attention_mask[ cache_idx, text_cache_start : text_cache_start + text_len ] = 1 else: self.attention_mask[cache_idx, -sliding_window:] = 1 # Slow impl for now - Prefill time is dominated by the large sequence length forward pass def _prefill_update( self, key_cache: torch.Tensor, value_cache: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, text_token_counts: torch.Tensor, cache_kwargs: Optional[Dict[str, Any]] = None, ): cache_idxs: List[int] = cache_kwargs.get("cache_idxs", None) text_lengths: List[int] = cache_kwargs.get("text_lengths", None) assert cache_idxs is not None, "cache_idxs must be specified during prefill" assert text_lengths is not None, "text_lengths must be specified during prefill" _, _, seq_len, _ = key_states.shape total_cache_len = self.max_cache_len sliding_window = self.text_sliding_window prefix_cache_space = total_cache_len - sliding_window for batch_idx, cache_idx in enumerate(cache_idxs): text_len = text_lengths[batch_idx] prefix_len = seq_len - text_len ###### Handle Image Tokens (Prefix) ##### # Place image tokens in appropriate cache space, aligned to the **right edge** assert prefix_len > 0, "There are no prefix (image) tokens!" # prefix_len may be greater than the prefix cache space due to left padding - This happens when # a different batch element has a large input text during prefill, causing others to have a lot of # left padding. We can safely take the last `prefix_cache_space` elements from the kv states, since # `prefix_cache_space` is large enough to fit any image, and the rest **has to be** padding end_pos = prefix_cache_space if prefix_len <= prefix_cache_space: start_pos = prefix_cache_space - prefix_len key_cache[cache_idx, :, start_pos:end_pos] = key_states[ batch_idx, :, :prefix_len ] value_cache[cache_idx, :, start_pos:end_pos] = value_states[ batch_idx, :, :prefix_len ] else: key_cache[cache_idx, :, :end_pos] = key_states[ batch_idx, :, prefix_len - prefix_cache_space : prefix_len ] value_cache[cache_idx, :, :end_pos] = value_states[ batch_idx, :, prefix_len - prefix_cache_space : prefix_len ] ###### Handle Text Tokens ##### # Text tokens start at the **left edge** of sliding window cache space if text_len > 0: text_cache_start = prefix_cache_space if text_len <= sliding_window: key_cache[ cache_idx, :, text_cache_start : text_cache_start + text_len ] = key_states[batch_idx, :, prefix_len : prefix_len + text_len] value_cache[ cache_idx, :, text_cache_start : text_cache_start + text_len ] = value_states[batch_idx, :, prefix_len : prefix_len + text_len] else: start_in_text = text_len - sliding_window key_cache[ cache_idx, :, text_cache_start : text_cache_start + sliding_window, ] = key_states[ batch_idx, :, prefix_len + start_in_text : prefix_len + text_len ] value_cache[ cache_idx, :, text_cache_start : text_cache_start + sliding_window, ] = value_states[ batch_idx, :, prefix_len + start_in_text : prefix_len + text_len ] # Return the full key/value states (not just cached) for use in subsequent layers return key_states, value_states # """ # Matches the logic of the decode update, but needs to be called before the updates # since some parts of the model depend on the attention mask # """ def decode_attention_mask_update( self, num_valid_tokens: torch.Tensor, cache_idxs: List[int] ): sliding_window = self.text_sliding_window text_cache_start = self.max_cache_len - sliding_window # Using text_token_counts of first layer, should be same for all though current_text_lens = self.text_token_counts[0] cache_idxs_tensor = torch.tensor(cache_idxs, device=current_text_lens.device) # Get current text lengths for the relevant cache indices current_lens = current_text_lens[cache_idxs_tensor] new_text_lens = current_lens + num_valid_tokens is_full = new_text_lens > sliding_window # Handle full caches - set entire sliding window to 1 if is_full.any(): full_mask = is_full full_cache_idxs = cache_idxs_tensor[full_mask] self.attention_mask[full_cache_idxs, text_cache_start:] = 1 # Handle non-full caches - set specific ranges to 1 if (~is_full).any(): non_full_mask = ~is_full non_full_cache_idxs = cache_idxs_tensor[non_full_mask] non_full_current_lens = current_lens[non_full_mask] non_full_valid_tokens = num_valid_tokens[non_full_mask] max_valid_tokens = ( non_full_valid_tokens.max().item() if len(non_full_valid_tokens) > 0 else 0 ) if max_valid_tokens > 0: batch_size = len(non_full_cache_idxs) offset_range = torch.arange( max_valid_tokens, device=current_text_lens.device ) batch_offsets = offset_range.unsqueeze(0).expand(batch_size, -1) start_positions = non_full_current_lens.unsqueeze(1) valid_token_counts = non_full_valid_tokens.unsqueeze(1) position_indices = start_positions + batch_offsets valid_mask = batch_offsets < valid_token_counts row_indices = non_full_cache_idxs.unsqueeze(1).expand( -1, max_valid_tokens )[valid_mask] col_indices = text_cache_start + position_indices[valid_mask] self.attention_mask[row_indices, col_indices] = 1 """ Static cache update - respects per-batch text token limits - per-batch valid token lengths (right-padded inputs) kv states are expected to have shape [batch_size, kv_heads, T_pad, head_dim] They may have different `true` lengths, to account for multi token preds, or beacon tokens Expects `num_valid_tokens` in cache_kwargs: a tensor of shape (B,) indicating the number of actual (non-padded) tokens to add per batch element. """ def _decode_update( self, key_cache: torch.Tensor, value_cache: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, text_token_counts: torch.Tensor, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: num_valid_tokens: torch.Tensor = cache_kwargs.get( "num_valid_tokens" ) # shape: (B,) assert num_valid_tokens is not None, ( "`num_valid_tokens` must be provided in `cache_kwargs`" ) device = key_states.device batch_size, num_head, seq_len, head_dim = key_states.shape sliding_window = self.text_sliding_window max_cache_len = self.max_cache_len cache_text_start = max_cache_len - sliding_window new_text_lengths = text_token_counts + num_valid_tokens slide_amounts = torch.clamp(new_text_lengths - sliding_window, min=0) needs_rotate = slide_amounts > 0 # Rotate the cache if needed if torch.any(needs_rotate): k_slice = key_cache[:, :, -sliding_window:] # shape: [B, H, W, D] v_slice = value_cache[:, :, -sliding_window:] # same shape cache_indices = ( torch.arange(sliding_window, device=device) .unsqueeze(0) .repeat(batch_size, 1) ) # [B, W] rolled_indices = ( cache_indices + slide_amounts.unsqueeze(1) ) % sliding_window # [B, W] # We need to expand indices to shape: [B, 1, W, 1] to broadcast with k_slice rolled_indices = ( rolled_indices.unsqueeze(1) .unsqueeze(-1) .expand(-1, num_head, -1, head_dim) ) k_slice_rolled = k_slice.gather(dim=2, index=rolled_indices) v_slice_rolled = v_slice.gather(dim=2, index=rolled_indices) key_cache[:, :, -sliding_window:] = k_slice_rolled value_cache[:, :, -sliding_window:] = v_slice_rolled # Insert only **valid tokens** into the cache. These are **right aligned** within the input sequence insert_positions = torch.where( needs_rotate, max_cache_len - num_valid_tokens, text_token_counts + cache_text_start, ) max_tokens = num_valid_tokens.max().item() offsets = torch.arange(max_tokens, device=device).unsqueeze(0) # [1, max_T] valid_mask = offsets < num_valid_tokens.unsqueeze(1) # [B, max_T] src_indices = (seq_len - num_valid_tokens).unsqueeze(1) + offsets # [B, max_T] src_indices = src_indices.clamp(max=seq_len - 1) # safety tgt_indices = insert_positions.unsqueeze(1) + offsets # [B, max_T] tgt_indices = tgt_indices.clamp(max=max_cache_len - 1) # safety src_idx_exp = ( src_indices.unsqueeze(1) .unsqueeze(-1) .expand(batch_size, num_head, max_tokens, head_dim) ) tgt_idx_exp = ( tgt_indices.unsqueeze(1) .unsqueeze(-1) .expand(batch_size, num_head, max_tokens, head_dim) ) valid_mask_exp = ( valid_mask.unsqueeze(1) .unsqueeze(-1) .expand(batch_size, num_head, max_tokens, head_dim) ) k_src = torch.gather(key_states, 2, src_idx_exp) v_src = torch.gather(value_states, 2, src_idx_exp) k_src = k_src * valid_mask_exp v_src = v_src * valid_mask_exp # Write into cache key_cache.scatter_(2, tgt_idx_exp, k_src) value_cache.scatter_(2, tgt_idx_exp, v_src) # In-place edit - Mutates text_token_counts += num_valid_tokens text_token_counts.clamp_(max=sliding_window) return key_cache, value_cache # We have a non-uniform cache, so its better to not return it and handle any logic # that requires this ourselves def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: raise NotImplementedError() ``` -------------------------------------------------------------------------------- /surya/table_rec/__init__.py: -------------------------------------------------------------------------------- ```python from copy import deepcopy from itertools import chain from typing import List import numpy as np import torch from PIL import Image from tqdm import tqdm from surya.common.xla import mark_step from surya.common.predictor import BasePredictor from surya.table_rec.schema import TableCell, TableRow, TableCol, TableResult from surya.common.polygon import PolygonBox from surya.settings import settings from surya.table_rec.loader import TableRecModelLoader from surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM, CATEGORY_TO_ID, MERGE_KEYS, \ MERGE_VALUES from surya.table_rec.shaper import LabelShaper class TableRecPredictor(BasePredictor): model_loader_cls = TableRecModelLoader batch_size = settings.TABLE_REC_BATCH_SIZE default_batch_sizes = { "cpu": 8, "mps": 8, "cuda": 32, "xla": 16 } def __call__(self, images: List[Image.Image], batch_size: int | None = None) -> List[TableResult]: return self.batch_table_recognition(images, batch_size) def inference_loop( self, encoder_hidden_states: torch.Tensor, batch_input_ids: torch.Tensor, current_batch_size: int, batch_size: int ): shaper = LabelShaper() batch_predictions = [[] for _ in range(current_batch_size)] max_tokens = settings.TABLE_REC_MAX_BOXES decoder_position_ids = torch.ones_like(batch_input_ids[0, :, 0], dtype=torch.int64, device=self.model.device).cumsum( 0) - 1 inference_token_count = batch_input_ids.shape[1] if settings.TABLE_REC_STATIC_CACHE: encoder_hidden_states = self.pad_to_batch_size(encoder_hidden_states, batch_size) batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) # Move to device after padding for XLA encoder_hidden_states = encoder_hidden_states.to(self.model.device) batch_input_ids = batch_input_ids.to(self.model.device) self.model.decoder.model._setup_cache(self.model.config, batch_size, self.model.device, self.model.dtype) with settings.INFERENCE_MODE(): token_count = 0 all_done = torch.zeros(encoder_hidden_states.shape[0], dtype=torch.bool, device=self.model.device) while token_count < max_tokens: is_prefill = token_count == 0 return_dict = self.model.decoder( input_ids=batch_input_ids, encoder_hidden_states=encoder_hidden_states, cache_position=decoder_position_ids, use_cache=True, prefill=is_prefill ) decoder_position_ids = decoder_position_ids[-1:] + 1 # Get predictions for each box element box_properties = [] done = [] # Pre-process all logits at once processed_logits = {} for k, _, mode in BOX_PROPERTIES: k_logits = return_dict["box_property_logits"][k][:, -1, :] # Get all batch logits at once if mode == "classification": # Process all classification logits in one operation items = torch.argmax(k_logits, dim=-1) if k == "category": done = (items == self.model.decoder.config.eos_token_id) | (items == self.model.decoder.config.pad_token_id) items = items - SPECIAL_TOKENS processed_logits[k] = items elif mode == "regression": if k == "bbox": k_logits = k_logits * BOX_DIM processed_logits[k] = k_logits elif k == "colspan": k_logits = torch.clamp(k_logits, min=1) processed_logits[k] = torch.round(k_logits) items = {k: processed_logits[k].cpu() for k, _, _ in BOX_PROPERTIES} for j in range(current_batch_size): box_property = {} for k, _, mode in BOX_PROPERTIES: if mode == "classification": box_property[k] = int(items[k][j].item()) elif mode == "regression": if k == "bbox": box_property[k] = items[k][j].tolist() elif k == "colspan": box_property[k] = int(items[k][j].item()) box_properties.append(box_property) all_done = all_done | done all_done_cpu = all_done.cpu() if all_done_cpu[:current_batch_size].all(): break batch_input_ids = torch.tensor(shaper.dict_to_labels(box_properties), dtype=torch.long) batch_input_ids = batch_input_ids.unsqueeze(1) # Add sequence length dimension for j, (box_property, status) in enumerate(zip(box_properties, all_done_cpu)): if not status: batch_predictions[j].append(box_property) token_count += inference_token_count inference_token_count = batch_input_ids.shape[1] if settings.TABLE_REC_STATIC_CACHE: batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) # Move to device after padding for XLA batch_input_ids = batch_input_ids.to(self.model.device) return batch_predictions def batch_table_recognition( self, images: List, batch_size=None) -> List[TableResult]: assert all([isinstance(image, Image.Image) for image in images]) if batch_size is None: batch_size = self.get_batch_size() if len(images) == 0: return [] query_items = [] for image in images: query_items.append({ "polygon": [[0, 0], [image.width, 0], [image.width, image.height], [0, image.height]], "category": CATEGORY_TO_ID["Table"], "colspan": 0, "merges": 0, "is_header": 0 }) output_order = [] for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables", disable=self.disable_tqdm): batch_query_items = query_items[i:i + batch_size] batch_images = images[i:i + batch_size] batch_images = [image.convert("RGB") for image in batch_images] # also copies the images current_batch_size = len(batch_images) orig_sizes = [image.size for image in batch_images] model_inputs = self.processor(images=batch_images, query_items=batch_query_items) batch_pixel_values = model_inputs["pixel_values"] batch_input_ids = model_inputs["input_ids"] batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=self.model.dtype) if settings.TABLE_REC_STATIC_CACHE: batch_pixel_values = self.pad_to_batch_size(batch_pixel_values, batch_size) # Move to device after padding for XLA batch_pixel_values = batch_pixel_values.to(self.model.device) shaper = LabelShaper() # We only need to process each image once with settings.INFERENCE_MODE(): encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state # Inference to get rows and columns rowcol_predictions = self.inference_loop( encoder_hidden_states, batch_input_ids, current_batch_size, batch_size ) mark_step() row_query_items = [] row_encoder_hidden_states = [] idx_map = [] columns = [] for j, img_predictions in enumerate(rowcol_predictions): for row_prediction in img_predictions: polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"]) if row_prediction["category"] == CATEGORY_TO_ID["Table-row"]: row_query_items.append({ "polygon": polygon, "category": row_prediction["category"], "colspan": 0, "merges": 0, "is_header": int(row_prediction["is_header"] == 1) }) row_encoder_hidden_states.append(encoder_hidden_states[j]) idx_map.append(j) elif row_prediction["category"] == CATEGORY_TO_ID["Table-column"]: columns.append({ "polygon": polygon, "category": row_prediction["category"], "colspan": 0, "merges": 0, "is_header": int(row_prediction["is_header"] == 1) }) # Re-inference to predict cells row_encoder_hidden_states = torch.stack(row_encoder_hidden_states) row_inputs = self.processor(images=None, query_items=row_query_items, columns=columns, convert_images=False) row_input_ids = row_inputs["input_ids"] cell_predictions = [] for j in range(0, len(row_input_ids), batch_size): cell_batch_hidden_states = row_encoder_hidden_states[j:j + batch_size] cell_batch_input_ids = row_input_ids[j:j + batch_size] cell_batch_size = len(cell_batch_input_ids) cell_predictions.extend( self.inference_loop(cell_batch_hidden_states, cell_batch_input_ids, cell_batch_size, batch_size) ) mark_step() result = self.decode_batch_predictions(rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper) output_order.extend(result) return output_order def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper): results = [] for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)): row_cell_predictions = [c for i, c in enumerate(cell_predictions) if idx_map[i] == j] # Each row prediction matches a cell prediction rows = [] cells = [] columns = [] cell_id = 0 row_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-row"]] col_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-column"]] # Generate table columns for z, col_prediction in enumerate(col_predictions): polygon = shaper.convert_bbox_to_polygon(col_prediction["bbox"]) polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) columns.append( TableCol( polygon=polygon, col_id=z, is_header=col_prediction["is_header"] == 1 ) ) # Generate table rows for z, row_prediction in enumerate(row_predictions): polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"]) polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) row = TableRow( polygon=polygon, row_id=z, is_header=row_prediction["is_header"] == 1 ) rows.append(row) # Get cells that span multiple columns within a row spanning_cells = [] for l, spanning_cell in enumerate(row_cell_predictions[z]): polygon = shaper.convert_bbox_to_polygon(spanning_cell["bbox"]) polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) colspan = max(1, int(spanning_cell["colspan"])) if colspan == 1 and spanning_cell["merges"] not in MERGE_VALUES: # Skip single column cells if they don't merge continue if PolygonBox(polygon=polygon).height < row.height * .85: # Spanning cell must cover most of the row continue spanning_cells.append( TableCell( polygon=polygon, row_id=z, rowspan=1, cell_id=cell_id, within_row_id=l, colspan=colspan, merge_up=spanning_cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]], merge_down=spanning_cell["merges"] in [MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]], is_header=row.is_header or z == 0 ) ) cell_id += 1 # Add cells - either add spanning cells (multiple cols), or generate a cell based on row/col used_spanning_cells = set() skip_columns = 0 for l, col in enumerate(columns): if skip_columns: skip_columns -= 1 continue cell_polygon = row.intersection_polygon(col) cell_added = False for zz, spanning_cell in enumerate(spanning_cells): cell_polygonbox = PolygonBox(polygon=cell_polygon) intersection_pct = cell_polygonbox.intersection_pct(spanning_cell) # Make sure cells intersect, and that the spanning cell is wider than the current cell (takes up multiple columns) correct_col_width = sum([col.width for col in columns[l:l + spanning_cell.colspan]]) if intersection_pct > .9: if spanning_cell.width > (correct_col_width * .85): cell_added = True if zz not in used_spanning_cells: used_spanning_cells.add(zz) spanning_cell.col_id = l cells.append(spanning_cell) skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell else: used_spanning_cells.add(zz) # Skip this spanning cell if not cell_added: cells.append( TableCell( polygon=cell_polygon, row_id=z, rowspan=1, cell_id=cell_id, within_row_id=l, colspan=1, merge_up=False, merge_down=False, col_id=l, is_header=row.is_header or col.is_header or z == 0 ) ) cell_id += 1 # Turn cells into a row grid grid_cells = deepcopy([ [cell for cell in cells if cell.row_id == row.row_id] for row in rows ]) # Merge cells across rows for z, grid_row in enumerate(grid_cells[1:]): prev_row = grid_cells[z] for l, cell in enumerate(grid_row): if l >= len(prev_row): continue above_cell = prev_row[l] if all([ above_cell.merge_down, cell.merge_up, above_cell.col_id == cell.col_id, above_cell.colspan == cell.colspan, ]): above_cell.merge(cell) above_cell.rowspan += cell.rowspan grid_row[l] = above_cell merged_cells_all = list(chain.from_iterable(grid_cells)) used_ids = set() merged_cells = [] for cell in merged_cells_all: if cell.cell_id in used_ids: continue used_ids.add(cell.cell_id) merged_cells.append(cell) result = TableResult( cells=merged_cells, unmerged_cells=cells, rows=rows, cols=columns, image_bbox=[0, 0, orig_size[0], orig_size[1]], ) results.append(result) return results ```