This is page 3 of 5. Use http://codebase.md/datalab-to/surya?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── ISSUE_TEMPLATE │ │ ├── breaking-bug-report.md │ │ ├── feature_request.md │ │ └── output-bug-report.md │ └── workflows │ ├── benchmarks.yml │ ├── ci.yml │ ├── cla.yml │ ├── publish.yml │ └── scripts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── benchmark │ ├── detection.py │ ├── layout.py │ ├── ordering.py │ ├── recognition.py │ ├── table_recognition.py │ ├── texify.py │ └── utils │ ├── __init__.py │ ├── bbox.py │ ├── metrics.py │ ├── scoring.py │ ├── tatr.py │ ├── tesseract.py │ ├── textract.py │ └── verify_benchmark_scores.py ├── CITATION.cff ├── CLA.md ├── detect_layout.py ├── detect_text.py ├── LICENSE ├── ocr_app.py ├── ocr_latex.py ├── ocr_text.py ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── README.md ├── signatures │ └── version1 │ └── cla.json ├── static │ ├── fonts │ │ └── .gitignore │ └── images │ ├── arabic_layout.jpg │ ├── arabic_reading.jpg │ ├── arabic_text.jpg │ ├── arabic.jpg │ ├── benchmark_chart_small.png │ ├── benchmark_chart.png │ ├── benchmark_layout_chart.png │ ├── benchmark_rec_chart.png │ ├── benchmark_tablerec_acc.png │ ├── benchmark_tablerec_speed.png │ ├── chi_hind_layout.jpg │ ├── chi_hind_orig.jpg │ ├── chi_hind_reading.jpg │ ├── chi_hind_text.jpg │ ├── chi_hind.jpg │ ├── chinese_layout.jpg │ ├── chinese_reading.jpg │ ├── chinese_text.jpg │ ├── chinese.jpg │ ├── excerpt_layout.png │ ├── excerpt_reading.jpg │ ├── excerpt_text.png │ ├── excerpt.png │ ├── funsd_layout.jpg │ ├── funsd_reading.jpg │ ├── funsd_text.jpg │ ├── funsd.png │ ├── gcloud_full_langs.png │ ├── gcloud_rec_bench.png │ ├── hindi_layout.jpg │ ├── hindi_reading.jpg │ ├── hindi_text.jpg │ ├── hindi.jpg │ ├── japanese_layout.jpg │ ├── japanese_reading.jpg │ ├── japanese_tablerec.png │ ├── japanese_text.jpg │ ├── japanese.jpg │ ├── latex_ocr.png │ ├── nyt_layout.jpg │ ├── nyt_order.jpg │ ├── nyt_text.jpg │ ├── nyt.jpg │ ├── paper_layout.jpg │ ├── paper_reading.jpg │ ├── paper_tablerec.png │ ├── paper_text.jpg │ ├── paper.jpg │ ├── pres_layout.jpg │ ├── pres_reading.jpg │ ├── pres_tablerec.png │ ├── pres_text.jpg │ ├── pres.png │ ├── rec_acc_table.png │ ├── scanned_layout.jpg │ ├── scanned_reading.jpg │ ├── scanned_tablerec.png │ ├── scanned_tablerec2.png │ ├── scanned_text.jpg │ ├── scanned.png │ ├── surya_rec_perf.png │ ├── table_rec.png │ ├── textbook_layout.jpg │ ├── textbook_order.jpg │ ├── textbook_text.jpg │ └── textbook.jpg ├── surya │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── adetr │ │ │ └── decoder.py │ │ ├── donut │ │ │ ├── encoder.py │ │ │ └── processor.py │ │ ├── load.py │ │ ├── polygon.py │ │ ├── predictor.py │ │ ├── pretrained.py │ │ ├── s3.py │ │ ├── surya │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── decoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── embedder │ │ │ │ └── __init__.py │ │ │ ├── encoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── flash_attn_utils.py │ │ │ ├── processor │ │ │ │ ├── __init__.py │ │ │ │ ├── schema.py │ │ │ │ └── tokenizer.py │ │ │ └── schema.py │ │ ├── util.py │ │ └── xla.py │ ├── debug │ │ ├── draw.py │ │ ├── fonts.py │ │ ├── katex.js │ │ ├── render_html.py │ │ └── text.py │ ├── detection │ │ ├── __init__.py │ │ ├── heatmap.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoderdecoder.py │ │ ├── parallel.py │ │ ├── processor.py │ │ ├── schema.py │ │ └── util.py │ ├── foundation │ │ ├── __init__.py │ │ ├── cache │ │ │ ├── __init__.py │ │ │ ├── dynamic_ops.py │ │ │ └── static_ops.py │ │ ├── loader.py │ │ └── util.py │ ├── input │ │ ├── load.py │ │ └── processing.py │ ├── layout │ │ ├── __init__.py │ │ ├── label.py │ │ └── schema.py │ ├── logging.py │ ├── models.py │ ├── ocr_error │ │ ├── __init__.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoder.py │ │ ├── schema.py │ │ └── tokenizer.py │ ├── recognition │ │ ├── __init__.py │ │ ├── languages.py │ │ ├── postprocessing.py │ │ ├── schema.py │ │ └── util.py │ ├── scripts │ │ ├── __init__.py │ │ ├── config.py │ │ ├── detect_layout.py │ │ ├── detect_text.py │ │ ├── finetune_ocr.py │ │ ├── hf_to_s3.py │ │ ├── ocr_latex.py │ │ ├── ocr_text.py │ │ ├── run_streamlit_app.py │ │ ├── run_texify_app.py │ │ ├── streamlit_app.py │ │ ├── table_recognition.py │ │ └── texify_app.py │ ├── settings.py │ └── table_rec │ ├── __init__.py │ ├── loader.py │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── encoderdecoder.py │ ├── processor.py │ ├── schema.py │ └── shaper.py ├── table_recognition.py ├── tests │ ├── assets │ │ └── test_latex.png │ ├── conftest.py │ ├── test_detection.py │ ├── test_foundation.py │ ├── test_latex_ocr.py │ ├── test_layout.py │ ├── test_ocr_errors.py │ ├── test_recognition.py │ └── test_table_rec.py └── texify_app.py ``` # Files -------------------------------------------------------------------------------- /surya/common/surya/flash_attn_utils.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Optional 2 | import torch 3 | import torch.nn.functional as F 4 | from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func 5 | from flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache 6 | from flash_attn.bert_padding import index_first_axis as _index_first_axis 7 | from flash_attn.bert_padding import pad_input 8 | 9 | def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: 10 | """ 11 | Retrieves indexing data required to repad unpadded (ragged) tensors. 12 | 13 | Arguments: 14 | attention_mask (`torch.Tensor`): 15 | Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. 16 | 17 | Return: 18 | indices (`torch.Tensor`): 19 | The indices of non-masked tokens from the flattened input sequence. 20 | cu_seqlens (`torch.Tensor`): 21 | The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). 22 | max_seqlen_in_batch (`int`): 23 | Maximum sequence length in batch. 24 | """ 25 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 26 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 27 | max_seqlen_in_batch = seqlens_in_batch.max().item() 28 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 29 | return ( 30 | indices, 31 | cu_seqlens, 32 | max_seqlen_in_batch, 33 | ) 34 | 35 | def _upad_input( 36 | query_layer: torch.Tensor, 37 | key_layer: torch.Tensor, 38 | value_layer: torch.Tensor, 39 | query_length: int, 40 | indices_k, 41 | cu_seqlens_k, 42 | max_seqlen_in_batch_k 43 | ): 44 | """ 45 | Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. 46 | 47 | This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary 48 | tensors for query, key, value tensors. 49 | 50 | Arguments: 51 | query_layer (`torch.Tensor`): 52 | Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). 53 | key_layer (`torch.Tensor`): 54 | Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). 55 | value_layer (`torch.Tensor`): 56 | Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). 57 | attention_mask (`torch.Tensor`): 58 | Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. 59 | query_length (`int`): 60 | Target length. 61 | 62 | Return: 63 | query_layer (`torch.Tensor`): 64 | Query state without padding. Shape: (total_target_length, num_heads, head_dim). 65 | key_layer (`torch.Tensor`): 66 | Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). 67 | value_layer (`torch.Tensor`): 68 | Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). 69 | indices_q (`torch.Tensor`): 70 | The indices of non-masked tokens from the flattened input target sequence. 71 | (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): 72 | The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). 73 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): 74 | Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). 75 | """ 76 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 77 | 78 | key_layer = _index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) 79 | value_layer = _index_first_axis( 80 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 81 | ) 82 | if query_length == kv_seq_len: 83 | query_layer = _index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) 84 | cu_seqlens_q = cu_seqlens_k 85 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 86 | indices_q = indices_k 87 | elif query_length == 1: 88 | max_seqlen_in_batch_q = 1 89 | cu_seqlens_q = torch.arange( 90 | batch_size + 1, dtype=torch.int32, device=query_layer.device 91 | ) # There is a memcpy here, that is very bad. 92 | indices_q = cu_seqlens_q[:-1] 93 | query_layer = query_layer.squeeze(1) 94 | else: 95 | raise NotImplementedError() 96 | 97 | return ( 98 | query_layer, 99 | key_layer, 100 | value_layer, 101 | indices_q, 102 | (cu_seqlens_q, cu_seqlens_k), 103 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 104 | ) 105 | 106 | def flash_attn_prefill( 107 | module: torch.nn.Module, 108 | query_states: torch.Tensor, 109 | key_states: torch.Tensor, 110 | value_states: torch.Tensor, 111 | attention_mask: torch.Tensor, 112 | dropout: float, 113 | scaling: float, 114 | query_length: int, 115 | batch_size: int, 116 | indices_k: torch.Tensor, 117 | cu_seqlens_k: torch.Tensor, 118 | max_seqlen_in_batch_k: int, 119 | **kwargs 120 | ): 121 | """ 122 | Wrapper for flash attention during the prefill stage 123 | query_states must have shape (batch_size, num_heads, seq_len, head_dim) 124 | key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim) 125 | 126 | This is the opposite of what is required by flash attention, but keeps parity with the HF convention 127 | 128 | query_length, batch_size, indices_k, cu_seqlens_k, and max_seqlen_in_batch_k should come from the flash attention kwargs 129 | """ 130 | query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2) 131 | q_flash, k_flash, v_flash, indices_q, cu_seq_lens, max_seq_lens = _upad_input( 132 | query_states, key_states, value_states, query_length, indices_k, cu_seqlens_k, max_seqlen_in_batch_k 133 | ) 134 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 135 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 136 | 137 | # Returning None for attn_weights to match other attention interfaces 138 | flash_attn_out = _flash_attn_varlen_func( 139 | q_flash, 140 | k_flash, 141 | v_flash, 142 | cu_seqlens_q=cu_seqlens_q, 143 | cu_seqlens_k=cu_seqlens_k, 144 | max_seqlen_q=max_seqlen_in_batch_q, 145 | max_seqlen_k=max_seqlen_in_batch_k, 146 | dropout_p=dropout, 147 | softmax_scale=scaling, 148 | causal=module.is_causal, 149 | ) 150 | return pad_input(flash_attn_out, indices_q, batch_size, query_length), None 151 | 152 | # NOTE: Does not support dropout, accepts argument as kwargs to maintain compatibility 153 | # This function is an order of magnitude faster than the prefill variant, or using the HF interface 154 | def flash_attn_decode( 155 | module: torch.nn.Module, 156 | query_states: torch.Tensor, 157 | key_states: torch.Tensor, 158 | value_states: torch.Tensor, 159 | attention_mask: torch.Tensor, 160 | scaling: float, 161 | **kwargs, 162 | ): 163 | """ 164 | Wrapper for flash attention during the decode stage 165 | 166 | query_states must have shape (batch_size, num_heads, seq_len, head_dim), 1 is the seq length in the decoding stage 167 | key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim) 168 | 169 | This is the opposite of what is required by flash attention, but keeps parity with the HF convention 170 | 171 | This function computes the left pad and cache seqlens to pass into FA2. For example - 172 | Given an attention_mask shaped (batch_size=2, seq_len=8), where 0 = padding, 1 = real token 173 | attention_mask = 174 | tensor([ 175 | [0, 0, 1, 1, 1, 0, 0, 0], # ← batch 0 176 | [0, 1, 1, 1, 1, 1, 1, 0], # ← batch 1 177 | ]) 178 | cache_leftpad = tensor([2, 1], dtype=torch.int32) 179 | cache_seqlens = tensor([5, 7], dtype=torch.int32) 180 | These values allow FlashAttention to use a static cache layout with efficient slicing during decoding. 181 | """ 182 | query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2) 183 | 184 | cache_leftpad = (attention_mask == 0).cumprod(dim=1).sum(dim=1).to(torch.int32) 185 | cache_seqlens = (attention_mask * torch.arange(attention_mask.size(1), device=attention_mask.device)).argmax(dim=1).to(torch.int32) + 1 186 | 187 | # Returning None for attn_weights to match other attention interfaces 188 | return _flash_attn_with_kvcache( 189 | q=query_states, 190 | k_cache=key_states, 191 | v_cache=value_states, 192 | cache_leftpad=cache_leftpad, 193 | cache_seqlens=cache_seqlens, 194 | causal=module.is_causal, 195 | softmax_scale=scaling, 196 | ), None ``` -------------------------------------------------------------------------------- /surya/common/util.py: -------------------------------------------------------------------------------- ```python 1 | import copy 2 | from typing import List 3 | import torch 4 | from functools import lru_cache 5 | 6 | import torch.nn.functional as F 7 | 8 | from surya.common.polygon import PolygonBox 9 | 10 | 11 | def clean_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]: 12 | new_boxes = [] 13 | for box_obj in boxes: 14 | xs = [point[0] for point in box_obj.polygon] 15 | ys = [point[1] for point in box_obj.polygon] 16 | if max(xs) == min(xs) or max(ys) == min(ys): 17 | continue 18 | 19 | box = box_obj.bbox 20 | contained = False 21 | for other_box_obj in boxes: 22 | if other_box_obj.polygon == box_obj.polygon: 23 | continue 24 | 25 | other_box = other_box_obj.bbox 26 | if box == other_box: 27 | continue 28 | if ( 29 | box[0] >= other_box[0] 30 | and box[1] >= other_box[1] 31 | and box[2] <= other_box[2] 32 | and box[3] <= other_box[3] 33 | ): 34 | contained = True 35 | break 36 | if not contained: 37 | new_boxes.append(box_obj) 38 | return new_boxes 39 | 40 | 41 | def rescale_bbox(bbox, processor_size, image_size): 42 | page_width, page_height = processor_size 43 | 44 | img_width, img_height = image_size 45 | width_scaler = img_width / page_width 46 | height_scaler = img_height / page_height 47 | 48 | new_bbox = copy.deepcopy(bbox) 49 | new_bbox[0] = int(new_bbox[0] * width_scaler) 50 | new_bbox[1] = int(new_bbox[1] * height_scaler) 51 | new_bbox[2] = int(new_bbox[2] * width_scaler) 52 | new_bbox[3] = int(new_bbox[3] * height_scaler) 53 | return new_bbox 54 | 55 | 56 | def expand_bbox(bbox, expansion_factor=0.01): 57 | expansion_low = 1 - expansion_factor 58 | expansion_high = 1 + expansion_factor 59 | return [ 60 | bbox[0] * expansion_low, 61 | bbox[1] * expansion_low, 62 | bbox[2] * expansion_high, 63 | bbox[3] * expansion_high, 64 | ] 65 | 66 | SCRIPT_TOKEN_MAPPING = { 67 | "latin": "<SCRIPT-LATIN>", 68 | "punctuation": "<SCRIPT-PUNCTUATION>", 69 | "cyrillic": "<SCRIPT-CYRILLIC>", 70 | "arabic": "<SCRIPT-ARABIC>", 71 | "chinese": "<SCRIPT-CHINESE>", 72 | "japanese": "<SCRIPT-JAPANESE>", 73 | "korean": "<SCRIPT-KOREAN>", 74 | "symbols": "<SCRIPT-SYMBOLS>", 75 | "greek": "<SCRIPT-GREEK>", 76 | "armenian": "<SCRIPT-ARMENIAN>", 77 | "hebrew": "<SCRIPT-HEBREW>", 78 | "devanagari": "<SCRIPT-DEVANAGARI>", 79 | "bengali": "<SCRIPT-BENGALI>", 80 | "gurmukhi": "<SCRIPT-GURMUKHI>", 81 | "gujarati": "<SCRIPT-GUJARATI>", 82 | "oriya": "<SCRIPT-ORIYA>", 83 | "tamil": "<SCRIPT-TAMIL>", 84 | "telugu": "<SCRIPT-TELUGU>", 85 | "kannada": "<SCRIPT-KANNADA>", 86 | "malayalam": "<SCRIPT-MALAYALAM>", 87 | "sinhala": "<SCRIPT-SINHALA>", 88 | "thai": "<SCRIPT-THAI>", 89 | "lao": "<SCRIPT-LAO>", 90 | "myanmar": "<SCRIPT-MYANMAR>", 91 | "georgian": "<SCRIPT-GEORGIAN>", 92 | "ethiopic": "<SCRIPT-ETHIOPIC>", 93 | "khmer": "<SCRIPT-KHMER>", 94 | "mongolian": "<SCRIPT-MONGOLIAN>", 95 | "math": "<SCRIPT-MATH>", 96 | } 97 | 98 | @lru_cache(maxsize=1) 99 | def script_ranges(): 100 | script_categories = { 101 | # Latin-based scripts (used by English, French, German, etc.) 102 | "latin": [ 103 | (0x0041, 0x005A), # Latin uppercase A-Z 104 | (0x0061, 0x007A), # Latin lowercase a-z 105 | (0x0080, 0x00FF), # Latin-1 Supplement 106 | (0x0100, 0x017F), # Latin Extended-A 107 | (0x0180, 0x024F), # Latin Extended-B 108 | (0x0250, 0x02AF), # IPA Extensions 109 | (0x02B0, 0x02FF), # Spacing Modifier Letters 110 | (0x0300, 0x036F), # Combining Diacritical Marks 111 | (0x1E00, 0x1EFF), # Latin Extended Additional 112 | (0x2C60, 0x2C7F), # Latin Extended-C 113 | (0xA720, 0xA7FF), # Latin Extended-D 114 | ], 115 | # Punctuation, universal characters, and general symbols 116 | "punctuation": [ 117 | (0x0020, 0x0020), # Space 118 | (0x0021, 0x002F), # Basic punctuation and symbols 119 | (0x0030, 0x0039), # Digits 0-9 120 | (0x003A, 0x0040), # More punctuation and symbols 121 | (0x005B, 0x0060), # More punctuation and symbols 122 | (0x007B, 0x007F), # More punctuation and symbols 123 | (0x2000, 0x206F), # General Punctuation 124 | ], 125 | # Cyrillic scripts (used by Russian, Ukrainian, etc.) 126 | "cyrillic": [ 127 | (0x0400, 0x04FF), # Cyrillic 128 | (0x0500, 0x052F), # Cyrillic Supplement 129 | ], 130 | # Arabic scripts 131 | "arabic": [ 132 | (0x0600, 0x06FF), # Arabic 133 | (0x0750, 0x077F), # Arabic Supplement 134 | (0x08A0, 0x08FF), # Arabic Extended-A 135 | ], 136 | # Chinese characters 137 | "chinese": [ 138 | (0x4E00, 0x9FFF), # Common CJK Unified Ideographs 139 | (0x3400, 0x4DBF), # CJK Extension A 140 | (0x20000, 0x2A6DF), # CJK Extension B 141 | ], 142 | # Japanese-specific scripts (excluding shared CJK) 143 | "japanese": [ 144 | (0x3040, 0x30FF), # Hiragana and Katakana 145 | ], 146 | # Korean-specific scripts 147 | "korean": [ 148 | (0x1100, 0x11FF), # Hangul Jamo 149 | (0x3130, 0x318F), # Hangul Compatibility Jamo 150 | (0xAC00, 0xD7AF), # Hangul Syllables 151 | ], 152 | # Various mathematical and technical symbols 153 | "symbols": [ 154 | (0x2070, 0x209F), # Superscripts and Subscripts 155 | (0x20A0, 0x20CF), # Currency Symbols 156 | (0x2100, 0x214F), # Letterlike Symbols 157 | (0x2150, 0x218F), # Number Forms 158 | (0x2190, 0x21FF), # Arrows 159 | (0x2200, 0x22FF), # Mathematical Operators 160 | (0x2300, 0x23FF), # Miscellaneous Technical 161 | (0x2500, 0x257F), # Box Drawing 162 | (0x2580, 0x259F), # Block Elements 163 | (0x25A0, 0x25FF), # Geometric Shapes 164 | (0x2600, 0x26FF), # Miscellaneous Symbols 165 | (0x2700, 0x27BF), # Dingbats 166 | (0x27C0, 0x27EF), # Miscellaneous Mathematical Symbols-A 167 | (0x2980, 0x29FF), # Miscellaneous Mathematical Symbols-B 168 | (0x2A00, 0x2AFF), # Supplemental Mathematical Operators 169 | (0x1D400, 0x1D7FF), # Mathematical Alphanumeric Symbols 170 | ], 171 | # Individual scripts for languages with unique writing systems 172 | "greek": [(0x0370, 0x03FF)], # Greek and Coptic 173 | "armenian": [(0x0530, 0x058F)], # Armenian 174 | "hebrew": [(0x0590, 0x05FF)], # Hebrew 175 | "devanagari": [(0x0900, 0x097F)], # Devanagari (Hindi, Sanskrit) 176 | "bengali": [(0x0980, 0x09FF)], # Bengali 177 | "gurmukhi": [(0x0A00, 0x0A7F)], # Gurmukhi (Punjabi) 178 | "gujarati": [(0x0A80, 0x0AFF)], # Gujarati 179 | "oriya": [(0x0B00, 0x0B7F)], # Oriya 180 | "tamil": [(0x0B80, 0x0BFF)], # Tamil 181 | "telugu": [(0x0C00, 0x0C7F)], # Telugu 182 | "kannada": [(0x0C80, 0x0CFF)], # Kannada 183 | "malayalam": [(0x0D00, 0x0D7F)], # Malayalam 184 | "sinhala": [(0x0D80, 0x0DFF)], # Sinhala 185 | "thai": [(0x0E00, 0x0E7F)], # Thai 186 | "lao": [(0x0E80, 0x0EFF)], # Lao 187 | "myanmar": [(0x1000, 0x109F)], # Myanmar 188 | "georgian": [(0x10A0, 0x10FF)], # Georgian 189 | "ethiopic": [(0x1200, 0x137F)], # Ethiopic 190 | "khmer": [(0x1780, 0x17FF)], # Khmer 191 | "mongolian": [(0x1800, 0x18AF)], # Mongolian 192 | } 193 | 194 | # Convert to a flat structure with character ranges 195 | flat_ranges = {} 196 | for category, ranges in script_categories.items(): 197 | # Create a set of all characters in this category 198 | char_set = set() 199 | for start, end in ranges: 200 | char_set.update(range(start, end + 1)) 201 | 202 | # Store the set in flat_ranges 203 | flat_ranges[category] = char_set 204 | 205 | return script_categories, flat_ranges 206 | 207 | def get_top_scripts(text: str, max_scripts: int = 5): 208 | script_categories, flat_ranges = script_ranges() 209 | char_count = {category: 0 for category in script_categories.keys()} 210 | for char in text: 211 | for category, char_set in flat_ranges.items(): 212 | if ord(char) in char_set: 213 | char_count[category] += 1 214 | break 215 | 216 | top_scripts = sorted(char_count.items(), key=lambda x: x[1], reverse=True) 217 | top_scripts = [ts[0] for ts in top_scripts if ts[1] > 0] 218 | if "<math" in text: 219 | top_scripts.insert(0, "math") 220 | 221 | return top_scripts[:max_scripts] 222 | 223 | def is_flash_attn_2_supported(device: str | torch.device) -> bool: 224 | if not torch.cuda.is_available(): 225 | return False 226 | 227 | if "cuda" not in str(device): 228 | return False 229 | 230 | # Check CUDA version >= 12.0 231 | cuda_version_str = torch.version.cuda 232 | if cuda_version_str is None: 233 | return False 234 | cuda_version = tuple(map(int, cuda_version_str.split("."))) 235 | if cuda_version < (12, 0): 236 | return False 237 | 238 | # Check GPU compute capability (Ampere, Ada, Hopper GPUs) 239 | major, minor = torch.cuda.get_device_capability() 240 | compute_capability = major + minor / 10 241 | if compute_capability < 8.0: 242 | return False 243 | 244 | return True 245 | 246 | 247 | def pad_to_batch_size_repeat(tensor: torch.Tensor, batch_size: int): 248 | current_batch_size = tensor.shape[0] 249 | if current_batch_size >= batch_size: 250 | return tensor 251 | 252 | pad_size = batch_size - current_batch_size 253 | if pad_size < 0: 254 | return tensor 255 | 256 | # Repeat the last row pad_size times 257 | last_row = tensor[-1:].repeat(pad_size, 1, 1) 258 | 259 | # Concatenate original tensor with repeated last rows 260 | return torch.cat([tensor, last_row], dim=0) 261 | 262 | 263 | def pad_to_batch_size(tensor: torch.Tensor, batch_size: int): 264 | current_batch_size = tensor.shape[0] 265 | if current_batch_size >= batch_size: 266 | return tensor 267 | 268 | pad_size = batch_size - current_batch_size 269 | padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) 270 | 271 | return F.pad(tensor, padding, mode="constant", value=0) 272 | ``` -------------------------------------------------------------------------------- /surya/scripts/streamlit_app.py: -------------------------------------------------------------------------------- ```python 1 | import io 2 | import tempfile 3 | from typing import List 4 | 5 | import pypdfium2 6 | import streamlit as st 7 | 8 | from surya.common.surya.schema import TaskNames 9 | from surya.models import load_predictors 10 | 11 | from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image 12 | 13 | from surya.debug.text import draw_text_on_image 14 | from PIL import Image, ImageDraw 15 | from surya.table_rec import TableResult 16 | from surya.detection import TextDetectionResult 17 | from surya.recognition import OCRResult 18 | from surya.layout import LayoutResult 19 | from surya.settings import settings 20 | from surya.common.util import rescale_bbox, expand_bbox 21 | 22 | 23 | @st.cache_resource() 24 | def load_predictors_cached(): 25 | return load_predictors() 26 | 27 | 28 | def ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15): 29 | from pdftext.extraction import plain_text_output 30 | 31 | with tempfile.NamedTemporaryFile(suffix=".pdf") as f: 32 | f.write(pdf_file.getvalue()) 33 | f.seek(0) 34 | 35 | # Sample the text from the middle of the PDF 36 | page_middle = page_count // 2 37 | page_range = range( 38 | max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count) 39 | ) 40 | text = plain_text_output(f.name, page_range=page_range) 41 | 42 | sample_gap = len(text) // max_samples 43 | if len(text) == 0 or sample_gap == 0: 44 | return "This PDF has no text or very little text", ["no text"] 45 | 46 | if sample_gap < sample_len: 47 | sample_gap = sample_len 48 | 49 | # Split the text into samples for the model 50 | samples = [] 51 | for i in range(0, len(text), sample_gap): 52 | samples.append(text[i : i + sample_len]) 53 | 54 | results = predictors["ocr_error"](samples) 55 | label = "This PDF has good text." 56 | if results.labels.count("bad") / len(results.labels) > 0.2: 57 | label = "This PDF may have garbled or bad OCR text." 58 | return label, results.labels 59 | 60 | 61 | def text_detection(img) -> (Image.Image, TextDetectionResult): 62 | text_pred = predictors["detection"]([img])[0] 63 | text_polygons = [p.polygon for p in text_pred.bboxes] 64 | det_img = draw_polys_on_image(text_polygons, img.copy()) 65 | return det_img, text_pred 66 | 67 | 68 | def layout_detection(img) -> (Image.Image, LayoutResult): 69 | pred = predictors["layout"]([img])[0] 70 | polygons = [p.polygon for p in pred.bboxes] 71 | labels = [ 72 | f"{p.label}-{p.position}-{round(p.top_k[p.label], 2)}" for p in pred.bboxes 73 | ] 74 | layout_img = draw_polys_on_image( 75 | polygons, img.copy(), labels=labels, label_font_size=18 76 | ) 77 | return layout_img, pred 78 | 79 | 80 | def table_recognition( 81 | img, highres_img, skip_table_detection: bool 82 | ) -> (Image.Image, List[TableResult]): 83 | if skip_table_detection: 84 | layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])] 85 | table_imgs = [highres_img] 86 | else: 87 | _, layout_pred = layout_detection(img) 88 | layout_tables_lowres = [ 89 | line.bbox 90 | for line in layout_pred.bboxes 91 | if line.label in ["Table", "TableOfContents"] 92 | ] 93 | table_imgs = [] 94 | layout_tables = [] 95 | for tb in layout_tables_lowres: 96 | highres_bbox = rescale_bbox(tb, img.size, highres_img.size) 97 | # Slightly expand the box 98 | highres_bbox = expand_bbox(highres_bbox) 99 | table_imgs.append(highres_img.crop(highres_bbox)) 100 | layout_tables.append(highres_bbox) 101 | 102 | table_preds = predictors["table_rec"](table_imgs) 103 | table_img = img.copy() 104 | 105 | for results, table_bbox in zip(table_preds, layout_tables): 106 | adjusted_bboxes = [] 107 | labels = [] 108 | colors = [] 109 | 110 | for item in results.cells: 111 | adjusted_bboxes.append( 112 | [ 113 | (item.bbox[0] + table_bbox[0]), 114 | (item.bbox[1] + table_bbox[1]), 115 | (item.bbox[2] + table_bbox[0]), 116 | (item.bbox[3] + table_bbox[1]), 117 | ] 118 | ) 119 | labels.append(item.label) 120 | if "Row" in item.label: 121 | colors.append("blue") 122 | else: 123 | colors.append("red") 124 | table_img = draw_bboxes_on_image( 125 | adjusted_bboxes, 126 | highres_img, 127 | labels=labels, 128 | label_font_size=18, 129 | color=colors, 130 | ) 131 | return table_img, table_preds 132 | 133 | 134 | # Function for OCR 135 | def ocr( 136 | img: Image.Image, 137 | highres_img: Image.Image, 138 | skip_text_detection: bool = False, 139 | recognize_math: bool = True, 140 | with_bboxes: bool = True, 141 | ) -> (Image.Image, OCRResult): 142 | if skip_text_detection: 143 | img = highres_img 144 | bboxes = [[[0, 0, img.width, img.height]]] 145 | else: 146 | bboxes = None 147 | 148 | if with_bboxes: 149 | tasks = [TaskNames.ocr_with_boxes] 150 | else: 151 | tasks = [TaskNames.ocr_without_boxes] 152 | 153 | img_pred = predictors["recognition"]( 154 | [img], 155 | task_names=tasks, 156 | bboxes=bboxes, 157 | det_predictor=predictors["detection"], 158 | highres_images=[highres_img], 159 | math_mode=recognize_math, 160 | return_words=True, 161 | )[0] 162 | 163 | bboxes = [line.bbox for line in img_pred.text_lines] 164 | text = [line.text for line in img_pred.text_lines] 165 | rec_img = draw_text_on_image(bboxes, text, img.size) 166 | 167 | word_boxes = [] 168 | for line in img_pred.text_lines: 169 | if line.words: 170 | word_boxes.extend([word.bbox for word in line.words]) 171 | 172 | box_img = img.copy() 173 | draw = ImageDraw.Draw(box_img) 174 | for word_box in word_boxes: 175 | draw.rectangle(word_box, outline="red", width=2) 176 | 177 | return rec_img, img_pred, box_img 178 | 179 | 180 | def open_pdf(pdf_file): 181 | stream = io.BytesIO(pdf_file.getvalue()) 182 | return pypdfium2.PdfDocument(stream) 183 | 184 | 185 | @st.cache_data() 186 | def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI): 187 | doc = open_pdf(pdf_file) 188 | renderer = doc.render( 189 | pypdfium2.PdfBitmap.to_pil, 190 | page_indices=[page_num - 1], 191 | scale=dpi / 72, 192 | ) 193 | png = list(renderer)[0] 194 | png_image = png.convert("RGB") 195 | doc.close() 196 | return png_image 197 | 198 | 199 | @st.cache_data() 200 | def page_counter(pdf_file): 201 | doc = open_pdf(pdf_file) 202 | doc_len = len(doc) 203 | doc.close() 204 | return doc_len 205 | 206 | 207 | st.set_page_config(layout="wide") 208 | col1, col2 = st.columns([0.5, 0.5]) 209 | 210 | predictors = load_predictors_cached() 211 | 212 | st.markdown(""" 213 | # Surya OCR Demo 214 | 215 | This app will let you try surya, a multilingual OCR toolkit. 216 | 217 | Notes: 218 | 219 | - This works best on documents with printed text. 220 | - For OCR, the formatting (math, italics, etc) will not show up in the image preview, but it will show up in the returned text lines. 221 | - If OCR doesn't work, try changing the resolution of your image (increase if below 2048px width, otherwise decrease). 222 | 223 | Find the project [here](https://github.com/VikParuchuri/surya). 224 | """) 225 | 226 | in_file = st.sidebar.file_uploader( 227 | "PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"] 228 | ) 229 | 230 | if in_file is None: 231 | st.stop() 232 | 233 | filetype = in_file.type 234 | page_count = None 235 | if "pdf" in filetype: 236 | page_count = page_counter(in_file) 237 | page_number = st.sidebar.number_input( 238 | f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count 239 | ) 240 | 241 | pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI) 242 | pil_image_highres = get_page_image( 243 | in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES 244 | ) 245 | else: 246 | pil_image = Image.open(in_file).convert("RGB") 247 | pil_image_highres = pil_image 248 | page_number = None 249 | 250 | run_text_det = st.sidebar.button("Run Text Detection") 251 | run_text_rec = st.sidebar.button("Run OCR") 252 | run_layout_det = st.sidebar.button("Run Layout Analysis") 253 | run_table_rec = st.sidebar.button("Run Table Rec") 254 | run_ocr_errors = st.sidebar.button("Run bad PDF text detection") 255 | use_pdf_boxes = st.sidebar.checkbox( 256 | "PDF table boxes", 257 | value=True, 258 | help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.", 259 | ) 260 | skip_table_detection = st.sidebar.checkbox( 261 | "Skip table detection", 262 | value=False, 263 | help="Table recognition only: Skip table detection and treat the whole image/page as a table.", 264 | ) 265 | skip_text_detection = st.sidebar.checkbox( 266 | "Skip text detection", 267 | value=False, 268 | help="OCR only: Skip text detection and treat the whole image as a single line.", 269 | ) 270 | recognize_math = st.sidebar.checkbox( 271 | "Recognize math in OCR", 272 | value=True, 273 | help="Enable math mode in OCR - this will recognize math.", 274 | ) 275 | ocr_with_boxes = st.sidebar.checkbox( 276 | "OCR with boxes", 277 | value=True, 278 | help="Enable OCR with boxes - this will predict character-level boxes.", 279 | ) 280 | 281 | if pil_image is None: 282 | st.stop() 283 | 284 | # Run Text Detection 285 | if run_text_det: 286 | det_img, text_pred = text_detection(pil_image) 287 | with col1: 288 | st.image(det_img, caption="Detected Text", use_container_width=True) 289 | st.json( 290 | text_pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True 291 | ) 292 | 293 | 294 | # Run layout 295 | if run_layout_det: 296 | layout_img, pred = layout_detection(pil_image) 297 | with col1: 298 | st.image(layout_img, caption="Detected Layout", use_container_width=True) 299 | st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True) 300 | 301 | # Run OCR 302 | if run_text_rec: 303 | rec_img, pred, box_img = ocr( 304 | pil_image, 305 | pil_image_highres, 306 | skip_text_detection, 307 | recognize_math, 308 | with_bboxes=ocr_with_boxes, 309 | ) 310 | with col1: 311 | st.image(rec_img, caption="OCR Result", use_container_width=True) 312 | json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"]) 313 | with json_tab: 314 | st.json(pred.model_dump(), expanded=False) 315 | with text_tab: 316 | st.text("\n".join([p.text for p in pred.text_lines])) 317 | 318 | st.image( 319 | box_img, 320 | caption="OCR with Word Boxes (for debugging)", 321 | use_container_width=True, 322 | ) 323 | 324 | 325 | if run_table_rec: 326 | table_img, pred = table_recognition( 327 | pil_image, pil_image_highres, skip_table_detection 328 | ) 329 | with col1: 330 | st.image(table_img, caption="Table Recognition", use_container_width=True) 331 | st.json([p.model_dump() for p in pred], expanded=True) 332 | 333 | if run_ocr_errors: 334 | if "pdf" not in filetype: 335 | st.error("This feature only works with PDFs.") 336 | label, results = ocr_errors(in_file, page_count) 337 | with col1: 338 | st.write(label) 339 | st.json(results) 340 | 341 | with col2: 342 | st.image(pil_image, caption="Uploaded Image", use_container_width=True) 343 | ``` -------------------------------------------------------------------------------- /benchmark/recognition.py: -------------------------------------------------------------------------------- ```python 1 | import re 2 | import unicodedata 3 | from collections import defaultdict 4 | 5 | import click 6 | 7 | from benchmark.utils.scoring import overlap_score, overlap_score_exact 8 | from surya.input.processing import convert_if_not_rgb 9 | from surya.debug.text import draw_text_on_image 10 | from surya.foundation import FoundationPredictor 11 | from surya.recognition import RecognitionPredictor 12 | from surya.settings import settings 13 | from surya.recognition.languages import CODE_TO_LANGUAGE 14 | from benchmark.utils.tesseract import ( 15 | tesseract_ocr_parallel, 16 | surya_lang_to_tesseract, 17 | TESS_CODE_TO_LANGUAGE, 18 | ) 19 | from benchmark.utils.textract import textract_ocr_parallel 20 | import os 21 | import datasets 22 | import json 23 | import time 24 | from tabulate import tabulate 25 | 26 | KEY_LANGUAGES = [ 27 | "Chinese", 28 | "Spanish", 29 | "English", 30 | "Arabic", 31 | "Hindi", 32 | "Bengali", 33 | "Russian", 34 | "Japanese", 35 | ] 36 | 37 | 38 | def list_in(lst: str | list, lst2: list): 39 | if isinstance(lst, str): 40 | lst = [lst] 41 | return any([item in lst for item in lst2]) 42 | 43 | 44 | def standardize_bullets(text): 45 | patterns = [ 46 | r"•\s+", 47 | r"·\s+", 48 | r"○\s+", 49 | r"◦\s+", 50 | r"▪\s+", 51 | r"▫\s+", 52 | r"➢\s+", 53 | r"➤\s+", 54 | r"★\s+", 55 | r"✓\s+", 56 | r"✗\s+", 57 | r"✦\s+", 58 | r"\\bullet\s+", 59 | ] 60 | 61 | combined_pattern = "|".join(patterns) 62 | text = re.sub(combined_pattern, "*", text) 63 | 64 | return text 65 | 66 | 67 | def normalize_text(text: str) -> str: 68 | # Remove HTML tags 69 | text = re.sub(r"<[^>]+>", "", text) 70 | # Remove LaTeX tags 71 | text = re.sub(r"\\[a-zA-Z]+", "", text) 72 | text = standardize_bullets(text) 73 | text = unicodedata.normalize("NFKC", text) 74 | return text.strip().lower().replace(",", ".") 75 | 76 | 77 | @click.command(help="Benchmark recognition model.") 78 | @click.option( 79 | "--results_dir", 80 | type=str, 81 | help="Path to JSON file with OCR results.", 82 | default=os.path.join(settings.RESULT_DIR, "benchmark"), 83 | ) 84 | @click.option( 85 | "--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=None 86 | ) 87 | @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) 88 | @click.option( 89 | "--tesseract", is_flag=True, help="Run benchmarks on tesseract.", default=False 90 | ) 91 | @click.option( 92 | "--textract", is_flag=True, help="Run benchmarks on textract.", default=False 93 | ) 94 | @click.option( 95 | "--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28 96 | ) 97 | @click.option( 98 | "--textract_cpus", type=int, help="Number of CPUs to use for textract.", default=28 99 | ) 100 | @click.option( 101 | "--languages", 102 | type=str, 103 | help="Comma-separated list of languages to benchmark.", 104 | default=None, 105 | ) 106 | @click.option( 107 | "--print_results", 108 | is_flag=True, 109 | ) 110 | def main( 111 | results_dir: str, 112 | max_rows: int, 113 | debug: bool, 114 | tesseract: bool, 115 | textract: bool, 116 | tess_cpus: int, 117 | textract_cpus: int, 118 | languages: str | None, 119 | print_results: bool, 120 | ): 121 | foundation_predictor = FoundationPredictor() 122 | rec_predictor = RecognitionPredictor(foundation_predictor) 123 | 124 | split = "train" 125 | dataset = datasets.load_dataset( 126 | settings.RECOGNITION_BENCH_DATASET_NAME, split=split 127 | ) 128 | 129 | if languages: 130 | languages = languages.split(",") 131 | dataset = dataset.filter( 132 | lambda x: list_in(x["language"], languages), num_proc=4 133 | ) 134 | 135 | if max_rows and max_rows < len(dataset): 136 | dataset = dataset.shuffle(seed=1).select(range(max_rows)) 137 | 138 | images = list(dataset["image"]) 139 | images = convert_if_not_rgb(images) 140 | bboxes = list(dataset["bboxes"]) 141 | line_text = list(dataset["text"]) 142 | languages = list(dataset["language"]) 143 | 144 | print(f"Loaded {len(images)} images. Running OCR...") 145 | 146 | start = time.time() 147 | predictions_by_image = rec_predictor(images, None, bboxes=bboxes) 148 | surya_time = time.time() - start 149 | 150 | lang_list = [] 151 | for lang in languages: 152 | if not isinstance(lang, list): 153 | lang_list.append([lang]) 154 | else: 155 | lang_list.append(lang) 156 | 157 | surya_scores = defaultdict(list) 158 | img_surya_scores = [] 159 | outputs = [] 160 | for idx, (pred, ref_text, langs) in enumerate( 161 | zip(predictions_by_image, line_text, lang_list) 162 | ): 163 | pred_text = [line.text for line in pred.text_lines] 164 | 165 | score_ref_text = [normalize_text(line) for line in ref_text] 166 | score_pred_text = [normalize_text(text) for text in pred_text] 167 | image_scores, image_weights = overlap_score_exact( 168 | score_pred_text, score_ref_text 169 | ) 170 | normalized_scores = [ 171 | score / max(1, weight) for score, weight in zip(image_scores, image_weights) 172 | ] 173 | image_score = sum(image_scores) / max(1, sum(image_weights)) 174 | 175 | img_surya_scores.append(image_score) 176 | for lang in langs: 177 | surya_scores[CODE_TO_LANGUAGE[lang]].append(image_score) 178 | 179 | assert len(pred_text) == len(ref_text) == len(bboxes[idx]) 180 | if debug: 181 | for j, (pred_line, ref_line, score, bbox) in enumerate( 182 | zip(pred_text, ref_text, normalized_scores, bboxes[idx]) 183 | ): 184 | image_slice = images[idx].crop(bbox) 185 | 186 | outputs.append( 187 | { 188 | "image": image_slice, 189 | "bbox": bbox, 190 | "score": score, 191 | "pred": pred_line, 192 | "ref": ref_line, 193 | "langs": ",".join(langs), 194 | } 195 | ) 196 | 197 | if debug: 198 | out_ds = datasets.Dataset.from_list(outputs) 199 | out_ds.push_to_hub("datalab-to/rec_bench_outputs", private=True) 200 | 201 | flat_surya_scores = [score for lang in surya_scores for score in surya_scores[lang]] 202 | benchmark_stats = { 203 | "surya": { 204 | "avg_score": sum(flat_surya_scores) / max(1, len(flat_surya_scores)), 205 | "lang_scores": { 206 | lang: sum(scores) / max(1, len(scores)) 207 | for lang, scores in surya_scores.items() 208 | }, 209 | "time_per_img": surya_time / max(1, len(images)), 210 | } 211 | } 212 | 213 | result_path = os.path.join(results_dir, "rec_bench") 214 | os.makedirs(result_path, exist_ok=True) 215 | 216 | with open(os.path.join(result_path, "surya_scores.json"), "w+") as f: 217 | json.dump(surya_scores, f) 218 | 219 | if tesseract: 220 | tess_valid = [] 221 | tess_langs = [] 222 | for idx, lang in enumerate(lang_list): 223 | # Tesseract does not support all languages 224 | tess_lang = surya_lang_to_tesseract(lang[0]) 225 | if tess_lang is None: 226 | continue 227 | 228 | tess_valid.append(idx) 229 | tess_langs.append(tess_lang) 230 | 231 | tess_imgs = [images[i] for i in tess_valid] 232 | tess_bboxes = [bboxes[i] for i in tess_valid] 233 | tess_reference = [line_text[i] for i in tess_valid] 234 | start = time.time() 235 | tess_predictions = tesseract_ocr_parallel( 236 | tess_imgs, tess_bboxes, tess_langs, cpus=tess_cpus 237 | ) 238 | tesseract_time = time.time() - start 239 | 240 | tess_scores = defaultdict(list) 241 | for idx, (pred, ref_text, lang) in enumerate( 242 | zip(tess_predictions, tess_reference, tess_langs) 243 | ): 244 | image_scores, image_weights, _ = overlap_score(pred, ref_text) 245 | image_score = sum(image_scores) / max(1, sum(image_weights)) 246 | tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score) 247 | 248 | flat_tess_scores = [ 249 | score for lang in tess_scores for score in tess_scores[lang] 250 | ] 251 | benchmark_stats["tesseract"] = { 252 | "avg_score": sum(flat_tess_scores) / len(flat_tess_scores), 253 | "lang_scores": { 254 | lang: sum(scores) / len(scores) for lang, scores in tess_scores.items() 255 | }, 256 | "time_per_img": tesseract_time / len(tess_imgs), 257 | } 258 | 259 | with open(os.path.join(result_path, "tesseract_scores.json"), "w+") as f: 260 | json.dump(tess_scores, f) 261 | 262 | if textract: 263 | start = time.time() 264 | textract_predictions = textract_ocr_parallel(images, cpus=textract_cpus) 265 | textract_time = time.time() - start 266 | 267 | textract_scores = defaultdict(list) 268 | for idx, (pred, ref_text, lang) in enumerate( 269 | zip(textract_predictions, line_text, lang_list) 270 | ): 271 | image_scores, image_weights, _ = overlap_score(pred, ref_text) 272 | image_score = sum(image_scores) / max(1, sum(image_weights)) 273 | 274 | for lang in lang: 275 | textract_scores[CODE_TO_LANGUAGE[lang]].append(image_score) 276 | 277 | flat_textract_scores = [ 278 | score for lang in textract_scores for score in textract_scores[lang] 279 | ] 280 | benchmark_stats["textract"] = { 281 | "avg_score": sum(flat_textract_scores) / len(flat_textract_scores), 282 | "lang_scores": { 283 | lang: sum(scores) / len(scores) 284 | for lang, scores in textract_scores.items() 285 | }, 286 | "time_per_img": textract_time / len(images), 287 | } 288 | print(len(flat_textract_scores)) 289 | 290 | with open(os.path.join(result_path, "textract_scores.json"), "w+") as f: 291 | json.dump(textract_scores, f) 292 | 293 | with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: 294 | json.dump(benchmark_stats, f) 295 | 296 | key_languages = [k for k in KEY_LANGUAGES if k in surya_scores] 297 | table_headers = ["Model", "Time per page (s)", "Avg Score"] + key_languages 298 | table_data = [ 299 | [ 300 | "surya", 301 | benchmark_stats["surya"]["time_per_img"], 302 | benchmark_stats["surya"]["avg_score"], 303 | ] 304 | + [benchmark_stats["surya"]["lang_scores"][lang] for lang in key_languages], 305 | ] 306 | if tesseract: 307 | table_data.append( 308 | [ 309 | "tesseract", 310 | benchmark_stats["tesseract"]["time_per_img"], 311 | benchmark_stats["tesseract"]["avg_score"], 312 | ] 313 | + [ 314 | benchmark_stats["tesseract"]["lang_scores"].get(lang, 0) 315 | for lang in key_languages 316 | ] 317 | ) 318 | if textract: 319 | table_data.append( 320 | [ 321 | "textract", 322 | benchmark_stats["textract"]["time_per_img"], 323 | benchmark_stats["textract"]["avg_score"], 324 | ] 325 | + [ 326 | benchmark_stats["textract"]["lang_scores"][lang] 327 | for lang in key_languages 328 | ], 329 | ) 330 | 331 | print(tabulate(table_data, headers=table_headers, tablefmt="github")) 332 | print( 333 | "Only a few major languages are displayed. See the result path for additional languages." 334 | ) 335 | 336 | if debug >= 1: 337 | bad_detections = [] 338 | for idx, (score, lang) in enumerate(zip(flat_surya_scores, lang_list)): 339 | if score < 0.8: 340 | bad_detections.append((idx, lang, score)) 341 | print(f"Found {len(bad_detections)} bad detections. Writing to file...") 342 | with open(os.path.join(result_path, "bad_detections.json"), "w+") as f: 343 | json.dump(bad_detections, f) 344 | 345 | if debug == 2: 346 | for idx, (image, pred, ref_text, bbox, lang) in enumerate( 347 | zip(images, predictions_by_image, line_text, bboxes, lang_list) 348 | ): 349 | pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png" 350 | ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png" 351 | pred_text = [line.text for line in pred.text_lines] 352 | pred_image = draw_text_on_image(bbox, pred_text, image.size) 353 | pred_image.save(os.path.join(result_path, pred_image_name)) 354 | ref_image = draw_text_on_image(bbox, ref_text, image.size) 355 | ref_image.save(os.path.join(result_path, ref_image_name)) 356 | image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png")) 357 | 358 | print(f"Wrote results to {result_path}") 359 | 360 | if print_results: 361 | for idx, (pred, ref_text) in enumerate(zip(predictions_by_image, line_text)): 362 | print(f"Image {idx}") 363 | print("----") 364 | for line_idx, (pred_line, ref_line) in enumerate( 365 | zip(pred.text_lines, ref_text) 366 | ): 367 | print(f"Sample {line_idx}") 368 | print(f"Pred: {pred_line.text}") 369 | print(f"Ref: {ref_line}") 370 | print() 371 | 372 | if settings.TORCH_DEVICE == "xla": 373 | import torch_xla.debug.metrics as met 374 | 375 | print(met.short_metrics_report()) 376 | 377 | 378 | if __name__ == "__main__": 379 | main() 380 | ``` -------------------------------------------------------------------------------- /surya/detection/processor.py: -------------------------------------------------------------------------------- ```python 1 | import warnings 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | import numpy as np 5 | 6 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict 7 | from transformers.image_transforms import to_channel_dimension_format 8 | from transformers.image_utils import ( 9 | IMAGENET_DEFAULT_MEAN, 10 | IMAGENET_DEFAULT_STD, 11 | ChannelDimension, 12 | ImageInput, 13 | PILImageResampling, 14 | infer_channel_dimension_format, 15 | make_list_of_images, 16 | ) 17 | from transformers.utils import TensorType 18 | 19 | 20 | import PIL.Image 21 | import torch 22 | 23 | from surya.common.s3 import S3DownloaderMixin 24 | 25 | 26 | class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor): 27 | r""" 28 | Constructs a Segformer image processor. 29 | 30 | Args: 31 | do_resize (`bool`, *optional*, defaults to `True`): 32 | Whether to resize the image's (height, width) dimensions to the specified `(size["height"], 33 | size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. 34 | size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`): 35 | Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` 36 | method. 37 | resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): 38 | Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the 39 | `preprocess` method. 40 | do_rescale (`bool`, *optional*, defaults to `True`): 41 | Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` 42 | parameter in the `preprocess` method. 43 | rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): 44 | Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` 45 | method. 46 | do_normalize (`bool`, *optional*, defaults to `True`): 47 | Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` 48 | method. 49 | image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): 50 | Mean to use if normalizing the image. This is a float or list of floats the length of the number of 51 | channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. 52 | image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): 53 | Standard deviation to use if normalizing the image. This is a float or list of floats the length of the 54 | number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. 55 | do_reduce_labels (`bool`, *optional*, defaults to `False`): 56 | Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is 57 | used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The 58 | background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the 59 | `preprocess` method. 60 | """ 61 | 62 | model_input_names = ["pixel_values"] 63 | 64 | def __init__( 65 | self, 66 | do_resize: bool = True, 67 | size: Dict[str, int] = None, 68 | resample: PILImageResampling = PILImageResampling.BILINEAR, 69 | do_rescale: bool = True, 70 | rescale_factor: Union[int, float] = 1 / 255, 71 | do_normalize: bool = True, 72 | image_mean: Optional[Union[float, List[float]]] = None, 73 | image_std: Optional[Union[float, List[float]]] = None, 74 | do_reduce_labels: bool = False, 75 | **kwargs, 76 | ) -> None: 77 | if "reduce_labels" in kwargs: 78 | warnings.warn( 79 | "The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use " 80 | "`do_reduce_labels` instead.", 81 | FutureWarning, 82 | ) 83 | do_reduce_labels = kwargs.pop("reduce_labels") 84 | 85 | super().__init__(**kwargs) 86 | size = size if size is not None else {"height": 512, "width": 512} 87 | size = get_size_dict(size) 88 | self.do_resize = do_resize 89 | self.size = size 90 | self.resample = resample 91 | self.do_rescale = do_rescale 92 | self.rescale_factor = rescale_factor 93 | self.do_normalize = do_normalize 94 | self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN 95 | self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD 96 | self.do_reduce_labels = do_reduce_labels 97 | self._valid_processor_keys = [ 98 | "images", 99 | "segmentation_maps", 100 | "do_resize", 101 | "size", 102 | "resample", 103 | "do_rescale", 104 | "rescale_factor", 105 | "do_normalize", 106 | "image_mean", 107 | "image_std", 108 | "do_reduce_labels", 109 | "return_tensors", 110 | "data_format", 111 | "input_data_format", 112 | ] 113 | 114 | @classmethod 115 | def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): 116 | """ 117 | Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image 118 | processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint, 119 | reduce_labels=True)` 120 | """ 121 | image_processor_dict = image_processor_dict.copy() 122 | if "reduce_labels" in kwargs: 123 | image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels") 124 | return super().from_dict(image_processor_dict, **kwargs) 125 | 126 | def _preprocess( 127 | self, 128 | image: ImageInput, 129 | do_resize: bool, 130 | do_rescale: bool, 131 | do_normalize: bool, 132 | size: Optional[Dict[str, int]] = None, 133 | resample: PILImageResampling = None, 134 | rescale_factor: Optional[float] = None, 135 | image_mean: Optional[Union[float, List[float]]] = None, 136 | image_std: Optional[Union[float, List[float]]] = None, 137 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 138 | ): 139 | 140 | if do_rescale: 141 | image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) 142 | 143 | if do_normalize: 144 | image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) 145 | 146 | return image 147 | 148 | def _preprocess_image( 149 | self, 150 | image: ImageInput, 151 | do_resize: bool = None, 152 | size: Dict[str, int] = None, 153 | resample: PILImageResampling = None, 154 | do_rescale: bool = None, 155 | rescale_factor: float = None, 156 | do_normalize: bool = None, 157 | image_mean: Optional[Union[float, List[float]]] = None, 158 | image_std: Optional[Union[float, List[float]]] = None, 159 | data_format: Optional[Union[str, ChannelDimension]] = None, 160 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 161 | ) -> np.ndarray: 162 | """Preprocesses a single image.""" 163 | # All transformations expect numpy arrays. 164 | if input_data_format is None: 165 | input_data_format = infer_channel_dimension_format(image) 166 | 167 | image = self._preprocess( 168 | image=image, 169 | do_resize=do_resize, 170 | size=size, 171 | resample=resample, 172 | do_rescale=do_rescale, 173 | rescale_factor=rescale_factor, 174 | do_normalize=do_normalize, 175 | image_mean=image_mean, 176 | image_std=image_std, 177 | input_data_format=input_data_format, 178 | ) 179 | if data_format is not None: 180 | image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) 181 | return image 182 | 183 | def __call__(self, images, segmentation_maps=None, **kwargs): 184 | """ 185 | Preprocesses a batch of images and optionally segmentation maps. 186 | 187 | Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be 188 | passed in as positional arguments. 189 | """ 190 | return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) 191 | 192 | def preprocess( 193 | self, 194 | images: ImageInput, 195 | segmentation_maps: Optional[ImageInput] = None, 196 | do_resize: Optional[bool] = None, 197 | size: Optional[Dict[str, int]] = None, 198 | resample: PILImageResampling = None, 199 | do_rescale: Optional[bool] = None, 200 | rescale_factor: Optional[float] = None, 201 | do_normalize: Optional[bool] = None, 202 | image_mean: Optional[Union[float, List[float]]] = None, 203 | image_std: Optional[Union[float, List[float]]] = None, 204 | do_reduce_labels: Optional[bool] = None, 205 | return_tensors: Optional[Union[str, TensorType]] = None, 206 | data_format: ChannelDimension = ChannelDimension.FIRST, 207 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 208 | **kwargs, 209 | ) -> PIL.Image.Image: 210 | """ 211 | Preprocess an image or batch of images. 212 | 213 | Args: 214 | images (`ImageInput`): 215 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If 216 | passing in images with pixel values between 0 and 1, set `do_rescale=False`. 217 | segmentation_maps (`ImageInput`, *optional*): 218 | Segmentation map to preprocess. 219 | do_resize (`bool`, *optional*, defaults to `self.do_resize`): 220 | Whether to resize the image. 221 | size (`Dict[str, int]`, *optional*, defaults to `self.size`): 222 | Size of the image after `resize` is applied. 223 | resample (`int`, *optional*, defaults to `self.resample`): 224 | Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only 225 | has an effect if `do_resize` is set to `True`. 226 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): 227 | Whether to rescale the image values between [0 - 1]. 228 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): 229 | Rescale factor to rescale the image by if `do_rescale` is set to `True`. 230 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): 231 | Whether to normalize the image. 232 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): 233 | Image mean. 234 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): 235 | Image standard deviation. 236 | do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): 237 | Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 238 | is used for background, and background itself is not included in all classes of a dataset (e.g. 239 | ADE20k). The background label will be replaced by 255. 240 | return_tensors (`str` or `TensorType`, *optional*): 241 | The type of tensors to return. Can be one of: 242 | - Unset: Return a list of `np.ndarray`. 243 | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. 244 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. 245 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. 246 | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. 247 | data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): 248 | The channel dimension format for the output image. Can be one of: 249 | - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 250 | - `ChannelDimension.LAST`: image in (height, width, num_channels) format. 251 | input_data_format (`ChannelDimension` or `str`, *optional*): 252 | The channel dimension format for the input image. If unset, the channel dimension format is inferred 253 | from the input image. Can be one of: 254 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 255 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 256 | - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. 257 | """ 258 | do_resize = do_resize if do_resize is not None else self.do_resize 259 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale 260 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize 261 | resample = resample if resample is not None else self.resample 262 | size = size if size is not None else self.size 263 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor 264 | image_mean = image_mean if image_mean is not None else self.image_mean 265 | image_std = image_std if image_std is not None else self.image_std 266 | 267 | images = make_list_of_images(images) 268 | images = [ 269 | self._preprocess_image( 270 | image=img, 271 | do_resize=do_resize, 272 | resample=resample, 273 | size=size, 274 | do_rescale=do_rescale, 275 | rescale_factor=rescale_factor, 276 | do_normalize=do_normalize, 277 | image_mean=image_mean, 278 | image_std=image_std, 279 | data_format=data_format, 280 | input_data_format=input_data_format, 281 | ) 282 | for img in images 283 | ] 284 | 285 | data = {"pixel_values": images} 286 | return BatchFeature(data=data, tensor_type=return_tensors) ``` -------------------------------------------------------------------------------- /surya/foundation/cache/dynamic_ops.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Any, Dict, List, Optional, Tuple 2 | import torch 3 | from transformers import PretrainedConfig 4 | 5 | """ 6 | Special cache class for the surya foundation model that supports - 7 | 1) Static shape 8 | 2) A custom sliding window, where image tokens stay in cache, and text tokens are popped 9 | 3) Continuous batching - merging etc 10 | 4) Attention mask management - To match with what's currently in the cache 11 | 12 | Heavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079 13 | """ 14 | 15 | 16 | class DynamicOpsCache: 17 | def __init__( 18 | self, 19 | config: PretrainedConfig, 20 | batch_size: int, 21 | max_cache_len: int, 22 | text_sliding_window: int, 23 | device: int, 24 | dtype: int, 25 | ): 26 | self.text_sliding_window = text_sliding_window 27 | self.num_layers = config.num_hidden_layers 28 | self.max_batch_size = batch_size 29 | self.max_cache_len = max_cache_len 30 | self.head_dim = ( 31 | getattr(config, "head_dim", None) 32 | or config.hidden_size // config.num_attention_heads 33 | ) 34 | self._dtype = dtype 35 | self.num_key_value_heads = ( 36 | config.num_attention_heads 37 | if getattr(config, "num_key_value_heads", None) is None 38 | else config.num_key_value_heads 39 | ) 40 | 41 | # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125 42 | self.key_cache: list[torch.Tensor] = [] 43 | self.value_cache: list[torch.Tensor] = [] 44 | cache_shape = ( 45 | self.max_batch_size, 46 | self.num_key_value_heads, 47 | self.max_cache_len, 48 | self.head_dim, 49 | ) 50 | device = torch.device(device) if device is not None else None 51 | for _ in range(config.num_hidden_layers): 52 | new_layer_key_cache = torch.zeros( 53 | cache_shape, dtype=self._dtype, device=device 54 | ) 55 | new_layer_value_cache = torch.zeros( 56 | cache_shape, dtype=self._dtype, device=device 57 | ) 58 | torch._dynamo.mark_static_address(new_layer_key_cache) 59 | torch._dynamo.mark_static_address(new_layer_value_cache) 60 | self.key_cache.append(new_layer_key_cache) 61 | self.value_cache.append(new_layer_value_cache) 62 | 63 | self.attention_mask = torch.zeros( 64 | (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long 65 | ) 66 | self.text_token_counts = [ 67 | torch.zeros(self.max_batch_size, dtype=torch.long, device=device) 68 | for _ in range(self.num_layers) 69 | ] 70 | 71 | self.dtype = dtype 72 | self.device = device 73 | 74 | def update( 75 | self, 76 | key_states: torch.Tensor, 77 | value_states: torch.Tensor, 78 | layer_idx: int, 79 | cache_kwargs: Optional[Dict[str, Any]] = None, 80 | ) -> Tuple[torch.Tensor, torch.Tensor]: 81 | prefill = cache_kwargs.get("prefill", False) 82 | update_fn = self._prefill_update if prefill else self._decode_update 83 | return update_fn( 84 | self.key_cache[layer_idx], 85 | self.value_cache[layer_idx], 86 | key_states, 87 | value_states, 88 | self.text_token_counts[layer_idx], 89 | cache_kwargs, 90 | ) 91 | 92 | def update_text_counts( 93 | self, 94 | merge_idxs: torch.Tensor, 95 | valid_batch_size: torch.Tensor, 96 | new_text_lens: torch.Tensor, 97 | ): 98 | new_text_len_tensor = new_text_lens.to(device=self.device) 99 | 100 | for layer_idx in range(self.num_layers): 101 | self.text_token_counts[layer_idx][merge_idxs] = new_text_len_tensor[ 102 | :valid_batch_size 103 | ] 104 | 105 | # Mirrors the logic from _prefill_update 106 | # Logic is better explained in this funcrtion 107 | def prefill_attention_mask_update( 108 | self, 109 | prefill_attention_mask: torch.Tensor, 110 | merge_idxs: torch.Tensor, 111 | valid_batch_mask: torch.Tensor, 112 | text_lengths: List[int], 113 | ): 114 | seq_len = prefill_attention_mask.shape[1] 115 | sliding_window = self.text_sliding_window 116 | total_cache_len = self.max_cache_len 117 | prefix_cache_space = total_cache_len - sliding_window 118 | 119 | for batch_idx, cache_idx in enumerate(merge_idxs): 120 | text_len = text_lengths[batch_idx] 121 | prefix_len = seq_len - text_len 122 | self.attention_mask[cache_idx] = 0 # Set default 123 | 124 | assert prefix_len > 0, "There are no prefix (image) tokens!" 125 | 126 | end_pos = prefix_cache_space 127 | # Handle prefix part - Which may be left padded 128 | if prefix_len <= prefix_cache_space: 129 | start_pos = prefix_cache_space - prefix_len 130 | self.attention_mask[cache_idx, start_pos:end_pos] = ( 131 | prefill_attention_mask[batch_idx, :prefix_len] 132 | ) 133 | else: 134 | self.attention_mask[cache_idx, :end_pos] = prefill_attention_mask[ 135 | batch_idx, prefix_len - prefix_cache_space : prefix_len 136 | ] 137 | 138 | # Handle text part, keeping sliding window in consideration 139 | # All of the left padding is before the prefix, so we can ignore the prefill_attention_mask here 140 | if text_len > 0: 141 | text_cache_start = prefix_cache_space 142 | if text_len <= sliding_window: 143 | self.attention_mask[ 144 | cache_idx, text_cache_start : text_cache_start + text_len 145 | ] = 1 146 | else: 147 | self.attention_mask[cache_idx, -sliding_window:] = 1 148 | 149 | # Slow impl for now - Prefill time is dominated by the large sequence length forward pass 150 | def _prefill_update( 151 | self, 152 | key_cache: torch.Tensor, 153 | value_cache: torch.Tensor, 154 | key_states: torch.Tensor, 155 | value_states: torch.Tensor, 156 | text_token_counts: torch.Tensor, 157 | cache_kwargs: Optional[Dict[str, Any]] = None, 158 | ): 159 | cache_idxs: List[int] = cache_kwargs.get("cache_idxs", None) 160 | text_lengths: List[int] = cache_kwargs.get("text_lengths", None) 161 | assert cache_idxs is not None, "cache_idxs must be specified during prefill" 162 | assert text_lengths is not None, "text_lengths must be specified during prefill" 163 | 164 | _, _, seq_len, _ = key_states.shape 165 | total_cache_len = self.max_cache_len 166 | sliding_window = self.text_sliding_window 167 | prefix_cache_space = total_cache_len - sliding_window 168 | 169 | for batch_idx, cache_idx in enumerate(cache_idxs): 170 | text_len = text_lengths[batch_idx] 171 | prefix_len = seq_len - text_len 172 | 173 | ###### Handle Image Tokens (Prefix) ##### 174 | # Place image tokens in appropriate cache space, aligned to the **right edge** 175 | assert prefix_len > 0, "There are no prefix (image) tokens!" 176 | 177 | # prefix_len may be greater than the prefix cache space due to left padding - This happens when 178 | # a different batch element has a large input text during prefill, causing others to have a lot of 179 | # left padding. We can safely take the last `prefix_cache_space` elements from the kv states, since 180 | # `prefix_cache_space` is large enough to fit any image, and the rest **has to be** padding 181 | end_pos = prefix_cache_space 182 | if prefix_len <= prefix_cache_space: 183 | start_pos = prefix_cache_space - prefix_len 184 | key_cache[cache_idx, :, start_pos:end_pos] = key_states[ 185 | batch_idx, :, :prefix_len 186 | ] 187 | value_cache[cache_idx, :, start_pos:end_pos] = value_states[ 188 | batch_idx, :, :prefix_len 189 | ] 190 | else: 191 | key_cache[cache_idx, :, :end_pos] = key_states[ 192 | batch_idx, :, prefix_len - prefix_cache_space : prefix_len 193 | ] 194 | value_cache[cache_idx, :, :end_pos] = value_states[ 195 | batch_idx, :, prefix_len - prefix_cache_space : prefix_len 196 | ] 197 | 198 | ###### Handle Text Tokens ##### 199 | # Text tokens start at the **left edge** of sliding window cache space 200 | if text_len > 0: 201 | text_cache_start = prefix_cache_space 202 | 203 | if text_len <= sliding_window: 204 | key_cache[ 205 | cache_idx, :, text_cache_start : text_cache_start + text_len 206 | ] = key_states[batch_idx, :, prefix_len : prefix_len + text_len] 207 | value_cache[ 208 | cache_idx, :, text_cache_start : text_cache_start + text_len 209 | ] = value_states[batch_idx, :, prefix_len : prefix_len + text_len] 210 | else: 211 | start_in_text = text_len - sliding_window 212 | key_cache[ 213 | cache_idx, 214 | :, 215 | text_cache_start : text_cache_start + sliding_window, 216 | ] = key_states[ 217 | batch_idx, :, prefix_len + start_in_text : prefix_len + text_len 218 | ] 219 | value_cache[ 220 | cache_idx, 221 | :, 222 | text_cache_start : text_cache_start + sliding_window, 223 | ] = value_states[ 224 | batch_idx, :, prefix_len + start_in_text : prefix_len + text_len 225 | ] 226 | 227 | # Return the full key/value states (not just cached) for use in subsequent layers 228 | return key_states, value_states 229 | 230 | # """ 231 | # Matches the logic of the decode update, but needs to be called before the updates 232 | # since some parts of the model depend on the attention mask 233 | # """ 234 | def decode_attention_mask_update( 235 | self, num_valid_tokens: torch.Tensor, cache_idxs: List[int] 236 | ): 237 | sliding_window = self.text_sliding_window 238 | text_cache_start = self.max_cache_len - sliding_window 239 | 240 | # Using text_token_counts of first layer, should be same for all though 241 | current_text_lens = self.text_token_counts[0] 242 | cache_idxs_tensor = torch.tensor(cache_idxs, device=current_text_lens.device) 243 | 244 | # Get current text lengths for the relevant cache indices 245 | current_lens = current_text_lens[cache_idxs_tensor] 246 | new_text_lens = current_lens + num_valid_tokens 247 | is_full = new_text_lens > sliding_window 248 | 249 | # Handle full caches - set entire sliding window to 1 250 | if is_full.any(): 251 | full_mask = is_full 252 | full_cache_idxs = cache_idxs_tensor[full_mask] 253 | self.attention_mask[full_cache_idxs, text_cache_start:] = 1 254 | 255 | # Handle non-full caches - set specific ranges to 1 256 | if (~is_full).any(): 257 | non_full_mask = ~is_full 258 | non_full_cache_idxs = cache_idxs_tensor[non_full_mask] 259 | non_full_current_lens = current_lens[non_full_mask] 260 | non_full_valid_tokens = num_valid_tokens[non_full_mask] 261 | 262 | max_valid_tokens = ( 263 | non_full_valid_tokens.max().item() 264 | if len(non_full_valid_tokens) > 0 265 | else 0 266 | ) 267 | if max_valid_tokens > 0: 268 | batch_size = len(non_full_cache_idxs) 269 | offset_range = torch.arange( 270 | max_valid_tokens, device=current_text_lens.device 271 | ) 272 | batch_offsets = offset_range.unsqueeze(0).expand(batch_size, -1) 273 | start_positions = non_full_current_lens.unsqueeze(1) 274 | valid_token_counts = non_full_valid_tokens.unsqueeze(1) 275 | 276 | position_indices = start_positions + batch_offsets 277 | valid_mask = batch_offsets < valid_token_counts 278 | 279 | row_indices = non_full_cache_idxs.unsqueeze(1).expand( 280 | -1, max_valid_tokens 281 | )[valid_mask] 282 | col_indices = text_cache_start + position_indices[valid_mask] 283 | 284 | self.attention_mask[row_indices, col_indices] = 1 285 | 286 | """ 287 | Static cache update 288 | - respects per-batch text token limits 289 | - per-batch valid token lengths (right-padded inputs) 290 | 291 | kv states are expected to have shape [batch_size, kv_heads, T_pad, head_dim] 292 | They may have different `true` lengths, to account for multi token preds, or beacon tokens 293 | Expects `num_valid_tokens` in cache_kwargs: a tensor of shape (B,) indicating the number 294 | of actual (non-padded) tokens to add per batch element. 295 | """ 296 | 297 | def _decode_update( 298 | self, 299 | key_cache: torch.Tensor, 300 | value_cache: torch.Tensor, 301 | key_states: torch.Tensor, 302 | value_states: torch.Tensor, 303 | text_token_counts: torch.Tensor, 304 | cache_kwargs: Optional[Dict[str, Any]] = None, 305 | ) -> Tuple[torch.Tensor, torch.Tensor]: 306 | num_valid_tokens: torch.Tensor = cache_kwargs.get( 307 | "num_valid_tokens" 308 | ) # shape: (B,) 309 | assert num_valid_tokens is not None, ( 310 | "`num_valid_tokens` must be provided in `cache_kwargs`" 311 | ) 312 | device = key_states.device 313 | 314 | batch_size, num_head, seq_len, head_dim = key_states.shape 315 | sliding_window = self.text_sliding_window 316 | max_cache_len = self.max_cache_len 317 | cache_text_start = max_cache_len - sliding_window 318 | new_text_lengths = text_token_counts + num_valid_tokens 319 | slide_amounts = torch.clamp(new_text_lengths - sliding_window, min=0) 320 | needs_rotate = slide_amounts > 0 321 | 322 | # Rotate the cache if needed 323 | if torch.any(needs_rotate): 324 | k_slice = key_cache[:, :, -sliding_window:] # shape: [B, H, W, D] 325 | v_slice = value_cache[:, :, -sliding_window:] # same shape 326 | 327 | cache_indices = ( 328 | torch.arange(sliding_window, device=device) 329 | .unsqueeze(0) 330 | .repeat(batch_size, 1) 331 | ) # [B, W] 332 | rolled_indices = ( 333 | cache_indices + slide_amounts.unsqueeze(1) 334 | ) % sliding_window # [B, W] 335 | 336 | # We need to expand indices to shape: [B, 1, W, 1] to broadcast with k_slice 337 | rolled_indices = ( 338 | rolled_indices.unsqueeze(1) 339 | .unsqueeze(-1) 340 | .expand(-1, num_head, -1, head_dim) 341 | ) 342 | 343 | k_slice_rolled = k_slice.gather(dim=2, index=rolled_indices) 344 | v_slice_rolled = v_slice.gather(dim=2, index=rolled_indices) 345 | 346 | key_cache[:, :, -sliding_window:] = k_slice_rolled 347 | value_cache[:, :, -sliding_window:] = v_slice_rolled 348 | 349 | # Insert only **valid tokens** into the cache. These are **right aligned** within the input sequence 350 | insert_positions = torch.where( 351 | needs_rotate, 352 | max_cache_len - num_valid_tokens, 353 | text_token_counts + cache_text_start, 354 | ) 355 | 356 | max_tokens = num_valid_tokens.max().item() 357 | offsets = torch.arange(max_tokens, device=device).unsqueeze(0) # [1, max_T] 358 | valid_mask = offsets < num_valid_tokens.unsqueeze(1) # [B, max_T] 359 | src_indices = (seq_len - num_valid_tokens).unsqueeze(1) + offsets # [B, max_T] 360 | src_indices = src_indices.clamp(max=seq_len - 1) # safety 361 | 362 | tgt_indices = insert_positions.unsqueeze(1) + offsets # [B, max_T] 363 | tgt_indices = tgt_indices.clamp(max=max_cache_len - 1) # safety 364 | 365 | src_idx_exp = ( 366 | src_indices.unsqueeze(1) 367 | .unsqueeze(-1) 368 | .expand(batch_size, num_head, max_tokens, head_dim) 369 | ) 370 | tgt_idx_exp = ( 371 | tgt_indices.unsqueeze(1) 372 | .unsqueeze(-1) 373 | .expand(batch_size, num_head, max_tokens, head_dim) 374 | ) 375 | valid_mask_exp = ( 376 | valid_mask.unsqueeze(1) 377 | .unsqueeze(-1) 378 | .expand(batch_size, num_head, max_tokens, head_dim) 379 | ) 380 | 381 | k_src = torch.gather(key_states, 2, src_idx_exp) 382 | v_src = torch.gather(value_states, 2, src_idx_exp) 383 | k_src = k_src * valid_mask_exp 384 | v_src = v_src * valid_mask_exp 385 | 386 | # Write into cache 387 | key_cache.scatter_(2, tgt_idx_exp, k_src) 388 | value_cache.scatter_(2, tgt_idx_exp, v_src) 389 | 390 | # In-place edit - Mutates 391 | text_token_counts += num_valid_tokens 392 | text_token_counts.clamp_(max=sliding_window) 393 | 394 | return key_cache, value_cache 395 | 396 | # We have a non-uniform cache, so its better to not return it and handle any logic 397 | # that requires this ourselves 398 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 399 | raise NotImplementedError() 400 | ``` -------------------------------------------------------------------------------- /surya/table_rec/__init__.py: -------------------------------------------------------------------------------- ```python 1 | from copy import deepcopy 2 | from itertools import chain 3 | from typing import List 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from tqdm import tqdm 9 | 10 | from surya.common.xla import mark_step 11 | from surya.common.predictor import BasePredictor 12 | from surya.table_rec.schema import TableCell, TableRow, TableCol, TableResult 13 | from surya.common.polygon import PolygonBox 14 | from surya.settings import settings 15 | from surya.table_rec.loader import TableRecModelLoader 16 | from surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM, CATEGORY_TO_ID, MERGE_KEYS, \ 17 | MERGE_VALUES 18 | from surya.table_rec.shaper import LabelShaper 19 | 20 | 21 | class TableRecPredictor(BasePredictor): 22 | model_loader_cls = TableRecModelLoader 23 | batch_size = settings.TABLE_REC_BATCH_SIZE 24 | default_batch_sizes = { 25 | "cpu": 8, 26 | "mps": 8, 27 | "cuda": 32, 28 | "xla": 16 29 | } 30 | 31 | def __call__(self, images: List[Image.Image], batch_size: int | None = None) -> List[TableResult]: 32 | return self.batch_table_recognition(images, batch_size) 33 | 34 | def inference_loop( 35 | self, 36 | encoder_hidden_states: torch.Tensor, 37 | batch_input_ids: torch.Tensor, 38 | current_batch_size: int, 39 | batch_size: int 40 | ): 41 | shaper = LabelShaper() 42 | batch_predictions = [[] for _ in range(current_batch_size)] 43 | max_tokens = settings.TABLE_REC_MAX_BOXES 44 | decoder_position_ids = torch.ones_like(batch_input_ids[0, :, 0], dtype=torch.int64, device=self.model.device).cumsum( 45 | 0) - 1 46 | inference_token_count = batch_input_ids.shape[1] 47 | 48 | if settings.TABLE_REC_STATIC_CACHE: 49 | encoder_hidden_states = self.pad_to_batch_size(encoder_hidden_states, batch_size) 50 | batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) 51 | 52 | # Move to device after padding for XLA 53 | encoder_hidden_states = encoder_hidden_states.to(self.model.device) 54 | batch_input_ids = batch_input_ids.to(self.model.device) 55 | 56 | self.model.decoder.model._setup_cache(self.model.config, batch_size, self.model.device, self.model.dtype) 57 | 58 | with settings.INFERENCE_MODE(): 59 | token_count = 0 60 | all_done = torch.zeros(encoder_hidden_states.shape[0], dtype=torch.bool, device=self.model.device) 61 | 62 | while token_count < max_tokens: 63 | is_prefill = token_count == 0 64 | return_dict = self.model.decoder( 65 | input_ids=batch_input_ids, 66 | encoder_hidden_states=encoder_hidden_states, 67 | cache_position=decoder_position_ids, 68 | use_cache=True, 69 | prefill=is_prefill 70 | ) 71 | 72 | decoder_position_ids = decoder_position_ids[-1:] + 1 73 | 74 | # Get predictions for each box element 75 | box_properties = [] 76 | done = [] 77 | 78 | # Pre-process all logits at once 79 | processed_logits = {} 80 | for k, _, mode in BOX_PROPERTIES: 81 | k_logits = return_dict["box_property_logits"][k][:, -1, :] # Get all batch logits at once 82 | 83 | if mode == "classification": 84 | # Process all classification logits in one operation 85 | items = torch.argmax(k_logits, dim=-1) 86 | if k == "category": 87 | done = (items == self.model.decoder.config.eos_token_id) | (items == self.model.decoder.config.pad_token_id) 88 | items = items - SPECIAL_TOKENS 89 | processed_logits[k] = items 90 | elif mode == "regression": 91 | if k == "bbox": 92 | k_logits = k_logits * BOX_DIM 93 | processed_logits[k] = k_logits 94 | elif k == "colspan": 95 | k_logits = torch.clamp(k_logits, min=1) 96 | processed_logits[k] = torch.round(k_logits) 97 | 98 | items = {k: processed_logits[k].cpu() for k, _, _ in BOX_PROPERTIES} 99 | for j in range(current_batch_size): 100 | box_property = {} 101 | for k, _, mode in BOX_PROPERTIES: 102 | if mode == "classification": 103 | box_property[k] = int(items[k][j].item()) 104 | elif mode == "regression": 105 | if k == "bbox": 106 | box_property[k] = items[k][j].tolist() 107 | elif k == "colspan": 108 | box_property[k] = int(items[k][j].item()) 109 | box_properties.append(box_property) 110 | 111 | all_done = all_done | done 112 | all_done_cpu = all_done.cpu() 113 | 114 | if all_done_cpu[:current_batch_size].all(): 115 | break 116 | 117 | batch_input_ids = torch.tensor(shaper.dict_to_labels(box_properties), dtype=torch.long) 118 | batch_input_ids = batch_input_ids.unsqueeze(1) # Add sequence length dimension 119 | 120 | for j, (box_property, status) in enumerate(zip(box_properties, all_done_cpu)): 121 | if not status: 122 | batch_predictions[j].append(box_property) 123 | 124 | token_count += inference_token_count 125 | inference_token_count = batch_input_ids.shape[1] 126 | 127 | if settings.TABLE_REC_STATIC_CACHE: 128 | batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) 129 | 130 | # Move to device after padding for XLA 131 | batch_input_ids = batch_input_ids.to(self.model.device) 132 | return batch_predictions 133 | 134 | def batch_table_recognition( 135 | self, 136 | images: List, 137 | batch_size=None) -> List[TableResult]: 138 | assert all([isinstance(image, Image.Image) for image in images]) 139 | if batch_size is None: 140 | batch_size = self.get_batch_size() 141 | 142 | if len(images) == 0: 143 | return [] 144 | 145 | query_items = [] 146 | for image in images: 147 | query_items.append({ 148 | "polygon": [[0, 0], [image.width, 0], [image.width, image.height], [0, image.height]], 149 | "category": CATEGORY_TO_ID["Table"], 150 | "colspan": 0, 151 | "merges": 0, 152 | "is_header": 0 153 | }) 154 | 155 | output_order = [] 156 | for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables", disable=self.disable_tqdm): 157 | batch_query_items = query_items[i:i + batch_size] 158 | 159 | batch_images = images[i:i + batch_size] 160 | batch_images = [image.convert("RGB") for image in batch_images] # also copies the images 161 | 162 | current_batch_size = len(batch_images) 163 | 164 | orig_sizes = [image.size for image in batch_images] 165 | model_inputs = self.processor(images=batch_images, query_items=batch_query_items) 166 | 167 | batch_pixel_values = model_inputs["pixel_values"] 168 | 169 | batch_input_ids = model_inputs["input_ids"] 170 | batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=self.model.dtype) 171 | 172 | if settings.TABLE_REC_STATIC_CACHE: 173 | batch_pixel_values = self.pad_to_batch_size(batch_pixel_values, batch_size) 174 | 175 | # Move to device after padding for XLA 176 | batch_pixel_values = batch_pixel_values.to(self.model.device) 177 | 178 | shaper = LabelShaper() 179 | 180 | # We only need to process each image once 181 | with settings.INFERENCE_MODE(): 182 | encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state 183 | 184 | # Inference to get rows and columns 185 | rowcol_predictions = self.inference_loop( 186 | encoder_hidden_states, 187 | batch_input_ids, 188 | current_batch_size, 189 | batch_size 190 | ) 191 | mark_step() 192 | 193 | row_query_items = [] 194 | row_encoder_hidden_states = [] 195 | idx_map = [] 196 | columns = [] 197 | for j, img_predictions in enumerate(rowcol_predictions): 198 | for row_prediction in img_predictions: 199 | polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"]) 200 | if row_prediction["category"] == CATEGORY_TO_ID["Table-row"]: 201 | row_query_items.append({ 202 | "polygon": polygon, 203 | "category": row_prediction["category"], 204 | "colspan": 0, 205 | "merges": 0, 206 | "is_header": int(row_prediction["is_header"] == 1) 207 | }) 208 | row_encoder_hidden_states.append(encoder_hidden_states[j]) 209 | idx_map.append(j) 210 | elif row_prediction["category"] == CATEGORY_TO_ID["Table-column"]: 211 | columns.append({ 212 | "polygon": polygon, 213 | "category": row_prediction["category"], 214 | "colspan": 0, 215 | "merges": 0, 216 | "is_header": int(row_prediction["is_header"] == 1) 217 | }) 218 | 219 | # Re-inference to predict cells 220 | row_encoder_hidden_states = torch.stack(row_encoder_hidden_states) 221 | row_inputs = self.processor(images=None, query_items=row_query_items, columns=columns, convert_images=False) 222 | row_input_ids = row_inputs["input_ids"] 223 | cell_predictions = [] 224 | for j in range(0, len(row_input_ids), batch_size): 225 | cell_batch_hidden_states = row_encoder_hidden_states[j:j + batch_size] 226 | cell_batch_input_ids = row_input_ids[j:j + batch_size] 227 | cell_batch_size = len(cell_batch_input_ids) 228 | cell_predictions.extend( 229 | self.inference_loop(cell_batch_hidden_states, cell_batch_input_ids, cell_batch_size, batch_size) 230 | ) 231 | mark_step() 232 | 233 | result = self.decode_batch_predictions(rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper) 234 | output_order.extend(result) 235 | 236 | return output_order 237 | 238 | 239 | def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper): 240 | results = [] 241 | for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)): 242 | row_cell_predictions = [c for i, c in enumerate(cell_predictions) if idx_map[i] == j] 243 | # Each row prediction matches a cell prediction 244 | rows = [] 245 | cells = [] 246 | columns = [] 247 | 248 | cell_id = 0 249 | row_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-row"]] 250 | col_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-column"]] 251 | 252 | # Generate table columns 253 | for z, col_prediction in enumerate(col_predictions): 254 | polygon = shaper.convert_bbox_to_polygon(col_prediction["bbox"]) 255 | polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) 256 | columns.append( 257 | TableCol( 258 | polygon=polygon, 259 | col_id=z, 260 | is_header=col_prediction["is_header"] == 1 261 | ) 262 | ) 263 | 264 | # Generate table rows 265 | for z, row_prediction in enumerate(row_predictions): 266 | polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"]) 267 | polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) 268 | row = TableRow( 269 | polygon=polygon, 270 | row_id=z, 271 | is_header=row_prediction["is_header"] == 1 272 | ) 273 | rows.append(row) 274 | 275 | # Get cells that span multiple columns within a row 276 | spanning_cells = [] 277 | for l, spanning_cell in enumerate(row_cell_predictions[z]): 278 | polygon = shaper.convert_bbox_to_polygon(spanning_cell["bbox"]) 279 | polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) 280 | colspan = max(1, int(spanning_cell["colspan"])) 281 | if colspan == 1 and spanning_cell["merges"] not in MERGE_VALUES: 282 | # Skip single column cells if they don't merge 283 | continue 284 | if PolygonBox(polygon=polygon).height < row.height * .85: 285 | # Spanning cell must cover most of the row 286 | continue 287 | 288 | spanning_cells.append( 289 | TableCell( 290 | polygon=polygon, 291 | row_id=z, 292 | rowspan=1, 293 | cell_id=cell_id, 294 | within_row_id=l, 295 | colspan=colspan, 296 | merge_up=spanning_cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]], 297 | merge_down=spanning_cell["merges"] in [MERGE_KEYS["merge_down"], 298 | MERGE_KEYS["merge_both"]], 299 | is_header=row.is_header or z == 0 300 | ) 301 | ) 302 | cell_id += 1 303 | 304 | # Add cells - either add spanning cells (multiple cols), or generate a cell based on row/col 305 | used_spanning_cells = set() 306 | skip_columns = 0 307 | for l, col in enumerate(columns): 308 | if skip_columns: 309 | skip_columns -= 1 310 | continue 311 | cell_polygon = row.intersection_polygon(col) 312 | cell_added = False 313 | for zz, spanning_cell in enumerate(spanning_cells): 314 | cell_polygonbox = PolygonBox(polygon=cell_polygon) 315 | intersection_pct = cell_polygonbox.intersection_pct(spanning_cell) 316 | # Make sure cells intersect, and that the spanning cell is wider than the current cell (takes up multiple columns) 317 | correct_col_width = sum([col.width for col in columns[l:l + spanning_cell.colspan]]) 318 | if intersection_pct > .9: 319 | if spanning_cell.width > (correct_col_width * .85): 320 | cell_added = True 321 | if zz not in used_spanning_cells: 322 | used_spanning_cells.add(zz) 323 | spanning_cell.col_id = l 324 | cells.append(spanning_cell) 325 | skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell 326 | else: 327 | used_spanning_cells.add(zz) # Skip this spanning cell 328 | 329 | if not cell_added: 330 | cells.append( 331 | TableCell( 332 | polygon=cell_polygon, 333 | row_id=z, 334 | rowspan=1, 335 | cell_id=cell_id, 336 | within_row_id=l, 337 | colspan=1, 338 | merge_up=False, 339 | merge_down=False, 340 | col_id=l, 341 | is_header=row.is_header or col.is_header or z == 0 342 | ) 343 | ) 344 | cell_id += 1 345 | 346 | # Turn cells into a row grid 347 | grid_cells = deepcopy([ 348 | [cell for cell in cells if cell.row_id == row.row_id] 349 | for row in rows 350 | ]) 351 | 352 | # Merge cells across rows 353 | for z, grid_row in enumerate(grid_cells[1:]): 354 | prev_row = grid_cells[z] 355 | for l, cell in enumerate(grid_row): 356 | if l >= len(prev_row): 357 | continue 358 | 359 | above_cell = prev_row[l] 360 | if all([ 361 | above_cell.merge_down, 362 | cell.merge_up, 363 | above_cell.col_id == cell.col_id, 364 | above_cell.colspan == cell.colspan, 365 | ]): 366 | above_cell.merge(cell) 367 | above_cell.rowspan += cell.rowspan 368 | grid_row[l] = above_cell 369 | 370 | merged_cells_all = list(chain.from_iterable(grid_cells)) 371 | used_ids = set() 372 | merged_cells = [] 373 | for cell in merged_cells_all: 374 | if cell.cell_id in used_ids: 375 | continue 376 | used_ids.add(cell.cell_id) 377 | merged_cells.append(cell) 378 | 379 | result = TableResult( 380 | cells=merged_cells, 381 | unmerged_cells=cells, 382 | rows=rows, 383 | cols=columns, 384 | image_bbox=[0, 0, orig_size[0], orig_size[1]], 385 | ) 386 | results.append(result) 387 | return results 388 | ``` -------------------------------------------------------------------------------- /surya/common/surya/processor/__init__.py: -------------------------------------------------------------------------------- ```python 1 | import math 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.nn.utils.rnn import pad_sequence 8 | 9 | from typing import List, Optional, Tuple 10 | 11 | from transformers.feature_extraction_utils import BatchFeature 12 | from transformers.processing_utils import ProcessorMixin 13 | from transformers.tokenization_utils import PreTrainedTokenizer 14 | 15 | from surya.common.s3 import S3DownloaderMixin 16 | from surya.common.surya.processor.schema import ( 17 | TextInput, 18 | ImageInput, 19 | ProcessorOutput, 20 | ) 21 | from surya.common.surya.schema import TaskNames 22 | from surya.logging import get_logger 23 | from surya.settings import settings 24 | 25 | logger = get_logger() 26 | 27 | # Task agnostic tokens - Every task will use these in some form or another 28 | EOS_TOKEN = "</S>" 29 | EOI_TOKEN = "<EOI>" # This is end of INPUT, not image. Images are always followed by a task specific BOS token, so that serves as a delimiter anyways. 30 | IMAGE_TOKEN = "<IMAGE>" 31 | PAD_TOKEN = "<PAD>" 32 | NO_OUTPUT_TOKEN = "<NOP>" 33 | IMAGE_ROTATED_TOKEN = "<ROT>" 34 | REGISTER_TOKENS = ["<REG1>", "<REG2>", "<REG3>", "<REG4>"] 35 | BEACON_TOKEN = "<BEACON>" 36 | NOMATH_TOKEN = "<NO-MATH>" 37 | 38 | # Task specific tokens 39 | OCR_WITH_BOXES_BOS_TOKEN = "<OCR-WB>" 40 | OCR_WITHOUT_BOXES_BOS_TOKEN = "<OCR-WOB>" 41 | BLOCK_WITHOUT_BOXES_TOKEN = "<BLOCKS-WOB>" 42 | LAYOUT_BOS_TOKEN = "<LAYOUT>" 43 | TABLE_STRUCTURE_BOS_TOKEN = "<TABLE-STRUCTURE>" 44 | 45 | 46 | class SuryaOCRProcessor(S3DownloaderMixin, ProcessorMixin): 47 | attributes = ["image_processor", "ocr_tokenizer"] 48 | image_processor_class = "BaseImageProcessor" 49 | ocr_tokenizer_class = "PreTrainedTokenizer" 50 | rescale_factor = 1 / 255.0 51 | image_mean = (0.485, 0.456, 0.406) 52 | image_std = (0.229, 0.224, 0.225) 53 | 54 | def __init__( 55 | self, 56 | ocr_tokenizer: PreTrainedTokenizer, 57 | blank_bbox_token_id: int, 58 | num_register_tokens: int, 59 | patch_size: int, 60 | merge_size: int, 61 | num_beacon_tokens: int, 62 | beacon_token_interval: int, 63 | model_device: str, 64 | **kwargs, 65 | ): 66 | self.ocr_tokenizer = ocr_tokenizer 67 | self.patch_size = patch_size 68 | self.merge_size = merge_size 69 | self.num_register_tokens = num_register_tokens 70 | self.num_beacon_tokens = num_beacon_tokens 71 | self.beacon_token_interval = beacon_token_interval 72 | 73 | self.tokenizer_vocab_size = 0 74 | for attr in self.attributes: 75 | if "tokenizer" in attr: 76 | self.tokenizer_vocab_size += getattr(self, attr).vocab_size 77 | 78 | self.offsets = {"ocr": 0} 79 | 80 | # Create special token mapping 81 | self.special_token_mapping = self.ocr_tokenizer.system_tokens 82 | 83 | self.register_token_ids = [ 84 | self.special_token_mapping.get(r) for r in REGISTER_TOKENS 85 | ] 86 | self.beacon_token_id = self.special_token_mapping.get(BEACON_TOKEN) 87 | self.image_token_id = self.special_token_mapping.get(IMAGE_TOKEN) 88 | self.pad_token_id = self.special_token_mapping.get(PAD_TOKEN) 89 | self.eos_token_id = self.special_token_mapping.get(EOS_TOKEN) 90 | self.eoi_token_id = self.special_token_mapping.get(EOI_TOKEN) 91 | self.no_output_token = self.special_token_mapping.get(NO_OUTPUT_TOKEN) 92 | self.image_rotated_token = self.special_token_mapping.get(IMAGE_ROTATED_TOKEN) 93 | self.nomath_token = self.special_token_mapping.get(NOMATH_TOKEN) 94 | 95 | self.bos_token_id = { 96 | TaskNames.ocr_with_boxes: self.special_token_mapping.get( 97 | OCR_WITH_BOXES_BOS_TOKEN 98 | ), 99 | TaskNames.ocr_without_boxes: self.special_token_mapping.get( 100 | OCR_WITHOUT_BOXES_BOS_TOKEN 101 | ), 102 | TaskNames.block_without_boxes: self.special_token_mapping.get( 103 | BLOCK_WITHOUT_BOXES_TOKEN 104 | ), 105 | TaskNames.layout: self.special_token_mapping.get(LAYOUT_BOS_TOKEN), 106 | TaskNames.table_structure: self.special_token_mapping.get( 107 | TABLE_STRUCTURE_BOS_TOKEN 108 | ), 109 | } 110 | 111 | if self.image_token_id is None: 112 | logger.warning("Warning: Image token not found in special tokens") 113 | 114 | self.blank_bbox_token_id = blank_bbox_token_id 115 | self.bbox_pad_token_id = self.blank_bbox_token_id 116 | 117 | self.ignore_bbox_token_ids = [ 118 | v 119 | for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items() 120 | if k not in self.ocr_tokenizer.special_tokens["math_external"] 121 | ] 122 | math_end_token = "</math>" 123 | self.math_start_token_ids = [ 124 | v 125 | for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items() 126 | if k in self.ocr_tokenizer.special_tokens["math_external"] 127 | and k != math_end_token 128 | ] 129 | self.math_end_token_ids = [ 130 | v 131 | for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items() 132 | if k == math_end_token 133 | ] 134 | 135 | if self.num_register_tokens > len(self.register_token_ids): 136 | raise ValueError( 137 | "The number of register tokens requested exceeds the number of register tokens defined in the special token mapping." 138 | ) 139 | 140 | self.image_mean = np.array(self.image_mean, dtype=np.float32) 141 | self.image_std = np.array(self.image_std, dtype=np.float32) 142 | self.model_device = model_device 143 | 144 | @property 145 | def vocab_size(self): 146 | return self.tokenizer_vocab_size 147 | 148 | def image_processor(self, image: Image.Image) -> np.ndarray: 149 | # Convert to array 150 | image = np.asarray(image, dtype=np.float32) 151 | return image 152 | 153 | @staticmethod 154 | def scale_to_fit( 155 | img: np.ndarray, 156 | max_size: Tuple[int, int], 157 | min_size: Tuple[int, int] = (168, 168), 158 | ): 159 | # Get current dimensions 160 | height, width = img.shape[:2] 161 | 162 | # Check for empty or invalid image 163 | if width == 0 or height == 0: 164 | return img 165 | 166 | max_width, max_height = max_size 167 | min_width, min_height = min_size 168 | 169 | # Calculate pixel counts 170 | current_pixels = width * height 171 | max_pixels = max_width * max_height 172 | min_pixels = min_width * min_height 173 | 174 | if current_pixels > max_pixels: 175 | scale_factor = (max_pixels / current_pixels) ** 0.5 176 | 177 | new_width = math.floor(width * scale_factor) 178 | new_height = math.floor(height * scale_factor) 179 | elif current_pixels == 0: 180 | return img 181 | elif current_pixels < min_pixels: 182 | scale_factor = (min_pixels / current_pixels) ** 0.5 183 | 184 | new_width = math.ceil(width * scale_factor) 185 | new_height = math.ceil(height * scale_factor) 186 | else: 187 | return img 188 | 189 | return cv2.resize( 190 | img, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4 191 | ) 192 | 193 | def _image_processor(self, image: np.ndarray): 194 | image = image.astype(np.float64) * self.rescale_factor 195 | image = (image.astype(np.float32) - self.image_mean) / self.image_std 196 | return image 197 | 198 | def _process_and_tile( 199 | self, image: np.ndarray 200 | ) -> Tuple[torch.Tensor, Tuple[int, int, int]]: 201 | """ 202 | Resizes the input image to the closest multiple of tile_size while preserving the aspect ratio 203 | and returns a tensor of image tiles. 204 | """ 205 | extra_multipler = ( 206 | 4 if settings.FOUNDATION_XLA else 1 207 | ) # Needed to force same size grid_thws per row with padding 208 | 209 | factor = ( 210 | self.patch_size * self.merge_size * extra_multipler 211 | ) # Make a multiple of window size 212 | 213 | height, width = image.shape[:2] 214 | 215 | h_bar = math.ceil(height / factor) * factor 216 | w_bar = math.ceil(width / factor) * factor 217 | if h_bar != height or w_bar != width: 218 | if height == 0 or width == 0: 219 | image = np.zeros((h_bar, w_bar, 3), dtype=np.uint8) 220 | else: 221 | image = cv2.resize(image, (w_bar, h_bar), interpolation=cv2.INTER_CUBIC) 222 | 223 | # Handle scaling and normalization 224 | image = self._image_processor(image) 225 | height, width = image.shape[:2] 226 | 227 | # Numpy array to torch tensor 228 | img_tensor = torch.from_numpy(image.transpose(2, 0, 1)) 229 | patches = img_tensor.unsqueeze(0) 230 | 231 | channel = patches.shape[1] 232 | grid_t = patches.shape[0] 233 | grid_h, grid_w = height // self.patch_size, width // self.patch_size 234 | 235 | patches = patches.reshape( 236 | grid_t, 237 | 1, 238 | channel, 239 | grid_h // self.merge_size, 240 | self.merge_size, 241 | self.patch_size, 242 | grid_w // self.merge_size, 243 | self.merge_size, 244 | self.patch_size, 245 | ) 246 | patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) 247 | flatten_patches = patches.reshape( 248 | grid_t * grid_h * grid_w, channel * 1 * self.patch_size * self.patch_size 249 | ) 250 | 251 | return flatten_patches, (grid_t, grid_h, grid_w) 252 | 253 | # Handle image input dictionaries - Process image, tile accordingly, and setup the input ids and boxes correspondingly 254 | def _process_image_input(self, image_input: ImageInput) -> ProcessorOutput: 255 | rotated = image_input.get("rotated", False) 256 | image = image_input.get("image", None) 257 | 258 | assert image is not None, ( 259 | "A PIL Image must be provided when the input type is `image`" 260 | ) 261 | image_tiles, grid_thw = self._process_and_tile(image) 262 | 263 | num_tokens = image_tiles.shape[0] / self.merge_size**2 264 | assert num_tokens.is_integer(), ( 265 | f"Expected number of tokens to be an integer, got {num_tokens}" 266 | ) 267 | 268 | input_ids = [self.image_token_id] * int(num_tokens) 269 | input_ids += self.register_token_ids[: self.num_register_tokens] 270 | 271 | # Handle the image being rotated in the imdataset 272 | if rotated: 273 | input_ids = [self.image_rotated_token] + input_ids 274 | 275 | return ProcessorOutput( 276 | input_ids=input_ids, 277 | image_tiles=image_tiles, 278 | grid_thw=grid_thw, 279 | ) 280 | 281 | def _process_text_input(self, text_input: TextInput, task: str) -> ProcessorOutput: 282 | input_text = text_input.get("text", None) 283 | math_mode = text_input.get("math", False) 284 | 285 | input_ids = self.ocr_tokenizer(input_text, tasks=task)["input_ids"][0] 286 | input_ids = [self.offsets["ocr"] + id for id in input_ids] 287 | 288 | # nomath token does not work for layout 289 | if not math_mode and task != "layout": 290 | input_ids.insert(0, self.nomath_token) 291 | 292 | return ProcessorOutput( 293 | input_ids=input_ids, 294 | image_tiles=None, 295 | grid_thw=None, 296 | ) 297 | 298 | def _process_input(self, input_dict: dict, task: str): 299 | input_type = input_dict["type"] 300 | if input_type == "image": 301 | return self._process_image_input(input_dict) 302 | elif input_type == "text": 303 | return self._process_text_input(input_dict, task) 304 | 305 | raise NotImplementedError(f"Input of type `{input_type}` is not implemented") 306 | 307 | # Peprocessing for OCR task 308 | # The task is expected to have - image_dict, user_input_dict, output_dict 309 | # use_input_dict is allowed to have an empty input which is fine, but needs to be present 310 | def _process_ocr_with_boxes( 311 | self, 312 | mixed_input: List[dict], 313 | bos_token_id: int, 314 | task: str = TaskNames.ocr_with_boxes, 315 | ): 316 | processed_input_ids = [] 317 | all_image_tiles = [] 318 | all_grid_thw = [] 319 | 320 | # 1. Process the image input 321 | for i, input_dict in enumerate(mixed_input): 322 | processor_output = self._process_input(input_dict, task) 323 | input_ids = processor_output["input_ids"] 324 | image_tiles = processor_output["image_tiles"] 325 | grid_thw = processor_output["grid_thw"] 326 | 327 | # Special handling of some delimiter tokens 328 | if i == 1: 329 | assert input_dict["type"] == "text", ( 330 | "Expected text input for model input." 331 | ) 332 | # Case for input - Add task specific bos token + end_of_input token 333 | # We do not want the model to learn how to predict inputs. Hence IGNORE_INDEX for these 334 | input_ids = [bos_token_id] + input_ids + [self.eoi_token_id] 335 | if i == 2: 336 | assert input_dict["type"] == "text", ( 337 | "Expected text for final model input" 338 | ) 339 | input_ids = input_ids + [self.eos_token_id] 340 | elif i > 2: 341 | raise ValueError(f"Too many inputs received. Expected is 2 for inference, 3 for training. Received: {len(mixed_input)}") 342 | 343 | # Some input types don't return any image tiles, accounting for that 344 | if image_tiles is not None: 345 | all_image_tiles.append(image_tiles) 346 | all_grid_thw.append(grid_thw) 347 | 348 | processed_input_ids.extend(input_ids) 349 | 350 | return ( 351 | torch.tensor(processed_input_ids, dtype=torch.long), 352 | all_image_tiles, 353 | all_grid_thw, 354 | ) 355 | 356 | def _process_layout(self, mixed_input: List[dict], bos_token_id: int): 357 | return self._process_ocr_with_boxes( 358 | mixed_input, bos_token_id=bos_token_id, task="layout" 359 | ) 360 | 361 | def _process_table_structure(self, mixed_input: List[dict], bos_token_id: int): 362 | return self._process_ocr_with_boxes( 363 | mixed_input, bos_token_id=bos_token_id, task="table_structure" 364 | ) 365 | 366 | def _process_ocr_without_boxes( 367 | self, 368 | mixed_input: List[dict], 369 | bos_token_id: int, 370 | task: str = "ocr_without_boxes", 371 | ): 372 | # Boxes are set to None, so this will work 373 | # TODO: improve this behavior 374 | return self._process_ocr_with_boxes( 375 | mixed_input, bos_token_id=bos_token_id, task=task 376 | ) 377 | 378 | def _process_block_without_boxes( 379 | self, 380 | mixed_input: List[dict], 381 | bos_token_id: int, 382 | task: str = "block_without_boxes", 383 | ): 384 | return self._process_ocr_with_boxes( 385 | mixed_input, bos_token_id=bos_token_id, task=task 386 | ) 387 | 388 | def align_long_axis(self, image: np.ndarray) -> Tuple[np.ndarray, bool]: 389 | height, width, _ = image.shape 390 | if height > width: # Rotate vertical lines 391 | image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) 392 | return image, True 393 | 394 | return image, False 395 | 396 | def __call__( 397 | self, 398 | mixed_batch: List[dict], 399 | padding_side: Optional[str] = "left", 400 | device: Optional[torch.device] = None, 401 | pad_to_multiple: Optional[int] = None, 402 | ): 403 | all_image_tiles = [] 404 | all_input_ids = [] 405 | all_grid_thw = [] 406 | 407 | for b in mixed_batch: 408 | mixed_input = b["inputs"] 409 | task = b["task"] 410 | assert task in self.bos_token_id, f"Task {task} has no bos token defined." 411 | 412 | # Select the correct processing function based on the task type 413 | input_ids, image_tiles, grid_thw = getattr(self, f"_process_{task}")( 414 | mixed_input, self.bos_token_id[task] 415 | ) 416 | 417 | all_input_ids.append(input_ids) 418 | all_image_tiles.extend(image_tiles) 419 | all_grid_thw.extend(grid_thw) 420 | 421 | batched_input_ids = pad_sequence( 422 | all_input_ids, 423 | batch_first=True, 424 | padding_side=padding_side, 425 | padding_value=self.pad_token_id, 426 | ) 427 | 428 | if pad_to_multiple is not None: 429 | current_len = batched_input_ids.shape[1] 430 | # Calculate the next multiple of pad_to_multiple 431 | padded_len = ( 432 | (current_len + pad_to_multiple - 1) // pad_to_multiple 433 | ) * pad_to_multiple 434 | 435 | if padded_len > current_len: 436 | pad_len = padded_len - current_len 437 | batched_input_ids = torch.nn.functional.pad( 438 | batched_input_ids, (pad_len, 0), value=self.pad_token_id 439 | ) 440 | 441 | attention_mask = batched_input_ids.ne(self.pad_token_id) 442 | 443 | # Generating position IDs that are independent of left and right padding; 444 | # This should ensure same results for either padding side. Exact position id for the pad tokens themselves don't matter since they are masked 445 | position_ids = attention_mask.cumsum(dim=-1) - 1 446 | position_ids[position_ids < 0] = ( 447 | 0 # For left padding, the position ids for padding will become -1 because of the shift; Setting to 0 448 | ) 449 | position_ids = ( 450 | attention_mask.to(torch.long) * position_ids 451 | ) # Ensure right pad ids get set to zero 452 | 453 | batched_image_tiles = torch.cat(all_image_tiles, dim=0) 454 | batched_grid_thw = torch.from_numpy(np.array(all_grid_thw)) 455 | 456 | # Pin memory for CUDA 457 | if device == torch.device("cuda"): 458 | batched_image_tiles = batched_image_tiles.pin_memory() 459 | batched_grid_thw = batched_grid_thw.pin_memory() 460 | attention_mask = attention_mask.pin_memory() 461 | batched_input_ids = batched_input_ids.pin_memory() 462 | position_ids = position_ids.pin_memory() 463 | 464 | return BatchFeature( 465 | { 466 | "input_ids": batched_input_ids, 467 | "image_tiles": batched_image_tiles, 468 | "attention_mask": attention_mask, 469 | "position_ids": position_ids, 470 | "grid_thw": batched_grid_thw, 471 | } 472 | ) 473 | 474 | # Decode model outputs; Strips special tokens 475 | def decode(self, tokens: List[int], task: str): 476 | filtered_tokens = [ 477 | t 478 | for t in tokens 479 | if t not in self.special_token_mapping.values() and t != -100 480 | ] # Skip special tokens and loss ignore index 481 | return self.ocr_tokenizer.decode(filtered_tokens, task=task) 482 | ``` -------------------------------------------------------------------------------- /surya/recognition/__init__.py: -------------------------------------------------------------------------------- ```python 1 | from __future__ import annotations 2 | 3 | import re 4 | from typing import List 5 | 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | import torch.nn.functional as F 10 | 11 | from surya.common.polygon import PolygonBox 12 | from surya.common.surya.processor import NOMATH_TOKEN 13 | from surya.common.predictor import BasePredictor 14 | from surya.detection import DetectionPredictor 15 | from surya.foundation import FoundationPredictor 16 | 17 | from surya.input.processing import ( 18 | convert_if_not_rgb, 19 | slice_polys_from_image, 20 | slice_bboxes_from_image, 21 | ) 22 | from surya.recognition.postprocessing import fix_unbalanced_tags 23 | from surya.recognition.util import ( 24 | sort_text_lines, 25 | clean_close_polygons, 26 | unwrap_math, 27 | clean_math_tags, 28 | filter_blacklist_tags, 29 | words_from_chars 30 | ) 31 | from surya.foundation.util import detect_repeat_token, prediction_to_polygon_batch 32 | from surya.recognition.schema import TextLine, OCRResult, TextChar 33 | from surya.common.surya.schema import TaskNames 34 | from surya.settings import settings 35 | from surya.logging import get_logger, configure_logging 36 | 37 | configure_logging() 38 | logger = get_logger() 39 | 40 | class RecognitionPredictor(BasePredictor): 41 | batch_size = settings.RECOGNITION_BATCH_SIZE 42 | default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 128} 43 | 44 | # Override base init - Do not load model 45 | def __init__(self, foundation_predictor: FoundationPredictor): 46 | self.foundation_predictor = foundation_predictor 47 | self.processor = self.foundation_predictor.processor 48 | self.bbox_size = self.foundation_predictor.model.config.bbox_size 49 | self.tasks = self.foundation_predictor.tasks 50 | 51 | # Special handling for disable tqdm to pass into foundation predictor 52 | # Make sure they are kept in sync 53 | @property 54 | def disable_tqdm(self) -> bool: 55 | return super().disable_tqdm 56 | 57 | @disable_tqdm.setter 58 | def disable_tqdm(self, value: bool) -> None: 59 | self._disable_tqdm = bool(value) 60 | self.foundation_predictor.disable_tqdm = bool(value) 61 | 62 | def detect_and_slice_bboxes( 63 | self, 64 | images: List[Image.Image], 65 | task_names: List[str], 66 | det_predictor: DetectionPredictor, 67 | detection_batch_size: int | None = None, 68 | highres_images: List[Image.Image] | None = None, 69 | ): 70 | det_predictions = det_predictor(images, batch_size=detection_batch_size) 71 | 72 | all_slices = [] 73 | slice_map = [] 74 | all_polygons = [] 75 | all_task_names = [] 76 | all_res_scales = [] 77 | 78 | for idx, (det_pred, image, highres_image, task_name) in enumerate( 79 | zip(det_predictions, images, highres_images, task_names) 80 | ): 81 | polygons = [p.polygon for p in det_pred.bboxes] 82 | if highres_image: 83 | width_scaler = highres_image.size[0] / image.size[0] 84 | height_scaler = highres_image.size[1] / image.size[1] 85 | scaled_polygons = [ 86 | [ 87 | [int(p[0] * width_scaler), int(p[1] * height_scaler)] 88 | for p in polygon 89 | ] 90 | for polygon in polygons 91 | ] 92 | highres_image = self.processor.image_processor(highres_image) 93 | slices = slice_polys_from_image(highres_image, scaled_polygons) 94 | res_scales = [(width_scaler, height_scaler) for _ in range(len(slices))] 95 | else: 96 | image = self.processor.image_processor(image) 97 | slices = slice_polys_from_image(image, polygons) 98 | res_scales = [(1, 1) for _ in range(len(slices))] 99 | 100 | slice_map.append(len(slices)) 101 | all_slices.extend(slices) 102 | all_polygons.extend(polygons) 103 | all_task_names.extend([task_name] * len(slices)) 104 | all_res_scales.extend(res_scales) 105 | 106 | assert ( 107 | len(all_slices) 108 | == sum(slice_map) 109 | == len(all_polygons) 110 | == len(all_task_names) 111 | == len(all_res_scales) 112 | ) 113 | 114 | return { 115 | "slices": all_slices, 116 | "slice_map": slice_map, 117 | "polygons": all_polygons, 118 | "task_names": all_task_names, 119 | "input_text": [None] * len(all_slices), 120 | "res_scales": all_res_scales, 121 | } 122 | 123 | def slice_bboxes( 124 | self, 125 | images: List[Image.Image], 126 | task_names: List[str], 127 | bboxes: List[List[List[int]]] | None = None, 128 | polygons: List[List[List[List[int]]]] | None = None, 129 | input_text: List[List[str | None]] | None = None, 130 | ) -> dict: 131 | assert bboxes is not None or polygons is not None 132 | slice_map = [] 133 | all_slices = [] 134 | all_polygons = [] 135 | all_text = [] 136 | all_task_names = [] 137 | 138 | for idx, image in enumerate(images): 139 | image = self.processor.image_processor(image) 140 | if polygons is not None: 141 | polys = polygons[idx] 142 | slices = slice_polys_from_image(image, polys) 143 | else: 144 | slices = slice_bboxes_from_image(image, bboxes[idx]) 145 | polys = [ 146 | [ 147 | [bbox[0], bbox[1]], 148 | [bbox[2], bbox[1]], 149 | [bbox[2], bbox[3]], 150 | [bbox[0], bbox[3]], 151 | ] 152 | for bbox in bboxes[idx] 153 | ] 154 | slice_map.append(len(slices)) 155 | all_slices.extend(slices) 156 | all_polygons.extend(polys) 157 | all_task_names.extend([task_names[idx]] * len(slices)) 158 | 159 | if input_text is None: 160 | all_text.extend([None] * len(slices)) 161 | else: 162 | all_text.extend(input_text[idx]) 163 | 164 | assert ( 165 | len(all_slices) 166 | == sum(slice_map) 167 | == len(all_polygons) 168 | == len(all_text) 169 | == len(all_task_names) 170 | ), ( 171 | f"Mismatch in lengths: {len(all_slices)}, {sum(slice_map)}, {len(all_polygons)}, {len(all_text)}, {len(all_task_names)}" 172 | ) 173 | 174 | return { 175 | "slices": all_slices, 176 | "slice_map": slice_map, 177 | "polygons": all_polygons, 178 | "input_text": all_text, 179 | "task_names": all_task_names, 180 | "res_scales": [(1, 1) for _ in range(len(all_slices))], 181 | } 182 | 183 | def get_bboxes_text( 184 | self, 185 | flat: dict, 186 | predicted_tokens: list, 187 | scores: list, 188 | predicted_polygons: list, 189 | drop_repeated_text: bool = False, 190 | ) -> list: 191 | char_predictions = [] 192 | needs_boxes = [ 193 | self.tasks[task_name]["needs_bboxes"] for task_name in flat["task_names"] 194 | ] 195 | 196 | for slice_idx, ( 197 | slice_image, 198 | image_tokens, 199 | image_polygons, 200 | image_scores, 201 | needs_box, 202 | ) in enumerate( 203 | zip( 204 | flat["slices"], 205 | predicted_tokens, 206 | predicted_polygons, 207 | scores, 208 | needs_boxes, 209 | ) 210 | ): 211 | blank_bbox = [[0, 0], [0, 1], [1, 1], [1, 0]] 212 | if self.processor.no_output_token in image_tokens: 213 | char_predictions.append(None) 214 | continue 215 | 216 | # If the image is very out of distribution, we can get nonsense repeats, and we may need to drop the text entirely 217 | if drop_repeated_text and detect_repeat_token(image_tokens): 218 | char_predictions.append( 219 | [ 220 | TextChar( 221 | text="", 222 | polygon=blank_bbox, 223 | confidence=0, 224 | bbox_valid=False, 225 | ) 226 | ] 227 | ) 228 | continue 229 | 230 | image_polygons = image_polygons[: len(image_tokens)].cpu().numpy().tolist() 231 | 232 | detokenize_sequences = [] 233 | detokenize_sequence = [] 234 | past_char_qwen_token = False 235 | 236 | def _add_detokenize_sequence( 237 | special_token: bool, 238 | past_special_token: bool, 239 | force: bool = False, 240 | ): 241 | nonlocal detokenize_sequence, detokenize_sequences 242 | 243 | if ( 244 | special_token 245 | or past_special_token 246 | or force 247 | ) and detokenize_sequence: 248 | chars = [dt[0] for dt in detokenize_sequence] 249 | scores = [dt[1] for dt in detokenize_sequence] 250 | bboxes = [dt[2] for dt in detokenize_sequence] 251 | 252 | if past_special_token: 253 | detokenize_sequences.append((chars, scores, None, "special")) 254 | else: 255 | detokenize_sequences.append((chars, scores, bboxes, "ocr")) 256 | 257 | detokenize_sequence = [] 258 | 259 | # Split up into sequences to detokenize separately 260 | past_special_token = False 261 | for bbox, char_id, score in zip(image_polygons, image_tokens, image_scores): 262 | if char_id in [ 263 | self.processor.eos_token_id, 264 | self.processor.pad_token_id, 265 | ]: 266 | break 267 | 268 | special_token = ( 269 | char_id >= self.processor.ocr_tokenizer.ocr_tokenizer.SPECIAL_BASE 270 | ) 271 | _add_detokenize_sequence( 272 | special_token, past_special_token 273 | ) 274 | detokenize_sequence.append((char_id, score, bbox)) 275 | past_special_token = special_token 276 | 277 | _add_detokenize_sequence( 278 | False, past_special_token, force=True 279 | ) 280 | 281 | img_chars = [] 282 | for sequence in detokenize_sequences: 283 | token_ids, seq_score, bboxes, token_type = sequence 284 | if token_type == "ocr": 285 | text = self.processor.ocr_tokenizer.decode( 286 | token_ids, task=TaskNames.ocr_with_boxes 287 | ) 288 | bboxes = clean_close_polygons( 289 | bboxes 290 | ) # clean out bboxes that are close, like what happens with multiple utf-16 tokens per char 291 | bbox_idx = 0 292 | for text_idx, text_line in enumerate(text): 293 | img_chars.append( 294 | TextChar( 295 | text=text_line, 296 | polygon=bboxes[bbox_idx], 297 | confidence=seq_score[bbox_idx], 298 | bbox_valid=True, 299 | ) 300 | ) 301 | 302 | # Ensure we don't exceed the bbox count 303 | # Use the last bbox for the rest of the text 304 | if bbox_idx < len(bboxes) - 1: 305 | bbox_idx += 1 306 | elif token_type == "special": 307 | text = self.processor.ocr_tokenizer.decode( 308 | token_ids, task="ocr_without_boxes" 309 | ) 310 | if text in [NOMATH_TOKEN] or re.match(r"<SCRIPT-\w+>", text): 311 | continue 312 | 313 | img_chars.append( 314 | TextChar( 315 | text=text, 316 | polygon=blank_bbox, 317 | confidence=seq_score[0], 318 | bbox_valid=False, 319 | ) 320 | ) 321 | else: 322 | text = self.processor.ocr_tokenizer.decode( 323 | token_ids, task=TaskNames.block_without_boxes 324 | ) 325 | img_chars.append( 326 | TextChar( 327 | text=text, 328 | polygon=blank_bbox, 329 | confidence=seq_score[0], 330 | bbox_valid=False, 331 | ) 332 | ) 333 | 334 | char_predictions.append(img_chars) 335 | 336 | return char_predictions 337 | 338 | def __call__( 339 | self, 340 | images: List[Image.Image], 341 | task_names: List[str] | None = None, 342 | det_predictor: DetectionPredictor | None = None, 343 | detection_batch_size: int | None = None, 344 | recognition_batch_size: int | None = None, 345 | highres_images: List[Image.Image] | None = None, 346 | bboxes: List[List[List[int]]] | None = None, 347 | polygons: List[List[List[List[int]]]] | None = None, 348 | input_text: List[List[str | None]] | None = None, 349 | sort_lines: bool = False, 350 | math_mode: bool = True, 351 | return_words: bool = False, 352 | drop_repeated_text: bool = False, 353 | max_sliding_window: int | None = None, 354 | max_tokens: int | None = None, 355 | filter_tag_list: List[str] = None 356 | ) -> List[OCRResult]: 357 | if task_names is None: 358 | task_names = [TaskNames.ocr_with_boxes] * len(images) 359 | if recognition_batch_size is None: 360 | recognition_batch_size = self.get_batch_size() 361 | 362 | assert len(images) == len(task_names), ( 363 | "You need to pass in one task name for each image" 364 | ) 365 | 366 | images = convert_if_not_rgb(images) 367 | if highres_images is not None: 368 | assert len(images) == len(highres_images), ( 369 | "You need to pass in one highres image for each image" 370 | ) 371 | 372 | highres_images = ( 373 | convert_if_not_rgb(highres_images) 374 | if highres_images is not None 375 | else [None] * len(images) 376 | ) 377 | 378 | if bboxes is None and polygons is None: 379 | assert det_predictor is not None, ( 380 | "You need to pass in a detection predictor if you don't provide bboxes or polygons" 381 | ) 382 | 383 | # Detect then slice 384 | flat = self.detect_and_slice_bboxes( 385 | images, 386 | task_names, 387 | det_predictor, 388 | detection_batch_size=detection_batch_size, 389 | highres_images=highres_images, 390 | ) 391 | else: 392 | if bboxes is not None: 393 | assert len(images) == len(bboxes), ( 394 | "You need to pass in one list of bboxes for each image" 395 | ) 396 | if polygons is not None: 397 | assert len(images) == len(polygons), ( 398 | "You need to pass in one list of polygons for each image" 399 | ) 400 | 401 | flat = self.slice_bboxes( 402 | images, 403 | bboxes=bboxes, 404 | polygons=polygons, 405 | input_text=input_text, 406 | task_names=task_names, 407 | ) 408 | 409 | # No images passed, or no boxes passed, or no text detected in the images 410 | if len(flat["slices"]) == 0: 411 | return [ 412 | OCRResult( 413 | text_lines=[], image_bbox=[0, 0, im.size[0], im.size[1]] 414 | ) 415 | for im in images 416 | ] 417 | 418 | # Sort by image sizes. Negative so that longer images come first, fits in with continuous batching better 419 | sorted_pairs = sorted( 420 | enumerate(flat["slices"]), 421 | key=lambda x: -(x[1].shape[0] * x[1].shape[1]) # height * width 422 | ) 423 | indices, sorted_slices = zip(*sorted_pairs) 424 | 425 | # Reorder input_text and task_names based on the new order 426 | flat["slices"] = list(sorted_slices) 427 | flat["input_text"] = [flat["input_text"][i] for i in indices] 428 | flat["task_names"] = [flat["task_names"][i] for i in indices] 429 | 430 | # Make predictions 431 | predicted_tokens, batch_bboxes, scores, _ = self.foundation_predictor.prediction_loop( 432 | images=flat["slices"], 433 | input_texts=flat["input_text"], 434 | task_names=flat["task_names"], 435 | batch_size=recognition_batch_size, 436 | math_mode=math_mode, 437 | drop_repeated_tokens=True, 438 | max_lookahead_tokens=self.foundation_predictor.model.config.multi_output_distance, 439 | max_sliding_window=max_sliding_window, 440 | max_tokens=max_tokens, 441 | tqdm_desc="Recognizing Text" 442 | ) 443 | 444 | # Get text and bboxes in structured form 445 | bbox_size = self.bbox_size 446 | image_sizes = [img.shape for img in flat["slices"]] 447 | predicted_polygons = prediction_to_polygon_batch( 448 | batch_bboxes, image_sizes, bbox_size, bbox_size // 2 449 | ) 450 | char_predictions = self.get_bboxes_text( 451 | flat, 452 | predicted_tokens, 453 | scores, 454 | predicted_polygons, 455 | drop_repeated_text=drop_repeated_text, 456 | ) 457 | 458 | char_predictions = sorted(zip(indices, char_predictions), key=lambda x: x[0]) 459 | char_predictions = [pred for _, pred in char_predictions] 460 | 461 | predictions_by_image = [] 462 | slice_start = 0 463 | for idx, image in enumerate(images): 464 | slice_end = slice_start + flat["slice_map"][idx] 465 | image_lines = char_predictions[slice_start:slice_end] 466 | polygons = flat["polygons"][slice_start:slice_end] 467 | res_scales = flat["res_scales"][slice_start:slice_end] 468 | slice_start = slice_end 469 | 470 | lines = [] 471 | for text_line, polygon, res_scale in zip(image_lines, polygons, res_scales): 472 | # Special case when input text is good 473 | if not text_line: 474 | lines.append( 475 | TextLine( 476 | text="", 477 | polygon=polygon, 478 | chars=[], 479 | confidence=1, 480 | original_text_good=True, 481 | ) 482 | ) 483 | else: 484 | confidence = ( 485 | float(np.mean([char.confidence for char in text_line])) 486 | if len(text_line) > 0 487 | else 0 488 | ) 489 | poly_box = PolygonBox(polygon=polygon) 490 | for char in text_line: 491 | char.rescale( 492 | res_scale, (1, 1) 493 | ) # Rescale from highres if needed 494 | char.shift( 495 | poly_box.bbox[0], poly_box.bbox[1] 496 | ) # Ensure character boxes match line boxes (relative to page) 497 | char.clamp(poly_box.bbox) 498 | 499 | text_line = fix_unbalanced_tags( 500 | text_line, self.processor.ocr_tokenizer.special_tokens 501 | ) 502 | text_line = filter_blacklist_tags(text_line, filter_tag_list) 503 | text = "".join([char.text for char in text_line]) 504 | text = unwrap_math(text) 505 | text = clean_math_tags(text) 506 | lines.append( 507 | TextLine( 508 | text=text, 509 | polygon=polygon, 510 | chars=text_line, 511 | confidence=confidence, 512 | words=words_from_chars(text_line, poly_box) 513 | if return_words 514 | else [], 515 | ) 516 | ) 517 | 518 | if sort_lines: 519 | lines = sort_text_lines(lines) 520 | predictions_by_image.append( 521 | OCRResult( 522 | text_lines=lines, image_bbox=[0, 0, image.size[0], image.size[1]] 523 | ) 524 | ) 525 | 526 | return predictions_by_image 527 | ``` -------------------------------------------------------------------------------- /surya/common/surya/decoder/__init__.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Callable, List, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from transformers.activations import ACT2FN 7 | from transformers.cache_utils import ( 8 | Cache, 9 | ) 10 | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs 11 | from transformers.modeling_outputs import ( 12 | BaseModelOutputWithPast, 13 | ) 14 | from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS 15 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS 16 | from transformers.processing_utils import Unpack 17 | from transformers.utils import ( 18 | logging, 19 | ) 20 | 21 | from surya.common.pretrained import SuryaPreTrainedModel 22 | from surya.common.surya.decoder.config import SuryaDecoderConfig 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | class Qwen2MLP(nn.Module): 29 | def __init__(self, config): 30 | super().__init__() 31 | self.config = config 32 | self.hidden_size = config.hidden_size 33 | self.intermediate_size = config.intermediate_size 34 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 35 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 36 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 37 | self.act_fn = ACT2FN[config.hidden_act] 38 | 39 | def forward(self, x): 40 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 41 | return down_proj 42 | 43 | 44 | def rotate_half(x): 45 | """Rotates half the hidden dims of the input.""" 46 | x1 = x[..., : x.shape[-1] // 2] 47 | x2 = x[..., x.shape[-1] // 2 :] 48 | return torch.cat((-x2, x1), dim=-1) 49 | 50 | 51 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 52 | """Applies Rotary Position Embedding to the query and key tensors. 53 | 54 | Args: 55 | q (`torch.Tensor`): The query tensor. 56 | k (`torch.Tensor`): The key tensor. 57 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 58 | sin (`torch.Tensor`): The sine part of the rotary embedding. 59 | position_ids (`torch.Tensor`, *optional*): 60 | Deprecated and unused. 61 | unsqueeze_dim (`int`, *optional*, defaults to 1): 62 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 63 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 64 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 65 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 66 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 67 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 68 | Returns: 69 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 70 | """ 71 | cos = cos.unsqueeze(unsqueeze_dim) 72 | sin = sin.unsqueeze(unsqueeze_dim) 73 | q_embed = (q * cos) + (rotate_half(q) * sin) 74 | k_embed = (k * cos) + (rotate_half(k) * sin) 75 | return q_embed, k_embed 76 | 77 | 78 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 79 | """ 80 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 81 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 82 | """ 83 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 84 | if n_rep == 1: 85 | return hidden_states 86 | hidden_states = hidden_states[:, :, None, :, :].expand( 87 | batch, num_key_value_heads, n_rep, slen, head_dim 88 | ) 89 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 90 | 91 | 92 | def eager_attention_forward( 93 | module: nn.Module, 94 | query: torch.Tensor, 95 | key: torch.Tensor, 96 | value: torch.Tensor, 97 | attention_mask: Optional[torch.Tensor], 98 | scaling: float, 99 | dropout: float = 0.0, 100 | **kwargs, 101 | ): 102 | key_states = repeat_kv(key, module.num_key_value_groups) 103 | value_states = repeat_kv(value, module.num_key_value_groups) 104 | 105 | attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling 106 | if attention_mask is not None: 107 | causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] 108 | attn_weights = attn_weights + causal_mask 109 | 110 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( 111 | query.dtype 112 | ) 113 | attn_weights = nn.functional.dropout( 114 | attn_weights, p=dropout, training=module.training 115 | ) 116 | attn_output = torch.matmul(attn_weights, value_states) 117 | attn_output = attn_output.transpose(1, 2).contiguous() 118 | 119 | return attn_output, attn_weights 120 | 121 | 122 | class Qwen2Attention(nn.Module): 123 | """Multi-headed attention from 'Attention Is All You Need' paper""" 124 | 125 | def __init__(self, config: SuryaDecoderConfig, layer_idx: int): 126 | super().__init__() 127 | self.config = config 128 | self.layer_idx = layer_idx 129 | self.head_dim = getattr( 130 | config, "head_dim", config.hidden_size // config.num_attention_heads 131 | ) 132 | self.num_key_value_groups = ( 133 | config.num_attention_heads // config.num_key_value_heads 134 | ) 135 | self.scaling = self.head_dim**-0.5 136 | self.attention_dropout = config.attention_dropout 137 | self.is_causal = True 138 | self.q_proj = nn.Linear( 139 | config.hidden_size, config.num_attention_heads * self.head_dim, bias=True 140 | ) 141 | self.k_proj = nn.Linear( 142 | config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True 143 | ) 144 | self.v_proj = nn.Linear( 145 | config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True 146 | ) 147 | self.o_proj = nn.Linear( 148 | config.num_attention_heads * self.head_dim, config.hidden_size, bias=False 149 | ) 150 | 151 | def forward( 152 | self, 153 | hidden_states: torch.Tensor, 154 | position_embeddings: Tuple[torch.Tensor, torch.Tensor], 155 | attention_mask: Optional[torch.Tensor], 156 | past_key_value: Optional[Cache] = None, 157 | cache_position: Optional[torch.LongTensor] = None, 158 | cache_idxs: Optional[List[int]] = None, 159 | num_valid_tokens: Optional[List[int]] = None, 160 | text_lengths: Optional[List[int]] = None, 161 | prefill: bool = False, 162 | **kwargs: Unpack[FlashAttentionKwargs], 163 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 164 | input_shape = hidden_states.shape[:-1] 165 | hidden_shape = (*input_shape, -1, self.head_dim) 166 | 167 | query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) 168 | key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) 169 | value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) 170 | 171 | cos, sin = position_embeddings 172 | query_states, key_states = apply_rotary_pos_emb( 173 | query_states, key_states, cos, sin 174 | ) 175 | 176 | if past_key_value is not None: 177 | # sin and cos are specific to RoPE models; cache_position needed for the static cache 178 | # cache_idxs, num_valid_tokens, and prefill add support for our new caching mechanism 179 | cache_kwargs = { 180 | "sin": sin, 181 | "cos": cos, 182 | "cache_position": cache_position, 183 | "cache_idxs": cache_idxs, 184 | "num_valid_tokens": num_valid_tokens, 185 | "prefill": prefill, 186 | "text_lengths": text_lengths, 187 | } 188 | key_states, value_states = past_key_value.update( 189 | key_states, value_states, self.layer_idx, cache_kwargs 190 | ) 191 | 192 | attention_interface: Callable = eager_attention_forward 193 | if self.config._attn_implementation != "eager": 194 | if self.config._attn_implementation == "sdpa" and kwargs.get( 195 | "output_attentions", False 196 | ): 197 | logger.warning_once( 198 | "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 199 | 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' 200 | ) 201 | elif self.config._attn_implementation == "flash_attention_2": 202 | # Needed for CPU -> GPU 203 | from surya.common.surya.flash_attn_utils import ( 204 | flash_attn_decode, 205 | flash_attn_prefill, 206 | ) 207 | 208 | if prefill: 209 | attention_interface = flash_attn_prefill 210 | else: 211 | attention_interface = flash_attn_decode 212 | else: 213 | attention_interface = ALL_ATTENTION_FUNCTIONS[ 214 | self.config._attn_implementation 215 | ] 216 | 217 | """ 218 | IMPORTANT: 219 | We sometimes use a custom sliding window impl. during training 220 | 221 | We force this to None to ensure that the HF attention integrations do not 222 | perform any special handling - FA2 in particular will ignore the 4D mask, and use this instead 223 | to infer the final mask 224 | 225 | SDPA ignores this completely, and is fully dependent on the 4D mask - (https://github.com/huggingface/transformers/blob/b9faf2f93085e3cf2c65184a69d1d9e502f95786/src/transformers/integrations/sdpa_attention.py#L23) 226 | """ 227 | sliding_window = None 228 | 229 | attn_output, attn_weights = attention_interface( 230 | self, 231 | query_states, 232 | key_states, 233 | value_states, 234 | attention_mask, 235 | dropout=0.0 if not self.training else self.attention_dropout, 236 | scaling=self.scaling, 237 | sliding_window=sliding_window, # main diff with Llama 238 | **kwargs, 239 | ) 240 | 241 | attn_output = attn_output.reshape(*input_shape, -1).contiguous() 242 | attn_output = self.o_proj(attn_output) 243 | return attn_output, attn_weights 244 | 245 | 246 | class Qwen2RMSNorm(nn.Module): 247 | def __init__(self, hidden_size, eps=1e-6): 248 | """ 249 | Qwen2RMSNorm is equivalent to T5LayerNorm 250 | """ 251 | super().__init__() 252 | self.weight = nn.Parameter(torch.ones(hidden_size)) 253 | self.variance_epsilon = eps 254 | 255 | def forward(self, hidden_states): 256 | input_dtype = hidden_states.dtype 257 | hidden_states = hidden_states.to(torch.float32) 258 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 259 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 260 | return self.weight * hidden_states.to(input_dtype) 261 | 262 | def extra_repr(self): 263 | return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" 264 | 265 | 266 | class Qwen2DecoderLayer(nn.Module): 267 | def __init__(self, config: SuryaDecoderConfig, layer_idx: int): 268 | super().__init__() 269 | self.hidden_size = config.hidden_size 270 | self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) 271 | self.mlp = Qwen2MLP(config) 272 | self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 273 | self.post_attention_layernorm = Qwen2RMSNorm( 274 | config.hidden_size, eps=config.rms_norm_eps 275 | ) 276 | 277 | def forward( 278 | self, 279 | hidden_states: torch.Tensor, 280 | attention_mask: Optional[torch.Tensor] = None, 281 | position_ids: Optional[torch.LongTensor] = None, 282 | past_key_value: Optional[Cache] = None, 283 | output_attentions: Optional[bool] = False, 284 | use_cache: Optional[bool] = False, 285 | cache_position: Optional[torch.LongTensor] = None, 286 | cache_idxs: Optional[List[int]] = None, 287 | num_valid_tokens: Optional[List[int]] = None, 288 | text_lengths: Optional[List[int]] = None, 289 | prefill: bool = False, 290 | position_embeddings: Optional[ 291 | Tuple[torch.Tensor, torch.Tensor] 292 | ] = None, # necessary, but kept here for BC 293 | **kwargs: Unpack[FlashAttentionKwargs], 294 | ) -> Tuple[ 295 | torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] 296 | ]: 297 | residual = hidden_states 298 | 299 | hidden_states = self.input_layernorm(hidden_states) 300 | 301 | # Self Attention 302 | hidden_states, self_attn_weights = self.self_attn( 303 | hidden_states=hidden_states, 304 | attention_mask=attention_mask, 305 | position_ids=position_ids, 306 | past_key_value=past_key_value, 307 | output_attentions=output_attentions, 308 | use_cache=use_cache, 309 | cache_position=cache_position, 310 | position_embeddings=position_embeddings, 311 | cache_idxs=cache_idxs, 312 | num_valid_tokens=num_valid_tokens, 313 | text_lengths=text_lengths, 314 | prefill=prefill, 315 | **kwargs, 316 | ) 317 | hidden_states = residual + hidden_states 318 | 319 | # Fully Connected 320 | residual = hidden_states 321 | hidden_states = self.post_attention_layernorm(hidden_states) 322 | hidden_states = self.mlp(hidden_states) 323 | hidden_states = residual + hidden_states 324 | 325 | outputs = (hidden_states,) 326 | if output_attentions: 327 | outputs += (self_attn_weights,) 328 | 329 | return outputs 330 | 331 | 332 | class Qwen2RotaryEmbedding(nn.Module): 333 | def __init__(self, config: SuryaDecoderConfig, device=None): 334 | super().__init__() 335 | # BC: "rope_type" was originally "type" 336 | if hasattr(config, "rope_scaling") and config.rope_scaling is not None: 337 | self.rope_type = config.rope_scaling.get( 338 | "rope_type", config.rope_scaling.get("type") 339 | ) 340 | else: 341 | self.rope_type = "default" 342 | self.max_seq_len_cached = config.max_position_embeddings 343 | self.original_max_seq_len = config.max_position_embeddings 344 | 345 | self.config = config 346 | self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] 347 | 348 | inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) 349 | self.register_buffer("inv_freq", inv_freq, persistent=False) 350 | self.original_inv_freq = self.inv_freq 351 | 352 | def _dynamic_frequency_update(self, position_ids, device): 353 | """ 354 | dynamic RoPE layers should recompute `inv_freq` in the following situations: 355 | 1 - growing beyond the cached sequence length (allow scaling) 356 | 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) 357 | """ 358 | seq_len = torch.max(position_ids) + 1 359 | if seq_len > self.max_seq_len_cached: # growth 360 | inv_freq, self.attention_scaling = self.rope_init_fn( 361 | self.config, device, seq_len=seq_len 362 | ) 363 | self.register_buffer( 364 | "inv_freq", inv_freq, persistent=False 365 | ) # TODO joao: may break with compilation 366 | self.max_seq_len_cached = seq_len 367 | 368 | if ( 369 | seq_len < self.original_max_seq_len 370 | and self.max_seq_len_cached > self.original_max_seq_len 371 | ): # reset 372 | # This .to() is needed if the model has been moved to a device after being initialized (because 373 | # the buffer is automatically moved, but not the original copy) 374 | self.original_inv_freq = self.original_inv_freq.to(device) 375 | self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 376 | self.max_seq_len_cached = self.original_max_seq_len 377 | 378 | @torch.no_grad() 379 | def forward(self, x, position_ids): 380 | if "dynamic" in self.rope_type: 381 | self._dynamic_frequency_update(position_ids, device=x.device) 382 | 383 | # Core RoPE block 384 | inv_freq_expanded = ( 385 | self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) 386 | ) 387 | position_ids_expanded = position_ids[:, None, :].float() 388 | # Force float32 (see https://github.com/huggingface/transformers/pull/29285) 389 | device_type = x.device.type 390 | device_type = ( 391 | device_type 392 | if isinstance(device_type, str) and device_type != "mps" 393 | else "cpu" 394 | ) 395 | with torch.autocast(device_type=device_type, enabled=False): 396 | freqs = ( 397 | inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float() 398 | ).transpose(1, 2) 399 | emb = torch.cat((freqs, freqs), dim=-1) 400 | cos = emb.cos() 401 | sin = emb.sin() 402 | 403 | # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention 404 | cos = cos * self.attention_scaling 405 | sin = sin * self.attention_scaling 406 | 407 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 408 | 409 | 410 | class Qwen2PreTrainedModel(SuryaPreTrainedModel): 411 | config_class = SuryaDecoderConfig 412 | base_model_prefix = "model" 413 | supports_gradient_checkpointing = True 414 | _no_split_modules = ["Qwen2DecoderLayer"] 415 | _skip_keys_device_placement = ["past_key_values"] 416 | _supports_flash_attn_2 = True 417 | _supports_sdpa = True 418 | _supports_flex_attn = True 419 | _supports_cache_class = True 420 | _supports_quantized_cache = True 421 | _supports_static_cache = True 422 | _supports_attention_backend = True 423 | 424 | def _init_weights(self, module): 425 | std = self.config.initializer_range 426 | if isinstance(module, nn.Linear): 427 | module.weight.data.normal_(mean=0.0, std=std) 428 | if module.bias is not None: 429 | module.bias.data.zero_() 430 | elif isinstance(module, nn.Embedding): 431 | module.weight.data.normal_(mean=0.0, std=std) 432 | if module.padding_idx is not None: 433 | module.weight.data[module.padding_idx].zero_() 434 | 435 | 436 | class SuryaDecoderModel(Qwen2PreTrainedModel): 437 | """ 438 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] 439 | This variant has been modified to remove the embedding layer completely - It only supports inputs_embeds as an input 440 | 441 | Args: 442 | config: Qwen2Config 443 | """ 444 | 445 | def __init__(self, config: SuryaDecoderConfig): 446 | super().__init__(config) 447 | self.padding_idx = config.pad_token_id 448 | self.vocab_size = config.vocab_size 449 | 450 | self.layers = nn.ModuleList( 451 | [ 452 | Qwen2DecoderLayer(config, layer_idx) 453 | for layer_idx in range(config.num_hidden_layers) 454 | ] 455 | ) 456 | self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 457 | self.rotary_emb = Qwen2RotaryEmbedding(config=config) 458 | self.gradient_checkpointing = False 459 | 460 | # Initialize weights and apply final processing 461 | self.post_init() 462 | 463 | def forward( 464 | self, 465 | attention_mask: Optional[torch.Tensor] = None, 466 | position_ids: Optional[torch.LongTensor] = None, 467 | past_key_values: Optional[Cache] = None, 468 | inputs_embeds: Optional[torch.FloatTensor] = None, 469 | use_cache: Optional[bool] = None, 470 | output_attentions: Optional[bool] = None, 471 | output_hidden_states: Optional[bool] = None, 472 | return_dict: Optional[bool] = None, 473 | cache_position: Optional[torch.LongTensor] = None, 474 | cache_idxs: Optional[List[int]] = None, 475 | num_valid_tokens: Optional[List[int]] = None, 476 | text_lengths: Optional[List[int]] = None, 477 | prefill: bool = False, 478 | **flash_attn_kwargs: Unpack[FlashAttentionKwargs], 479 | ) -> Union[Tuple, BaseModelOutputWithPast]: 480 | use_cache = use_cache if use_cache is not None else self.config.use_cache 481 | return_dict = ( 482 | return_dict if return_dict is not None else self.config.use_return_dict 483 | ) 484 | 485 | if inputs_embeds is None: 486 | raise ValueError("You must specify inputs_embeds") 487 | 488 | if cache_position is None: 489 | raise ValueError("You must specify cache_position") 490 | 491 | if position_ids is None: 492 | raise ValueError("You must specify position_ids") 493 | 494 | hidden_states = inputs_embeds 495 | causal_mask = ( 496 | attention_mask # We make the 4D mask in the combined model when needed 497 | ) 498 | 499 | # create position embeddings to be shared across the decoder layers 500 | position_embeddings = self.rotary_emb(hidden_states, position_ids) 501 | 502 | # decoder layers 503 | for decoder_layer in self.layers[: self.config.num_hidden_layers]: 504 | layer_outputs = decoder_layer( 505 | hidden_states, 506 | attention_mask=causal_mask, 507 | position_ids=position_ids, 508 | past_key_value=past_key_values, 509 | output_attentions=output_attentions, 510 | use_cache=use_cache, 511 | cache_position=cache_position, 512 | position_embeddings=position_embeddings, 513 | cache_idxs=cache_idxs, 514 | num_valid_tokens=num_valid_tokens, 515 | prefill=prefill, 516 | text_lengths=text_lengths, 517 | **flash_attn_kwargs, 518 | ) 519 | 520 | hidden_states = layer_outputs[0] 521 | 522 | hidden_states = self.norm(hidden_states) 523 | 524 | output = BaseModelOutputWithPast( 525 | last_hidden_state=hidden_states, 526 | past_key_values=past_key_values if use_cache else None, 527 | ) 528 | return output if return_dict else output.to_tuple() 529 | ```