#
tokens: 49683/50000 27/133 files (page 2/5)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 2 of 5. Use http://codebase.md/datalab-to/surya?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── ISSUE_TEMPLATE
│   │   ├── breaking-bug-report.md
│   │   ├── feature_request.md
│   │   └── output-bug-report.md
│   └── workflows
│       ├── benchmarks.yml
│       ├── ci.yml
│       ├── cla.yml
│       ├── publish.yml
│       └── scripts.yml
├── .gitignore
├── .pre-commit-config.yaml
├── benchmark
│   ├── detection.py
│   ├── layout.py
│   ├── ordering.py
│   ├── recognition.py
│   ├── table_recognition.py
│   ├── texify.py
│   └── utils
│       ├── __init__.py
│       ├── bbox.py
│       ├── metrics.py
│       ├── scoring.py
│       ├── tatr.py
│       ├── tesseract.py
│       ├── textract.py
│       └── verify_benchmark_scores.py
├── CITATION.cff
├── CLA.md
├── detect_layout.py
├── detect_text.py
├── LICENSE
├── ocr_app.py
├── ocr_latex.py
├── ocr_text.py
├── poetry.lock
├── pyproject.toml
├── pytest.ini
├── README.md
├── signatures
│   └── version1
│       └── cla.json
├── static
│   ├── fonts
│   │   └── .gitignore
│   └── images
│       ├── arabic_layout.jpg
│       ├── arabic_reading.jpg
│       ├── arabic_text.jpg
│       ├── arabic.jpg
│       ├── benchmark_chart_small.png
│       ├── benchmark_chart.png
│       ├── benchmark_layout_chart.png
│       ├── benchmark_rec_chart.png
│       ├── benchmark_tablerec_acc.png
│       ├── benchmark_tablerec_speed.png
│       ├── chi_hind_layout.jpg
│       ├── chi_hind_orig.jpg
│       ├── chi_hind_reading.jpg
│       ├── chi_hind_text.jpg
│       ├── chi_hind.jpg
│       ├── chinese_layout.jpg
│       ├── chinese_reading.jpg
│       ├── chinese_text.jpg
│       ├── chinese.jpg
│       ├── excerpt_layout.png
│       ├── excerpt_reading.jpg
│       ├── excerpt_text.png
│       ├── excerpt.png
│       ├── funsd_layout.jpg
│       ├── funsd_reading.jpg
│       ├── funsd_text.jpg
│       ├── funsd.png
│       ├── gcloud_full_langs.png
│       ├── gcloud_rec_bench.png
│       ├── hindi_layout.jpg
│       ├── hindi_reading.jpg
│       ├── hindi_text.jpg
│       ├── hindi.jpg
│       ├── japanese_layout.jpg
│       ├── japanese_reading.jpg
│       ├── japanese_tablerec.png
│       ├── japanese_text.jpg
│       ├── japanese.jpg
│       ├── latex_ocr.png
│       ├── nyt_layout.jpg
│       ├── nyt_order.jpg
│       ├── nyt_text.jpg
│       ├── nyt.jpg
│       ├── paper_layout.jpg
│       ├── paper_reading.jpg
│       ├── paper_tablerec.png
│       ├── paper_text.jpg
│       ├── paper.jpg
│       ├── pres_layout.jpg
│       ├── pres_reading.jpg
│       ├── pres_tablerec.png
│       ├── pres_text.jpg
│       ├── pres.png
│       ├── rec_acc_table.png
│       ├── scanned_layout.jpg
│       ├── scanned_reading.jpg
│       ├── scanned_tablerec.png
│       ├── scanned_tablerec2.png
│       ├── scanned_text.jpg
│       ├── scanned.png
│       ├── surya_rec_perf.png
│       ├── table_rec.png
│       ├── textbook_layout.jpg
│       ├── textbook_order.jpg
│       ├── textbook_text.jpg
│       └── textbook.jpg
├── surya
│   ├── __init__.py
│   ├── common
│   │   ├── __init__.py
│   │   ├── adetr
│   │   │   └── decoder.py
│   │   ├── donut
│   │   │   ├── encoder.py
│   │   │   └── processor.py
│   │   ├── load.py
│   │   ├── polygon.py
│   │   ├── predictor.py
│   │   ├── pretrained.py
│   │   ├── s3.py
│   │   ├── surya
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   ├── decoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── embedder
│   │   │   │   └── __init__.py
│   │   │   ├── encoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── flash_attn_utils.py
│   │   │   ├── processor
│   │   │   │   ├── __init__.py
│   │   │   │   ├── schema.py
│   │   │   │   └── tokenizer.py
│   │   │   └── schema.py
│   │   ├── util.py
│   │   └── xla.py
│   ├── debug
│   │   ├── draw.py
│   │   ├── fonts.py
│   │   ├── katex.js
│   │   ├── render_html.py
│   │   └── text.py
│   ├── detection
│   │   ├── __init__.py
│   │   ├── heatmap.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoderdecoder.py
│   │   ├── parallel.py
│   │   ├── processor.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── foundation
│   │   ├── __init__.py
│   │   ├── cache
│   │   │   ├── __init__.py
│   │   │   ├── dynamic_ops.py
│   │   │   └── static_ops.py
│   │   ├── loader.py
│   │   └── util.py
│   ├── input
│   │   ├── load.py
│   │   └── processing.py
│   ├── layout
│   │   ├── __init__.py
│   │   ├── label.py
│   │   └── schema.py
│   ├── logging.py
│   ├── models.py
│   ├── ocr_error
│   │   ├── __init__.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoder.py
│   │   ├── schema.py
│   │   └── tokenizer.py
│   ├── recognition
│   │   ├── __init__.py
│   │   ├── languages.py
│   │   ├── postprocessing.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── scripts
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── detect_layout.py
│   │   ├── detect_text.py
│   │   ├── finetune_ocr.py
│   │   ├── hf_to_s3.py
│   │   ├── ocr_latex.py
│   │   ├── ocr_text.py
│   │   ├── run_streamlit_app.py
│   │   ├── run_texify_app.py
│   │   ├── streamlit_app.py
│   │   ├── table_recognition.py
│   │   └── texify_app.py
│   ├── settings.py
│   └── table_rec
│       ├── __init__.py
│       ├── loader.py
│       ├── model
│       │   ├── __init__.py
│       │   ├── config.py
│       │   ├── decoder.py
│       │   ├── encoder.py
│       │   └── encoderdecoder.py
│       ├── processor.py
│       ├── schema.py
│       └── shaper.py
├── table_recognition.py
├── tests
│   ├── assets
│   │   └── test_latex.png
│   ├── conftest.py
│   ├── test_detection.py
│   ├── test_foundation.py
│   ├── test_latex_ocr.py
│   ├── test_layout.py
│   ├── test_ocr_errors.py
│   ├── test_recognition.py
│   └── test_table_rec.py
└── texify_app.py
```

# Files

--------------------------------------------------------------------------------
/benchmark/utils/tatr.py:
--------------------------------------------------------------------------------

```python
  1 | import torch
  2 | from transformers import AutoModelForObjectDetection
  3 | from surya.settings import settings
  4 | import numpy as np
  5 | 
  6 | 
  7 | class MaxResize(object):
  8 |     def __init__(self, max_size=800):
  9 |         self.max_size = max_size
 10 | 
 11 |     def __call__(self, image):
 12 |         width, height = image.size
 13 |         current_max_size = max(width, height)
 14 |         scale = self.max_size / current_max_size
 15 |         resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
 16 | 
 17 |         return resized_image
 18 | 
 19 | 
 20 | def to_tensor(image):
 21 |     # Convert PIL Image to NumPy array
 22 |     np_image = np.array(image).astype(np.float32)
 23 | 
 24 |     # Rearrange dimensions to [C, H, W] format
 25 |     np_image = np_image.transpose((2, 0, 1))
 26 | 
 27 |     # Normalize to [0.0, 1.0]
 28 |     np_image /= 255.0
 29 | 
 30 |     return torch.from_numpy(np_image)
 31 | 
 32 | 
 33 | def normalize(tensor, mean, std):
 34 |     for t, m, s in zip(tensor, mean, std):
 35 |         t.sub_(m).div_(s)
 36 |     return tensor
 37 | 
 38 | 
 39 | def structure_transform(image):
 40 |     image = MaxResize(1000)(image)
 41 |     tensor = to_tensor(image)
 42 |     normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
 43 |     return normalized_tensor
 44 | 
 45 | 
 46 | def box_cxcywh_to_xyxy(x):
 47 |     x_c, y_c, w, h = x.unbind(-1)
 48 |     b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
 49 |     return torch.stack(b, dim=1)
 50 | 
 51 | 
 52 | def rescale_bboxes(out_bbox, size):
 53 |     width, height = size
 54 |     boxes = box_cxcywh_to_xyxy(out_bbox)
 55 |     boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
 56 |     return boxes
 57 | 
 58 | 
 59 | def outputs_to_objects(outputs, img_sizes, id2label):
 60 |     m = outputs.logits.softmax(-1).max(-1)
 61 |     batch_labels = list(m.indices.detach().cpu().numpy())
 62 |     batch_scores = list(m.values.detach().cpu().numpy())
 63 |     batch_bboxes = outputs['pred_boxes'].detach().cpu()
 64 | 
 65 |     batch_objects = []
 66 |     for i in range(len(img_sizes)):
 67 |         pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])]
 68 |         pred_scores = batch_scores[i]
 69 |         pred_labels = batch_labels[i]
 70 | 
 71 |         objects = []
 72 |         for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
 73 |             class_label = id2label[int(label)]
 74 |             if not class_label == 'no object':
 75 |                 objects.append({
 76 |                     'label': class_label,
 77 |                     'score': float(score),
 78 |                     'bbox': [float(elem) for elem in bbox]}
 79 |                 )
 80 | 
 81 |         rows = []
 82 |         cols = []
 83 |         for cell in objects:
 84 |             if cell["label"] == "table column":
 85 |                 cols.append(cell)
 86 | 
 87 |             if cell["label"] == "table row":
 88 |                 rows.append(cell)
 89 |         batch_objects.append({
 90 |             "rows": rows,
 91 |             "cols": cols
 92 |         })
 93 | 
 94 |     return batch_objects
 95 | 
 96 | 
 97 | def load_tatr():
 98 |     return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL)
 99 | 
100 | 
101 | def batch_inference_tatr(model, images, batch_size):
102 |     device = model.device
103 |     rows_cols = []
104 |     for i in range(0, len(images), batch_size):
105 |         batch_images = images[i:i + batch_size]
106 |         pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device)
107 | 
108 |         # forward pass
109 |         with torch.no_grad():
110 |             outputs = model(pixel_values)
111 | 
112 |         id2label = model.config.id2label
113 |         id2label[len(model.config.id2label)] = "no object"
114 |         rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label))
115 |     return rows_cols
```

--------------------------------------------------------------------------------
/benchmark/texify.py:
--------------------------------------------------------------------------------

```python
  1 | import os.path
  2 | import re
  3 | import time
  4 | from pathlib import Path
  5 | from typing import List
  6 | 
  7 | import click
  8 | import datasets
  9 | from tabulate import tabulate
 10 | from bs4 import BeautifulSoup
 11 | 
 12 | from surya.common.surya.schema import TaskNames
 13 | from surya.settings import settings
 14 | from surya.foundation import FoundationPredictor
 15 | from surya.recognition import RecognitionPredictor, OCRResult
 16 | import json
 17 | from rapidfuzz.distance import Levenshtein
 18 | 
 19 | 
 20 | def normalize_text(text):
 21 |     soup = BeautifulSoup(text, "html.parser")
 22 |     # Unwrap math tags
 23 |     for tag in soup.find_all():
 24 |         if tag.name == "math":
 25 |             tag.unwrap()
 26 |     text = soup.get_text()
 27 |     text = re.sub(r"\n", " ", text)
 28 |     text = re.sub(r"\s+", " ", text)
 29 |     return text.strip()
 30 | 
 31 | 
 32 | def score_text(predictions, references):
 33 |     lev_dist = []
 34 |     for p, r in zip(predictions, references):
 35 |         p = normalize_text(p)
 36 |         r = normalize_text(r)
 37 |         lev_dist.append(Levenshtein.normalized_distance(p, r))
 38 | 
 39 |     return sum(lev_dist) / len(lev_dist)
 40 | 
 41 | 
 42 | def inference_texify(
 43 |     source_data, predictor: RecognitionPredictor, line_mode: bool = False
 44 | ):
 45 |     images = [sd["image"] for sd in source_data]
 46 |     mode = TaskNames.ocr_with_boxes if line_mode else TaskNames.block_without_boxes
 47 |     tasks = [mode] * len(images)
 48 |     bboxes = [[[0, 0, image.width, image.height]] for image in images]
 49 |     texify_predictions: List[OCRResult] = predictor(images, tasks, bboxes=bboxes)
 50 |     out_data = [
 51 |         {
 52 |             "text": texify_predictions[i].text_lines[0].text,
 53 |             "equation": source_data[i]["equation"],
 54 |         }
 55 |         for i in range(len(texify_predictions))
 56 |     ]
 57 | 
 58 |     return out_data
 59 | 
 60 | 
 61 | @click.command(help="Benchmark the performance of texify.")
 62 | @click.option(
 63 |     "--ds_name",
 64 |     type=str,
 65 |     help="Path to dataset file with source images/equations.",
 66 |     default=settings.TEXIFY_BENCHMARK_DATASET,
 67 | )
 68 | @click.option(
 69 |     "--results_dir",
 70 |     type=str,
 71 |     help="Path to JSON file with benchmark results.",
 72 |     default=os.path.join(settings.RESULT_DIR, "benchmark"),
 73 | )
 74 | @click.option(
 75 |     "--max_rows", type=int, help="Maximum number of images to benchmark.", default=None
 76 | )
 77 | @click.option(
 78 |     "--line_mode", is_flag=True, help="Use line mode for texify.", default=False
 79 | )
 80 | def main(ds_name: str, results_dir: str, max_rows: int, line_mode: bool):
 81 |     foundation_predictor = FoundationPredictor()
 82 |     predictor = RecognitionPredictor(foundation_predictor)
 83 |     ds = datasets.load_dataset(ds_name, split="train")
 84 | 
 85 |     if max_rows:
 86 |         ds = ds.filter(lambda x, idx: idx < max_rows, with_indices=True)
 87 | 
 88 |     start = time.time()
 89 |     predictions = inference_texify(ds, predictor, line_mode)
 90 |     time_taken = time.time() - start
 91 | 
 92 |     text = [p["text"] for p in predictions]
 93 |     references = [p["equation"] for p in predictions]
 94 |     scores = score_text(text, references)
 95 | 
 96 |     write_data = {
 97 |         "scores": scores,
 98 |         "text": [{"prediction": p, "reference": r} for p, r in zip(text, references)],
 99 |     }
100 | 
101 |     score_table = [["texify", write_data["scores"], time_taken]]
102 |     score_headers = ["edit", "time taken (s)"]
103 |     score_dirs = ["⬇", "⬇"]
104 | 
105 |     score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)]
106 |     table = tabulate(score_table, headers=["Method", *score_headers])
107 |     print()
108 |     print(table)
109 | 
110 |     result_path = Path(results_dir) / "texify_bench"
111 |     result_path.mkdir(parents=True, exist_ok=True)
112 |     with open(result_path / "results.json", "w", encoding="utf-8") as f:
113 |         json.dump(write_data, f, indent=4)
114 | 
115 | 
116 | if __name__ == "__main__":
117 |     main()
118 | 
```

--------------------------------------------------------------------------------
/surya/table_rec/model/encoder.py:
--------------------------------------------------------------------------------

```python
 1 | from typing import Optional, Union, Tuple
 2 | 
 3 | import torch
 4 | import torch.nn as nn
 5 | 
 6 | from surya.common.donut.encoder import DonutSwinPreTrainedModel, DonutSwinModelOutput, DonutSwinEmbeddings, DonutSwinEncoder
 7 | 
 8 | 
 9 | class DonutSwinModel(DonutSwinPreTrainedModel):
10 |     def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
11 |         super().__init__(config)
12 |         self.config = config
13 |         self.num_layers = len(config.depths)
14 |         self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
15 | 
16 |         self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token)
17 |         self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
18 | 
19 |         self.position_embeddings = None
20 |         if hasattr(config, "encoder_length"):
21 |             self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size))
22 | 
23 |         # Initialize weights and apply final processing
24 |         self.post_init()
25 | 
26 |     def get_input_embeddings(self):
27 |         return self.embeddings.patch_embeddings
28 | 
29 |     def _prune_heads(self, heads_to_prune):
30 |         """
31 |         Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
32 |         class PreTrainedModel
33 |         """
34 |         for layer, heads in heads_to_prune.items():
35 |             self.encoder.layer[layer].attention.prune_heads(heads)
36 | 
37 |     def forward(
38 |         self,
39 |         pixel_values: Optional[torch.FloatTensor] = None,
40 |         bool_masked_pos: Optional[torch.BoolTensor] = None,
41 |         head_mask: Optional[torch.FloatTensor] = None,
42 |         output_attentions: Optional[bool] = None,
43 |         output_hidden_states: Optional[bool] = None,
44 |         interpolate_pos_encoding: bool = False,
45 |         return_dict: Optional[bool] = None,
46 |     ) -> Union[Tuple, DonutSwinModelOutput]:
47 |         r"""
48 |         bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
49 |             Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
50 |         """
51 |         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
52 |         output_hidden_states = (
53 |             output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
54 |         )
55 |         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
56 | 
57 |         if pixel_values is None:
58 |             raise ValueError("You have to specify pixel_values")
59 | 
60 |         # Prepare head mask if needed
61 |         # 1.0 in head_mask indicate we keep the head
62 |         # attention_probs has shape bsz x n_heads x N x N
63 |         # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
64 |         # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
65 |         head_mask = self.get_head_mask(head_mask, len(self.config.depths))
66 | 
67 |         embedding_output, input_dimensions = self.embeddings(
68 |             pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
69 |         )
70 | 
71 |         encoder_outputs = self.encoder(
72 |             embedding_output,
73 |             input_dimensions,
74 |             head_mask=head_mask,
75 |             output_attentions=output_attentions,
76 |             output_hidden_states=output_hidden_states,
77 |             return_dict=return_dict,
78 |         )
79 | 
80 |         last_hidden_state = encoder_outputs[0]
81 | 
82 |         if self.position_embeddings is not None:
83 |             last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :]
84 | 
85 |         return DonutSwinModelOutput(
86 |             last_hidden_state=last_hidden_state,
87 |         )
88 | 
```

--------------------------------------------------------------------------------
/surya/table_rec/model/encoderdecoder.py:
--------------------------------------------------------------------------------

```python
  1 | from dataclasses import dataclass
  2 | from typing import Optional, Union, Tuple, Dict
  3 | 
  4 | import torch
  5 | from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig
  6 | 
  7 | from surya.common.pretrained import SuryaPreTrainedModel
  8 | from surya.common.s3 import S3DownloaderMixin
  9 | from surya.table_rec.model.decoder import SuryaTableRecDecoder
 10 | from surya.table_rec.model.encoder import DonutSwinModel
 11 | from transformers.utils import ModelOutput
 12 | 
 13 | 
 14 | @dataclass
 15 | class TableRecOutput(ModelOutput):
 16 |     box_property_logits: Dict[str, torch.FloatTensor]
 17 |     decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
 18 | 
 19 | 
 20 | class TableRecEncoderDecoderModel(S3DownloaderMixin, SuryaPreTrainedModel):
 21 |     config_class = VisionEncoderDecoderConfig
 22 |     base_model_prefix = "vision_encoder_decoder"
 23 |     main_input_name = "pixel_values"
 24 |     supports_gradient_checkpointing = True
 25 |     _supports_param_buffer_assignment = False
 26 | 
 27 |     def __init__(
 28 |         self,
 29 |         config: Optional[PretrainedConfig] = None,
 30 |         encoder: Optional[PreTrainedModel] = None,
 31 |         decoder: Optional[PreTrainedModel] = None,
 32 |         **kwargs,
 33 |     ):
 34 |         # initialize with config
 35 |         # make sure input & output embeddings is not tied
 36 |         config.tie_word_embeddings = False
 37 |         config.decoder.tie_word_embeddings = False
 38 |         super().__init__(config, **kwargs)
 39 | 
 40 |         if encoder is None:
 41 |             encoder = DonutSwinModel(config.encoder)
 42 | 
 43 |         if decoder is None:
 44 |             decoder = SuryaTableRecDecoder(
 45 |                 config.decoder, attn_implementation=config._attn_implementation
 46 |             )
 47 | 
 48 |         self.encoder = encoder
 49 |         self.decoder = decoder
 50 | 
 51 |         # make sure that the individual model's config refers to the shared config
 52 |         # so that the updates to the config will be synced
 53 |         self.encoder.config = self.config.encoder
 54 |         self.decoder.config = self.config.decoder
 55 | 
 56 |     def get_encoder(self):
 57 |         return self.encoder
 58 | 
 59 |     def get_decoder(self):
 60 |         return self.decoder
 61 | 
 62 |     def get_output_embeddings(self):
 63 |         return self.decoder.get_output_embeddings()
 64 | 
 65 |     def set_output_embeddings(self, new_embeddings):
 66 |         return self.decoder.set_output_embeddings(new_embeddings)
 67 | 
 68 |     def forward(
 69 |         self,
 70 |         decoder_input_ids: torch.LongTensor = None,
 71 |         decoder_cache_position: Optional[torch.LongTensor] = None,
 72 |         decoder_attention_mask: Optional[torch.LongTensor] = None,
 73 |         encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
 74 |         use_cache: Optional[bool] = None,
 75 |         return_dict: Optional[bool] = None,
 76 |         **kwargs,
 77 |     ) -> Union[Tuple[torch.FloatTensor], TableRecOutput]:
 78 |         kwargs_decoder = {
 79 |             argument[len("decoder_") :]: value
 80 |             for argument, value in kwargs.items()
 81 |             if argument.startswith("decoder_")
 82 |         }
 83 | 
 84 |         # Decode
 85 |         decoder_outputs = self.decoder(
 86 |             input_labels=decoder_input_ids,
 87 |             input_boxes_counts=None,
 88 |             cache_position=decoder_cache_position,
 89 |             attention_mask=decoder_attention_mask,
 90 |             encoder_hidden_states=encoder_outputs,
 91 |             encoder_attention_mask=None,
 92 |             use_cache=use_cache,
 93 |             **kwargs_decoder,
 94 |         )
 95 | 
 96 |         return TableRecOutput(
 97 |             box_property_logits=decoder_outputs.box_property_logits,
 98 |             decoder_hidden_states=decoder_outputs.hidden_states,
 99 |         )
100 | 
101 |     def resize_token_embeddings(self, *args, **kwargs):
102 |         raise NotImplementedError(
103 |             "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
104 |             " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
105 |         )
106 | 
107 |     def _reorder_cache(self, past_key_values, beam_idx):
108 |         # apply decoder cache reordering here
109 |         return self.decoder._reorder_cache(past_key_values, beam_idx)
110 | 
```

--------------------------------------------------------------------------------
/signatures/version1/cla.json:
--------------------------------------------------------------------------------

```json
  1 | {
  2 |   "signedContributors": [
  3 |     {
  4 |       "name": "rishiraj",
  5 |       "id": 44090649,
  6 |       "comment_id": 2170578748,
  7 |       "created_at": "2024-06-15T19:31:20Z",
  8 |       "repoId": 741297064,
  9 |       "pullRequestNo": 135
 10 |     },
 11 |     {
 12 |       "name": "mmacvicar",
 13 |       "id": 59354,
 14 |       "comment_id": 2236493182,
 15 |       "created_at": "2024-07-18T13:17:43Z",
 16 |       "repoId": 741297064,
 17 |       "pullRequestNo": 152
 18 |     },
 19 |     {
 20 |       "name": "jimexist",
 21 |       "id": 622789,
 22 |       "comment_id": 2255151376,
 23 |       "created_at": "2024-07-29T07:23:55Z",
 24 |       "repoId": 741297064,
 25 |       "pullRequestNo": 160
 26 |     },
 27 |     {
 28 |       "name": "michaeldriscoll-avant",
 29 |       "id": 85255083,
 30 |       "comment_id": 2259143427,
 31 |       "created_at": "2024-07-30T20:21:33Z",
 32 |       "repoId": 741297064,
 33 |       "pullRequestNo": 161
 34 |     },
 35 |     {
 36 |       "name": "EdoardoPona",
 37 |       "id": 29152472,
 38 |       "comment_id": 2271115922,
 39 |       "created_at": "2024-08-06T11:58:00Z",
 40 |       "repoId": 741297064,
 41 |       "pullRequestNo": 167
 42 |     },
 43 |     {
 44 |       "name": "hidenori-endo",
 45 |       "id": 15546605,
 46 |       "comment_id": 2307217499,
 47 |       "created_at": "2024-08-23T14:31:17Z",
 48 |       "repoId": 741297064,
 49 |       "pullRequestNo": 182
 50 |     },
 51 |     {
 52 |       "name": "dobosevych",
 53 |       "id": 12053536,
 54 |       "comment_id": 2430376828,
 55 |       "created_at": "2024-10-22T21:48:34Z",
 56 |       "repoId": 741297064,
 57 |       "pullRequestNo": 220
 58 |     },
 59 |     {
 60 |       "name": "iammosespaulr",
 61 |       "id": 28682735,
 62 |       "comment_id": 2447941238,
 63 |       "created_at": "2024-10-30T17:55:23Z",
 64 |       "repoId": 741297064,
 65 |       "pullRequestNo": 235
 66 |     },
 67 |     {
 68 |       "name": "ArthurMor4is",
 69 |       "id": 42987302,
 70 |       "comment_id": 2515315717,
 71 |       "created_at": "2024-12-03T18:37:45Z",
 72 |       "repoId": 741297064,
 73 |       "pullRequestNo": 255
 74 |     },
 75 |     {
 76 |       "name": "tarun-menta",
 77 |       "id": 66506307,
 78 |       "comment_id": 2543457960,
 79 |       "created_at": "2024-12-15T05:43:33Z",
 80 |       "repoId": 741297064,
 81 |       "pullRequestNo": 261
 82 |     },
 83 |     {
 84 |       "name": "jonaskahn",
 85 |       "id": 4338500,
 86 |       "comment_id": 2556622097,
 87 |       "created_at": "2024-12-20T09:36:20Z",
 88 |       "repoId": 741297064,
 89 |       "pullRequestNo": 269
 90 |     },
 91 |     {
 92 |       "name": "kumsumit",
 93 |       "id": 95072784,
 94 |       "comment_id": 2574534622,
 95 |       "created_at": "2025-01-07T07:05:59Z",
 96 |       "repoId": 741297064,
 97 |       "pullRequestNo": 276
 98 |     },
 99 |     {
100 |       "name": "kevinhu",
101 |       "id": 6051736,
102 |       "comment_id": 2614135351,
103 |       "created_at": "2025-01-25T23:34:12Z",
104 |       "repoId": 741297064,
105 |       "pullRequestNo": 291
106 |     },
107 |     {
108 |       "name": "zanussbaum",
109 |       "id": 33707069,
110 |       "comment_id": 3008673416,
111 |       "created_at": "2025-06-26T14:20:46Z",
112 |       "repoId": 741297064,
113 |       "pullRequestNo": 403
114 |     },
115 |     {
116 |       "name": "mebriki",
117 |       "id": 35892987,
118 |       "comment_id": 3154706976,
119 |       "created_at": "2025-08-05T10:54:27Z",
120 |       "repoId": 741297064,
121 |       "pullRequestNo": 418
122 |     },
123 |     {
124 |       "name": "starikovplusplus",
125 |       "id": 56602036,
126 |       "comment_id": 3168958011,
127 |       "created_at": "2025-08-08T18:29:50Z",
128 |       "repoId": 741297064,
129 |       "pullRequestNo": 423
130 |     },
131 |     {
132 |       "name": "sandy0kwon",
133 |       "id": 78377296,
134 |       "comment_id": 3207932260,
135 |       "created_at": "2025-08-20T20:07:15Z",
136 |       "repoId": 741297064,
137 |       "pullRequestNo": 434
138 |     },
139 |     {
140 |       "name": "n0kovo",
141 |       "id": 16690056,
142 |       "comment_id": 3208251881,
143 |       "created_at": "2025-08-20T22:22:06Z",
144 |       "repoId": 741297064,
145 |       "pullRequestNo": 435
146 |     },
147 |     {
148 |       "name": "davidxifeng",
149 |       "id": 158052,
150 |       "comment_id": 3249594859,
151 |       "created_at": "2025-09-03T14:52:16Z",
152 |       "repoId": 741297064,
153 |       "pullRequestNo": 445
154 |     },
155 |     {
156 |       "name": "u-ashish",
157 |       "id": 14264791,
158 |       "comment_id": 3258734182,
159 |       "created_at": "2025-09-05T15:16:48Z",
160 |       "repoId": 741297064,
161 |       "pullRequestNo": 447
162 |     },
163 |     {
164 |       "name": "Mohking1",
165 |       "id": 63689545,
166 |       "comment_id": 3314908963,
167 |       "created_at": "2025-09-20T11:21:42Z",
168 |       "repoId": 741297064,
169 |       "pullRequestNo": 462
170 |     },
171 |     {
172 |       "name": "wkpark",
173 |       "id": 232347,
174 |       "comment_id": 3330009557,
175 |       "created_at": "2025-09-24T17:42:55Z",
176 |       "repoId": 741297064,
177 |       "pullRequestNo": 464
178 |     }
179 |   ]
180 | }
```

--------------------------------------------------------------------------------
/surya/layout/__init__.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import List
  2 | 
  3 | from PIL import Image
  4 | 
  5 | from surya.common.predictor import BasePredictor
  6 | from surya.layout.schema import LayoutBox, LayoutResult
  7 | from surya.settings import settings
  8 | from surya.foundation import FoundationPredictor, TaskNames
  9 | from surya.foundation.util import prediction_to_polygon_batch
 10 | from surya.input.processing import convert_if_not_rgb
 11 | from surya.layout.label import LAYOUT_PRED_RELABEL
 12 | from surya.common.util import clean_boxes
 13 | 
 14 | 
 15 | class LayoutPredictor(BasePredictor):
 16 |     batch_size = settings.LAYOUT_BATCH_SIZE
 17 |     default_batch_sizes = {"cpu": 4, "mps": 4, "cuda": 32, "xla": 16}
 18 | 
 19 |     # Override base init - Do not load model
 20 |     def __init__(self, foundation_predictor: FoundationPredictor):
 21 |         self.foundation_predictor = foundation_predictor
 22 |         self.processor = self.foundation_predictor.processor
 23 |         self.bbox_size = self.foundation_predictor.model.config.bbox_size
 24 |         self.tasks = self.foundation_predictor.tasks
 25 | 
 26 |     # Special handling for disable tqdm to pass into foundation predictor
 27 |     # Make sure they are kept in sync
 28 |     @property
 29 |     def disable_tqdm(self) -> bool:
 30 |         return super().disable_tqdm
 31 | 
 32 |     @disable_tqdm.setter
 33 |     def disable_tqdm(self, value: bool) -> None:
 34 |         self._disable_tqdm = bool(value)
 35 |         self.foundation_predictor.disable_tqdm = bool(value)
 36 | 
 37 |     def __call__(
 38 |         self, images: List[Image.Image], batch_size: int | None = None, top_k: int = 5
 39 |     ) -> List[LayoutResult]:
 40 |         assert all([isinstance(image, Image.Image) for image in images])
 41 |         if batch_size is None:
 42 |             batch_size = self.get_batch_size()
 43 | 
 44 |         if len(images) == 0:
 45 |             return []
 46 | 
 47 |         images = convert_if_not_rgb(images)
 48 |         images = [self.processor.image_processor(image) for image in images]
 49 | 
 50 |         predicted_tokens, batch_bboxes, scores, topk_scores = (
 51 |             self.foundation_predictor.prediction_loop(
 52 |                 images=images,
 53 |                 input_texts=["" for _ in range(len(images))],
 54 |                 task_names=[TaskNames.layout for _ in range(len(images))],
 55 |                 batch_size=batch_size,
 56 |                 max_lookahead_tokens=0,  # Do not do MTP for layout
 57 |                 top_k=5,
 58 |                 max_sliding_window=576,
 59 |                 max_tokens=500,
 60 |                 tqdm_desc="Recognizing Layout"
 61 |             )
 62 |         )
 63 | 
 64 |         image_sizes = [img.shape for img in images]
 65 |         predicted_polygons = prediction_to_polygon_batch(
 66 |             batch_bboxes, image_sizes, self.bbox_size, self.bbox_size // 2
 67 |         )
 68 |         layout_results = []
 69 |         for image, image_tokens, image_polygons, image_scores, image_topk_scores in zip(
 70 |             images, predicted_tokens, predicted_polygons, scores, topk_scores
 71 |         ):
 72 |             layout_boxes = []
 73 |             for z, (tok, poly, score, tok_topk) in enumerate(
 74 |                 zip(image_tokens, image_polygons, image_scores, image_topk_scores)
 75 |             ):
 76 |                 if tok == self.processor.eos_token_id:
 77 |                     break
 78 | 
 79 |                 predicted_label = self.processor.decode([tok], "layout")
 80 |                 label = LAYOUT_PRED_RELABEL.get(predicted_label)
 81 |                 if not label:
 82 |                     # Layout can sometimes return unknown labels from other objectives
 83 |                     continue
 84 | 
 85 |                 top_k_dict = {}
 86 |                 for k, v in tok_topk.items():
 87 |                     topk_label = self.processor.decode([k], "layout")
 88 |                     if topk_label in LAYOUT_PRED_RELABEL:
 89 |                         topk_label = LAYOUT_PRED_RELABEL[topk_label]
 90 |                     if not topk_label.strip():
 91 |                         continue
 92 |                     top_k_dict.update({topk_label: v})
 93 |                 layout_boxes.append(
 94 |                     LayoutBox(
 95 |                         polygon=poly.tolist(),
 96 |                         label=label,
 97 |                         position=z,
 98 |                         top_k=top_k_dict,
 99 |                         confidence=score,
100 |                     )
101 |                 )
102 |             layout_boxes = clean_boxes(layout_boxes)
103 |             layout_results.append(
104 |                 LayoutResult(
105 |                     bboxes=layout_boxes,
106 |                     image_bbox=[0, 0, image.shape[1], image.shape[0]],
107 |                 )  # Image is numpy array
108 |             )
109 | 
110 |         assert len(layout_results) == len(images)
111 |         return layout_results
112 | 
```

--------------------------------------------------------------------------------
/surya/scripts/table_recognition.py:
--------------------------------------------------------------------------------

```python
  1 | import os
  2 | import click
  3 | import copy
  4 | import json
  5 | from collections import defaultdict
  6 | 
  7 | from surya.logging import configure_logging, get_logger
  8 | from surya.scripts.config import CLILoader
  9 | from surya.foundation import FoundationPredictor
 10 | from surya.layout import LayoutPredictor
 11 | from surya.table_rec import TableRecPredictor
 12 | from surya.debug.draw import draw_bboxes_on_image
 13 | from surya.common.util import rescale_bbox, expand_bbox
 14 | from surya.settings import settings
 15 | 
 16 | configure_logging()
 17 | logger = get_logger()
 18 | 
 19 | 
 20 | @click.command(help="Detect layout of an input file or folder (PDFs or image).")
 21 | @CLILoader.common_options
 22 | @click.option(
 23 |     "--skip_table_detection",
 24 |     is_flag=True,
 25 |     help="Tables are already cropped, so don't re-detect tables.",
 26 |     default=False,
 27 | )
 28 | def table_recognition_cli(input_path: str, skip_table_detection: bool, **kwargs):
 29 |     loader = CLILoader(input_path, kwargs, highres=True)
 30 | 
 31 |     foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
 32 |     layout_predictor = LayoutPredictor(foundation_predictor)
 33 |     table_rec_predictor = TableRecPredictor()
 34 | 
 35 |     pnums = []
 36 |     prev_name = None
 37 |     for i, name in enumerate(loader.names):
 38 |         if prev_name is None or prev_name != name:
 39 |             pnums.append(0)
 40 |         else:
 41 |             pnums.append(pnums[-1] + 1)
 42 | 
 43 |         prev_name = name
 44 | 
 45 |     layout_predictions = layout_predictor(loader.images)
 46 | 
 47 |     table_imgs = []
 48 |     table_counts = []
 49 | 
 50 |     for layout_pred, img, highres_img in zip(
 51 |         layout_predictions, loader.images, loader.highres_images
 52 |     ):
 53 |         # The table may already be cropped
 54 |         if skip_table_detection:
 55 |             table_imgs.append(highres_img)
 56 |             table_counts.append(1)
 57 |         else:
 58 |             # The bbox for the entire table
 59 |             bbox = [
 60 |                 line.bbox
 61 |                 for line in layout_pred.bboxes
 62 |                 if line.label in ["Table", "TableOfContents"]
 63 |             ]
 64 |             # Number of tables per page
 65 |             table_counts.append(len(bbox))
 66 | 
 67 |             if len(bbox) == 0:
 68 |                 continue
 69 | 
 70 |             page_table_imgs = []
 71 |             highres_bbox = []
 72 |             for bb in bbox:
 73 |                 highres_bb = rescale_bbox(bb, img.size, highres_img.size)
 74 |                 highres_bb = expand_bbox(highres_bb)
 75 |                 page_table_imgs.append(highres_img.crop(highres_bb))
 76 |                 highres_bbox.append(highres_bb)
 77 | 
 78 |             table_imgs.extend(page_table_imgs)
 79 | 
 80 |     table_preds = table_rec_predictor(table_imgs)
 81 | 
 82 |     img_idx = 0
 83 |     prev_count = 0
 84 |     table_predictions = defaultdict(list)
 85 |     for i in range(sum(table_counts)):
 86 |         while i >= prev_count + table_counts[img_idx]:
 87 |             prev_count += table_counts[img_idx]
 88 |             img_idx += 1
 89 | 
 90 |         pred = table_preds[i]
 91 |         orig_name = loader.names[img_idx]
 92 |         pnum = pnums[img_idx]
 93 |         table_img = table_imgs[i]
 94 | 
 95 |         out_pred = pred.model_dump()
 96 |         out_pred["page"] = pnum + 1
 97 |         table_idx = i - prev_count
 98 |         out_pred["table_idx"] = table_idx
 99 |         table_predictions[orig_name].append(out_pred)
100 | 
101 |         if loader.save_images:
102 |             rows = [line.bbox for line in pred.rows]
103 |             cols = [line.bbox for line in pred.cols]
104 |             row_labels = [f"Row {line.row_id}" for line in pred.rows]
105 |             col_labels = [f"Col {line.col_id}" for line in pred.cols]
106 |             cells = [line.bbox for line in pred.cells]
107 | 
108 |             rc_image = copy.deepcopy(table_img)
109 |             rc_image = draw_bboxes_on_image(
110 |                 rows, rc_image, labels=row_labels, label_font_size=20, color="blue"
111 |             )
112 |             rc_image = draw_bboxes_on_image(
113 |                 cols, rc_image, labels=col_labels, label_font_size=20, color="red"
114 |             )
115 |             rc_image.save(
116 |                 os.path.join(
117 |                     loader.result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png"
118 |                 )
119 |             )
120 | 
121 |             cell_image = copy.deepcopy(table_img)
122 |             cell_image = draw_bboxes_on_image(cells, cell_image, color="green")
123 |             cell_image.save(
124 |                 os.path.join(
125 |                     loader.result_path,
126 |                     f"{name}_page{pnum + 1}_table{table_idx}_cells.png",
127 |                 )
128 |             )
129 | 
130 |     with open(
131 |         os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8"
132 |     ) as f:
133 |         json.dump(table_predictions, f, ensure_ascii=False)
134 | 
135 |     logger.info(f"Wrote results to {loader.result_path}")
136 | 
```

--------------------------------------------------------------------------------
/CLA.md:
--------------------------------------------------------------------------------

```markdown
 1 | Surya Contributor Agreement
 2 | 
 3 | This Surya Contributor Agreement ("SCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Endless Labs, Inc. The term "you" shall mean the person or entity identified below. 
 4 | 
 5 | If you agree to be bound by these terms, sign by writing "I have read the CLA document and I hereby sign the CLA" in response to the CLA bot Github comment. Read this agreement carefully before signing. These terms and conditions constitute a binding legal agreement.
 6 | 
 7 | 1. The term 'contribution' or 'contributed materials' means any source code, object code, patch, tool, sample, graphic, specification, manual, documentation, or any other material posted or submitted by you to the project. 
 8 | 2. With respect to any worldwide copyrights, or copyright applications and registrations, in your contribution: 
 9 |    - you hereby assign to us joint ownership, and to the extent that such assignment is or becomes invalid, ineffective or unenforceable, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty free, unrestricted license to exercise all rights under those copyrights. This includes, at our option, the right to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements, including dual-license structures for commercial customers; 
10 |    - you agree that each of us can do all things in relation to your contribution as if each of us were the sole owners, and if one of us makes a derivative work of your contribution, the one who makes the derivative work (or has it made will be the sole owner of that derivative work; 
11 |    - you agree that you will not assert any moral rights in your contribution against us, our licensees or transferees; 
12 |    - you agree that we may register a copyright in your contribution and exercise all ownership rights associated with it; and 
13 |    - you agree that neither of us has any duty to consult with, obtain the consent of, pay or render an accounting to the other for any use or distribution of vour contribution. 
14 | 3. With respect to any patents you own, or that you can license without payment to any third party, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty-free license to:
15 |    - make, have made, use, sell, offer to sell, import, and otherwise transfer your contribution in whole or in part, alone or in combination with or included in any product, work or materials arising out of the project to which your contribution was submitted, and
16 |    - at our option, to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements. 
17 | If you or your affiliates institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the contribution or any project it was submitted to constitutes direct or contributory patent infringement, then any patent licenses granted to you under this agreement for that contribution shall terminate as of the date such litigation is filed.
18 | 4. Except as set out above, you keep all right, title, and interest in your contribution. The rights that you grant to us under these terms are effective on the date you first submitted a contribution to us, even if your submission took place before the date you sign these terms. Any contribution we make available under any license will also be made available under a suitable FSF (Free Software Foundation) or OSI (Open Source Initiative) approved license. 
19 | 5. You covenant, represent, warrant and agree that: 
20 |    - each contribution that you submit is and shall be an original work of authorship and you can legally grant the rights set out in this SCA; 
21 |    - to the best of your knowledge, each contribution will not violate any third party's copyrights, trademarks, patents, or other intellectual property rights; and 
22 |    - each contribution shall be in compliance with U.S. export control laws and other applicable export and import laws.
23 | You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Endless Labs, Inc. may publicly disclose your participation in the project, including the fact that you have signed the SCA. 
24 | 6. This SCA is governed by the laws of the State of California and applicable U.S. Federal law. Any choice of law rules will not apply.
```

--------------------------------------------------------------------------------
/surya/scripts/texify_app.py:
--------------------------------------------------------------------------------

```python
  1 | import os
  2 | import re
  3 | from typing import List
  4 | 
  5 | from surya.recognition import RecognitionPredictor
  6 | from surya.foundation import FoundationPredictor
  7 | from surya.common.surya.schema import TaskNames
  8 | 
  9 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = (
 10 |     "1"  # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS
 11 | )
 12 | 
 13 | import io
 14 | 
 15 | import pandas as pd
 16 | import streamlit as st
 17 | from streamlit_drawable_canvas import st_canvas
 18 | import hashlib
 19 | import pypdfium2
 20 | 
 21 | from surya.settings import settings
 22 | from PIL import Image
 23 | 
 24 | MAX_WIDTH = 800
 25 | MAX_HEIGHT = 1000
 26 | 
 27 | 
 28 | def replace_fences(text):
 29 |     text = re.sub(r'<math display="block">(.*?)</math>', r"$$\1$$", text)
 30 |     text = re.sub(r"<math>(.*?)</math>", r"$\1$", text)
 31 |     text = re.sub(r'<math display="inline">(.*?)</math>', r"$\1$", text)
 32 |     return text
 33 | 
 34 | 
 35 | @st.cache_resource()
 36 | def load_predictor():
 37 |     foundation_predictor = FoundationPredictor()
 38 |     return RecognitionPredictor(foundation_predictor)
 39 | 
 40 | 
 41 | @st.cache_data()
 42 | def inference(pil_image: Image.Image, bbox: List[float]):
 43 |     input_img = pil_image.crop(bbox)
 44 |     bbox = [0, 0, input_img.width, input_img.height]
 45 |     model_output = predictor(
 46 |         [input_img], [TaskNames.block_without_boxes], bboxes=[[bbox]]
 47 |     )
 48 |     return model_output[0].text_lines[0].text
 49 | 
 50 | 
 51 | def open_pdf(pdf_file):
 52 |     stream = io.BytesIO(pdf_file.getvalue())
 53 |     return pypdfium2.PdfDocument(stream)
 54 | 
 55 | 
 56 | @st.cache_data()
 57 | def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI_HIGHRES):
 58 |     doc = open_pdf(pdf_file)
 59 |     renderer = doc.render(
 60 |         pypdfium2.PdfBitmap.to_pil,
 61 |         page_indices=[page_num - 1],
 62 |         scale=dpi / 72,
 63 |     )
 64 |     png = list(renderer)[0]
 65 |     png_image = png.convert("RGB")
 66 |     doc.close()
 67 |     return png_image
 68 | 
 69 | 
 70 | @st.cache_data()
 71 | def page_counter(pdf_file):
 72 |     doc = open_pdf(pdf_file)
 73 |     doc_len = len(doc)
 74 |     doc.close()
 75 |     return doc_len
 76 | 
 77 | 
 78 | def resize_image(pil_image):
 79 |     if pil_image is None:
 80 |         return
 81 |     pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
 82 | 
 83 | 
 84 | def get_canvas_hash(pil_image):
 85 |     return hashlib.md5(pil_image.tobytes()).hexdigest()
 86 | 
 87 | 
 88 | st.set_page_config(layout="wide")
 89 | 
 90 | top_message = """### LaTeX OCR
 91 | 
 92 | After the model loads, upload an image or a pdf, then draw a box around the equation or text you want to OCR by clicking and dragging. Surya will convert it to Markdown with LaTeX math on the right.
 93 | """
 94 | 
 95 | st.markdown(top_message)
 96 | col1, col2 = st.columns([0.7, 0.3])
 97 | 
 98 | predictor = load_predictor()
 99 | 
100 | in_file = st.sidebar.file_uploader(
101 |     "PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"]
102 | )
103 | if in_file is None:
104 |     st.stop()
105 | 
106 | if in_file is None:
107 |     st.stop()
108 | 
109 | filetype = in_file.type
110 | page_count = None
111 | if "pdf" in filetype:
112 |     page_count = page_counter(in_file)
113 |     page_number = st.sidebar.number_input(
114 |         f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count
115 |     )
116 |     pil_image = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
117 | else:
118 |     pil_image = Image.open(in_file).convert("RGB")
119 |     page_number = None
120 | 
121 | if pil_image is None:
122 |     st.stop()
123 | 
124 | pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
125 | canvas_hash = get_canvas_hash(pil_image)
126 | 
127 | with col1:
128 |     # Create a canvas component
129 |     canvas_result = st_canvas(
130 |         fill_color="rgba(255, 165, 0, 0.1)",  # Fixed fill color with some opacity
131 |         stroke_width=1,
132 |         stroke_color="#FFAA00",
133 |         background_color="#FFF",
134 |         background_image=pil_image,
135 |         update_streamlit=True,
136 |         height=pil_image.height,
137 |         width=pil_image.width,
138 |         drawing_mode="rect",
139 |         point_display_radius=0,
140 |         key=canvas_hash,
141 |     )
142 | 
143 | if not canvas_result.json_data:
144 |     st.stop()
145 | 
146 | objects = pd.json_normalize(
147 |     canvas_result.json_data["objects"]
148 | )  # need to convert obj to str because PyArrow
149 | bbox_list = None
150 | if objects.shape[0] > 0:
151 |     boxes = objects[objects["type"] == "rect"][["left", "top", "width", "height"]]
152 |     boxes["right"] = boxes["left"] + boxes["width"]
153 |     boxes["bottom"] = boxes["top"] + boxes["height"]
154 |     bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist()
155 | 
156 | if bbox_list:
157 |     with col2:
158 |         texts = [inference(pil_image, bbox) for bbox in bbox_list]
159 |         for idx, latex in enumerate(reversed(texts)):
160 |             st.markdown(f"### {len(texts) - idx}")
161 |             st.markdown(replace_fences(latex), unsafe_allow_html=True)
162 |             st.code(latex)
163 |             st.divider()
164 | 
165 | with col2:
166 |     tips = """
167 |     ### Usage tips
168 |     - Texify is sensitive to how you draw the box around the text you want to OCR. If you get bad results, try selecting a slightly different box, or splitting the box into multiple.
169 |     """
170 |     st.markdown(tips)
171 | 
```

--------------------------------------------------------------------------------
/surya/scripts/finetune_ocr.py:
--------------------------------------------------------------------------------

```python
  1 | from __future__ import annotations
  2 | from dataclasses import dataclass, field
  3 | from typing import Optional, Tuple
  4 | from datasets import load_dataset
  5 | import numpy as np
  6 | import torch
  7 | from transformers import (
  8 |     HfArgumentParser,
  9 |     TrainingArguments,
 10 |     Trainer,
 11 | )
 12 | 
 13 | from surya.common.surya import SuryaModel
 14 | from surya.common.surya.processor import SuryaOCRProcessor
 15 | from surya.foundation import FoundationPredictor
 16 | from surya.common.surya.processor.schema import ImageInput, TextInput
 17 | from surya.common.surya.schema import TaskNames
 18 | from surya.common.util import get_top_scripts, SCRIPT_TOKEN_MAPPING
 19 | 
 20 | # Do not change these defaults
 21 | OCR_TASK_NAME = TaskNames.ocr_with_boxes
 22 | OCR_MAX_IMAGE_SIZE = (1024, 512)
 23 | 
 24 | # Simple wrapper for huggingface dataset
 25 | class SuryaOCRDataset(torch.utils.data.Dataset):
 26 |     def __init__(self, processor: SuryaOCRProcessor, data_args: SuryaOCRDataArguments):
 27 |         super().__init__()
 28 |         self.hf_dataset = load_dataset(data_args.dataset_name, num_proc=data_args.num_loading_proc, split="train")
 29 |         self.processor = processor
 30 | 
 31 |     def __len__(self):
 32 |         return len(self.hf_dataset)
 33 | 
 34 |     def get_script_text(self, text: str) -> str:
 35 |         scripts = get_top_scripts(text)
 36 |         script_text = "".join(SCRIPT_TOKEN_MAPPING[script] for script in scripts)
 37 |         return script_text
 38 | 
 39 |     def __getitem__(self, index):
 40 |         try:
 41 |             data = self.hf_dataset[index]
 42 |             image = data["image"]
 43 |             image = image.convert("RGB")
 44 |             image = np.asarray(image, dtype=np.float32)
 45 |             image = self.processor.scale_to_fit(image, max_size=OCR_MAX_IMAGE_SIZE)
 46 | 
 47 |             # Add in script information
 48 |             gt_text = data["text"]
 49 |             gt_text = self.get_script_text(gt_text) + gt_text
 50 | 
 51 |             return_dict = {
 52 |                 "task": TaskNames.ocr_with_boxes,
 53 |                 "inputs": [
 54 |                     ImageInput(type="image", image=image, rotated=False),
 55 |                     # This empty TextInput **must be included** to match the original format
 56 |                     TextInput(type="text", text=""),
 57 |                     TextInput(type="text",text=gt_text),
 58 |                 ],
 59 |             }
 60 |             return return_dict
 61 |         except:
 62 |             import traceback; traceback.print_exc()
 63 |             return self.__getitem__((index + 1) % self.__len__())
 64 | 
 65 | class SuryaOCRDataCollator:
 66 |     def __init__(self, processor: SuryaOCRProcessor, data_args: SuryaOCRDataArguments):
 67 |         self.processor = processor
 68 |         self.max_sequence_length = data_args.max_sequence_length
 69 | 
 70 |     def __call__(self, inputs):
 71 |         # Use right padding for training. Defaults to left for inference
 72 |         processed_batch = self.processor(inputs, padding_side="right")
 73 |         
 74 |         if self.max_sequence_length is not None:
 75 |             processed_batch["input_ids"] = processed_batch["input_ids"][:, :self.max_sequence_length]
 76 |             processed_batch["attention_mask"] = processed_batch["attention_mask"][:, :self.max_sequence_length]
 77 |             processed_batch["position_ids"] = processed_batch["position_ids"][:, :self.max_sequence_length]
 78 | 
 79 |         lm_labels = processed_batch["input_ids"].clone()
 80 |         skip_label_mask = (
 81 |             (lm_labels == self.processor.pad_token_id )
 82 |             | (lm_labels == self.processor.bos_token_id[TaskNames.ocr_with_boxes])
 83 |             | (lm_labels == self.processor.eoi_token_id)
 84 |             | (lm_labels == self.processor.image_token_id)
 85 |         )
 86 |         lm_labels[skip_label_mask] = -100
 87 |         processed_batch["labels"] = lm_labels
 88 | 
 89 |         return processed_batch
 90 | 
 91 | def load_model_and_processor(checkpoint_path: Optional[str] = None) -> Tuple[SuryaModel, SuryaOCRProcessor]:
 92 |     foundation_predictor = FoundationPredictor(checkpoint=checkpoint_path)
 93 |     return foundation_predictor.model, foundation_predictor.processor
 94 | 
 95 | @dataclass
 96 | class SuryaOCRModelArguments:
 97 |     pretrained_checkpoint_path: Optional[str] = field(default=None)
 98 | 
 99 | @dataclass
100 | class SuryaOCRDataArguments:
101 |     dataset_name: str = field(default="datalab-to/ocr_finetune_example")
102 |     num_loading_proc: int = field(default=16)
103 |     max_sequence_length: Optional[int] = field(default=None)
104 | 
105 | @dataclass
106 | class SuryaOCRTrainingArguments(TrainingArguments):
107 |     remove_unused_columns: bool = field(default=False)
108 |     
109 | def main():
110 |     parser = HfArgumentParser((SuryaOCRModelArguments, SuryaOCRDataArguments, SuryaOCRTrainingArguments))
111 |     model_args, data_args, training_args = parser.parse_args_into_dataclasses()
112 | 
113 |     model, processor = load_model_and_processor(model_args.pretrained_checkpoint_path)
114 |     dataset = SuryaOCRDataset(processor, data_args)
115 |     collator = SuryaOCRDataCollator(processor, data_args)
116 | 
117 |     trainer = Trainer(
118 |         model=model,
119 |         args=training_args,
120 |         train_dataset=dataset,
121 |         data_collator=collator
122 |     )
123 | 
124 |     trainer.train()
125 | 
126 | if __name__ == "__main__":
127 |     main()
```

--------------------------------------------------------------------------------
/benchmark/utils/tesseract.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import List, Optional
  2 | 
  3 | import numpy as np
  4 | from tqdm import tqdm
  5 | 
  6 | from surya.input.processing import slice_bboxes_from_image
  7 | from surya.settings import settings
  8 | import os
  9 | from concurrent.futures import ProcessPoolExecutor
 10 | from surya.recognition.languages import CODE_TO_LANGUAGE
 11 | from surya.recognition import RecognitionPredictor
 12 | from surya.detection import DetectionPredictor
 13 | 
 14 | 
 15 | def surya_lang_to_tesseract(code: str) -> Optional[str]:
 16 |     lang_str = CODE_TO_LANGUAGE[code]
 17 |     try:
 18 |         tess_lang = TESS_LANGUAGE_TO_CODE[lang_str]
 19 |     except KeyError:
 20 |         return None
 21 |     return tess_lang
 22 | 
 23 | 
 24 | def tesseract_ocr(img, bboxes, lang: str):
 25 |     import pytesseract
 26 |     line_imgs = slice_bboxes_from_image(img, bboxes)
 27 |     config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"'
 28 |     lines = []
 29 |     for line_img in line_imgs:
 30 |         line = pytesseract.image_to_string(line_img, lang=lang, config=config)
 31 |         lines.append(line)
 32 |     return lines
 33 | 
 34 | 
 35 | def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None):
 36 |     tess_parallel_cores = min(len(imgs), RecognitionPredictor.get_batch_size())
 37 |     if not cpus:
 38 |         cpus = os.cpu_count()
 39 |     tess_parallel_cores = min(tess_parallel_cores, cpus)
 40 | 
 41 |     # Tesseract uses up to 4 processes per instance
 42 |     # Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images
 43 |     tess_parallel = max(tess_parallel_cores // 2, 1)
 44 | 
 45 |     with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
 46 |         tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc="Running tesseract OCR")
 47 |         tess_text = list(tess_text)
 48 |     return tess_text
 49 | 
 50 | 
 51 | def tesseract_bboxes(img):
 52 |     import pytesseract
 53 |     from pytesseract import Output
 54 |     arr_img = np.asarray(img, dtype=np.uint8)
 55 |     ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT)
 56 | 
 57 |     bboxes = []
 58 |     n_boxes = len(ocr['level'])
 59 |     for i in range(n_boxes):
 60 |         # It is possible to merge by line here with line number, but it gives bad results.
 61 |         _, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i]
 62 |         bbox = (x, y, x + w, y + h)
 63 |         bboxes.append(bbox)
 64 | 
 65 |     return bboxes
 66 | 
 67 | 
 68 | def tesseract_parallel(imgs):
 69 |     # Tesseract uses 4 threads per instance
 70 |     tess_parallel_cores = min(len(imgs), DetectionPredictor.get_batch_size())
 71 |     cpus = os.cpu_count()
 72 |     tess_parallel_cores = min(tess_parallel_cores, cpus)
 73 | 
 74 |     # Tesseract uses 4 threads per instance
 75 |     tess_parallel = max(tess_parallel_cores // 4, 1)
 76 | 
 77 |     with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
 78 |         tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc="Running tesseract bbox detection")
 79 |         tess_bboxes = list(tess_bboxes)
 80 |     return tess_bboxes
 81 | 
 82 | 
 83 | TESS_CODE_TO_LANGUAGE = {
 84 |     "afr": "Afrikaans",
 85 |     "amh": "Amharic",
 86 |     "ara": "Arabic",
 87 |     "asm": "Assamese",
 88 |     "aze": "Azerbaijani",
 89 |     "bel": "Belarusian",
 90 |     "ben": "Bengali",
 91 |     "bod": "Tibetan",
 92 |     "bos": "Bosnian",
 93 |     "bre": "Breton",
 94 |     "bul": "Bulgarian",
 95 |     "cat": "Catalan",
 96 |     "ceb": "Cebuano",
 97 |     "ces": "Czech",
 98 |     "chi_sim": "Chinese",
 99 |     "chr": "Cherokee",
100 |     "cym": "Welsh",
101 |     "dan": "Danish",
102 |     "deu": "German",
103 |     "dzo": "Dzongkha",
104 |     "ell": "Greek",
105 |     "eng": "English",
106 |     "epo": "Esperanto",
107 |     "est": "Estonian",
108 |     "eus": "Basque",
109 |     "fas": "Persian",
110 |     "fin": "Finnish",
111 |     "fra": "French",
112 |     "fry": "Western Frisian",
113 |     "guj": "Gujarati",
114 |     "gla": "Scottish Gaelic",
115 |     "gle": "Irish",
116 |     "glg": "Galician",
117 |     "heb": "Hebrew",
118 |     "hin": "Hindi",
119 |     "hrv": "Croatian",
120 |     "hun": "Hungarian",
121 |     "hye": "Armenian",
122 |     "iku": "Inuktitut",
123 |     "ind": "Indonesian",
124 |     "isl": "Icelandic",
125 |     "ita": "Italian",
126 |     "jav": "Javanese",
127 |     "jpn": "Japanese",
128 |     "kan": "Kannada",
129 |     "kat": "Georgian",
130 |     "kaz": "Kazakh",
131 |     "khm": "Khmer",
132 |     "kir": "Kyrgyz",
133 |     "kor": "Korean",
134 |     "lao": "Lao",
135 |     "lat": "Latin",
136 |     "lav": "Latvian",
137 |     "lit": "Lithuanian",
138 |     "mal": "Malayalam",
139 |     "mar": "Marathi",
140 |     "mkd": "Macedonian",
141 |     "mlt": "Maltese",
142 |     "mon": "Mongolian",
143 |     "msa": "Malay",
144 |     "mya": "Burmese",
145 |     "nep": "Nepali",
146 |     "nld": "Dutch",
147 |     "nor": "Norwegian",
148 |     "ori": "Oriya",
149 |     "pan": "Punjabi",
150 |     "pol": "Polish",
151 |     "por": "Portuguese",
152 |     "pus": "Pashto",
153 |     "ron": "Romanian",
154 |     "rus": "Russian",
155 |     "san": "Sanskrit",
156 |     "sin": "Sinhala",
157 |     "slk": "Slovak",
158 |     "slv": "Slovenian",
159 |     "snd": "Sindhi",
160 |     "spa": "Spanish",
161 |     "sqi": "Albanian",
162 |     "srp": "Serbian",
163 |     "swa": "Swahili",
164 |     "swe": "Swedish",
165 |     "syr": "Syriac",
166 |     "tam": "Tamil",
167 |     "tel": "Telugu",
168 |     "tgk": "Tajik",
169 |     "tha": "Thai",
170 |     "tir": "Tigrinya",
171 |     "tur": "Turkish",
172 |     "uig": "Uyghur",
173 |     "ukr": "Ukrainian",
174 |     "urd": "Urdu",
175 |     "uzb": "Uzbek",
176 |     "vie": "Vietnamese",
177 |     "yid": "Yiddish"
178 | }
179 | 
180 | TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()}
181 | 
```

--------------------------------------------------------------------------------
/benchmark/layout.py:
--------------------------------------------------------------------------------

```python
  1 | import collections
  2 | import copy
  3 | import json
  4 | 
  5 | import click
  6 | 
  7 | from benchmark.utils.metrics import precision_recall
  8 | from surya.foundation import FoundationPredictor
  9 | from surya.layout import LayoutPredictor
 10 | from surya.input.processing import convert_if_not_rgb
 11 | from surya.debug.draw import draw_bboxes_on_image
 12 | from surya.settings import settings
 13 | import os
 14 | import time
 15 | from tabulate import tabulate
 16 | import datasets
 17 | 
 18 | 
 19 | @click.command(help="Benchmark surya layout model.")
 20 | @click.option(
 21 |     "--results_dir",
 22 |     type=str,
 23 |     help="Path to JSON file with OCR results.",
 24 |     default=os.path.join(settings.RESULT_DIR, "benchmark"),
 25 | )
 26 | @click.option(
 27 |     "--max_rows",
 28 |     type=int,
 29 |     help="Maximum number of images to run benchmark on.",
 30 |     default=100,
 31 | )
 32 | @click.option("--debug", is_flag=True, help="Run in debug mode.", default=False)
 33 | def main(results_dir: str, max_rows: int, debug: bool):
 34 |     foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
 35 |     layout_predictor = LayoutPredictor(foundation_predictor)
 36 | 
 37 |     pathname = "layout_bench"
 38 |     # These have already been shuffled randomly, so sampling from the start is fine
 39 |     dataset = datasets.load_dataset(
 40 |         settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{max_rows}]"
 41 |     )
 42 |     images = list(dataset["image"])
 43 |     images = convert_if_not_rgb(images)
 44 | 
 45 |     if settings.LAYOUT_STATIC_CACHE:
 46 |         layout_predictor(images[:1])
 47 | 
 48 |     start = time.time()
 49 |     layout_predictions = layout_predictor(images)
 50 |     surya_time = time.time() - start
 51 | 
 52 |     folder_name = os.path.basename(pathname).split(".")[0]
 53 |     result_path = os.path.join(results_dir, folder_name)
 54 |     os.makedirs(result_path, exist_ok=True)
 55 | 
 56 |     label_alignment = {  # First is publaynet, second is surya
 57 |         "Image": [["Figure"], ["Picture", "Figure"]],
 58 |         "Table": [["Table"], ["Table", "Form", "TableOfContents"]],
 59 |         "Text": [
 60 |             ["Text"],
 61 |             [
 62 |                 "Text",
 63 |                 "Formula",
 64 |                 "Footnote",
 65 |                 "Caption",
 66 |                 "TextInlineMath",
 67 |                 "Code",
 68 |                 "Handwriting",
 69 |             ],
 70 |         ],
 71 |         "List": [["List"], ["ListItem"]],
 72 |         "Title": [["Title"], ["SectionHeader", "Title"]],
 73 |     }
 74 | 
 75 |     page_metrics = collections.OrderedDict()
 76 |     for idx, pred in enumerate(layout_predictions):
 77 |         row = dataset[idx]
 78 |         all_correct_bboxes = []
 79 |         page_results = {}
 80 |         for label_name in label_alignment:
 81 |             correct_cats, surya_cats = label_alignment[label_name]
 82 |             correct_bboxes = [
 83 |                 b
 84 |                 for b, category in zip(row["bboxes"], row["labels"])
 85 |                 if category in correct_cats
 86 |             ]
 87 |             all_correct_bboxes.extend(correct_bboxes)
 88 |             pred_bboxes = [b.bbox for b in pred.bboxes if b.label in surya_cats]
 89 | 
 90 |             metrics = precision_recall(
 91 |                 pred_bboxes, correct_bboxes, penalize_double=False
 92 |             )
 93 |             weight = len(correct_bboxes)
 94 |             metrics["weight"] = weight
 95 |             page_results[label_name] = metrics
 96 | 
 97 |         page_metrics[idx] = page_results
 98 | 
 99 |         if debug:
100 |             bbox_image = draw_bboxes_on_image(
101 |                 all_correct_bboxes, copy.deepcopy(images[idx])
102 |             )
103 |             bbox_image.save(os.path.join(result_path, f"{idx}_layout.png"))
104 | 
105 |     mean_metrics = collections.defaultdict(dict)
106 |     layout_types = sorted(page_metrics[0].keys())
107 |     metric_types = sorted(page_metrics[0][layout_types[0]].keys())
108 |     metric_types.remove("weight")
109 |     for label in layout_types:
110 |         for m in metric_types:
111 |             metric = []
112 |             total = 0
113 |             for page in page_metrics:
114 |                 metric.append(
115 |                     page_metrics[page][label][m] * page_metrics[page][label]["weight"]
116 |                 )
117 |                 total += page_metrics[page][label]["weight"]
118 | 
119 |             value = sum(metric)
120 |             if value > 0:
121 |                 value /= total
122 |             mean_metrics[label][m] = value
123 | 
124 |     out_data = {
125 |         "time": surya_time,
126 |         "metrics": mean_metrics,
127 |         "page_metrics": page_metrics,
128 |     }
129 | 
130 |     with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
131 |         json.dump(out_data, f, indent=4)
132 | 
133 |     table_headers = [
134 |         "Layout Type",
135 |     ] + metric_types
136 |     table_data = []
137 |     for layout_type in layout_types:
138 |         table_data.append(
139 |             [
140 |                 layout_type,
141 |             ]
142 |             + [f"{mean_metrics[layout_type][m]:.5f}" for m in metric_types]
143 |         )
144 | 
145 |     print(tabulate(table_data, headers=table_headers, tablefmt="github"))
146 |     print(
147 |         f"Took {surya_time / len(images):.5f} seconds per image, and {surya_time:.5f} seconds total."
148 |     )
149 |     print(
150 |         "Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold."
151 |     )
152 |     print(f"Wrote results to {result_path}")
153 | 
154 | 
155 | if __name__ == "__main__":
156 |     main()
157 | 
```

--------------------------------------------------------------------------------
/surya/foundation/loader.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import Optional
  2 | 
  3 | import torch
  4 | from transformers.utils import is_flash_attn_2_available
  5 | 
  6 | from surya.common.load import ModelLoader
  7 | from surya.common.surya.config import SuryaModelConfig
  8 | from surya.common.surya import SuryaModel, SuryaXLAModel
  9 | from surya.common.surya.processor import SuryaOCRProcessor
 10 | from surya.common.surya.processor.tokenizer import SuryaOCRTokenizer
 11 | from surya.common.util import is_flash_attn_2_supported
 12 | from surya.common.xla import get_compile_args
 13 | from surya.logging import get_logger
 14 | from surya.settings import settings
 15 | 
 16 | logger = get_logger()
 17 | 
 18 | 
 19 | class FoundationModelLoader(ModelLoader):
 20 |     def __init__(self, checkpoint: Optional[str] = None):
 21 |         super().__init__(checkpoint)
 22 | 
 23 |         if self.checkpoint is None:
 24 |             self.checkpoint = settings.FOUNDATION_MODEL_CHECKPOINT
 25 | 
 26 |     def model(
 27 |         self,
 28 |         device=settings.TORCH_DEVICE_MODEL,
 29 |         dtype=None,
 30 |         attention_implementation: Optional[str] = None,
 31 |     ) -> SuryaModel:
 32 |         if device is None:
 33 |             device = settings.TORCH_DEVICE_MODEL
 34 |         if dtype is None:
 35 |             # See https://github.com/pytorch/pytorch/issues/118122 - T4 (device version 7.5) will return true since it supports
 36 |             # emulated bf16, but falls back to very slow kernels, especially for SDPA
 37 |             dtype = settings.MODEL_DTYPE_BFLOAT
 38 |             if device == "cuda" and not torch.cuda.is_bf16_supported(
 39 |                 including_emulation=False
 40 |             ):
 41 |                 # If the device is cuda, we check if bf16 is supported, and if not, we use float16
 42 |                 dtype = settings.MODEL_DTYPE
 43 |         elif dtype == torch.float16:
 44 |             dtype = torch.bfloat16  # Model weights in bfloat16
 45 | 
 46 |         config = SuryaModelConfig.from_pretrained(self.checkpoint)
 47 | 
 48 |         if attention_implementation is not None:
 49 |             config.decoder._attn_implementation = attention_implementation
 50 |             config.vision_encoder._attn_implementation = attention_implementation
 51 |         elif is_flash_attn_2_available() and is_flash_attn_2_supported(device):
 52 |             config.decoder._attn_implementation = "flash_attention_2"
 53 |             config.vision_encoder._attn_implementation = "flash_attention_2"
 54 |         elif device == "xla":
 55 |             config.decoder._attn_implementation = "sdpa"
 56 |             config.vision_encoder._attn_implementation = "sdpa"
 57 |         else:
 58 |             config.decoder._attn_implementation = "sdpa"
 59 |             config.vision_encoder._attn_implementation = "sdpa"
 60 | 
 61 |         model_cls = SuryaModel
 62 |         if device == "xla":
 63 |             model_cls = SuryaXLAModel
 64 | 
 65 |         config._attn_implementation_autoset = True
 66 |         config.vision_encoder._attn_implementation_autoset = True
 67 |         config.decoder._attn_implementation_autoset = True
 68 | 
 69 |         model = model_cls.from_pretrained(
 70 |             self.checkpoint, dtype=dtype, config=config, ignore_mismatched_sizes=True
 71 |         ).to(device)
 72 |         model = model.eval()
 73 | 
 74 |         if settings.COMPILE_ALL or settings.COMPILE_FOUNDATION:
 75 |             torch._dynamo.config.cache_size_limit = 1000
 76 |             torch._dynamo.config.suppress_errors = True
 77 |             torch._dynamo.config.specialize_int = False
 78 |             torch._dynamo.config.allow_unspec_int_on_nn_module = True
 79 |             torch._dynamo.config.capture_scalar_outputs = True
 80 |             torch._dynamo.config.recompile_limit = 32
 81 | 
 82 |             logger.info(
 83 |                 f"Compiling foundation model {self.checkpoint} on device {device} with dtype {dtype}"
 84 |             )
 85 |             compile_args = get_compile_args(device)
 86 |             model.vision_encoder = torch.compile(model.vision_encoder, **compile_args)
 87 |             model.decoder = torch.compile(model.decoder, **compile_args)
 88 | 
 89 |         logger.debug(
 90 |             f"Loaded recognition model {self.checkpoint} from {SuryaModel.get_local_path(self.checkpoint)} onto device {model.device} with dtype {dtype}, using decoder attention mechanism {model.config.decoder._attn_implementation}, encoder attention mechanism {model.config.vision_encoder._attn_implementation}."
 91 |         )
 92 |         return model
 93 | 
 94 |     def processor(
 95 |         self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE_BFLOAT
 96 |     ) -> SuryaOCRProcessor:
 97 |         config: SuryaModelConfig = SuryaModelConfig.from_pretrained(self.checkpoint)
 98 | 
 99 |         ocr_tokenizer = SuryaOCRTokenizer(
100 |             special_tokens=config.special_ocr_tokens, model_checkpoint=self.checkpoint
101 |         )
102 | 
103 |         processor = SuryaOCRProcessor(
104 |             ocr_tokenizer=ocr_tokenizer,
105 |             blank_bbox_token_id=config.blank_bbox_token_id,
106 |             num_register_tokens=config.num_register_tokens,
107 |             sequence_length=None,
108 |             patch_size=config.vision_encoder.patch_size,
109 |             merge_size=config.vision_encoder.spatial_merge_size,
110 |             model_device=device,
111 |             num_beacon_tokens=config.num_beacon_tokens,
112 |             beacon_token_interval=config.beacon_token_interval,
113 |         )
114 | 
115 |         return processor
116 | 
```

--------------------------------------------------------------------------------
/surya/table_rec/shaper.py:
--------------------------------------------------------------------------------

```python
  1 | import math
  2 | from typing import List, Dict
  3 | import numpy as np
  4 | 
  5 | from surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM
  6 | 
  7 | 
  8 | class LabelShaper:
  9 |     def __init__(self):
 10 |         self.property_keys = [k for (k, kcount, mode) in BOX_PROPERTIES]
 11 | 
 12 |     def dict_to_labels(self, label_components: List[dict]):
 13 |         if len(label_components) == 0:
 14 |             return []
 15 | 
 16 |         out_list = []
 17 |         for (k, kcount, mode) in BOX_PROPERTIES:
 18 |             for label_component in label_components:
 19 |                 if k not in label_component:
 20 |                     raise ValueError(f"Missing key {k} in label component {label_component}")
 21 | 
 22 |                 if mode == "classification":
 23 |                     assert isinstance(label_component[k], int)
 24 |                 elif mode == "regression":
 25 |                     assert (isinstance(label_component[k], (int, float)) and kcount == 1) or len(label_component[k]) == kcount
 26 |                 else:
 27 |                     raise ValueError(f"Invalid mode {k['mode']} for key {k}")
 28 | 
 29 |         for label_component in label_components:
 30 |             bbox = label_component["bbox"]
 31 |             for i in range(len(bbox)):
 32 |                 if bbox[i] < 0:
 33 |                     bbox[i] = 0
 34 |                 if bbox[i] > BOX_DIM:
 35 |                     bbox[i] = BOX_DIM
 36 | 
 37 |             vector = []
 38 |             for (k, kcount, mode) in BOX_PROPERTIES:
 39 |                 item = label_component[k]
 40 |                 if isinstance(item, (list, tuple)):
 41 |                     vector += list(item)
 42 |                 elif isinstance(item, (float, int)):
 43 |                     if mode == "classification":
 44 |                         # Shift up for model
 45 |                         item += SPECIAL_TOKENS
 46 |                     vector.append(item)
 47 |                 else:
 48 |                     raise ValueError(f"Invalid item {item} for key {k}")
 49 | 
 50 |             out_list.append(vector)
 51 | 
 52 |         return out_list
 53 | 
 54 |     def component_idx(self, key):
 55 |         idx = 0
 56 |         for (k, kcount, mode) in BOX_PROPERTIES:
 57 |             if mode == "regression":
 58 |                 incr = kcount
 59 |             elif mode == "classification":
 60 |                 incr = 1
 61 |             else:
 62 |                 raise ValueError(f"Invalid mode {mode} for key {k}")
 63 |             if k == key:
 64 |                 return (idx, idx + incr)
 65 |             idx += incr
 66 |         raise ValueError(f"Key {key} not found in properties")
 67 | 
 68 |     def get_box_property(self, key, add_special_tokens=True):
 69 |         for (k, kcount, mode) in BOX_PROPERTIES:
 70 |             if k == key:
 71 |                 # Add special token count
 72 |                 if mode == "classification" and add_special_tokens:
 73 |                     kcount += SPECIAL_TOKENS
 74 |                 return (k, kcount, mode)
 75 |         raise ValueError(f"Key {key} not found in properties")
 76 | 
 77 |     def component_idx_dict(self):
 78 |         idx_dict = {}
 79 |         for (k, kcount, mode) in BOX_PROPERTIES:
 80 |             idx_dict[k] = self.component_idx(k)
 81 |         return idx_dict
 82 | 
 83 |     def convert_polygons_to_bboxes(self, label_components: List[Dict]):
 84 |         for i, label_component in enumerate(label_components):
 85 |             poly = label_component["polygon"]
 86 |             poly = np.clip(poly, 0, BOX_DIM)
 87 | 
 88 |             (x1, y1), (x2, y2), (x3, y3), (x4, y4) = poly
 89 |             cx = (x1 + x2 + x3 + x4) / 4
 90 |             cy = (y1 + y2 + y3 + y4) / 4
 91 |             width = (x2 + x3) / 2 - (x1 + x4) / 2
 92 |             height = (y3 + y4) / 2 - (y2 + y1) / 2
 93 |             bottom_avg_x = (x3 + x4) / 2
 94 |             top_avg_x = (x1 + x2) / 2
 95 |             right_avg_y = (y2 + y3) / 2
 96 |             left_avg_y = (y1 + y4) / 2
 97 | 
 98 |             x_skew = bottom_avg_x - top_avg_x
 99 |             y_skew = right_avg_y - left_avg_y
100 |             x_skew += BOX_DIM // 2 # Shift up into positive space
101 |             y_skew += BOX_DIM // 2 # Shift up into positive space
102 |             new_poly = [
103 |                 cx,
104 |                 cy,
105 |                 width,
106 |                 height,
107 |                 x_skew,
108 |                 y_skew
109 |             ]
110 |             label_component["bbox"] = new_poly
111 | 
112 |         return label_components
113 | 
114 |     def convert_bbox_to_polygon(self, box, skew_scaler=BOX_DIM // 2, skew_min=.001):
115 |         cx = box[0]
116 |         cy = box[1]
117 |         width = box[2]
118 |         height = box[3]
119 |         x1 = cx - width / 2
120 |         y1 = cy - height / 2
121 |         x2 = cx + width / 2
122 |         y2 = cy + height / 2
123 |         skew_x = math.floor((box[4] - skew_scaler) / 2)
124 |         skew_y = math.floor((box[5] - skew_scaler) / 2)
125 | 
126 |         # Ensures we don't get slightly warped boxes
127 |         # Note that the values are later scaled, so this is in 1/1024 space
128 |         if abs(skew_x) < skew_min:
129 |             skew_x = 0
130 | 
131 |         if abs(skew_y) < skew_min:
132 |             skew_y = 0
133 | 
134 |         polygon = [x1 - skew_x, y1 - skew_y, x2 - skew_x, y1 + skew_y, x2 + skew_x, y2 + skew_y, x1 + skew_x,
135 |                    y2 - skew_y]
136 |         poly = []
137 |         for i in range(4):
138 |             poly.append([
139 |                 polygon[2 * i],
140 |                 polygon[2 * i + 1]
141 |             ])
142 |         return poly
143 | 
144 | 
145 | 
146 | 
```

--------------------------------------------------------------------------------
/benchmark/detection.py:
--------------------------------------------------------------------------------

```python
  1 | import argparse
  2 | import collections
  3 | import copy
  4 | import json
  5 | 
  6 | import click
  7 | 
  8 | from benchmark.utils.bbox import get_pdf_lines
  9 | from benchmark.utils.metrics import precision_recall
 10 | from benchmark.utils.tesseract import tesseract_parallel
 11 | from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb
 12 | from surya.debug.draw import draw_polys_on_image
 13 | from surya.common.util import rescale_bbox
 14 | from surya.settings import settings
 15 | from surya.detection import DetectionPredictor
 16 | 
 17 | import os
 18 | import time
 19 | from tabulate import tabulate
 20 | import datasets
 21 | 
 22 | 
 23 | @click.command(help="Benchmark detection model.")
 24 | @click.option("--pdf_path", type=str, help="Path to PDF to detect bboxes in.", default=None)
 25 | @click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
 26 | @click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=100)
 27 | @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
 28 | @click.option("--tesseract", is_flag=True, help="Run tesseract as well.", default=False)
 29 | def main(pdf_path: str, results_dir: str, max_rows: int, debug: bool, tesseract: bool):
 30 |     det_predictor = DetectionPredictor()
 31 | 
 32 |     if pdf_path is not None:
 33 |         pathname = pdf_path
 34 |         doc = open_pdf(pdf_path)
 35 |         page_count = len(doc)
 36 |         page_indices = list(range(page_count))
 37 |         page_indices = page_indices[:max_rows]
 38 | 
 39 |         images = get_page_images(doc, page_indices)
 40 |         doc.close()
 41 | 
 42 |         image_sizes = [img.size for img in images]
 43 |         correct_boxes = get_pdf_lines(pdf_path, image_sizes)
 44 |     else:
 45 |         pathname = "det_bench"
 46 |         # These have already been shuffled randomly, so sampling from the start is fine
 47 |         dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{max_rows}]")
 48 |         images = list(dataset["image"])
 49 |         images = convert_if_not_rgb(images)
 50 |         correct_boxes = []
 51 |         for i, boxes in enumerate(dataset["bboxes"]):
 52 |             img_size = images[i].size
 53 |             # 1000,1000 is bbox size for doclaynet
 54 |             correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes])
 55 | 
 56 |     if settings.DETECTOR_STATIC_CACHE:
 57 |         # Run through one batch to compile the model
 58 |         det_predictor(images[:1])
 59 | 
 60 |     start = time.time()
 61 |     predictions = det_predictor(images)
 62 |     surya_time = time.time() - start
 63 | 
 64 |     if tesseract:
 65 |         start = time.time()
 66 |         tess_predictions = tesseract_parallel(images)
 67 |         tess_time = time.time() - start
 68 |     else:
 69 |         tess_predictions = [None] * len(images)
 70 |         tess_time = None
 71 | 
 72 |     folder_name = os.path.basename(pathname).split(".")[0]
 73 |     result_path = os.path.join(results_dir, folder_name)
 74 |     os.makedirs(result_path, exist_ok=True)
 75 | 
 76 |     page_metrics = collections.OrderedDict()
 77 |     for idx, (tb, sb, cb) in enumerate(zip(tess_predictions, predictions, correct_boxes)):
 78 |         surya_boxes = [s.bbox for s in sb.bboxes]
 79 |         surya_polys = [s.polygon for s in sb.bboxes]
 80 | 
 81 |         surya_metrics = precision_recall(surya_boxes, cb)
 82 |         if tb is not None:
 83 |             tess_metrics = precision_recall(tb, cb)
 84 |         else:
 85 |             tess_metrics = None
 86 | 
 87 |         page_metrics[idx] = {
 88 |             "surya": surya_metrics,
 89 |             "tesseract": tess_metrics
 90 |         }
 91 | 
 92 |         if debug:
 93 |             bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx]))
 94 |             bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png"))
 95 | 
 96 |     mean_metrics = {}
 97 |     metric_types = sorted(page_metrics[0]["surya"].keys())
 98 |     models = ["surya"]
 99 |     if tesseract:
100 |         models.append("tesseract")
101 | 
102 |     for k in models:
103 |         for m in metric_types:
104 |             metric = []
105 |             for page in page_metrics:
106 |                 metric.append(page_metrics[page][k][m])
107 |             if k not in mean_metrics:
108 |                 mean_metrics[k] = {}
109 |             mean_metrics[k][m] = sum(metric) / len(metric)
110 | 
111 |     out_data = {
112 |         "times": {
113 |             "surya": surya_time,
114 |             "tesseract": tess_time
115 |         },
116 |         "metrics": mean_metrics,
117 |         "page_metrics": page_metrics
118 |     }
119 | 
120 |     with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
121 |         json.dump(out_data, f, indent=4)
122 | 
123 |     table_headers = ["Model", "Time (s)", "Time per page (s)"] + metric_types
124 |     table_data = [
125 |         ["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types],
126 |     ]
127 |     if tesseract:
128 |         table_data.append(
129 |             ["tesseract", tess_time, tess_time / len(images)] + [mean_metrics["tesseract"][m] for m in metric_types]
130 |         )
131 | 
132 |     print(tabulate(table_data, headers=table_headers, tablefmt="github"))
133 |     print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold.  There is a precision penalty for multiple boxes overlapping reference lines.")
134 |     print(f"Wrote results to {result_path}")
135 | 
136 | 
137 | if __name__ == "__main__":
138 |     main()
139 | 
```

--------------------------------------------------------------------------------
/surya/detection/heatmap.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import List
  2 | 
  3 | import cv2
  4 | import numpy as np
  5 | from PIL import Image
  6 | 
  7 | from surya.common.util import clean_boxes
  8 | from surya.detection import TextDetectionResult
  9 | from surya.common.polygon import PolygonBox
 10 | from surya.settings import settings
 11 | 
 12 | 
 13 | def get_dynamic_thresholds(linemap, text_threshold, low_text, typical_top10_avg=0.7):
 14 |     # Find average intensity of top 10% pixels
 15 |     flat_map = linemap.ravel()
 16 |     top_10_count = int(len(flat_map) * 0.9)
 17 |     avg_intensity = np.mean(np.partition(flat_map, top_10_count)[top_10_count:])
 18 |     scaling_factor = np.clip(avg_intensity / typical_top10_avg, 0, 1) ** (1 / 2)
 19 | 
 20 |     low_text = np.clip(low_text * scaling_factor, 0.1, 0.6)
 21 |     text_threshold = np.clip(text_threshold * scaling_factor, 0.15, 0.8)
 22 | 
 23 |     return text_threshold, low_text
 24 | 
 25 | 
 26 | def detect_boxes(linemap, text_threshold, low_text):
 27 |     # From CRAFT - https://github.com/clovaai/CRAFT-pytorch
 28 |     # Modified to return boxes and for speed, accuracy
 29 |     img_h, img_w = linemap.shape
 30 | 
 31 |     text_threshold, low_text = get_dynamic_thresholds(linemap, text_threshold, low_text)
 32 | 
 33 |     text_score_comb = (linemap > low_text).astype(np.uint8)
 34 |     label_count, labels, stats, centroids = cv2.connectedComponentsWithStats(
 35 |         text_score_comb, connectivity=4
 36 |     )
 37 | 
 38 |     det = []
 39 |     confidences = []
 40 |     max_confidence = 0
 41 | 
 42 |     for k in range(1, label_count):
 43 |         # size filtering
 44 |         size = stats[k, cv2.CC_STAT_AREA]
 45 |         if size < 10:
 46 |             continue
 47 | 
 48 |         # make segmentation map
 49 |         x, y, w, h = stats[
 50 |             k,
 51 |             [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT],
 52 |         ]
 53 | 
 54 |         try:
 55 |             niter = int(np.sqrt(min(w, h)))
 56 |         except ValueError:
 57 |             niter = 0
 58 | 
 59 |         buffer = 1
 60 |         sx, sy = max(0, x - niter - buffer), max(0, y - niter - buffer)
 61 |         ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer)
 62 | 
 63 |         mask = labels[sy:ey, sx:ex] == k
 64 |         selected_linemap = linemap[sy:ey, sx:ex][mask]
 65 |         if selected_linemap.size == 0:
 66 |             continue
 67 | 
 68 |         line_max = np.max(selected_linemap)
 69 | 
 70 |         # thresholding
 71 |         if line_max < text_threshold:
 72 |             continue
 73 | 
 74 |         segmap = mask.astype(np.uint8)
 75 | 
 76 |         ksize = buffer + niter
 77 |         kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (ksize, ksize))
 78 |         selected_segmap = cv2.dilate(segmap, kernel)
 79 | 
 80 |         # make box
 81 |         y_inds, x_inds = np.nonzero(selected_segmap)
 82 |         x_inds += sx
 83 |         y_inds += sy
 84 |         np_contours = np.column_stack((x_inds, y_inds))
 85 |         rectangle = cv2.minAreaRect(np_contours)
 86 |         box = cv2.boxPoints(rectangle)
 87 | 
 88 |         # align diamond-shape
 89 |         w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
 90 |         box_ratio = max(w, h) / (min(w, h) + 1e-5)
 91 |         if abs(1 - box_ratio) <= 0.1:
 92 |             left, right = np_contours[:, 0].min(), np_contours[:, 0].max()
 93 |             top, bottom = np_contours[:, 1].min(), np_contours[:, 1].max()
 94 |             box = np.array(
 95 |                 [[left, top], [right, top], [right, bottom], [left, bottom]],
 96 |                 dtype=np.float32,
 97 |             )
 98 | 
 99 |         # make clock-wise order
100 |         startidx = box.sum(axis=1).argmin()
101 |         box = np.roll(box, 4 - startidx, 0)
102 | 
103 |         max_confidence = max(max_confidence, line_max)
104 | 
105 |         confidences.append(line_max)
106 |         det.append(box)
107 | 
108 |     if max_confidence > 0:
109 |         confidences = [c / max_confidence for c in confidences]
110 |     return det, confidences
111 | 
112 | 
113 | def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]:
114 |     if text_threshold is None:
115 |         text_threshold = settings.DETECTOR_TEXT_THRESHOLD
116 |     if low_text is None:
117 |         low_text = settings.DETECTOR_BLANK_THRESHOLD
118 | 
119 |     if textmap.dtype != np.float32:
120 |         textmap = textmap.astype(np.float32)
121 | 
122 |     boxes, confidences = detect_boxes(textmap, text_threshold, low_text)
123 |     # From point form to box form
124 |     return [
125 |         PolygonBox(polygon=box, confidence=confidence)
126 |         for box, confidence in zip(boxes, confidences)
127 |     ]
128 | 
129 | 
130 | def get_and_clean_boxes(
131 |     textmap, processor_size, image_size, text_threshold=None, low_text=None
132 | ) -> List[PolygonBox]:
133 |     bboxes = get_detected_boxes(textmap, text_threshold, low_text)
134 |     for bbox in bboxes:
135 |         bbox.rescale(processor_size, image_size)
136 |         bbox.fit_to_bounds([0, 0, image_size[0], image_size[1]])
137 | 
138 |     bboxes = clean_boxes(bboxes)
139 |     return bboxes
140 | 
141 | 
142 | def parallel_get_boxes(preds, orig_sizes, include_maps=False):
143 |     heatmap, affinity_map = preds
144 |     heat_img, aff_img = None, None
145 | 
146 |     if include_maps:
147 |         heat_img = Image.fromarray((heatmap * 255).astype(np.uint8))
148 |         aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8))
149 |     heatmap_size = list(reversed(heatmap.shape))
150 |     bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes)
151 |     for box in bboxes:
152 |         # Skip for vertical boxes
153 |         if box.height < 3 * box.width:
154 |             box.expand(x_margin=0, y_margin=settings.DETECTOR_BOX_Y_EXPAND_MARGIN)
155 |             box.fit_to_bounds(
156 |                 [0, 0, orig_sizes[0], orig_sizes[1]]
157 |             )  # Fix any bad expands
158 | 
159 |     result = TextDetectionResult(
160 |         bboxes=bboxes,
161 |         heatmap=heat_img,
162 |         affinity_map=aff_img,
163 |         image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]],
164 |     )
165 |     return result
166 | 
```

--------------------------------------------------------------------------------
/surya/recognition/util.py:
--------------------------------------------------------------------------------

```python
  1 | import re
  2 | from typing import List, Tuple
  3 | 
  4 | import numpy
  5 | import torch
  6 | 
  7 | from surya.common.polygon import PolygonBox
  8 | from surya.recognition.schema import TextLine, TextWord, TextChar
  9 | 
 10 | MATH_SYMBOLS = ["+", "-", "*", "=", "^", "_", "\\", "{", "}"]
 11 | 
 12 | 
 13 | def unwrap_math(text: str) -> str:
 14 |     if len(text) > 50:
 15 |         return text
 16 | 
 17 |     # Detected as math, but does not contain LaTeX commands
 18 |     if (
 19 |         re.match(r'^\s*<math(?:\s+display="inline")?.*?</math>\s*$', text, re.DOTALL)
 20 |         and text.count("<math") == 1
 21 |         and not any([symb in text for symb in MATH_SYMBOLS])
 22 |     ):
 23 |         # Remove math tags
 24 |         text = re.sub(r"<math.*?>", "", text)
 25 |         text = re.sub(r"</math>", "", text)
 26 | 
 27 |     return text
 28 | 
 29 | 
 30 | MATH_BLOCK = re.compile(r"(<math\b[^>]*>)(.*?)</math>", flags=re.I | re.S)
 31 | STRIP_TAGS = re.compile(r"</?(?:br|u|del|mark|i|b|sup|sub)\b[^>]*>", flags=re.I | re.S)
 32 | DEFAULT_TAGS_TO_FILTER = ["p", "li", "ul", "ol", "table", "td", "tr", "th", "tbody", "pre"]
 33 | 
 34 | def filter_blacklist_tags(text_chars: List[TextChar], tags_to_filter: List[str] = None) -> List[TextChar]:
 35 |     filtered_chars = []
 36 |     char_buffer = []
 37 |     in_tag = False
 38 |     if tags_to_filter is None:
 39 |         tags_to_filter = DEFAULT_TAGS_TO_FILTER
 40 | 
 41 |     for text_char in text_chars:
 42 |         char = text_char.text
 43 | 
 44 |         if char.startswith("<") or in_tag:
 45 |             in_tag = True
 46 |             char_buffer.append(text_char)
 47 |             if char.endswith(">"):
 48 |                 full_tag = ''.join(c.text for c in char_buffer)
 49 |                 inner = full_tag[1:-1].strip()  # remove < >
 50 |                 inner = inner.strip("/")  # remove '/'
 51 |                 
 52 |                 # Possible that it is just an empty <>
 53 |                 if not inner:
 54 |                     filtered_chars.extend(char_buffer)
 55 |                     in_tag = False
 56 |                     char_buffer = []
 57 |                     continue
 58 |                 
 59 |                 tag_name_candidate = inner.split()[0]   # remove any attributes
 60 |                 if tag_name_candidate in tags_to_filter:
 61 |                     # Discard tag
 62 |                     pass
 63 |                 else:
 64 |                     # Keep tag
 65 |                     filtered_chars.extend(char_buffer)
 66 | 
 67 |                 in_tag = False
 68 |                 char_buffer = []
 69 |         else:
 70 |             filtered_chars.append(text_char)
 71 | 
 72 |     # Flush buffer if we never reached a tag close
 73 |     if char_buffer:
 74 |         filtered_chars.extend(char_buffer)
 75 | 
 76 |     return filtered_chars
 77 | 
 78 | 
 79 | def clean_math_tags(html: str) -> str:
 80 |     # strip unwanted tags inside every well‑formed <math>…</math>
 81 |     def _inner(m):
 82 |         inner = STRIP_TAGS.sub("", m.group(2))
 83 |         return f"{m.group(1)}{inner}</math>" if inner.strip() else ""
 84 | 
 85 |     cleaned = MATH_BLOCK.sub(_inner, html)
 86 | 
 87 |     # drop only orphan *closing* </math> tags
 88 |     depth = 0
 89 |     parts = []
 90 |     for token in re.split(r"(</?math[^>]*>)", cleaned, flags=re.I):
 91 |         if token.lower().startswith("<math"):
 92 |             depth += 1
 93 |             parts.append(token)
 94 |         elif token.lower() == "</math>":
 95 |             if depth:  # keep it only if it matches an open
 96 |                 depth -= 1
 97 |                 parts.append(token)
 98 |             # else: skip orphan closing tag
 99 |         else:
100 |             parts.append(token)
101 |     return "".join(parts)
102 | 
103 | 
104 | def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25):
105 |     # Sorts in reading order.  Not 100% accurate, this should only
106 |     # be used as a starting point for more advanced sorting.
107 |     vertical_groups = {}
108 |     for line in lines:
109 |         group_key = (
110 |             round(
111 |                 line.bbox[1]
112 |                 if isinstance(line, TextLine)
113 |                 else line["bbox"][1] / tolerance
114 |             )
115 |             * tolerance
116 |         )
117 |         if group_key not in vertical_groups:
118 |             vertical_groups[group_key] = []
119 |         vertical_groups[group_key].append(line)
120 | 
121 |     # Sort each group horizontally and flatten the groups into a single list
122 |     sorted_lines = []
123 |     for _, group in sorted(vertical_groups.items()):
124 |         sorted_group = sorted(
125 |             group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0]
126 |         )
127 |         sorted_lines.extend(sorted_group)
128 | 
129 |     return sorted_lines
130 | 
131 | 
132 | def clean_close_polygons(bboxes: List[List[List[int]]], thresh: float = 0.1):
133 |     if len(bboxes) < 2:
134 |         return bboxes
135 | 
136 |     new_bboxes = [bboxes[0]]
137 |     for i in range(1, len(bboxes)):
138 |         close = True
139 |         prev_bbox = bboxes[i - 1]
140 |         bbox = bboxes[i]
141 |         for j in range(4):
142 |             if (
143 |                 abs(bbox[j][0] - prev_bbox[j][0]) > thresh
144 |                 or abs(bbox[j][1] - prev_bbox[j][1]) > thresh
145 |             ):
146 |                 close = False
147 |                 break
148 | 
149 |         if not close:
150 |             new_bboxes.append(bboxes[i])
151 | 
152 |     return new_bboxes
153 | 
154 | 
155 | def words_from_chars(chars: List[TextChar], line_box: PolygonBox):
156 |     words = []
157 |     word = None
158 |     for i, char in enumerate(chars):
159 |         if not char.bbox_valid:
160 |             if word:
161 |                 words.append(word)
162 |                 word = None
163 |             continue
164 | 
165 |         if not word:
166 |             word = TextWord(**char.model_dump())
167 | 
168 |             # Fit bounds to line if first word
169 |             if i == 0:
170 |                 word.merge_left(line_box)
171 | 
172 |         elif not char.text.strip():
173 |             if word:
174 |                 words.append(word)
175 |             word = None
176 |         else:
177 |             # Merge bboxes
178 |             word.merge(char)
179 |             word.text = word.text + char.text
180 | 
181 |             if i == len(chars) - 1:
182 |                 word.merge_right(line_box)
183 |     if word:
184 |         words.append(word)
185 | 
186 |     return words
```

--------------------------------------------------------------------------------
/surya/common/s3.py:
--------------------------------------------------------------------------------

```python
  1 | import json
  2 | import os
  3 | import shutil
  4 | import tempfile
  5 | import time
  6 | from concurrent.futures import ThreadPoolExecutor
  7 | from pathlib import Path
  8 | 
  9 | import requests
 10 | from tqdm import tqdm
 11 | 
 12 | from surya.logging import get_logger
 13 | from surya.settings import settings
 14 | 
 15 | logger = get_logger()
 16 | 
 17 | # Lock file expiration time in seconds (10 minutes)
 18 | LOCK_EXPIRATION = 600
 19 | 
 20 | 
 21 | def join_urls(url1: str, url2: str):
 22 |     url1 = url1.rstrip("/")
 23 |     url2 = url2.lstrip("/")
 24 |     return f"{url1}/{url2}"
 25 | 
 26 | 
 27 | def get_model_name(pretrained_model_name_or_path: str):
 28 |     return pretrained_model_name_or_path.split("/")[0]
 29 | 
 30 | 
 31 | def download_file(remote_path: str, local_path: str, chunk_size: int = 1024 * 1024):
 32 |     local_path = Path(local_path)
 33 |     try:
 34 |         response = requests.get(remote_path, stream=True, allow_redirects=True)
 35 |         response.raise_for_status()  # Raise an exception for bad status codes
 36 | 
 37 |         # Get file size from headers for progress bar
 38 |         total_size = int(response.headers.get('content-length', 0))
 39 |         
 40 |         # Create progress bar with file name and size info
 41 |         filename = local_path.name
 42 |         pbar = tqdm(
 43 |             total=total_size,
 44 |             unit='B',
 45 |             unit_scale=True,
 46 |             unit_divisor=1024,
 47 |             desc=f"Downloading {filename}",
 48 |             miniters=1
 49 |         )
 50 | 
 51 |         with open(local_path, "wb") as f:
 52 |             downloaded = 0
 53 |             for chunk in response.iter_content(chunk_size=chunk_size):
 54 |                 if chunk:
 55 |                     f.write(chunk)
 56 |                     downloaded += len(chunk)
 57 |                     pbar.update(len(chunk))
 58 |         
 59 |         pbar.close()
 60 |         return local_path
 61 |     except Exception as e:
 62 |         if local_path.exists():
 63 |             local_path.unlink()
 64 |         logger.error(f"Download error for file {remote_path}: {str(e)}")
 65 |         raise
 66 | 
 67 | 
 68 | def check_manifest(local_dir: str):
 69 |     local_dir = Path(local_dir)
 70 |     manifest_path = local_dir / "manifest.json"
 71 |     if not os.path.exists(manifest_path):
 72 |         return False
 73 | 
 74 |     try:
 75 |         with open(manifest_path, "r") as f:
 76 |             manifest = json.load(f)
 77 |         for file in manifest["files"]:
 78 |             if not os.path.exists(local_dir / file):
 79 |                 return False
 80 |     except Exception:
 81 |         return False
 82 | 
 83 |     return True
 84 | 
 85 | 
 86 | def download_directory(remote_path: str, local_dir: str):
 87 |     model_name = get_model_name(remote_path)
 88 |     s3_url = join_urls(settings.S3_BASE_URL, remote_path)
 89 |     # Check to see if it's already downloaded
 90 |     model_exists = check_manifest(local_dir)
 91 |     if model_exists:
 92 |         return
 93 | 
 94 |     # Use tempfile.TemporaryDirectory to automatically clean up
 95 |     with tempfile.TemporaryDirectory() as temp_dir:
 96 |         # Download the manifest file
 97 |         manifest_file = join_urls(s3_url, "manifest.json")
 98 |         manifest_path = os.path.join(temp_dir, "manifest.json")
 99 |         download_file(manifest_file, manifest_path)
100 | 
101 |         # List and download all files
102 |         with open(manifest_path, "r") as f:
103 |             manifest = json.load(f)
104 | 
105 |         pbar = tqdm(
106 |             desc=f"Downloading {model_name} model to {local_dir}",
107 |             total=len(manifest["files"]),
108 |         )
109 | 
110 |         with ThreadPoolExecutor(
111 |             max_workers=settings.PARALLEL_DOWNLOAD_WORKERS
112 |         ) as executor:
113 |             futures = []
114 |             for file in manifest["files"]:
115 |                 remote_file = join_urls(s3_url, file)
116 |                 local_file = os.path.join(temp_dir, file)
117 |                 futures.append(executor.submit(download_file, remote_file, local_file))
118 | 
119 |             for future in futures:
120 |                 future.result()
121 |                 pbar.update(1)
122 | 
123 |         pbar.close()
124 | 
125 |         # Move all files to new directory
126 |         for file in os.listdir(temp_dir):
127 |             shutil.move(os.path.join(temp_dir, file), local_dir)
128 | 
129 | 
130 | class S3DownloaderMixin:
131 |     s3_prefix = "s3://"
132 | 
133 |     @classmethod
134 |     def get_local_path(cls, pretrained_model_name_or_path) -> str:
135 |         if pretrained_model_name_or_path.startswith(cls.s3_prefix):
136 |             pretrained_model_name_or_path = pretrained_model_name_or_path.replace(
137 |                 cls.s3_prefix, ""
138 |             )
139 |             cache_dir = settings.MODEL_CACHE_DIR
140 |             local_path = os.path.join(cache_dir, pretrained_model_name_or_path)
141 |             os.makedirs(local_path, exist_ok=True)
142 |         else:
143 |             local_path = ""
144 |         return local_path
145 | 
146 |     @classmethod
147 |     def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
148 |         # Allow loading models directly from the hub, or using s3
149 |         if not pretrained_model_name_or_path.startswith(cls.s3_prefix):
150 |             return super().from_pretrained(
151 |                 pretrained_model_name_or_path, *args, **kwargs
152 |             )
153 | 
154 |         local_path = cls.get_local_path(pretrained_model_name_or_path)
155 |         pretrained_model_name_or_path = pretrained_model_name_or_path.replace(
156 |             cls.s3_prefix, ""
157 |         )
158 | 
159 |         # Retry logic for downloading the model folder
160 |         retries = 3
161 |         delay = 5
162 |         attempt = 0
163 |         success = False
164 |         while not success and attempt < retries:
165 |             try:
166 |                 download_directory(pretrained_model_name_or_path, local_path)
167 |                 success = True  # If download succeeded
168 |             except Exception as e:
169 |                 logger.error(
170 |                     f"Error downloading model from {pretrained_model_name_or_path}. Attempt {attempt + 1} of {retries}. Error: {e}"
171 |                 )
172 |                 attempt += 1
173 |                 if attempt < retries:
174 |                     logger.info(f"Retrying in {delay} seconds...")
175 |                     time.sleep(delay)  # Wait before retrying
176 |                 else:
177 |                     logger.error(
178 |                         f"Failed to download {pretrained_model_name_or_path} after {retries} attempts."
179 |                     )
180 |                     raise e  # Reraise exception after max retries
181 | 
182 |         return super().from_pretrained(local_path, *args, **kwargs)
183 | 
```

--------------------------------------------------------------------------------
/surya/detection/__init__.py:
--------------------------------------------------------------------------------

```python
  1 | from concurrent.futures import ThreadPoolExecutor
  2 | from typing import List, Generator, Tuple
  3 | 
  4 | import numpy as np
  5 | import torch
  6 | import torch.nn.functional as F
  7 | 
  8 | from PIL import Image
  9 | from tqdm import tqdm
 10 | 
 11 | from surya.common.predictor import BasePredictor
 12 | from surya.common.xla import mark_step
 13 | 
 14 | from surya.detection.loader import DetectionModelLoader
 15 | from surya.detection.parallel import FakeExecutor
 16 | from surya.detection.util import get_total_splits, split_image
 17 | from surya.detection.schema import TextDetectionResult
 18 | from surya.settings import settings
 19 | from surya.detection.heatmap import parallel_get_boxes
 20 | 
 21 | 
 22 | class DetectionPredictor(BasePredictor):
 23 |     model_loader_cls = DetectionModelLoader
 24 |     batch_size = settings.DETECTOR_BATCH_SIZE
 25 |     default_batch_sizes = {"cpu": 8, "mps": 8, "cuda": 36, "xla": 18}
 26 | 
 27 |     def __call__(
 28 |         self, images: List[Image.Image], batch_size=None, include_maps=False
 29 |     ) -> List[TextDetectionResult]:
 30 |         detection_generator = self.batch_detection(
 31 |             images, batch_size=batch_size, static_cache=settings.DETECTOR_STATIC_CACHE
 32 |         )
 33 | 
 34 |         postprocessing_futures = []
 35 |         max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
 36 |         parallelize = (
 37 |             not settings.IN_STREAMLIT
 38 |             and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
 39 |         )
 40 |         executor = ThreadPoolExecutor if parallelize else FakeExecutor
 41 |         with executor(max_workers=max_workers) as e:
 42 |             for preds, orig_sizes in detection_generator:
 43 |                 for pred, orig_size in zip(preds, orig_sizes):
 44 |                     postprocessing_futures.append(
 45 |                         e.submit(parallel_get_boxes, pred, orig_size, include_maps)
 46 |                     )
 47 | 
 48 |         return [future.result() for future in postprocessing_futures]
 49 | 
 50 |     def prepare_image(self, img):
 51 |         new_size = (self.processor.size["width"], self.processor.size["height"])
 52 | 
 53 |         # This double resize actually necessary for downstream accuracy
 54 |         img.thumbnail(new_size, Image.Resampling.LANCZOS)
 55 |         img = img.resize(
 56 |             new_size, Image.Resampling.LANCZOS
 57 |         )  # Stretch smaller dimension to fit new size
 58 | 
 59 |         img = np.asarray(img, dtype=np.uint8)
 60 |         img = self.processor(img)["pixel_values"][0]
 61 |         img = torch.from_numpy(img)
 62 |         return img
 63 | 
 64 |     def batch_detection(
 65 |         self, images: List, batch_size=None, static_cache=False
 66 |     ) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]:
 67 |         assert all([isinstance(image, Image.Image) for image in images])
 68 |         if batch_size is None:
 69 |             batch_size = self.get_batch_size()
 70 |         heatmap_count = self.model.config.num_labels
 71 | 
 72 |         orig_sizes = [image.size for image in images]
 73 |         splits_per_image = [
 74 |             get_total_splits(size, self.processor.size["height"]) for size in orig_sizes
 75 |         ]
 76 | 
 77 |         batches = []
 78 |         current_batch_size = 0
 79 |         current_batch = []
 80 |         for i in range(len(images)):
 81 |             if current_batch_size + splits_per_image[i] > batch_size:
 82 |                 if len(current_batch) > 0:
 83 |                     batches.append(current_batch)
 84 |                 current_batch = []
 85 |                 current_batch_size = 0
 86 |             current_batch.append(i)
 87 |             current_batch_size += splits_per_image[i]
 88 | 
 89 |         if len(current_batch) > 0:
 90 |             batches.append(current_batch)
 91 | 
 92 |         for batch_idx in tqdm(
 93 |             range(len(batches)), desc="Detecting bboxes", disable=self.disable_tqdm
 94 |         ):
 95 |             batch_image_idxs = batches[batch_idx]
 96 |             batch_images = [images[j].convert("RGB") for j in batch_image_idxs]
 97 | 
 98 |             split_index = []
 99 |             split_heights = []
100 |             image_splits = []
101 |             for image_idx, image in enumerate(batch_images):
102 |                 image_parts, split_height = split_image(
103 |                     image, self.processor.size["height"]
104 |                 )
105 |                 image_splits.extend(image_parts)
106 |                 split_index.extend([image_idx] * len(image_parts))
107 |                 split_heights.extend(split_height)
108 | 
109 |             image_splits = [self.prepare_image(image) for image in image_splits]
110 |             # Batch images in dim 0
111 |             batch = torch.stack(image_splits, dim=0).to(self.model.dtype)
112 |             if static_cache:
113 |                 batch = self.pad_to_batch_size(batch, batch_size)
114 | 
115 |             with settings.INFERENCE_MODE():
116 |                 pred = self.model(
117 |                     pixel_values=batch.to(self.model.device)
118 |                 )  # Moving the to device here fixes issues with xla recompilation
119 | 
120 |             logits = pred.logits
121 |             correct_shape = [
122 |                 self.processor.size["height"],
123 |                 self.processor.size["width"],
124 |             ]
125 |             current_shape = list(logits.shape[2:])
126 |             if current_shape != correct_shape:
127 |                 logits = F.interpolate(
128 |                     logits, size=correct_shape, mode="bilinear", align_corners=False
129 |                 )
130 |             mark_step()
131 | 
132 |             logits = logits.to(torch.float32).cpu().numpy()
133 |             preds = []
134 |             for i, (idx, height) in enumerate(zip(split_index, split_heights)):
135 |                 # If our current prediction length is below the image idx, that means we have a new image
136 |                 # Otherwise, we need to add to the current image
137 |                 if len(preds) <= idx:
138 |                     preds.append([logits[i][k] for k in range(heatmap_count)])
139 |                 else:
140 |                     heatmaps = preds[idx]
141 |                     pred_heatmaps = [logits[i][k] for k in range(heatmap_count)]
142 | 
143 |                     if height < self.processor.size["height"]:
144 |                         # Cut off padding to get original height
145 |                         pred_heatmaps = [
146 |                             pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps
147 |                         ]
148 | 
149 |                     for k in range(heatmap_count):
150 |                         heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]])
151 |                     preds[idx] = heatmaps
152 | 
153 |             yield preds, [orig_sizes[j] for j in batch_image_idxs]
154 | 
155 |         torch.cuda.empty_cache()
156 | 
```

--------------------------------------------------------------------------------
/benchmark/utils/metrics.py:
--------------------------------------------------------------------------------

```python
  1 | from functools import partial
  2 | from itertools import repeat
  3 | 
  4 | import numpy as np
  5 | from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
  6 | 
  7 | 
  8 | def box_area(box):
  9 |     return (box[2] - box[0]) * (box[3] - box[1])
 10 | 
 11 | 
 12 | def calculate_iou(box1, box2, box1_only=False):
 13 |     intersection = intersection_area(box1, box2)
 14 |     union = box_area(box1)
 15 |     if not box1_only:
 16 |         union += box_area(box2) - intersection
 17 | 
 18 |     if union == 0:
 19 |         return 0
 20 |     return intersection / union
 21 | 
 22 | 
 23 | def match_boxes(preds, references):
 24 |     num_actual = len(references)
 25 |     num_predicted = len(preds)
 26 | 
 27 |     iou_matrix = np.zeros((num_actual, num_predicted))
 28 |     for i, actual in enumerate(references):
 29 |         for j, pred in enumerate(preds):
 30 |             iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True)
 31 | 
 32 |     sorted_indices = np.argsort(iou_matrix, axis=None)[::-1]
 33 |     sorted_ious = iou_matrix.flatten()[sorted_indices]
 34 |     actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape)
 35 | 
 36 |     assigned_actual = set()
 37 |     assigned_pred = set()
 38 | 
 39 |     matches = []
 40 |     for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious):
 41 |         i, j = idx
 42 |         if i not in assigned_actual and j not in assigned_pred:
 43 |             iou_val = iou_matrix[i, j]
 44 |             if iou_val > .95: # Account for rounding on box edges
 45 |                 iou_val = 1.0
 46 |             matches.append((i, j, iou_val))
 47 |             assigned_actual.add(i)
 48 |             assigned_pred.add(j)
 49 | 
 50 |     unassigned_actual = set(range(num_actual)) - assigned_actual
 51 |     unassigned_pred = set(range(num_predicted)) - assigned_pred
 52 |     matches.extend([(i, None, -1.0) for i in unassigned_actual])
 53 |     matches.extend([(None, j, 0.0) for j in unassigned_pred])
 54 | 
 55 |     return matches
 56 | 
 57 | def penalized_iou_score(preds, references):
 58 |     matches = match_boxes(preds, references)
 59 |     iou = sum([match[2] for match in matches]) / len(matches)
 60 |     return iou
 61 | 
 62 | def intersection_pixels(box1, box2):
 63 |     x_left = max(box1[0], box2[0])
 64 |     y_top = max(box1[1], box2[1])
 65 |     x_right = min(box1[2], box2[2])
 66 |     y_bottom = min(box1[3], box2[3])
 67 | 
 68 |     if x_right < x_left or y_bottom < y_top:
 69 |         return set()
 70 | 
 71 |     x_left, x_right = int(x_left), int(x_right)
 72 |     y_top, y_bottom = int(y_top), int(y_bottom)
 73 | 
 74 |     coords = np.meshgrid(np.arange(x_left, x_right), np.arange(y_top, y_bottom))
 75 |     pixels = set(zip(coords[0].flat, coords[1].flat))
 76 | 
 77 |     return pixels
 78 | 
 79 | 
 80 | def calculate_coverage(box, other_boxes, penalize_double=False):
 81 |     box_area = (box[2] - box[0]) * (box[3] - box[1])
 82 |     if box_area == 0:
 83 |         return 0
 84 | 
 85 |     # find total coverage of the box
 86 |     covered_pixels = set()
 87 |     double_coverage = list()
 88 |     for other_box in other_boxes:
 89 |         ia = intersection_pixels(box, other_box)
 90 |         double_coverage.append(list(covered_pixels.intersection(ia)))
 91 |         covered_pixels = covered_pixels.union(ia)
 92 | 
 93 |     # Penalize double coverage - having multiple bboxes overlapping the same pixels
 94 |     double_coverage_penalty = len(double_coverage)
 95 |     if not penalize_double:
 96 |         double_coverage_penalty = 0
 97 |     covered_pixels_count = max(0, len(covered_pixels) - double_coverage_penalty)
 98 |     return covered_pixels_count / box_area
 99 | 
100 | 
101 | def intersection_area(box1, box2):
102 |     x_left = max(box1[0], box2[0])
103 |     y_top = max(box1[1], box2[1])
104 |     x_right = min(box1[2], box2[2])
105 |     y_bottom = min(box1[3], box2[3])
106 | 
107 |     if x_right < x_left or y_bottom < y_top:
108 |         return 0.0
109 | 
110 |     return (x_right - x_left) * (y_bottom - y_top)
111 | 
112 | 
113 | def calculate_coverage_fast(box, other_boxes, penalize_double=False):
114 |     box = np.array(box)
115 |     other_boxes = np.array(other_boxes)
116 | 
117 |     # Calculate box area
118 |     box_area = (box[2] - box[0]) * (box[3] - box[1])
119 |     if box_area == 0:
120 |         return 0
121 | 
122 |     x_left = np.maximum(box[0], other_boxes[:, 0])
123 |     y_top = np.maximum(box[1], other_boxes[:, 1])
124 |     x_right = np.minimum(box[2], other_boxes[:, 2])
125 |     y_bottom = np.minimum(box[3], other_boxes[:, 3])
126 | 
127 |     widths = np.maximum(0, x_right - x_left)
128 |     heights = np.maximum(0, y_bottom - y_top)
129 |     intersect_areas = widths * heights
130 | 
131 |     total_intersect = np.sum(intersect_areas)
132 | 
133 |     return min(1.0, total_intersect / box_area)
134 | 
135 | 
136 | def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True):
137 |     if len(references) == 0:
138 |         return {
139 |             "precision": 1,
140 |             "recall": 1,
141 |         }
142 | 
143 |     if len(preds) == 0:
144 |         return {
145 |             "precision": 0,
146 |             "recall": 0,
147 |         }
148 | 
149 |     # If we're not penalizing double coverage, we can use a faster calculation
150 |     coverage_func = calculate_coverage_fast
151 |     if penalize_double:
152 |         coverage_func = calculate_coverage
153 | 
154 |     with ThreadPoolExecutor(max_workers=workers) as executor:
155 |         precision_func = partial(coverage_func, penalize_double=penalize_double)
156 |         precision_iou = executor.map(precision_func, preds, repeat(references))
157 |         reference_iou = executor.map(coverage_func, references, repeat(preds))
158 | 
159 |     precision_classes = [1 if i > threshold else 0 for i in precision_iou]
160 |     precision = sum(precision_classes) / len(precision_classes)
161 | 
162 |     recall_classes = [1 if i > threshold else 0 for i in reference_iou]
163 |     recall = sum(recall_classes) / len(recall_classes)
164 | 
165 |     return {
166 |         "precision": precision,
167 |         "recall": recall,
168 |     }
169 | 
170 | 
171 | def mean_coverage(preds, references):
172 |     coverages = []
173 | 
174 |     for box1 in references:
175 |         coverage = calculate_coverage(box1, preds)
176 |         coverages.append(coverage)
177 | 
178 |     for box2 in preds:
179 |         coverage = calculate_coverage(box2, references)
180 |         coverages.append(coverage)
181 | 
182 |     # Calculate the average coverage over all comparisons
183 |     if len(coverages) == 0:
184 |         return 0
185 |     coverage = sum(coverages) / len(coverages)
186 |     return {"coverage": coverage}
187 | 
188 | 
189 | def rank_accuracy(preds, references):
190 |     # Preds and references need to be aligned so each position refers to the same bbox
191 |     pairs = []
192 |     for i, pred in enumerate(preds):
193 |         for j, pred2 in enumerate(preds):
194 |             if i == j:
195 |                 continue
196 |             pairs.append((i, j, pred > pred2))
197 | 
198 |     # Find how many of the prediction rankings are correct
199 |     correct = 0
200 |     for i, ref in enumerate(references):
201 |         for j, ref2 in enumerate(references):
202 |             if (i, j, ref > ref2) in pairs:
203 |                 correct += 1
204 | 
205 |     return correct / len(pairs)
```

--------------------------------------------------------------------------------
/surya/common/donut/processor.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import Dict, Union, Optional, List, Iterable
  2 | 
  3 | import cv2
  4 | from torch import TensorType
  5 | from transformers import ImageProcessingMixin
  6 | from transformers.image_processing_utils import BatchFeature
  7 | from transformers.image_transforms import pad, normalize
  8 | from transformers.image_utils import (
  9 |     ImageInput,
 10 |     ChannelDimension,
 11 |     make_list_of_images,
 12 |     get_image_size,
 13 | )
 14 | import numpy as np
 15 | from PIL import Image
 16 | import PIL
 17 | from transformers.utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
 18 | 
 19 | from surya.common.s3 import S3DownloaderMixin
 20 | from surya.settings import settings
 21 | 
 22 | 
 23 | class SuryaEncoderImageProcessor(S3DownloaderMixin, ImageProcessingMixin):
 24 |     def __init__(
 25 |         self,
 26 |         *args,
 27 |         max_size=None,
 28 |         align_long_axis=False,
 29 |         rescale_factor: Union[int, float] = 1 / 255,
 30 |         image_mean: Optional[Union[float, List[float]]] = None,
 31 |         image_std: Optional[Union[float, List[float]]] = None,
 32 |         **kwargs,
 33 |     ):
 34 |         super().__init__(*args, **kwargs)
 35 | 
 36 |         self.patch_size = kwargs.get("patch_size", (4, 4))
 37 |         self.max_size = max_size
 38 |         self.do_align_long_axis = align_long_axis
 39 |         self.resample = Image.Resampling.BILINEAR
 40 |         self.rescale_factor = rescale_factor
 41 |         self.image_mean = (
 42 |             image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
 43 |         )
 44 |         self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
 45 | 
 46 |     def __call__(self, images, **kwargs) -> PIL.Image.Image:
 47 |         """Preprocess an image or a batch of images."""
 48 |         return self.preprocess(images, **kwargs)
 49 | 
 50 |     @classmethod
 51 |     def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4):
 52 |         max_width, max_height = size["width"], size["height"]
 53 | 
 54 |         resized_image = cv2.resize(
 55 |             image, (max_width, max_height), interpolation=interpolation
 56 |         )
 57 |         resized_image = resized_image.transpose(2, 0, 1)
 58 | 
 59 |         return resized_image
 60 | 
 61 |     def process_inner(self, images: List[np.ndarray]):
 62 |         assert images[0].shape[2] == 3  # RGB input images, channel dim last
 63 | 
 64 |         if self.do_align_long_axis:
 65 |             # Rotate if the bbox is wider than it is tall
 66 |             images = [
 67 |                 SuryaEncoderImageProcessor.align_long_axis(
 68 |                     image, size=self.max_size, input_data_format=ChannelDimension.LAST
 69 |                 )
 70 |                 for image in images
 71 |             ]
 72 | 
 73 |             # Verify that the image is wider than it is tall
 74 |             for img in images:
 75 |                 assert img.shape[1] >= img.shape[0]
 76 | 
 77 |         # This also applies the right channel dim format, to channel x height x width
 78 |         images = [
 79 |             SuryaEncoderImageProcessor.numpy_resize(img, self.max_size, self.resample)
 80 |             for img in images
 81 |         ]
 82 |         assert images[0].shape[0] == 3  # RGB input images, channel dim first
 83 | 
 84 |         # Convert to float32 for rescale/normalize
 85 |         images = [img.astype(np.float32) for img in images]
 86 | 
 87 |         # Pads with 255 (whitespace)
 88 |         # Pad to max size to improve performance
 89 |         max_size = self.max_size
 90 |         images = [
 91 |             SuryaEncoderImageProcessor.pad_image(
 92 |                 image=image,
 93 |                 size=max_size,
 94 |                 input_data_format=ChannelDimension.FIRST,
 95 |                 pad_value=settings.RECOGNITION_PAD_VALUE,
 96 |             )
 97 |             for image in images
 98 |         ]
 99 | 
100 |         # Rescale and normalize
101 |         for idx in range(len(images)):
102 |             images[idx] = (images[idx].astype(np.float64) * self.rescale_factor).astype(
103 |                 np.float32
104 |             )
105 | 
106 |         images = [
107 |             SuryaEncoderImageProcessor.normalize(
108 |                 img,
109 |                 mean=self.image_mean,
110 |                 std=self.image_std,
111 |                 input_data_format=ChannelDimension.FIRST,
112 |             )
113 |             for img in images
114 |         ]
115 | 
116 |         return images
117 | 
118 |     def preprocess(
119 |         self,
120 |         images: ImageInput,
121 |         return_tensors: Optional[Union[str, TensorType]] = None,
122 |         **kwargs,
123 |     ) -> PIL.Image.Image:
124 |         images = make_list_of_images(images)
125 | 
126 |         # Convert to numpy for later processing steps
127 |         images = [np.array(img) for img in images]
128 |         images = self.process_inner(images)
129 | 
130 |         data = {"pixel_values": images}
131 |         return BatchFeature(data=data, tensor_type=return_tensors)
132 | 
133 |     @classmethod
134 |     def pad_image(
135 |         cls,
136 |         image: np.ndarray,
137 |         size: Dict[str, int],
138 |         data_format: Optional[Union[str, ChannelDimension]] = None,
139 |         input_data_format: Optional[Union[str, ChannelDimension]] = None,
140 |         pad_value: float = 0.0,
141 |     ) -> np.ndarray:
142 |         output_height, output_width = size["height"], size["width"]
143 |         input_height, input_width = get_image_size(image, channel_dim=input_data_format)
144 | 
145 |         delta_width = output_width - input_width
146 |         delta_height = output_height - input_height
147 | 
148 |         assert delta_width >= 0 and delta_height >= 0
149 | 
150 |         pad_top = delta_height // 2
151 |         pad_left = delta_width // 2
152 | 
153 |         pad_bottom = delta_height - pad_top
154 |         pad_right = delta_width - pad_left
155 | 
156 |         padding = ((pad_top, pad_bottom), (pad_left, pad_right))
157 |         return pad(
158 |             image,
159 |             padding,
160 |             data_format=data_format,
161 |             input_data_format=input_data_format,
162 |             constant_values=pad_value,
163 |         )
164 | 
165 |     @classmethod
166 |     def align_long_axis(
167 |         cls, image: np.ndarray, size: Dict[str, int], **kwargs
168 |     ) -> np.ndarray:
169 |         input_height, input_width = image.shape[:2]
170 |         output_height, output_width = size["height"], size["width"]
171 | 
172 |         if (output_width < output_height and input_width > input_height) or (
173 |             output_width > output_height and input_width < input_height
174 |         ):
175 |             image = np.rot90(image, 3)
176 | 
177 |         return image
178 | 
179 |     @classmethod
180 |     def normalize(
181 |         cls,
182 |         image: np.ndarray,
183 |         mean: Union[float, Iterable[float]],
184 |         std: Union[float, Iterable[float]],
185 |         data_format: Optional[Union[str, ChannelDimension]] = None,
186 |         input_data_format: Optional[Union[str, ChannelDimension]] = None,
187 |         **kwargs,
188 |     ) -> np.ndarray:
189 |         return normalize(
190 |             image,
191 |             mean=mean,
192 |             std=std,
193 |             data_format=data_format,
194 |             input_data_format=input_data_format,
195 |             **kwargs,
196 |         )
197 | 
```

--------------------------------------------------------------------------------
/benchmark/table_recognition.py:
--------------------------------------------------------------------------------

```python
  1 | import click
  2 | import collections
  3 | import json
  4 | 
  5 | from surya.debug.draw import draw_bboxes_on_image
  6 | from tabulate import tabulate
  7 | 
  8 | from surya.input.processing import convert_if_not_rgb
  9 | from surya.table_rec import TableRecPredictor
 10 | from surya.settings import settings
 11 | from benchmark.utils.metrics import penalized_iou_score
 12 | from benchmark.utils.tatr import load_tatr, batch_inference_tatr
 13 | import os
 14 | import time
 15 | import datasets
 16 | 
 17 | 
 18 | @click.command(help="Benchmark table rec dataset")
 19 | @click.option(
 20 |     "--results_dir",
 21 |     type=str,
 22 |     help="Path to JSON file with benchmark results.",
 23 |     default=os.path.join(settings.RESULT_DIR, "benchmark"),
 24 | )
 25 | @click.option(
 26 |     "--max_rows",
 27 |     type=int,
 28 |     help="Maximum number of images to run benchmark on.",
 29 |     default=512,
 30 | )
 31 | @click.option("--tatr", is_flag=True, help="Run table transformer.", default=False)
 32 | @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
 33 | def main(results_dir: str, max_rows: int, tatr: bool, debug: bool):
 34 |     table_rec_predictor = TableRecPredictor()
 35 | 
 36 |     pathname = "table_rec_bench"
 37 |     # These have already been shuffled randomly, so sampling from the start is fine
 38 |     split = "train"
 39 |     if max_rows is not None:
 40 |         split = f"train[:{max_rows}]"
 41 |     dataset = datasets.load_dataset(settings.TABLE_REC_BENCH_DATASET_NAME, split=split)
 42 |     images = list(dataset["image"])
 43 |     images = convert_if_not_rgb(images)
 44 | 
 45 |     if settings.TABLE_REC_STATIC_CACHE:
 46 |         # Run through one batch to compile the model
 47 |         table_rec_predictor(images[:1])
 48 | 
 49 |     start = time.time()
 50 |     table_rec_predictions = table_rec_predictor(images)
 51 |     surya_time = time.time() - start
 52 | 
 53 |     folder_name = os.path.basename(pathname).split(".")[0]
 54 |     result_path = os.path.join(results_dir, folder_name)
 55 |     os.makedirs(result_path, exist_ok=True)
 56 | 
 57 |     page_metrics = collections.OrderedDict()
 58 |     mean_col_iou = 0
 59 |     mean_row_iou = 0
 60 |     for idx, (pred, image) in enumerate(zip(table_rec_predictions, images)):
 61 |         row = dataset[idx]
 62 |         pred_row_boxes = [p.bbox for p in pred.rows]
 63 |         pred_col_bboxes = [p.bbox for p in pred.cols]
 64 |         actual_row_bboxes = [r["bbox"] for r in row["rows"]]
 65 |         actual_col_bboxes = [c["bbox"] for c in row["columns"]]
 66 |         row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
 67 |         col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
 68 |         page_results = {
 69 |             "row_score": row_score,
 70 |             "col_score": col_score,
 71 |             "row_count": len(actual_row_bboxes),
 72 |             "col_count": len(actual_col_bboxes),
 73 |         }
 74 | 
 75 |         mean_col_iou += col_score
 76 |         mean_row_iou += row_score
 77 | 
 78 |         page_metrics[idx] = page_results
 79 | 
 80 |         if debug:
 81 |             # Save debug images
 82 |             draw_img = image.copy()
 83 |             draw_bboxes_on_image(
 84 |                 pred_row_boxes,
 85 |                 draw_img,
 86 |                 [f"Row {i}" for i in range(len(pred_row_boxes))],
 87 |             )
 88 |             draw_bboxes_on_image(
 89 |                 pred_col_bboxes,
 90 |                 draw_img,
 91 |                 [f"Col {i}" for i in range(len(pred_col_bboxes))],
 92 |                 color="blue",
 93 |             )
 94 |             draw_img.save(os.path.join(result_path, f"{idx}_bbox.png"))
 95 | 
 96 |             actual_draw_image = image.copy()
 97 |             draw_bboxes_on_image(
 98 |                 actual_row_bboxes,
 99 |                 actual_draw_image,
100 |                 [f"Row {i}" for i in range(len(actual_row_bboxes))],
101 |             )
102 |             draw_bboxes_on_image(
103 |                 actual_col_bboxes,
104 |                 actual_draw_image,
105 |                 [f"Col {i}" for i in range(len(actual_col_bboxes))],
106 |                 color="blue",
107 |             )
108 |             actual_draw_image.save(os.path.join(result_path, f"{idx}_actual.png"))
109 | 
110 |     mean_col_iou /= len(table_rec_predictions)
111 |     mean_row_iou /= len(table_rec_predictions)
112 | 
113 |     out_data = {
114 |         "surya": {
115 |             "time": surya_time,
116 |             "mean_row_iou": mean_row_iou,
117 |             "mean_col_iou": mean_col_iou,
118 |             "page_metrics": page_metrics,
119 |         }
120 |     }
121 | 
122 |     if tatr:
123 |         tatr_model = load_tatr()
124 |         start = time.time()
125 |         tatr_predictions = batch_inference_tatr(tatr_model, images, 1)
126 |         tatr_time = time.time() - start
127 | 
128 |         page_metrics = collections.OrderedDict()
129 |         mean_col_iou = 0
130 |         mean_row_iou = 0
131 |         for idx, pred in enumerate(tatr_predictions):
132 |             row = dataset[idx]
133 |             pred_row_boxes = [p["bbox"] for p in pred["rows"]]
134 |             pred_col_bboxes = [p["bbox"] for p in pred["cols"]]
135 |             actual_row_bboxes = [r["bbox"] for r in row["rows"]]
136 |             actual_col_bboxes = [c["bbox"] for c in row["columns"]]
137 |             row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
138 |             col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
139 |             page_results = {
140 |                 "row_score": row_score,
141 |                 "col_score": col_score,
142 |                 "row_count": len(actual_row_bboxes),
143 |                 "col_count": len(actual_col_bboxes),
144 |             }
145 | 
146 |             mean_col_iou += col_score
147 |             mean_row_iou += row_score
148 | 
149 |             page_metrics[idx] = page_results
150 | 
151 |         mean_col_iou /= len(tatr_predictions)
152 |         mean_row_iou /= len(tatr_predictions)
153 | 
154 |         out_data["tatr"] = {
155 |             "time": tatr_time,
156 |             "mean_row_iou": mean_row_iou,
157 |             "mean_col_iou": mean_col_iou,
158 |             "page_metrics": page_metrics,
159 |         }
160 | 
161 |     with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
162 |         json.dump(out_data, f, indent=4)
163 | 
164 |     table = [
165 |         ["Model", "Row Intersection", "Col Intersection", "Time Per Image"],
166 |         [
167 |             "Surya",
168 |             f"{out_data['surya']['mean_row_iou']:.2f}",
169 |             f"{out_data['surya']['mean_col_iou']:.5f}",
170 |             f"{surya_time / len(images):.5f}",
171 |         ],
172 |     ]
173 | 
174 |     if tatr:
175 |         table.append(
176 |             [
177 |                 "Table transformer",
178 |                 f"{out_data['tatr']['mean_row_iou']:.2f}",
179 |                 f"{out_data['tatr']['mean_col_iou']:.5f}",
180 |                 f"{tatr_time / len(images):.5f}",
181 |             ]
182 |         )
183 | 
184 |     print(tabulate(table, headers="firstrow", tablefmt="github"))
185 | 
186 |     print(
187 |         "Intersection is the average of the intersection % between each actual row/column, and the predictions.  With penalties for too many/few predictions."
188 |     )
189 |     print(
190 |         "Note that table transformers is unbatched, since the example code in the repo is unbatched."
191 |     )
192 |     print(f"Wrote results to {result_path}")
193 | 
194 | 
195 | if __name__ == "__main__":
196 |     main()
197 | 
```

--------------------------------------------------------------------------------
/surya/table_rec/model/decoder.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import Optional, Tuple, Union
  2 | 
  3 | import torch
  4 | from torch import nn
  5 | 
  6 | from surya.common.adetr.decoder import SuryaADETRDecoderModel, SuryaADETRDecoderPreTrainedModel
  7 | from surya.table_rec.model.config import TableRecModelOutput
  8 | from surya.table_rec.shaper import LabelShaper
  9 | from surya.settings import settings
 10 | 
 11 | 
 12 | class LabelEmbedding(nn.Module):
 13 |     def __init__(self, config):
 14 |         super().__init__()
 15 | 
 16 |         # Bboxes
 17 |         self.w_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 18 |         self.h_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 19 |         self.cx_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 20 |         self.cy_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 21 |         self.xskew_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 22 |         self.yskew_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 23 | 
 24 |         self.x1_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 25 |         self.y1_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 26 |         self.x2_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 27 |         self.y2_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 28 |         self.x3_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 29 |         self.y3_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 30 |         self.x4_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 31 |         self.y4_embed = nn.Embedding(config.vocab_size, config.box_embed_size)
 32 | 
 33 |         # Get indexes for passed in tensor
 34 |         shaper = LabelShaper()
 35 |         self.component_idxs = shaper.component_idx_dict()
 36 |         merge_count = shaper.get_box_property("merges")[1] + config.special_token_count
 37 |         category_count = shaper.get_box_property("category")[1] + config.special_token_count
 38 | 
 39 |         # Other box properties
 40 |         self.category_embed = nn.Embedding(category_count, config.property_embed_size)
 41 |         self.merge_embed = nn.Embedding(merge_count, config.property_embed_size)
 42 |         self.colspan_embed = nn.Embedding(config.vocab_size, config.property_embed_size)
 43 | 
 44 |         self.config = config
 45 | 
 46 |     def forward(self, boxes: torch.LongTensor, *args):
 47 |         # Need to keep *args for compatibility with common decoder
 48 |         boxes = boxes.to(torch.long).clamp(0, self.config.vocab_size)
 49 | 
 50 |         boxes_unbound = boxes.to(torch.long).unbind(dim=-1)
 51 |         cx, cy, w, h, xskew, yskew = boxes_unbound[self.component_idxs["bbox"][0]:self.component_idxs["bbox"][1]]
 52 |         category = boxes_unbound[self.component_idxs["category"][0]:self.component_idxs["category"][1]][0]
 53 |         merges = boxes_unbound[self.component_idxs["merges"][0]:self.component_idxs["merges"][1]][0]
 54 |         colspan = boxes_unbound[self.component_idxs["colspan"][0]:self.component_idxs["colspan"][1]][0]
 55 | 
 56 |         xskew_actual = ((xskew - self.config.bbox_size // 2) / 2).to(torch.long)
 57 |         yskew_actual = ((yskew - self.config.bbox_size // 2) / 2).to(torch.long)
 58 | 
 59 |         x1 = (cx - w // 2 - xskew_actual).clamp(0, self.config.bbox_size).to(torch.long)
 60 |         y1 = (cy - h // 2 - yskew_actual).clamp(0, self.config.bbox_size).to(torch.long)
 61 |         x3 = (cx + w // 2 + xskew_actual).clamp(0, self.config.bbox_size).to(torch.long)
 62 |         y3 = (cy + h // 2 + yskew_actual).clamp(0, self.config.bbox_size).to(torch.long)
 63 | 
 64 |         size_embeds = self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)
 65 |         skew_embeds = self.xskew_embed(xskew) + self.yskew_embed(yskew)
 66 |         corner_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x3_embed(x3) + self.y3_embed(y3)
 67 |         box_embeds = size_embeds + skew_embeds + corner_embeds
 68 | 
 69 |         property_embeds = self.category_embed(category) + self.merge_embed(merges) + self.colspan_embed(colspan)
 70 | 
 71 |         # Cat bbox and property embeddings
 72 |         embedded = torch.cat([box_embeds, property_embeds], dim=-1)
 73 |         return embedded
 74 | 
 75 | 
 76 | class SuryaTableRecDecoder(SuryaADETRDecoderPreTrainedModel):
 77 |     _tied_weights_keys = None
 78 | 
 79 |     def __init__(self, config, **kwargs):
 80 |         super().__init__(config)
 81 |         embed_tokens = LabelEmbedding(config)
 82 |         self.model = SuryaADETRDecoderModel(
 83 |             config,
 84 |             embedder=embed_tokens,
 85 |             static_cache=settings.TABLE_REC_STATIC_CACHE,
 86 |             max_boxes=settings.TABLE_REC_MAX_BOXES
 87 |         )
 88 |         self.vocab_size = config.vocab_size
 89 | 
 90 |         shaper = LabelShaper()
 91 |         property_heads = {}
 92 |         for k in shaper.property_keys:
 93 |             _, kcount, mode = shaper.get_box_property(k)
 94 |             property_heads[k] = nn.Linear(config.hidden_size, kcount, bias=False)
 95 | 
 96 |         self.box_property_heads = nn.ModuleDict(property_heads)
 97 |         self.pre_output_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
 98 | 
 99 |         # Initialize weights and apply final processing
100 |         self.post_init()
101 | 
102 |     def get_input_embeddings(self):
103 |         return self.model.embed_tokens
104 | 
105 |     def set_input_embeddings(self, value):
106 |         self.model.embed_tokens = value
107 | 
108 |     def get_output_embeddings(self):
109 |         return self.lm_head
110 | 
111 |     def set_output_embeddings(self, new_embeddings):
112 |         self.lm_head = new_embeddings
113 | 
114 |     def set_decoder(self, decoder):
115 |         self.model = decoder
116 | 
117 |     def get_decoder(self):
118 |         return self.model
119 | 
120 |     # Ignore copy
121 |     def forward(
122 |         self,
123 |         input_ids: Optional[torch.LongTensor] = None,
124 |         cache_position: Optional[torch.LongTensor] = None,
125 |         attention_mask: Optional[torch.Tensor] = None,
126 |         encoder_hidden_states: Optional[torch.FloatTensor] = None,
127 |         encoder_attention_mask: Optional[torch.FloatTensor] = None,
128 |         use_cache: Optional[bool] = None,
129 |         prefill: bool = False,
130 |         **kwargs
131 |     ) -> Union[Tuple, TableRecModelOutput]:
132 |         outputs = self.model(
133 |             input_ids=input_ids,
134 |             cache_position=cache_position,
135 |             attention_mask=attention_mask,
136 |             encoder_hidden_states=encoder_hidden_states,
137 |             encoder_attention_mask=encoder_attention_mask,
138 |             use_cache=use_cache,
139 |             output_hidden_states=True,
140 |             return_dict=True,
141 |             prefill=prefill,
142 |         )
143 | 
144 |         hidden_states = self.pre_output_norm(outputs[0])
145 |         box_property_logits = {}
146 |         for key in self.box_property_heads:
147 |             box_property_logits[key] = self.box_property_heads[key](hidden_states)
148 | 
149 |         bbox_logits = nn.functional.sigmoid(box_property_logits["bbox"])
150 |         box_property_logits["bbox"] = bbox_logits
151 | 
152 |         return TableRecModelOutput(
153 |             box_property_logits=box_property_logits,
154 |             hidden_states=hidden_states,
155 |         )
```

--------------------------------------------------------------------------------
/surya/common/polygon.py:
--------------------------------------------------------------------------------

```python
  1 | import copy
  2 | from typing import List, Optional
  3 | 
  4 | import numpy as np
  5 | from pydantic import BaseModel, field_validator, computed_field
  6 | import numbers
  7 | 
  8 | 
  9 | class PolygonBox(BaseModel):
 10 |     polygon: List[List[float]]
 11 |     confidence: Optional[float] = None
 12 | 
 13 |     @field_validator("polygon", mode="before")
 14 |     @classmethod
 15 |     def convert_bbox_to_polygon(cls, value):
 16 |         if isinstance(value, (list, tuple)) and len(value) == 4:
 17 |             if all(isinstance(x, numbers.Number) for x in value):
 18 |                 value = [float(v) for v in value]
 19 |                 x_min, y_min, x_max, y_max = value
 20 |                 polygon = [
 21 |                     [x_min, y_min],
 22 |                     [x_max, y_min],
 23 |                     [x_max, y_max],
 24 |                     [x_min, y_max],
 25 |                 ]
 26 |                 return polygon
 27 |             elif all(
 28 |                 isinstance(point, (list, tuple)) and len(point) == 2 for point in value
 29 |             ):
 30 |                 value = [[float(v) for v in point] for point in value]
 31 |                 return value
 32 |         elif isinstance(value, np.ndarray):
 33 |             if value.shape == (4, 2):
 34 |                 return value.tolist()
 35 | 
 36 |         raise ValueError(
 37 |             f"Input must be either a bbox [x_min, y_min, x_max, y_max] or a polygon with 4 corners [(x,y), (x,y), (x,y), (x,y)].  All values must be numeric. You passed {value} of type {type(value)}.  The first value is of type {type(value[0])}."
 38 |         )
 39 | 
 40 |     @property
 41 |     def height(self):
 42 |         return self.bbox[3] - self.bbox[1]
 43 | 
 44 |     @property
 45 |     def width(self):
 46 |         return self.bbox[2] - self.bbox[0]
 47 | 
 48 |     @property
 49 |     def area(self):
 50 |         return self.width * self.height
 51 | 
 52 |     @computed_field
 53 |     @property
 54 |     def bbox(self) -> List[float]:
 55 |         x_coords = [point[0] for point in self.polygon]
 56 |         y_coords = [point[1] for point in self.polygon]
 57 |         return [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
 58 | 
 59 |     def rescale(self, processor_size, image_size):
 60 |         # Point is in x, y format
 61 |         page_width, page_height = processor_size
 62 | 
 63 |         img_width, img_height = image_size
 64 |         width_scaler = img_width / page_width
 65 |         height_scaler = img_height / page_height
 66 | 
 67 |         for corner in self.polygon:
 68 |             corner[0] = int(corner[0] * width_scaler)
 69 |             corner[1] = int(corner[1] * height_scaler)
 70 | 
 71 |     def round(self, divisor):
 72 |         for corner in self.polygon:
 73 |             corner[0] = int(corner[0] / divisor) * divisor
 74 |             corner[1] = int(corner[1] / divisor) * divisor
 75 | 
 76 |     def fit_to_bounds(self, bounds):
 77 |         new_corners = copy.deepcopy(self.polygon)
 78 |         for corner in new_corners:
 79 |             corner[0] = max(min(corner[0], bounds[2]), bounds[0])
 80 |             corner[1] = max(min(corner[1], bounds[3]), bounds[1])
 81 |         self.polygon = new_corners
 82 | 
 83 |     def merge(self, other):
 84 |         x1 = min(self.bbox[0], other.bbox[0])
 85 |         y1 = min(self.bbox[1], other.bbox[1])
 86 |         x2 = max(self.bbox[2], other.bbox[2])
 87 |         y2 = max(self.bbox[3], other.bbox[3])
 88 |         self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
 89 | 
 90 |     def merge_left(self, other):
 91 |         x1 = min(self.bbox[0], other.bbox[0])
 92 |         self.polygon[0][0] = x1
 93 |         self.polygon[3][0] = x1
 94 | 
 95 |     def merge_right(self, other):
 96 |         x2 = max(self.bbox[2], other.bbox[2])
 97 |         self.polygon[1][0] = x2
 98 |         self.polygon[2][0] = x2
 99 | 
100 |     def expand(self, x_margin: float, y_margin: float):
101 |         new_polygon = []
102 |         x_margin = x_margin * self.width
103 |         y_margin = y_margin * self.height
104 |         for idx, poly in enumerate(self.polygon):
105 |             if idx == 0:
106 |                 new_polygon.append([int(poly[0] - x_margin), int(poly[1] - y_margin)])
107 |             elif idx == 1:
108 |                 new_polygon.append([int(poly[0] + x_margin), int(poly[1] - y_margin)])
109 |             elif idx == 2:
110 |                 new_polygon.append([int(poly[0] + x_margin), int(poly[1] + y_margin)])
111 |             elif idx == 3:
112 |                 new_polygon.append([int(poly[0] - x_margin), int(poly[1] + y_margin)])
113 |         self.polygon = new_polygon
114 | 
115 |     def intersection_polygon(self, other) -> List[List[float]]:
116 |         new_poly = []
117 |         for i in range(4):
118 |             if i == 0:
119 |                 new_corner = [
120 |                     max(self.polygon[0][0], other.polygon[0][0]),
121 |                     max(self.polygon[0][1], other.polygon[0][1]),
122 |                 ]
123 |             elif i == 1:
124 |                 new_corner = [
125 |                     min(self.polygon[1][0], other.polygon[1][0]),
126 |                     max(self.polygon[1][1], other.polygon[1][1]),
127 |                 ]
128 |             elif i == 2:
129 |                 new_corner = [
130 |                     min(self.polygon[2][0], other.polygon[2][0]),
131 |                     min(self.polygon[2][1], other.polygon[2][1]),
132 |                 ]
133 |             elif i == 3:
134 |                 new_corner = [
135 |                     max(self.polygon[3][0], other.polygon[3][0]),
136 |                     min(self.polygon[3][1], other.polygon[3][1]),
137 |                 ]
138 |             new_poly.append(new_corner)
139 | 
140 |         return new_poly
141 | 
142 |     def intersection_area(self, other, x_margin=0, y_margin=0):
143 |         x_overlap = self.x_overlap(other, x_margin)
144 |         y_overlap = self.y_overlap(other, y_margin)
145 |         return x_overlap * y_overlap
146 | 
147 |     def x_overlap(self, other, x_margin=0):
148 |         return max(
149 |             0,
150 |             min(self.bbox[2] + x_margin, other.bbox[2] + x_margin)
151 |             - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin),
152 |         )
153 | 
154 |     def y_overlap(self, other, y_margin=0):
155 |         return max(
156 |             0,
157 |             min(self.bbox[3] + y_margin, other.bbox[3] + y_margin)
158 |             - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin),
159 |         )
160 | 
161 |     def intersection_pct(self, other, x_margin=0, y_margin=0):
162 |         assert 0 <= x_margin <= 1
163 |         assert 0 <= y_margin <= 1
164 |         if self.area == 0:
165 |             return 0
166 | 
167 |         if x_margin:
168 |             x_margin = int(min(self.width, other.width) * x_margin)
169 |         if y_margin:
170 |             y_margin = int(min(self.height, other.height) * y_margin)
171 | 
172 |         intersection = self.intersection_area(other, x_margin, y_margin)
173 |         return intersection / self.area
174 | 
175 |     def shift(self, x_shift: float | None = None, y_shift: float | None = None):
176 |         if x_shift is not None:
177 |             for corner in self.polygon:
178 |                 corner[0] += x_shift
179 |         if y_shift is not None:
180 |             for corner in self.polygon:
181 |                 corner[1] += y_shift
182 | 
183 |     def clamp(self, bbox: List[float]):
184 |         for corner in self.polygon:
185 |             corner[0] = max(min(corner[0], bbox[2]), bbox[0])
186 |             corner[1] = max(min(corner[1], bbox[3]), bbox[1])
187 | 
188 |     @property
189 |     def center(self):
190 |         return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2]
191 | 
192 |     def distance(self, other):
193 |         center = self.center
194 |         other_center = other.center
195 | 
196 |         return (
197 |             (center[0] - other_center[0]) ** 2 + (center[1] - other_center[1]) ** 2
198 |         ) ** 0.5
199 | 
200 |     def __hash__(self):
201 |         return hash(tuple(self.bbox))
202 | 
```

--------------------------------------------------------------------------------
/surya/settings.py:
--------------------------------------------------------------------------------

```python
  1 | import os
  2 | from typing import Callable, Dict, Optional
  3 | 
  4 | import torch
  5 | from dotenv import find_dotenv
  6 | from pydantic import computed_field
  7 | from pydantic_settings import BaseSettings
  8 | from pathlib import Path
  9 | from platformdirs import user_cache_dir
 10 | 
 11 | 
 12 | class Settings(BaseSettings):
 13 |     # General
 14 |     TORCH_DEVICE: Optional[str] = None
 15 |     IMAGE_DPI: int = 96  # Used for detection, layout, reading order
 16 |     IMAGE_DPI_HIGHRES: int = 192  # Used for OCR, table rec
 17 |     IN_STREAMLIT: bool = False  # Whether we're running in streamlit
 18 |     FLATTEN_PDF: bool = True  # Flatten PDFs by merging form fields before processing
 19 |     DISABLE_TQDM: bool = False  # Disable tqdm progress bars
 20 |     S3_BASE_URL: str = "https://models.datalab.to"
 21 |     PARALLEL_DOWNLOAD_WORKERS: int = (
 22 |         10  # Number of workers for parallel model downloads
 23 |     )
 24 |     MODEL_CACHE_DIR: str = str(Path(user_cache_dir("datalab")) / "models")
 25 |     LOGLEVEL: str = "INFO"  # Logging level
 26 | 
 27 |     # Paths
 28 |     DATA_DIR: str = "data"
 29 |     RESULT_DIR: str = "results"
 30 |     BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 31 |     FONT_DIR: str = os.path.join(BASE_DIR, "static", "fonts")
 32 | 
 33 |     @computed_field
 34 |     def TORCH_DEVICE_MODEL(self) -> str:
 35 |         if self.TORCH_DEVICE is not None:
 36 |             return self.TORCH_DEVICE
 37 | 
 38 |         if torch.cuda.is_available():
 39 |             return "cuda"
 40 | 
 41 |         if torch.backends.mps.is_available():
 42 |             return "mps"
 43 | 
 44 |         try:
 45 |             import torch_xla
 46 | 
 47 |             if len(torch_xla.devices()) > 0:
 48 |                 return "xla"
 49 |         except Exception:
 50 |             pass
 51 | 
 52 |         return "cpu"
 53 | 
 54 |     # Text detection
 55 |     DETECTOR_BATCH_SIZE: Optional[int] = None  # Defaults to 2 for CPU/MPS, 32 otherwise
 56 |     DETECTOR_MODEL_CHECKPOINT: str = "s3://text_detection/2025_05_07"
 57 |     DETECTOR_BENCH_DATASET_NAME: str = "vikp/doclaynet_bench"
 58 |     DETECTOR_IMAGE_CHUNK_HEIGHT: int = (
 59 |         1400  # Height at which to slice images vertically
 60 |     )
 61 |     DETECTOR_TEXT_THRESHOLD: float = (
 62 |         0.6  # Threshold for text detection (above this is considered text)
 63 |     )
 64 |     DETECTOR_BLANK_THRESHOLD: float = (
 65 |         0.35  # Threshold for blank space (below this is considered blank)
 66 |     )
 67 |     DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min(
 68 |         8, os.cpu_count()
 69 |     )  # Number of workers for postprocessing
 70 |     DETECTOR_MIN_PARALLEL_THRESH: int = (
 71 |         3  # Minimum number of images before we parallelize
 72 |     )
 73 |     DETECTOR_BOX_Y_EXPAND_MARGIN: float = (
 74 |         0.05  # Margin by which to expand detected boxes vertically
 75 |     )
 76 |     COMPILE_DETECTOR: bool = False
 77 | 
 78 |     # Text recognition
 79 |     FOUNDATION_MODEL_CHECKPOINT: str = "s3://text_recognition/2025_09_23"
 80 |     FOUNDATION_MODEL_QUANTIZE: bool = False
 81 |     FOUNDATION_MAX_TOKENS: Optional[int] = None
 82 |     FOUNDATION_CHUNK_SIZE: Optional[int] = None
 83 |     FOUNDATION_PAD_TO_NEAREST: int = 256
 84 |     COMPILE_FOUNDATION: bool = False
 85 |     FOUNDATION_MULTI_TOKEN_MIN_CONFIDENCE: float = 0.9
 86 | 
 87 |     RECOGNITION_MODEL_CHECKPOINT: str = "s3://text_recognition/2025_09_23"
 88 |     RECOGNITION_BATCH_SIZE: Optional[int] = (
 89 |         None  # Defaults to 8 for CPU/MPS, 256 otherwise
 90 |     )
 91 |     RECOGNITION_RENDER_FONTS: Dict[str, str] = {
 92 |         "all": os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf"),
 93 |         "zh": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"),
 94 |         "ja": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"),
 95 |         "ko": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"),
 96 |     }
 97 |     RECOGNITION_FONT_DL_BASE: str = (
 98 |         "https://github.com/satbyy/go-noto-universal/releases/download/v7.0"
 99 |     )
100 |     RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench"
101 |     RECOGNITION_PAD_VALUE: int = 255  # Should be 0 or 255
102 | 
103 |     # Layout
104 |     LAYOUT_MODEL_CHECKPOINT: str = "s3://layout/2025_09_23"
105 |     LAYOUT_IMAGE_SIZE: Dict = {"height": 768, "width": 768}
106 |     LAYOUT_SLICE_MIN: Dict = {
107 |         "height": 1500,
108 |         "width": 1500,
109 |     }  # When to start slicing images
110 |     LAYOUT_SLICE_SIZE: Dict = {"height": 1200, "width": 1200}  # Size of slices
111 |     LAYOUT_BATCH_SIZE: Optional[int] = None
112 |     LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench"
113 |     LAYOUT_MAX_BOXES: int = 100
114 |     COMPILE_LAYOUT: bool = False
115 |     LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench"
116 |     ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench"
117 | 
118 |     # Table Rec
119 |     TABLE_REC_MODEL_CHECKPOINT: str = "s3://table_recognition/2025_02_18"
120 |     TABLE_REC_IMAGE_SIZE: Dict = {"height": 768, "width": 768}
121 |     TABLE_REC_MAX_BOXES: int = 150
122 |     TABLE_REC_BATCH_SIZE: Optional[int] = None
123 |     TABLE_REC_BENCH_DATASET_NAME: str = "datalab-to/fintabnet_bench"
124 |     COMPILE_TABLE_REC: bool = False
125 | 
126 |     # Texify
127 |     TEXIFY_BENCHMARK_DATASET: str = "datalab-to/texify_bench"
128 | 
129 |     # OCR Error Detection
130 |     OCR_ERROR_MODEL_CHECKPOINT: str = "s3://ocr_error_detection/2025_02_18"
131 |     OCR_ERROR_BATCH_SIZE: Optional[int] = None
132 |     COMPILE_OCR_ERROR: bool = False
133 | 
134 |     # Tesseract (for benchmarks only)
135 |     TESSDATA_PREFIX: Optional[str] = None
136 | 
137 |     COMPILE_ALL: bool = False
138 | 
139 |     @computed_field
140 |     def DETECTOR_STATIC_CACHE(self) -> bool:
141 |         return (
142 |             self.COMPILE_ALL
143 |             or self.COMPILE_DETECTOR
144 |             or self.TORCH_DEVICE_MODEL == "xla"
145 |         )  # We need to static cache and pad to batch size for XLA, since it will recompile otherwise
146 | 
147 |     @computed_field
148 |     def LAYOUT_STATIC_CACHE(self) -> bool:
149 |         return (
150 |             self.COMPILE_ALL or self.COMPILE_LAYOUT or self.TORCH_DEVICE_MODEL == "xla"
151 |         )
152 | 
153 |     @computed_field
154 |     def FOUNDATION_XLA(self) -> bool:
155 |         return (
156 |             self.TORCH_DEVICE_MODEL == "xla"
157 |         )  # We need to static cache and pad to batch size for XLA, since it will recompile otherwise
158 | 
159 |     @computed_field
160 |     def FOUNDATION_STATIC_CACHE(self) -> bool:
161 |         return (
162 |             self.COMPILE_ALL
163 |             or self.COMPILE_FOUNDATION
164 |             or self.TORCH_DEVICE_MODEL == "xla"
165 |         )  # We need to static cache and pad to batch size for XLA, since it will recompile otherwise
166 | 
167 |     @computed_field
168 |     def TABLE_REC_STATIC_CACHE(self) -> bool:
169 |         return (
170 |             self.COMPILE_ALL
171 |             or self.COMPILE_TABLE_REC
172 |             or self.TORCH_DEVICE_MODEL == "xla"
173 |         )
174 | 
175 |     @computed_field
176 |     def OCR_ERROR_STATIC_CACHE(self) -> bool:
177 |         return (
178 |             self.COMPILE_ALL
179 |             or self.COMPILE_OCR_ERROR
180 |             or self.TORCH_DEVICE_MODEL == "xla"
181 |         )
182 | 
183 |     @computed_field
184 |     def MODEL_DTYPE(self) -> torch.dtype:
185 |         if self.TORCH_DEVICE_MODEL == "cpu":
186 |             return torch.float32
187 |         if self.TORCH_DEVICE_MODEL == "xla":
188 |             return torch.bfloat16
189 |         return torch.float16
190 | 
191 |     @computed_field
192 |     def MODEL_DTYPE_BFLOAT(self) -> torch.dtype:
193 |         if self.TORCH_DEVICE_MODEL == "cpu":
194 |             return torch.float32
195 |         if self.TORCH_DEVICE_MODEL == "mps":
196 |             return torch.bfloat16
197 |         return torch.bfloat16
198 | 
199 |     @computed_field
200 |     def INFERENCE_MODE(self) -> Callable:
201 |         if self.TORCH_DEVICE_MODEL == "xla":
202 |             return torch.no_grad
203 |         return torch.inference_mode
204 | 
205 |     class Config:
206 |         env_file = find_dotenv("local.env")
207 |         extra = "ignore"
208 | 
209 | 
210 | settings = Settings()
211 | 
```

--------------------------------------------------------------------------------
/surya/foundation/cache/static_ops.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import Any, Dict, List, Optional, Tuple
  2 | import torch
  3 | from transformers import PretrainedConfig
  4 | 
  5 | from surya.foundation.cache.dynamic_ops import DynamicOpsCache
  6 | 
  7 | """
  8 | Special cache class for the surya foundation model that supports - 
  9 | 1) Static shape
 10 | 2) A custom sliding window, where image tokens stay in cache, and text tokens are popped
 11 | 3) Continuous batching - merging etc
 12 | 4) Attention mask management - To match with what's currently in the cache
 13 | 
 14 | Heavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079
 15 | """
 16 | 
 17 | 
 18 | class StaticOpsCache(DynamicOpsCache):
 19 |     def __init__(
 20 |         self,
 21 |         config: PretrainedConfig,
 22 |         batch_size: int,
 23 |         max_cache_len: int,
 24 |         text_sliding_window: int,
 25 |         device: int,
 26 |         dtype: int,
 27 |     ):
 28 |         self.text_sliding_window = text_sliding_window
 29 |         self.num_layers = config.num_hidden_layers
 30 |         self.max_batch_size = batch_size
 31 |         self.max_cache_len = max_cache_len
 32 |         self.head_dim = (
 33 |             getattr(config, "head_dim", None)
 34 |             or config.hidden_size // config.num_attention_heads
 35 |         )
 36 |         self._dtype = dtype
 37 |         self.num_key_value_heads = (
 38 |             config.num_attention_heads
 39 |             if getattr(config, "num_key_value_heads", None) is None
 40 |             else config.num_key_value_heads
 41 |         )
 42 | 
 43 |         # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125
 44 |         self.key_cache: list[torch.Tensor] = []
 45 |         self.value_cache: list[torch.Tensor] = []
 46 |         cache_shape = (
 47 |             self.max_batch_size,
 48 |             self.num_key_value_heads,
 49 |             self.max_cache_len,
 50 |             self.head_dim,
 51 |         )
 52 |         device = torch.device(device) if device is not None else None
 53 |         for _ in range(config.num_hidden_layers):
 54 |             new_layer_key_cache = torch.zeros(
 55 |                 cache_shape, dtype=self._dtype, device=device
 56 |             )
 57 |             new_layer_value_cache = torch.zeros(
 58 |                 cache_shape, dtype=self._dtype, device=device
 59 |             )
 60 |             torch._dynamo.mark_static_address(new_layer_key_cache)
 61 |             torch._dynamo.mark_static_address(new_layer_value_cache)
 62 |             self.key_cache.append(new_layer_key_cache)
 63 |             self.value_cache.append(new_layer_value_cache)
 64 | 
 65 |         self.attention_mask = torch.zeros(
 66 |             (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long
 67 |         )
 68 |         self.text_token_counts = [
 69 |             torch.zeros(self.max_batch_size, dtype=torch.long, device=device)
 70 |             for _ in range(self.num_layers)
 71 |         ]
 72 | 
 73 |         self.dtype = dtype
 74 |         self.device = device
 75 | 
 76 |     def update(
 77 |         self,
 78 |         key_states: torch.Tensor,
 79 |         value_states: torch.Tensor,
 80 |         layer_idx: int,
 81 |         cache_kwargs: Optional[Dict[str, Any]] = None,
 82 |     ) -> Tuple[torch.Tensor, torch.Tensor]:
 83 |         prefill = cache_kwargs.get("prefill", False)
 84 |         update_fn = self._prefill_update if prefill else self._decode_update
 85 |         return update_fn(
 86 |             self.key_cache[layer_idx],
 87 |             self.value_cache[layer_idx],
 88 |             key_states,
 89 |             value_states,
 90 |             self.text_token_counts[layer_idx],
 91 |             cache_kwargs,
 92 |         )
 93 | 
 94 |     def _prefill_update(
 95 |         self,
 96 |         key_cache: torch.Tensor,
 97 |         value_cache: torch.Tensor,
 98 |         key_states: torch.Tensor,
 99 |         value_states: torch.Tensor,
100 |         text_token_counts: torch.Tensor,
101 |         cache_kwargs: Optional[Dict[str, Any]] = None,
102 |     ):
103 |         cache_idxs: torch.tensor = cache_kwargs.get("cache_idxs", None)
104 |         text_lengths: List[int] = cache_kwargs.get("text_lengths", None)
105 |         assert cache_idxs is not None, "cache_idxs must be specified during prefill"
106 |         assert text_lengths is not None, "text_lengths must be specified during prefill"
107 | 
108 |         cache_idx_length = len(cache_idxs)
109 |         full_batch = len(cache_idxs) == self.max_batch_size
110 | 
111 |         # Insert key and value states at the end of the cache
112 |         new_tokens = key_states.shape[2]
113 | 
114 |         # Direct right-aligned assignment
115 |         if full_batch:
116 |             key_cache[:, :, -new_tokens:] = key_states
117 |             value_cache[:, :, -new_tokens:] = value_states
118 |         else:
119 |             key_cache[cache_idxs, :, -new_tokens:] = key_states[:cache_idx_length]
120 |             value_cache[cache_idxs, :, -new_tokens:] = value_states[:cache_idx_length]
121 | 
122 |         return key_states, value_states
123 | 
124 |     # """
125 |     # Matches the logic of the decode update, but needs to be called before the updates
126 |     # since some parts of the model depend on the attention mask
127 |     # """
128 |     def decode_attention_mask_update(
129 |         self, num_valid_tokens: torch.Tensor, cache_idxs: List[int]
130 |     ):
131 |         max_valid_tokens = num_valid_tokens.max().item()
132 |         if max_valid_tokens == 0:
133 |             # If no valid tokens, we don't need to update the attention mask
134 |             return
135 | 
136 |         # Shift the attention mask to the left by max_valid_tokens
137 |         self.attention_mask = self.attention_mask.roll(-1 * max_valid_tokens, dims=1)
138 |         self.attention_mask[:, -max_valid_tokens:] = (
139 |             1  # Full attention to all new tokens
140 |         )
141 | 
142 |     # Mirrors the logic from _prefill_update
143 |     def prefill_attention_mask_update(
144 |         self,
145 |         attention_mask: torch.Tensor,
146 |         merge_idxs: torch.Tensor,
147 |         valid_batch_size: torch.Tensor,
148 |         text_lengths: List[int],
149 |     ):
150 |         # Set from -(image_length + text_length) to end to 1 for each batch element
151 |         seq_len = attention_mask.shape[1]
152 |         self.attention_mask[merge_idxs] = (
153 |             0  # Reset the attention mask for the current batch elements
154 |         )
155 |         self.attention_mask[merge_idxs, -seq_len:] = attention_mask[:valid_batch_size]
156 | 
157 |     def _decode_update(
158 |         self,
159 |         key_cache: torch.Tensor,
160 |         value_cache: torch.Tensor,
161 |         key_states: torch.Tensor,
162 |         value_states: torch.Tensor,
163 |         text_token_counts: torch.Tensor,
164 |         cache_kwargs: Optional[Dict[str, Any]] = None,
165 |     ) -> Tuple[torch.Tensor, torch.Tensor]:
166 |         # Naive, always assumes we'll roll by a fixed amount
167 |         # Needs left padding with beacons to work properly
168 | 
169 |         num_valid_tokens: torch.Tensor = cache_kwargs.get(
170 |             "num_valid_tokens"
171 |         )  # shape: (B,)
172 |         assert num_valid_tokens is not None, (
173 |             "`num_valid_tokens` must be provided in `cache_kwargs`"
174 |         )
175 |         # (B, H, L, D)
176 | 
177 |         valid_tokens = key_states.shape[2]
178 | 
179 |         key_cache.copy_(torch.roll(key_cache, -valid_tokens, dims=2))
180 |         value_cache.copy_(torch.roll(value_cache, -valid_tokens, dims=2))
181 | 
182 |         key_cache[:, :, -valid_tokens:, :] = key_states
183 |         value_cache[:, :, -valid_tokens:, :] = value_states
184 | 
185 |         # In-place edit - Mutates
186 |         text_token_counts += num_valid_tokens
187 |         text_token_counts.clamp_(max=self.text_sliding_window)
188 |         return key_cache, value_cache
189 | 
190 |     # The attention mask managed by our kv cache automatically masks the tokens
191 |     # in the cache, so we can return full length for HF to use in other places
192 |     # This is mainly utilized in the cache_positions creation
193 |     def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
194 |         return self.max_cache_len
195 | 
```

--------------------------------------------------------------------------------
/surya/table_rec/model/config.py:
--------------------------------------------------------------------------------

```python
  1 | from dataclasses import dataclass
  2 | from typing import Dict
  3 | 
  4 | import torch
  5 | from transformers import PretrainedConfig
  6 | from transformers.utils import ModelOutput
  7 | 
  8 | from surya.common.s3 import S3DownloaderMixin
  9 | from surya.settings import settings
 10 | 
 11 | BOX_DIM = 1024
 12 | SPECIAL_TOKENS = 5
 13 | MAX_BOXES = 150
 14 | 
 15 | MERGE_KEYS = {
 16 |     "none": 0,
 17 |     "merge_up": 1,
 18 |     "merge_down": 2,
 19 |     "merge_both": 3
 20 | }
 21 | MERGE_VALUES = [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]]
 22 | 
 23 | ID_TO_CATEGORY = {
 24 |     0: 'Blank',
 25 |     1: 'Table-row',
 26 |     2: 'Table-column',
 27 |     3: 'Table-cell',
 28 |     4: 'Table'
 29 | }
 30 | CATEGORY_TO_ID = {v: k for k, v in ID_TO_CATEGORY.items()}
 31 | 
 32 | ID_TO_HEADER = {
 33 |     0: "None",
 34 |     1: "Header"
 35 | }
 36 | HEADER_TO_ID = {v: k for k, v in ID_TO_HEADER.items()}
 37 | 
 38 | BOX_PROPERTIES = [
 39 |     ("bbox", 6, "regression"),
 40 |     ("category", len(ID_TO_CATEGORY), "classification"),
 41 |     ("merges", len(MERGE_KEYS), "classification"),
 42 |     ("colspan", 1, "regression"),
 43 |     ("is_header", len(ID_TO_HEADER), "classification")
 44 | ]
 45 | 
 46 | 
 47 | @dataclass
 48 | class TableRecModelOutput(ModelOutput):
 49 |     box_property_logits: Dict[str, torch.Tensor]
 50 |     hidden_states: torch.Tensor | None = None
 51 | 
 52 | 
 53 | class SuryaTableRecConfig(S3DownloaderMixin, PretrainedConfig):
 54 |     model_type = "vision-encoder-decoder"
 55 |     is_composition = True
 56 | 
 57 |     def __init__(self, **kwargs):
 58 |         super().__init__(**kwargs)
 59 | 
 60 |         if "encoder" in kwargs:
 61 |             encoder_config = kwargs.pop("encoder")
 62 |             decoder_config = kwargs.pop("decoder")
 63 |         else:
 64 |             encoder_config = DonutSwinTableRecConfig()
 65 |             decoder_config = SuryaTableRecDecoderConfig()
 66 | 
 67 |         self.encoder = encoder_config
 68 |         self.decoder = decoder_config
 69 |         self.is_encoder_decoder = True
 70 | 
 71 |         if isinstance(decoder_config, dict):
 72 |             self.decoder_start_token_id = decoder_config["bos_token_id"]
 73 |             self.pad_token_id = decoder_config["pad_token_id"]
 74 |             self.eos_token_id = decoder_config["eos_token_id"]
 75 |         else:
 76 |             self.decoder_start_token_id = decoder_config.bos_token_id
 77 |             self.pad_token_id = decoder_config.pad_token_id
 78 |             self.eos_token_id = decoder_config.eos_token_id
 79 | 
 80 | 
 81 | class DonutSwinTableRecConfig(PretrainedConfig):
 82 |     model_type = "donut-swin"
 83 | 
 84 |     attribute_map = {
 85 |         "num_attention_heads": "num_heads",
 86 |         "num_hidden_layers": "num_layers",
 87 |     }
 88 | 
 89 |     def __init__(
 90 |         self,
 91 |         image_size=(settings.TABLE_REC_IMAGE_SIZE["width"], settings.TABLE_REC_IMAGE_SIZE["height"]),
 92 |         patch_size=4,
 93 |         num_channels=3,
 94 |         embed_dim=128,
 95 |         depths=[2, 2, 12, 2],
 96 |         num_heads=[4, 8, 16, 32],
 97 |         num_kv_heads=[4, 8, 16, 32],
 98 |         window_size=8,
 99 |         mlp_ratio=4.0,
100 |         qkv_bias=True,
101 |         hidden_dropout_prob=0.0,
102 |         attention_probs_dropout_prob=0.0,
103 |         drop_path_rate=0.1,
104 |         hidden_act="gelu",
105 |         use_absolute_embeddings=False,
106 |         initializer_range=0.02,
107 |         layer_norm_eps=1e-5,
108 |         encoder_length=1024,
109 |         use_positional_embeddings=True,
110 |         **kwargs,
111 |     ):
112 |         super().__init__(**kwargs)
113 | 
114 |         self.image_size = image_size
115 |         self.patch_size = patch_size
116 |         self.num_channels = num_channels
117 |         self.embed_dim = embed_dim
118 |         self.depths = depths
119 |         self.num_layers = len(depths)
120 |         self.num_heads = num_heads
121 |         self.num_kv_heads = num_kv_heads
122 |         self.window_size = window_size
123 |         self.mlp_ratio = mlp_ratio
124 |         self.qkv_bias = qkv_bias
125 |         self.hidden_dropout_prob = hidden_dropout_prob
126 |         self.attention_probs_dropout_prob = attention_probs_dropout_prob
127 |         self.drop_path_rate = drop_path_rate
128 |         self.hidden_act = hidden_act
129 |         self.use_absolute_embeddings = use_absolute_embeddings
130 |         self.layer_norm_eps = layer_norm_eps
131 |         self.initializer_range = initializer_range
132 |         # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
133 |         # this indicates the channel dimension after the last stage of the model
134 |         self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
135 |         self.encoder_length = encoder_length
136 |         self.use_positional_embeddings = use_positional_embeddings
137 | 
138 | 
139 | class SuryaTableRecDecoderConfig(PretrainedConfig):
140 |     model_type = "surya_tablerec"
141 | 
142 |     def __init__(
143 |         self,
144 |         num_hidden_layers=6,
145 |         vocab_size=BOX_DIM + 1,
146 |         bbox_size=BOX_DIM,
147 |         hidden_size=512,
148 |         property_embed_size=64,
149 |         box_embed_size=512 - 64,
150 |         intermediate_size=4 * 512,
151 |         encoder_hidden_size=1024,
152 |         num_attention_heads=8,
153 |         lru_width=None,
154 |         attention_window_size=16,
155 |         conv1d_width=4,
156 |         logits_soft_cap=30.0,
157 |         rms_norm_eps=1e-6,
158 |         use_cache=True,
159 |         pad_token_id=0,
160 |         eos_token_id=1,
161 |         bos_token_id=1,
162 |         pause_token_id=2,
163 |         query_end_token_id=4,
164 |         hidden_activation="gelu_pytorch_tanh",
165 |         rope_theta=10000.0,
166 |         block_types=("attention",),
167 |         cross_attn_layers=tuple(range(10)),
168 |         encoder_cross_attn_layers=tuple(range(10)),
169 |         self_attn_layers=tuple(range(10)),
170 |         global_attn_layers=tuple(range(10)),
171 |         attention_dropout=0.0,
172 |         num_key_value_heads=4,
173 |         attention_bias=False,
174 |         w_init_variance_scale=0.01,
175 |         init_std=0.02,
176 |         tie_word_embeddings=False,
177 |         aux_heads=0, # How many n-token-ahead heads to add
178 |         causal=True,
179 |         layer_norm_eps=1e-5,
180 |         dropout=0.0,
181 |         special_token_count=SPECIAL_TOKENS,
182 |         **kwargs,
183 |     ):
184 |         self.num_hidden_layers = num_hidden_layers
185 |         self.vocab_size = vocab_size
186 |         self.hidden_size = hidden_size
187 |         self.intermediate_size = intermediate_size
188 |         self.num_attention_heads = num_attention_heads
189 |         self.lru_width = lru_width if lru_width is not None else hidden_size
190 |         self.attention_window_size = attention_window_size
191 |         self.conv1d_width = conv1d_width
192 |         self.logits_soft_cap = logits_soft_cap
193 |         self.rms_norm_eps = rms_norm_eps
194 |         self.use_cache = use_cache
195 |         self.rope_theta = rope_theta
196 |         self.block_types = list(block_types)
197 |         self.hidden_activation = hidden_activation
198 |         self.head_dim = self.hidden_size // self.num_attention_heads
199 |         self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
200 |         if self.num_key_value_heads > self.num_attention_heads:
201 |             raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
202 |         self.cross_attn_layers = cross_attn_layers
203 |         self.self_attn_layers = self_attn_layers
204 |         self.global_attn_layers = global_attn_layers
205 |         self.attention_dropout = attention_dropout
206 |         self.attention_bias = attention_bias
207 |         self.w_init_variance_scale = w_init_variance_scale
208 |         self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
209 |         self.init_std = init_std
210 |         self.tie_word_embeddings = tie_word_embeddings
211 |         self.aux_heads = aux_heads
212 |         self.encoder_hidden_size=encoder_hidden_size
213 |         self.causal = causal
214 |         self.encoder_cross_attn_layers = encoder_cross_attn_layers
215 |         self.layer_norm_eps = layer_norm_eps
216 |         self.dropout = dropout
217 |         self.bbox_size = bbox_size
218 |         self.pause_token_id = pause_token_id
219 |         self.box_properties = BOX_PROPERTIES
220 |         self.property_embed_size = property_embed_size
221 |         self.box_embed_size = box_embed_size
222 |         self.special_token_count = special_token_count
223 |         self.query_end_token_id = query_end_token_id
224 |         self.double_residual_flow = False
225 | 
226 |         super().__init__(
227 |             pad_token_id=pad_token_id,
228 |             bos_token_id=bos_token_id,
229 |             eos_token_id=eos_token_id,
230 |             **kwargs,
231 |         )
232 | 
233 |     @property
234 |     def layers_block_type(self):
235 |         return (self.block_types * 100)[: self.num_hidden_layers]
```
Page 2/5FirstPrevNextLast