This is page 2 of 5. Use http://codebase.md/datalab-to/surya?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── ISSUE_TEMPLATE │ │ ├── breaking-bug-report.md │ │ ├── feature_request.md │ │ └── output-bug-report.md │ └── workflows │ ├── benchmarks.yml │ ├── ci.yml │ ├── cla.yml │ ├── publish.yml │ └── scripts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── benchmark │ ├── detection.py │ ├── layout.py │ ├── ordering.py │ ├── recognition.py │ ├── table_recognition.py │ ├── texify.py │ └── utils │ ├── __init__.py │ ├── bbox.py │ ├── metrics.py │ ├── scoring.py │ ├── tatr.py │ ├── tesseract.py │ ├── textract.py │ └── verify_benchmark_scores.py ├── CITATION.cff ├── CLA.md ├── detect_layout.py ├── detect_text.py ├── LICENSE ├── ocr_app.py ├── ocr_latex.py ├── ocr_text.py ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── README.md ├── signatures │ └── version1 │ └── cla.json ├── static │ ├── fonts │ │ └── .gitignore │ └── images │ ├── arabic_layout.jpg │ ├── arabic_reading.jpg │ ├── arabic_text.jpg │ ├── arabic.jpg │ ├── benchmark_chart_small.png │ ├── benchmark_chart.png │ ├── benchmark_layout_chart.png │ ├── benchmark_rec_chart.png │ ├── benchmark_tablerec_acc.png │ ├── benchmark_tablerec_speed.png │ ├── chi_hind_layout.jpg │ ├── chi_hind_orig.jpg │ ├── chi_hind_reading.jpg │ ├── chi_hind_text.jpg │ ├── chi_hind.jpg │ ├── chinese_layout.jpg │ ├── chinese_reading.jpg │ ├── chinese_text.jpg │ ├── chinese.jpg │ ├── excerpt_layout.png │ ├── excerpt_reading.jpg │ ├── excerpt_text.png │ ├── excerpt.png │ ├── funsd_layout.jpg │ ├── funsd_reading.jpg │ ├── funsd_text.jpg │ ├── funsd.png │ ├── gcloud_full_langs.png │ ├── gcloud_rec_bench.png │ ├── hindi_layout.jpg │ ├── hindi_reading.jpg │ ├── hindi_text.jpg │ ├── hindi.jpg │ ├── japanese_layout.jpg │ ├── japanese_reading.jpg │ ├── japanese_tablerec.png │ ├── japanese_text.jpg │ ├── japanese.jpg │ ├── latex_ocr.png │ ├── nyt_layout.jpg │ ├── nyt_order.jpg │ ├── nyt_text.jpg │ ├── nyt.jpg │ ├── paper_layout.jpg │ ├── paper_reading.jpg │ ├── paper_tablerec.png │ ├── paper_text.jpg │ ├── paper.jpg │ ├── pres_layout.jpg │ ├── pres_reading.jpg │ ├── pres_tablerec.png │ ├── pres_text.jpg │ ├── pres.png │ ├── rec_acc_table.png │ ├── scanned_layout.jpg │ ├── scanned_reading.jpg │ ├── scanned_tablerec.png │ ├── scanned_tablerec2.png │ ├── scanned_text.jpg │ ├── scanned.png │ ├── surya_rec_perf.png │ ├── table_rec.png │ ├── textbook_layout.jpg │ ├── textbook_order.jpg │ ├── textbook_text.jpg │ └── textbook.jpg ├── surya │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── adetr │ │ │ └── decoder.py │ │ ├── donut │ │ │ ├── encoder.py │ │ │ └── processor.py │ │ ├── load.py │ │ ├── polygon.py │ │ ├── predictor.py │ │ ├── pretrained.py │ │ ├── s3.py │ │ ├── surya │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── decoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── embedder │ │ │ │ └── __init__.py │ │ │ ├── encoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── flash_attn_utils.py │ │ │ ├── processor │ │ │ │ ├── __init__.py │ │ │ │ ├── schema.py │ │ │ │ └── tokenizer.py │ │ │ └── schema.py │ │ ├── util.py │ │ └── xla.py │ ├── debug │ │ ├── draw.py │ │ ├── fonts.py │ │ ├── katex.js │ │ ├── render_html.py │ │ └── text.py │ ├── detection │ │ ├── __init__.py │ │ ├── heatmap.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoderdecoder.py │ │ ├── parallel.py │ │ ├── processor.py │ │ ├── schema.py │ │ └── util.py │ ├── foundation │ │ ├── __init__.py │ │ ├── cache │ │ │ ├── __init__.py │ │ │ ├── dynamic_ops.py │ │ │ └── static_ops.py │ │ ├── loader.py │ │ └── util.py │ ├── input │ │ ├── load.py │ │ └── processing.py │ ├── layout │ │ ├── __init__.py │ │ ├── label.py │ │ └── schema.py │ ├── logging.py │ ├── models.py │ ├── ocr_error │ │ ├── __init__.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoder.py │ │ ├── schema.py │ │ └── tokenizer.py │ ├── recognition │ │ ├── __init__.py │ │ ├── languages.py │ │ ├── postprocessing.py │ │ ├── schema.py │ │ └── util.py │ ├── scripts │ │ ├── __init__.py │ │ ├── config.py │ │ ├── detect_layout.py │ │ ├── detect_text.py │ │ ├── finetune_ocr.py │ │ ├── hf_to_s3.py │ │ ├── ocr_latex.py │ │ ├── ocr_text.py │ │ ├── run_streamlit_app.py │ │ ├── run_texify_app.py │ │ ├── streamlit_app.py │ │ ├── table_recognition.py │ │ └── texify_app.py │ ├── settings.py │ └── table_rec │ ├── __init__.py │ ├── loader.py │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── encoderdecoder.py │ ├── processor.py │ ├── schema.py │ └── shaper.py ├── table_recognition.py ├── tests │ ├── assets │ │ └── test_latex.png │ ├── conftest.py │ ├── test_detection.py │ ├── test_foundation.py │ ├── test_latex_ocr.py │ ├── test_layout.py │ ├── test_ocr_errors.py │ ├── test_recognition.py │ └── test_table_rec.py └── texify_app.py ``` # Files -------------------------------------------------------------------------------- /benchmark/utils/tatr.py: -------------------------------------------------------------------------------- ```python 1 | import torch 2 | from transformers import AutoModelForObjectDetection 3 | from surya.settings import settings 4 | import numpy as np 5 | 6 | 7 | class MaxResize(object): 8 | def __init__(self, max_size=800): 9 | self.max_size = max_size 10 | 11 | def __call__(self, image): 12 | width, height = image.size 13 | current_max_size = max(width, height) 14 | scale = self.max_size / current_max_size 15 | resized_image = image.resize((int(round(scale * width)), int(round(scale * height)))) 16 | 17 | return resized_image 18 | 19 | 20 | def to_tensor(image): 21 | # Convert PIL Image to NumPy array 22 | np_image = np.array(image).astype(np.float32) 23 | 24 | # Rearrange dimensions to [C, H, W] format 25 | np_image = np_image.transpose((2, 0, 1)) 26 | 27 | # Normalize to [0.0, 1.0] 28 | np_image /= 255.0 29 | 30 | return torch.from_numpy(np_image) 31 | 32 | 33 | def normalize(tensor, mean, std): 34 | for t, m, s in zip(tensor, mean, std): 35 | t.sub_(m).div_(s) 36 | return tensor 37 | 38 | 39 | def structure_transform(image): 40 | image = MaxResize(1000)(image) 41 | tensor = to_tensor(image) 42 | normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 43 | return normalized_tensor 44 | 45 | 46 | def box_cxcywh_to_xyxy(x): 47 | x_c, y_c, w, h = x.unbind(-1) 48 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] 49 | return torch.stack(b, dim=1) 50 | 51 | 52 | def rescale_bboxes(out_bbox, size): 53 | width, height = size 54 | boxes = box_cxcywh_to_xyxy(out_bbox) 55 | boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32) 56 | return boxes 57 | 58 | 59 | def outputs_to_objects(outputs, img_sizes, id2label): 60 | m = outputs.logits.softmax(-1).max(-1) 61 | batch_labels = list(m.indices.detach().cpu().numpy()) 62 | batch_scores = list(m.values.detach().cpu().numpy()) 63 | batch_bboxes = outputs['pred_boxes'].detach().cpu() 64 | 65 | batch_objects = [] 66 | for i in range(len(img_sizes)): 67 | pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])] 68 | pred_scores = batch_scores[i] 69 | pred_labels = batch_labels[i] 70 | 71 | objects = [] 72 | for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): 73 | class_label = id2label[int(label)] 74 | if not class_label == 'no object': 75 | objects.append({ 76 | 'label': class_label, 77 | 'score': float(score), 78 | 'bbox': [float(elem) for elem in bbox]} 79 | ) 80 | 81 | rows = [] 82 | cols = [] 83 | for cell in objects: 84 | if cell["label"] == "table column": 85 | cols.append(cell) 86 | 87 | if cell["label"] == "table row": 88 | rows.append(cell) 89 | batch_objects.append({ 90 | "rows": rows, 91 | "cols": cols 92 | }) 93 | 94 | return batch_objects 95 | 96 | 97 | def load_tatr(): 98 | return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL) 99 | 100 | 101 | def batch_inference_tatr(model, images, batch_size): 102 | device = model.device 103 | rows_cols = [] 104 | for i in range(0, len(images), batch_size): 105 | batch_images = images[i:i + batch_size] 106 | pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device) 107 | 108 | # forward pass 109 | with torch.no_grad(): 110 | outputs = model(pixel_values) 111 | 112 | id2label = model.config.id2label 113 | id2label[len(model.config.id2label)] = "no object" 114 | rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label)) 115 | return rows_cols ``` -------------------------------------------------------------------------------- /benchmark/texify.py: -------------------------------------------------------------------------------- ```python 1 | import os.path 2 | import re 3 | import time 4 | from pathlib import Path 5 | from typing import List 6 | 7 | import click 8 | import datasets 9 | from tabulate import tabulate 10 | from bs4 import BeautifulSoup 11 | 12 | from surya.common.surya.schema import TaskNames 13 | from surya.settings import settings 14 | from surya.foundation import FoundationPredictor 15 | from surya.recognition import RecognitionPredictor, OCRResult 16 | import json 17 | from rapidfuzz.distance import Levenshtein 18 | 19 | 20 | def normalize_text(text): 21 | soup = BeautifulSoup(text, "html.parser") 22 | # Unwrap math tags 23 | for tag in soup.find_all(): 24 | if tag.name == "math": 25 | tag.unwrap() 26 | text = soup.get_text() 27 | text = re.sub(r"\n", " ", text) 28 | text = re.sub(r"\s+", " ", text) 29 | return text.strip() 30 | 31 | 32 | def score_text(predictions, references): 33 | lev_dist = [] 34 | for p, r in zip(predictions, references): 35 | p = normalize_text(p) 36 | r = normalize_text(r) 37 | lev_dist.append(Levenshtein.normalized_distance(p, r)) 38 | 39 | return sum(lev_dist) / len(lev_dist) 40 | 41 | 42 | def inference_texify( 43 | source_data, predictor: RecognitionPredictor, line_mode: bool = False 44 | ): 45 | images = [sd["image"] for sd in source_data] 46 | mode = TaskNames.ocr_with_boxes if line_mode else TaskNames.block_without_boxes 47 | tasks = [mode] * len(images) 48 | bboxes = [[[0, 0, image.width, image.height]] for image in images] 49 | texify_predictions: List[OCRResult] = predictor(images, tasks, bboxes=bboxes) 50 | out_data = [ 51 | { 52 | "text": texify_predictions[i].text_lines[0].text, 53 | "equation": source_data[i]["equation"], 54 | } 55 | for i in range(len(texify_predictions)) 56 | ] 57 | 58 | return out_data 59 | 60 | 61 | @click.command(help="Benchmark the performance of texify.") 62 | @click.option( 63 | "--ds_name", 64 | type=str, 65 | help="Path to dataset file with source images/equations.", 66 | default=settings.TEXIFY_BENCHMARK_DATASET, 67 | ) 68 | @click.option( 69 | "--results_dir", 70 | type=str, 71 | help="Path to JSON file with benchmark results.", 72 | default=os.path.join(settings.RESULT_DIR, "benchmark"), 73 | ) 74 | @click.option( 75 | "--max_rows", type=int, help="Maximum number of images to benchmark.", default=None 76 | ) 77 | @click.option( 78 | "--line_mode", is_flag=True, help="Use line mode for texify.", default=False 79 | ) 80 | def main(ds_name: str, results_dir: str, max_rows: int, line_mode: bool): 81 | foundation_predictor = FoundationPredictor() 82 | predictor = RecognitionPredictor(foundation_predictor) 83 | ds = datasets.load_dataset(ds_name, split="train") 84 | 85 | if max_rows: 86 | ds = ds.filter(lambda x, idx: idx < max_rows, with_indices=True) 87 | 88 | start = time.time() 89 | predictions = inference_texify(ds, predictor, line_mode) 90 | time_taken = time.time() - start 91 | 92 | text = [p["text"] for p in predictions] 93 | references = [p["equation"] for p in predictions] 94 | scores = score_text(text, references) 95 | 96 | write_data = { 97 | "scores": scores, 98 | "text": [{"prediction": p, "reference": r} for p, r in zip(text, references)], 99 | } 100 | 101 | score_table = [["texify", write_data["scores"], time_taken]] 102 | score_headers = ["edit", "time taken (s)"] 103 | score_dirs = ["⬇", "⬇"] 104 | 105 | score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)] 106 | table = tabulate(score_table, headers=["Method", *score_headers]) 107 | print() 108 | print(table) 109 | 110 | result_path = Path(results_dir) / "texify_bench" 111 | result_path.mkdir(parents=True, exist_ok=True) 112 | with open(result_path / "results.json", "w", encoding="utf-8") as f: 113 | json.dump(write_data, f, indent=4) 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | ``` -------------------------------------------------------------------------------- /surya/table_rec/model/encoder.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Optional, Union, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from surya.common.donut.encoder import DonutSwinPreTrainedModel, DonutSwinModelOutput, DonutSwinEmbeddings, DonutSwinEncoder 7 | 8 | 9 | class DonutSwinModel(DonutSwinPreTrainedModel): 10 | def __init__(self, config, add_pooling_layer=True, use_mask_token=False): 11 | super().__init__(config) 12 | self.config = config 13 | self.num_layers = len(config.depths) 14 | self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) 15 | 16 | self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token) 17 | self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) 18 | 19 | self.position_embeddings = None 20 | if hasattr(config, "encoder_length"): 21 | self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size)) 22 | 23 | # Initialize weights and apply final processing 24 | self.post_init() 25 | 26 | def get_input_embeddings(self): 27 | return self.embeddings.patch_embeddings 28 | 29 | def _prune_heads(self, heads_to_prune): 30 | """ 31 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 32 | class PreTrainedModel 33 | """ 34 | for layer, heads in heads_to_prune.items(): 35 | self.encoder.layer[layer].attention.prune_heads(heads) 36 | 37 | def forward( 38 | self, 39 | pixel_values: Optional[torch.FloatTensor] = None, 40 | bool_masked_pos: Optional[torch.BoolTensor] = None, 41 | head_mask: Optional[torch.FloatTensor] = None, 42 | output_attentions: Optional[bool] = None, 43 | output_hidden_states: Optional[bool] = None, 44 | interpolate_pos_encoding: bool = False, 45 | return_dict: Optional[bool] = None, 46 | ) -> Union[Tuple, DonutSwinModelOutput]: 47 | r""" 48 | bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): 49 | Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). 50 | """ 51 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 52 | output_hidden_states = ( 53 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 54 | ) 55 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 56 | 57 | if pixel_values is None: 58 | raise ValueError("You have to specify pixel_values") 59 | 60 | # Prepare head mask if needed 61 | # 1.0 in head_mask indicate we keep the head 62 | # attention_probs has shape bsz x n_heads x N x N 63 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 64 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 65 | head_mask = self.get_head_mask(head_mask, len(self.config.depths)) 66 | 67 | embedding_output, input_dimensions = self.embeddings( 68 | pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding 69 | ) 70 | 71 | encoder_outputs = self.encoder( 72 | embedding_output, 73 | input_dimensions, 74 | head_mask=head_mask, 75 | output_attentions=output_attentions, 76 | output_hidden_states=output_hidden_states, 77 | return_dict=return_dict, 78 | ) 79 | 80 | last_hidden_state = encoder_outputs[0] 81 | 82 | if self.position_embeddings is not None: 83 | last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :] 84 | 85 | return DonutSwinModelOutput( 86 | last_hidden_state=last_hidden_state, 87 | ) 88 | ``` -------------------------------------------------------------------------------- /surya/table_rec/model/encoderdecoder.py: -------------------------------------------------------------------------------- ```python 1 | from dataclasses import dataclass 2 | from typing import Optional, Union, Tuple, Dict 3 | 4 | import torch 5 | from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig 6 | 7 | from surya.common.pretrained import SuryaPreTrainedModel 8 | from surya.common.s3 import S3DownloaderMixin 9 | from surya.table_rec.model.decoder import SuryaTableRecDecoder 10 | from surya.table_rec.model.encoder import DonutSwinModel 11 | from transformers.utils import ModelOutput 12 | 13 | 14 | @dataclass 15 | class TableRecOutput(ModelOutput): 16 | box_property_logits: Dict[str, torch.FloatTensor] 17 | decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 18 | 19 | 20 | class TableRecEncoderDecoderModel(S3DownloaderMixin, SuryaPreTrainedModel): 21 | config_class = VisionEncoderDecoderConfig 22 | base_model_prefix = "vision_encoder_decoder" 23 | main_input_name = "pixel_values" 24 | supports_gradient_checkpointing = True 25 | _supports_param_buffer_assignment = False 26 | 27 | def __init__( 28 | self, 29 | config: Optional[PretrainedConfig] = None, 30 | encoder: Optional[PreTrainedModel] = None, 31 | decoder: Optional[PreTrainedModel] = None, 32 | **kwargs, 33 | ): 34 | # initialize with config 35 | # make sure input & output embeddings is not tied 36 | config.tie_word_embeddings = False 37 | config.decoder.tie_word_embeddings = False 38 | super().__init__(config, **kwargs) 39 | 40 | if encoder is None: 41 | encoder = DonutSwinModel(config.encoder) 42 | 43 | if decoder is None: 44 | decoder = SuryaTableRecDecoder( 45 | config.decoder, attn_implementation=config._attn_implementation 46 | ) 47 | 48 | self.encoder = encoder 49 | self.decoder = decoder 50 | 51 | # make sure that the individual model's config refers to the shared config 52 | # so that the updates to the config will be synced 53 | self.encoder.config = self.config.encoder 54 | self.decoder.config = self.config.decoder 55 | 56 | def get_encoder(self): 57 | return self.encoder 58 | 59 | def get_decoder(self): 60 | return self.decoder 61 | 62 | def get_output_embeddings(self): 63 | return self.decoder.get_output_embeddings() 64 | 65 | def set_output_embeddings(self, new_embeddings): 66 | return self.decoder.set_output_embeddings(new_embeddings) 67 | 68 | def forward( 69 | self, 70 | decoder_input_ids: torch.LongTensor = None, 71 | decoder_cache_position: Optional[torch.LongTensor] = None, 72 | decoder_attention_mask: Optional[torch.LongTensor] = None, 73 | encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, 74 | use_cache: Optional[bool] = None, 75 | return_dict: Optional[bool] = None, 76 | **kwargs, 77 | ) -> Union[Tuple[torch.FloatTensor], TableRecOutput]: 78 | kwargs_decoder = { 79 | argument[len("decoder_") :]: value 80 | for argument, value in kwargs.items() 81 | if argument.startswith("decoder_") 82 | } 83 | 84 | # Decode 85 | decoder_outputs = self.decoder( 86 | input_labels=decoder_input_ids, 87 | input_boxes_counts=None, 88 | cache_position=decoder_cache_position, 89 | attention_mask=decoder_attention_mask, 90 | encoder_hidden_states=encoder_outputs, 91 | encoder_attention_mask=None, 92 | use_cache=use_cache, 93 | **kwargs_decoder, 94 | ) 95 | 96 | return TableRecOutput( 97 | box_property_logits=decoder_outputs.box_property_logits, 98 | decoder_hidden_states=decoder_outputs.hidden_states, 99 | ) 100 | 101 | def resize_token_embeddings(self, *args, **kwargs): 102 | raise NotImplementedError( 103 | "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" 104 | " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" 105 | ) 106 | 107 | def _reorder_cache(self, past_key_values, beam_idx): 108 | # apply decoder cache reordering here 109 | return self.decoder._reorder_cache(past_key_values, beam_idx) 110 | ``` -------------------------------------------------------------------------------- /signatures/version1/cla.json: -------------------------------------------------------------------------------- ```json 1 | { 2 | "signedContributors": [ 3 | { 4 | "name": "rishiraj", 5 | "id": 44090649, 6 | "comment_id": 2170578748, 7 | "created_at": "2024-06-15T19:31:20Z", 8 | "repoId": 741297064, 9 | "pullRequestNo": 135 10 | }, 11 | { 12 | "name": "mmacvicar", 13 | "id": 59354, 14 | "comment_id": 2236493182, 15 | "created_at": "2024-07-18T13:17:43Z", 16 | "repoId": 741297064, 17 | "pullRequestNo": 152 18 | }, 19 | { 20 | "name": "jimexist", 21 | "id": 622789, 22 | "comment_id": 2255151376, 23 | "created_at": "2024-07-29T07:23:55Z", 24 | "repoId": 741297064, 25 | "pullRequestNo": 160 26 | }, 27 | { 28 | "name": "michaeldriscoll-avant", 29 | "id": 85255083, 30 | "comment_id": 2259143427, 31 | "created_at": "2024-07-30T20:21:33Z", 32 | "repoId": 741297064, 33 | "pullRequestNo": 161 34 | }, 35 | { 36 | "name": "EdoardoPona", 37 | "id": 29152472, 38 | "comment_id": 2271115922, 39 | "created_at": "2024-08-06T11:58:00Z", 40 | "repoId": 741297064, 41 | "pullRequestNo": 167 42 | }, 43 | { 44 | "name": "hidenori-endo", 45 | "id": 15546605, 46 | "comment_id": 2307217499, 47 | "created_at": "2024-08-23T14:31:17Z", 48 | "repoId": 741297064, 49 | "pullRequestNo": 182 50 | }, 51 | { 52 | "name": "dobosevych", 53 | "id": 12053536, 54 | "comment_id": 2430376828, 55 | "created_at": "2024-10-22T21:48:34Z", 56 | "repoId": 741297064, 57 | "pullRequestNo": 220 58 | }, 59 | { 60 | "name": "iammosespaulr", 61 | "id": 28682735, 62 | "comment_id": 2447941238, 63 | "created_at": "2024-10-30T17:55:23Z", 64 | "repoId": 741297064, 65 | "pullRequestNo": 235 66 | }, 67 | { 68 | "name": "ArthurMor4is", 69 | "id": 42987302, 70 | "comment_id": 2515315717, 71 | "created_at": "2024-12-03T18:37:45Z", 72 | "repoId": 741297064, 73 | "pullRequestNo": 255 74 | }, 75 | { 76 | "name": "tarun-menta", 77 | "id": 66506307, 78 | "comment_id": 2543457960, 79 | "created_at": "2024-12-15T05:43:33Z", 80 | "repoId": 741297064, 81 | "pullRequestNo": 261 82 | }, 83 | { 84 | "name": "jonaskahn", 85 | "id": 4338500, 86 | "comment_id": 2556622097, 87 | "created_at": "2024-12-20T09:36:20Z", 88 | "repoId": 741297064, 89 | "pullRequestNo": 269 90 | }, 91 | { 92 | "name": "kumsumit", 93 | "id": 95072784, 94 | "comment_id": 2574534622, 95 | "created_at": "2025-01-07T07:05:59Z", 96 | "repoId": 741297064, 97 | "pullRequestNo": 276 98 | }, 99 | { 100 | "name": "kevinhu", 101 | "id": 6051736, 102 | "comment_id": 2614135351, 103 | "created_at": "2025-01-25T23:34:12Z", 104 | "repoId": 741297064, 105 | "pullRequestNo": 291 106 | }, 107 | { 108 | "name": "zanussbaum", 109 | "id": 33707069, 110 | "comment_id": 3008673416, 111 | "created_at": "2025-06-26T14:20:46Z", 112 | "repoId": 741297064, 113 | "pullRequestNo": 403 114 | }, 115 | { 116 | "name": "mebriki", 117 | "id": 35892987, 118 | "comment_id": 3154706976, 119 | "created_at": "2025-08-05T10:54:27Z", 120 | "repoId": 741297064, 121 | "pullRequestNo": 418 122 | }, 123 | { 124 | "name": "starikovplusplus", 125 | "id": 56602036, 126 | "comment_id": 3168958011, 127 | "created_at": "2025-08-08T18:29:50Z", 128 | "repoId": 741297064, 129 | "pullRequestNo": 423 130 | }, 131 | { 132 | "name": "sandy0kwon", 133 | "id": 78377296, 134 | "comment_id": 3207932260, 135 | "created_at": "2025-08-20T20:07:15Z", 136 | "repoId": 741297064, 137 | "pullRequestNo": 434 138 | }, 139 | { 140 | "name": "n0kovo", 141 | "id": 16690056, 142 | "comment_id": 3208251881, 143 | "created_at": "2025-08-20T22:22:06Z", 144 | "repoId": 741297064, 145 | "pullRequestNo": 435 146 | }, 147 | { 148 | "name": "davidxifeng", 149 | "id": 158052, 150 | "comment_id": 3249594859, 151 | "created_at": "2025-09-03T14:52:16Z", 152 | "repoId": 741297064, 153 | "pullRequestNo": 445 154 | }, 155 | { 156 | "name": "u-ashish", 157 | "id": 14264791, 158 | "comment_id": 3258734182, 159 | "created_at": "2025-09-05T15:16:48Z", 160 | "repoId": 741297064, 161 | "pullRequestNo": 447 162 | }, 163 | { 164 | "name": "Mohking1", 165 | "id": 63689545, 166 | "comment_id": 3314908963, 167 | "created_at": "2025-09-20T11:21:42Z", 168 | "repoId": 741297064, 169 | "pullRequestNo": 462 170 | }, 171 | { 172 | "name": "wkpark", 173 | "id": 232347, 174 | "comment_id": 3330009557, 175 | "created_at": "2025-09-24T17:42:55Z", 176 | "repoId": 741297064, 177 | "pullRequestNo": 464 178 | } 179 | ] 180 | } ``` -------------------------------------------------------------------------------- /surya/layout/__init__.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List 2 | 3 | from PIL import Image 4 | 5 | from surya.common.predictor import BasePredictor 6 | from surya.layout.schema import LayoutBox, LayoutResult 7 | from surya.settings import settings 8 | from surya.foundation import FoundationPredictor, TaskNames 9 | from surya.foundation.util import prediction_to_polygon_batch 10 | from surya.input.processing import convert_if_not_rgb 11 | from surya.layout.label import LAYOUT_PRED_RELABEL 12 | from surya.common.util import clean_boxes 13 | 14 | 15 | class LayoutPredictor(BasePredictor): 16 | batch_size = settings.LAYOUT_BATCH_SIZE 17 | default_batch_sizes = {"cpu": 4, "mps": 4, "cuda": 32, "xla": 16} 18 | 19 | # Override base init - Do not load model 20 | def __init__(self, foundation_predictor: FoundationPredictor): 21 | self.foundation_predictor = foundation_predictor 22 | self.processor = self.foundation_predictor.processor 23 | self.bbox_size = self.foundation_predictor.model.config.bbox_size 24 | self.tasks = self.foundation_predictor.tasks 25 | 26 | # Special handling for disable tqdm to pass into foundation predictor 27 | # Make sure they are kept in sync 28 | @property 29 | def disable_tqdm(self) -> bool: 30 | return super().disable_tqdm 31 | 32 | @disable_tqdm.setter 33 | def disable_tqdm(self, value: bool) -> None: 34 | self._disable_tqdm = bool(value) 35 | self.foundation_predictor.disable_tqdm = bool(value) 36 | 37 | def __call__( 38 | self, images: List[Image.Image], batch_size: int | None = None, top_k: int = 5 39 | ) -> List[LayoutResult]: 40 | assert all([isinstance(image, Image.Image) for image in images]) 41 | if batch_size is None: 42 | batch_size = self.get_batch_size() 43 | 44 | if len(images) == 0: 45 | return [] 46 | 47 | images = convert_if_not_rgb(images) 48 | images = [self.processor.image_processor(image) for image in images] 49 | 50 | predicted_tokens, batch_bboxes, scores, topk_scores = ( 51 | self.foundation_predictor.prediction_loop( 52 | images=images, 53 | input_texts=["" for _ in range(len(images))], 54 | task_names=[TaskNames.layout for _ in range(len(images))], 55 | batch_size=batch_size, 56 | max_lookahead_tokens=0, # Do not do MTP for layout 57 | top_k=5, 58 | max_sliding_window=576, 59 | max_tokens=500, 60 | tqdm_desc="Recognizing Layout" 61 | ) 62 | ) 63 | 64 | image_sizes = [img.shape for img in images] 65 | predicted_polygons = prediction_to_polygon_batch( 66 | batch_bboxes, image_sizes, self.bbox_size, self.bbox_size // 2 67 | ) 68 | layout_results = [] 69 | for image, image_tokens, image_polygons, image_scores, image_topk_scores in zip( 70 | images, predicted_tokens, predicted_polygons, scores, topk_scores 71 | ): 72 | layout_boxes = [] 73 | for z, (tok, poly, score, tok_topk) in enumerate( 74 | zip(image_tokens, image_polygons, image_scores, image_topk_scores) 75 | ): 76 | if tok == self.processor.eos_token_id: 77 | break 78 | 79 | predicted_label = self.processor.decode([tok], "layout") 80 | label = LAYOUT_PRED_RELABEL.get(predicted_label) 81 | if not label: 82 | # Layout can sometimes return unknown labels from other objectives 83 | continue 84 | 85 | top_k_dict = {} 86 | for k, v in tok_topk.items(): 87 | topk_label = self.processor.decode([k], "layout") 88 | if topk_label in LAYOUT_PRED_RELABEL: 89 | topk_label = LAYOUT_PRED_RELABEL[topk_label] 90 | if not topk_label.strip(): 91 | continue 92 | top_k_dict.update({topk_label: v}) 93 | layout_boxes.append( 94 | LayoutBox( 95 | polygon=poly.tolist(), 96 | label=label, 97 | position=z, 98 | top_k=top_k_dict, 99 | confidence=score, 100 | ) 101 | ) 102 | layout_boxes = clean_boxes(layout_boxes) 103 | layout_results.append( 104 | LayoutResult( 105 | bboxes=layout_boxes, 106 | image_bbox=[0, 0, image.shape[1], image.shape[0]], 107 | ) # Image is numpy array 108 | ) 109 | 110 | assert len(layout_results) == len(images) 111 | return layout_results 112 | ``` -------------------------------------------------------------------------------- /surya/scripts/table_recognition.py: -------------------------------------------------------------------------------- ```python 1 | import os 2 | import click 3 | import copy 4 | import json 5 | from collections import defaultdict 6 | 7 | from surya.logging import configure_logging, get_logger 8 | from surya.scripts.config import CLILoader 9 | from surya.foundation import FoundationPredictor 10 | from surya.layout import LayoutPredictor 11 | from surya.table_rec import TableRecPredictor 12 | from surya.debug.draw import draw_bboxes_on_image 13 | from surya.common.util import rescale_bbox, expand_bbox 14 | from surya.settings import settings 15 | 16 | configure_logging() 17 | logger = get_logger() 18 | 19 | 20 | @click.command(help="Detect layout of an input file or folder (PDFs or image).") 21 | @CLILoader.common_options 22 | @click.option( 23 | "--skip_table_detection", 24 | is_flag=True, 25 | help="Tables are already cropped, so don't re-detect tables.", 26 | default=False, 27 | ) 28 | def table_recognition_cli(input_path: str, skip_table_detection: bool, **kwargs): 29 | loader = CLILoader(input_path, kwargs, highres=True) 30 | 31 | foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) 32 | layout_predictor = LayoutPredictor(foundation_predictor) 33 | table_rec_predictor = TableRecPredictor() 34 | 35 | pnums = [] 36 | prev_name = None 37 | for i, name in enumerate(loader.names): 38 | if prev_name is None or prev_name != name: 39 | pnums.append(0) 40 | else: 41 | pnums.append(pnums[-1] + 1) 42 | 43 | prev_name = name 44 | 45 | layout_predictions = layout_predictor(loader.images) 46 | 47 | table_imgs = [] 48 | table_counts = [] 49 | 50 | for layout_pred, img, highres_img in zip( 51 | layout_predictions, loader.images, loader.highres_images 52 | ): 53 | # The table may already be cropped 54 | if skip_table_detection: 55 | table_imgs.append(highres_img) 56 | table_counts.append(1) 57 | else: 58 | # The bbox for the entire table 59 | bbox = [ 60 | line.bbox 61 | for line in layout_pred.bboxes 62 | if line.label in ["Table", "TableOfContents"] 63 | ] 64 | # Number of tables per page 65 | table_counts.append(len(bbox)) 66 | 67 | if len(bbox) == 0: 68 | continue 69 | 70 | page_table_imgs = [] 71 | highres_bbox = [] 72 | for bb in bbox: 73 | highres_bb = rescale_bbox(bb, img.size, highres_img.size) 74 | highres_bb = expand_bbox(highres_bb) 75 | page_table_imgs.append(highres_img.crop(highres_bb)) 76 | highres_bbox.append(highres_bb) 77 | 78 | table_imgs.extend(page_table_imgs) 79 | 80 | table_preds = table_rec_predictor(table_imgs) 81 | 82 | img_idx = 0 83 | prev_count = 0 84 | table_predictions = defaultdict(list) 85 | for i in range(sum(table_counts)): 86 | while i >= prev_count + table_counts[img_idx]: 87 | prev_count += table_counts[img_idx] 88 | img_idx += 1 89 | 90 | pred = table_preds[i] 91 | orig_name = loader.names[img_idx] 92 | pnum = pnums[img_idx] 93 | table_img = table_imgs[i] 94 | 95 | out_pred = pred.model_dump() 96 | out_pred["page"] = pnum + 1 97 | table_idx = i - prev_count 98 | out_pred["table_idx"] = table_idx 99 | table_predictions[orig_name].append(out_pred) 100 | 101 | if loader.save_images: 102 | rows = [line.bbox for line in pred.rows] 103 | cols = [line.bbox for line in pred.cols] 104 | row_labels = [f"Row {line.row_id}" for line in pred.rows] 105 | col_labels = [f"Col {line.col_id}" for line in pred.cols] 106 | cells = [line.bbox for line in pred.cells] 107 | 108 | rc_image = copy.deepcopy(table_img) 109 | rc_image = draw_bboxes_on_image( 110 | rows, rc_image, labels=row_labels, label_font_size=20, color="blue" 111 | ) 112 | rc_image = draw_bboxes_on_image( 113 | cols, rc_image, labels=col_labels, label_font_size=20, color="red" 114 | ) 115 | rc_image.save( 116 | os.path.join( 117 | loader.result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png" 118 | ) 119 | ) 120 | 121 | cell_image = copy.deepcopy(table_img) 122 | cell_image = draw_bboxes_on_image(cells, cell_image, color="green") 123 | cell_image.save( 124 | os.path.join( 125 | loader.result_path, 126 | f"{name}_page{pnum + 1}_table{table_idx}_cells.png", 127 | ) 128 | ) 129 | 130 | with open( 131 | os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" 132 | ) as f: 133 | json.dump(table_predictions, f, ensure_ascii=False) 134 | 135 | logger.info(f"Wrote results to {loader.result_path}") 136 | ``` -------------------------------------------------------------------------------- /CLA.md: -------------------------------------------------------------------------------- ```markdown 1 | Surya Contributor Agreement 2 | 3 | This Surya Contributor Agreement ("SCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Endless Labs, Inc. The term "you" shall mean the person or entity identified below. 4 | 5 | If you agree to be bound by these terms, sign by writing "I have read the CLA document and I hereby sign the CLA" in response to the CLA bot Github comment. Read this agreement carefully before signing. These terms and conditions constitute a binding legal agreement. 6 | 7 | 1. The term 'contribution' or 'contributed materials' means any source code, object code, patch, tool, sample, graphic, specification, manual, documentation, or any other material posted or submitted by you to the project. 8 | 2. With respect to any worldwide copyrights, or copyright applications and registrations, in your contribution: 9 | - you hereby assign to us joint ownership, and to the extent that such assignment is or becomes invalid, ineffective or unenforceable, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty free, unrestricted license to exercise all rights under those copyrights. This includes, at our option, the right to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements, including dual-license structures for commercial customers; 10 | - you agree that each of us can do all things in relation to your contribution as if each of us were the sole owners, and if one of us makes a derivative work of your contribution, the one who makes the derivative work (or has it made will be the sole owner of that derivative work; 11 | - you agree that you will not assert any moral rights in your contribution against us, our licensees or transferees; 12 | - you agree that we may register a copyright in your contribution and exercise all ownership rights associated with it; and 13 | - you agree that neither of us has any duty to consult with, obtain the consent of, pay or render an accounting to the other for any use or distribution of vour contribution. 14 | 3. With respect to any patents you own, or that you can license without payment to any third party, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty-free license to: 15 | - make, have made, use, sell, offer to sell, import, and otherwise transfer your contribution in whole or in part, alone or in combination with or included in any product, work or materials arising out of the project to which your contribution was submitted, and 16 | - at our option, to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements. 17 | If you or your affiliates institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the contribution or any project it was submitted to constitutes direct or contributory patent infringement, then any patent licenses granted to you under this agreement for that contribution shall terminate as of the date such litigation is filed. 18 | 4. Except as set out above, you keep all right, title, and interest in your contribution. The rights that you grant to us under these terms are effective on the date you first submitted a contribution to us, even if your submission took place before the date you sign these terms. Any contribution we make available under any license will also be made available under a suitable FSF (Free Software Foundation) or OSI (Open Source Initiative) approved license. 19 | 5. You covenant, represent, warrant and agree that: 20 | - each contribution that you submit is and shall be an original work of authorship and you can legally grant the rights set out in this SCA; 21 | - to the best of your knowledge, each contribution will not violate any third party's copyrights, trademarks, patents, or other intellectual property rights; and 22 | - each contribution shall be in compliance with U.S. export control laws and other applicable export and import laws. 23 | You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Endless Labs, Inc. may publicly disclose your participation in the project, including the fact that you have signed the SCA. 24 | 6. This SCA is governed by the laws of the State of California and applicable U.S. Federal law. Any choice of law rules will not apply. ``` -------------------------------------------------------------------------------- /surya/scripts/texify_app.py: -------------------------------------------------------------------------------- ```python 1 | import os 2 | import re 3 | from typing import List 4 | 5 | from surya.recognition import RecognitionPredictor 6 | from surya.foundation import FoundationPredictor 7 | from surya.common.surya.schema import TaskNames 8 | 9 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = ( 10 | "1" # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS 11 | ) 12 | 13 | import io 14 | 15 | import pandas as pd 16 | import streamlit as st 17 | from streamlit_drawable_canvas import st_canvas 18 | import hashlib 19 | import pypdfium2 20 | 21 | from surya.settings import settings 22 | from PIL import Image 23 | 24 | MAX_WIDTH = 800 25 | MAX_HEIGHT = 1000 26 | 27 | 28 | def replace_fences(text): 29 | text = re.sub(r'<math display="block">(.*?)</math>', r"$$\1$$", text) 30 | text = re.sub(r"<math>(.*?)</math>", r"$\1$", text) 31 | text = re.sub(r'<math display="inline">(.*?)</math>', r"$\1$", text) 32 | return text 33 | 34 | 35 | @st.cache_resource() 36 | def load_predictor(): 37 | foundation_predictor = FoundationPredictor() 38 | return RecognitionPredictor(foundation_predictor) 39 | 40 | 41 | @st.cache_data() 42 | def inference(pil_image: Image.Image, bbox: List[float]): 43 | input_img = pil_image.crop(bbox) 44 | bbox = [0, 0, input_img.width, input_img.height] 45 | model_output = predictor( 46 | [input_img], [TaskNames.block_without_boxes], bboxes=[[bbox]] 47 | ) 48 | return model_output[0].text_lines[0].text 49 | 50 | 51 | def open_pdf(pdf_file): 52 | stream = io.BytesIO(pdf_file.getvalue()) 53 | return pypdfium2.PdfDocument(stream) 54 | 55 | 56 | @st.cache_data() 57 | def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI_HIGHRES): 58 | doc = open_pdf(pdf_file) 59 | renderer = doc.render( 60 | pypdfium2.PdfBitmap.to_pil, 61 | page_indices=[page_num - 1], 62 | scale=dpi / 72, 63 | ) 64 | png = list(renderer)[0] 65 | png_image = png.convert("RGB") 66 | doc.close() 67 | return png_image 68 | 69 | 70 | @st.cache_data() 71 | def page_counter(pdf_file): 72 | doc = open_pdf(pdf_file) 73 | doc_len = len(doc) 74 | doc.close() 75 | return doc_len 76 | 77 | 78 | def resize_image(pil_image): 79 | if pil_image is None: 80 | return 81 | pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) 82 | 83 | 84 | def get_canvas_hash(pil_image): 85 | return hashlib.md5(pil_image.tobytes()).hexdigest() 86 | 87 | 88 | st.set_page_config(layout="wide") 89 | 90 | top_message = """### LaTeX OCR 91 | 92 | After the model loads, upload an image or a pdf, then draw a box around the equation or text you want to OCR by clicking and dragging. Surya will convert it to Markdown with LaTeX math on the right. 93 | """ 94 | 95 | st.markdown(top_message) 96 | col1, col2 = st.columns([0.7, 0.3]) 97 | 98 | predictor = load_predictor() 99 | 100 | in_file = st.sidebar.file_uploader( 101 | "PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"] 102 | ) 103 | if in_file is None: 104 | st.stop() 105 | 106 | if in_file is None: 107 | st.stop() 108 | 109 | filetype = in_file.type 110 | page_count = None 111 | if "pdf" in filetype: 112 | page_count = page_counter(in_file) 113 | page_number = st.sidebar.number_input( 114 | f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count 115 | ) 116 | pil_image = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES) 117 | else: 118 | pil_image = Image.open(in_file).convert("RGB") 119 | page_number = None 120 | 121 | if pil_image is None: 122 | st.stop() 123 | 124 | pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) 125 | canvas_hash = get_canvas_hash(pil_image) 126 | 127 | with col1: 128 | # Create a canvas component 129 | canvas_result = st_canvas( 130 | fill_color="rgba(255, 165, 0, 0.1)", # Fixed fill color with some opacity 131 | stroke_width=1, 132 | stroke_color="#FFAA00", 133 | background_color="#FFF", 134 | background_image=pil_image, 135 | update_streamlit=True, 136 | height=pil_image.height, 137 | width=pil_image.width, 138 | drawing_mode="rect", 139 | point_display_radius=0, 140 | key=canvas_hash, 141 | ) 142 | 143 | if not canvas_result.json_data: 144 | st.stop() 145 | 146 | objects = pd.json_normalize( 147 | canvas_result.json_data["objects"] 148 | ) # need to convert obj to str because PyArrow 149 | bbox_list = None 150 | if objects.shape[0] > 0: 151 | boxes = objects[objects["type"] == "rect"][["left", "top", "width", "height"]] 152 | boxes["right"] = boxes["left"] + boxes["width"] 153 | boxes["bottom"] = boxes["top"] + boxes["height"] 154 | bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist() 155 | 156 | if bbox_list: 157 | with col2: 158 | texts = [inference(pil_image, bbox) for bbox in bbox_list] 159 | for idx, latex in enumerate(reversed(texts)): 160 | st.markdown(f"### {len(texts) - idx}") 161 | st.markdown(replace_fences(latex), unsafe_allow_html=True) 162 | st.code(latex) 163 | st.divider() 164 | 165 | with col2: 166 | tips = """ 167 | ### Usage tips 168 | - Texify is sensitive to how you draw the box around the text you want to OCR. If you get bad results, try selecting a slightly different box, or splitting the box into multiple. 169 | """ 170 | st.markdown(tips) 171 | ``` -------------------------------------------------------------------------------- /surya/scripts/finetune_ocr.py: -------------------------------------------------------------------------------- ```python 1 | from __future__ import annotations 2 | from dataclasses import dataclass, field 3 | from typing import Optional, Tuple 4 | from datasets import load_dataset 5 | import numpy as np 6 | import torch 7 | from transformers import ( 8 | HfArgumentParser, 9 | TrainingArguments, 10 | Trainer, 11 | ) 12 | 13 | from surya.common.surya import SuryaModel 14 | from surya.common.surya.processor import SuryaOCRProcessor 15 | from surya.foundation import FoundationPredictor 16 | from surya.common.surya.processor.schema import ImageInput, TextInput 17 | from surya.common.surya.schema import TaskNames 18 | from surya.common.util import get_top_scripts, SCRIPT_TOKEN_MAPPING 19 | 20 | # Do not change these defaults 21 | OCR_TASK_NAME = TaskNames.ocr_with_boxes 22 | OCR_MAX_IMAGE_SIZE = (1024, 512) 23 | 24 | # Simple wrapper for huggingface dataset 25 | class SuryaOCRDataset(torch.utils.data.Dataset): 26 | def __init__(self, processor: SuryaOCRProcessor, data_args: SuryaOCRDataArguments): 27 | super().__init__() 28 | self.hf_dataset = load_dataset(data_args.dataset_name, num_proc=data_args.num_loading_proc, split="train") 29 | self.processor = processor 30 | 31 | def __len__(self): 32 | return len(self.hf_dataset) 33 | 34 | def get_script_text(self, text: str) -> str: 35 | scripts = get_top_scripts(text) 36 | script_text = "".join(SCRIPT_TOKEN_MAPPING[script] for script in scripts) 37 | return script_text 38 | 39 | def __getitem__(self, index): 40 | try: 41 | data = self.hf_dataset[index] 42 | image = data["image"] 43 | image = image.convert("RGB") 44 | image = np.asarray(image, dtype=np.float32) 45 | image = self.processor.scale_to_fit(image, max_size=OCR_MAX_IMAGE_SIZE) 46 | 47 | # Add in script information 48 | gt_text = data["text"] 49 | gt_text = self.get_script_text(gt_text) + gt_text 50 | 51 | return_dict = { 52 | "task": TaskNames.ocr_with_boxes, 53 | "inputs": [ 54 | ImageInput(type="image", image=image, rotated=False), 55 | # This empty TextInput **must be included** to match the original format 56 | TextInput(type="text", text=""), 57 | TextInput(type="text",text=gt_text), 58 | ], 59 | } 60 | return return_dict 61 | except: 62 | import traceback; traceback.print_exc() 63 | return self.__getitem__((index + 1) % self.__len__()) 64 | 65 | class SuryaOCRDataCollator: 66 | def __init__(self, processor: SuryaOCRProcessor, data_args: SuryaOCRDataArguments): 67 | self.processor = processor 68 | self.max_sequence_length = data_args.max_sequence_length 69 | 70 | def __call__(self, inputs): 71 | # Use right padding for training. Defaults to left for inference 72 | processed_batch = self.processor(inputs, padding_side="right") 73 | 74 | if self.max_sequence_length is not None: 75 | processed_batch["input_ids"] = processed_batch["input_ids"][:, :self.max_sequence_length] 76 | processed_batch["attention_mask"] = processed_batch["attention_mask"][:, :self.max_sequence_length] 77 | processed_batch["position_ids"] = processed_batch["position_ids"][:, :self.max_sequence_length] 78 | 79 | lm_labels = processed_batch["input_ids"].clone() 80 | skip_label_mask = ( 81 | (lm_labels == self.processor.pad_token_id ) 82 | | (lm_labels == self.processor.bos_token_id[TaskNames.ocr_with_boxes]) 83 | | (lm_labels == self.processor.eoi_token_id) 84 | | (lm_labels == self.processor.image_token_id) 85 | ) 86 | lm_labels[skip_label_mask] = -100 87 | processed_batch["labels"] = lm_labels 88 | 89 | return processed_batch 90 | 91 | def load_model_and_processor(checkpoint_path: Optional[str] = None) -> Tuple[SuryaModel, SuryaOCRProcessor]: 92 | foundation_predictor = FoundationPredictor(checkpoint=checkpoint_path) 93 | return foundation_predictor.model, foundation_predictor.processor 94 | 95 | @dataclass 96 | class SuryaOCRModelArguments: 97 | pretrained_checkpoint_path: Optional[str] = field(default=None) 98 | 99 | @dataclass 100 | class SuryaOCRDataArguments: 101 | dataset_name: str = field(default="datalab-to/ocr_finetune_example") 102 | num_loading_proc: int = field(default=16) 103 | max_sequence_length: Optional[int] = field(default=None) 104 | 105 | @dataclass 106 | class SuryaOCRTrainingArguments(TrainingArguments): 107 | remove_unused_columns: bool = field(default=False) 108 | 109 | def main(): 110 | parser = HfArgumentParser((SuryaOCRModelArguments, SuryaOCRDataArguments, SuryaOCRTrainingArguments)) 111 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 112 | 113 | model, processor = load_model_and_processor(model_args.pretrained_checkpoint_path) 114 | dataset = SuryaOCRDataset(processor, data_args) 115 | collator = SuryaOCRDataCollator(processor, data_args) 116 | 117 | trainer = Trainer( 118 | model=model, 119 | args=training_args, 120 | train_dataset=dataset, 121 | data_collator=collator 122 | ) 123 | 124 | trainer.train() 125 | 126 | if __name__ == "__main__": 127 | main() ``` -------------------------------------------------------------------------------- /benchmark/utils/tesseract.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from surya.input.processing import slice_bboxes_from_image 7 | from surya.settings import settings 8 | import os 9 | from concurrent.futures import ProcessPoolExecutor 10 | from surya.recognition.languages import CODE_TO_LANGUAGE 11 | from surya.recognition import RecognitionPredictor 12 | from surya.detection import DetectionPredictor 13 | 14 | 15 | def surya_lang_to_tesseract(code: str) -> Optional[str]: 16 | lang_str = CODE_TO_LANGUAGE[code] 17 | try: 18 | tess_lang = TESS_LANGUAGE_TO_CODE[lang_str] 19 | except KeyError: 20 | return None 21 | return tess_lang 22 | 23 | 24 | def tesseract_ocr(img, bboxes, lang: str): 25 | import pytesseract 26 | line_imgs = slice_bboxes_from_image(img, bboxes) 27 | config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"' 28 | lines = [] 29 | for line_img in line_imgs: 30 | line = pytesseract.image_to_string(line_img, lang=lang, config=config) 31 | lines.append(line) 32 | return lines 33 | 34 | 35 | def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None): 36 | tess_parallel_cores = min(len(imgs), RecognitionPredictor.get_batch_size()) 37 | if not cpus: 38 | cpus = os.cpu_count() 39 | tess_parallel_cores = min(tess_parallel_cores, cpus) 40 | 41 | # Tesseract uses up to 4 processes per instance 42 | # Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images 43 | tess_parallel = max(tess_parallel_cores // 2, 1) 44 | 45 | with ProcessPoolExecutor(max_workers=tess_parallel) as executor: 46 | tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc="Running tesseract OCR") 47 | tess_text = list(tess_text) 48 | return tess_text 49 | 50 | 51 | def tesseract_bboxes(img): 52 | import pytesseract 53 | from pytesseract import Output 54 | arr_img = np.asarray(img, dtype=np.uint8) 55 | ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT) 56 | 57 | bboxes = [] 58 | n_boxes = len(ocr['level']) 59 | for i in range(n_boxes): 60 | # It is possible to merge by line here with line number, but it gives bad results. 61 | _, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i] 62 | bbox = (x, y, x + w, y + h) 63 | bboxes.append(bbox) 64 | 65 | return bboxes 66 | 67 | 68 | def tesseract_parallel(imgs): 69 | # Tesseract uses 4 threads per instance 70 | tess_parallel_cores = min(len(imgs), DetectionPredictor.get_batch_size()) 71 | cpus = os.cpu_count() 72 | tess_parallel_cores = min(tess_parallel_cores, cpus) 73 | 74 | # Tesseract uses 4 threads per instance 75 | tess_parallel = max(tess_parallel_cores // 4, 1) 76 | 77 | with ProcessPoolExecutor(max_workers=tess_parallel) as executor: 78 | tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc="Running tesseract bbox detection") 79 | tess_bboxes = list(tess_bboxes) 80 | return tess_bboxes 81 | 82 | 83 | TESS_CODE_TO_LANGUAGE = { 84 | "afr": "Afrikaans", 85 | "amh": "Amharic", 86 | "ara": "Arabic", 87 | "asm": "Assamese", 88 | "aze": "Azerbaijani", 89 | "bel": "Belarusian", 90 | "ben": "Bengali", 91 | "bod": "Tibetan", 92 | "bos": "Bosnian", 93 | "bre": "Breton", 94 | "bul": "Bulgarian", 95 | "cat": "Catalan", 96 | "ceb": "Cebuano", 97 | "ces": "Czech", 98 | "chi_sim": "Chinese", 99 | "chr": "Cherokee", 100 | "cym": "Welsh", 101 | "dan": "Danish", 102 | "deu": "German", 103 | "dzo": "Dzongkha", 104 | "ell": "Greek", 105 | "eng": "English", 106 | "epo": "Esperanto", 107 | "est": "Estonian", 108 | "eus": "Basque", 109 | "fas": "Persian", 110 | "fin": "Finnish", 111 | "fra": "French", 112 | "fry": "Western Frisian", 113 | "guj": "Gujarati", 114 | "gla": "Scottish Gaelic", 115 | "gle": "Irish", 116 | "glg": "Galician", 117 | "heb": "Hebrew", 118 | "hin": "Hindi", 119 | "hrv": "Croatian", 120 | "hun": "Hungarian", 121 | "hye": "Armenian", 122 | "iku": "Inuktitut", 123 | "ind": "Indonesian", 124 | "isl": "Icelandic", 125 | "ita": "Italian", 126 | "jav": "Javanese", 127 | "jpn": "Japanese", 128 | "kan": "Kannada", 129 | "kat": "Georgian", 130 | "kaz": "Kazakh", 131 | "khm": "Khmer", 132 | "kir": "Kyrgyz", 133 | "kor": "Korean", 134 | "lao": "Lao", 135 | "lat": "Latin", 136 | "lav": "Latvian", 137 | "lit": "Lithuanian", 138 | "mal": "Malayalam", 139 | "mar": "Marathi", 140 | "mkd": "Macedonian", 141 | "mlt": "Maltese", 142 | "mon": "Mongolian", 143 | "msa": "Malay", 144 | "mya": "Burmese", 145 | "nep": "Nepali", 146 | "nld": "Dutch", 147 | "nor": "Norwegian", 148 | "ori": "Oriya", 149 | "pan": "Punjabi", 150 | "pol": "Polish", 151 | "por": "Portuguese", 152 | "pus": "Pashto", 153 | "ron": "Romanian", 154 | "rus": "Russian", 155 | "san": "Sanskrit", 156 | "sin": "Sinhala", 157 | "slk": "Slovak", 158 | "slv": "Slovenian", 159 | "snd": "Sindhi", 160 | "spa": "Spanish", 161 | "sqi": "Albanian", 162 | "srp": "Serbian", 163 | "swa": "Swahili", 164 | "swe": "Swedish", 165 | "syr": "Syriac", 166 | "tam": "Tamil", 167 | "tel": "Telugu", 168 | "tgk": "Tajik", 169 | "tha": "Thai", 170 | "tir": "Tigrinya", 171 | "tur": "Turkish", 172 | "uig": "Uyghur", 173 | "ukr": "Ukrainian", 174 | "urd": "Urdu", 175 | "uzb": "Uzbek", 176 | "vie": "Vietnamese", 177 | "yid": "Yiddish" 178 | } 179 | 180 | TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()} 181 | ``` -------------------------------------------------------------------------------- /benchmark/layout.py: -------------------------------------------------------------------------------- ```python 1 | import collections 2 | import copy 3 | import json 4 | 5 | import click 6 | 7 | from benchmark.utils.metrics import precision_recall 8 | from surya.foundation import FoundationPredictor 9 | from surya.layout import LayoutPredictor 10 | from surya.input.processing import convert_if_not_rgb 11 | from surya.debug.draw import draw_bboxes_on_image 12 | from surya.settings import settings 13 | import os 14 | import time 15 | from tabulate import tabulate 16 | import datasets 17 | 18 | 19 | @click.command(help="Benchmark surya layout model.") 20 | @click.option( 21 | "--results_dir", 22 | type=str, 23 | help="Path to JSON file with OCR results.", 24 | default=os.path.join(settings.RESULT_DIR, "benchmark"), 25 | ) 26 | @click.option( 27 | "--max_rows", 28 | type=int, 29 | help="Maximum number of images to run benchmark on.", 30 | default=100, 31 | ) 32 | @click.option("--debug", is_flag=True, help="Run in debug mode.", default=False) 33 | def main(results_dir: str, max_rows: int, debug: bool): 34 | foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) 35 | layout_predictor = LayoutPredictor(foundation_predictor) 36 | 37 | pathname = "layout_bench" 38 | # These have already been shuffled randomly, so sampling from the start is fine 39 | dataset = datasets.load_dataset( 40 | settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{max_rows}]" 41 | ) 42 | images = list(dataset["image"]) 43 | images = convert_if_not_rgb(images) 44 | 45 | if settings.LAYOUT_STATIC_CACHE: 46 | layout_predictor(images[:1]) 47 | 48 | start = time.time() 49 | layout_predictions = layout_predictor(images) 50 | surya_time = time.time() - start 51 | 52 | folder_name = os.path.basename(pathname).split(".")[0] 53 | result_path = os.path.join(results_dir, folder_name) 54 | os.makedirs(result_path, exist_ok=True) 55 | 56 | label_alignment = { # First is publaynet, second is surya 57 | "Image": [["Figure"], ["Picture", "Figure"]], 58 | "Table": [["Table"], ["Table", "Form", "TableOfContents"]], 59 | "Text": [ 60 | ["Text"], 61 | [ 62 | "Text", 63 | "Formula", 64 | "Footnote", 65 | "Caption", 66 | "TextInlineMath", 67 | "Code", 68 | "Handwriting", 69 | ], 70 | ], 71 | "List": [["List"], ["ListItem"]], 72 | "Title": [["Title"], ["SectionHeader", "Title"]], 73 | } 74 | 75 | page_metrics = collections.OrderedDict() 76 | for idx, pred in enumerate(layout_predictions): 77 | row = dataset[idx] 78 | all_correct_bboxes = [] 79 | page_results = {} 80 | for label_name in label_alignment: 81 | correct_cats, surya_cats = label_alignment[label_name] 82 | correct_bboxes = [ 83 | b 84 | for b, category in zip(row["bboxes"], row["labels"]) 85 | if category in correct_cats 86 | ] 87 | all_correct_bboxes.extend(correct_bboxes) 88 | pred_bboxes = [b.bbox for b in pred.bboxes if b.label in surya_cats] 89 | 90 | metrics = precision_recall( 91 | pred_bboxes, correct_bboxes, penalize_double=False 92 | ) 93 | weight = len(correct_bboxes) 94 | metrics["weight"] = weight 95 | page_results[label_name] = metrics 96 | 97 | page_metrics[idx] = page_results 98 | 99 | if debug: 100 | bbox_image = draw_bboxes_on_image( 101 | all_correct_bboxes, copy.deepcopy(images[idx]) 102 | ) 103 | bbox_image.save(os.path.join(result_path, f"{idx}_layout.png")) 104 | 105 | mean_metrics = collections.defaultdict(dict) 106 | layout_types = sorted(page_metrics[0].keys()) 107 | metric_types = sorted(page_metrics[0][layout_types[0]].keys()) 108 | metric_types.remove("weight") 109 | for label in layout_types: 110 | for m in metric_types: 111 | metric = [] 112 | total = 0 113 | for page in page_metrics: 114 | metric.append( 115 | page_metrics[page][label][m] * page_metrics[page][label]["weight"] 116 | ) 117 | total += page_metrics[page][label]["weight"] 118 | 119 | value = sum(metric) 120 | if value > 0: 121 | value /= total 122 | mean_metrics[label][m] = value 123 | 124 | out_data = { 125 | "time": surya_time, 126 | "metrics": mean_metrics, 127 | "page_metrics": page_metrics, 128 | } 129 | 130 | with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: 131 | json.dump(out_data, f, indent=4) 132 | 133 | table_headers = [ 134 | "Layout Type", 135 | ] + metric_types 136 | table_data = [] 137 | for layout_type in layout_types: 138 | table_data.append( 139 | [ 140 | layout_type, 141 | ] 142 | + [f"{mean_metrics[layout_type][m]:.5f}" for m in metric_types] 143 | ) 144 | 145 | print(tabulate(table_data, headers=table_headers, tablefmt="github")) 146 | print( 147 | f"Took {surya_time / len(images):.5f} seconds per image, and {surya_time:.5f} seconds total." 148 | ) 149 | print( 150 | "Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold." 151 | ) 152 | print(f"Wrote results to {result_path}") 153 | 154 | 155 | if __name__ == "__main__": 156 | main() 157 | ``` -------------------------------------------------------------------------------- /surya/foundation/loader.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Optional 2 | 3 | import torch 4 | from transformers.utils import is_flash_attn_2_available 5 | 6 | from surya.common.load import ModelLoader 7 | from surya.common.surya.config import SuryaModelConfig 8 | from surya.common.surya import SuryaModel, SuryaXLAModel 9 | from surya.common.surya.processor import SuryaOCRProcessor 10 | from surya.common.surya.processor.tokenizer import SuryaOCRTokenizer 11 | from surya.common.util import is_flash_attn_2_supported 12 | from surya.common.xla import get_compile_args 13 | from surya.logging import get_logger 14 | from surya.settings import settings 15 | 16 | logger = get_logger() 17 | 18 | 19 | class FoundationModelLoader(ModelLoader): 20 | def __init__(self, checkpoint: Optional[str] = None): 21 | super().__init__(checkpoint) 22 | 23 | if self.checkpoint is None: 24 | self.checkpoint = settings.FOUNDATION_MODEL_CHECKPOINT 25 | 26 | def model( 27 | self, 28 | device=settings.TORCH_DEVICE_MODEL, 29 | dtype=None, 30 | attention_implementation: Optional[str] = None, 31 | ) -> SuryaModel: 32 | if device is None: 33 | device = settings.TORCH_DEVICE_MODEL 34 | if dtype is None: 35 | # See https://github.com/pytorch/pytorch/issues/118122 - T4 (device version 7.5) will return true since it supports 36 | # emulated bf16, but falls back to very slow kernels, especially for SDPA 37 | dtype = settings.MODEL_DTYPE_BFLOAT 38 | if device == "cuda" and not torch.cuda.is_bf16_supported( 39 | including_emulation=False 40 | ): 41 | # If the device is cuda, we check if bf16 is supported, and if not, we use float16 42 | dtype = settings.MODEL_DTYPE 43 | elif dtype == torch.float16: 44 | dtype = torch.bfloat16 # Model weights in bfloat16 45 | 46 | config = SuryaModelConfig.from_pretrained(self.checkpoint) 47 | 48 | if attention_implementation is not None: 49 | config.decoder._attn_implementation = attention_implementation 50 | config.vision_encoder._attn_implementation = attention_implementation 51 | elif is_flash_attn_2_available() and is_flash_attn_2_supported(device): 52 | config.decoder._attn_implementation = "flash_attention_2" 53 | config.vision_encoder._attn_implementation = "flash_attention_2" 54 | elif device == "xla": 55 | config.decoder._attn_implementation = "sdpa" 56 | config.vision_encoder._attn_implementation = "sdpa" 57 | else: 58 | config.decoder._attn_implementation = "sdpa" 59 | config.vision_encoder._attn_implementation = "sdpa" 60 | 61 | model_cls = SuryaModel 62 | if device == "xla": 63 | model_cls = SuryaXLAModel 64 | 65 | config._attn_implementation_autoset = True 66 | config.vision_encoder._attn_implementation_autoset = True 67 | config.decoder._attn_implementation_autoset = True 68 | 69 | model = model_cls.from_pretrained( 70 | self.checkpoint, dtype=dtype, config=config, ignore_mismatched_sizes=True 71 | ).to(device) 72 | model = model.eval() 73 | 74 | if settings.COMPILE_ALL or settings.COMPILE_FOUNDATION: 75 | torch._dynamo.config.cache_size_limit = 1000 76 | torch._dynamo.config.suppress_errors = True 77 | torch._dynamo.config.specialize_int = False 78 | torch._dynamo.config.allow_unspec_int_on_nn_module = True 79 | torch._dynamo.config.capture_scalar_outputs = True 80 | torch._dynamo.config.recompile_limit = 32 81 | 82 | logger.info( 83 | f"Compiling foundation model {self.checkpoint} on device {device} with dtype {dtype}" 84 | ) 85 | compile_args = get_compile_args(device) 86 | model.vision_encoder = torch.compile(model.vision_encoder, **compile_args) 87 | model.decoder = torch.compile(model.decoder, **compile_args) 88 | 89 | logger.debug( 90 | f"Loaded recognition model {self.checkpoint} from {SuryaModel.get_local_path(self.checkpoint)} onto device {model.device} with dtype {dtype}, using decoder attention mechanism {model.config.decoder._attn_implementation}, encoder attention mechanism {model.config.vision_encoder._attn_implementation}." 91 | ) 92 | return model 93 | 94 | def processor( 95 | self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE_BFLOAT 96 | ) -> SuryaOCRProcessor: 97 | config: SuryaModelConfig = SuryaModelConfig.from_pretrained(self.checkpoint) 98 | 99 | ocr_tokenizer = SuryaOCRTokenizer( 100 | special_tokens=config.special_ocr_tokens, model_checkpoint=self.checkpoint 101 | ) 102 | 103 | processor = SuryaOCRProcessor( 104 | ocr_tokenizer=ocr_tokenizer, 105 | blank_bbox_token_id=config.blank_bbox_token_id, 106 | num_register_tokens=config.num_register_tokens, 107 | sequence_length=None, 108 | patch_size=config.vision_encoder.patch_size, 109 | merge_size=config.vision_encoder.spatial_merge_size, 110 | model_device=device, 111 | num_beacon_tokens=config.num_beacon_tokens, 112 | beacon_token_interval=config.beacon_token_interval, 113 | ) 114 | 115 | return processor 116 | ``` -------------------------------------------------------------------------------- /surya/table_rec/shaper.py: -------------------------------------------------------------------------------- ```python 1 | import math 2 | from typing import List, Dict 3 | import numpy as np 4 | 5 | from surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM 6 | 7 | 8 | class LabelShaper: 9 | def __init__(self): 10 | self.property_keys = [k for (k, kcount, mode) in BOX_PROPERTIES] 11 | 12 | def dict_to_labels(self, label_components: List[dict]): 13 | if len(label_components) == 0: 14 | return [] 15 | 16 | out_list = [] 17 | for (k, kcount, mode) in BOX_PROPERTIES: 18 | for label_component in label_components: 19 | if k not in label_component: 20 | raise ValueError(f"Missing key {k} in label component {label_component}") 21 | 22 | if mode == "classification": 23 | assert isinstance(label_component[k], int) 24 | elif mode == "regression": 25 | assert (isinstance(label_component[k], (int, float)) and kcount == 1) or len(label_component[k]) == kcount 26 | else: 27 | raise ValueError(f"Invalid mode {k['mode']} for key {k}") 28 | 29 | for label_component in label_components: 30 | bbox = label_component["bbox"] 31 | for i in range(len(bbox)): 32 | if bbox[i] < 0: 33 | bbox[i] = 0 34 | if bbox[i] > BOX_DIM: 35 | bbox[i] = BOX_DIM 36 | 37 | vector = [] 38 | for (k, kcount, mode) in BOX_PROPERTIES: 39 | item = label_component[k] 40 | if isinstance(item, (list, tuple)): 41 | vector += list(item) 42 | elif isinstance(item, (float, int)): 43 | if mode == "classification": 44 | # Shift up for model 45 | item += SPECIAL_TOKENS 46 | vector.append(item) 47 | else: 48 | raise ValueError(f"Invalid item {item} for key {k}") 49 | 50 | out_list.append(vector) 51 | 52 | return out_list 53 | 54 | def component_idx(self, key): 55 | idx = 0 56 | for (k, kcount, mode) in BOX_PROPERTIES: 57 | if mode == "regression": 58 | incr = kcount 59 | elif mode == "classification": 60 | incr = 1 61 | else: 62 | raise ValueError(f"Invalid mode {mode} for key {k}") 63 | if k == key: 64 | return (idx, idx + incr) 65 | idx += incr 66 | raise ValueError(f"Key {key} not found in properties") 67 | 68 | def get_box_property(self, key, add_special_tokens=True): 69 | for (k, kcount, mode) in BOX_PROPERTIES: 70 | if k == key: 71 | # Add special token count 72 | if mode == "classification" and add_special_tokens: 73 | kcount += SPECIAL_TOKENS 74 | return (k, kcount, mode) 75 | raise ValueError(f"Key {key} not found in properties") 76 | 77 | def component_idx_dict(self): 78 | idx_dict = {} 79 | for (k, kcount, mode) in BOX_PROPERTIES: 80 | idx_dict[k] = self.component_idx(k) 81 | return idx_dict 82 | 83 | def convert_polygons_to_bboxes(self, label_components: List[Dict]): 84 | for i, label_component in enumerate(label_components): 85 | poly = label_component["polygon"] 86 | poly = np.clip(poly, 0, BOX_DIM) 87 | 88 | (x1, y1), (x2, y2), (x3, y3), (x4, y4) = poly 89 | cx = (x1 + x2 + x3 + x4) / 4 90 | cy = (y1 + y2 + y3 + y4) / 4 91 | width = (x2 + x3) / 2 - (x1 + x4) / 2 92 | height = (y3 + y4) / 2 - (y2 + y1) / 2 93 | bottom_avg_x = (x3 + x4) / 2 94 | top_avg_x = (x1 + x2) / 2 95 | right_avg_y = (y2 + y3) / 2 96 | left_avg_y = (y1 + y4) / 2 97 | 98 | x_skew = bottom_avg_x - top_avg_x 99 | y_skew = right_avg_y - left_avg_y 100 | x_skew += BOX_DIM // 2 # Shift up into positive space 101 | y_skew += BOX_DIM // 2 # Shift up into positive space 102 | new_poly = [ 103 | cx, 104 | cy, 105 | width, 106 | height, 107 | x_skew, 108 | y_skew 109 | ] 110 | label_component["bbox"] = new_poly 111 | 112 | return label_components 113 | 114 | def convert_bbox_to_polygon(self, box, skew_scaler=BOX_DIM // 2, skew_min=.001): 115 | cx = box[0] 116 | cy = box[1] 117 | width = box[2] 118 | height = box[3] 119 | x1 = cx - width / 2 120 | y1 = cy - height / 2 121 | x2 = cx + width / 2 122 | y2 = cy + height / 2 123 | skew_x = math.floor((box[4] - skew_scaler) / 2) 124 | skew_y = math.floor((box[5] - skew_scaler) / 2) 125 | 126 | # Ensures we don't get slightly warped boxes 127 | # Note that the values are later scaled, so this is in 1/1024 space 128 | if abs(skew_x) < skew_min: 129 | skew_x = 0 130 | 131 | if abs(skew_y) < skew_min: 132 | skew_y = 0 133 | 134 | polygon = [x1 - skew_x, y1 - skew_y, x2 - skew_x, y1 + skew_y, x2 + skew_x, y2 + skew_y, x1 + skew_x, 135 | y2 - skew_y] 136 | poly = [] 137 | for i in range(4): 138 | poly.append([ 139 | polygon[2 * i], 140 | polygon[2 * i + 1] 141 | ]) 142 | return poly 143 | 144 | 145 | 146 | ``` -------------------------------------------------------------------------------- /benchmark/detection.py: -------------------------------------------------------------------------------- ```python 1 | import argparse 2 | import collections 3 | import copy 4 | import json 5 | 6 | import click 7 | 8 | from benchmark.utils.bbox import get_pdf_lines 9 | from benchmark.utils.metrics import precision_recall 10 | from benchmark.utils.tesseract import tesseract_parallel 11 | from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb 12 | from surya.debug.draw import draw_polys_on_image 13 | from surya.common.util import rescale_bbox 14 | from surya.settings import settings 15 | from surya.detection import DetectionPredictor 16 | 17 | import os 18 | import time 19 | from tabulate import tabulate 20 | import datasets 21 | 22 | 23 | @click.command(help="Benchmark detection model.") 24 | @click.option("--pdf_path", type=str, help="Path to PDF to detect bboxes in.", default=None) 25 | @click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark")) 26 | @click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=100) 27 | @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) 28 | @click.option("--tesseract", is_flag=True, help="Run tesseract as well.", default=False) 29 | def main(pdf_path: str, results_dir: str, max_rows: int, debug: bool, tesseract: bool): 30 | det_predictor = DetectionPredictor() 31 | 32 | if pdf_path is not None: 33 | pathname = pdf_path 34 | doc = open_pdf(pdf_path) 35 | page_count = len(doc) 36 | page_indices = list(range(page_count)) 37 | page_indices = page_indices[:max_rows] 38 | 39 | images = get_page_images(doc, page_indices) 40 | doc.close() 41 | 42 | image_sizes = [img.size for img in images] 43 | correct_boxes = get_pdf_lines(pdf_path, image_sizes) 44 | else: 45 | pathname = "det_bench" 46 | # These have already been shuffled randomly, so sampling from the start is fine 47 | dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{max_rows}]") 48 | images = list(dataset["image"]) 49 | images = convert_if_not_rgb(images) 50 | correct_boxes = [] 51 | for i, boxes in enumerate(dataset["bboxes"]): 52 | img_size = images[i].size 53 | # 1000,1000 is bbox size for doclaynet 54 | correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes]) 55 | 56 | if settings.DETECTOR_STATIC_CACHE: 57 | # Run through one batch to compile the model 58 | det_predictor(images[:1]) 59 | 60 | start = time.time() 61 | predictions = det_predictor(images) 62 | surya_time = time.time() - start 63 | 64 | if tesseract: 65 | start = time.time() 66 | tess_predictions = tesseract_parallel(images) 67 | tess_time = time.time() - start 68 | else: 69 | tess_predictions = [None] * len(images) 70 | tess_time = None 71 | 72 | folder_name = os.path.basename(pathname).split(".")[0] 73 | result_path = os.path.join(results_dir, folder_name) 74 | os.makedirs(result_path, exist_ok=True) 75 | 76 | page_metrics = collections.OrderedDict() 77 | for idx, (tb, sb, cb) in enumerate(zip(tess_predictions, predictions, correct_boxes)): 78 | surya_boxes = [s.bbox for s in sb.bboxes] 79 | surya_polys = [s.polygon for s in sb.bboxes] 80 | 81 | surya_metrics = precision_recall(surya_boxes, cb) 82 | if tb is not None: 83 | tess_metrics = precision_recall(tb, cb) 84 | else: 85 | tess_metrics = None 86 | 87 | page_metrics[idx] = { 88 | "surya": surya_metrics, 89 | "tesseract": tess_metrics 90 | } 91 | 92 | if debug: 93 | bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx])) 94 | bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png")) 95 | 96 | mean_metrics = {} 97 | metric_types = sorted(page_metrics[0]["surya"].keys()) 98 | models = ["surya"] 99 | if tesseract: 100 | models.append("tesseract") 101 | 102 | for k in models: 103 | for m in metric_types: 104 | metric = [] 105 | for page in page_metrics: 106 | metric.append(page_metrics[page][k][m]) 107 | if k not in mean_metrics: 108 | mean_metrics[k] = {} 109 | mean_metrics[k][m] = sum(metric) / len(metric) 110 | 111 | out_data = { 112 | "times": { 113 | "surya": surya_time, 114 | "tesseract": tess_time 115 | }, 116 | "metrics": mean_metrics, 117 | "page_metrics": page_metrics 118 | } 119 | 120 | with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: 121 | json.dump(out_data, f, indent=4) 122 | 123 | table_headers = ["Model", "Time (s)", "Time per page (s)"] + metric_types 124 | table_data = [ 125 | ["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types], 126 | ] 127 | if tesseract: 128 | table_data.append( 129 | ["tesseract", tess_time, tess_time / len(images)] + [mean_metrics["tesseract"][m] for m in metric_types] 130 | ) 131 | 132 | print(tabulate(table_data, headers=table_headers, tablefmt="github")) 133 | print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines.") 134 | print(f"Wrote results to {result_path}") 135 | 136 | 137 | if __name__ == "__main__": 138 | main() 139 | ``` -------------------------------------------------------------------------------- /surya/detection/heatmap.py: -------------------------------------------------------------------------------- ```python 1 | from typing import List 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from surya.common.util import clean_boxes 8 | from surya.detection import TextDetectionResult 9 | from surya.common.polygon import PolygonBox 10 | from surya.settings import settings 11 | 12 | 13 | def get_dynamic_thresholds(linemap, text_threshold, low_text, typical_top10_avg=0.7): 14 | # Find average intensity of top 10% pixels 15 | flat_map = linemap.ravel() 16 | top_10_count = int(len(flat_map) * 0.9) 17 | avg_intensity = np.mean(np.partition(flat_map, top_10_count)[top_10_count:]) 18 | scaling_factor = np.clip(avg_intensity / typical_top10_avg, 0, 1) ** (1 / 2) 19 | 20 | low_text = np.clip(low_text * scaling_factor, 0.1, 0.6) 21 | text_threshold = np.clip(text_threshold * scaling_factor, 0.15, 0.8) 22 | 23 | return text_threshold, low_text 24 | 25 | 26 | def detect_boxes(linemap, text_threshold, low_text): 27 | # From CRAFT - https://github.com/clovaai/CRAFT-pytorch 28 | # Modified to return boxes and for speed, accuracy 29 | img_h, img_w = linemap.shape 30 | 31 | text_threshold, low_text = get_dynamic_thresholds(linemap, text_threshold, low_text) 32 | 33 | text_score_comb = (linemap > low_text).astype(np.uint8) 34 | label_count, labels, stats, centroids = cv2.connectedComponentsWithStats( 35 | text_score_comb, connectivity=4 36 | ) 37 | 38 | det = [] 39 | confidences = [] 40 | max_confidence = 0 41 | 42 | for k in range(1, label_count): 43 | # size filtering 44 | size = stats[k, cv2.CC_STAT_AREA] 45 | if size < 10: 46 | continue 47 | 48 | # make segmentation map 49 | x, y, w, h = stats[ 50 | k, 51 | [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT], 52 | ] 53 | 54 | try: 55 | niter = int(np.sqrt(min(w, h))) 56 | except ValueError: 57 | niter = 0 58 | 59 | buffer = 1 60 | sx, sy = max(0, x - niter - buffer), max(0, y - niter - buffer) 61 | ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer) 62 | 63 | mask = labels[sy:ey, sx:ex] == k 64 | selected_linemap = linemap[sy:ey, sx:ex][mask] 65 | if selected_linemap.size == 0: 66 | continue 67 | 68 | line_max = np.max(selected_linemap) 69 | 70 | # thresholding 71 | if line_max < text_threshold: 72 | continue 73 | 74 | segmap = mask.astype(np.uint8) 75 | 76 | ksize = buffer + niter 77 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (ksize, ksize)) 78 | selected_segmap = cv2.dilate(segmap, kernel) 79 | 80 | # make box 81 | y_inds, x_inds = np.nonzero(selected_segmap) 82 | x_inds += sx 83 | y_inds += sy 84 | np_contours = np.column_stack((x_inds, y_inds)) 85 | rectangle = cv2.minAreaRect(np_contours) 86 | box = cv2.boxPoints(rectangle) 87 | 88 | # align diamond-shape 89 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) 90 | box_ratio = max(w, h) / (min(w, h) + 1e-5) 91 | if abs(1 - box_ratio) <= 0.1: 92 | left, right = np_contours[:, 0].min(), np_contours[:, 0].max() 93 | top, bottom = np_contours[:, 1].min(), np_contours[:, 1].max() 94 | box = np.array( 95 | [[left, top], [right, top], [right, bottom], [left, bottom]], 96 | dtype=np.float32, 97 | ) 98 | 99 | # make clock-wise order 100 | startidx = box.sum(axis=1).argmin() 101 | box = np.roll(box, 4 - startidx, 0) 102 | 103 | max_confidence = max(max_confidence, line_max) 104 | 105 | confidences.append(line_max) 106 | det.append(box) 107 | 108 | if max_confidence > 0: 109 | confidences = [c / max_confidence for c in confidences] 110 | return det, confidences 111 | 112 | 113 | def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]: 114 | if text_threshold is None: 115 | text_threshold = settings.DETECTOR_TEXT_THRESHOLD 116 | if low_text is None: 117 | low_text = settings.DETECTOR_BLANK_THRESHOLD 118 | 119 | if textmap.dtype != np.float32: 120 | textmap = textmap.astype(np.float32) 121 | 122 | boxes, confidences = detect_boxes(textmap, text_threshold, low_text) 123 | # From point form to box form 124 | return [ 125 | PolygonBox(polygon=box, confidence=confidence) 126 | for box, confidence in zip(boxes, confidences) 127 | ] 128 | 129 | 130 | def get_and_clean_boxes( 131 | textmap, processor_size, image_size, text_threshold=None, low_text=None 132 | ) -> List[PolygonBox]: 133 | bboxes = get_detected_boxes(textmap, text_threshold, low_text) 134 | for bbox in bboxes: 135 | bbox.rescale(processor_size, image_size) 136 | bbox.fit_to_bounds([0, 0, image_size[0], image_size[1]]) 137 | 138 | bboxes = clean_boxes(bboxes) 139 | return bboxes 140 | 141 | 142 | def parallel_get_boxes(preds, orig_sizes, include_maps=False): 143 | heatmap, affinity_map = preds 144 | heat_img, aff_img = None, None 145 | 146 | if include_maps: 147 | heat_img = Image.fromarray((heatmap * 255).astype(np.uint8)) 148 | aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8)) 149 | heatmap_size = list(reversed(heatmap.shape)) 150 | bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes) 151 | for box in bboxes: 152 | # Skip for vertical boxes 153 | if box.height < 3 * box.width: 154 | box.expand(x_margin=0, y_margin=settings.DETECTOR_BOX_Y_EXPAND_MARGIN) 155 | box.fit_to_bounds( 156 | [0, 0, orig_sizes[0], orig_sizes[1]] 157 | ) # Fix any bad expands 158 | 159 | result = TextDetectionResult( 160 | bboxes=bboxes, 161 | heatmap=heat_img, 162 | affinity_map=aff_img, 163 | image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]], 164 | ) 165 | return result 166 | ``` -------------------------------------------------------------------------------- /surya/recognition/util.py: -------------------------------------------------------------------------------- ```python 1 | import re 2 | from typing import List, Tuple 3 | 4 | import numpy 5 | import torch 6 | 7 | from surya.common.polygon import PolygonBox 8 | from surya.recognition.schema import TextLine, TextWord, TextChar 9 | 10 | MATH_SYMBOLS = ["+", "-", "*", "=", "^", "_", "\\", "{", "}"] 11 | 12 | 13 | def unwrap_math(text: str) -> str: 14 | if len(text) > 50: 15 | return text 16 | 17 | # Detected as math, but does not contain LaTeX commands 18 | if ( 19 | re.match(r'^\s*<math(?:\s+display="inline")?.*?</math>\s*$', text, re.DOTALL) 20 | and text.count("<math") == 1 21 | and not any([symb in text for symb in MATH_SYMBOLS]) 22 | ): 23 | # Remove math tags 24 | text = re.sub(r"<math.*?>", "", text) 25 | text = re.sub(r"</math>", "", text) 26 | 27 | return text 28 | 29 | 30 | MATH_BLOCK = re.compile(r"(<math\b[^>]*>)(.*?)</math>", flags=re.I | re.S) 31 | STRIP_TAGS = re.compile(r"</?(?:br|u|del|mark|i|b|sup|sub)\b[^>]*>", flags=re.I | re.S) 32 | DEFAULT_TAGS_TO_FILTER = ["p", "li", "ul", "ol", "table", "td", "tr", "th", "tbody", "pre"] 33 | 34 | def filter_blacklist_tags(text_chars: List[TextChar], tags_to_filter: List[str] = None) -> List[TextChar]: 35 | filtered_chars = [] 36 | char_buffer = [] 37 | in_tag = False 38 | if tags_to_filter is None: 39 | tags_to_filter = DEFAULT_TAGS_TO_FILTER 40 | 41 | for text_char in text_chars: 42 | char = text_char.text 43 | 44 | if char.startswith("<") or in_tag: 45 | in_tag = True 46 | char_buffer.append(text_char) 47 | if char.endswith(">"): 48 | full_tag = ''.join(c.text for c in char_buffer) 49 | inner = full_tag[1:-1].strip() # remove < > 50 | inner = inner.strip("/") # remove '/' 51 | 52 | # Possible that it is just an empty <> 53 | if not inner: 54 | filtered_chars.extend(char_buffer) 55 | in_tag = False 56 | char_buffer = [] 57 | continue 58 | 59 | tag_name_candidate = inner.split()[0] # remove any attributes 60 | if tag_name_candidate in tags_to_filter: 61 | # Discard tag 62 | pass 63 | else: 64 | # Keep tag 65 | filtered_chars.extend(char_buffer) 66 | 67 | in_tag = False 68 | char_buffer = [] 69 | else: 70 | filtered_chars.append(text_char) 71 | 72 | # Flush buffer if we never reached a tag close 73 | if char_buffer: 74 | filtered_chars.extend(char_buffer) 75 | 76 | return filtered_chars 77 | 78 | 79 | def clean_math_tags(html: str) -> str: 80 | # strip unwanted tags inside every well‑formed <math>…</math> 81 | def _inner(m): 82 | inner = STRIP_TAGS.sub("", m.group(2)) 83 | return f"{m.group(1)}{inner}</math>" if inner.strip() else "" 84 | 85 | cleaned = MATH_BLOCK.sub(_inner, html) 86 | 87 | # drop only orphan *closing* </math> tags 88 | depth = 0 89 | parts = [] 90 | for token in re.split(r"(</?math[^>]*>)", cleaned, flags=re.I): 91 | if token.lower().startswith("<math"): 92 | depth += 1 93 | parts.append(token) 94 | elif token.lower() == "</math>": 95 | if depth: # keep it only if it matches an open 96 | depth -= 1 97 | parts.append(token) 98 | # else: skip orphan closing tag 99 | else: 100 | parts.append(token) 101 | return "".join(parts) 102 | 103 | 104 | def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25): 105 | # Sorts in reading order. Not 100% accurate, this should only 106 | # be used as a starting point for more advanced sorting. 107 | vertical_groups = {} 108 | for line in lines: 109 | group_key = ( 110 | round( 111 | line.bbox[1] 112 | if isinstance(line, TextLine) 113 | else line["bbox"][1] / tolerance 114 | ) 115 | * tolerance 116 | ) 117 | if group_key not in vertical_groups: 118 | vertical_groups[group_key] = [] 119 | vertical_groups[group_key].append(line) 120 | 121 | # Sort each group horizontally and flatten the groups into a single list 122 | sorted_lines = [] 123 | for _, group in sorted(vertical_groups.items()): 124 | sorted_group = sorted( 125 | group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0] 126 | ) 127 | sorted_lines.extend(sorted_group) 128 | 129 | return sorted_lines 130 | 131 | 132 | def clean_close_polygons(bboxes: List[List[List[int]]], thresh: float = 0.1): 133 | if len(bboxes) < 2: 134 | return bboxes 135 | 136 | new_bboxes = [bboxes[0]] 137 | for i in range(1, len(bboxes)): 138 | close = True 139 | prev_bbox = bboxes[i - 1] 140 | bbox = bboxes[i] 141 | for j in range(4): 142 | if ( 143 | abs(bbox[j][0] - prev_bbox[j][0]) > thresh 144 | or abs(bbox[j][1] - prev_bbox[j][1]) > thresh 145 | ): 146 | close = False 147 | break 148 | 149 | if not close: 150 | new_bboxes.append(bboxes[i]) 151 | 152 | return new_bboxes 153 | 154 | 155 | def words_from_chars(chars: List[TextChar], line_box: PolygonBox): 156 | words = [] 157 | word = None 158 | for i, char in enumerate(chars): 159 | if not char.bbox_valid: 160 | if word: 161 | words.append(word) 162 | word = None 163 | continue 164 | 165 | if not word: 166 | word = TextWord(**char.model_dump()) 167 | 168 | # Fit bounds to line if first word 169 | if i == 0: 170 | word.merge_left(line_box) 171 | 172 | elif not char.text.strip(): 173 | if word: 174 | words.append(word) 175 | word = None 176 | else: 177 | # Merge bboxes 178 | word.merge(char) 179 | word.text = word.text + char.text 180 | 181 | if i == len(chars) - 1: 182 | word.merge_right(line_box) 183 | if word: 184 | words.append(word) 185 | 186 | return words ``` -------------------------------------------------------------------------------- /surya/common/s3.py: -------------------------------------------------------------------------------- ```python 1 | import json 2 | import os 3 | import shutil 4 | import tempfile 5 | import time 6 | from concurrent.futures import ThreadPoolExecutor 7 | from pathlib import Path 8 | 9 | import requests 10 | from tqdm import tqdm 11 | 12 | from surya.logging import get_logger 13 | from surya.settings import settings 14 | 15 | logger = get_logger() 16 | 17 | # Lock file expiration time in seconds (10 minutes) 18 | LOCK_EXPIRATION = 600 19 | 20 | 21 | def join_urls(url1: str, url2: str): 22 | url1 = url1.rstrip("/") 23 | url2 = url2.lstrip("/") 24 | return f"{url1}/{url2}" 25 | 26 | 27 | def get_model_name(pretrained_model_name_or_path: str): 28 | return pretrained_model_name_or_path.split("/")[0] 29 | 30 | 31 | def download_file(remote_path: str, local_path: str, chunk_size: int = 1024 * 1024): 32 | local_path = Path(local_path) 33 | try: 34 | response = requests.get(remote_path, stream=True, allow_redirects=True) 35 | response.raise_for_status() # Raise an exception for bad status codes 36 | 37 | # Get file size from headers for progress bar 38 | total_size = int(response.headers.get('content-length', 0)) 39 | 40 | # Create progress bar with file name and size info 41 | filename = local_path.name 42 | pbar = tqdm( 43 | total=total_size, 44 | unit='B', 45 | unit_scale=True, 46 | unit_divisor=1024, 47 | desc=f"Downloading {filename}", 48 | miniters=1 49 | ) 50 | 51 | with open(local_path, "wb") as f: 52 | downloaded = 0 53 | for chunk in response.iter_content(chunk_size=chunk_size): 54 | if chunk: 55 | f.write(chunk) 56 | downloaded += len(chunk) 57 | pbar.update(len(chunk)) 58 | 59 | pbar.close() 60 | return local_path 61 | except Exception as e: 62 | if local_path.exists(): 63 | local_path.unlink() 64 | logger.error(f"Download error for file {remote_path}: {str(e)}") 65 | raise 66 | 67 | 68 | def check_manifest(local_dir: str): 69 | local_dir = Path(local_dir) 70 | manifest_path = local_dir / "manifest.json" 71 | if not os.path.exists(manifest_path): 72 | return False 73 | 74 | try: 75 | with open(manifest_path, "r") as f: 76 | manifest = json.load(f) 77 | for file in manifest["files"]: 78 | if not os.path.exists(local_dir / file): 79 | return False 80 | except Exception: 81 | return False 82 | 83 | return True 84 | 85 | 86 | def download_directory(remote_path: str, local_dir: str): 87 | model_name = get_model_name(remote_path) 88 | s3_url = join_urls(settings.S3_BASE_URL, remote_path) 89 | # Check to see if it's already downloaded 90 | model_exists = check_manifest(local_dir) 91 | if model_exists: 92 | return 93 | 94 | # Use tempfile.TemporaryDirectory to automatically clean up 95 | with tempfile.TemporaryDirectory() as temp_dir: 96 | # Download the manifest file 97 | manifest_file = join_urls(s3_url, "manifest.json") 98 | manifest_path = os.path.join(temp_dir, "manifest.json") 99 | download_file(manifest_file, manifest_path) 100 | 101 | # List and download all files 102 | with open(manifest_path, "r") as f: 103 | manifest = json.load(f) 104 | 105 | pbar = tqdm( 106 | desc=f"Downloading {model_name} model to {local_dir}", 107 | total=len(manifest["files"]), 108 | ) 109 | 110 | with ThreadPoolExecutor( 111 | max_workers=settings.PARALLEL_DOWNLOAD_WORKERS 112 | ) as executor: 113 | futures = [] 114 | for file in manifest["files"]: 115 | remote_file = join_urls(s3_url, file) 116 | local_file = os.path.join(temp_dir, file) 117 | futures.append(executor.submit(download_file, remote_file, local_file)) 118 | 119 | for future in futures: 120 | future.result() 121 | pbar.update(1) 122 | 123 | pbar.close() 124 | 125 | # Move all files to new directory 126 | for file in os.listdir(temp_dir): 127 | shutil.move(os.path.join(temp_dir, file), local_dir) 128 | 129 | 130 | class S3DownloaderMixin: 131 | s3_prefix = "s3://" 132 | 133 | @classmethod 134 | def get_local_path(cls, pretrained_model_name_or_path) -> str: 135 | if pretrained_model_name_or_path.startswith(cls.s3_prefix): 136 | pretrained_model_name_or_path = pretrained_model_name_or_path.replace( 137 | cls.s3_prefix, "" 138 | ) 139 | cache_dir = settings.MODEL_CACHE_DIR 140 | local_path = os.path.join(cache_dir, pretrained_model_name_or_path) 141 | os.makedirs(local_path, exist_ok=True) 142 | else: 143 | local_path = "" 144 | return local_path 145 | 146 | @classmethod 147 | def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): 148 | # Allow loading models directly from the hub, or using s3 149 | if not pretrained_model_name_or_path.startswith(cls.s3_prefix): 150 | return super().from_pretrained( 151 | pretrained_model_name_or_path, *args, **kwargs 152 | ) 153 | 154 | local_path = cls.get_local_path(pretrained_model_name_or_path) 155 | pretrained_model_name_or_path = pretrained_model_name_or_path.replace( 156 | cls.s3_prefix, "" 157 | ) 158 | 159 | # Retry logic for downloading the model folder 160 | retries = 3 161 | delay = 5 162 | attempt = 0 163 | success = False 164 | while not success and attempt < retries: 165 | try: 166 | download_directory(pretrained_model_name_or_path, local_path) 167 | success = True # If download succeeded 168 | except Exception as e: 169 | logger.error( 170 | f"Error downloading model from {pretrained_model_name_or_path}. Attempt {attempt + 1} of {retries}. Error: {e}" 171 | ) 172 | attempt += 1 173 | if attempt < retries: 174 | logger.info(f"Retrying in {delay} seconds...") 175 | time.sleep(delay) # Wait before retrying 176 | else: 177 | logger.error( 178 | f"Failed to download {pretrained_model_name_or_path} after {retries} attempts." 179 | ) 180 | raise e # Reraise exception after max retries 181 | 182 | return super().from_pretrained(local_path, *args, **kwargs) 183 | ``` -------------------------------------------------------------------------------- /surya/detection/__init__.py: -------------------------------------------------------------------------------- ```python 1 | from concurrent.futures import ThreadPoolExecutor 2 | from typing import List, Generator, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | from surya.common.predictor import BasePredictor 12 | from surya.common.xla import mark_step 13 | 14 | from surya.detection.loader import DetectionModelLoader 15 | from surya.detection.parallel import FakeExecutor 16 | from surya.detection.util import get_total_splits, split_image 17 | from surya.detection.schema import TextDetectionResult 18 | from surya.settings import settings 19 | from surya.detection.heatmap import parallel_get_boxes 20 | 21 | 22 | class DetectionPredictor(BasePredictor): 23 | model_loader_cls = DetectionModelLoader 24 | batch_size = settings.DETECTOR_BATCH_SIZE 25 | default_batch_sizes = {"cpu": 8, "mps": 8, "cuda": 36, "xla": 18} 26 | 27 | def __call__( 28 | self, images: List[Image.Image], batch_size=None, include_maps=False 29 | ) -> List[TextDetectionResult]: 30 | detection_generator = self.batch_detection( 31 | images, batch_size=batch_size, static_cache=settings.DETECTOR_STATIC_CACHE 32 | ) 33 | 34 | postprocessing_futures = [] 35 | max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) 36 | parallelize = ( 37 | not settings.IN_STREAMLIT 38 | and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH 39 | ) 40 | executor = ThreadPoolExecutor if parallelize else FakeExecutor 41 | with executor(max_workers=max_workers) as e: 42 | for preds, orig_sizes in detection_generator: 43 | for pred, orig_size in zip(preds, orig_sizes): 44 | postprocessing_futures.append( 45 | e.submit(parallel_get_boxes, pred, orig_size, include_maps) 46 | ) 47 | 48 | return [future.result() for future in postprocessing_futures] 49 | 50 | def prepare_image(self, img): 51 | new_size = (self.processor.size["width"], self.processor.size["height"]) 52 | 53 | # This double resize actually necessary for downstream accuracy 54 | img.thumbnail(new_size, Image.Resampling.LANCZOS) 55 | img = img.resize( 56 | new_size, Image.Resampling.LANCZOS 57 | ) # Stretch smaller dimension to fit new size 58 | 59 | img = np.asarray(img, dtype=np.uint8) 60 | img = self.processor(img)["pixel_values"][0] 61 | img = torch.from_numpy(img) 62 | return img 63 | 64 | def batch_detection( 65 | self, images: List, batch_size=None, static_cache=False 66 | ) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]: 67 | assert all([isinstance(image, Image.Image) for image in images]) 68 | if batch_size is None: 69 | batch_size = self.get_batch_size() 70 | heatmap_count = self.model.config.num_labels 71 | 72 | orig_sizes = [image.size for image in images] 73 | splits_per_image = [ 74 | get_total_splits(size, self.processor.size["height"]) for size in orig_sizes 75 | ] 76 | 77 | batches = [] 78 | current_batch_size = 0 79 | current_batch = [] 80 | for i in range(len(images)): 81 | if current_batch_size + splits_per_image[i] > batch_size: 82 | if len(current_batch) > 0: 83 | batches.append(current_batch) 84 | current_batch = [] 85 | current_batch_size = 0 86 | current_batch.append(i) 87 | current_batch_size += splits_per_image[i] 88 | 89 | if len(current_batch) > 0: 90 | batches.append(current_batch) 91 | 92 | for batch_idx in tqdm( 93 | range(len(batches)), desc="Detecting bboxes", disable=self.disable_tqdm 94 | ): 95 | batch_image_idxs = batches[batch_idx] 96 | batch_images = [images[j].convert("RGB") for j in batch_image_idxs] 97 | 98 | split_index = [] 99 | split_heights = [] 100 | image_splits = [] 101 | for image_idx, image in enumerate(batch_images): 102 | image_parts, split_height = split_image( 103 | image, self.processor.size["height"] 104 | ) 105 | image_splits.extend(image_parts) 106 | split_index.extend([image_idx] * len(image_parts)) 107 | split_heights.extend(split_height) 108 | 109 | image_splits = [self.prepare_image(image) for image in image_splits] 110 | # Batch images in dim 0 111 | batch = torch.stack(image_splits, dim=0).to(self.model.dtype) 112 | if static_cache: 113 | batch = self.pad_to_batch_size(batch, batch_size) 114 | 115 | with settings.INFERENCE_MODE(): 116 | pred = self.model( 117 | pixel_values=batch.to(self.model.device) 118 | ) # Moving the to device here fixes issues with xla recompilation 119 | 120 | logits = pred.logits 121 | correct_shape = [ 122 | self.processor.size["height"], 123 | self.processor.size["width"], 124 | ] 125 | current_shape = list(logits.shape[2:]) 126 | if current_shape != correct_shape: 127 | logits = F.interpolate( 128 | logits, size=correct_shape, mode="bilinear", align_corners=False 129 | ) 130 | mark_step() 131 | 132 | logits = logits.to(torch.float32).cpu().numpy() 133 | preds = [] 134 | for i, (idx, height) in enumerate(zip(split_index, split_heights)): 135 | # If our current prediction length is below the image idx, that means we have a new image 136 | # Otherwise, we need to add to the current image 137 | if len(preds) <= idx: 138 | preds.append([logits[i][k] for k in range(heatmap_count)]) 139 | else: 140 | heatmaps = preds[idx] 141 | pred_heatmaps = [logits[i][k] for k in range(heatmap_count)] 142 | 143 | if height < self.processor.size["height"]: 144 | # Cut off padding to get original height 145 | pred_heatmaps = [ 146 | pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps 147 | ] 148 | 149 | for k in range(heatmap_count): 150 | heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]]) 151 | preds[idx] = heatmaps 152 | 153 | yield preds, [orig_sizes[j] for j in batch_image_idxs] 154 | 155 | torch.cuda.empty_cache() 156 | ``` -------------------------------------------------------------------------------- /benchmark/utils/metrics.py: -------------------------------------------------------------------------------- ```python 1 | from functools import partial 2 | from itertools import repeat 3 | 4 | import numpy as np 5 | from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor 6 | 7 | 8 | def box_area(box): 9 | return (box[2] - box[0]) * (box[3] - box[1]) 10 | 11 | 12 | def calculate_iou(box1, box2, box1_only=False): 13 | intersection = intersection_area(box1, box2) 14 | union = box_area(box1) 15 | if not box1_only: 16 | union += box_area(box2) - intersection 17 | 18 | if union == 0: 19 | return 0 20 | return intersection / union 21 | 22 | 23 | def match_boxes(preds, references): 24 | num_actual = len(references) 25 | num_predicted = len(preds) 26 | 27 | iou_matrix = np.zeros((num_actual, num_predicted)) 28 | for i, actual in enumerate(references): 29 | for j, pred in enumerate(preds): 30 | iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True) 31 | 32 | sorted_indices = np.argsort(iou_matrix, axis=None)[::-1] 33 | sorted_ious = iou_matrix.flatten()[sorted_indices] 34 | actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape) 35 | 36 | assigned_actual = set() 37 | assigned_pred = set() 38 | 39 | matches = [] 40 | for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious): 41 | i, j = idx 42 | if i not in assigned_actual and j not in assigned_pred: 43 | iou_val = iou_matrix[i, j] 44 | if iou_val > .95: # Account for rounding on box edges 45 | iou_val = 1.0 46 | matches.append((i, j, iou_val)) 47 | assigned_actual.add(i) 48 | assigned_pred.add(j) 49 | 50 | unassigned_actual = set(range(num_actual)) - assigned_actual 51 | unassigned_pred = set(range(num_predicted)) - assigned_pred 52 | matches.extend([(i, None, -1.0) for i in unassigned_actual]) 53 | matches.extend([(None, j, 0.0) for j in unassigned_pred]) 54 | 55 | return matches 56 | 57 | def penalized_iou_score(preds, references): 58 | matches = match_boxes(preds, references) 59 | iou = sum([match[2] for match in matches]) / len(matches) 60 | return iou 61 | 62 | def intersection_pixels(box1, box2): 63 | x_left = max(box1[0], box2[0]) 64 | y_top = max(box1[1], box2[1]) 65 | x_right = min(box1[2], box2[2]) 66 | y_bottom = min(box1[3], box2[3]) 67 | 68 | if x_right < x_left or y_bottom < y_top: 69 | return set() 70 | 71 | x_left, x_right = int(x_left), int(x_right) 72 | y_top, y_bottom = int(y_top), int(y_bottom) 73 | 74 | coords = np.meshgrid(np.arange(x_left, x_right), np.arange(y_top, y_bottom)) 75 | pixels = set(zip(coords[0].flat, coords[1].flat)) 76 | 77 | return pixels 78 | 79 | 80 | def calculate_coverage(box, other_boxes, penalize_double=False): 81 | box_area = (box[2] - box[0]) * (box[3] - box[1]) 82 | if box_area == 0: 83 | return 0 84 | 85 | # find total coverage of the box 86 | covered_pixels = set() 87 | double_coverage = list() 88 | for other_box in other_boxes: 89 | ia = intersection_pixels(box, other_box) 90 | double_coverage.append(list(covered_pixels.intersection(ia))) 91 | covered_pixels = covered_pixels.union(ia) 92 | 93 | # Penalize double coverage - having multiple bboxes overlapping the same pixels 94 | double_coverage_penalty = len(double_coverage) 95 | if not penalize_double: 96 | double_coverage_penalty = 0 97 | covered_pixels_count = max(0, len(covered_pixels) - double_coverage_penalty) 98 | return covered_pixels_count / box_area 99 | 100 | 101 | def intersection_area(box1, box2): 102 | x_left = max(box1[0], box2[0]) 103 | y_top = max(box1[1], box2[1]) 104 | x_right = min(box1[2], box2[2]) 105 | y_bottom = min(box1[3], box2[3]) 106 | 107 | if x_right < x_left or y_bottom < y_top: 108 | return 0.0 109 | 110 | return (x_right - x_left) * (y_bottom - y_top) 111 | 112 | 113 | def calculate_coverage_fast(box, other_boxes, penalize_double=False): 114 | box = np.array(box) 115 | other_boxes = np.array(other_boxes) 116 | 117 | # Calculate box area 118 | box_area = (box[2] - box[0]) * (box[3] - box[1]) 119 | if box_area == 0: 120 | return 0 121 | 122 | x_left = np.maximum(box[0], other_boxes[:, 0]) 123 | y_top = np.maximum(box[1], other_boxes[:, 1]) 124 | x_right = np.minimum(box[2], other_boxes[:, 2]) 125 | y_bottom = np.minimum(box[3], other_boxes[:, 3]) 126 | 127 | widths = np.maximum(0, x_right - x_left) 128 | heights = np.maximum(0, y_bottom - y_top) 129 | intersect_areas = widths * heights 130 | 131 | total_intersect = np.sum(intersect_areas) 132 | 133 | return min(1.0, total_intersect / box_area) 134 | 135 | 136 | def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True): 137 | if len(references) == 0: 138 | return { 139 | "precision": 1, 140 | "recall": 1, 141 | } 142 | 143 | if len(preds) == 0: 144 | return { 145 | "precision": 0, 146 | "recall": 0, 147 | } 148 | 149 | # If we're not penalizing double coverage, we can use a faster calculation 150 | coverage_func = calculate_coverage_fast 151 | if penalize_double: 152 | coverage_func = calculate_coverage 153 | 154 | with ThreadPoolExecutor(max_workers=workers) as executor: 155 | precision_func = partial(coverage_func, penalize_double=penalize_double) 156 | precision_iou = executor.map(precision_func, preds, repeat(references)) 157 | reference_iou = executor.map(coverage_func, references, repeat(preds)) 158 | 159 | precision_classes = [1 if i > threshold else 0 for i in precision_iou] 160 | precision = sum(precision_classes) / len(precision_classes) 161 | 162 | recall_classes = [1 if i > threshold else 0 for i in reference_iou] 163 | recall = sum(recall_classes) / len(recall_classes) 164 | 165 | return { 166 | "precision": precision, 167 | "recall": recall, 168 | } 169 | 170 | 171 | def mean_coverage(preds, references): 172 | coverages = [] 173 | 174 | for box1 in references: 175 | coverage = calculate_coverage(box1, preds) 176 | coverages.append(coverage) 177 | 178 | for box2 in preds: 179 | coverage = calculate_coverage(box2, references) 180 | coverages.append(coverage) 181 | 182 | # Calculate the average coverage over all comparisons 183 | if len(coverages) == 0: 184 | return 0 185 | coverage = sum(coverages) / len(coverages) 186 | return {"coverage": coverage} 187 | 188 | 189 | def rank_accuracy(preds, references): 190 | # Preds and references need to be aligned so each position refers to the same bbox 191 | pairs = [] 192 | for i, pred in enumerate(preds): 193 | for j, pred2 in enumerate(preds): 194 | if i == j: 195 | continue 196 | pairs.append((i, j, pred > pred2)) 197 | 198 | # Find how many of the prediction rankings are correct 199 | correct = 0 200 | for i, ref in enumerate(references): 201 | for j, ref2 in enumerate(references): 202 | if (i, j, ref > ref2) in pairs: 203 | correct += 1 204 | 205 | return correct / len(pairs) ``` -------------------------------------------------------------------------------- /surya/common/donut/processor.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Dict, Union, Optional, List, Iterable 2 | 3 | import cv2 4 | from torch import TensorType 5 | from transformers import ImageProcessingMixin 6 | from transformers.image_processing_utils import BatchFeature 7 | from transformers.image_transforms import pad, normalize 8 | from transformers.image_utils import ( 9 | ImageInput, 10 | ChannelDimension, 11 | make_list_of_images, 12 | get_image_size, 13 | ) 14 | import numpy as np 15 | from PIL import Image 16 | import PIL 17 | from transformers.utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD 18 | 19 | from surya.common.s3 import S3DownloaderMixin 20 | from surya.settings import settings 21 | 22 | 23 | class SuryaEncoderImageProcessor(S3DownloaderMixin, ImageProcessingMixin): 24 | def __init__( 25 | self, 26 | *args, 27 | max_size=None, 28 | align_long_axis=False, 29 | rescale_factor: Union[int, float] = 1 / 255, 30 | image_mean: Optional[Union[float, List[float]]] = None, 31 | image_std: Optional[Union[float, List[float]]] = None, 32 | **kwargs, 33 | ): 34 | super().__init__(*args, **kwargs) 35 | 36 | self.patch_size = kwargs.get("patch_size", (4, 4)) 37 | self.max_size = max_size 38 | self.do_align_long_axis = align_long_axis 39 | self.resample = Image.Resampling.BILINEAR 40 | self.rescale_factor = rescale_factor 41 | self.image_mean = ( 42 | image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN 43 | ) 44 | self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD 45 | 46 | def __call__(self, images, **kwargs) -> PIL.Image.Image: 47 | """Preprocess an image or a batch of images.""" 48 | return self.preprocess(images, **kwargs) 49 | 50 | @classmethod 51 | def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4): 52 | max_width, max_height = size["width"], size["height"] 53 | 54 | resized_image = cv2.resize( 55 | image, (max_width, max_height), interpolation=interpolation 56 | ) 57 | resized_image = resized_image.transpose(2, 0, 1) 58 | 59 | return resized_image 60 | 61 | def process_inner(self, images: List[np.ndarray]): 62 | assert images[0].shape[2] == 3 # RGB input images, channel dim last 63 | 64 | if self.do_align_long_axis: 65 | # Rotate if the bbox is wider than it is tall 66 | images = [ 67 | SuryaEncoderImageProcessor.align_long_axis( 68 | image, size=self.max_size, input_data_format=ChannelDimension.LAST 69 | ) 70 | for image in images 71 | ] 72 | 73 | # Verify that the image is wider than it is tall 74 | for img in images: 75 | assert img.shape[1] >= img.shape[0] 76 | 77 | # This also applies the right channel dim format, to channel x height x width 78 | images = [ 79 | SuryaEncoderImageProcessor.numpy_resize(img, self.max_size, self.resample) 80 | for img in images 81 | ] 82 | assert images[0].shape[0] == 3 # RGB input images, channel dim first 83 | 84 | # Convert to float32 for rescale/normalize 85 | images = [img.astype(np.float32) for img in images] 86 | 87 | # Pads with 255 (whitespace) 88 | # Pad to max size to improve performance 89 | max_size = self.max_size 90 | images = [ 91 | SuryaEncoderImageProcessor.pad_image( 92 | image=image, 93 | size=max_size, 94 | input_data_format=ChannelDimension.FIRST, 95 | pad_value=settings.RECOGNITION_PAD_VALUE, 96 | ) 97 | for image in images 98 | ] 99 | 100 | # Rescale and normalize 101 | for idx in range(len(images)): 102 | images[idx] = (images[idx].astype(np.float64) * self.rescale_factor).astype( 103 | np.float32 104 | ) 105 | 106 | images = [ 107 | SuryaEncoderImageProcessor.normalize( 108 | img, 109 | mean=self.image_mean, 110 | std=self.image_std, 111 | input_data_format=ChannelDimension.FIRST, 112 | ) 113 | for img in images 114 | ] 115 | 116 | return images 117 | 118 | def preprocess( 119 | self, 120 | images: ImageInput, 121 | return_tensors: Optional[Union[str, TensorType]] = None, 122 | **kwargs, 123 | ) -> PIL.Image.Image: 124 | images = make_list_of_images(images) 125 | 126 | # Convert to numpy for later processing steps 127 | images = [np.array(img) for img in images] 128 | images = self.process_inner(images) 129 | 130 | data = {"pixel_values": images} 131 | return BatchFeature(data=data, tensor_type=return_tensors) 132 | 133 | @classmethod 134 | def pad_image( 135 | cls, 136 | image: np.ndarray, 137 | size: Dict[str, int], 138 | data_format: Optional[Union[str, ChannelDimension]] = None, 139 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 140 | pad_value: float = 0.0, 141 | ) -> np.ndarray: 142 | output_height, output_width = size["height"], size["width"] 143 | input_height, input_width = get_image_size(image, channel_dim=input_data_format) 144 | 145 | delta_width = output_width - input_width 146 | delta_height = output_height - input_height 147 | 148 | assert delta_width >= 0 and delta_height >= 0 149 | 150 | pad_top = delta_height // 2 151 | pad_left = delta_width // 2 152 | 153 | pad_bottom = delta_height - pad_top 154 | pad_right = delta_width - pad_left 155 | 156 | padding = ((pad_top, pad_bottom), (pad_left, pad_right)) 157 | return pad( 158 | image, 159 | padding, 160 | data_format=data_format, 161 | input_data_format=input_data_format, 162 | constant_values=pad_value, 163 | ) 164 | 165 | @classmethod 166 | def align_long_axis( 167 | cls, image: np.ndarray, size: Dict[str, int], **kwargs 168 | ) -> np.ndarray: 169 | input_height, input_width = image.shape[:2] 170 | output_height, output_width = size["height"], size["width"] 171 | 172 | if (output_width < output_height and input_width > input_height) or ( 173 | output_width > output_height and input_width < input_height 174 | ): 175 | image = np.rot90(image, 3) 176 | 177 | return image 178 | 179 | @classmethod 180 | def normalize( 181 | cls, 182 | image: np.ndarray, 183 | mean: Union[float, Iterable[float]], 184 | std: Union[float, Iterable[float]], 185 | data_format: Optional[Union[str, ChannelDimension]] = None, 186 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 187 | **kwargs, 188 | ) -> np.ndarray: 189 | return normalize( 190 | image, 191 | mean=mean, 192 | std=std, 193 | data_format=data_format, 194 | input_data_format=input_data_format, 195 | **kwargs, 196 | ) 197 | ``` -------------------------------------------------------------------------------- /benchmark/table_recognition.py: -------------------------------------------------------------------------------- ```python 1 | import click 2 | import collections 3 | import json 4 | 5 | from surya.debug.draw import draw_bboxes_on_image 6 | from tabulate import tabulate 7 | 8 | from surya.input.processing import convert_if_not_rgb 9 | from surya.table_rec import TableRecPredictor 10 | from surya.settings import settings 11 | from benchmark.utils.metrics import penalized_iou_score 12 | from benchmark.utils.tatr import load_tatr, batch_inference_tatr 13 | import os 14 | import time 15 | import datasets 16 | 17 | 18 | @click.command(help="Benchmark table rec dataset") 19 | @click.option( 20 | "--results_dir", 21 | type=str, 22 | help="Path to JSON file with benchmark results.", 23 | default=os.path.join(settings.RESULT_DIR, "benchmark"), 24 | ) 25 | @click.option( 26 | "--max_rows", 27 | type=int, 28 | help="Maximum number of images to run benchmark on.", 29 | default=512, 30 | ) 31 | @click.option("--tatr", is_flag=True, help="Run table transformer.", default=False) 32 | @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) 33 | def main(results_dir: str, max_rows: int, tatr: bool, debug: bool): 34 | table_rec_predictor = TableRecPredictor() 35 | 36 | pathname = "table_rec_bench" 37 | # These have already been shuffled randomly, so sampling from the start is fine 38 | split = "train" 39 | if max_rows is not None: 40 | split = f"train[:{max_rows}]" 41 | dataset = datasets.load_dataset(settings.TABLE_REC_BENCH_DATASET_NAME, split=split) 42 | images = list(dataset["image"]) 43 | images = convert_if_not_rgb(images) 44 | 45 | if settings.TABLE_REC_STATIC_CACHE: 46 | # Run through one batch to compile the model 47 | table_rec_predictor(images[:1]) 48 | 49 | start = time.time() 50 | table_rec_predictions = table_rec_predictor(images) 51 | surya_time = time.time() - start 52 | 53 | folder_name = os.path.basename(pathname).split(".")[0] 54 | result_path = os.path.join(results_dir, folder_name) 55 | os.makedirs(result_path, exist_ok=True) 56 | 57 | page_metrics = collections.OrderedDict() 58 | mean_col_iou = 0 59 | mean_row_iou = 0 60 | for idx, (pred, image) in enumerate(zip(table_rec_predictions, images)): 61 | row = dataset[idx] 62 | pred_row_boxes = [p.bbox for p in pred.rows] 63 | pred_col_bboxes = [p.bbox for p in pred.cols] 64 | actual_row_bboxes = [r["bbox"] for r in row["rows"]] 65 | actual_col_bboxes = [c["bbox"] for c in row["columns"]] 66 | row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes) 67 | col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes) 68 | page_results = { 69 | "row_score": row_score, 70 | "col_score": col_score, 71 | "row_count": len(actual_row_bboxes), 72 | "col_count": len(actual_col_bboxes), 73 | } 74 | 75 | mean_col_iou += col_score 76 | mean_row_iou += row_score 77 | 78 | page_metrics[idx] = page_results 79 | 80 | if debug: 81 | # Save debug images 82 | draw_img = image.copy() 83 | draw_bboxes_on_image( 84 | pred_row_boxes, 85 | draw_img, 86 | [f"Row {i}" for i in range(len(pred_row_boxes))], 87 | ) 88 | draw_bboxes_on_image( 89 | pred_col_bboxes, 90 | draw_img, 91 | [f"Col {i}" for i in range(len(pred_col_bboxes))], 92 | color="blue", 93 | ) 94 | draw_img.save(os.path.join(result_path, f"{idx}_bbox.png")) 95 | 96 | actual_draw_image = image.copy() 97 | draw_bboxes_on_image( 98 | actual_row_bboxes, 99 | actual_draw_image, 100 | [f"Row {i}" for i in range(len(actual_row_bboxes))], 101 | ) 102 | draw_bboxes_on_image( 103 | actual_col_bboxes, 104 | actual_draw_image, 105 | [f"Col {i}" for i in range(len(actual_col_bboxes))], 106 | color="blue", 107 | ) 108 | actual_draw_image.save(os.path.join(result_path, f"{idx}_actual.png")) 109 | 110 | mean_col_iou /= len(table_rec_predictions) 111 | mean_row_iou /= len(table_rec_predictions) 112 | 113 | out_data = { 114 | "surya": { 115 | "time": surya_time, 116 | "mean_row_iou": mean_row_iou, 117 | "mean_col_iou": mean_col_iou, 118 | "page_metrics": page_metrics, 119 | } 120 | } 121 | 122 | if tatr: 123 | tatr_model = load_tatr() 124 | start = time.time() 125 | tatr_predictions = batch_inference_tatr(tatr_model, images, 1) 126 | tatr_time = time.time() - start 127 | 128 | page_metrics = collections.OrderedDict() 129 | mean_col_iou = 0 130 | mean_row_iou = 0 131 | for idx, pred in enumerate(tatr_predictions): 132 | row = dataset[idx] 133 | pred_row_boxes = [p["bbox"] for p in pred["rows"]] 134 | pred_col_bboxes = [p["bbox"] for p in pred["cols"]] 135 | actual_row_bboxes = [r["bbox"] for r in row["rows"]] 136 | actual_col_bboxes = [c["bbox"] for c in row["columns"]] 137 | row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes) 138 | col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes) 139 | page_results = { 140 | "row_score": row_score, 141 | "col_score": col_score, 142 | "row_count": len(actual_row_bboxes), 143 | "col_count": len(actual_col_bboxes), 144 | } 145 | 146 | mean_col_iou += col_score 147 | mean_row_iou += row_score 148 | 149 | page_metrics[idx] = page_results 150 | 151 | mean_col_iou /= len(tatr_predictions) 152 | mean_row_iou /= len(tatr_predictions) 153 | 154 | out_data["tatr"] = { 155 | "time": tatr_time, 156 | "mean_row_iou": mean_row_iou, 157 | "mean_col_iou": mean_col_iou, 158 | "page_metrics": page_metrics, 159 | } 160 | 161 | with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: 162 | json.dump(out_data, f, indent=4) 163 | 164 | table = [ 165 | ["Model", "Row Intersection", "Col Intersection", "Time Per Image"], 166 | [ 167 | "Surya", 168 | f"{out_data['surya']['mean_row_iou']:.2f}", 169 | f"{out_data['surya']['mean_col_iou']:.5f}", 170 | f"{surya_time / len(images):.5f}", 171 | ], 172 | ] 173 | 174 | if tatr: 175 | table.append( 176 | [ 177 | "Table transformer", 178 | f"{out_data['tatr']['mean_row_iou']:.2f}", 179 | f"{out_data['tatr']['mean_col_iou']:.5f}", 180 | f"{tatr_time / len(images):.5f}", 181 | ] 182 | ) 183 | 184 | print(tabulate(table, headers="firstrow", tablefmt="github")) 185 | 186 | print( 187 | "Intersection is the average of the intersection % between each actual row/column, and the predictions. With penalties for too many/few predictions." 188 | ) 189 | print( 190 | "Note that table transformers is unbatched, since the example code in the repo is unbatched." 191 | ) 192 | print(f"Wrote results to {result_path}") 193 | 194 | 195 | if __name__ == "__main__": 196 | main() 197 | ``` -------------------------------------------------------------------------------- /surya/table_rec/model/decoder.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from surya.common.adetr.decoder import SuryaADETRDecoderModel, SuryaADETRDecoderPreTrainedModel 7 | from surya.table_rec.model.config import TableRecModelOutput 8 | from surya.table_rec.shaper import LabelShaper 9 | from surya.settings import settings 10 | 11 | 12 | class LabelEmbedding(nn.Module): 13 | def __init__(self, config): 14 | super().__init__() 15 | 16 | # Bboxes 17 | self.w_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 18 | self.h_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 19 | self.cx_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 20 | self.cy_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 21 | self.xskew_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 22 | self.yskew_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 23 | 24 | self.x1_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 25 | self.y1_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 26 | self.x2_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 27 | self.y2_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 28 | self.x3_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 29 | self.y3_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 30 | self.x4_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 31 | self.y4_embed = nn.Embedding(config.vocab_size, config.box_embed_size) 32 | 33 | # Get indexes for passed in tensor 34 | shaper = LabelShaper() 35 | self.component_idxs = shaper.component_idx_dict() 36 | merge_count = shaper.get_box_property("merges")[1] + config.special_token_count 37 | category_count = shaper.get_box_property("category")[1] + config.special_token_count 38 | 39 | # Other box properties 40 | self.category_embed = nn.Embedding(category_count, config.property_embed_size) 41 | self.merge_embed = nn.Embedding(merge_count, config.property_embed_size) 42 | self.colspan_embed = nn.Embedding(config.vocab_size, config.property_embed_size) 43 | 44 | self.config = config 45 | 46 | def forward(self, boxes: torch.LongTensor, *args): 47 | # Need to keep *args for compatibility with common decoder 48 | boxes = boxes.to(torch.long).clamp(0, self.config.vocab_size) 49 | 50 | boxes_unbound = boxes.to(torch.long).unbind(dim=-1) 51 | cx, cy, w, h, xskew, yskew = boxes_unbound[self.component_idxs["bbox"][0]:self.component_idxs["bbox"][1]] 52 | category = boxes_unbound[self.component_idxs["category"][0]:self.component_idxs["category"][1]][0] 53 | merges = boxes_unbound[self.component_idxs["merges"][0]:self.component_idxs["merges"][1]][0] 54 | colspan = boxes_unbound[self.component_idxs["colspan"][0]:self.component_idxs["colspan"][1]][0] 55 | 56 | xskew_actual = ((xskew - self.config.bbox_size // 2) / 2).to(torch.long) 57 | yskew_actual = ((yskew - self.config.bbox_size // 2) / 2).to(torch.long) 58 | 59 | x1 = (cx - w // 2 - xskew_actual).clamp(0, self.config.bbox_size).to(torch.long) 60 | y1 = (cy - h // 2 - yskew_actual).clamp(0, self.config.bbox_size).to(torch.long) 61 | x3 = (cx + w // 2 + xskew_actual).clamp(0, self.config.bbox_size).to(torch.long) 62 | y3 = (cy + h // 2 + yskew_actual).clamp(0, self.config.bbox_size).to(torch.long) 63 | 64 | size_embeds = self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) 65 | skew_embeds = self.xskew_embed(xskew) + self.yskew_embed(yskew) 66 | corner_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x3_embed(x3) + self.y3_embed(y3) 67 | box_embeds = size_embeds + skew_embeds + corner_embeds 68 | 69 | property_embeds = self.category_embed(category) + self.merge_embed(merges) + self.colspan_embed(colspan) 70 | 71 | # Cat bbox and property embeddings 72 | embedded = torch.cat([box_embeds, property_embeds], dim=-1) 73 | return embedded 74 | 75 | 76 | class SuryaTableRecDecoder(SuryaADETRDecoderPreTrainedModel): 77 | _tied_weights_keys = None 78 | 79 | def __init__(self, config, **kwargs): 80 | super().__init__(config) 81 | embed_tokens = LabelEmbedding(config) 82 | self.model = SuryaADETRDecoderModel( 83 | config, 84 | embedder=embed_tokens, 85 | static_cache=settings.TABLE_REC_STATIC_CACHE, 86 | max_boxes=settings.TABLE_REC_MAX_BOXES 87 | ) 88 | self.vocab_size = config.vocab_size 89 | 90 | shaper = LabelShaper() 91 | property_heads = {} 92 | for k in shaper.property_keys: 93 | _, kcount, mode = shaper.get_box_property(k) 94 | property_heads[k] = nn.Linear(config.hidden_size, kcount, bias=False) 95 | 96 | self.box_property_heads = nn.ModuleDict(property_heads) 97 | self.pre_output_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 98 | 99 | # Initialize weights and apply final processing 100 | self.post_init() 101 | 102 | def get_input_embeddings(self): 103 | return self.model.embed_tokens 104 | 105 | def set_input_embeddings(self, value): 106 | self.model.embed_tokens = value 107 | 108 | def get_output_embeddings(self): 109 | return self.lm_head 110 | 111 | def set_output_embeddings(self, new_embeddings): 112 | self.lm_head = new_embeddings 113 | 114 | def set_decoder(self, decoder): 115 | self.model = decoder 116 | 117 | def get_decoder(self): 118 | return self.model 119 | 120 | # Ignore copy 121 | def forward( 122 | self, 123 | input_ids: Optional[torch.LongTensor] = None, 124 | cache_position: Optional[torch.LongTensor] = None, 125 | attention_mask: Optional[torch.Tensor] = None, 126 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 127 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 128 | use_cache: Optional[bool] = None, 129 | prefill: bool = False, 130 | **kwargs 131 | ) -> Union[Tuple, TableRecModelOutput]: 132 | outputs = self.model( 133 | input_ids=input_ids, 134 | cache_position=cache_position, 135 | attention_mask=attention_mask, 136 | encoder_hidden_states=encoder_hidden_states, 137 | encoder_attention_mask=encoder_attention_mask, 138 | use_cache=use_cache, 139 | output_hidden_states=True, 140 | return_dict=True, 141 | prefill=prefill, 142 | ) 143 | 144 | hidden_states = self.pre_output_norm(outputs[0]) 145 | box_property_logits = {} 146 | for key in self.box_property_heads: 147 | box_property_logits[key] = self.box_property_heads[key](hidden_states) 148 | 149 | bbox_logits = nn.functional.sigmoid(box_property_logits["bbox"]) 150 | box_property_logits["bbox"] = bbox_logits 151 | 152 | return TableRecModelOutput( 153 | box_property_logits=box_property_logits, 154 | hidden_states=hidden_states, 155 | ) ``` -------------------------------------------------------------------------------- /surya/common/polygon.py: -------------------------------------------------------------------------------- ```python 1 | import copy 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | from pydantic import BaseModel, field_validator, computed_field 6 | import numbers 7 | 8 | 9 | class PolygonBox(BaseModel): 10 | polygon: List[List[float]] 11 | confidence: Optional[float] = None 12 | 13 | @field_validator("polygon", mode="before") 14 | @classmethod 15 | def convert_bbox_to_polygon(cls, value): 16 | if isinstance(value, (list, tuple)) and len(value) == 4: 17 | if all(isinstance(x, numbers.Number) for x in value): 18 | value = [float(v) for v in value] 19 | x_min, y_min, x_max, y_max = value 20 | polygon = [ 21 | [x_min, y_min], 22 | [x_max, y_min], 23 | [x_max, y_max], 24 | [x_min, y_max], 25 | ] 26 | return polygon 27 | elif all( 28 | isinstance(point, (list, tuple)) and len(point) == 2 for point in value 29 | ): 30 | value = [[float(v) for v in point] for point in value] 31 | return value 32 | elif isinstance(value, np.ndarray): 33 | if value.shape == (4, 2): 34 | return value.tolist() 35 | 36 | raise ValueError( 37 | f"Input must be either a bbox [x_min, y_min, x_max, y_max] or a polygon with 4 corners [(x,y), (x,y), (x,y), (x,y)]. All values must be numeric. You passed {value} of type {type(value)}. The first value is of type {type(value[0])}." 38 | ) 39 | 40 | @property 41 | def height(self): 42 | return self.bbox[3] - self.bbox[1] 43 | 44 | @property 45 | def width(self): 46 | return self.bbox[2] - self.bbox[0] 47 | 48 | @property 49 | def area(self): 50 | return self.width * self.height 51 | 52 | @computed_field 53 | @property 54 | def bbox(self) -> List[float]: 55 | x_coords = [point[0] for point in self.polygon] 56 | y_coords = [point[1] for point in self.polygon] 57 | return [min(x_coords), min(y_coords), max(x_coords), max(y_coords)] 58 | 59 | def rescale(self, processor_size, image_size): 60 | # Point is in x, y format 61 | page_width, page_height = processor_size 62 | 63 | img_width, img_height = image_size 64 | width_scaler = img_width / page_width 65 | height_scaler = img_height / page_height 66 | 67 | for corner in self.polygon: 68 | corner[0] = int(corner[0] * width_scaler) 69 | corner[1] = int(corner[1] * height_scaler) 70 | 71 | def round(self, divisor): 72 | for corner in self.polygon: 73 | corner[0] = int(corner[0] / divisor) * divisor 74 | corner[1] = int(corner[1] / divisor) * divisor 75 | 76 | def fit_to_bounds(self, bounds): 77 | new_corners = copy.deepcopy(self.polygon) 78 | for corner in new_corners: 79 | corner[0] = max(min(corner[0], bounds[2]), bounds[0]) 80 | corner[1] = max(min(corner[1], bounds[3]), bounds[1]) 81 | self.polygon = new_corners 82 | 83 | def merge(self, other): 84 | x1 = min(self.bbox[0], other.bbox[0]) 85 | y1 = min(self.bbox[1], other.bbox[1]) 86 | x2 = max(self.bbox[2], other.bbox[2]) 87 | y2 = max(self.bbox[3], other.bbox[3]) 88 | self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] 89 | 90 | def merge_left(self, other): 91 | x1 = min(self.bbox[0], other.bbox[0]) 92 | self.polygon[0][0] = x1 93 | self.polygon[3][0] = x1 94 | 95 | def merge_right(self, other): 96 | x2 = max(self.bbox[2], other.bbox[2]) 97 | self.polygon[1][0] = x2 98 | self.polygon[2][0] = x2 99 | 100 | def expand(self, x_margin: float, y_margin: float): 101 | new_polygon = [] 102 | x_margin = x_margin * self.width 103 | y_margin = y_margin * self.height 104 | for idx, poly in enumerate(self.polygon): 105 | if idx == 0: 106 | new_polygon.append([int(poly[0] - x_margin), int(poly[1] - y_margin)]) 107 | elif idx == 1: 108 | new_polygon.append([int(poly[0] + x_margin), int(poly[1] - y_margin)]) 109 | elif idx == 2: 110 | new_polygon.append([int(poly[0] + x_margin), int(poly[1] + y_margin)]) 111 | elif idx == 3: 112 | new_polygon.append([int(poly[0] - x_margin), int(poly[1] + y_margin)]) 113 | self.polygon = new_polygon 114 | 115 | def intersection_polygon(self, other) -> List[List[float]]: 116 | new_poly = [] 117 | for i in range(4): 118 | if i == 0: 119 | new_corner = [ 120 | max(self.polygon[0][0], other.polygon[0][0]), 121 | max(self.polygon[0][1], other.polygon[0][1]), 122 | ] 123 | elif i == 1: 124 | new_corner = [ 125 | min(self.polygon[1][0], other.polygon[1][0]), 126 | max(self.polygon[1][1], other.polygon[1][1]), 127 | ] 128 | elif i == 2: 129 | new_corner = [ 130 | min(self.polygon[2][0], other.polygon[2][0]), 131 | min(self.polygon[2][1], other.polygon[2][1]), 132 | ] 133 | elif i == 3: 134 | new_corner = [ 135 | max(self.polygon[3][0], other.polygon[3][0]), 136 | min(self.polygon[3][1], other.polygon[3][1]), 137 | ] 138 | new_poly.append(new_corner) 139 | 140 | return new_poly 141 | 142 | def intersection_area(self, other, x_margin=0, y_margin=0): 143 | x_overlap = self.x_overlap(other, x_margin) 144 | y_overlap = self.y_overlap(other, y_margin) 145 | return x_overlap * y_overlap 146 | 147 | def x_overlap(self, other, x_margin=0): 148 | return max( 149 | 0, 150 | min(self.bbox[2] + x_margin, other.bbox[2] + x_margin) 151 | - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin), 152 | ) 153 | 154 | def y_overlap(self, other, y_margin=0): 155 | return max( 156 | 0, 157 | min(self.bbox[3] + y_margin, other.bbox[3] + y_margin) 158 | - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin), 159 | ) 160 | 161 | def intersection_pct(self, other, x_margin=0, y_margin=0): 162 | assert 0 <= x_margin <= 1 163 | assert 0 <= y_margin <= 1 164 | if self.area == 0: 165 | return 0 166 | 167 | if x_margin: 168 | x_margin = int(min(self.width, other.width) * x_margin) 169 | if y_margin: 170 | y_margin = int(min(self.height, other.height) * y_margin) 171 | 172 | intersection = self.intersection_area(other, x_margin, y_margin) 173 | return intersection / self.area 174 | 175 | def shift(self, x_shift: float | None = None, y_shift: float | None = None): 176 | if x_shift is not None: 177 | for corner in self.polygon: 178 | corner[0] += x_shift 179 | if y_shift is not None: 180 | for corner in self.polygon: 181 | corner[1] += y_shift 182 | 183 | def clamp(self, bbox: List[float]): 184 | for corner in self.polygon: 185 | corner[0] = max(min(corner[0], bbox[2]), bbox[0]) 186 | corner[1] = max(min(corner[1], bbox[3]), bbox[1]) 187 | 188 | @property 189 | def center(self): 190 | return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2] 191 | 192 | def distance(self, other): 193 | center = self.center 194 | other_center = other.center 195 | 196 | return ( 197 | (center[0] - other_center[0]) ** 2 + (center[1] - other_center[1]) ** 2 198 | ) ** 0.5 199 | 200 | def __hash__(self): 201 | return hash(tuple(self.bbox)) 202 | ``` -------------------------------------------------------------------------------- /surya/settings.py: -------------------------------------------------------------------------------- ```python 1 | import os 2 | from typing import Callable, Dict, Optional 3 | 4 | import torch 5 | from dotenv import find_dotenv 6 | from pydantic import computed_field 7 | from pydantic_settings import BaseSettings 8 | from pathlib import Path 9 | from platformdirs import user_cache_dir 10 | 11 | 12 | class Settings(BaseSettings): 13 | # General 14 | TORCH_DEVICE: Optional[str] = None 15 | IMAGE_DPI: int = 96 # Used for detection, layout, reading order 16 | IMAGE_DPI_HIGHRES: int = 192 # Used for OCR, table rec 17 | IN_STREAMLIT: bool = False # Whether we're running in streamlit 18 | FLATTEN_PDF: bool = True # Flatten PDFs by merging form fields before processing 19 | DISABLE_TQDM: bool = False # Disable tqdm progress bars 20 | S3_BASE_URL: str = "https://models.datalab.to" 21 | PARALLEL_DOWNLOAD_WORKERS: int = ( 22 | 10 # Number of workers for parallel model downloads 23 | ) 24 | MODEL_CACHE_DIR: str = str(Path(user_cache_dir("datalab")) / "models") 25 | LOGLEVEL: str = "INFO" # Logging level 26 | 27 | # Paths 28 | DATA_DIR: str = "data" 29 | RESULT_DIR: str = "results" 30 | BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 31 | FONT_DIR: str = os.path.join(BASE_DIR, "static", "fonts") 32 | 33 | @computed_field 34 | def TORCH_DEVICE_MODEL(self) -> str: 35 | if self.TORCH_DEVICE is not None: 36 | return self.TORCH_DEVICE 37 | 38 | if torch.cuda.is_available(): 39 | return "cuda" 40 | 41 | if torch.backends.mps.is_available(): 42 | return "mps" 43 | 44 | try: 45 | import torch_xla 46 | 47 | if len(torch_xla.devices()) > 0: 48 | return "xla" 49 | except Exception: 50 | pass 51 | 52 | return "cpu" 53 | 54 | # Text detection 55 | DETECTOR_BATCH_SIZE: Optional[int] = None # Defaults to 2 for CPU/MPS, 32 otherwise 56 | DETECTOR_MODEL_CHECKPOINT: str = "s3://text_detection/2025_05_07" 57 | DETECTOR_BENCH_DATASET_NAME: str = "vikp/doclaynet_bench" 58 | DETECTOR_IMAGE_CHUNK_HEIGHT: int = ( 59 | 1400 # Height at which to slice images vertically 60 | ) 61 | DETECTOR_TEXT_THRESHOLD: float = ( 62 | 0.6 # Threshold for text detection (above this is considered text) 63 | ) 64 | DETECTOR_BLANK_THRESHOLD: float = ( 65 | 0.35 # Threshold for blank space (below this is considered blank) 66 | ) 67 | DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min( 68 | 8, os.cpu_count() 69 | ) # Number of workers for postprocessing 70 | DETECTOR_MIN_PARALLEL_THRESH: int = ( 71 | 3 # Minimum number of images before we parallelize 72 | ) 73 | DETECTOR_BOX_Y_EXPAND_MARGIN: float = ( 74 | 0.05 # Margin by which to expand detected boxes vertically 75 | ) 76 | COMPILE_DETECTOR: bool = False 77 | 78 | # Text recognition 79 | FOUNDATION_MODEL_CHECKPOINT: str = "s3://text_recognition/2025_09_23" 80 | FOUNDATION_MODEL_QUANTIZE: bool = False 81 | FOUNDATION_MAX_TOKENS: Optional[int] = None 82 | FOUNDATION_CHUNK_SIZE: Optional[int] = None 83 | FOUNDATION_PAD_TO_NEAREST: int = 256 84 | COMPILE_FOUNDATION: bool = False 85 | FOUNDATION_MULTI_TOKEN_MIN_CONFIDENCE: float = 0.9 86 | 87 | RECOGNITION_MODEL_CHECKPOINT: str = "s3://text_recognition/2025_09_23" 88 | RECOGNITION_BATCH_SIZE: Optional[int] = ( 89 | None # Defaults to 8 for CPU/MPS, 256 otherwise 90 | ) 91 | RECOGNITION_RENDER_FONTS: Dict[str, str] = { 92 | "all": os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf"), 93 | "zh": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), 94 | "ja": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), 95 | "ko": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), 96 | } 97 | RECOGNITION_FONT_DL_BASE: str = ( 98 | "https://github.com/satbyy/go-noto-universal/releases/download/v7.0" 99 | ) 100 | RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench" 101 | RECOGNITION_PAD_VALUE: int = 255 # Should be 0 or 255 102 | 103 | # Layout 104 | LAYOUT_MODEL_CHECKPOINT: str = "s3://layout/2025_09_23" 105 | LAYOUT_IMAGE_SIZE: Dict = {"height": 768, "width": 768} 106 | LAYOUT_SLICE_MIN: Dict = { 107 | "height": 1500, 108 | "width": 1500, 109 | } # When to start slicing images 110 | LAYOUT_SLICE_SIZE: Dict = {"height": 1200, "width": 1200} # Size of slices 111 | LAYOUT_BATCH_SIZE: Optional[int] = None 112 | LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench" 113 | LAYOUT_MAX_BOXES: int = 100 114 | COMPILE_LAYOUT: bool = False 115 | LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench" 116 | ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench" 117 | 118 | # Table Rec 119 | TABLE_REC_MODEL_CHECKPOINT: str = "s3://table_recognition/2025_02_18" 120 | TABLE_REC_IMAGE_SIZE: Dict = {"height": 768, "width": 768} 121 | TABLE_REC_MAX_BOXES: int = 150 122 | TABLE_REC_BATCH_SIZE: Optional[int] = None 123 | TABLE_REC_BENCH_DATASET_NAME: str = "datalab-to/fintabnet_bench" 124 | COMPILE_TABLE_REC: bool = False 125 | 126 | # Texify 127 | TEXIFY_BENCHMARK_DATASET: str = "datalab-to/texify_bench" 128 | 129 | # OCR Error Detection 130 | OCR_ERROR_MODEL_CHECKPOINT: str = "s3://ocr_error_detection/2025_02_18" 131 | OCR_ERROR_BATCH_SIZE: Optional[int] = None 132 | COMPILE_OCR_ERROR: bool = False 133 | 134 | # Tesseract (for benchmarks only) 135 | TESSDATA_PREFIX: Optional[str] = None 136 | 137 | COMPILE_ALL: bool = False 138 | 139 | @computed_field 140 | def DETECTOR_STATIC_CACHE(self) -> bool: 141 | return ( 142 | self.COMPILE_ALL 143 | or self.COMPILE_DETECTOR 144 | or self.TORCH_DEVICE_MODEL == "xla" 145 | ) # We need to static cache and pad to batch size for XLA, since it will recompile otherwise 146 | 147 | @computed_field 148 | def LAYOUT_STATIC_CACHE(self) -> bool: 149 | return ( 150 | self.COMPILE_ALL or self.COMPILE_LAYOUT or self.TORCH_DEVICE_MODEL == "xla" 151 | ) 152 | 153 | @computed_field 154 | def FOUNDATION_XLA(self) -> bool: 155 | return ( 156 | self.TORCH_DEVICE_MODEL == "xla" 157 | ) # We need to static cache and pad to batch size for XLA, since it will recompile otherwise 158 | 159 | @computed_field 160 | def FOUNDATION_STATIC_CACHE(self) -> bool: 161 | return ( 162 | self.COMPILE_ALL 163 | or self.COMPILE_FOUNDATION 164 | or self.TORCH_DEVICE_MODEL == "xla" 165 | ) # We need to static cache and pad to batch size for XLA, since it will recompile otherwise 166 | 167 | @computed_field 168 | def TABLE_REC_STATIC_CACHE(self) -> bool: 169 | return ( 170 | self.COMPILE_ALL 171 | or self.COMPILE_TABLE_REC 172 | or self.TORCH_DEVICE_MODEL == "xla" 173 | ) 174 | 175 | @computed_field 176 | def OCR_ERROR_STATIC_CACHE(self) -> bool: 177 | return ( 178 | self.COMPILE_ALL 179 | or self.COMPILE_OCR_ERROR 180 | or self.TORCH_DEVICE_MODEL == "xla" 181 | ) 182 | 183 | @computed_field 184 | def MODEL_DTYPE(self) -> torch.dtype: 185 | if self.TORCH_DEVICE_MODEL == "cpu": 186 | return torch.float32 187 | if self.TORCH_DEVICE_MODEL == "xla": 188 | return torch.bfloat16 189 | return torch.float16 190 | 191 | @computed_field 192 | def MODEL_DTYPE_BFLOAT(self) -> torch.dtype: 193 | if self.TORCH_DEVICE_MODEL == "cpu": 194 | return torch.float32 195 | if self.TORCH_DEVICE_MODEL == "mps": 196 | return torch.bfloat16 197 | return torch.bfloat16 198 | 199 | @computed_field 200 | def INFERENCE_MODE(self) -> Callable: 201 | if self.TORCH_DEVICE_MODEL == "xla": 202 | return torch.no_grad 203 | return torch.inference_mode 204 | 205 | class Config: 206 | env_file = find_dotenv("local.env") 207 | extra = "ignore" 208 | 209 | 210 | settings = Settings() 211 | ``` -------------------------------------------------------------------------------- /surya/foundation/cache/static_ops.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Dict, List, Optional, Tuple 2 | import torch 3 | from transformers import PretrainedConfig 4 | 5 | from surya.foundation.cache.dynamic_ops import DynamicOpsCache 6 | 7 | """ 8 | Special cache class for the surya foundation model that supports - 9 | 1) Static shape 10 | 2) A custom sliding window, where image tokens stay in cache, and text tokens are popped 11 | 3) Continuous batching - merging etc 12 | 4) Attention mask management - To match with what's currently in the cache 13 | 14 | Heavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079 15 | """ 16 | 17 | 18 | class StaticOpsCache(DynamicOpsCache): 19 | def __init__( 20 | self, 21 | config: PretrainedConfig, 22 | batch_size: int, 23 | max_cache_len: int, 24 | text_sliding_window: int, 25 | device: int, 26 | dtype: int, 27 | ): 28 | self.text_sliding_window = text_sliding_window 29 | self.num_layers = config.num_hidden_layers 30 | self.max_batch_size = batch_size 31 | self.max_cache_len = max_cache_len 32 | self.head_dim = ( 33 | getattr(config, "head_dim", None) 34 | or config.hidden_size // config.num_attention_heads 35 | ) 36 | self._dtype = dtype 37 | self.num_key_value_heads = ( 38 | config.num_attention_heads 39 | if getattr(config, "num_key_value_heads", None) is None 40 | else config.num_key_value_heads 41 | ) 42 | 43 | # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125 44 | self.key_cache: list[torch.Tensor] = [] 45 | self.value_cache: list[torch.Tensor] = [] 46 | cache_shape = ( 47 | self.max_batch_size, 48 | self.num_key_value_heads, 49 | self.max_cache_len, 50 | self.head_dim, 51 | ) 52 | device = torch.device(device) if device is not None else None 53 | for _ in range(config.num_hidden_layers): 54 | new_layer_key_cache = torch.zeros( 55 | cache_shape, dtype=self._dtype, device=device 56 | ) 57 | new_layer_value_cache = torch.zeros( 58 | cache_shape, dtype=self._dtype, device=device 59 | ) 60 | torch._dynamo.mark_static_address(new_layer_key_cache) 61 | torch._dynamo.mark_static_address(new_layer_value_cache) 62 | self.key_cache.append(new_layer_key_cache) 63 | self.value_cache.append(new_layer_value_cache) 64 | 65 | self.attention_mask = torch.zeros( 66 | (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long 67 | ) 68 | self.text_token_counts = [ 69 | torch.zeros(self.max_batch_size, dtype=torch.long, device=device) 70 | for _ in range(self.num_layers) 71 | ] 72 | 73 | self.dtype = dtype 74 | self.device = device 75 | 76 | def update( 77 | self, 78 | key_states: torch.Tensor, 79 | value_states: torch.Tensor, 80 | layer_idx: int, 81 | cache_kwargs: Optional[Dict[str, Any]] = None, 82 | ) -> Tuple[torch.Tensor, torch.Tensor]: 83 | prefill = cache_kwargs.get("prefill", False) 84 | update_fn = self._prefill_update if prefill else self._decode_update 85 | return update_fn( 86 | self.key_cache[layer_idx], 87 | self.value_cache[layer_idx], 88 | key_states, 89 | value_states, 90 | self.text_token_counts[layer_idx], 91 | cache_kwargs, 92 | ) 93 | 94 | def _prefill_update( 95 | self, 96 | key_cache: torch.Tensor, 97 | value_cache: torch.Tensor, 98 | key_states: torch.Tensor, 99 | value_states: torch.Tensor, 100 | text_token_counts: torch.Tensor, 101 | cache_kwargs: Optional[Dict[str, Any]] = None, 102 | ): 103 | cache_idxs: torch.tensor = cache_kwargs.get("cache_idxs", None) 104 | text_lengths: List[int] = cache_kwargs.get("text_lengths", None) 105 | assert cache_idxs is not None, "cache_idxs must be specified during prefill" 106 | assert text_lengths is not None, "text_lengths must be specified during prefill" 107 | 108 | cache_idx_length = len(cache_idxs) 109 | full_batch = len(cache_idxs) == self.max_batch_size 110 | 111 | # Insert key and value states at the end of the cache 112 | new_tokens = key_states.shape[2] 113 | 114 | # Direct right-aligned assignment 115 | if full_batch: 116 | key_cache[:, :, -new_tokens:] = key_states 117 | value_cache[:, :, -new_tokens:] = value_states 118 | else: 119 | key_cache[cache_idxs, :, -new_tokens:] = key_states[:cache_idx_length] 120 | value_cache[cache_idxs, :, -new_tokens:] = value_states[:cache_idx_length] 121 | 122 | return key_states, value_states 123 | 124 | # """ 125 | # Matches the logic of the decode update, but needs to be called before the updates 126 | # since some parts of the model depend on the attention mask 127 | # """ 128 | def decode_attention_mask_update( 129 | self, num_valid_tokens: torch.Tensor, cache_idxs: List[int] 130 | ): 131 | max_valid_tokens = num_valid_tokens.max().item() 132 | if max_valid_tokens == 0: 133 | # If no valid tokens, we don't need to update the attention mask 134 | return 135 | 136 | # Shift the attention mask to the left by max_valid_tokens 137 | self.attention_mask = self.attention_mask.roll(-1 * max_valid_tokens, dims=1) 138 | self.attention_mask[:, -max_valid_tokens:] = ( 139 | 1 # Full attention to all new tokens 140 | ) 141 | 142 | # Mirrors the logic from _prefill_update 143 | def prefill_attention_mask_update( 144 | self, 145 | attention_mask: torch.Tensor, 146 | merge_idxs: torch.Tensor, 147 | valid_batch_size: torch.Tensor, 148 | text_lengths: List[int], 149 | ): 150 | # Set from -(image_length + text_length) to end to 1 for each batch element 151 | seq_len = attention_mask.shape[1] 152 | self.attention_mask[merge_idxs] = ( 153 | 0 # Reset the attention mask for the current batch elements 154 | ) 155 | self.attention_mask[merge_idxs, -seq_len:] = attention_mask[:valid_batch_size] 156 | 157 | def _decode_update( 158 | self, 159 | key_cache: torch.Tensor, 160 | value_cache: torch.Tensor, 161 | key_states: torch.Tensor, 162 | value_states: torch.Tensor, 163 | text_token_counts: torch.Tensor, 164 | cache_kwargs: Optional[Dict[str, Any]] = None, 165 | ) -> Tuple[torch.Tensor, torch.Tensor]: 166 | # Naive, always assumes we'll roll by a fixed amount 167 | # Needs left padding with beacons to work properly 168 | 169 | num_valid_tokens: torch.Tensor = cache_kwargs.get( 170 | "num_valid_tokens" 171 | ) # shape: (B,) 172 | assert num_valid_tokens is not None, ( 173 | "`num_valid_tokens` must be provided in `cache_kwargs`" 174 | ) 175 | # (B, H, L, D) 176 | 177 | valid_tokens = key_states.shape[2] 178 | 179 | key_cache.copy_(torch.roll(key_cache, -valid_tokens, dims=2)) 180 | value_cache.copy_(torch.roll(value_cache, -valid_tokens, dims=2)) 181 | 182 | key_cache[:, :, -valid_tokens:, :] = key_states 183 | value_cache[:, :, -valid_tokens:, :] = value_states 184 | 185 | # In-place edit - Mutates 186 | text_token_counts += num_valid_tokens 187 | text_token_counts.clamp_(max=self.text_sliding_window) 188 | return key_cache, value_cache 189 | 190 | # The attention mask managed by our kv cache automatically masks the tokens 191 | # in the cache, so we can return full length for HF to use in other places 192 | # This is mainly utilized in the cache_positions creation 193 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 194 | return self.max_cache_len 195 | ``` -------------------------------------------------------------------------------- /surya/table_rec/model/config.py: -------------------------------------------------------------------------------- ```python 1 | from dataclasses import dataclass 2 | from typing import Dict 3 | 4 | import torch 5 | from transformers import PretrainedConfig 6 | from transformers.utils import ModelOutput 7 | 8 | from surya.common.s3 import S3DownloaderMixin 9 | from surya.settings import settings 10 | 11 | BOX_DIM = 1024 12 | SPECIAL_TOKENS = 5 13 | MAX_BOXES = 150 14 | 15 | MERGE_KEYS = { 16 | "none": 0, 17 | "merge_up": 1, 18 | "merge_down": 2, 19 | "merge_both": 3 20 | } 21 | MERGE_VALUES = [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]] 22 | 23 | ID_TO_CATEGORY = { 24 | 0: 'Blank', 25 | 1: 'Table-row', 26 | 2: 'Table-column', 27 | 3: 'Table-cell', 28 | 4: 'Table' 29 | } 30 | CATEGORY_TO_ID = {v: k for k, v in ID_TO_CATEGORY.items()} 31 | 32 | ID_TO_HEADER = { 33 | 0: "None", 34 | 1: "Header" 35 | } 36 | HEADER_TO_ID = {v: k for k, v in ID_TO_HEADER.items()} 37 | 38 | BOX_PROPERTIES = [ 39 | ("bbox", 6, "regression"), 40 | ("category", len(ID_TO_CATEGORY), "classification"), 41 | ("merges", len(MERGE_KEYS), "classification"), 42 | ("colspan", 1, "regression"), 43 | ("is_header", len(ID_TO_HEADER), "classification") 44 | ] 45 | 46 | 47 | @dataclass 48 | class TableRecModelOutput(ModelOutput): 49 | box_property_logits: Dict[str, torch.Tensor] 50 | hidden_states: torch.Tensor | None = None 51 | 52 | 53 | class SuryaTableRecConfig(S3DownloaderMixin, PretrainedConfig): 54 | model_type = "vision-encoder-decoder" 55 | is_composition = True 56 | 57 | def __init__(self, **kwargs): 58 | super().__init__(**kwargs) 59 | 60 | if "encoder" in kwargs: 61 | encoder_config = kwargs.pop("encoder") 62 | decoder_config = kwargs.pop("decoder") 63 | else: 64 | encoder_config = DonutSwinTableRecConfig() 65 | decoder_config = SuryaTableRecDecoderConfig() 66 | 67 | self.encoder = encoder_config 68 | self.decoder = decoder_config 69 | self.is_encoder_decoder = True 70 | 71 | if isinstance(decoder_config, dict): 72 | self.decoder_start_token_id = decoder_config["bos_token_id"] 73 | self.pad_token_id = decoder_config["pad_token_id"] 74 | self.eos_token_id = decoder_config["eos_token_id"] 75 | else: 76 | self.decoder_start_token_id = decoder_config.bos_token_id 77 | self.pad_token_id = decoder_config.pad_token_id 78 | self.eos_token_id = decoder_config.eos_token_id 79 | 80 | 81 | class DonutSwinTableRecConfig(PretrainedConfig): 82 | model_type = "donut-swin" 83 | 84 | attribute_map = { 85 | "num_attention_heads": "num_heads", 86 | "num_hidden_layers": "num_layers", 87 | } 88 | 89 | def __init__( 90 | self, 91 | image_size=(settings.TABLE_REC_IMAGE_SIZE["width"], settings.TABLE_REC_IMAGE_SIZE["height"]), 92 | patch_size=4, 93 | num_channels=3, 94 | embed_dim=128, 95 | depths=[2, 2, 12, 2], 96 | num_heads=[4, 8, 16, 32], 97 | num_kv_heads=[4, 8, 16, 32], 98 | window_size=8, 99 | mlp_ratio=4.0, 100 | qkv_bias=True, 101 | hidden_dropout_prob=0.0, 102 | attention_probs_dropout_prob=0.0, 103 | drop_path_rate=0.1, 104 | hidden_act="gelu", 105 | use_absolute_embeddings=False, 106 | initializer_range=0.02, 107 | layer_norm_eps=1e-5, 108 | encoder_length=1024, 109 | use_positional_embeddings=True, 110 | **kwargs, 111 | ): 112 | super().__init__(**kwargs) 113 | 114 | self.image_size = image_size 115 | self.patch_size = patch_size 116 | self.num_channels = num_channels 117 | self.embed_dim = embed_dim 118 | self.depths = depths 119 | self.num_layers = len(depths) 120 | self.num_heads = num_heads 121 | self.num_kv_heads = num_kv_heads 122 | self.window_size = window_size 123 | self.mlp_ratio = mlp_ratio 124 | self.qkv_bias = qkv_bias 125 | self.hidden_dropout_prob = hidden_dropout_prob 126 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 127 | self.drop_path_rate = drop_path_rate 128 | self.hidden_act = hidden_act 129 | self.use_absolute_embeddings = use_absolute_embeddings 130 | self.layer_norm_eps = layer_norm_eps 131 | self.initializer_range = initializer_range 132 | # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel 133 | # this indicates the channel dimension after the last stage of the model 134 | self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) 135 | self.encoder_length = encoder_length 136 | self.use_positional_embeddings = use_positional_embeddings 137 | 138 | 139 | class SuryaTableRecDecoderConfig(PretrainedConfig): 140 | model_type = "surya_tablerec" 141 | 142 | def __init__( 143 | self, 144 | num_hidden_layers=6, 145 | vocab_size=BOX_DIM + 1, 146 | bbox_size=BOX_DIM, 147 | hidden_size=512, 148 | property_embed_size=64, 149 | box_embed_size=512 - 64, 150 | intermediate_size=4 * 512, 151 | encoder_hidden_size=1024, 152 | num_attention_heads=8, 153 | lru_width=None, 154 | attention_window_size=16, 155 | conv1d_width=4, 156 | logits_soft_cap=30.0, 157 | rms_norm_eps=1e-6, 158 | use_cache=True, 159 | pad_token_id=0, 160 | eos_token_id=1, 161 | bos_token_id=1, 162 | pause_token_id=2, 163 | query_end_token_id=4, 164 | hidden_activation="gelu_pytorch_tanh", 165 | rope_theta=10000.0, 166 | block_types=("attention",), 167 | cross_attn_layers=tuple(range(10)), 168 | encoder_cross_attn_layers=tuple(range(10)), 169 | self_attn_layers=tuple(range(10)), 170 | global_attn_layers=tuple(range(10)), 171 | attention_dropout=0.0, 172 | num_key_value_heads=4, 173 | attention_bias=False, 174 | w_init_variance_scale=0.01, 175 | init_std=0.02, 176 | tie_word_embeddings=False, 177 | aux_heads=0, # How many n-token-ahead heads to add 178 | causal=True, 179 | layer_norm_eps=1e-5, 180 | dropout=0.0, 181 | special_token_count=SPECIAL_TOKENS, 182 | **kwargs, 183 | ): 184 | self.num_hidden_layers = num_hidden_layers 185 | self.vocab_size = vocab_size 186 | self.hidden_size = hidden_size 187 | self.intermediate_size = intermediate_size 188 | self.num_attention_heads = num_attention_heads 189 | self.lru_width = lru_width if lru_width is not None else hidden_size 190 | self.attention_window_size = attention_window_size 191 | self.conv1d_width = conv1d_width 192 | self.logits_soft_cap = logits_soft_cap 193 | self.rms_norm_eps = rms_norm_eps 194 | self.use_cache = use_cache 195 | self.rope_theta = rope_theta 196 | self.block_types = list(block_types) 197 | self.hidden_activation = hidden_activation 198 | self.head_dim = self.hidden_size // self.num_attention_heads 199 | self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads 200 | if self.num_key_value_heads > self.num_attention_heads: 201 | raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") 202 | self.cross_attn_layers = cross_attn_layers 203 | self.self_attn_layers = self_attn_layers 204 | self.global_attn_layers = global_attn_layers 205 | self.attention_dropout = attention_dropout 206 | self.attention_bias = attention_bias 207 | self.w_init_variance_scale = w_init_variance_scale 208 | self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers 209 | self.init_std = init_std 210 | self.tie_word_embeddings = tie_word_embeddings 211 | self.aux_heads = aux_heads 212 | self.encoder_hidden_size=encoder_hidden_size 213 | self.causal = causal 214 | self.encoder_cross_attn_layers = encoder_cross_attn_layers 215 | self.layer_norm_eps = layer_norm_eps 216 | self.dropout = dropout 217 | self.bbox_size = bbox_size 218 | self.pause_token_id = pause_token_id 219 | self.box_properties = BOX_PROPERTIES 220 | self.property_embed_size = property_embed_size 221 | self.box_embed_size = box_embed_size 222 | self.special_token_count = special_token_count 223 | self.query_end_token_id = query_end_token_id 224 | self.double_residual_flow = False 225 | 226 | super().__init__( 227 | pad_token_id=pad_token_id, 228 | bos_token_id=bos_token_id, 229 | eos_token_id=eos_token_id, 230 | **kwargs, 231 | ) 232 | 233 | @property 234 | def layers_block_type(self): 235 | return (self.block_types * 100)[: self.num_hidden_layers] ```