#
tokens: 44335/50000 10/133 files (page 3/5)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 3 of 5. Use http://codebase.md/datalab-to/surya?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── ISSUE_TEMPLATE
│   │   ├── breaking-bug-report.md
│   │   ├── feature_request.md
│   │   └── output-bug-report.md
│   └── workflows
│       ├── benchmarks.yml
│       ├── ci.yml
│       ├── cla.yml
│       ├── publish.yml
│       └── scripts.yml
├── .gitignore
├── .pre-commit-config.yaml
├── benchmark
│   ├── detection.py
│   ├── layout.py
│   ├── ordering.py
│   ├── recognition.py
│   ├── table_recognition.py
│   ├── texify.py
│   └── utils
│       ├── __init__.py
│       ├── bbox.py
│       ├── metrics.py
│       ├── scoring.py
│       ├── tatr.py
│       ├── tesseract.py
│       ├── textract.py
│       └── verify_benchmark_scores.py
├── CITATION.cff
├── CLA.md
├── detect_layout.py
├── detect_text.py
├── LICENSE
├── ocr_app.py
├── ocr_latex.py
├── ocr_text.py
├── poetry.lock
├── pyproject.toml
├── pytest.ini
├── README.md
├── signatures
│   └── version1
│       └── cla.json
├── static
│   ├── fonts
│   │   └── .gitignore
│   └── images
│       ├── arabic_layout.jpg
│       ├── arabic_reading.jpg
│       ├── arabic_text.jpg
│       ├── arabic.jpg
│       ├── benchmark_chart_small.png
│       ├── benchmark_chart.png
│       ├── benchmark_layout_chart.png
│       ├── benchmark_rec_chart.png
│       ├── benchmark_tablerec_acc.png
│       ├── benchmark_tablerec_speed.png
│       ├── chi_hind_layout.jpg
│       ├── chi_hind_orig.jpg
│       ├── chi_hind_reading.jpg
│       ├── chi_hind_text.jpg
│       ├── chi_hind.jpg
│       ├── chinese_layout.jpg
│       ├── chinese_reading.jpg
│       ├── chinese_text.jpg
│       ├── chinese.jpg
│       ├── excerpt_layout.png
│       ├── excerpt_reading.jpg
│       ├── excerpt_text.png
│       ├── excerpt.png
│       ├── funsd_layout.jpg
│       ├── funsd_reading.jpg
│       ├── funsd_text.jpg
│       ├── funsd.png
│       ├── gcloud_full_langs.png
│       ├── gcloud_rec_bench.png
│       ├── hindi_layout.jpg
│       ├── hindi_reading.jpg
│       ├── hindi_text.jpg
│       ├── hindi.jpg
│       ├── japanese_layout.jpg
│       ├── japanese_reading.jpg
│       ├── japanese_tablerec.png
│       ├── japanese_text.jpg
│       ├── japanese.jpg
│       ├── latex_ocr.png
│       ├── nyt_layout.jpg
│       ├── nyt_order.jpg
│       ├── nyt_text.jpg
│       ├── nyt.jpg
│       ├── paper_layout.jpg
│       ├── paper_reading.jpg
│       ├── paper_tablerec.png
│       ├── paper_text.jpg
│       ├── paper.jpg
│       ├── pres_layout.jpg
│       ├── pres_reading.jpg
│       ├── pres_tablerec.png
│       ├── pres_text.jpg
│       ├── pres.png
│       ├── rec_acc_table.png
│       ├── scanned_layout.jpg
│       ├── scanned_reading.jpg
│       ├── scanned_tablerec.png
│       ├── scanned_tablerec2.png
│       ├── scanned_text.jpg
│       ├── scanned.png
│       ├── surya_rec_perf.png
│       ├── table_rec.png
│       ├── textbook_layout.jpg
│       ├── textbook_order.jpg
│       ├── textbook_text.jpg
│       └── textbook.jpg
├── surya
│   ├── __init__.py
│   ├── common
│   │   ├── __init__.py
│   │   ├── adetr
│   │   │   └── decoder.py
│   │   ├── donut
│   │   │   ├── encoder.py
│   │   │   └── processor.py
│   │   ├── load.py
│   │   ├── polygon.py
│   │   ├── predictor.py
│   │   ├── pretrained.py
│   │   ├── s3.py
│   │   ├── surya
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   ├── decoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── embedder
│   │   │   │   └── __init__.py
│   │   │   ├── encoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── flash_attn_utils.py
│   │   │   ├── processor
│   │   │   │   ├── __init__.py
│   │   │   │   ├── schema.py
│   │   │   │   └── tokenizer.py
│   │   │   └── schema.py
│   │   ├── util.py
│   │   └── xla.py
│   ├── debug
│   │   ├── draw.py
│   │   ├── fonts.py
│   │   ├── katex.js
│   │   ├── render_html.py
│   │   └── text.py
│   ├── detection
│   │   ├── __init__.py
│   │   ├── heatmap.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoderdecoder.py
│   │   ├── parallel.py
│   │   ├── processor.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── foundation
│   │   ├── __init__.py
│   │   ├── cache
│   │   │   ├── __init__.py
│   │   │   ├── dynamic_ops.py
│   │   │   └── static_ops.py
│   │   ├── loader.py
│   │   └── util.py
│   ├── input
│   │   ├── load.py
│   │   └── processing.py
│   ├── layout
│   │   ├── __init__.py
│   │   ├── label.py
│   │   └── schema.py
│   ├── logging.py
│   ├── models.py
│   ├── ocr_error
│   │   ├── __init__.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoder.py
│   │   ├── schema.py
│   │   └── tokenizer.py
│   ├── recognition
│   │   ├── __init__.py
│   │   ├── languages.py
│   │   ├── postprocessing.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── scripts
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── detect_layout.py
│   │   ├── detect_text.py
│   │   ├── finetune_ocr.py
│   │   ├── hf_to_s3.py
│   │   ├── ocr_latex.py
│   │   ├── ocr_text.py
│   │   ├── run_streamlit_app.py
│   │   ├── run_texify_app.py
│   │   ├── streamlit_app.py
│   │   ├── table_recognition.py
│   │   └── texify_app.py
│   ├── settings.py
│   └── table_rec
│       ├── __init__.py
│       ├── loader.py
│       ├── model
│       │   ├── __init__.py
│       │   ├── config.py
│       │   ├── decoder.py
│       │   ├── encoder.py
│       │   └── encoderdecoder.py
│       ├── processor.py
│       ├── schema.py
│       └── shaper.py
├── table_recognition.py
├── tests
│   ├── assets
│   │   └── test_latex.png
│   ├── conftest.py
│   ├── test_detection.py
│   ├── test_foundation.py
│   ├── test_latex_ocr.py
│   ├── test_layout.py
│   ├── test_ocr_errors.py
│   ├── test_recognition.py
│   └── test_table_rec.py
└── texify_app.py
```

# Files

--------------------------------------------------------------------------------
/surya/common/surya/flash_attn_utils.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import Optional
  2 | import torch
  3 | import torch.nn.functional as F
  4 | from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
  5 | from flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
  6 | from flash_attn.bert_padding import index_first_axis as _index_first_axis
  7 | from flash_attn.bert_padding import pad_input
  8 | 
  9 | def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
 10 |     """
 11 |     Retrieves indexing data required to repad unpadded (ragged) tensors.
 12 | 
 13 |     Arguments:
 14 |         attention_mask (`torch.Tensor`):
 15 |             Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
 16 | 
 17 |     Return:
 18 |         indices (`torch.Tensor`):
 19 |             The indices of non-masked tokens from the flattened input sequence.
 20 |         cu_seqlens (`torch.Tensor`):
 21 |             The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
 22 |         max_seqlen_in_batch (`int`):
 23 |             Maximum sequence length in batch.
 24 |     """
 25 |     seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 26 |     indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
 27 |     max_seqlen_in_batch = seqlens_in_batch.max().item()
 28 |     cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
 29 |     return (
 30 |         indices,
 31 |         cu_seqlens,
 32 |         max_seqlen_in_batch,
 33 |     )
 34 | 
 35 | def _upad_input(
 36 |     query_layer: torch.Tensor,
 37 |     key_layer: torch.Tensor,
 38 |     value_layer: torch.Tensor,
 39 |     query_length: int,
 40 |     indices_k,
 41 |     cu_seqlens_k,
 42 |     max_seqlen_in_batch_k
 43 | ):
 44 |     """
 45 |     Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
 46 | 
 47 |     This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
 48 |     tensors for query, key, value tensors.
 49 | 
 50 |     Arguments:
 51 |         query_layer (`torch.Tensor`):
 52 |             Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
 53 |         key_layer (`torch.Tensor`):
 54 |             Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
 55 |         value_layer (`torch.Tensor`):
 56 |             Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
 57 |         attention_mask (`torch.Tensor`):
 58 |             Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
 59 |         query_length (`int`):
 60 |             Target length.
 61 | 
 62 |     Return:
 63 |         query_layer (`torch.Tensor`):
 64 |             Query state without padding. Shape: (total_target_length, num_heads, head_dim).
 65 |         key_layer (`torch.Tensor`):
 66 |             Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
 67 |         value_layer (`torch.Tensor`):
 68 |             Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
 69 |         indices_q (`torch.Tensor`):
 70 |             The indices of non-masked tokens from the flattened input target sequence.
 71 |         (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
 72 |             The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
 73 |         (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
 74 |             Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
 75 |     """
 76 |     batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
 77 | 
 78 |     key_layer = _index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
 79 |     value_layer = _index_first_axis(
 80 |         value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 81 |     )
 82 |     if query_length == kv_seq_len:
 83 |         query_layer = _index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
 84 |         cu_seqlens_q = cu_seqlens_k
 85 |         max_seqlen_in_batch_q = max_seqlen_in_batch_k
 86 |         indices_q = indices_k
 87 |     elif query_length == 1:
 88 |         max_seqlen_in_batch_q = 1
 89 |         cu_seqlens_q = torch.arange(
 90 |             batch_size + 1, dtype=torch.int32, device=query_layer.device
 91 |         )  # There is a memcpy here, that is very bad.
 92 |         indices_q = cu_seqlens_q[:-1]
 93 |         query_layer = query_layer.squeeze(1)
 94 |     else:
 95 |         raise NotImplementedError()
 96 | 
 97 |     return (
 98 |         query_layer,
 99 |         key_layer,
100 |         value_layer,
101 |         indices_q,
102 |         (cu_seqlens_q, cu_seqlens_k),
103 |         (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
104 |     )
105 | 
106 | def flash_attn_prefill(
107 |     module: torch.nn.Module,
108 |     query_states: torch.Tensor,
109 |     key_states: torch.Tensor,
110 |     value_states: torch.Tensor,
111 |     attention_mask: torch.Tensor,
112 |     dropout: float,
113 |     scaling: float,
114 |     query_length: int,
115 |     batch_size: int,
116 |     indices_k: torch.Tensor,
117 |     cu_seqlens_k: torch.Tensor,
118 |     max_seqlen_in_batch_k: int,
119 |     **kwargs
120 | ):
121 |     """
122 |     Wrapper for flash attention during the prefill stage
123 |     query_states must have shape (batch_size, num_heads, seq_len, head_dim)
124 |     key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim)
125 | 
126 |     This is the opposite of what is required by flash attention, but keeps parity with the HF convention
127 | 
128 |     query_length, batch_size, indices_k, cu_seqlens_k, and max_seqlen_in_batch_k should come from the flash attention kwargs
129 |     """
130 |     query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2)
131 |     q_flash, k_flash, v_flash, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
132 |         query_states, key_states, value_states, query_length, indices_k, cu_seqlens_k, max_seqlen_in_batch_k
133 |     )
134 |     cu_seqlens_q, cu_seqlens_k = cu_seq_lens
135 |     max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
136 | 
137 |     # Returning None for attn_weights to match other attention interfaces
138 |     flash_attn_out = _flash_attn_varlen_func(
139 |         q_flash,
140 |         k_flash,
141 |         v_flash,
142 |         cu_seqlens_q=cu_seqlens_q,
143 |         cu_seqlens_k=cu_seqlens_k,
144 |         max_seqlen_q=max_seqlen_in_batch_q,
145 |         max_seqlen_k=max_seqlen_in_batch_k,
146 |         dropout_p=dropout,
147 |         softmax_scale=scaling,
148 |         causal=module.is_causal,
149 |     )
150 |     return pad_input(flash_attn_out, indices_q, batch_size, query_length), None
151 | 
152 | # NOTE: Does not support dropout, accepts argument as kwargs to maintain compatibility
153 | # This function is an order of magnitude faster than the prefill variant, or using the HF interface
154 | def flash_attn_decode(
155 |     module: torch.nn.Module,
156 |     query_states: torch.Tensor,
157 |     key_states: torch.Tensor,
158 |     value_states: torch.Tensor,
159 |     attention_mask: torch.Tensor,
160 |     scaling: float,
161 |     **kwargs,
162 | ):
163 |     """
164 |     Wrapper for flash attention during the decode stage
165 |     
166 |     query_states must have shape (batch_size, num_heads, seq_len, head_dim), 1 is the seq length in the decoding stage
167 |     key_states and value_states must have shape (batch_size, num_kv_heads, kv_len, head_dim)
168 | 
169 |     This is the opposite of what is required by flash attention, but keeps parity with the HF convention
170 | 
171 |     This function computes the left pad and cache seqlens to pass into FA2. For example - 
172 |     Given an attention_mask shaped (batch_size=2, seq_len=8), where 0 = padding, 1 = real token
173 |     attention_mask =
174 |     tensor([
175 |         [0, 0, 1, 1, 1, 0, 0, 0],  # ← batch 0
176 |         [0, 1, 1, 1, 1, 1, 1, 0],  # ← batch 1
177 |     ])
178 |     cache_leftpad = tensor([2, 1], dtype=torch.int32)
179 |     cache_seqlens = tensor([5, 7], dtype=torch.int32)
180 |     These values allow FlashAttention to use a static cache layout with efficient slicing during decoding.
181 |     """
182 |     query_states, key_states, value_states = query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2)
183 | 
184 |     cache_leftpad = (attention_mask == 0).cumprod(dim=1).sum(dim=1).to(torch.int32)
185 |     cache_seqlens = (attention_mask * torch.arange(attention_mask.size(1), device=attention_mask.device)).argmax(dim=1).to(torch.int32) + 1
186 | 
187 |     # Returning None for attn_weights to match other attention interfaces
188 |     return _flash_attn_with_kvcache(
189 |         q=query_states,
190 |         k_cache=key_states,
191 |         v_cache=value_states,
192 |         cache_leftpad=cache_leftpad,
193 |         cache_seqlens=cache_seqlens,
194 |         causal=module.is_causal,
195 |         softmax_scale=scaling,
196 |     ), None
```

--------------------------------------------------------------------------------
/surya/common/util.py:
--------------------------------------------------------------------------------

```python
  1 | import copy
  2 | from typing import List
  3 | import torch
  4 | from functools import lru_cache
  5 | 
  6 | import torch.nn.functional as F
  7 | 
  8 | from surya.common.polygon import PolygonBox
  9 | 
 10 | 
 11 | def clean_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]:
 12 |     new_boxes = []
 13 |     for box_obj in boxes:
 14 |         xs = [point[0] for point in box_obj.polygon]
 15 |         ys = [point[1] for point in box_obj.polygon]
 16 |         if max(xs) == min(xs) or max(ys) == min(ys):
 17 |             continue
 18 | 
 19 |         box = box_obj.bbox
 20 |         contained = False
 21 |         for other_box_obj in boxes:
 22 |             if other_box_obj.polygon == box_obj.polygon:
 23 |                 continue
 24 | 
 25 |             other_box = other_box_obj.bbox
 26 |             if box == other_box:
 27 |                 continue
 28 |             if (
 29 |                 box[0] >= other_box[0]
 30 |                 and box[1] >= other_box[1]
 31 |                 and box[2] <= other_box[2]
 32 |                 and box[3] <= other_box[3]
 33 |             ):
 34 |                 contained = True
 35 |                 break
 36 |         if not contained:
 37 |             new_boxes.append(box_obj)
 38 |     return new_boxes
 39 | 
 40 | 
 41 | def rescale_bbox(bbox, processor_size, image_size):
 42 |     page_width, page_height = processor_size
 43 | 
 44 |     img_width, img_height = image_size
 45 |     width_scaler = img_width / page_width
 46 |     height_scaler = img_height / page_height
 47 | 
 48 |     new_bbox = copy.deepcopy(bbox)
 49 |     new_bbox[0] = int(new_bbox[0] * width_scaler)
 50 |     new_bbox[1] = int(new_bbox[1] * height_scaler)
 51 |     new_bbox[2] = int(new_bbox[2] * width_scaler)
 52 |     new_bbox[3] = int(new_bbox[3] * height_scaler)
 53 |     return new_bbox
 54 | 
 55 | 
 56 | def expand_bbox(bbox, expansion_factor=0.01):
 57 |     expansion_low = 1 - expansion_factor
 58 |     expansion_high = 1 + expansion_factor
 59 |     return [
 60 |         bbox[0] * expansion_low,
 61 |         bbox[1] * expansion_low,
 62 |         bbox[2] * expansion_high,
 63 |         bbox[3] * expansion_high,
 64 |     ]
 65 | 
 66 | SCRIPT_TOKEN_MAPPING = {
 67 |     "latin": "<SCRIPT-LATIN>",
 68 |     "punctuation": "<SCRIPT-PUNCTUATION>",
 69 |     "cyrillic": "<SCRIPT-CYRILLIC>",
 70 |     "arabic": "<SCRIPT-ARABIC>",
 71 |     "chinese": "<SCRIPT-CHINESE>",
 72 |     "japanese": "<SCRIPT-JAPANESE>",
 73 |     "korean": "<SCRIPT-KOREAN>",
 74 |     "symbols": "<SCRIPT-SYMBOLS>",
 75 |     "greek": "<SCRIPT-GREEK>",
 76 |     "armenian": "<SCRIPT-ARMENIAN>",
 77 |     "hebrew": "<SCRIPT-HEBREW>",
 78 |     "devanagari": "<SCRIPT-DEVANAGARI>",
 79 |     "bengali": "<SCRIPT-BENGALI>",
 80 |     "gurmukhi": "<SCRIPT-GURMUKHI>",
 81 |     "gujarati": "<SCRIPT-GUJARATI>",
 82 |     "oriya": "<SCRIPT-ORIYA>",
 83 |     "tamil": "<SCRIPT-TAMIL>",
 84 |     "telugu": "<SCRIPT-TELUGU>",
 85 |     "kannada": "<SCRIPT-KANNADA>",
 86 |     "malayalam": "<SCRIPT-MALAYALAM>",
 87 |     "sinhala": "<SCRIPT-SINHALA>",
 88 |     "thai": "<SCRIPT-THAI>",
 89 |     "lao": "<SCRIPT-LAO>",
 90 |     "myanmar": "<SCRIPT-MYANMAR>",
 91 |     "georgian": "<SCRIPT-GEORGIAN>",
 92 |     "ethiopic": "<SCRIPT-ETHIOPIC>",
 93 |     "khmer": "<SCRIPT-KHMER>",
 94 |     "mongolian": "<SCRIPT-MONGOLIAN>",
 95 |     "math": "<SCRIPT-MATH>",
 96 | }
 97 | 
 98 | @lru_cache(maxsize=1)
 99 | def script_ranges():
100 |     script_categories = {
101 |         # Latin-based scripts (used by English, French, German, etc.)
102 |         "latin": [
103 |             (0x0041, 0x005A),  # Latin uppercase A-Z
104 |             (0x0061, 0x007A),  # Latin lowercase a-z
105 |             (0x0080, 0x00FF),  # Latin-1 Supplement
106 |             (0x0100, 0x017F),  # Latin Extended-A
107 |             (0x0180, 0x024F),  # Latin Extended-B
108 |             (0x0250, 0x02AF),  # IPA Extensions
109 |             (0x02B0, 0x02FF),  # Spacing Modifier Letters
110 |             (0x0300, 0x036F),  # Combining Diacritical Marks
111 |             (0x1E00, 0x1EFF),  # Latin Extended Additional
112 |             (0x2C60, 0x2C7F),  # Latin Extended-C
113 |             (0xA720, 0xA7FF),  # Latin Extended-D
114 |         ],
115 |         # Punctuation, universal characters, and general symbols
116 |         "punctuation": [
117 |             (0x0020, 0x0020),  # Space
118 |             (0x0021, 0x002F),  # Basic punctuation and symbols
119 |             (0x0030, 0x0039),  # Digits 0-9
120 |             (0x003A, 0x0040),  # More punctuation and symbols
121 |             (0x005B, 0x0060),  # More punctuation and symbols
122 |             (0x007B, 0x007F),  # More punctuation and symbols
123 |             (0x2000, 0x206F),  # General Punctuation
124 |         ],
125 |         # Cyrillic scripts (used by Russian, Ukrainian, etc.)
126 |         "cyrillic": [
127 |             (0x0400, 0x04FF),  # Cyrillic
128 |             (0x0500, 0x052F),  # Cyrillic Supplement
129 |         ],
130 |         # Arabic scripts
131 |         "arabic": [
132 |             (0x0600, 0x06FF),  # Arabic
133 |             (0x0750, 0x077F),  # Arabic Supplement
134 |             (0x08A0, 0x08FF),  # Arabic Extended-A
135 |         ],
136 |         # Chinese characters
137 |         "chinese": [
138 |             (0x4E00, 0x9FFF),  # Common CJK Unified Ideographs
139 |             (0x3400, 0x4DBF),  # CJK Extension A
140 |             (0x20000, 0x2A6DF),  # CJK Extension B
141 |         ],
142 |         # Japanese-specific scripts (excluding shared CJK)
143 |         "japanese": [
144 |             (0x3040, 0x30FF),  # Hiragana and Katakana
145 |         ],
146 |         # Korean-specific scripts
147 |         "korean": [
148 |             (0x1100, 0x11FF),  # Hangul Jamo
149 |             (0x3130, 0x318F),  # Hangul Compatibility Jamo
150 |             (0xAC00, 0xD7AF),  # Hangul Syllables
151 |         ],
152 |         # Various mathematical and technical symbols
153 |         "symbols": [
154 |             (0x2070, 0x209F),  # Superscripts and Subscripts
155 |             (0x20A0, 0x20CF),  # Currency Symbols
156 |             (0x2100, 0x214F),  # Letterlike Symbols
157 |             (0x2150, 0x218F),  # Number Forms
158 |             (0x2190, 0x21FF),  # Arrows
159 |             (0x2200, 0x22FF),  # Mathematical Operators
160 |             (0x2300, 0x23FF),  # Miscellaneous Technical
161 |             (0x2500, 0x257F),  # Box Drawing
162 |             (0x2580, 0x259F),  # Block Elements
163 |             (0x25A0, 0x25FF),  # Geometric Shapes
164 |             (0x2600, 0x26FF),  # Miscellaneous Symbols
165 |             (0x2700, 0x27BF),  # Dingbats
166 |             (0x27C0, 0x27EF),  # Miscellaneous Mathematical Symbols-A
167 |             (0x2980, 0x29FF),  # Miscellaneous Mathematical Symbols-B
168 |             (0x2A00, 0x2AFF),  # Supplemental Mathematical Operators
169 |             (0x1D400, 0x1D7FF),  # Mathematical Alphanumeric Symbols
170 |         ],
171 |         # Individual scripts for languages with unique writing systems
172 |         "greek": [(0x0370, 0x03FF)],  # Greek and Coptic
173 |         "armenian": [(0x0530, 0x058F)],  # Armenian
174 |         "hebrew": [(0x0590, 0x05FF)],  # Hebrew
175 |         "devanagari": [(0x0900, 0x097F)],  # Devanagari (Hindi, Sanskrit)
176 |         "bengali": [(0x0980, 0x09FF)],  # Bengali
177 |         "gurmukhi": [(0x0A00, 0x0A7F)],  # Gurmukhi (Punjabi)
178 |         "gujarati": [(0x0A80, 0x0AFF)],  # Gujarati
179 |         "oriya": [(0x0B00, 0x0B7F)],  # Oriya
180 |         "tamil": [(0x0B80, 0x0BFF)],  # Tamil
181 |         "telugu": [(0x0C00, 0x0C7F)],  # Telugu
182 |         "kannada": [(0x0C80, 0x0CFF)],  # Kannada
183 |         "malayalam": [(0x0D00, 0x0D7F)],  # Malayalam
184 |         "sinhala": [(0x0D80, 0x0DFF)],  # Sinhala
185 |         "thai": [(0x0E00, 0x0E7F)],  # Thai
186 |         "lao": [(0x0E80, 0x0EFF)],  # Lao
187 |         "myanmar": [(0x1000, 0x109F)],  # Myanmar
188 |         "georgian": [(0x10A0, 0x10FF)],  # Georgian
189 |         "ethiopic": [(0x1200, 0x137F)],  # Ethiopic
190 |         "khmer": [(0x1780, 0x17FF)],  # Khmer
191 |         "mongolian": [(0x1800, 0x18AF)],  # Mongolian
192 |     }
193 | 
194 |     # Convert to a flat structure with character ranges
195 |     flat_ranges = {}
196 |     for category, ranges in script_categories.items():
197 |         # Create a set of all characters in this category
198 |         char_set = set()
199 |         for start, end in ranges:
200 |             char_set.update(range(start, end + 1))
201 | 
202 |         # Store the set in flat_ranges
203 |         flat_ranges[category] = char_set
204 | 
205 |     return script_categories, flat_ranges
206 | 
207 | def get_top_scripts(text: str, max_scripts: int = 5):
208 |     script_categories, flat_ranges = script_ranges()
209 |     char_count = {category: 0 for category in script_categories.keys()}
210 |     for char in text:
211 |         for category, char_set in flat_ranges.items():
212 |             if ord(char) in char_set:
213 |                 char_count[category] += 1
214 |                 break
215 | 
216 |     top_scripts = sorted(char_count.items(), key=lambda x: x[1], reverse=True)
217 |     top_scripts = [ts[0] for ts in top_scripts if ts[1] > 0]
218 |     if "<math" in text:
219 |         top_scripts.insert(0, "math")
220 | 
221 |     return top_scripts[:max_scripts]
222 | 
223 | def is_flash_attn_2_supported(device: str | torch.device) -> bool:
224 |     if not torch.cuda.is_available():
225 |         return False
226 | 
227 |     if "cuda" not in str(device):
228 |         return False
229 | 
230 |     # Check CUDA version >= 12.0
231 |     cuda_version_str = torch.version.cuda
232 |     if cuda_version_str is None:
233 |         return False
234 |     cuda_version = tuple(map(int, cuda_version_str.split(".")))
235 |     if cuda_version < (12, 0):
236 |         return False
237 | 
238 |     # Check GPU compute capability (Ampere, Ada, Hopper GPUs)
239 |     major, minor = torch.cuda.get_device_capability()
240 |     compute_capability = major + minor / 10
241 |     if compute_capability < 8.0:
242 |         return False
243 | 
244 |     return True
245 | 
246 | 
247 | def pad_to_batch_size_repeat(tensor: torch.Tensor, batch_size: int):
248 |     current_batch_size = tensor.shape[0]
249 |     if current_batch_size >= batch_size:
250 |         return tensor
251 | 
252 |     pad_size = batch_size - current_batch_size
253 |     if pad_size < 0:
254 |         return tensor
255 | 
256 |     # Repeat the last row pad_size times
257 |     last_row = tensor[-1:].repeat(pad_size, 1, 1)
258 | 
259 |     # Concatenate original tensor with repeated last rows
260 |     return torch.cat([tensor, last_row], dim=0)
261 | 
262 | 
263 | def pad_to_batch_size(tensor: torch.Tensor, batch_size: int):
264 |     current_batch_size = tensor.shape[0]
265 |     if current_batch_size >= batch_size:
266 |         return tensor
267 | 
268 |     pad_size = batch_size - current_batch_size
269 |     padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)
270 | 
271 |     return F.pad(tensor, padding, mode="constant", value=0)
272 | 
```

--------------------------------------------------------------------------------
/surya/scripts/streamlit_app.py:
--------------------------------------------------------------------------------

```python
  1 | import io
  2 | import tempfile
  3 | from typing import List
  4 | 
  5 | import pypdfium2
  6 | import streamlit as st
  7 | 
  8 | from surya.common.surya.schema import TaskNames
  9 | from surya.models import load_predictors
 10 | 
 11 | from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image
 12 | 
 13 | from surya.debug.text import draw_text_on_image
 14 | from PIL import Image, ImageDraw
 15 | from surya.table_rec import TableResult
 16 | from surya.detection import TextDetectionResult
 17 | from surya.recognition import OCRResult
 18 | from surya.layout import LayoutResult
 19 | from surya.settings import settings
 20 | from surya.common.util import rescale_bbox, expand_bbox
 21 | 
 22 | 
 23 | @st.cache_resource()
 24 | def load_predictors_cached():
 25 |     return load_predictors()
 26 | 
 27 | 
 28 | def ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
 29 |     from pdftext.extraction import plain_text_output
 30 | 
 31 |     with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
 32 |         f.write(pdf_file.getvalue())
 33 |         f.seek(0)
 34 | 
 35 |         # Sample the text from the middle of the PDF
 36 |         page_middle = page_count // 2
 37 |         page_range = range(
 38 |             max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count)
 39 |         )
 40 |         text = plain_text_output(f.name, page_range=page_range)
 41 | 
 42 |     sample_gap = len(text) // max_samples
 43 |     if len(text) == 0 or sample_gap == 0:
 44 |         return "This PDF has no text or very little text", ["no text"]
 45 | 
 46 |     if sample_gap < sample_len:
 47 |         sample_gap = sample_len
 48 | 
 49 |     # Split the text into samples for the model
 50 |     samples = []
 51 |     for i in range(0, len(text), sample_gap):
 52 |         samples.append(text[i : i + sample_len])
 53 | 
 54 |     results = predictors["ocr_error"](samples)
 55 |     label = "This PDF has good text."
 56 |     if results.labels.count("bad") / len(results.labels) > 0.2:
 57 |         label = "This PDF may have garbled or bad OCR text."
 58 |     return label, results.labels
 59 | 
 60 | 
 61 | def text_detection(img) -> (Image.Image, TextDetectionResult):
 62 |     text_pred = predictors["detection"]([img])[0]
 63 |     text_polygons = [p.polygon for p in text_pred.bboxes]
 64 |     det_img = draw_polys_on_image(text_polygons, img.copy())
 65 |     return det_img, text_pred
 66 | 
 67 | 
 68 | def layout_detection(img) -> (Image.Image, LayoutResult):
 69 |     pred = predictors["layout"]([img])[0]
 70 |     polygons = [p.polygon for p in pred.bboxes]
 71 |     labels = [
 72 |         f"{p.label}-{p.position}-{round(p.top_k[p.label], 2)}" for p in pred.bboxes
 73 |     ]
 74 |     layout_img = draw_polys_on_image(
 75 |         polygons, img.copy(), labels=labels, label_font_size=18
 76 |     )
 77 |     return layout_img, pred
 78 | 
 79 | 
 80 | def table_recognition(
 81 |     img, highres_img, skip_table_detection: bool
 82 | ) -> (Image.Image, List[TableResult]):
 83 |     if skip_table_detection:
 84 |         layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
 85 |         table_imgs = [highres_img]
 86 |     else:
 87 |         _, layout_pred = layout_detection(img)
 88 |         layout_tables_lowres = [
 89 |             line.bbox
 90 |             for line in layout_pred.bboxes
 91 |             if line.label in ["Table", "TableOfContents"]
 92 |         ]
 93 |         table_imgs = []
 94 |         layout_tables = []
 95 |         for tb in layout_tables_lowres:
 96 |             highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
 97 |             # Slightly expand the box
 98 |             highres_bbox = expand_bbox(highres_bbox)
 99 |             table_imgs.append(highres_img.crop(highres_bbox))
100 |             layout_tables.append(highres_bbox)
101 | 
102 |     table_preds = predictors["table_rec"](table_imgs)
103 |     table_img = img.copy()
104 | 
105 |     for results, table_bbox in zip(table_preds, layout_tables):
106 |         adjusted_bboxes = []
107 |         labels = []
108 |         colors = []
109 | 
110 |         for item in results.cells:
111 |             adjusted_bboxes.append(
112 |                 [
113 |                     (item.bbox[0] + table_bbox[0]),
114 |                     (item.bbox[1] + table_bbox[1]),
115 |                     (item.bbox[2] + table_bbox[0]),
116 |                     (item.bbox[3] + table_bbox[1]),
117 |                 ]
118 |             )
119 |             labels.append(item.label)
120 |             if "Row" in item.label:
121 |                 colors.append("blue")
122 |             else:
123 |                 colors.append("red")
124 |         table_img = draw_bboxes_on_image(
125 |             adjusted_bboxes,
126 |             highres_img,
127 |             labels=labels,
128 |             label_font_size=18,
129 |             color=colors,
130 |         )
131 |     return table_img, table_preds
132 | 
133 | 
134 | # Function for OCR
135 | def ocr(
136 |     img: Image.Image,
137 |     highres_img: Image.Image,
138 |     skip_text_detection: bool = False,
139 |     recognize_math: bool = True,
140 |     with_bboxes: bool = True,
141 | ) -> (Image.Image, OCRResult):
142 |     if skip_text_detection:
143 |         img = highres_img
144 |         bboxes = [[[0, 0, img.width, img.height]]]
145 |     else:
146 |         bboxes = None
147 | 
148 |     if with_bboxes:
149 |         tasks = [TaskNames.ocr_with_boxes]
150 |     else:
151 |         tasks = [TaskNames.ocr_without_boxes]
152 | 
153 |     img_pred = predictors["recognition"](
154 |         [img],
155 |         task_names=tasks,
156 |         bboxes=bboxes,
157 |         det_predictor=predictors["detection"],
158 |         highres_images=[highres_img],
159 |         math_mode=recognize_math,
160 |         return_words=True,
161 |     )[0]
162 | 
163 |     bboxes = [line.bbox for line in img_pred.text_lines]
164 |     text = [line.text for line in img_pred.text_lines]
165 |     rec_img = draw_text_on_image(bboxes, text, img.size)
166 | 
167 |     word_boxes = []
168 |     for line in img_pred.text_lines:
169 |         if line.words:
170 |             word_boxes.extend([word.bbox for word in line.words])
171 | 
172 |     box_img = img.copy()
173 |     draw = ImageDraw.Draw(box_img)
174 |     for word_box in word_boxes:
175 |         draw.rectangle(word_box, outline="red", width=2)
176 | 
177 |     return rec_img, img_pred, box_img
178 | 
179 | 
180 | def open_pdf(pdf_file):
181 |     stream = io.BytesIO(pdf_file.getvalue())
182 |     return pypdfium2.PdfDocument(stream)
183 | 
184 | 
185 | @st.cache_data()
186 | def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
187 |     doc = open_pdf(pdf_file)
188 |     renderer = doc.render(
189 |         pypdfium2.PdfBitmap.to_pil,
190 |         page_indices=[page_num - 1],
191 |         scale=dpi / 72,
192 |     )
193 |     png = list(renderer)[0]
194 |     png_image = png.convert("RGB")
195 |     doc.close()
196 |     return png_image
197 | 
198 | 
199 | @st.cache_data()
200 | def page_counter(pdf_file):
201 |     doc = open_pdf(pdf_file)
202 |     doc_len = len(doc)
203 |     doc.close()
204 |     return doc_len
205 | 
206 | 
207 | st.set_page_config(layout="wide")
208 | col1, col2 = st.columns([0.5, 0.5])
209 | 
210 | predictors = load_predictors_cached()
211 | 
212 | st.markdown("""
213 | # Surya OCR Demo
214 | 
215 | This app will let you try surya, a multilingual OCR toolkit.
216 | 
217 | Notes:
218 | 
219 | - This works best on documents with printed text.
220 | - For OCR, the formatting (math, italics, etc) will not show up in the image preview, but it will show up in the returned text lines.
221 | - If OCR doesn't work, try changing the resolution of your image (increase if below 2048px width, otherwise decrease).
222 | 
223 | Find the project [here](https://github.com/VikParuchuri/surya).
224 | """)
225 | 
226 | in_file = st.sidebar.file_uploader(
227 |     "PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"]
228 | )
229 | 
230 | if in_file is None:
231 |     st.stop()
232 | 
233 | filetype = in_file.type
234 | page_count = None
235 | if "pdf" in filetype:
236 |     page_count = page_counter(in_file)
237 |     page_number = st.sidebar.number_input(
238 |         f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count
239 |     )
240 | 
241 |     pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI)
242 |     pil_image_highres = get_page_image(
243 |         in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES
244 |     )
245 | else:
246 |     pil_image = Image.open(in_file).convert("RGB")
247 |     pil_image_highres = pil_image
248 |     page_number = None
249 | 
250 | run_text_det = st.sidebar.button("Run Text Detection")
251 | run_text_rec = st.sidebar.button("Run OCR")
252 | run_layout_det = st.sidebar.button("Run Layout Analysis")
253 | run_table_rec = st.sidebar.button("Run Table Rec")
254 | run_ocr_errors = st.sidebar.button("Run bad PDF text detection")
255 | use_pdf_boxes = st.sidebar.checkbox(
256 |     "PDF table boxes",
257 |     value=True,
258 |     help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.",
259 | )
260 | skip_table_detection = st.sidebar.checkbox(
261 |     "Skip table detection",
262 |     value=False,
263 |     help="Table recognition only: Skip table detection and treat the whole image/page as a table.",
264 | )
265 | skip_text_detection = st.sidebar.checkbox(
266 |     "Skip text detection",
267 |     value=False,
268 |     help="OCR only: Skip text detection and treat the whole image as a single line.",
269 | )
270 | recognize_math = st.sidebar.checkbox(
271 |     "Recognize math in OCR",
272 |     value=True,
273 |     help="Enable math mode in OCR - this will recognize math.",
274 | )
275 | ocr_with_boxes = st.sidebar.checkbox(
276 |     "OCR with boxes",
277 |     value=True,
278 |     help="Enable OCR with boxes - this will predict character-level boxes.",
279 | )
280 | 
281 | if pil_image is None:
282 |     st.stop()
283 | 
284 | # Run Text Detection
285 | if run_text_det:
286 |     det_img, text_pred = text_detection(pil_image)
287 |     with col1:
288 |         st.image(det_img, caption="Detected Text", use_container_width=True)
289 |         st.json(
290 |             text_pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True
291 |         )
292 | 
293 | 
294 | # Run layout
295 | if run_layout_det:
296 |     layout_img, pred = layout_detection(pil_image)
297 |     with col1:
298 |         st.image(layout_img, caption="Detected Layout", use_container_width=True)
299 |         st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True)
300 | 
301 | # Run OCR
302 | if run_text_rec:
303 |     rec_img, pred, box_img = ocr(
304 |         pil_image,
305 |         pil_image_highres,
306 |         skip_text_detection,
307 |         recognize_math,
308 |         with_bboxes=ocr_with_boxes,
309 |     )
310 |     with col1:
311 |         st.image(rec_img, caption="OCR Result", use_container_width=True)
312 |         json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"])
313 |         with json_tab:
314 |             st.json(pred.model_dump(), expanded=False)
315 |         with text_tab:
316 |             st.text("\n".join([p.text for p in pred.text_lines]))
317 | 
318 |         st.image(
319 |             box_img,
320 |             caption="OCR with Word Boxes (for debugging)",
321 |             use_container_width=True,
322 |         )
323 | 
324 | 
325 | if run_table_rec:
326 |     table_img, pred = table_recognition(
327 |         pil_image, pil_image_highres, skip_table_detection
328 |     )
329 |     with col1:
330 |         st.image(table_img, caption="Table Recognition", use_container_width=True)
331 |         st.json([p.model_dump() for p in pred], expanded=True)
332 | 
333 | if run_ocr_errors:
334 |     if "pdf" not in filetype:
335 |         st.error("This feature only works with PDFs.")
336 |     label, results = ocr_errors(in_file, page_count)
337 |     with col1:
338 |         st.write(label)
339 |         st.json(results)
340 | 
341 | with col2:
342 |     st.image(pil_image, caption="Uploaded Image", use_container_width=True)
343 | 
```

--------------------------------------------------------------------------------
/benchmark/recognition.py:
--------------------------------------------------------------------------------

```python
  1 | import re
  2 | import unicodedata
  3 | from collections import defaultdict
  4 | 
  5 | import click
  6 | 
  7 | from benchmark.utils.scoring import overlap_score, overlap_score_exact
  8 | from surya.input.processing import convert_if_not_rgb
  9 | from surya.debug.text import draw_text_on_image
 10 | from surya.foundation import FoundationPredictor
 11 | from surya.recognition import RecognitionPredictor
 12 | from surya.settings import settings
 13 | from surya.recognition.languages import CODE_TO_LANGUAGE
 14 | from benchmark.utils.tesseract import (
 15 |     tesseract_ocr_parallel,
 16 |     surya_lang_to_tesseract,
 17 |     TESS_CODE_TO_LANGUAGE,
 18 | )
 19 | from benchmark.utils.textract import textract_ocr_parallel
 20 | import os
 21 | import datasets
 22 | import json
 23 | import time
 24 | from tabulate import tabulate
 25 | 
 26 | KEY_LANGUAGES = [
 27 |     "Chinese",
 28 |     "Spanish",
 29 |     "English",
 30 |     "Arabic",
 31 |     "Hindi",
 32 |     "Bengali",
 33 |     "Russian",
 34 |     "Japanese",
 35 | ]
 36 | 
 37 | 
 38 | def list_in(lst: str | list, lst2: list):
 39 |     if isinstance(lst, str):
 40 |         lst = [lst]
 41 |     return any([item in lst for item in lst2])
 42 | 
 43 | 
 44 | def standardize_bullets(text):
 45 |     patterns = [
 46 |         r"•\s+",
 47 |         r"·\s+",
 48 |         r"○\s+",
 49 |         r"◦\s+",
 50 |         r"▪\s+",
 51 |         r"▫\s+",
 52 |         r"➢\s+",
 53 |         r"➤\s+",
 54 |         r"★\s+",
 55 |         r"✓\s+",
 56 |         r"✗\s+",
 57 |         r"✦\s+",
 58 |         r"\\bullet\s+",
 59 |     ]
 60 | 
 61 |     combined_pattern = "|".join(patterns)
 62 |     text = re.sub(combined_pattern, "*", text)
 63 | 
 64 |     return text
 65 | 
 66 | 
 67 | def normalize_text(text: str) -> str:
 68 |     # Remove HTML tags
 69 |     text = re.sub(r"<[^>]+>", "", text)
 70 |     # Remove LaTeX tags
 71 |     text = re.sub(r"\\[a-zA-Z]+", "", text)
 72 |     text = standardize_bullets(text)
 73 |     text = unicodedata.normalize("NFKC", text)
 74 |     return text.strip().lower().replace(",", ".")
 75 | 
 76 | 
 77 | @click.command(help="Benchmark recognition model.")
 78 | @click.option(
 79 |     "--results_dir",
 80 |     type=str,
 81 |     help="Path to JSON file with OCR results.",
 82 |     default=os.path.join(settings.RESULT_DIR, "benchmark"),
 83 | )
 84 | @click.option(
 85 |     "--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=None
 86 | )
 87 | @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
 88 | @click.option(
 89 |     "--tesseract", is_flag=True, help="Run benchmarks on tesseract.", default=False
 90 | )
 91 | @click.option(
 92 |     "--textract", is_flag=True, help="Run benchmarks on textract.", default=False
 93 | )
 94 | @click.option(
 95 |     "--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28
 96 | )
 97 | @click.option(
 98 |     "--textract_cpus", type=int, help="Number of CPUs to use for textract.", default=28
 99 | )
100 | @click.option(
101 |     "--languages",
102 |     type=str,
103 |     help="Comma-separated list of languages to benchmark.",
104 |     default=None,
105 | )
106 | @click.option(
107 |     "--print_results",
108 |     is_flag=True,
109 | )
110 | def main(
111 |     results_dir: str,
112 |     max_rows: int,
113 |     debug: bool,
114 |     tesseract: bool,
115 |     textract: bool,
116 |     tess_cpus: int,
117 |     textract_cpus: int,
118 |     languages: str | None,
119 |     print_results: bool,
120 | ):
121 |     foundation_predictor = FoundationPredictor()
122 |     rec_predictor = RecognitionPredictor(foundation_predictor)
123 | 
124 |     split = "train"
125 |     dataset = datasets.load_dataset(
126 |         settings.RECOGNITION_BENCH_DATASET_NAME, split=split
127 |     )
128 | 
129 |     if languages:
130 |         languages = languages.split(",")
131 |         dataset = dataset.filter(
132 |             lambda x: list_in(x["language"], languages), num_proc=4
133 |         )
134 | 
135 |     if max_rows and max_rows < len(dataset):
136 |         dataset = dataset.shuffle(seed=1).select(range(max_rows))
137 | 
138 |     images = list(dataset["image"])
139 |     images = convert_if_not_rgb(images)
140 |     bboxes = list(dataset["bboxes"])
141 |     line_text = list(dataset["text"])
142 |     languages = list(dataset["language"])
143 | 
144 |     print(f"Loaded {len(images)} images. Running OCR...")
145 | 
146 |     start = time.time()
147 |     predictions_by_image = rec_predictor(images, None, bboxes=bboxes)
148 |     surya_time = time.time() - start
149 | 
150 |     lang_list = []
151 |     for lang in languages:
152 |         if not isinstance(lang, list):
153 |             lang_list.append([lang])
154 |         else:
155 |             lang_list.append(lang)
156 | 
157 |     surya_scores = defaultdict(list)
158 |     img_surya_scores = []
159 |     outputs = []
160 |     for idx, (pred, ref_text, langs) in enumerate(
161 |         zip(predictions_by_image, line_text, lang_list)
162 |     ):
163 |         pred_text = [line.text for line in pred.text_lines]
164 | 
165 |         score_ref_text = [normalize_text(line) for line in ref_text]
166 |         score_pred_text = [normalize_text(text) for text in pred_text]
167 |         image_scores, image_weights = overlap_score_exact(
168 |             score_pred_text, score_ref_text
169 |         )
170 |         normalized_scores = [
171 |             score / max(1, weight) for score, weight in zip(image_scores, image_weights)
172 |         ]
173 |         image_score = sum(image_scores) / max(1, sum(image_weights))
174 | 
175 |         img_surya_scores.append(image_score)
176 |         for lang in langs:
177 |             surya_scores[CODE_TO_LANGUAGE[lang]].append(image_score)
178 | 
179 |         assert len(pred_text) == len(ref_text) == len(bboxes[idx])
180 |         if debug:
181 |             for j, (pred_line, ref_line, score, bbox) in enumerate(
182 |                 zip(pred_text, ref_text, normalized_scores, bboxes[idx])
183 |             ):
184 |                 image_slice = images[idx].crop(bbox)
185 | 
186 |                 outputs.append(
187 |                     {
188 |                         "image": image_slice,
189 |                         "bbox": bbox,
190 |                         "score": score,
191 |                         "pred": pred_line,
192 |                         "ref": ref_line,
193 |                         "langs": ",".join(langs),
194 |                     }
195 |                 )
196 | 
197 |     if debug:
198 |         out_ds = datasets.Dataset.from_list(outputs)
199 |         out_ds.push_to_hub("datalab-to/rec_bench_outputs", private=True)
200 | 
201 |     flat_surya_scores = [score for lang in surya_scores for score in surya_scores[lang]]
202 |     benchmark_stats = {
203 |         "surya": {
204 |             "avg_score": sum(flat_surya_scores) / max(1, len(flat_surya_scores)),
205 |             "lang_scores": {
206 |                 lang: sum(scores) / max(1, len(scores))
207 |                 for lang, scores in surya_scores.items()
208 |             },
209 |             "time_per_img": surya_time / max(1, len(images)),
210 |         }
211 |     }
212 | 
213 |     result_path = os.path.join(results_dir, "rec_bench")
214 |     os.makedirs(result_path, exist_ok=True)
215 | 
216 |     with open(os.path.join(result_path, "surya_scores.json"), "w+") as f:
217 |         json.dump(surya_scores, f)
218 | 
219 |     if tesseract:
220 |         tess_valid = []
221 |         tess_langs = []
222 |         for idx, lang in enumerate(lang_list):
223 |             # Tesseract does not support all languages
224 |             tess_lang = surya_lang_to_tesseract(lang[0])
225 |             if tess_lang is None:
226 |                 continue
227 | 
228 |             tess_valid.append(idx)
229 |             tess_langs.append(tess_lang)
230 | 
231 |         tess_imgs = [images[i] for i in tess_valid]
232 |         tess_bboxes = [bboxes[i] for i in tess_valid]
233 |         tess_reference = [line_text[i] for i in tess_valid]
234 |         start = time.time()
235 |         tess_predictions = tesseract_ocr_parallel(
236 |             tess_imgs, tess_bboxes, tess_langs, cpus=tess_cpus
237 |         )
238 |         tesseract_time = time.time() - start
239 | 
240 |         tess_scores = defaultdict(list)
241 |         for idx, (pred, ref_text, lang) in enumerate(
242 |             zip(tess_predictions, tess_reference, tess_langs)
243 |         ):
244 |             image_scores, image_weights, _ = overlap_score(pred, ref_text)
245 |             image_score = sum(image_scores) / max(1, sum(image_weights))
246 |             tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score)
247 | 
248 |         flat_tess_scores = [
249 |             score for lang in tess_scores for score in tess_scores[lang]
250 |         ]
251 |         benchmark_stats["tesseract"] = {
252 |             "avg_score": sum(flat_tess_scores) / len(flat_tess_scores),
253 |             "lang_scores": {
254 |                 lang: sum(scores) / len(scores) for lang, scores in tess_scores.items()
255 |             },
256 |             "time_per_img": tesseract_time / len(tess_imgs),
257 |         }
258 | 
259 |         with open(os.path.join(result_path, "tesseract_scores.json"), "w+") as f:
260 |             json.dump(tess_scores, f)
261 | 
262 |     if textract:
263 |         start = time.time()
264 |         textract_predictions = textract_ocr_parallel(images, cpus=textract_cpus)
265 |         textract_time = time.time() - start
266 | 
267 |         textract_scores = defaultdict(list)
268 |         for idx, (pred, ref_text, lang) in enumerate(
269 |             zip(textract_predictions, line_text, lang_list)
270 |         ):
271 |             image_scores, image_weights, _ = overlap_score(pred, ref_text)
272 |             image_score = sum(image_scores) / max(1, sum(image_weights))
273 | 
274 |             for lang in lang:
275 |                 textract_scores[CODE_TO_LANGUAGE[lang]].append(image_score)
276 | 
277 |         flat_textract_scores = [
278 |             score for lang in textract_scores for score in textract_scores[lang]
279 |         ]
280 |         benchmark_stats["textract"] = {
281 |             "avg_score": sum(flat_textract_scores) / len(flat_textract_scores),
282 |             "lang_scores": {
283 |                 lang: sum(scores) / len(scores)
284 |                 for lang, scores in textract_scores.items()
285 |             },
286 |             "time_per_img": textract_time / len(images),
287 |         }
288 |         print(len(flat_textract_scores))
289 | 
290 |         with open(os.path.join(result_path, "textract_scores.json"), "w+") as f:
291 |             json.dump(textract_scores, f)
292 | 
293 |     with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
294 |         json.dump(benchmark_stats, f)
295 | 
296 |     key_languages = [k for k in KEY_LANGUAGES if k in surya_scores]
297 |     table_headers = ["Model", "Time per page (s)", "Avg Score"] + key_languages
298 |     table_data = [
299 |         [
300 |             "surya",
301 |             benchmark_stats["surya"]["time_per_img"],
302 |             benchmark_stats["surya"]["avg_score"],
303 |         ]
304 |         + [benchmark_stats["surya"]["lang_scores"][lang] for lang in key_languages],
305 |     ]
306 |     if tesseract:
307 |         table_data.append(
308 |             [
309 |                 "tesseract",
310 |                 benchmark_stats["tesseract"]["time_per_img"],
311 |                 benchmark_stats["tesseract"]["avg_score"],
312 |             ]
313 |             + [
314 |                 benchmark_stats["tesseract"]["lang_scores"].get(lang, 0)
315 |                 for lang in key_languages
316 |             ]
317 |         )
318 |     if textract:
319 |         table_data.append(
320 |             [
321 |                 "textract",
322 |                 benchmark_stats["textract"]["time_per_img"],
323 |                 benchmark_stats["textract"]["avg_score"],
324 |             ]
325 |             + [
326 |                 benchmark_stats["textract"]["lang_scores"][lang]
327 |                 for lang in key_languages
328 |             ],
329 |         )
330 | 
331 |     print(tabulate(table_data, headers=table_headers, tablefmt="github"))
332 |     print(
333 |         "Only a few major languages are displayed. See the result path for additional languages."
334 |     )
335 | 
336 |     if debug >= 1:
337 |         bad_detections = []
338 |         for idx, (score, lang) in enumerate(zip(flat_surya_scores, lang_list)):
339 |             if score < 0.8:
340 |                 bad_detections.append((idx, lang, score))
341 |         print(f"Found {len(bad_detections)} bad detections. Writing to file...")
342 |         with open(os.path.join(result_path, "bad_detections.json"), "w+") as f:
343 |             json.dump(bad_detections, f)
344 | 
345 |     if debug == 2:
346 |         for idx, (image, pred, ref_text, bbox, lang) in enumerate(
347 |             zip(images, predictions_by_image, line_text, bboxes, lang_list)
348 |         ):
349 |             pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png"
350 |             ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png"
351 |             pred_text = [line.text for line in pred.text_lines]
352 |             pred_image = draw_text_on_image(bbox, pred_text, image.size)
353 |             pred_image.save(os.path.join(result_path, pred_image_name))
354 |             ref_image = draw_text_on_image(bbox, ref_text, image.size)
355 |             ref_image.save(os.path.join(result_path, ref_image_name))
356 |             image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png"))
357 | 
358 |     print(f"Wrote results to {result_path}")
359 | 
360 |     if print_results:
361 |         for idx, (pred, ref_text) in enumerate(zip(predictions_by_image, line_text)):
362 |             print(f"Image {idx}")
363 |             print("----")
364 |             for line_idx, (pred_line, ref_line) in enumerate(
365 |                 zip(pred.text_lines, ref_text)
366 |             ):
367 |                 print(f"Sample {line_idx}")
368 |                 print(f"Pred: {pred_line.text}")
369 |                 print(f"Ref: {ref_line}")
370 |                 print()
371 | 
372 |     if settings.TORCH_DEVICE == "xla":
373 |         import torch_xla.debug.metrics as met
374 | 
375 |         print(met.short_metrics_report())
376 | 
377 | 
378 | if __name__ == "__main__":
379 |     main()
380 | 
```

--------------------------------------------------------------------------------
/surya/detection/processor.py:
--------------------------------------------------------------------------------

```python
  1 | import warnings
  2 | from typing import Any, Dict, List, Optional, Union
  3 | 
  4 | import numpy as np
  5 | 
  6 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
  7 | from transformers.image_transforms import to_channel_dimension_format
  8 | from transformers.image_utils import (
  9 |     IMAGENET_DEFAULT_MEAN,
 10 |     IMAGENET_DEFAULT_STD,
 11 |     ChannelDimension,
 12 |     ImageInput,
 13 |     PILImageResampling,
 14 |     infer_channel_dimension_format,
 15 |     make_list_of_images,
 16 | )
 17 | from transformers.utils import TensorType
 18 | 
 19 | 
 20 | import PIL.Image
 21 | import torch
 22 | 
 23 | from surya.common.s3 import S3DownloaderMixin
 24 | 
 25 | 
 26 | class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
 27 |     r"""
 28 |     Constructs a Segformer image processor.
 29 | 
 30 |     Args:
 31 |         do_resize (`bool`, *optional*, defaults to `True`):
 32 |             Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
 33 |             size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
 34 |         size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`):
 35 |             Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
 36 |             method.
 37 |         resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
 38 |             Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
 39 |             `preprocess` method.
 40 |         do_rescale (`bool`, *optional*, defaults to `True`):
 41 |             Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
 42 |             parameter in the `preprocess` method.
 43 |         rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
 44 |             Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
 45 |             method.
 46 |         do_normalize (`bool`, *optional*, defaults to `True`):
 47 |             Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
 48 |             method.
 49 |         image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
 50 |             Mean to use if normalizing the image. This is a float or list of floats the length of the number of
 51 |             channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
 52 |         image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
 53 |             Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
 54 |             number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
 55 |         do_reduce_labels (`bool`, *optional*, defaults to `False`):
 56 |             Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
 57 |             used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
 58 |             background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
 59 |             `preprocess` method.
 60 |     """
 61 | 
 62 |     model_input_names = ["pixel_values"]
 63 | 
 64 |     def __init__(
 65 |         self,
 66 |         do_resize: bool = True,
 67 |         size: Dict[str, int] = None,
 68 |         resample: PILImageResampling = PILImageResampling.BILINEAR,
 69 |         do_rescale: bool = True,
 70 |         rescale_factor: Union[int, float] = 1 / 255,
 71 |         do_normalize: bool = True,
 72 |         image_mean: Optional[Union[float, List[float]]] = None,
 73 |         image_std: Optional[Union[float, List[float]]] = None,
 74 |         do_reduce_labels: bool = False,
 75 |         **kwargs,
 76 |     ) -> None:
 77 |         if "reduce_labels" in kwargs:
 78 |             warnings.warn(
 79 |                 "The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use "
 80 |                 "`do_reduce_labels` instead.",
 81 |                 FutureWarning,
 82 |             )
 83 |             do_reduce_labels = kwargs.pop("reduce_labels")
 84 | 
 85 |         super().__init__(**kwargs)
 86 |         size = size if size is not None else {"height": 512, "width": 512}
 87 |         size = get_size_dict(size)
 88 |         self.do_resize = do_resize
 89 |         self.size = size
 90 |         self.resample = resample
 91 |         self.do_rescale = do_rescale
 92 |         self.rescale_factor = rescale_factor
 93 |         self.do_normalize = do_normalize
 94 |         self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
 95 |         self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
 96 |         self.do_reduce_labels = do_reduce_labels
 97 |         self._valid_processor_keys = [
 98 |             "images",
 99 |             "segmentation_maps",
100 |             "do_resize",
101 |             "size",
102 |             "resample",
103 |             "do_rescale",
104 |             "rescale_factor",
105 |             "do_normalize",
106 |             "image_mean",
107 |             "image_std",
108 |             "do_reduce_labels",
109 |             "return_tensors",
110 |             "data_format",
111 |             "input_data_format",
112 |         ]
113 | 
114 |     @classmethod
115 |     def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
116 |         """
117 |         Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image
118 |         processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint,
119 |         reduce_labels=True)`
120 |         """
121 |         image_processor_dict = image_processor_dict.copy()
122 |         if "reduce_labels" in kwargs:
123 |             image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
124 |         return super().from_dict(image_processor_dict, **kwargs)
125 | 
126 |     def _preprocess(
127 |         self,
128 |         image: ImageInput,
129 |         do_resize: bool,
130 |         do_rescale: bool,
131 |         do_normalize: bool,
132 |         size: Optional[Dict[str, int]] = None,
133 |         resample: PILImageResampling = None,
134 |         rescale_factor: Optional[float] = None,
135 |         image_mean: Optional[Union[float, List[float]]] = None,
136 |         image_std: Optional[Union[float, List[float]]] = None,
137 |         input_data_format: Optional[Union[str, ChannelDimension]] = None,
138 |     ):
139 | 
140 |         if do_rescale:
141 |             image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
142 | 
143 |         if do_normalize:
144 |             image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
145 | 
146 |         return image
147 | 
148 |     def _preprocess_image(
149 |         self,
150 |         image: ImageInput,
151 |         do_resize: bool = None,
152 |         size: Dict[str, int] = None,
153 |         resample: PILImageResampling = None,
154 |         do_rescale: bool = None,
155 |         rescale_factor: float = None,
156 |         do_normalize: bool = None,
157 |         image_mean: Optional[Union[float, List[float]]] = None,
158 |         image_std: Optional[Union[float, List[float]]] = None,
159 |         data_format: Optional[Union[str, ChannelDimension]] = None,
160 |         input_data_format: Optional[Union[str, ChannelDimension]] = None,
161 |     ) -> np.ndarray:
162 |         """Preprocesses a single image."""
163 |         # All transformations expect numpy arrays.
164 |         if input_data_format is None:
165 |             input_data_format = infer_channel_dimension_format(image)
166 | 
167 |         image = self._preprocess(
168 |             image=image,
169 |             do_resize=do_resize,
170 |             size=size,
171 |             resample=resample,
172 |             do_rescale=do_rescale,
173 |             rescale_factor=rescale_factor,
174 |             do_normalize=do_normalize,
175 |             image_mean=image_mean,
176 |             image_std=image_std,
177 |             input_data_format=input_data_format,
178 |         )
179 |         if data_format is not None:
180 |             image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
181 |         return image
182 | 
183 |     def __call__(self, images, segmentation_maps=None, **kwargs):
184 |         """
185 |         Preprocesses a batch of images and optionally segmentation maps.
186 | 
187 |         Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
188 |         passed in as positional arguments.
189 |         """
190 |         return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
191 | 
192 |     def preprocess(
193 |         self,
194 |         images: ImageInput,
195 |         segmentation_maps: Optional[ImageInput] = None,
196 |         do_resize: Optional[bool] = None,
197 |         size: Optional[Dict[str, int]] = None,
198 |         resample: PILImageResampling = None,
199 |         do_rescale: Optional[bool] = None,
200 |         rescale_factor: Optional[float] = None,
201 |         do_normalize: Optional[bool] = None,
202 |         image_mean: Optional[Union[float, List[float]]] = None,
203 |         image_std: Optional[Union[float, List[float]]] = None,
204 |         do_reduce_labels: Optional[bool] = None,
205 |         return_tensors: Optional[Union[str, TensorType]] = None,
206 |         data_format: ChannelDimension = ChannelDimension.FIRST,
207 |         input_data_format: Optional[Union[str, ChannelDimension]] = None,
208 |         **kwargs,
209 |     ) -> PIL.Image.Image:
210 |         """
211 |         Preprocess an image or batch of images.
212 | 
213 |         Args:
214 |             images (`ImageInput`):
215 |                 Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
216 |                 passing in images with pixel values between 0 and 1, set `do_rescale=False`.
217 |             segmentation_maps (`ImageInput`, *optional*):
218 |                 Segmentation map to preprocess.
219 |             do_resize (`bool`, *optional*, defaults to `self.do_resize`):
220 |                 Whether to resize the image.
221 |             size (`Dict[str, int]`, *optional*, defaults to `self.size`):
222 |                 Size of the image after `resize` is applied.
223 |             resample (`int`, *optional*, defaults to `self.resample`):
224 |                 Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
225 |                 has an effect if `do_resize` is set to `True`.
226 |             do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
227 |                 Whether to rescale the image values between [0 - 1].
228 |             rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
229 |                 Rescale factor to rescale the image by if `do_rescale` is set to `True`.
230 |             do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
231 |                 Whether to normalize the image.
232 |             image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
233 |                 Image mean.
234 |             image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
235 |                 Image standard deviation.
236 |             do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
237 |                 Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
238 |                 is used for background, and background itself is not included in all classes of a dataset (e.g.
239 |                 ADE20k). The background label will be replaced by 255.
240 |             return_tensors (`str` or `TensorType`, *optional*):
241 |                 The type of tensors to return. Can be one of:
242 |                     - Unset: Return a list of `np.ndarray`.
243 |                     - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
244 |                     - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
245 |                     - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
246 |                     - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
247 |             data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
248 |                 The channel dimension format for the output image. Can be one of:
249 |                     - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
250 |                     - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
251 |             input_data_format (`ChannelDimension` or `str`, *optional*):
252 |                 The channel dimension format for the input image. If unset, the channel dimension format is inferred
253 |                 from the input image. Can be one of:
254 |                 - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
255 |                 - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
256 |                 - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
257 |         """
258 |         do_resize = do_resize if do_resize is not None else self.do_resize
259 |         do_rescale = do_rescale if do_rescale is not None else self.do_rescale
260 |         do_normalize = do_normalize if do_normalize is not None else self.do_normalize
261 |         resample = resample if resample is not None else self.resample
262 |         size = size if size is not None else self.size
263 |         rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
264 |         image_mean = image_mean if image_mean is not None else self.image_mean
265 |         image_std = image_std if image_std is not None else self.image_std
266 | 
267 |         images = make_list_of_images(images)
268 |         images = [
269 |             self._preprocess_image(
270 |                 image=img,
271 |                 do_resize=do_resize,
272 |                 resample=resample,
273 |                 size=size,
274 |                 do_rescale=do_rescale,
275 |                 rescale_factor=rescale_factor,
276 |                 do_normalize=do_normalize,
277 |                 image_mean=image_mean,
278 |                 image_std=image_std,
279 |                 data_format=data_format,
280 |                 input_data_format=input_data_format,
281 |             )
282 |             for img in images
283 |         ]
284 | 
285 |         data = {"pixel_values": images}
286 |         return BatchFeature(data=data, tensor_type=return_tensors)
```

--------------------------------------------------------------------------------
/surya/foundation/cache/dynamic_ops.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import Any, Dict, List, Optional, Tuple
  2 | import torch
  3 | from transformers import PretrainedConfig
  4 | 
  5 | """
  6 | Special cache class for the surya foundation model that supports - 
  7 | 1) Static shape
  8 | 2) A custom sliding window, where image tokens stay in cache, and text tokens are popped
  9 | 3) Continuous batching - merging etc
 10 | 4) Attention mask management - To match with what's currently in the cache
 11 | 
 12 | Heavily inspired from https://github.com/huggingface/transformers/blob/0725cd6953803b8aacfc85288cbfb83dea30c469/src/transformers/cache_utils.py#L1079
 13 | """
 14 | 
 15 | 
 16 | class DynamicOpsCache:
 17 |     def __init__(
 18 |         self,
 19 |         config: PretrainedConfig,
 20 |         batch_size: int,
 21 |         max_cache_len: int,
 22 |         text_sliding_window: int,
 23 |         device: int,
 24 |         dtype: int,
 25 |     ):
 26 |         self.text_sliding_window = text_sliding_window
 27 |         self.num_layers = config.num_hidden_layers
 28 |         self.max_batch_size = batch_size
 29 |         self.max_cache_len = max_cache_len
 30 |         self.head_dim = (
 31 |             getattr(config, "head_dim", None)
 32 |             or config.hidden_size // config.num_attention_heads
 33 |         )
 34 |         self._dtype = dtype
 35 |         self.num_key_value_heads = (
 36 |             config.num_attention_heads
 37 |             if getattr(config, "num_key_value_heads", None) is None
 38 |             else config.num_key_value_heads
 39 |         )
 40 | 
 41 |         # Cache init is taken from huggingface StaticCache - https://github.com/huggingface/transformers/blob/67ddc82fbc7e52c6f42a395b4a6d278c55b77a39/src/transformers/cache_utils.py#L1125
 42 |         self.key_cache: list[torch.Tensor] = []
 43 |         self.value_cache: list[torch.Tensor] = []
 44 |         cache_shape = (
 45 |             self.max_batch_size,
 46 |             self.num_key_value_heads,
 47 |             self.max_cache_len,
 48 |             self.head_dim,
 49 |         )
 50 |         device = torch.device(device) if device is not None else None
 51 |         for _ in range(config.num_hidden_layers):
 52 |             new_layer_key_cache = torch.zeros(
 53 |                 cache_shape, dtype=self._dtype, device=device
 54 |             )
 55 |             new_layer_value_cache = torch.zeros(
 56 |                 cache_shape, dtype=self._dtype, device=device
 57 |             )
 58 |             torch._dynamo.mark_static_address(new_layer_key_cache)
 59 |             torch._dynamo.mark_static_address(new_layer_value_cache)
 60 |             self.key_cache.append(new_layer_key_cache)
 61 |             self.value_cache.append(new_layer_value_cache)
 62 | 
 63 |         self.attention_mask = torch.zeros(
 64 |             (self.max_batch_size, self.max_cache_len), device=device, dtype=torch.long
 65 |         )
 66 |         self.text_token_counts = [
 67 |             torch.zeros(self.max_batch_size, dtype=torch.long, device=device)
 68 |             for _ in range(self.num_layers)
 69 |         ]
 70 | 
 71 |         self.dtype = dtype
 72 |         self.device = device
 73 | 
 74 |     def update(
 75 |         self,
 76 |         key_states: torch.Tensor,
 77 |         value_states: torch.Tensor,
 78 |         layer_idx: int,
 79 |         cache_kwargs: Optional[Dict[str, Any]] = None,
 80 |     ) -> Tuple[torch.Tensor, torch.Tensor]:
 81 |         prefill = cache_kwargs.get("prefill", False)
 82 |         update_fn = self._prefill_update if prefill else self._decode_update
 83 |         return update_fn(
 84 |             self.key_cache[layer_idx],
 85 |             self.value_cache[layer_idx],
 86 |             key_states,
 87 |             value_states,
 88 |             self.text_token_counts[layer_idx],
 89 |             cache_kwargs,
 90 |         )
 91 | 
 92 |     def update_text_counts(
 93 |         self,
 94 |         merge_idxs: torch.Tensor,
 95 |         valid_batch_size: torch.Tensor,
 96 |         new_text_lens: torch.Tensor,
 97 |     ):
 98 |         new_text_len_tensor = new_text_lens.to(device=self.device)
 99 | 
100 |         for layer_idx in range(self.num_layers):
101 |             self.text_token_counts[layer_idx][merge_idxs] = new_text_len_tensor[
102 |                 :valid_batch_size
103 |             ]
104 | 
105 |     # Mirrors the logic from _prefill_update
106 |     # Logic is better explained in this funcrtion
107 |     def prefill_attention_mask_update(
108 |         self,
109 |         prefill_attention_mask: torch.Tensor,
110 |         merge_idxs: torch.Tensor,
111 |         valid_batch_mask: torch.Tensor,
112 |         text_lengths: List[int],
113 |     ):
114 |         seq_len = prefill_attention_mask.shape[1]
115 |         sliding_window = self.text_sliding_window
116 |         total_cache_len = self.max_cache_len
117 |         prefix_cache_space = total_cache_len - sliding_window
118 | 
119 |         for batch_idx, cache_idx in enumerate(merge_idxs):
120 |             text_len = text_lengths[batch_idx]
121 |             prefix_len = seq_len - text_len
122 |             self.attention_mask[cache_idx] = 0  # Set default
123 | 
124 |             assert prefix_len > 0, "There are no prefix (image) tokens!"
125 | 
126 |             end_pos = prefix_cache_space
127 |             # Handle prefix part - Which may be left padded
128 |             if prefix_len <= prefix_cache_space:
129 |                 start_pos = prefix_cache_space - prefix_len
130 |                 self.attention_mask[cache_idx, start_pos:end_pos] = (
131 |                     prefill_attention_mask[batch_idx, :prefix_len]
132 |                 )
133 |             else:
134 |                 self.attention_mask[cache_idx, :end_pos] = prefill_attention_mask[
135 |                     batch_idx, prefix_len - prefix_cache_space : prefix_len
136 |                 ]
137 | 
138 |             # Handle text part, keeping sliding window in consideration
139 |             # All of the left padding is before the prefix, so we can ignore the prefill_attention_mask here
140 |             if text_len > 0:
141 |                 text_cache_start = prefix_cache_space
142 |                 if text_len <= sliding_window:
143 |                     self.attention_mask[
144 |                         cache_idx, text_cache_start : text_cache_start + text_len
145 |                     ] = 1
146 |                 else:
147 |                     self.attention_mask[cache_idx, -sliding_window:] = 1
148 | 
149 |     # Slow impl for now - Prefill time is dominated by the large sequence length forward pass
150 |     def _prefill_update(
151 |         self,
152 |         key_cache: torch.Tensor,
153 |         value_cache: torch.Tensor,
154 |         key_states: torch.Tensor,
155 |         value_states: torch.Tensor,
156 |         text_token_counts: torch.Tensor,
157 |         cache_kwargs: Optional[Dict[str, Any]] = None,
158 |     ):
159 |         cache_idxs: List[int] = cache_kwargs.get("cache_idxs", None)
160 |         text_lengths: List[int] = cache_kwargs.get("text_lengths", None)
161 |         assert cache_idxs is not None, "cache_idxs must be specified during prefill"
162 |         assert text_lengths is not None, "text_lengths must be specified during prefill"
163 | 
164 |         _, _, seq_len, _ = key_states.shape
165 |         total_cache_len = self.max_cache_len
166 |         sliding_window = self.text_sliding_window
167 |         prefix_cache_space = total_cache_len - sliding_window
168 | 
169 |         for batch_idx, cache_idx in enumerate(cache_idxs):
170 |             text_len = text_lengths[batch_idx]
171 |             prefix_len = seq_len - text_len
172 | 
173 |             ###### Handle Image Tokens (Prefix) #####
174 |             # Place image tokens in appropriate cache space, aligned to the **right edge**
175 |             assert prefix_len > 0, "There are no prefix (image) tokens!"
176 | 
177 |             # prefix_len may be greater than the prefix cache space due to left padding - This happens when
178 |             # a different batch element has a large input text during prefill, causing others to have a lot of
179 |             # left padding. We can safely take the last `prefix_cache_space` elements from the kv states, since
180 |             # `prefix_cache_space` is large enough to fit any image, and the rest **has to be** padding
181 |             end_pos = prefix_cache_space
182 |             if prefix_len <= prefix_cache_space:
183 |                 start_pos = prefix_cache_space - prefix_len
184 |                 key_cache[cache_idx, :, start_pos:end_pos] = key_states[
185 |                     batch_idx, :, :prefix_len
186 |                 ]
187 |                 value_cache[cache_idx, :, start_pos:end_pos] = value_states[
188 |                     batch_idx, :, :prefix_len
189 |                 ]
190 |             else:
191 |                 key_cache[cache_idx, :, :end_pos] = key_states[
192 |                     batch_idx, :, prefix_len - prefix_cache_space : prefix_len
193 |                 ]
194 |                 value_cache[cache_idx, :, :end_pos] = value_states[
195 |                     batch_idx, :, prefix_len - prefix_cache_space : prefix_len
196 |                 ]
197 | 
198 |             ###### Handle Text Tokens #####
199 |             # Text tokens start at the **left edge** of sliding window cache space
200 |             if text_len > 0:
201 |                 text_cache_start = prefix_cache_space
202 | 
203 |                 if text_len <= sliding_window:
204 |                     key_cache[
205 |                         cache_idx, :, text_cache_start : text_cache_start + text_len
206 |                     ] = key_states[batch_idx, :, prefix_len : prefix_len + text_len]
207 |                     value_cache[
208 |                         cache_idx, :, text_cache_start : text_cache_start + text_len
209 |                     ] = value_states[batch_idx, :, prefix_len : prefix_len + text_len]
210 |                 else:
211 |                     start_in_text = text_len - sliding_window
212 |                     key_cache[
213 |                         cache_idx,
214 |                         :,
215 |                         text_cache_start : text_cache_start + sliding_window,
216 |                     ] = key_states[
217 |                         batch_idx, :, prefix_len + start_in_text : prefix_len + text_len
218 |                     ]
219 |                     value_cache[
220 |                         cache_idx,
221 |                         :,
222 |                         text_cache_start : text_cache_start + sliding_window,
223 |                     ] = value_states[
224 |                         batch_idx, :, prefix_len + start_in_text : prefix_len + text_len
225 |                     ]
226 | 
227 |         # Return the full key/value states (not just cached) for use in subsequent layers
228 |         return key_states, value_states
229 | 
230 |     # """
231 |     # Matches the logic of the decode update, but needs to be called before the updates
232 |     # since some parts of the model depend on the attention mask
233 |     # """
234 |     def decode_attention_mask_update(
235 |         self, num_valid_tokens: torch.Tensor, cache_idxs: List[int]
236 |     ):
237 |         sliding_window = self.text_sliding_window
238 |         text_cache_start = self.max_cache_len - sliding_window
239 | 
240 |         # Using text_token_counts of first layer, should be same for all though
241 |         current_text_lens = self.text_token_counts[0]
242 |         cache_idxs_tensor = torch.tensor(cache_idxs, device=current_text_lens.device)
243 | 
244 |         # Get current text lengths for the relevant cache indices
245 |         current_lens = current_text_lens[cache_idxs_tensor]
246 |         new_text_lens = current_lens + num_valid_tokens
247 |         is_full = new_text_lens > sliding_window
248 | 
249 |         # Handle full caches - set entire sliding window to 1
250 |         if is_full.any():
251 |             full_mask = is_full
252 |             full_cache_idxs = cache_idxs_tensor[full_mask]
253 |             self.attention_mask[full_cache_idxs, text_cache_start:] = 1
254 | 
255 |         # Handle non-full caches - set specific ranges to 1
256 |         if (~is_full).any():
257 |             non_full_mask = ~is_full
258 |             non_full_cache_idxs = cache_idxs_tensor[non_full_mask]
259 |             non_full_current_lens = current_lens[non_full_mask]
260 |             non_full_valid_tokens = num_valid_tokens[non_full_mask]
261 | 
262 |             max_valid_tokens = (
263 |                 non_full_valid_tokens.max().item()
264 |                 if len(non_full_valid_tokens) > 0
265 |                 else 0
266 |             )
267 |             if max_valid_tokens > 0:
268 |                 batch_size = len(non_full_cache_idxs)
269 |                 offset_range = torch.arange(
270 |                     max_valid_tokens, device=current_text_lens.device
271 |                 )
272 |                 batch_offsets = offset_range.unsqueeze(0).expand(batch_size, -1)
273 |                 start_positions = non_full_current_lens.unsqueeze(1)
274 |                 valid_token_counts = non_full_valid_tokens.unsqueeze(1)
275 | 
276 |                 position_indices = start_positions + batch_offsets
277 |                 valid_mask = batch_offsets < valid_token_counts
278 | 
279 |                 row_indices = non_full_cache_idxs.unsqueeze(1).expand(
280 |                     -1, max_valid_tokens
281 |                 )[valid_mask]
282 |                 col_indices = text_cache_start + position_indices[valid_mask]
283 | 
284 |                 self.attention_mask[row_indices, col_indices] = 1
285 | 
286 |     """
287 |     Static cache update
288 |     - respects per-batch text token limits
289 |     - per-batch valid token lengths (right-padded inputs)
290 | 
291 |     kv states are expected to have shape [batch_size, kv_heads, T_pad, head_dim]
292 |     They may have different `true` lengths, to account for multi token preds, or beacon tokens
293 |     Expects `num_valid_tokens` in cache_kwargs: a tensor of shape (B,) indicating the number
294 |     of actual (non-padded) tokens to add per batch element.
295 |     """
296 | 
297 |     def _decode_update(
298 |         self,
299 |         key_cache: torch.Tensor,
300 |         value_cache: torch.Tensor,
301 |         key_states: torch.Tensor,
302 |         value_states: torch.Tensor,
303 |         text_token_counts: torch.Tensor,
304 |         cache_kwargs: Optional[Dict[str, Any]] = None,
305 |     ) -> Tuple[torch.Tensor, torch.Tensor]:
306 |         num_valid_tokens: torch.Tensor = cache_kwargs.get(
307 |             "num_valid_tokens"
308 |         )  # shape: (B,)
309 |         assert num_valid_tokens is not None, (
310 |             "`num_valid_tokens` must be provided in `cache_kwargs`"
311 |         )
312 |         device = key_states.device
313 | 
314 |         batch_size, num_head, seq_len, head_dim = key_states.shape
315 |         sliding_window = self.text_sliding_window
316 |         max_cache_len = self.max_cache_len
317 |         cache_text_start = max_cache_len - sliding_window
318 |         new_text_lengths = text_token_counts + num_valid_tokens
319 |         slide_amounts = torch.clamp(new_text_lengths - sliding_window, min=0)
320 |         needs_rotate = slide_amounts > 0
321 | 
322 |         # Rotate the cache if needed
323 |         if torch.any(needs_rotate):
324 |             k_slice = key_cache[:, :, -sliding_window:]  # shape: [B, H, W, D]
325 |             v_slice = value_cache[:, :, -sliding_window:]  # same shape
326 | 
327 |             cache_indices = (
328 |                 torch.arange(sliding_window, device=device)
329 |                 .unsqueeze(0)
330 |                 .repeat(batch_size, 1)
331 |             )  # [B, W]
332 |             rolled_indices = (
333 |                 cache_indices + slide_amounts.unsqueeze(1)
334 |             ) % sliding_window  # [B, W]
335 | 
336 |             # We need to expand indices to shape: [B, 1, W, 1] to broadcast with k_slice
337 |             rolled_indices = (
338 |                 rolled_indices.unsqueeze(1)
339 |                 .unsqueeze(-1)
340 |                 .expand(-1, num_head, -1, head_dim)
341 |             )
342 | 
343 |             k_slice_rolled = k_slice.gather(dim=2, index=rolled_indices)
344 |             v_slice_rolled = v_slice.gather(dim=2, index=rolled_indices)
345 | 
346 |             key_cache[:, :, -sliding_window:] = k_slice_rolled
347 |             value_cache[:, :, -sliding_window:] = v_slice_rolled
348 | 
349 |         # Insert only **valid tokens** into the cache. These are **right aligned** within the input sequence
350 |         insert_positions = torch.where(
351 |             needs_rotate,
352 |             max_cache_len - num_valid_tokens,
353 |             text_token_counts + cache_text_start,
354 |         )
355 | 
356 |         max_tokens = num_valid_tokens.max().item()
357 |         offsets = torch.arange(max_tokens, device=device).unsqueeze(0)  # [1, max_T]
358 |         valid_mask = offsets < num_valid_tokens.unsqueeze(1)  # [B, max_T]
359 |         src_indices = (seq_len - num_valid_tokens).unsqueeze(1) + offsets  # [B, max_T]
360 |         src_indices = src_indices.clamp(max=seq_len - 1)  # safety
361 | 
362 |         tgt_indices = insert_positions.unsqueeze(1) + offsets  # [B, max_T]
363 |         tgt_indices = tgt_indices.clamp(max=max_cache_len - 1)  # safety
364 | 
365 |         src_idx_exp = (
366 |             src_indices.unsqueeze(1)
367 |             .unsqueeze(-1)
368 |             .expand(batch_size, num_head, max_tokens, head_dim)
369 |         )
370 |         tgt_idx_exp = (
371 |             tgt_indices.unsqueeze(1)
372 |             .unsqueeze(-1)
373 |             .expand(batch_size, num_head, max_tokens, head_dim)
374 |         )
375 |         valid_mask_exp = (
376 |             valid_mask.unsqueeze(1)
377 |             .unsqueeze(-1)
378 |             .expand(batch_size, num_head, max_tokens, head_dim)
379 |         )
380 | 
381 |         k_src = torch.gather(key_states, 2, src_idx_exp)
382 |         v_src = torch.gather(value_states, 2, src_idx_exp)
383 |         k_src = k_src * valid_mask_exp
384 |         v_src = v_src * valid_mask_exp
385 | 
386 |         # Write into cache
387 |         key_cache.scatter_(2, tgt_idx_exp, k_src)
388 |         value_cache.scatter_(2, tgt_idx_exp, v_src)
389 | 
390 |         # In-place edit - Mutates
391 |         text_token_counts += num_valid_tokens
392 |         text_token_counts.clamp_(max=sliding_window)
393 | 
394 |         return key_cache, value_cache
395 | 
396 |     # We have a non-uniform cache, so its better to not return it and handle any logic
397 |     # that requires this ourselves
398 |     def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
399 |         raise NotImplementedError()
400 | 
```

--------------------------------------------------------------------------------
/surya/table_rec/__init__.py:
--------------------------------------------------------------------------------

```python
  1 | from copy import deepcopy
  2 | from itertools import chain
  3 | from typing import List
  4 | 
  5 | import numpy as np
  6 | import torch
  7 | from PIL import Image
  8 | from tqdm import tqdm
  9 | 
 10 | from surya.common.xla import mark_step
 11 | from surya.common.predictor import BasePredictor
 12 | from surya.table_rec.schema import TableCell, TableRow, TableCol, TableResult
 13 | from surya.common.polygon import PolygonBox
 14 | from surya.settings import settings
 15 | from surya.table_rec.loader import TableRecModelLoader
 16 | from surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM, CATEGORY_TO_ID, MERGE_KEYS, \
 17 |     MERGE_VALUES
 18 | from surya.table_rec.shaper import LabelShaper
 19 | 
 20 | 
 21 | class TableRecPredictor(BasePredictor):
 22 |     model_loader_cls = TableRecModelLoader
 23 |     batch_size = settings.TABLE_REC_BATCH_SIZE
 24 |     default_batch_sizes = {
 25 |         "cpu": 8,
 26 |         "mps": 8,
 27 |         "cuda": 32,
 28 |         "xla": 16
 29 |     }
 30 | 
 31 |     def __call__(self, images: List[Image.Image], batch_size: int | None = None) -> List[TableResult]:
 32 |         return self.batch_table_recognition(images, batch_size)
 33 | 
 34 |     def inference_loop(
 35 |             self,
 36 |             encoder_hidden_states: torch.Tensor,
 37 |             batch_input_ids: torch.Tensor,
 38 |             current_batch_size: int,
 39 |             batch_size: int
 40 |     ):
 41 |         shaper = LabelShaper()
 42 |         batch_predictions = [[] for _ in range(current_batch_size)]
 43 |         max_tokens = settings.TABLE_REC_MAX_BOXES
 44 |         decoder_position_ids = torch.ones_like(batch_input_ids[0, :, 0], dtype=torch.int64, device=self.model.device).cumsum(
 45 |             0) - 1
 46 |         inference_token_count = batch_input_ids.shape[1]
 47 | 
 48 |         if settings.TABLE_REC_STATIC_CACHE:
 49 |             encoder_hidden_states = self.pad_to_batch_size(encoder_hidden_states, batch_size)
 50 |             batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size)
 51 | 
 52 |         # Move to device after padding for XLA
 53 |         encoder_hidden_states = encoder_hidden_states.to(self.model.device)
 54 |         batch_input_ids = batch_input_ids.to(self.model.device)
 55 | 
 56 |         self.model.decoder.model._setup_cache(self.model.config, batch_size, self.model.device, self.model.dtype)
 57 | 
 58 |         with settings.INFERENCE_MODE():
 59 |             token_count = 0
 60 |             all_done = torch.zeros(encoder_hidden_states.shape[0], dtype=torch.bool, device=self.model.device)
 61 | 
 62 |             while token_count < max_tokens:
 63 |                 is_prefill = token_count == 0
 64 |                 return_dict = self.model.decoder(
 65 |                     input_ids=batch_input_ids,
 66 |                     encoder_hidden_states=encoder_hidden_states,
 67 |                     cache_position=decoder_position_ids,
 68 |                     use_cache=True,
 69 |                     prefill=is_prefill
 70 |                 )
 71 | 
 72 |                 decoder_position_ids = decoder_position_ids[-1:] + 1
 73 | 
 74 |                 # Get predictions for each box element
 75 |                 box_properties = []
 76 |                 done = []
 77 | 
 78 |                 # Pre-process all logits at once
 79 |                 processed_logits = {}
 80 |                 for k, _, mode in BOX_PROPERTIES:
 81 |                     k_logits = return_dict["box_property_logits"][k][:, -1, :]  # Get all batch logits at once
 82 |                     
 83 |                     if mode == "classification":
 84 |                         # Process all classification logits in one operation
 85 |                         items = torch.argmax(k_logits, dim=-1)
 86 |                         if k == "category":
 87 |                             done = (items == self.model.decoder.config.eos_token_id) | (items == self.model.decoder.config.pad_token_id)
 88 |                         items = items - SPECIAL_TOKENS
 89 |                         processed_logits[k] = items
 90 |                     elif mode == "regression":
 91 |                         if k == "bbox":
 92 |                             k_logits = k_logits * BOX_DIM
 93 |                             processed_logits[k] = k_logits
 94 |                         elif k == "colspan":
 95 |                             k_logits = torch.clamp(k_logits, min=1)
 96 |                             processed_logits[k] = torch.round(k_logits)
 97 | 
 98 |                 items = {k: processed_logits[k].cpu() for k, _, _ in BOX_PROPERTIES}
 99 |                 for j in range(current_batch_size):
100 |                     box_property = {}
101 |                     for k, _, mode in BOX_PROPERTIES:
102 |                         if mode == "classification":
103 |                             box_property[k] = int(items[k][j].item())
104 |                         elif mode == "regression":
105 |                             if k == "bbox":
106 |                                 box_property[k] = items[k][j].tolist()
107 |                             elif k == "colspan":
108 |                                 box_property[k] = int(items[k][j].item())
109 |                     box_properties.append(box_property)
110 | 
111 |                 all_done = all_done | done
112 |                 all_done_cpu = all_done.cpu()
113 | 
114 |                 if all_done_cpu[:current_batch_size].all():
115 |                     break
116 | 
117 |                 batch_input_ids = torch.tensor(shaper.dict_to_labels(box_properties), dtype=torch.long)
118 |                 batch_input_ids = batch_input_ids.unsqueeze(1)  # Add sequence length dimension
119 | 
120 |                 for j, (box_property, status) in enumerate(zip(box_properties, all_done_cpu)):
121 |                     if not status:
122 |                         batch_predictions[j].append(box_property)
123 | 
124 |                 token_count += inference_token_count
125 |                 inference_token_count = batch_input_ids.shape[1]
126 | 
127 |                 if settings.TABLE_REC_STATIC_CACHE:
128 |                     batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size)
129 | 
130 |                 # Move to device after padding for XLA
131 |                 batch_input_ids = batch_input_ids.to(self.model.device)
132 |         return batch_predictions
133 | 
134 |     def batch_table_recognition(
135 |             self,
136 |             images: List,
137 |             batch_size=None) -> List[TableResult]:
138 |         assert all([isinstance(image, Image.Image) for image in images])
139 |         if batch_size is None:
140 |             batch_size = self.get_batch_size()
141 | 
142 |         if len(images) == 0:
143 |             return []
144 | 
145 |         query_items = []
146 |         for image in images:
147 |             query_items.append({
148 |                 "polygon": [[0, 0], [image.width, 0], [image.width, image.height], [0, image.height]],
149 |                 "category": CATEGORY_TO_ID["Table"],
150 |                 "colspan": 0,
151 |                 "merges": 0,
152 |                 "is_header": 0
153 |             })
154 | 
155 |         output_order = []
156 |         for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables", disable=self.disable_tqdm):
157 |             batch_query_items = query_items[i:i + batch_size]
158 | 
159 |             batch_images = images[i:i + batch_size]
160 |             batch_images = [image.convert("RGB") for image in batch_images]  # also copies the images
161 | 
162 |             current_batch_size = len(batch_images)
163 | 
164 |             orig_sizes = [image.size for image in batch_images]
165 |             model_inputs = self.processor(images=batch_images, query_items=batch_query_items)
166 | 
167 |             batch_pixel_values = model_inputs["pixel_values"]
168 | 
169 |             batch_input_ids = model_inputs["input_ids"]
170 |             batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=self.model.dtype)
171 | 
172 |             if settings.TABLE_REC_STATIC_CACHE:
173 |                 batch_pixel_values = self.pad_to_batch_size(batch_pixel_values, batch_size)
174 | 
175 |             # Move to device after padding for XLA
176 |             batch_pixel_values = batch_pixel_values.to(self.model.device)
177 | 
178 |             shaper = LabelShaper()
179 | 
180 |             # We only need to process each image once
181 |             with settings.INFERENCE_MODE():
182 |                 encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state
183 | 
184 |             # Inference to get rows and columns
185 |             rowcol_predictions = self.inference_loop(
186 |                 encoder_hidden_states,
187 |                 batch_input_ids,
188 |                 current_batch_size,
189 |                 batch_size
190 |             )
191 |             mark_step()
192 | 
193 |             row_query_items = []
194 |             row_encoder_hidden_states = []
195 |             idx_map = []
196 |             columns = []
197 |             for j, img_predictions in enumerate(rowcol_predictions):
198 |                 for row_prediction in img_predictions:
199 |                     polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"])
200 |                     if row_prediction["category"] == CATEGORY_TO_ID["Table-row"]:
201 |                         row_query_items.append({
202 |                             "polygon": polygon,
203 |                             "category": row_prediction["category"],
204 |                             "colspan": 0,
205 |                             "merges": 0,
206 |                             "is_header": int(row_prediction["is_header"] == 1)
207 |                         })
208 |                         row_encoder_hidden_states.append(encoder_hidden_states[j])
209 |                         idx_map.append(j)
210 |                     elif row_prediction["category"] == CATEGORY_TO_ID["Table-column"]:
211 |                         columns.append({
212 |                             "polygon": polygon,
213 |                             "category": row_prediction["category"],
214 |                             "colspan": 0,
215 |                             "merges": 0,
216 |                             "is_header": int(row_prediction["is_header"] == 1)
217 |                         })
218 | 
219 |             # Re-inference to predict cells
220 |             row_encoder_hidden_states = torch.stack(row_encoder_hidden_states)
221 |             row_inputs = self.processor(images=None, query_items=row_query_items, columns=columns, convert_images=False)
222 |             row_input_ids = row_inputs["input_ids"]
223 |             cell_predictions = []
224 |             for j in range(0, len(row_input_ids), batch_size):
225 |                 cell_batch_hidden_states = row_encoder_hidden_states[j:j + batch_size]
226 |                 cell_batch_input_ids = row_input_ids[j:j + batch_size]
227 |                 cell_batch_size = len(cell_batch_input_ids)
228 |                 cell_predictions.extend(
229 |                     self.inference_loop(cell_batch_hidden_states, cell_batch_input_ids, cell_batch_size, batch_size)
230 |                 )
231 |                 mark_step()
232 | 
233 |             result = self.decode_batch_predictions(rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper)
234 |             output_order.extend(result)
235 | 
236 |         return output_order
237 | 
238 | 
239 |     def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper):
240 |         results = []
241 |         for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)):
242 |             row_cell_predictions = [c for i, c in enumerate(cell_predictions) if idx_map[i] == j]
243 |             # Each row prediction matches a cell prediction
244 |             rows = []
245 |             cells = []
246 |             columns = []
247 | 
248 |             cell_id = 0
249 |             row_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-row"]]
250 |             col_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-column"]]
251 | 
252 |             # Generate table columns
253 |             for z, col_prediction in enumerate(col_predictions):
254 |                 polygon = shaper.convert_bbox_to_polygon(col_prediction["bbox"])
255 |                 polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
256 |                 columns.append(
257 |                     TableCol(
258 |                         polygon=polygon,
259 |                         col_id=z,
260 |                         is_header=col_prediction["is_header"] == 1
261 |                     )
262 |                 )
263 | 
264 |             # Generate table rows
265 |             for z, row_prediction in enumerate(row_predictions):
266 |                 polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"])
267 |                 polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
268 |                 row = TableRow(
269 |                     polygon=polygon,
270 |                     row_id=z,
271 |                     is_header=row_prediction["is_header"] == 1
272 |                 )
273 |                 rows.append(row)
274 | 
275 |                 # Get cells that span multiple columns within a row
276 |                 spanning_cells = []
277 |                 for l, spanning_cell in enumerate(row_cell_predictions[z]):
278 |                     polygon = shaper.convert_bbox_to_polygon(spanning_cell["bbox"])
279 |                     polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
280 |                     colspan = max(1, int(spanning_cell["colspan"]))
281 |                     if colspan == 1 and spanning_cell["merges"] not in MERGE_VALUES:
282 |                         # Skip single column cells if they don't merge
283 |                         continue
284 |                     if PolygonBox(polygon=polygon).height < row.height * .85:
285 |                         # Spanning cell must cover most of the row
286 |                         continue
287 | 
288 |                     spanning_cells.append(
289 |                         TableCell(
290 |                             polygon=polygon,
291 |                             row_id=z,
292 |                             rowspan=1,
293 |                             cell_id=cell_id,
294 |                             within_row_id=l,
295 |                             colspan=colspan,
296 |                             merge_up=spanning_cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]],
297 |                             merge_down=spanning_cell["merges"] in [MERGE_KEYS["merge_down"],
298 |                                                                    MERGE_KEYS["merge_both"]],
299 |                             is_header=row.is_header or z == 0
300 |                         )
301 |                     )
302 |                     cell_id += 1
303 | 
304 |                 # Add cells - either add spanning cells (multiple cols), or generate a cell based on row/col
305 |                 used_spanning_cells = set()
306 |                 skip_columns = 0
307 |                 for l, col in enumerate(columns):
308 |                     if skip_columns:
309 |                         skip_columns -= 1
310 |                         continue
311 |                     cell_polygon = row.intersection_polygon(col)
312 |                     cell_added = False
313 |                     for zz, spanning_cell in enumerate(spanning_cells):
314 |                         cell_polygonbox = PolygonBox(polygon=cell_polygon)
315 |                         intersection_pct = cell_polygonbox.intersection_pct(spanning_cell)
316 |                         # Make sure cells intersect, and that the spanning cell is wider than the current cell (takes up multiple columns)
317 |                         correct_col_width = sum([col.width for col in columns[l:l + spanning_cell.colspan]])
318 |                         if intersection_pct > .9:
319 |                             if spanning_cell.width > (correct_col_width * .85):
320 |                                 cell_added = True
321 |                                 if zz not in used_spanning_cells:
322 |                                     used_spanning_cells.add(zz)
323 |                                     spanning_cell.col_id = l
324 |                                     cells.append(spanning_cell)
325 |                                     skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell
326 |                             else:
327 |                                 used_spanning_cells.add(zz) # Skip this spanning cell
328 | 
329 |                     if not cell_added:
330 |                         cells.append(
331 |                             TableCell(
332 |                                 polygon=cell_polygon,
333 |                                 row_id=z,
334 |                                 rowspan=1,
335 |                                 cell_id=cell_id,
336 |                                 within_row_id=l,
337 |                                 colspan=1,
338 |                                 merge_up=False,
339 |                                 merge_down=False,
340 |                                 col_id=l,
341 |                                 is_header=row.is_header or col.is_header or z == 0
342 |                             )
343 |                         )
344 |                         cell_id += 1
345 | 
346 |             # Turn cells into a row grid
347 |             grid_cells = deepcopy([
348 |                 [cell for cell in cells if cell.row_id == row.row_id]
349 |                 for row in rows
350 |             ])
351 | 
352 |             # Merge cells across rows
353 |             for z, grid_row in enumerate(grid_cells[1:]):
354 |                 prev_row = grid_cells[z]
355 |                 for l, cell in enumerate(grid_row):
356 |                     if l >= len(prev_row):
357 |                         continue
358 | 
359 |                     above_cell = prev_row[l]
360 |                     if all([
361 |                         above_cell.merge_down,
362 |                         cell.merge_up,
363 |                         above_cell.col_id == cell.col_id,
364 |                         above_cell.colspan == cell.colspan,
365 |                     ]):
366 |                         above_cell.merge(cell)
367 |                         above_cell.rowspan += cell.rowspan
368 |                         grid_row[l] = above_cell
369 | 
370 |             merged_cells_all = list(chain.from_iterable(grid_cells))
371 |             used_ids = set()
372 |             merged_cells = []
373 |             for cell in merged_cells_all:
374 |                 if cell.cell_id in used_ids:
375 |                     continue
376 |                 used_ids.add(cell.cell_id)
377 |                 merged_cells.append(cell)
378 | 
379 |             result = TableResult(
380 |                 cells=merged_cells,
381 |                 unmerged_cells=cells,
382 |                 rows=rows,
383 |                 cols=columns,
384 |                 image_bbox=[0, 0, orig_size[0], orig_size[1]],
385 |             )
386 |             results.append(result)
387 |         return results
388 | 
```

--------------------------------------------------------------------------------
/surya/common/surya/processor/__init__.py:
--------------------------------------------------------------------------------

```python
  1 | import math
  2 | 
  3 | import cv2
  4 | import numpy as np
  5 | import torch
  6 | from PIL import Image
  7 | from torch.nn.utils.rnn import pad_sequence
  8 | 
  9 | from typing import List, Optional, Tuple
 10 | 
 11 | from transformers.feature_extraction_utils import BatchFeature
 12 | from transformers.processing_utils import ProcessorMixin
 13 | from transformers.tokenization_utils import PreTrainedTokenizer
 14 | 
 15 | from surya.common.s3 import S3DownloaderMixin
 16 | from surya.common.surya.processor.schema import (
 17 |     TextInput,
 18 |     ImageInput,
 19 |     ProcessorOutput,
 20 | )
 21 | from surya.common.surya.schema import TaskNames
 22 | from surya.logging import get_logger
 23 | from surya.settings import settings
 24 | 
 25 | logger = get_logger()
 26 | 
 27 | # Task agnostic tokens - Every task will use these in some form or another
 28 | EOS_TOKEN = "</S>"
 29 | EOI_TOKEN = "<EOI>"  # This is end of INPUT, not image. Images are always followed by a task specific BOS token, so that serves as a delimiter anyways.
 30 | IMAGE_TOKEN = "<IMAGE>"
 31 | PAD_TOKEN = "<PAD>"
 32 | NO_OUTPUT_TOKEN = "<NOP>"
 33 | IMAGE_ROTATED_TOKEN = "<ROT>"
 34 | REGISTER_TOKENS = ["<REG1>", "<REG2>", "<REG3>", "<REG4>"]
 35 | BEACON_TOKEN = "<BEACON>"
 36 | NOMATH_TOKEN = "<NO-MATH>"
 37 | 
 38 | # Task specific tokens
 39 | OCR_WITH_BOXES_BOS_TOKEN = "<OCR-WB>"
 40 | OCR_WITHOUT_BOXES_BOS_TOKEN = "<OCR-WOB>"
 41 | BLOCK_WITHOUT_BOXES_TOKEN = "<BLOCKS-WOB>"
 42 | LAYOUT_BOS_TOKEN = "<LAYOUT>"
 43 | TABLE_STRUCTURE_BOS_TOKEN = "<TABLE-STRUCTURE>"
 44 | 
 45 | 
 46 | class SuryaOCRProcessor(S3DownloaderMixin, ProcessorMixin):
 47 |     attributes = ["image_processor", "ocr_tokenizer"]
 48 |     image_processor_class = "BaseImageProcessor"
 49 |     ocr_tokenizer_class = "PreTrainedTokenizer"
 50 |     rescale_factor = 1 / 255.0
 51 |     image_mean = (0.485, 0.456, 0.406)
 52 |     image_std = (0.229, 0.224, 0.225)
 53 | 
 54 |     def __init__(
 55 |         self,
 56 |         ocr_tokenizer: PreTrainedTokenizer,
 57 |         blank_bbox_token_id: int,
 58 |         num_register_tokens: int,
 59 |         patch_size: int,
 60 |         merge_size: int,
 61 |         num_beacon_tokens: int,
 62 |         beacon_token_interval: int,
 63 |         model_device: str,
 64 |         **kwargs,
 65 |     ):
 66 |         self.ocr_tokenizer = ocr_tokenizer
 67 |         self.patch_size = patch_size
 68 |         self.merge_size = merge_size
 69 |         self.num_register_tokens = num_register_tokens
 70 |         self.num_beacon_tokens = num_beacon_tokens
 71 |         self.beacon_token_interval = beacon_token_interval
 72 | 
 73 |         self.tokenizer_vocab_size = 0
 74 |         for attr in self.attributes:
 75 |             if "tokenizer" in attr:
 76 |                 self.tokenizer_vocab_size += getattr(self, attr).vocab_size
 77 | 
 78 |         self.offsets = {"ocr": 0}
 79 | 
 80 |         # Create special token mapping
 81 |         self.special_token_mapping = self.ocr_tokenizer.system_tokens
 82 | 
 83 |         self.register_token_ids = [
 84 |             self.special_token_mapping.get(r) for r in REGISTER_TOKENS
 85 |         ]
 86 |         self.beacon_token_id = self.special_token_mapping.get(BEACON_TOKEN)
 87 |         self.image_token_id = self.special_token_mapping.get(IMAGE_TOKEN)
 88 |         self.pad_token_id = self.special_token_mapping.get(PAD_TOKEN)
 89 |         self.eos_token_id = self.special_token_mapping.get(EOS_TOKEN)
 90 |         self.eoi_token_id = self.special_token_mapping.get(EOI_TOKEN)
 91 |         self.no_output_token = self.special_token_mapping.get(NO_OUTPUT_TOKEN)
 92 |         self.image_rotated_token = self.special_token_mapping.get(IMAGE_ROTATED_TOKEN)
 93 |         self.nomath_token = self.special_token_mapping.get(NOMATH_TOKEN)
 94 | 
 95 |         self.bos_token_id = {
 96 |             TaskNames.ocr_with_boxes: self.special_token_mapping.get(
 97 |                 OCR_WITH_BOXES_BOS_TOKEN
 98 |             ),
 99 |             TaskNames.ocr_without_boxes: self.special_token_mapping.get(
100 |                 OCR_WITHOUT_BOXES_BOS_TOKEN
101 |             ),
102 |             TaskNames.block_without_boxes: self.special_token_mapping.get(
103 |                 BLOCK_WITHOUT_BOXES_TOKEN
104 |             ),
105 |             TaskNames.layout: self.special_token_mapping.get(LAYOUT_BOS_TOKEN),
106 |             TaskNames.table_structure: self.special_token_mapping.get(
107 |                 TABLE_STRUCTURE_BOS_TOKEN
108 |             ),
109 |         }
110 | 
111 |         if self.image_token_id is None:
112 |             logger.warning("Warning: Image token not found in special tokens")
113 | 
114 |         self.blank_bbox_token_id = blank_bbox_token_id
115 |         self.bbox_pad_token_id = self.blank_bbox_token_id
116 | 
117 |         self.ignore_bbox_token_ids = [
118 |             v
119 |             for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items()
120 |             if k not in self.ocr_tokenizer.special_tokens["math_external"]
121 |         ]
122 |         math_end_token = "</math>"
123 |         self.math_start_token_ids = [
124 |             v
125 |             for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items()
126 |             if k in self.ocr_tokenizer.special_tokens["math_external"]
127 |             and k != math_end_token
128 |         ]
129 |         self.math_end_token_ids = [
130 |             v
131 |             for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items()
132 |             if k == math_end_token
133 |         ]
134 | 
135 |         if self.num_register_tokens > len(self.register_token_ids):
136 |             raise ValueError(
137 |                 "The number of register tokens requested exceeds the number of register tokens defined in the special token mapping."
138 |             )
139 | 
140 |         self.image_mean = np.array(self.image_mean, dtype=np.float32)
141 |         self.image_std = np.array(self.image_std, dtype=np.float32)
142 |         self.model_device = model_device
143 | 
144 |     @property
145 |     def vocab_size(self):
146 |         return self.tokenizer_vocab_size
147 | 
148 |     def image_processor(self, image: Image.Image) -> np.ndarray:
149 |         # Convert to array
150 |         image = np.asarray(image, dtype=np.float32)
151 |         return image
152 | 
153 |     @staticmethod
154 |     def scale_to_fit(
155 |         img: np.ndarray,
156 |         max_size: Tuple[int, int],
157 |         min_size: Tuple[int, int] = (168, 168),
158 |     ):
159 |         # Get current dimensions
160 |         height, width = img.shape[:2]
161 | 
162 |         # Check for empty or invalid image
163 |         if width == 0 or height == 0:
164 |             return img
165 | 
166 |         max_width, max_height = max_size
167 |         min_width, min_height = min_size
168 | 
169 |         # Calculate pixel counts
170 |         current_pixels = width * height
171 |         max_pixels = max_width * max_height
172 |         min_pixels = min_width * min_height
173 | 
174 |         if current_pixels > max_pixels:
175 |             scale_factor = (max_pixels / current_pixels) ** 0.5
176 | 
177 |             new_width = math.floor(width * scale_factor)
178 |             new_height = math.floor(height * scale_factor)
179 |         elif current_pixels == 0:
180 |             return img
181 |         elif current_pixels < min_pixels:
182 |             scale_factor = (min_pixels / current_pixels) ** 0.5
183 | 
184 |             new_width = math.ceil(width * scale_factor)
185 |             new_height = math.ceil(height * scale_factor)
186 |         else:
187 |             return img
188 | 
189 |         return cv2.resize(
190 |             img, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4
191 |         )
192 | 
193 |     def _image_processor(self, image: np.ndarray):
194 |         image = image.astype(np.float64) * self.rescale_factor
195 |         image = (image.astype(np.float32) - self.image_mean) / self.image_std
196 |         return image
197 | 
198 |     def _process_and_tile(
199 |         self, image: np.ndarray
200 |     ) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
201 |         """
202 |         Resizes the input image to the closest multiple of tile_size while preserving the aspect ratio
203 |         and returns a tensor of image tiles.
204 |         """
205 |         extra_multipler = (
206 |             4 if settings.FOUNDATION_XLA else 1
207 |         )  # Needed to force same size grid_thws per row with padding
208 | 
209 |         factor = (
210 |             self.patch_size * self.merge_size * extra_multipler
211 |         )  # Make a multiple of window size
212 | 
213 |         height, width = image.shape[:2]
214 | 
215 |         h_bar = math.ceil(height / factor) * factor
216 |         w_bar = math.ceil(width / factor) * factor
217 |         if h_bar != height or w_bar != width:
218 |             if height == 0 or width == 0:
219 |                 image = np.zeros((h_bar, w_bar, 3), dtype=np.uint8)
220 |             else:
221 |                 image = cv2.resize(image, (w_bar, h_bar), interpolation=cv2.INTER_CUBIC)
222 | 
223 |         # Handle scaling and normalization
224 |         image = self._image_processor(image)
225 |         height, width = image.shape[:2]
226 | 
227 |         # Numpy array to torch tensor
228 |         img_tensor = torch.from_numpy(image.transpose(2, 0, 1))
229 |         patches = img_tensor.unsqueeze(0)
230 | 
231 |         channel = patches.shape[1]
232 |         grid_t = patches.shape[0]
233 |         grid_h, grid_w = height // self.patch_size, width // self.patch_size
234 | 
235 |         patches = patches.reshape(
236 |             grid_t,
237 |             1,
238 |             channel,
239 |             grid_h // self.merge_size,
240 |             self.merge_size,
241 |             self.patch_size,
242 |             grid_w // self.merge_size,
243 |             self.merge_size,
244 |             self.patch_size,
245 |         )
246 |         patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
247 |         flatten_patches = patches.reshape(
248 |             grid_t * grid_h * grid_w, channel * 1 * self.patch_size * self.patch_size
249 |         )
250 | 
251 |         return flatten_patches, (grid_t, grid_h, grid_w)
252 | 
253 |     # Handle image input dictionaries - Process image, tile accordingly, and setup the input ids and boxes correspondingly
254 |     def _process_image_input(self, image_input: ImageInput) -> ProcessorOutput:
255 |         rotated = image_input.get("rotated", False)
256 |         image = image_input.get("image", None)
257 | 
258 |         assert image is not None, (
259 |             "A PIL Image must be provided when the input type is `image`"
260 |         )
261 |         image_tiles, grid_thw = self._process_and_tile(image)
262 | 
263 |         num_tokens = image_tiles.shape[0] / self.merge_size**2
264 |         assert num_tokens.is_integer(), (
265 |             f"Expected number of tokens to be an integer, got {num_tokens}"
266 |         )
267 | 
268 |         input_ids = [self.image_token_id] * int(num_tokens)
269 |         input_ids += self.register_token_ids[: self.num_register_tokens]
270 | 
271 |         # Handle the image being rotated in the imdataset
272 |         if rotated:
273 |             input_ids = [self.image_rotated_token] + input_ids
274 | 
275 |         return ProcessorOutput(
276 |             input_ids=input_ids,
277 |             image_tiles=image_tiles,
278 |             grid_thw=grid_thw,
279 |         )
280 | 
281 |     def _process_text_input(self, text_input: TextInput, task: str) -> ProcessorOutput:
282 |         input_text = text_input.get("text", None)
283 |         math_mode = text_input.get("math", False)
284 | 
285 |         input_ids = self.ocr_tokenizer(input_text, tasks=task)["input_ids"][0]
286 |         input_ids = [self.offsets["ocr"] + id for id in input_ids]
287 | 
288 |         # nomath token does not work for layout
289 |         if not math_mode and task != "layout":
290 |             input_ids.insert(0, self.nomath_token)
291 | 
292 |         return ProcessorOutput(
293 |             input_ids=input_ids,
294 |             image_tiles=None,
295 |             grid_thw=None,
296 |         )
297 | 
298 |     def _process_input(self, input_dict: dict, task: str):
299 |         input_type = input_dict["type"]
300 |         if input_type == "image":
301 |             return self._process_image_input(input_dict)
302 |         elif input_type == "text":
303 |             return self._process_text_input(input_dict, task)
304 | 
305 |         raise NotImplementedError(f"Input of type `{input_type}` is not implemented")
306 | 
307 |     # Peprocessing for OCR task
308 |     # The task is expected to have - image_dict, user_input_dict, output_dict
309 |     # use_input_dict is allowed to have an empty input which is fine, but needs to be present
310 |     def _process_ocr_with_boxes(
311 |         self,
312 |         mixed_input: List[dict],
313 |         bos_token_id: int,
314 |         task: str = TaskNames.ocr_with_boxes,
315 |     ):
316 |         processed_input_ids = []
317 |         all_image_tiles = []
318 |         all_grid_thw = []
319 | 
320 |         # 1. Process the image input
321 |         for i, input_dict in enumerate(mixed_input):
322 |             processor_output = self._process_input(input_dict, task)
323 |             input_ids = processor_output["input_ids"]
324 |             image_tiles = processor_output["image_tiles"]
325 |             grid_thw = processor_output["grid_thw"]
326 | 
327 |             # Special handling of some delimiter tokens
328 |             if i == 1:
329 |                 assert input_dict["type"] == "text", (
330 |                     "Expected text input for model input."
331 |                 )
332 |                 # Case for input - Add task specific bos token + end_of_input token
333 |                 # We do not want the model to learn how to predict inputs. Hence IGNORE_INDEX for these
334 |                 input_ids = [bos_token_id] + input_ids + [self.eoi_token_id]
335 |             if i == 2:
336 |                 assert input_dict["type"] == "text", (
337 |                     "Expected text for final model input"
338 |                 )
339 |                 input_ids = input_ids + [self.eos_token_id]
340 |             elif i > 2:
341 |                 raise ValueError(f"Too many inputs received. Expected is 2 for inference, 3 for training. Received: {len(mixed_input)}")
342 | 
343 |             # Some input types don't return any image tiles, accounting for that
344 |             if image_tiles is not None:
345 |                 all_image_tiles.append(image_tiles)
346 |                 all_grid_thw.append(grid_thw)
347 | 
348 |             processed_input_ids.extend(input_ids)
349 | 
350 |         return (
351 |             torch.tensor(processed_input_ids, dtype=torch.long),
352 |             all_image_tiles,
353 |             all_grid_thw,
354 |         )
355 | 
356 |     def _process_layout(self, mixed_input: List[dict], bos_token_id: int):
357 |         return self._process_ocr_with_boxes(
358 |             mixed_input, bos_token_id=bos_token_id, task="layout"
359 |         )
360 | 
361 |     def _process_table_structure(self, mixed_input: List[dict], bos_token_id: int):
362 |         return self._process_ocr_with_boxes(
363 |             mixed_input, bos_token_id=bos_token_id, task="table_structure"
364 |         )
365 | 
366 |     def _process_ocr_without_boxes(
367 |         self,
368 |         mixed_input: List[dict],
369 |         bos_token_id: int,
370 |         task: str = "ocr_without_boxes",
371 |     ):
372 |         # Boxes are set to None, so this will work
373 |         # TODO: improve this behavior
374 |         return self._process_ocr_with_boxes(
375 |             mixed_input, bos_token_id=bos_token_id, task=task
376 |         )
377 | 
378 |     def _process_block_without_boxes(
379 |         self,
380 |         mixed_input: List[dict],
381 |         bos_token_id: int,
382 |         task: str = "block_without_boxes",
383 |     ):
384 |         return self._process_ocr_with_boxes(
385 |             mixed_input, bos_token_id=bos_token_id, task=task
386 |         )
387 | 
388 |     def align_long_axis(self, image: np.ndarray) -> Tuple[np.ndarray, bool]:
389 |         height, width, _ = image.shape
390 |         if height > width:  # Rotate vertical lines
391 |             image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
392 |             return image, True
393 | 
394 |         return image, False
395 | 
396 |     def __call__(
397 |         self,
398 |         mixed_batch: List[dict],
399 |         padding_side: Optional[str] = "left",
400 |         device: Optional[torch.device] = None,
401 |         pad_to_multiple: Optional[int] = None,
402 |     ):
403 |         all_image_tiles = []
404 |         all_input_ids = []
405 |         all_grid_thw = []
406 | 
407 |         for b in mixed_batch:
408 |             mixed_input = b["inputs"]
409 |             task = b["task"]
410 |             assert task in self.bos_token_id, f"Task {task} has no bos token defined."
411 | 
412 |             # Select the correct processing function based on the task type
413 |             input_ids, image_tiles, grid_thw = getattr(self, f"_process_{task}")(
414 |                 mixed_input, self.bos_token_id[task]
415 |             )
416 | 
417 |             all_input_ids.append(input_ids)
418 |             all_image_tiles.extend(image_tiles)
419 |             all_grid_thw.extend(grid_thw)
420 | 
421 |         batched_input_ids = pad_sequence(
422 |             all_input_ids,
423 |             batch_first=True,
424 |             padding_side=padding_side,
425 |             padding_value=self.pad_token_id,
426 |         )
427 | 
428 |         if pad_to_multiple is not None:
429 |             current_len = batched_input_ids.shape[1]
430 |             # Calculate the next multiple of pad_to_multiple
431 |             padded_len = (
432 |                 (current_len + pad_to_multiple - 1) // pad_to_multiple
433 |             ) * pad_to_multiple
434 | 
435 |             if padded_len > current_len:
436 |                 pad_len = padded_len - current_len
437 |                 batched_input_ids = torch.nn.functional.pad(
438 |                     batched_input_ids, (pad_len, 0), value=self.pad_token_id
439 |                 )
440 | 
441 |         attention_mask = batched_input_ids.ne(self.pad_token_id)
442 | 
443 |         # Generating position IDs that are independent of left and right padding;
444 |         # This should ensure same results for either padding side. Exact position id for the pad tokens themselves don't matter since they are masked
445 |         position_ids = attention_mask.cumsum(dim=-1) - 1
446 |         position_ids[position_ids < 0] = (
447 |             0  # For left padding, the position ids for padding will become -1 because of the shift; Setting to 0
448 |         )
449 |         position_ids = (
450 |             attention_mask.to(torch.long) * position_ids
451 |         )  # Ensure right pad ids get set to zero
452 | 
453 |         batched_image_tiles = torch.cat(all_image_tiles, dim=0)
454 |         batched_grid_thw = torch.from_numpy(np.array(all_grid_thw))
455 | 
456 |         # Pin memory for CUDA
457 |         if device == torch.device("cuda"):
458 |             batched_image_tiles = batched_image_tiles.pin_memory()
459 |             batched_grid_thw = batched_grid_thw.pin_memory()
460 |             attention_mask = attention_mask.pin_memory()
461 |             batched_input_ids = batched_input_ids.pin_memory()
462 |             position_ids = position_ids.pin_memory()
463 | 
464 |         return BatchFeature(
465 |             {
466 |                 "input_ids": batched_input_ids,
467 |                 "image_tiles": batched_image_tiles,
468 |                 "attention_mask": attention_mask,
469 |                 "position_ids": position_ids,
470 |                 "grid_thw": batched_grid_thw,
471 |             }
472 |         )
473 | 
474 |     # Decode model outputs; Strips special tokens
475 |     def decode(self, tokens: List[int], task: str):
476 |         filtered_tokens = [
477 |             t
478 |             for t in tokens
479 |             if t not in self.special_token_mapping.values() and t != -100
480 |         ]  # Skip special tokens and loss ignore index
481 |         return self.ocr_tokenizer.decode(filtered_tokens, task=task)
482 | 
```

--------------------------------------------------------------------------------
/surya/recognition/__init__.py:
--------------------------------------------------------------------------------

```python
  1 | from __future__ import annotations
  2 | 
  3 | import re
  4 | from typing import List
  5 | 
  6 | import numpy as np
  7 | import torch
  8 | from PIL import Image
  9 | import torch.nn.functional as F
 10 | 
 11 | from surya.common.polygon import PolygonBox
 12 | from surya.common.surya.processor import NOMATH_TOKEN
 13 | from surya.common.predictor import BasePredictor
 14 | from surya.detection import DetectionPredictor
 15 | from surya.foundation import FoundationPredictor
 16 | 
 17 | from surya.input.processing import (
 18 |     convert_if_not_rgb,
 19 |     slice_polys_from_image,
 20 |     slice_bboxes_from_image,
 21 | )
 22 | from surya.recognition.postprocessing import fix_unbalanced_tags
 23 | from surya.recognition.util import (
 24 |     sort_text_lines,
 25 |     clean_close_polygons,
 26 |     unwrap_math,
 27 |     clean_math_tags,
 28 |     filter_blacklist_tags,
 29 |     words_from_chars
 30 | )
 31 | from surya.foundation.util import detect_repeat_token, prediction_to_polygon_batch
 32 | from surya.recognition.schema import TextLine, OCRResult, TextChar
 33 | from surya.common.surya.schema import TaskNames
 34 | from surya.settings import settings
 35 | from surya.logging import get_logger, configure_logging
 36 | 
 37 | configure_logging()
 38 | logger = get_logger()
 39 | 
 40 | class RecognitionPredictor(BasePredictor):
 41 |     batch_size = settings.RECOGNITION_BATCH_SIZE
 42 |     default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 128}
 43 | 
 44 |     # Override base init - Do not load model
 45 |     def __init__(self, foundation_predictor: FoundationPredictor):
 46 |         self.foundation_predictor = foundation_predictor
 47 |         self.processor = self.foundation_predictor.processor
 48 |         self.bbox_size = self.foundation_predictor.model.config.bbox_size
 49 |         self.tasks = self.foundation_predictor.tasks
 50 | 
 51 |     # Special handling for disable tqdm to pass into foundation predictor
 52 |     # Make sure they are kept in sync
 53 |     @property
 54 |     def disable_tqdm(self) -> bool:
 55 |         return super().disable_tqdm
 56 | 
 57 |     @disable_tqdm.setter
 58 |     def disable_tqdm(self, value: bool) -> None:
 59 |         self._disable_tqdm = bool(value)
 60 |         self.foundation_predictor.disable_tqdm = bool(value)
 61 | 
 62 |     def detect_and_slice_bboxes(
 63 |         self,
 64 |         images: List[Image.Image],
 65 |         task_names: List[str],
 66 |         det_predictor: DetectionPredictor,
 67 |         detection_batch_size: int | None = None,
 68 |         highres_images: List[Image.Image] | None = None,
 69 |     ):
 70 |         det_predictions = det_predictor(images, batch_size=detection_batch_size)
 71 | 
 72 |         all_slices = []
 73 |         slice_map = []
 74 |         all_polygons = []
 75 |         all_task_names = []
 76 |         all_res_scales = []
 77 | 
 78 |         for idx, (det_pred, image, highres_image, task_name) in enumerate(
 79 |             zip(det_predictions, images, highres_images, task_names)
 80 |         ):
 81 |             polygons = [p.polygon for p in det_pred.bboxes]
 82 |             if highres_image:
 83 |                 width_scaler = highres_image.size[0] / image.size[0]
 84 |                 height_scaler = highres_image.size[1] / image.size[1]
 85 |                 scaled_polygons = [
 86 |                     [
 87 |                         [int(p[0] * width_scaler), int(p[1] * height_scaler)]
 88 |                         for p in polygon
 89 |                     ]
 90 |                     for polygon in polygons
 91 |                 ]
 92 |                 highres_image = self.processor.image_processor(highres_image)
 93 |                 slices = slice_polys_from_image(highres_image, scaled_polygons)
 94 |                 res_scales = [(width_scaler, height_scaler) for _ in range(len(slices))]
 95 |             else:
 96 |                 image = self.processor.image_processor(image)
 97 |                 slices = slice_polys_from_image(image, polygons)
 98 |                 res_scales = [(1, 1) for _ in range(len(slices))]
 99 | 
100 |             slice_map.append(len(slices))
101 |             all_slices.extend(slices)
102 |             all_polygons.extend(polygons)
103 |             all_task_names.extend([task_name] * len(slices))
104 |             all_res_scales.extend(res_scales)
105 | 
106 |         assert (
107 |             len(all_slices)
108 |             == sum(slice_map)
109 |             == len(all_polygons)
110 |             == len(all_task_names)
111 |             == len(all_res_scales)
112 |         )
113 | 
114 |         return {
115 |             "slices": all_slices,
116 |             "slice_map": slice_map,
117 |             "polygons": all_polygons,
118 |             "task_names": all_task_names,
119 |             "input_text": [None] * len(all_slices),
120 |             "res_scales": all_res_scales,
121 |         }
122 | 
123 |     def slice_bboxes(
124 |         self,
125 |         images: List[Image.Image],
126 |         task_names: List[str],
127 |         bboxes: List[List[List[int]]] | None = None,
128 |         polygons: List[List[List[List[int]]]] | None = None,
129 |         input_text: List[List[str | None]] | None = None,
130 |     ) -> dict:
131 |         assert bboxes is not None or polygons is not None
132 |         slice_map = []
133 |         all_slices = []
134 |         all_polygons = []
135 |         all_text = []
136 |         all_task_names = []
137 | 
138 |         for idx, image in enumerate(images):
139 |             image = self.processor.image_processor(image)
140 |             if polygons is not None:
141 |                 polys = polygons[idx]
142 |                 slices = slice_polys_from_image(image, polys)
143 |             else:
144 |                 slices = slice_bboxes_from_image(image, bboxes[idx])
145 |                 polys = [
146 |                     [
147 |                         [bbox[0], bbox[1]],
148 |                         [bbox[2], bbox[1]],
149 |                         [bbox[2], bbox[3]],
150 |                         [bbox[0], bbox[3]],
151 |                     ]
152 |                     for bbox in bboxes[idx]
153 |                 ]
154 |             slice_map.append(len(slices))
155 |             all_slices.extend(slices)
156 |             all_polygons.extend(polys)
157 |             all_task_names.extend([task_names[idx]] * len(slices))
158 | 
159 |             if input_text is None:
160 |                 all_text.extend([None] * len(slices))
161 |             else:
162 |                 all_text.extend(input_text[idx])
163 | 
164 |         assert (
165 |             len(all_slices)
166 |             == sum(slice_map)
167 |             == len(all_polygons)
168 |             == len(all_text)
169 |             == len(all_task_names)
170 |         ), (
171 |             f"Mismatch in lengths: {len(all_slices)}, {sum(slice_map)}, {len(all_polygons)}, {len(all_text)}, {len(all_task_names)}"
172 |         )
173 | 
174 |         return {
175 |             "slices": all_slices,
176 |             "slice_map": slice_map,
177 |             "polygons": all_polygons,
178 |             "input_text": all_text,
179 |             "task_names": all_task_names,
180 |             "res_scales": [(1, 1) for _ in range(len(all_slices))],
181 |         }
182 | 
183 |     def get_bboxes_text(
184 |         self,
185 |         flat: dict,
186 |         predicted_tokens: list,
187 |         scores: list,
188 |         predicted_polygons: list,
189 |         drop_repeated_text: bool = False,
190 |     ) -> list:
191 |         char_predictions = []
192 |         needs_boxes = [
193 |             self.tasks[task_name]["needs_bboxes"] for task_name in flat["task_names"]
194 |         ]
195 | 
196 |         for slice_idx, (
197 |             slice_image,
198 |             image_tokens,
199 |             image_polygons,
200 |             image_scores,
201 |             needs_box,
202 |         ) in enumerate(
203 |             zip(
204 |                 flat["slices"],
205 |                 predicted_tokens,
206 |                 predicted_polygons,
207 |                 scores,
208 |                 needs_boxes,
209 |             )
210 |         ):
211 |             blank_bbox = [[0, 0], [0, 1], [1, 1], [1, 0]]
212 |             if self.processor.no_output_token in image_tokens:
213 |                 char_predictions.append(None)
214 |                 continue
215 | 
216 |             # If the image is very out of distribution, we can get nonsense repeats, and we may need to drop the text entirely
217 |             if drop_repeated_text and detect_repeat_token(image_tokens):
218 |                 char_predictions.append(
219 |                     [
220 |                         TextChar(
221 |                             text="",
222 |                             polygon=blank_bbox,
223 |                             confidence=0,
224 |                             bbox_valid=False,
225 |                         )
226 |                     ]
227 |                 )
228 |                 continue
229 | 
230 |             image_polygons = image_polygons[: len(image_tokens)].cpu().numpy().tolist()
231 | 
232 |             detokenize_sequences = []
233 |             detokenize_sequence = []
234 |             past_char_qwen_token = False
235 | 
236 |             def _add_detokenize_sequence(
237 |                 special_token: bool,
238 |                 past_special_token: bool,
239 |                 force: bool = False,
240 |             ):
241 |                 nonlocal detokenize_sequence, detokenize_sequences
242 | 
243 |                 if (
244 |                     special_token
245 |                     or past_special_token
246 |                     or force
247 |                 ) and detokenize_sequence:
248 |                     chars = [dt[0] for dt in detokenize_sequence]
249 |                     scores = [dt[1] for dt in detokenize_sequence]
250 |                     bboxes = [dt[2] for dt in detokenize_sequence]
251 | 
252 |                     if past_special_token:
253 |                         detokenize_sequences.append((chars, scores, None, "special"))
254 |                     else:
255 |                         detokenize_sequences.append((chars, scores, bboxes, "ocr"))
256 | 
257 |                     detokenize_sequence = []
258 | 
259 |             # Split up into sequences to detokenize separately
260 |             past_special_token = False
261 |             for bbox, char_id, score in zip(image_polygons, image_tokens, image_scores):
262 |                 if char_id in [
263 |                     self.processor.eos_token_id,
264 |                     self.processor.pad_token_id,
265 |                 ]:
266 |                     break
267 | 
268 |                 special_token = (
269 |                     char_id >= self.processor.ocr_tokenizer.ocr_tokenizer.SPECIAL_BASE
270 |                 )
271 |                 _add_detokenize_sequence(
272 |                     special_token, past_special_token
273 |                 )
274 |                 detokenize_sequence.append((char_id, score, bbox))
275 |                 past_special_token = special_token
276 | 
277 |             _add_detokenize_sequence(
278 |                 False, past_special_token, force=True
279 |             )
280 | 
281 |             img_chars = []
282 |             for sequence in detokenize_sequences:
283 |                 token_ids, seq_score, bboxes, token_type = sequence
284 |                 if token_type == "ocr":
285 |                     text = self.processor.ocr_tokenizer.decode(
286 |                         token_ids, task=TaskNames.ocr_with_boxes
287 |                     )
288 |                     bboxes = clean_close_polygons(
289 |                         bboxes
290 |                     )  # clean out bboxes that are close, like what happens with multiple utf-16 tokens per char
291 |                     bbox_idx = 0
292 |                     for text_idx, text_line in enumerate(text):
293 |                         img_chars.append(
294 |                             TextChar(
295 |                                 text=text_line,
296 |                                 polygon=bboxes[bbox_idx],
297 |                                 confidence=seq_score[bbox_idx],
298 |                                 bbox_valid=True,
299 |                             )
300 |                         )
301 | 
302 |                         # Ensure we don't exceed the bbox count
303 |                         # Use the last bbox for the rest of the text
304 |                         if bbox_idx < len(bboxes) - 1:
305 |                             bbox_idx += 1
306 |                 elif token_type == "special":
307 |                     text = self.processor.ocr_tokenizer.decode(
308 |                         token_ids, task="ocr_without_boxes"
309 |                     )
310 |                     if text in [NOMATH_TOKEN] or re.match(r"<SCRIPT-\w+>", text):
311 |                         continue
312 | 
313 |                     img_chars.append(
314 |                         TextChar(
315 |                             text=text,
316 |                             polygon=blank_bbox,
317 |                             confidence=seq_score[0],
318 |                             bbox_valid=False,
319 |                         )
320 |                     )
321 |                 else:
322 |                     text = self.processor.ocr_tokenizer.decode(
323 |                         token_ids, task=TaskNames.block_without_boxes
324 |                     )
325 |                     img_chars.append(
326 |                         TextChar(
327 |                             text=text,
328 |                             polygon=blank_bbox,
329 |                             confidence=seq_score[0],
330 |                             bbox_valid=False,
331 |                         )
332 |                     )
333 | 
334 |             char_predictions.append(img_chars)
335 | 
336 |         return char_predictions
337 | 
338 |     def __call__(
339 |         self,
340 |         images: List[Image.Image],
341 |         task_names: List[str] | None = None,
342 |         det_predictor: DetectionPredictor | None = None,
343 |         detection_batch_size: int | None = None,
344 |         recognition_batch_size: int | None = None,
345 |         highres_images: List[Image.Image] | None = None,
346 |         bboxes: List[List[List[int]]] | None = None,
347 |         polygons: List[List[List[List[int]]]] | None = None,
348 |         input_text: List[List[str | None]] | None = None,
349 |         sort_lines: bool = False,
350 |         math_mode: bool = True,
351 |         return_words: bool = False,
352 |         drop_repeated_text: bool = False,
353 |         max_sliding_window: int | None = None,
354 |         max_tokens: int | None = None,
355 |         filter_tag_list: List[str] = None
356 |     ) -> List[OCRResult]:
357 |         if task_names is None:
358 |             task_names = [TaskNames.ocr_with_boxes] * len(images)
359 |         if recognition_batch_size is None:
360 |             recognition_batch_size = self.get_batch_size()
361 | 
362 |         assert len(images) == len(task_names), (
363 |             "You need to pass in one task name for each image"
364 |         )
365 | 
366 |         images = convert_if_not_rgb(images)
367 |         if highres_images is not None:
368 |             assert len(images) == len(highres_images), (
369 |                 "You need to pass in one highres image for each image"
370 |             )
371 | 
372 |         highres_images = (
373 |             convert_if_not_rgb(highres_images)
374 |             if highres_images is not None
375 |             else [None] * len(images)
376 |         )
377 | 
378 |         if bboxes is None and polygons is None:
379 |             assert det_predictor is not None, (
380 |                 "You need to pass in a detection predictor if you don't provide bboxes or polygons"
381 |             )
382 | 
383 |             # Detect then slice
384 |             flat = self.detect_and_slice_bboxes(
385 |                 images,
386 |                 task_names,
387 |                 det_predictor,
388 |                 detection_batch_size=detection_batch_size,
389 |                 highres_images=highres_images,
390 |             )
391 |         else:
392 |             if bboxes is not None:
393 |                 assert len(images) == len(bboxes), (
394 |                     "You need to pass in one list of bboxes for each image"
395 |                 )
396 |             if polygons is not None:
397 |                 assert len(images) == len(polygons), (
398 |                     "You need to pass in one list of polygons for each image"
399 |                 )
400 | 
401 |             flat = self.slice_bboxes(
402 |                 images,
403 |                 bboxes=bboxes,
404 |                 polygons=polygons,
405 |                 input_text=input_text,
406 |                 task_names=task_names,
407 |             )
408 | 
409 |         # No images passed, or no boxes passed, or no text detected in the images
410 |         if len(flat["slices"]) == 0:
411 |             return [
412 |                 OCRResult(
413 |                     text_lines=[], image_bbox=[0, 0, im.size[0], im.size[1]]
414 |                 )
415 |                 for im in images
416 |             ]
417 | 
418 |         # Sort by image sizes. Negative so that longer images come first, fits in with continuous batching better
419 |         sorted_pairs = sorted(
420 |             enumerate(flat["slices"]),
421 |             key=lambda x: -(x[1].shape[0] * x[1].shape[1])  # height * width
422 |         )
423 |         indices, sorted_slices = zip(*sorted_pairs)
424 | 
425 |         # Reorder input_text and task_names based on the new order
426 |         flat["slices"] = list(sorted_slices)
427 |         flat["input_text"] = [flat["input_text"][i] for i in indices]
428 |         flat["task_names"] = [flat["task_names"][i] for i in indices]
429 | 
430 |         # Make predictions
431 |         predicted_tokens, batch_bboxes, scores, _ = self.foundation_predictor.prediction_loop(
432 |             images=flat["slices"],
433 |             input_texts=flat["input_text"],
434 |             task_names=flat["task_names"],
435 |             batch_size=recognition_batch_size,
436 |             math_mode=math_mode,
437 |             drop_repeated_tokens=True,
438 |             max_lookahead_tokens=self.foundation_predictor.model.config.multi_output_distance,
439 |             max_sliding_window=max_sliding_window,
440 |             max_tokens=max_tokens,
441 |             tqdm_desc="Recognizing Text"
442 |         )
443 | 
444 |         # Get text and bboxes in structured form
445 |         bbox_size = self.bbox_size
446 |         image_sizes = [img.shape for img in flat["slices"]]
447 |         predicted_polygons = prediction_to_polygon_batch(
448 |             batch_bboxes, image_sizes, bbox_size, bbox_size // 2
449 |         )
450 |         char_predictions = self.get_bboxes_text(
451 |             flat,
452 |             predicted_tokens,
453 |             scores,
454 |             predicted_polygons,
455 |             drop_repeated_text=drop_repeated_text,
456 |         )
457 | 
458 |         char_predictions = sorted(zip(indices, char_predictions), key=lambda x: x[0])
459 |         char_predictions = [pred for _, pred in char_predictions]
460 | 
461 |         predictions_by_image = []
462 |         slice_start = 0
463 |         for idx, image in enumerate(images):
464 |             slice_end = slice_start + flat["slice_map"][idx]
465 |             image_lines = char_predictions[slice_start:slice_end]
466 |             polygons = flat["polygons"][slice_start:slice_end]
467 |             res_scales = flat["res_scales"][slice_start:slice_end]
468 |             slice_start = slice_end
469 | 
470 |             lines = []
471 |             for text_line, polygon, res_scale in zip(image_lines, polygons, res_scales):
472 |                 # Special case when input text is good
473 |                 if not text_line:
474 |                     lines.append(
475 |                         TextLine(
476 |                             text="",
477 |                             polygon=polygon,
478 |                             chars=[],
479 |                             confidence=1,
480 |                             original_text_good=True,
481 |                         )
482 |                     )
483 |                 else:
484 |                     confidence = (
485 |                         float(np.mean([char.confidence for char in text_line]))
486 |                         if len(text_line) > 0
487 |                         else 0
488 |                     )
489 |                     poly_box = PolygonBox(polygon=polygon)
490 |                     for char in text_line:
491 |                         char.rescale(
492 |                             res_scale, (1, 1)
493 |                         )  # Rescale from highres if needed
494 |                         char.shift(
495 |                             poly_box.bbox[0], poly_box.bbox[1]
496 |                         )  # Ensure character boxes match line boxes (relative to page)
497 |                         char.clamp(poly_box.bbox)
498 | 
499 |                     text_line = fix_unbalanced_tags(
500 |                         text_line, self.processor.ocr_tokenizer.special_tokens
501 |                     )
502 |                     text_line = filter_blacklist_tags(text_line, filter_tag_list)
503 |                     text = "".join([char.text for char in text_line])
504 |                     text = unwrap_math(text)
505 |                     text = clean_math_tags(text)
506 |                     lines.append(
507 |                         TextLine(
508 |                             text=text,
509 |                             polygon=polygon,
510 |                             chars=text_line,
511 |                             confidence=confidence,
512 |                             words=words_from_chars(text_line, poly_box)
513 |                             if return_words
514 |                             else [],
515 |                         )
516 |                     )
517 | 
518 |             if sort_lines:
519 |                 lines = sort_text_lines(lines)
520 |             predictions_by_image.append(
521 |                 OCRResult(
522 |                     text_lines=lines, image_bbox=[0, 0, image.size[0], image.size[1]]
523 |                 )
524 |             )
525 | 
526 |         return predictions_by_image
527 | 
```

--------------------------------------------------------------------------------
/surya/common/surya/decoder/__init__.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import Callable, List, Optional, Tuple, Union
  2 | 
  3 | import torch
  4 | from torch import nn
  5 | 
  6 | from transformers.activations import ACT2FN
  7 | from transformers.cache_utils import (
  8 |     Cache,
  9 | )
 10 | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
 11 | from transformers.modeling_outputs import (
 12 |     BaseModelOutputWithPast,
 13 | )
 14 | from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
 15 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
 16 | from transformers.processing_utils import Unpack
 17 | from transformers.utils import (
 18 |     logging,
 19 | )
 20 | 
 21 | from surya.common.pretrained import SuryaPreTrainedModel
 22 | from surya.common.surya.decoder.config import SuryaDecoderConfig
 23 | 
 24 | 
 25 | logger = logging.get_logger(__name__)
 26 | 
 27 | 
 28 | class Qwen2MLP(nn.Module):
 29 |     def __init__(self, config):
 30 |         super().__init__()
 31 |         self.config = config
 32 |         self.hidden_size = config.hidden_size
 33 |         self.intermediate_size = config.intermediate_size
 34 |         self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
 35 |         self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
 36 |         self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 37 |         self.act_fn = ACT2FN[config.hidden_act]
 38 | 
 39 |     def forward(self, x):
 40 |         down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 41 |         return down_proj
 42 | 
 43 | 
 44 | def rotate_half(x):
 45 |     """Rotates half the hidden dims of the input."""
 46 |     x1 = x[..., : x.shape[-1] // 2]
 47 |     x2 = x[..., x.shape[-1] // 2 :]
 48 |     return torch.cat((-x2, x1), dim=-1)
 49 | 
 50 | 
 51 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
 52 |     """Applies Rotary Position Embedding to the query and key tensors.
 53 | 
 54 |     Args:
 55 |         q (`torch.Tensor`): The query tensor.
 56 |         k (`torch.Tensor`): The key tensor.
 57 |         cos (`torch.Tensor`): The cosine part of the rotary embedding.
 58 |         sin (`torch.Tensor`): The sine part of the rotary embedding.
 59 |         position_ids (`torch.Tensor`, *optional*):
 60 |             Deprecated and unused.
 61 |         unsqueeze_dim (`int`, *optional*, defaults to 1):
 62 |             The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
 63 |             sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
 64 |             that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
 65 |             k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
 66 |             cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
 67 |             the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
 68 |     Returns:
 69 |         `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
 70 |     """
 71 |     cos = cos.unsqueeze(unsqueeze_dim)
 72 |     sin = sin.unsqueeze(unsqueeze_dim)
 73 |     q_embed = (q * cos) + (rotate_half(q) * sin)
 74 |     k_embed = (k * cos) + (rotate_half(k) * sin)
 75 |     return q_embed, k_embed
 76 | 
 77 | 
 78 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 79 |     """
 80 |     This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 81 |     num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
 82 |     """
 83 |     batch, num_key_value_heads, slen, head_dim = hidden_states.shape
 84 |     if n_rep == 1:
 85 |         return hidden_states
 86 |     hidden_states = hidden_states[:, :, None, :, :].expand(
 87 |         batch, num_key_value_heads, n_rep, slen, head_dim
 88 |     )
 89 |     return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 90 | 
 91 | 
 92 | def eager_attention_forward(
 93 |     module: nn.Module,
 94 |     query: torch.Tensor,
 95 |     key: torch.Tensor,
 96 |     value: torch.Tensor,
 97 |     attention_mask: Optional[torch.Tensor],
 98 |     scaling: float,
 99 |     dropout: float = 0.0,
100 |     **kwargs,
101 | ):
102 |     key_states = repeat_kv(key, module.num_key_value_groups)
103 |     value_states = repeat_kv(value, module.num_key_value_groups)
104 | 
105 |     attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
106 |     if attention_mask is not None:
107 |         causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
108 |         attn_weights = attn_weights + causal_mask
109 | 
110 |     attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
111 |         query.dtype
112 |     )
113 |     attn_weights = nn.functional.dropout(
114 |         attn_weights, p=dropout, training=module.training
115 |     )
116 |     attn_output = torch.matmul(attn_weights, value_states)
117 |     attn_output = attn_output.transpose(1, 2).contiguous()
118 | 
119 |     return attn_output, attn_weights
120 | 
121 | 
122 | class Qwen2Attention(nn.Module):
123 |     """Multi-headed attention from 'Attention Is All You Need' paper"""
124 | 
125 |     def __init__(self, config: SuryaDecoderConfig, layer_idx: int):
126 |         super().__init__()
127 |         self.config = config
128 |         self.layer_idx = layer_idx
129 |         self.head_dim = getattr(
130 |             config, "head_dim", config.hidden_size // config.num_attention_heads
131 |         )
132 |         self.num_key_value_groups = (
133 |             config.num_attention_heads // config.num_key_value_heads
134 |         )
135 |         self.scaling = self.head_dim**-0.5
136 |         self.attention_dropout = config.attention_dropout
137 |         self.is_causal = True
138 |         self.q_proj = nn.Linear(
139 |             config.hidden_size, config.num_attention_heads * self.head_dim, bias=True
140 |         )
141 |         self.k_proj = nn.Linear(
142 |             config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True
143 |         )
144 |         self.v_proj = nn.Linear(
145 |             config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True
146 |         )
147 |         self.o_proj = nn.Linear(
148 |             config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
149 |         )
150 | 
151 |     def forward(
152 |         self,
153 |         hidden_states: torch.Tensor,
154 |         position_embeddings: Tuple[torch.Tensor, torch.Tensor],
155 |         attention_mask: Optional[torch.Tensor],
156 |         past_key_value: Optional[Cache] = None,
157 |         cache_position: Optional[torch.LongTensor] = None,
158 |         cache_idxs: Optional[List[int]] = None,
159 |         num_valid_tokens: Optional[List[int]] = None,
160 |         text_lengths: Optional[List[int]] = None,
161 |         prefill: bool = False,
162 |         **kwargs: Unpack[FlashAttentionKwargs],
163 |     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
164 |         input_shape = hidden_states.shape[:-1]
165 |         hidden_shape = (*input_shape, -1, self.head_dim)
166 | 
167 |         query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
168 |         key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
169 |         value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
170 | 
171 |         cos, sin = position_embeddings
172 |         query_states, key_states = apply_rotary_pos_emb(
173 |             query_states, key_states, cos, sin
174 |         )
175 | 
176 |         if past_key_value is not None:
177 |             # sin and cos are specific to RoPE models; cache_position needed for the static cache
178 |             # cache_idxs, num_valid_tokens, and prefill add support for our new caching mechanism
179 |             cache_kwargs = {
180 |                 "sin": sin,
181 |                 "cos": cos,
182 |                 "cache_position": cache_position,
183 |                 "cache_idxs": cache_idxs,
184 |                 "num_valid_tokens": num_valid_tokens,
185 |                 "prefill": prefill,
186 |                 "text_lengths": text_lengths,
187 |             }
188 |             key_states, value_states = past_key_value.update(
189 |                 key_states, value_states, self.layer_idx, cache_kwargs
190 |             )
191 | 
192 |         attention_interface: Callable = eager_attention_forward
193 |         if self.config._attn_implementation != "eager":
194 |             if self.config._attn_implementation == "sdpa" and kwargs.get(
195 |                 "output_attentions", False
196 |             ):
197 |                 logger.warning_once(
198 |                     "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
199 |                     'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
200 |                 )
201 |             elif self.config._attn_implementation == "flash_attention_2":
202 |                 # Needed for CPU -> GPU
203 |                 from surya.common.surya.flash_attn_utils import (
204 |                     flash_attn_decode,
205 |                     flash_attn_prefill,
206 |                 )
207 | 
208 |                 if prefill:
209 |                     attention_interface = flash_attn_prefill
210 |                 else:
211 |                     attention_interface = flash_attn_decode
212 |             else:
213 |                 attention_interface = ALL_ATTENTION_FUNCTIONS[
214 |                     self.config._attn_implementation
215 |                 ]
216 | 
217 |         """
218 |         IMPORTANT:
219 |         We sometimes use a custom sliding window impl. during training
220 | 
221 |         We force this to None to ensure that the HF attention integrations do not
222 |         perform any special handling - FA2 in particular will ignore the 4D mask, and use this instead
223 |         to infer the final mask
224 | 
225 |         SDPA ignores this completely, and is fully dependent on the 4D mask - (https://github.com/huggingface/transformers/blob/b9faf2f93085e3cf2c65184a69d1d9e502f95786/src/transformers/integrations/sdpa_attention.py#L23)
226 |         """
227 |         sliding_window = None
228 | 
229 |         attn_output, attn_weights = attention_interface(
230 |             self,
231 |             query_states,
232 |             key_states,
233 |             value_states,
234 |             attention_mask,
235 |             dropout=0.0 if not self.training else self.attention_dropout,
236 |             scaling=self.scaling,
237 |             sliding_window=sliding_window,  # main diff with Llama
238 |             **kwargs,
239 |         )
240 | 
241 |         attn_output = attn_output.reshape(*input_shape, -1).contiguous()
242 |         attn_output = self.o_proj(attn_output)
243 |         return attn_output, attn_weights
244 | 
245 | 
246 | class Qwen2RMSNorm(nn.Module):
247 |     def __init__(self, hidden_size, eps=1e-6):
248 |         """
249 |         Qwen2RMSNorm is equivalent to T5LayerNorm
250 |         """
251 |         super().__init__()
252 |         self.weight = nn.Parameter(torch.ones(hidden_size))
253 |         self.variance_epsilon = eps
254 | 
255 |     def forward(self, hidden_states):
256 |         input_dtype = hidden_states.dtype
257 |         hidden_states = hidden_states.to(torch.float32)
258 |         variance = hidden_states.pow(2).mean(-1, keepdim=True)
259 |         hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
260 |         return self.weight * hidden_states.to(input_dtype)
261 | 
262 |     def extra_repr(self):
263 |         return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
264 | 
265 | 
266 | class Qwen2DecoderLayer(nn.Module):
267 |     def __init__(self, config: SuryaDecoderConfig, layer_idx: int):
268 |         super().__init__()
269 |         self.hidden_size = config.hidden_size
270 |         self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
271 |         self.mlp = Qwen2MLP(config)
272 |         self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
273 |         self.post_attention_layernorm = Qwen2RMSNorm(
274 |             config.hidden_size, eps=config.rms_norm_eps
275 |         )
276 | 
277 |     def forward(
278 |         self,
279 |         hidden_states: torch.Tensor,
280 |         attention_mask: Optional[torch.Tensor] = None,
281 |         position_ids: Optional[torch.LongTensor] = None,
282 |         past_key_value: Optional[Cache] = None,
283 |         output_attentions: Optional[bool] = False,
284 |         use_cache: Optional[bool] = False,
285 |         cache_position: Optional[torch.LongTensor] = None,
286 |         cache_idxs: Optional[List[int]] = None,
287 |         num_valid_tokens: Optional[List[int]] = None,
288 |         text_lengths: Optional[List[int]] = None,
289 |         prefill: bool = False,
290 |         position_embeddings: Optional[
291 |             Tuple[torch.Tensor, torch.Tensor]
292 |         ] = None,  # necessary, but kept here for BC
293 |         **kwargs: Unpack[FlashAttentionKwargs],
294 |     ) -> Tuple[
295 |         torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
296 |     ]:
297 |         residual = hidden_states
298 | 
299 |         hidden_states = self.input_layernorm(hidden_states)
300 | 
301 |         # Self Attention
302 |         hidden_states, self_attn_weights = self.self_attn(
303 |             hidden_states=hidden_states,
304 |             attention_mask=attention_mask,
305 |             position_ids=position_ids,
306 |             past_key_value=past_key_value,
307 |             output_attentions=output_attentions,
308 |             use_cache=use_cache,
309 |             cache_position=cache_position,
310 |             position_embeddings=position_embeddings,
311 |             cache_idxs=cache_idxs,
312 |             num_valid_tokens=num_valid_tokens,
313 |             text_lengths=text_lengths,
314 |             prefill=prefill,
315 |             **kwargs,
316 |         )
317 |         hidden_states = residual + hidden_states
318 | 
319 |         # Fully Connected
320 |         residual = hidden_states
321 |         hidden_states = self.post_attention_layernorm(hidden_states)
322 |         hidden_states = self.mlp(hidden_states)
323 |         hidden_states = residual + hidden_states
324 | 
325 |         outputs = (hidden_states,)
326 |         if output_attentions:
327 |             outputs += (self_attn_weights,)
328 | 
329 |         return outputs
330 | 
331 | 
332 | class Qwen2RotaryEmbedding(nn.Module):
333 |     def __init__(self, config: SuryaDecoderConfig, device=None):
334 |         super().__init__()
335 |         # BC: "rope_type" was originally "type"
336 |         if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
337 |             self.rope_type = config.rope_scaling.get(
338 |                 "rope_type", config.rope_scaling.get("type")
339 |             )
340 |         else:
341 |             self.rope_type = "default"
342 |         self.max_seq_len_cached = config.max_position_embeddings
343 |         self.original_max_seq_len = config.max_position_embeddings
344 | 
345 |         self.config = config
346 |         self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
347 | 
348 |         inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
349 |         self.register_buffer("inv_freq", inv_freq, persistent=False)
350 |         self.original_inv_freq = self.inv_freq
351 | 
352 |     def _dynamic_frequency_update(self, position_ids, device):
353 |         """
354 |         dynamic RoPE layers should recompute `inv_freq` in the following situations:
355 |         1 - growing beyond the cached sequence length (allow scaling)
356 |         2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
357 |         """
358 |         seq_len = torch.max(position_ids) + 1
359 |         if seq_len > self.max_seq_len_cached:  # growth
360 |             inv_freq, self.attention_scaling = self.rope_init_fn(
361 |                 self.config, device, seq_len=seq_len
362 |             )
363 |             self.register_buffer(
364 |                 "inv_freq", inv_freq, persistent=False
365 |             )  # TODO joao: may break with compilation
366 |             self.max_seq_len_cached = seq_len
367 | 
368 |         if (
369 |             seq_len < self.original_max_seq_len
370 |             and self.max_seq_len_cached > self.original_max_seq_len
371 |         ):  # reset
372 |             # This .to() is needed if the model has been moved to a device after being initialized (because
373 |             # the buffer is automatically moved, but not the original copy)
374 |             self.original_inv_freq = self.original_inv_freq.to(device)
375 |             self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
376 |             self.max_seq_len_cached = self.original_max_seq_len
377 | 
378 |     @torch.no_grad()
379 |     def forward(self, x, position_ids):
380 |         if "dynamic" in self.rope_type:
381 |             self._dynamic_frequency_update(position_ids, device=x.device)
382 | 
383 |         # Core RoPE block
384 |         inv_freq_expanded = (
385 |             self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
386 |         )
387 |         position_ids_expanded = position_ids[:, None, :].float()
388 |         # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
389 |         device_type = x.device.type
390 |         device_type = (
391 |             device_type
392 |             if isinstance(device_type, str) and device_type != "mps"
393 |             else "cpu"
394 |         )
395 |         with torch.autocast(device_type=device_type, enabled=False):
396 |             freqs = (
397 |                 inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()
398 |             ).transpose(1, 2)
399 |             emb = torch.cat((freqs, freqs), dim=-1)
400 |             cos = emb.cos()
401 |             sin = emb.sin()
402 | 
403 |         # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
404 |         cos = cos * self.attention_scaling
405 |         sin = sin * self.attention_scaling
406 | 
407 |         return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
408 | 
409 | 
410 | class Qwen2PreTrainedModel(SuryaPreTrainedModel):
411 |     config_class = SuryaDecoderConfig
412 |     base_model_prefix = "model"
413 |     supports_gradient_checkpointing = True
414 |     _no_split_modules = ["Qwen2DecoderLayer"]
415 |     _skip_keys_device_placement = ["past_key_values"]
416 |     _supports_flash_attn_2 = True
417 |     _supports_sdpa = True
418 |     _supports_flex_attn = True
419 |     _supports_cache_class = True
420 |     _supports_quantized_cache = True
421 |     _supports_static_cache = True
422 |     _supports_attention_backend = True
423 | 
424 |     def _init_weights(self, module):
425 |         std = self.config.initializer_range
426 |         if isinstance(module, nn.Linear):
427 |             module.weight.data.normal_(mean=0.0, std=std)
428 |             if module.bias is not None:
429 |                 module.bias.data.zero_()
430 |         elif isinstance(module, nn.Embedding):
431 |             module.weight.data.normal_(mean=0.0, std=std)
432 |             if module.padding_idx is not None:
433 |                 module.weight.data[module.padding_idx].zero_()
434 | 
435 | 
436 | class SuryaDecoderModel(Qwen2PreTrainedModel):
437 |     """
438 |     Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
439 |     This variant has been modified to remove the embedding layer completely - It only supports inputs_embeds as an input
440 | 
441 |     Args:
442 |         config: Qwen2Config
443 |     """
444 | 
445 |     def __init__(self, config: SuryaDecoderConfig):
446 |         super().__init__(config)
447 |         self.padding_idx = config.pad_token_id
448 |         self.vocab_size = config.vocab_size
449 | 
450 |         self.layers = nn.ModuleList(
451 |             [
452 |                 Qwen2DecoderLayer(config, layer_idx)
453 |                 for layer_idx in range(config.num_hidden_layers)
454 |             ]
455 |         )
456 |         self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
457 |         self.rotary_emb = Qwen2RotaryEmbedding(config=config)
458 |         self.gradient_checkpointing = False
459 | 
460 |         # Initialize weights and apply final processing
461 |         self.post_init()
462 | 
463 |     def forward(
464 |         self,
465 |         attention_mask: Optional[torch.Tensor] = None,
466 |         position_ids: Optional[torch.LongTensor] = None,
467 |         past_key_values: Optional[Cache] = None,
468 |         inputs_embeds: Optional[torch.FloatTensor] = None,
469 |         use_cache: Optional[bool] = None,
470 |         output_attentions: Optional[bool] = None,
471 |         output_hidden_states: Optional[bool] = None,
472 |         return_dict: Optional[bool] = None,
473 |         cache_position: Optional[torch.LongTensor] = None,
474 |         cache_idxs: Optional[List[int]] = None,
475 |         num_valid_tokens: Optional[List[int]] = None,
476 |         text_lengths: Optional[List[int]] = None,
477 |         prefill: bool = False,
478 |         **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
479 |     ) -> Union[Tuple, BaseModelOutputWithPast]:
480 |         use_cache = use_cache if use_cache is not None else self.config.use_cache
481 |         return_dict = (
482 |             return_dict if return_dict is not None else self.config.use_return_dict
483 |         )
484 | 
485 |         if inputs_embeds is None:
486 |             raise ValueError("You must specify inputs_embeds")
487 | 
488 |         if cache_position is None:
489 |             raise ValueError("You must specify cache_position")
490 | 
491 |         if position_ids is None:
492 |             raise ValueError("You must specify position_ids")
493 | 
494 |         hidden_states = inputs_embeds
495 |         causal_mask = (
496 |             attention_mask  # We make the 4D mask in the combined model when needed
497 |         )
498 | 
499 |         # create position embeddings to be shared across the decoder layers
500 |         position_embeddings = self.rotary_emb(hidden_states, position_ids)
501 | 
502 |         # decoder layers
503 |         for decoder_layer in self.layers[: self.config.num_hidden_layers]:
504 |             layer_outputs = decoder_layer(
505 |                 hidden_states,
506 |                 attention_mask=causal_mask,
507 |                 position_ids=position_ids,
508 |                 past_key_value=past_key_values,
509 |                 output_attentions=output_attentions,
510 |                 use_cache=use_cache,
511 |                 cache_position=cache_position,
512 |                 position_embeddings=position_embeddings,
513 |                 cache_idxs=cache_idxs,
514 |                 num_valid_tokens=num_valid_tokens,
515 |                 prefill=prefill,
516 |                 text_lengths=text_lengths,
517 |                 **flash_attn_kwargs,
518 |             )
519 | 
520 |             hidden_states = layer_outputs[0]
521 | 
522 |         hidden_states = self.norm(hidden_states)
523 | 
524 |         output = BaseModelOutputWithPast(
525 |             last_hidden_state=hidden_states,
526 |             past_key_values=past_key_values if use_cache else None,
527 |         )
528 |         return output if return_dict else output.to_tuple()
529 | 
```
Page 3/5FirstPrevNextLast