#
tokens: 49452/50000 6/133 files (page 4/5)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 4 of 5. Use http://codebase.md/datalab-to/surya?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── ISSUE_TEMPLATE
│   │   ├── breaking-bug-report.md
│   │   ├── feature_request.md
│   │   └── output-bug-report.md
│   └── workflows
│       ├── benchmarks.yml
│       ├── ci.yml
│       ├── cla.yml
│       ├── publish.yml
│       └── scripts.yml
├── .gitignore
├── .pre-commit-config.yaml
├── benchmark
│   ├── detection.py
│   ├── layout.py
│   ├── ordering.py
│   ├── recognition.py
│   ├── table_recognition.py
│   ├── texify.py
│   └── utils
│       ├── __init__.py
│       ├── bbox.py
│       ├── metrics.py
│       ├── scoring.py
│       ├── tatr.py
│       ├── tesseract.py
│       ├── textract.py
│       └── verify_benchmark_scores.py
├── CITATION.cff
├── CLA.md
├── detect_layout.py
├── detect_text.py
├── LICENSE
├── ocr_app.py
├── ocr_latex.py
├── ocr_text.py
├── poetry.lock
├── pyproject.toml
├── pytest.ini
├── README.md
├── signatures
│   └── version1
│       └── cla.json
├── static
│   ├── fonts
│   │   └── .gitignore
│   └── images
│       ├── arabic_layout.jpg
│       ├── arabic_reading.jpg
│       ├── arabic_text.jpg
│       ├── arabic.jpg
│       ├── benchmark_chart_small.png
│       ├── benchmark_chart.png
│       ├── benchmark_layout_chart.png
│       ├── benchmark_rec_chart.png
│       ├── benchmark_tablerec_acc.png
│       ├── benchmark_tablerec_speed.png
│       ├── chi_hind_layout.jpg
│       ├── chi_hind_orig.jpg
│       ├── chi_hind_reading.jpg
│       ├── chi_hind_text.jpg
│       ├── chi_hind.jpg
│       ├── chinese_layout.jpg
│       ├── chinese_reading.jpg
│       ├── chinese_text.jpg
│       ├── chinese.jpg
│       ├── excerpt_layout.png
│       ├── excerpt_reading.jpg
│       ├── excerpt_text.png
│       ├── excerpt.png
│       ├── funsd_layout.jpg
│       ├── funsd_reading.jpg
│       ├── funsd_text.jpg
│       ├── funsd.png
│       ├── gcloud_full_langs.png
│       ├── gcloud_rec_bench.png
│       ├── hindi_layout.jpg
│       ├── hindi_reading.jpg
│       ├── hindi_text.jpg
│       ├── hindi.jpg
│       ├── japanese_layout.jpg
│       ├── japanese_reading.jpg
│       ├── japanese_tablerec.png
│       ├── japanese_text.jpg
│       ├── japanese.jpg
│       ├── latex_ocr.png
│       ├── nyt_layout.jpg
│       ├── nyt_order.jpg
│       ├── nyt_text.jpg
│       ├── nyt.jpg
│       ├── paper_layout.jpg
│       ├── paper_reading.jpg
│       ├── paper_tablerec.png
│       ├── paper_text.jpg
│       ├── paper.jpg
│       ├── pres_layout.jpg
│       ├── pres_reading.jpg
│       ├── pres_tablerec.png
│       ├── pres_text.jpg
│       ├── pres.png
│       ├── rec_acc_table.png
│       ├── scanned_layout.jpg
│       ├── scanned_reading.jpg
│       ├── scanned_tablerec.png
│       ├── scanned_tablerec2.png
│       ├── scanned_text.jpg
│       ├── scanned.png
│       ├── surya_rec_perf.png
│       ├── table_rec.png
│       ├── textbook_layout.jpg
│       ├── textbook_order.jpg
│       ├── textbook_text.jpg
│       └── textbook.jpg
├── surya
│   ├── __init__.py
│   ├── common
│   │   ├── __init__.py
│   │   ├── adetr
│   │   │   └── decoder.py
│   │   ├── donut
│   │   │   ├── encoder.py
│   │   │   └── processor.py
│   │   ├── load.py
│   │   ├── polygon.py
│   │   ├── predictor.py
│   │   ├── pretrained.py
│   │   ├── s3.py
│   │   ├── surya
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   ├── decoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── embedder
│   │   │   │   └── __init__.py
│   │   │   ├── encoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── flash_attn_utils.py
│   │   │   ├── processor
│   │   │   │   ├── __init__.py
│   │   │   │   ├── schema.py
│   │   │   │   └── tokenizer.py
│   │   │   └── schema.py
│   │   ├── util.py
│   │   └── xla.py
│   ├── debug
│   │   ├── draw.py
│   │   ├── fonts.py
│   │   ├── katex.js
│   │   ├── render_html.py
│   │   └── text.py
│   ├── detection
│   │   ├── __init__.py
│   │   ├── heatmap.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoderdecoder.py
│   │   ├── parallel.py
│   │   ├── processor.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── foundation
│   │   ├── __init__.py
│   │   ├── cache
│   │   │   ├── __init__.py
│   │   │   ├── dynamic_ops.py
│   │   │   └── static_ops.py
│   │   ├── loader.py
│   │   └── util.py
│   ├── input
│   │   ├── load.py
│   │   └── processing.py
│   ├── layout
│   │   ├── __init__.py
│   │   ├── label.py
│   │   └── schema.py
│   ├── logging.py
│   ├── models.py
│   ├── ocr_error
│   │   ├── __init__.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoder.py
│   │   ├── schema.py
│   │   └── tokenizer.py
│   ├── recognition
│   │   ├── __init__.py
│   │   ├── languages.py
│   │   ├── postprocessing.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── scripts
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── detect_layout.py
│   │   ├── detect_text.py
│   │   ├── finetune_ocr.py
│   │   ├── hf_to_s3.py
│   │   ├── ocr_latex.py
│   │   ├── ocr_text.py
│   │   ├── run_streamlit_app.py
│   │   ├── run_texify_app.py
│   │   ├── streamlit_app.py
│   │   ├── table_recognition.py
│   │   └── texify_app.py
│   ├── settings.py
│   └── table_rec
│       ├── __init__.py
│       ├── loader.py
│       ├── model
│       │   ├── __init__.py
│       │   ├── config.py
│       │   ├── decoder.py
│       │   ├── encoder.py
│       │   └── encoderdecoder.py
│       ├── processor.py
│       ├── schema.py
│       └── shaper.py
├── table_recognition.py
├── tests
│   ├── assets
│   │   └── test_latex.png
│   ├── conftest.py
│   ├── test_detection.py
│   ├── test_foundation.py
│   ├── test_latex_ocr.py
│   ├── test_layout.py
│   ├── test_ocr_errors.py
│   ├── test_recognition.py
│   └── test_table_rec.py
└── texify_app.py
```

# Files

--------------------------------------------------------------------------------
/surya/ocr_error/tokenizer.py:
--------------------------------------------------------------------------------

```python
  1 | import collections
  2 | import os
  3 | import json
  4 | import unicodedata
  5 | from typing import List, Optional, Tuple
  6 | 
  7 | from tokenizers import normalizers
  8 | 
  9 | from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
 10 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
 11 | 
 12 | from surya.common.s3 import S3DownloaderMixin
 13 | 
 14 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
 15 | 
 16 | # Copied from transformers.models.bert.tokenization_bert.load_vocab
 17 | def load_vocab(vocab_file):
 18 |     """Loads a vocabulary file into a dictionary."""
 19 |     vocab = collections.OrderedDict()
 20 |     with open(vocab_file, "r", encoding="utf-8") as reader:
 21 |         tokens = reader.readlines()
 22 |     for index, token in enumerate(tokens):
 23 |         token = token.rstrip("\n")
 24 |         vocab[token] = index
 25 |     return vocab
 26 | 
 27 | 
 28 | # Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
 29 | def whitespace_tokenize(text):
 30 |     """Runs basic whitespace cleaning and splitting on a piece of text."""
 31 |     text = text.strip()
 32 |     if not text:
 33 |         return []
 34 |     tokens = text.split()
 35 |     return tokens
 36 | 
 37 | 
 38 | class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
 39 |     r"""
 40 |     Construct a DistilBERT tokenizer. Based on WordPiece.
 41 | 
 42 |     This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
 43 |     this superclass for more information regarding those methods.
 44 | 
 45 |     Args:
 46 |         vocab_file (`str`):
 47 |             File containing the vocabulary.
 48 |         do_lower_case (`bool`, *optional*, defaults to `True`):
 49 |             Whether or not to lowercase the input when tokenizing.
 50 |         do_basic_tokenize (`bool`, *optional*, defaults to `True`):
 51 |             Whether or not to do basic tokenization before WordPiece.
 52 |         never_split (`Iterable`, *optional*):
 53 |             Collection of tokens which will never be split during tokenization. Only has an effect when
 54 |             `do_basic_tokenize=True`
 55 |         unk_token (`str`, *optional*, defaults to `"[UNK]"`):
 56 |             The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
 57 |             token instead.
 58 |         sep_token (`str`, *optional*, defaults to `"[SEP]"`):
 59 |             The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
 60 |             sequence classification or for a text and a question for question answering. It is also used as the last
 61 |             token of a sequence built with special tokens.
 62 |         pad_token (`str`, *optional*, defaults to `"[PAD]"`):
 63 |             The token used for padding, for example when batching sequences of different lengths.
 64 |         cls_token (`str`, *optional*, defaults to `"[CLS]"`):
 65 |             The classifier token which is used when doing sequence classification (classification of the whole sequence
 66 |             instead of per-token classification). It is the first token of the sequence when built with special tokens.
 67 |         mask_token (`str`, *optional*, defaults to `"[MASK]"`):
 68 |             The token used for masking values. This is the token used when training this model with masked language
 69 |             modeling. This is the token which the model will try to predict.
 70 |         tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
 71 |             Whether or not to tokenize Chinese characters.
 72 | 
 73 |             This should likely be deactivated for Japanese (see this
 74 |             [issue](https://github.com/huggingface/transformers/issues/328)).
 75 |         strip_accents (`bool`, *optional*):
 76 |             Whether or not to strip all accents. If this option is not specified, then it will be determined by the
 77 |             value for `lowercase` (as in the original BERT).
 78 |     """
 79 | 
 80 |     vocab_files_names = VOCAB_FILES_NAMES
 81 |     model_input_names = ["input_ids", "attention_mask"]
 82 | 
 83 |     def __init__(
 84 |         self,
 85 |         vocab_file,
 86 |         do_lower_case=True,
 87 |         do_basic_tokenize=True,
 88 |         never_split=None,
 89 |         unk_token="[UNK]",
 90 |         sep_token="[SEP]",
 91 |         pad_token="[PAD]",
 92 |         cls_token="[CLS]",
 93 |         mask_token="[MASK]",
 94 |         tokenize_chinese_chars=True,
 95 |         strip_accents=None,
 96 |         **kwargs,
 97 |     ):
 98 |         if not os.path.isfile(vocab_file):
 99 |             raise ValueError(
100 |                 f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
101 |                 " model use `tokenizer = DistilBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
102 |             )
103 |         self.vocab = load_vocab(vocab_file)
104 |         self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
105 |         self.do_basic_tokenize = do_basic_tokenize
106 |         if do_basic_tokenize:
107 |             self.basic_tokenizer = BasicTokenizer(
108 |                 do_lower_case=do_lower_case,
109 |                 never_split=never_split,
110 |                 tokenize_chinese_chars=tokenize_chinese_chars,
111 |                 strip_accents=strip_accents,
112 |             )
113 |         self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
114 | 
115 |         super().__init__(
116 |             do_lower_case=do_lower_case,
117 |             do_basic_tokenize=do_basic_tokenize,
118 |             never_split=never_split,
119 |             unk_token=unk_token,
120 |             sep_token=sep_token,
121 |             pad_token=pad_token,
122 |             cls_token=cls_token,
123 |             mask_token=mask_token,
124 |             tokenize_chinese_chars=tokenize_chinese_chars,
125 |             strip_accents=strip_accents,
126 |             **kwargs,
127 |         )
128 | 
129 |     @property
130 |     # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.do_lower_case
131 |     def do_lower_case(self):
132 |         return self.basic_tokenizer.do_lower_case
133 | 
134 |     @property
135 |     # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size
136 |     def vocab_size(self):
137 |         return len(self.vocab)
138 | 
139 |     # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab
140 |     def get_vocab(self):
141 |         return dict(self.vocab, **self.added_tokens_encoder)
142 | 
143 |     # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
144 |     def _tokenize(self, text, split_special_tokens=False):
145 |         split_tokens = []
146 |         if self.do_basic_tokenize:
147 |             for token in self.basic_tokenizer.tokenize(
148 |                 text, never_split=self.all_special_tokens if not split_special_tokens else None
149 |             ):
150 |                 # If the token is part of the never_split set
151 |                 if token in self.basic_tokenizer.never_split:
152 |                     split_tokens.append(token)
153 |                 else:
154 |                     split_tokens += self.wordpiece_tokenizer.tokenize(token)
155 |         else:
156 |             split_tokens = self.wordpiece_tokenizer.tokenize(text)
157 |         return split_tokens
158 | 
159 |     # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id
160 |     def _convert_token_to_id(self, token):
161 |         """Converts a token (str) in an id using the vocab."""
162 |         return self.vocab.get(token, self.vocab.get(self.unk_token))
163 | 
164 |     # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token
165 |     def _convert_id_to_token(self, index):
166 |         """Converts an index (integer) in a token (str) using the vocab."""
167 |         return self.ids_to_tokens.get(index, self.unk_token)
168 | 
169 |     # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string
170 |     def convert_tokens_to_string(self, tokens):
171 |         """Converts a sequence of tokens (string) in a single string."""
172 |         out_string = " ".join(tokens).replace(" ##", "").strip()
173 |         return out_string
174 | 
175 |     # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens
176 |     def build_inputs_with_special_tokens(
177 |         self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
178 |     ) -> List[int]:
179 |         """
180 |         Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
181 |         adding special tokens. A BERT sequence has the following format:
182 | 
183 |         - single sequence: `[CLS] X [SEP]`
184 |         - pair of sequences: `[CLS] A [SEP] B [SEP]`
185 | 
186 |         Args:
187 |             token_ids_0 (`List[int]`):
188 |                 List of IDs to which the special tokens will be added.
189 |             token_ids_1 (`List[int]`, *optional*):
190 |                 Optional second list of IDs for sequence pairs.
191 | 
192 |         Returns:
193 |             `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
194 |         """
195 |         if token_ids_1 is None:
196 |             return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
197 |         cls = [self.cls_token_id]
198 |         sep = [self.sep_token_id]
199 |         return cls + token_ids_0 + sep + token_ids_1 + sep
200 | 
201 |     # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask
202 |     def get_special_tokens_mask(
203 |         self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
204 |     ) -> List[int]:
205 |         """
206 |         Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
207 |         special tokens using the tokenizer `prepare_for_model` method.
208 | 
209 |         Args:
210 |             token_ids_0 (`List[int]`):
211 |                 List of IDs.
212 |             token_ids_1 (`List[int]`, *optional*):
213 |                 Optional second list of IDs for sequence pairs.
214 |             already_has_special_tokens (`bool`, *optional*, defaults to `False`):
215 |                 Whether or not the token list is already formatted with special tokens for the model.
216 | 
217 |         Returns:
218 |             `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
219 |         """
220 | 
221 |         if already_has_special_tokens:
222 |             return super().get_special_tokens_mask(
223 |                 token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
224 |             )
225 | 
226 |         if token_ids_1 is not None:
227 |             return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
228 |         return [1] + ([0] * len(token_ids_0)) + [1]
229 | 
230 |     # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences
231 |     def create_token_type_ids_from_sequences(
232 |         self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
233 |     ) -> List[int]:
234 |         """
235 |         Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
236 |         pair mask has the following format:
237 | 
238 |         ```
239 |         0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
240 |         | first sequence    | second sequence |
241 |         ```
242 | 
243 |         If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
244 | 
245 |         Args:
246 |             token_ids_0 (`List[int]`):
247 |                 List of IDs.
248 |             token_ids_1 (`List[int]`, *optional*):
249 |                 Optional second list of IDs for sequence pairs.
250 | 
251 |         Returns:
252 |             `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
253 |         """
254 |         sep = [self.sep_token_id]
255 |         cls = [self.cls_token_id]
256 |         if token_ids_1 is None:
257 |             return len(cls + token_ids_0 + sep) * [0]
258 |         return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
259 | 
260 |     # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary
261 |     def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
262 |         index = 0
263 |         if os.path.isdir(save_directory):
264 |             vocab_file = os.path.join(
265 |                 save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
266 |             )
267 |         else:
268 |             vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
269 |         with open(vocab_file, "w", encoding="utf-8") as writer:
270 |             for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
271 |                 if index != token_index:
272 |                     # logger.warning(
273 |                     #     f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
274 |                     #     " Please check that the vocabulary is not corrupted!"
275 |                     # )
276 |                     index = token_index
277 |                 writer.write(token + "\n")
278 |                 index += 1
279 |         return (vocab_file,)
280 | 
281 | 
282 | # Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
283 | class BasicTokenizer(object):
284 |     """
285 |     Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
286 | 
287 |     Args:
288 |         do_lower_case (`bool`, *optional*, defaults to `True`):
289 |             Whether or not to lowercase the input when tokenizing.
290 |         never_split (`Iterable`, *optional*):
291 |             Collection of tokens which will never be split during tokenization. Only has an effect when
292 |             `do_basic_tokenize=True`
293 |         tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
294 |             Whether or not to tokenize Chinese characters.
295 | 
296 |             This should likely be deactivated for Japanese (see this
297 |             [issue](https://github.com/huggingface/transformers/issues/328)).
298 |         strip_accents (`bool`, *optional*):
299 |             Whether or not to strip all accents. If this option is not specified, then it will be determined by the
300 |             value for `lowercase` (as in the original BERT).
301 |         do_split_on_punc (`bool`, *optional*, defaults to `True`):
302 |             In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
303 |             the full context of the words, such as contractions.
304 |     """
305 | 
306 |     def __init__(
307 |         self,
308 |         do_lower_case=True,
309 |         never_split=None,
310 |         tokenize_chinese_chars=True,
311 |         strip_accents=None,
312 |         do_split_on_punc=True,
313 |     ):
314 |         if never_split is None:
315 |             never_split = []
316 |         self.do_lower_case = do_lower_case
317 |         self.never_split = set(never_split)
318 |         self.tokenize_chinese_chars = tokenize_chinese_chars
319 |         self.strip_accents = strip_accents
320 |         self.do_split_on_punc = do_split_on_punc
321 | 
322 |     def tokenize(self, text, never_split=None):
323 |         """
324 |         Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
325 | 
326 |         Args:
327 |             never_split (`List[str]`, *optional*)
328 |                 Kept for backward compatibility purposes. Now implemented directly at the base class level (see
329 |                 [`PreTrainedTokenizer.tokenize`]) List of token not to split.
330 |         """
331 |         # union() returns a new set by concatenating the two sets.
332 |         never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
333 |         text = self._clean_text(text)
334 | 
335 |         # This was added on November 1st, 2018 for the multilingual and Chinese
336 |         # models. This is also applied to the English models now, but it doesn't
337 |         # matter since the English models were not trained on any Chinese data
338 |         # and generally don't have any Chinese data in them (there are Chinese
339 |         # characters in the vocabulary because Wikipedia does have some Chinese
340 |         # words in the English Wikipedia.).
341 |         if self.tokenize_chinese_chars:
342 |             text = self._tokenize_chinese_chars(text)
343 |         # prevents treating the same character with different unicode codepoints as different characters
344 |         unicode_normalized_text = unicodedata.normalize("NFC", text)
345 |         orig_tokens = whitespace_tokenize(unicode_normalized_text)
346 |         split_tokens = []
347 |         for token in orig_tokens:
348 |             if token not in never_split:
349 |                 if self.do_lower_case:
350 |                     token = token.lower()
351 |                     if self.strip_accents is not False:
352 |                         token = self._run_strip_accents(token)
353 |                 elif self.strip_accents:
354 |                     token = self._run_strip_accents(token)
355 |             split_tokens.extend(self._run_split_on_punc(token, never_split))
356 | 
357 |         output_tokens = whitespace_tokenize(" ".join(split_tokens))
358 |         return output_tokens
359 | 
360 |     def _run_strip_accents(self, text):
361 |         """Strips accents from a piece of text."""
362 |         text = unicodedata.normalize("NFD", text)
363 |         output = []
364 |         for char in text:
365 |             cat = unicodedata.category(char)
366 |             if cat == "Mn":
367 |                 continue
368 |             output.append(char)
369 |         return "".join(output)
370 | 
371 |     def _run_split_on_punc(self, text, never_split=None):
372 |         """Splits punctuation on a piece of text."""
373 |         if not self.do_split_on_punc or (never_split is not None and text in never_split):
374 |             return [text]
375 |         chars = list(text)
376 |         i = 0
377 |         start_new_word = True
378 |         output = []
379 |         while i < len(chars):
380 |             char = chars[i]
381 |             if _is_punctuation(char):
382 |                 output.append([char])
383 |                 start_new_word = True
384 |             else:
385 |                 if start_new_word:
386 |                     output.append([])
387 |                 start_new_word = False
388 |                 output[-1].append(char)
389 |             i += 1
390 | 
391 |         return ["".join(x) for x in output]
392 | 
393 |     def _tokenize_chinese_chars(self, text):
394 |         """Adds whitespace around any CJK character."""
395 |         output = []
396 |         for char in text:
397 |             cp = ord(char)
398 |             if self._is_chinese_char(cp):
399 |                 output.append(" ")
400 |                 output.append(char)
401 |                 output.append(" ")
402 |             else:
403 |                 output.append(char)
404 |         return "".join(output)
405 | 
406 |     def _is_chinese_char(self, cp):
407 |         """Checks whether CP is the codepoint of a CJK character."""
408 |         # This defines a "chinese character" as anything in the CJK Unicode block:
409 |         #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
410 |         #
411 |         # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
412 |         # despite its name. The modern Korean Hangul alphabet is a different block,
413 |         # as is Japanese Hiragana and Katakana. Those alphabets are used to write
414 |         # space-separated words, so they are not treated specially and handled
415 |         # like the all of the other languages.
416 |         if (
417 |             (cp >= 0x4E00 and cp <= 0x9FFF)
418 |             or (cp >= 0x3400 and cp <= 0x4DBF)  #
419 |             or (cp >= 0x20000 and cp <= 0x2A6DF)  #
420 |             or (cp >= 0x2A700 and cp <= 0x2B73F)  #
421 |             or (cp >= 0x2B740 and cp <= 0x2B81F)  #
422 |             or (cp >= 0x2B820 and cp <= 0x2CEAF)  #
423 |             or (cp >= 0xF900 and cp <= 0xFAFF)
424 |             or (cp >= 0x2F800 and cp <= 0x2FA1F)  #
425 |         ):  #
426 |             return True
427 | 
428 |         return False
429 | 
430 |     def _clean_text(self, text):
431 |         """Performs invalid character removal and whitespace cleanup on text."""
432 |         output = []
433 |         for char in text:
434 |             cp = ord(char)
435 |             if cp == 0 or cp == 0xFFFD or _is_control(char):
436 |                 continue
437 |             if _is_whitespace(char):
438 |                 output.append(" ")
439 |             else:
440 |                 output.append(char)
441 |         return "".join(output)
442 | 
443 | 
444 | # Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
445 | class WordpieceTokenizer(object):
446 |     """Runs WordPiece tokenization."""
447 | 
448 |     def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
449 |         self.vocab = vocab
450 |         self.unk_token = unk_token
451 |         self.max_input_chars_per_word = max_input_chars_per_word
452 | 
453 |     def tokenize(self, text):
454 |         """
455 |         Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
456 |         tokenization using the given vocabulary.
457 | 
458 |         For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
459 | 
460 |         Args:
461 |             text: A single token or whitespace separated tokens. This should have
462 |                 already been passed through *BasicTokenizer*.
463 | 
464 |         Returns:
465 |             A list of wordpiece tokens.
466 |         """
467 | 
468 |         output_tokens = []
469 |         for token in whitespace_tokenize(text):
470 |             chars = list(token)
471 |             if len(chars) > self.max_input_chars_per_word:
472 |                 output_tokens.append(self.unk_token)
473 |                 continue
474 | 
475 |             is_bad = False
476 |             start = 0
477 |             sub_tokens = []
478 |             while start < len(chars):
479 |                 end = len(chars)
480 |                 cur_substr = None
481 |                 while start < end:
482 |                     substr = "".join(chars[start:end])
483 |                     if start > 0:
484 |                         substr = "##" + substr
485 |                     if substr in self.vocab:
486 |                         cur_substr = substr
487 |                         break
488 |                     end -= 1
489 |                 if cur_substr is None:
490 |                     is_bad = True
491 |                     break
492 |                 sub_tokens.append(cur_substr)
493 |                 start = end
494 | 
495 |             if is_bad:
496 |                 output_tokens.append(self.unk_token)
497 |             else:
498 |                 output_tokens.extend(sub_tokens)
499 |         return output_tokens
```

--------------------------------------------------------------------------------
/surya/detection/model/encoderdecoder.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | This is an implementation of efficientvit, with some modifications (decode head, etc).
  3 | 
  4 | Original paper at https://arxiv.org/abs/2205.14756
  5 | 
  6 | Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py
  7 | Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit
  8 | """
  9 | 
 10 | from __future__ import annotations
 11 | 
 12 | from typing import Optional, Union, Tuple, List, Any
 13 | from functools import partial
 14 | 
 15 | import torch
 16 | import torch.nn as nn
 17 | import torch.nn.functional as F
 18 | 
 19 | from transformers.modeling_outputs import SemanticSegmenterOutput
 20 | 
 21 | from surya.common.pretrained import SuryaPreTrainedModel
 22 | from surya.common.s3 import S3DownloaderMixin
 23 | from surya.detection.model.config import EfficientViTConfig
 24 | 
 25 | 
 26 | def val2list(x: Union[List, Tuple, Any], repeat_time=1):
 27 |     if isinstance(x, (list, tuple)):
 28 |         return list(x)
 29 |     return [x for _ in range(repeat_time)]
 30 | 
 31 | 
 32 | def val2tuple(x: Union[List, Tuple, Any], min_len: int = 1, idx_repeat: int = -1):
 33 |     # repeat elements if necessary
 34 |     x = val2list(x)
 35 |     if len(x) > 0:
 36 |         x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
 37 | 
 38 |     return tuple(x)
 39 | 
 40 | 
 41 | def get_same_padding(
 42 |     kernel_size: Union[int, Tuple[int, ...]],
 43 | ) -> Union[int, Tuple[int, ...]]:
 44 |     if isinstance(kernel_size, tuple):
 45 |         return tuple([get_same_padding(ks) for ks in kernel_size])
 46 |     else:
 47 |         assert kernel_size % 2 > 0, "kernel size should be odd number"
 48 |         return kernel_size // 2
 49 | 
 50 | 
 51 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int:
 52 |     padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
 53 |     return padding
 54 | 
 55 | 
 56 | class ConvNormAct(nn.Module):
 57 |     def __init__(
 58 |         self,
 59 |         in_channels: int,
 60 |         out_channels: int,
 61 |         kernel_size=3,
 62 |         stride=1,
 63 |         dilation=1,
 64 |         groups=1,
 65 |         bias=False,
 66 |         dropout=0.0,
 67 |         norm_layer=nn.BatchNorm2d,
 68 |         act_layer=nn.ReLU,
 69 |     ):
 70 |         super(ConvNormAct, self).__init__()
 71 |         self.dropout = nn.Dropout(dropout, inplace=False)
 72 |         padding = get_padding(kernel_size, stride, dilation)
 73 |         self.conv = nn.Conv2d(
 74 |             in_channels,
 75 |             out_channels,
 76 |             kernel_size=kernel_size,
 77 |             stride=stride,
 78 |             dilation=dilation,
 79 |             groups=groups,
 80 |             bias=bias,
 81 |             padding=padding,
 82 |         )
 83 |         self.norm = (
 84 |             norm_layer(num_features=out_channels) if norm_layer else nn.Identity()
 85 |         )
 86 |         self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()
 87 | 
 88 |     def forward(self, x):
 89 |         x = self.conv(x)
 90 |         x = self.norm(x)
 91 |         x = self.act(x)
 92 |         return x
 93 | 
 94 | 
 95 | class DSConv(nn.Module):
 96 |     def __init__(
 97 |         self,
 98 |         in_channels: int,
 99 |         out_channels: int,
100 |         kernel_size=3,
101 |         stride=1,
102 |         use_bias=False,
103 |         norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
104 |         act_layer=(nn.ReLU6, None),
105 |     ):
106 |         super(DSConv, self).__init__()
107 |         use_bias = val2tuple(use_bias, 2)
108 |         norm_layer = val2tuple(norm_layer, 2)
109 |         act_layer = val2tuple(act_layer, 2)
110 | 
111 |         self.depth_conv = ConvNormAct(
112 |             in_channels,
113 |             in_channels,
114 |             kernel_size,
115 |             stride,
116 |             groups=in_channels,
117 |             norm_layer=norm_layer[0],
118 |             act_layer=act_layer[0],
119 |             bias=use_bias[0],
120 |         )
121 |         self.point_conv = ConvNormAct(
122 |             in_channels,
123 |             out_channels,
124 |             1,
125 |             norm_layer=norm_layer[1],
126 |             act_layer=act_layer[1],
127 |             bias=use_bias[1],
128 |         )
129 | 
130 |     def forward(self, x):
131 |         x = self.depth_conv(x)
132 |         x = self.point_conv(x)
133 |         return x
134 | 
135 | 
136 | class ConvBlock(nn.Module):
137 |     def __init__(
138 |         self,
139 |         in_channels: int,
140 |         out_channels: int,
141 |         kernel_size=3,
142 |         stride=1,
143 |         mid_channels=None,
144 |         expand_ratio=1,
145 |         use_bias=False,
146 |         norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
147 |         act_layer=(nn.ReLU6, None),
148 |     ):
149 |         super(ConvBlock, self).__init__()
150 |         use_bias = val2tuple(use_bias, 2)
151 |         norm_layer = val2tuple(norm_layer, 2)
152 |         act_layer = val2tuple(act_layer, 2)
153 |         mid_channels = mid_channels or round(in_channels * expand_ratio)
154 | 
155 |         self.conv1 = ConvNormAct(
156 |             in_channels,
157 |             mid_channels,
158 |             kernel_size,
159 |             stride,
160 |             norm_layer=norm_layer[0],
161 |             act_layer=act_layer[0],
162 |             bias=use_bias[0],
163 |         )
164 |         self.conv2 = ConvNormAct(
165 |             mid_channels,
166 |             out_channels,
167 |             kernel_size,
168 |             1,
169 |             norm_layer=norm_layer[1],
170 |             act_layer=act_layer[1],
171 |             bias=use_bias[1],
172 |         )
173 | 
174 |     def forward(self, x):
175 |         x = self.conv1(x)
176 |         x = self.conv2(x)
177 |         return x
178 | 
179 | 
180 | class MBConv(nn.Module):
181 |     def __init__(
182 |         self,
183 |         in_channels: int,
184 |         out_channels: int,
185 |         kernel_size=3,
186 |         stride=1,
187 |         mid_channels=None,
188 |         expand_ratio=6,
189 |         use_bias=False,
190 |         norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d),
191 |         act_layer=(nn.ReLU6, nn.ReLU6, None),
192 |     ):
193 |         super(MBConv, self).__init__()
194 |         use_bias = val2tuple(use_bias, 3)
195 |         norm_layer = val2tuple(norm_layer, 3)
196 |         act_layer = val2tuple(act_layer, 3)
197 |         mid_channels = mid_channels or round(in_channels * expand_ratio)
198 | 
199 |         self.inverted_conv = ConvNormAct(
200 |             in_channels,
201 |             mid_channels,
202 |             1,
203 |             stride=1,
204 |             norm_layer=norm_layer[0],
205 |             act_layer=act_layer[0],
206 |             bias=use_bias[0],
207 |         )
208 |         self.depth_conv = ConvNormAct(
209 |             mid_channels,
210 |             mid_channels,
211 |             kernel_size,
212 |             stride=stride,
213 |             groups=mid_channels,
214 |             norm_layer=norm_layer[1],
215 |             act_layer=act_layer[1],
216 |             bias=use_bias[1],
217 |         )
218 |         self.point_conv = ConvNormAct(
219 |             mid_channels,
220 |             out_channels,
221 |             1,
222 |             norm_layer=norm_layer[2],
223 |             act_layer=act_layer[2],
224 |             bias=use_bias[2],
225 |         )
226 | 
227 |     def forward(self, x):
228 |         x = self.inverted_conv(x)
229 |         x = self.depth_conv(x)
230 |         x = self.point_conv(x)
231 |         return x
232 | 
233 | 
234 | class FusedMBConv(nn.Module):
235 |     def __init__(
236 |         self,
237 |         in_channels: int,
238 |         out_channels: int,
239 |         kernel_size=3,
240 |         stride=1,
241 |         mid_channels=None,
242 |         expand_ratio=6,
243 |         groups=1,
244 |         use_bias=False,
245 |         norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
246 |         act_layer=(nn.ReLU6, None),
247 |     ):
248 |         super(FusedMBConv, self).__init__()
249 |         use_bias = val2tuple(use_bias, 2)
250 |         norm_layer = val2tuple(norm_layer, 2)
251 |         act_layer = val2tuple(act_layer, 2)
252 |         mid_channels = mid_channels or round(in_channels * expand_ratio)
253 | 
254 |         self.spatial_conv = ConvNormAct(
255 |             in_channels,
256 |             mid_channels,
257 |             kernel_size,
258 |             stride=stride,
259 |             groups=groups,
260 |             norm_layer=norm_layer[0],
261 |             act_layer=act_layer[0],
262 |             bias=use_bias[0],
263 |         )
264 |         self.point_conv = ConvNormAct(
265 |             mid_channels,
266 |             out_channels,
267 |             1,
268 |             norm_layer=norm_layer[1],
269 |             act_layer=act_layer[1],
270 |             bias=use_bias[1],
271 |         )
272 | 
273 |     def forward(self, x):
274 |         x = self.spatial_conv(x)
275 |         x = self.point_conv(x)
276 |         return x
277 | 
278 | 
279 | class LiteMLA(nn.Module):
280 |     """Lightweight multi-scale linear attention"""
281 | 
282 |     def __init__(
283 |         self,
284 |         in_channels: int,
285 |         out_channels: int,
286 |         heads: Union[int, None] = None,
287 |         heads_ratio: float = 1.0,
288 |         dim=8,
289 |         use_bias=False,
290 |         norm_layer=(None, nn.BatchNorm2d),
291 |         act_layer=(None, None),
292 |         kernel_func=nn.ReLU,
293 |         scales=(5,),
294 |         eps=1e-5,
295 |     ):
296 |         super(LiteMLA, self).__init__()
297 |         self.eps = eps
298 |         heads = heads or int(in_channels // dim * heads_ratio)
299 |         total_dim = heads * dim
300 |         use_bias = val2tuple(use_bias, 2)
301 |         norm_layer = val2tuple(norm_layer, 2)
302 |         act_layer = val2tuple(act_layer, 2)
303 | 
304 |         self.dim = dim
305 |         self.qkv = ConvNormAct(
306 |             in_channels,
307 |             3 * total_dim,
308 |             1,
309 |             bias=use_bias[0],
310 |             norm_layer=norm_layer[0],
311 |             act_layer=act_layer[0],
312 |         )
313 |         self.aggreg = nn.ModuleList(
314 |             [
315 |                 nn.Sequential(
316 |                     nn.Conv2d(
317 |                         3 * total_dim,
318 |                         3 * total_dim,
319 |                         scale,
320 |                         padding=get_same_padding(scale),
321 |                         groups=3 * total_dim,
322 |                         bias=use_bias[0],
323 |                     ),
324 |                     nn.Conv2d(
325 |                         3 * total_dim,
326 |                         3 * total_dim,
327 |                         1,
328 |                         groups=3 * heads,
329 |                         bias=use_bias[0],
330 |                     ),
331 |                 )
332 |                 for scale in scales
333 |             ]
334 |         )
335 |         self.kernel_func = kernel_func(inplace=False)
336 | 
337 |         self.proj = ConvNormAct(
338 |             total_dim * (1 + len(scales)),
339 |             out_channels,
340 |             1,
341 |             bias=use_bias[1],
342 |             norm_layer=norm_layer[1],
343 |             act_layer=act_layer[1],
344 |         )
345 | 
346 |     def _attn(self, q, k, v):
347 |         dtype = v.dtype
348 |         q, k, v = q.float(), k.float(), v.float()
349 |         kv = k.transpose(-1, -2) @ v
350 |         out = q @ kv
351 |         out = out[..., :-1] / (out[..., -1:] + self.eps)
352 |         return out.to(dtype)
353 | 
354 |     def forward(self, x):
355 |         # Shape is B, C, H, W
356 |         B, _, H, W = x.shape
357 | 
358 |         # generate multi-scale q, k, v
359 |         qkv = self.qkv(x)
360 |         multi_scale_qkv = [qkv]
361 |         for op in self.aggreg:
362 |             multi_scale_qkv.append(op(qkv))
363 |         multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
364 |         multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(
365 |             -1, -2
366 |         )
367 |         # Shape for each is B, C, HW, head_dim
368 |         q, k, v = multi_scale_qkv.chunk(3, dim=-1)
369 | 
370 |         # lightweight global attention
371 |         q = self.kernel_func(q)
372 |         k = self.kernel_func(k)
373 |         v = F.pad(v, (0, 1), mode="constant", value=1.0)
374 | 
375 |         out = self._attn(q, k, v)
376 | 
377 |         # final projection
378 |         out = out.transpose(-1, -2).reshape(B, -1, H, W)
379 |         out = self.proj(out)
380 |         return out
381 | 
382 | 
383 | class EfficientVitBlock(nn.Module):
384 |     def __init__(
385 |         self,
386 |         in_channels,
387 |         heads_ratio=1.0,
388 |         head_dim=32,
389 |         expand_ratio=4,
390 |         norm_layer=nn.BatchNorm2d,
391 |         act_layer=nn.Hardswish,
392 |     ):
393 |         super(EfficientVitBlock, self).__init__()
394 |         self.context_module = ResidualBlock(
395 |             LiteMLA(
396 |                 in_channels=in_channels,
397 |                 out_channels=in_channels,
398 |                 heads_ratio=heads_ratio,
399 |                 dim=head_dim,
400 |                 norm_layer=(None, norm_layer),
401 |             ),
402 |             nn.Identity(),
403 |         )
404 |         self.local_module = ResidualBlock(
405 |             MBConv(
406 |                 in_channels=in_channels,
407 |                 out_channels=in_channels,
408 |                 expand_ratio=expand_ratio,
409 |                 use_bias=(True, True, False),
410 |                 norm_layer=(None, None, norm_layer),
411 |                 act_layer=(act_layer, act_layer, None),
412 |             ),
413 |             nn.Identity(),
414 |         )
415 | 
416 |     def forward(self, x):
417 |         x = self.context_module(x)
418 |         x = self.local_module(x)
419 |         return x
420 | 
421 | 
422 | class ResidualBlock(nn.Module):
423 |     def __init__(
424 |         self,
425 |         main: Optional[nn.Module],
426 |         shortcut: Optional[nn.Module] = None,
427 |         pre_norm: Optional[nn.Module] = None,
428 |     ):
429 |         super(ResidualBlock, self).__init__()
430 |         self.pre_norm = pre_norm if pre_norm is not None else nn.Identity()
431 |         self.main = main
432 |         self.shortcut = shortcut
433 | 
434 |     def forward(self, x):
435 |         res = self.main(self.pre_norm(x))
436 |         if self.shortcut is not None:
437 |             res = res + self.shortcut(x)
438 |         return res
439 | 
440 | 
441 | def build_local_block(
442 |     in_channels: int,
443 |     out_channels: int,
444 |     stride: int,
445 |     kernel_size: int,
446 |     expand_ratio: float,
447 |     norm_layer: str,
448 |     act_layer: str,
449 |     fewer_norm: bool = False,
450 |     block_type: str = "default",
451 | ):
452 |     assert block_type in ["default", "large", "fused"]
453 |     if expand_ratio == 1:
454 |         if block_type == "default":
455 |             block = DSConv(
456 |                 in_channels=in_channels,
457 |                 out_channels=out_channels,
458 |                 stride=stride,
459 |                 kernel_size=kernel_size,
460 |                 use_bias=(True, False) if fewer_norm else False,
461 |                 norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
462 |                 act_layer=(act_layer, None),
463 |             )
464 |         else:
465 |             block = ConvBlock(
466 |                 in_channels=in_channels,
467 |                 out_channels=out_channels,
468 |                 stride=stride,
469 |                 kernel_size=kernel_size,
470 |                 use_bias=(True, False) if fewer_norm else False,
471 |                 norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
472 |                 act_layer=(act_layer, None),
473 |             )
474 |     else:
475 |         if block_type == "default":
476 |             block = MBConv(
477 |                 in_channels=in_channels,
478 |                 out_channels=out_channels,
479 |                 stride=stride,
480 |                 kernel_size=kernel_size,
481 |                 expand_ratio=expand_ratio,
482 |                 use_bias=(True, True, False) if fewer_norm else False,
483 |                 norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer,
484 |                 act_layer=(act_layer, act_layer, None),
485 |             )
486 |         else:
487 |             block = FusedMBConv(
488 |                 in_channels=in_channels,
489 |                 out_channels=out_channels,
490 |                 stride=stride,
491 |                 kernel_size=kernel_size,
492 |                 expand_ratio=expand_ratio,
493 |                 use_bias=(True, False) if fewer_norm else False,
494 |                 norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
495 |                 act_layer=(act_layer, None),
496 |             )
497 |     return block
498 | 
499 | 
500 | class Stem(nn.Sequential):
501 |     def __init__(
502 |         self,
503 |         in_chs,
504 |         out_chs,
505 |         depth,
506 |         stride,
507 |         norm_layer,
508 |         act_layer,
509 |         block_type="default",
510 |     ):
511 |         super().__init__()
512 |         self.stride = stride
513 | 
514 |         self.add_module(
515 |             "in_conv",
516 |             ConvNormAct(
517 |                 in_chs,
518 |                 out_chs,
519 |                 kernel_size=stride + 1,
520 |                 stride=stride,
521 |                 norm_layer=norm_layer,
522 |                 act_layer=act_layer,
523 |             ),
524 |         )
525 |         stem_block = 0
526 |         for _ in range(depth):
527 |             self.add_module(
528 |                 f"res{stem_block}",
529 |                 ResidualBlock(
530 |                     build_local_block(
531 |                         in_channels=out_chs,
532 |                         out_channels=out_chs,
533 |                         stride=1,
534 |                         kernel_size=3,
535 |                         expand_ratio=1,
536 |                         norm_layer=norm_layer,
537 |                         act_layer=act_layer,
538 |                         block_type=block_type,
539 |                     ),
540 |                     nn.Identity(),
541 |                 ),
542 |             )
543 |             stem_block += 1
544 | 
545 | 
546 | class EfficientVitLargeStage(nn.Module):
547 |     def __init__(
548 |         self,
549 |         in_chs,
550 |         out_chs,
551 |         depth,
552 |         stride,
553 |         norm_layer,
554 |         act_layer,
555 |         head_dim,
556 |         vit_stage=False,
557 |         fewer_norm=False,
558 |     ):
559 |         super(EfficientVitLargeStage, self).__init__()
560 |         blocks = [
561 |             ResidualBlock(
562 |                 build_local_block(
563 |                     in_channels=in_chs,
564 |                     out_channels=out_chs,
565 |                     stride=stride,
566 |                     kernel_size=stride + 1,
567 |                     expand_ratio=24 if vit_stage else 16,
568 |                     norm_layer=norm_layer,
569 |                     act_layer=act_layer,
570 |                     fewer_norm=vit_stage or fewer_norm,
571 |                     block_type="default" if fewer_norm else "fused",
572 |                 ),
573 |                 None,
574 |             )
575 |         ]
576 |         in_chs = out_chs
577 | 
578 |         if vit_stage:
579 |             # for stage 4
580 |             for _ in range(depth):
581 |                 blocks.append(
582 |                     EfficientVitBlock(
583 |                         in_channels=in_chs,
584 |                         head_dim=head_dim,
585 |                         expand_ratio=6,
586 |                         norm_layer=norm_layer,
587 |                         act_layer=act_layer,
588 |                     )
589 |                 )
590 |         else:
591 |             # for stage 1, 2, 3
592 |             for i in range(depth):
593 |                 blocks.append(
594 |                     ResidualBlock(
595 |                         build_local_block(
596 |                             in_channels=in_chs,
597 |                             out_channels=out_chs,
598 |                             stride=1,
599 |                             kernel_size=3,
600 |                             expand_ratio=4,
601 |                             norm_layer=norm_layer,
602 |                             act_layer=act_layer,
603 |                             fewer_norm=fewer_norm,
604 |                             block_type="default" if fewer_norm else "fused",
605 |                         ),
606 |                         nn.Identity(),
607 |                     )
608 |                 )
609 | 
610 |         self.blocks = nn.Sequential(*blocks)
611 | 
612 |     def forward(self, x):
613 |         return self.blocks(x)
614 | 
615 | 
616 | class EfficientVitLarge(nn.Module):
617 |     def __init__(
618 |         self,
619 |         config: EfficientViTConfig,
620 |         norm_layer=nn.BatchNorm2d,
621 |         act_layer=nn.Hardswish,
622 |     ):
623 |         super(EfficientVitLarge, self).__init__()
624 |         self.grad_checkpointing = False
625 |         self.num_classes = config.num_classes
626 |         self.norm_eps = config.layer_norm_eps
627 |         norm_layer = partial(norm_layer, eps=self.norm_eps)
628 | 
629 |         # input stem
630 |         self.stem = Stem(
631 |             config.num_channels,
632 |             config.widths[0],
633 |             config.depths[0],
634 |             config.strides[0],
635 |             norm_layer,
636 |             act_layer,
637 |             block_type="large",
638 |         )
639 |         stride = config.strides[0]
640 | 
641 |         # stages
642 |         self.feature_info = []
643 |         self.stages = nn.Sequential()
644 |         in_channels = config.widths[0]
645 |         for i, (w, d, s) in enumerate(
646 |             zip(config.widths[1:], config.depths[1:], config.strides[1:])
647 |         ):
648 |             self.stages.append(
649 |                 EfficientVitLargeStage(
650 |                     in_channels,
651 |                     w,
652 |                     depth=d,
653 |                     stride=s,
654 |                     norm_layer=norm_layer,
655 |                     act_layer=act_layer,
656 |                     head_dim=config.head_dim,
657 |                     vit_stage=i >= 3,
658 |                     fewer_norm=i >= 2,
659 |                 )
660 |             )
661 |             stride *= s
662 |             in_channels = w
663 |             self.feature_info += [
664 |                 dict(num_chs=in_channels, reduction=stride, module=f"stages.{i}")
665 |             ]
666 | 
667 |         self.num_features = in_channels
668 | 
669 |     @torch.jit.ignore
670 |     def set_grad_checkpointing(self, enable=True):
671 |         self.grad_checkpointing = enable
672 | 
673 |     def forward(self, x):
674 |         x = self.stem(x)
675 |         encoder_hidden_states = []
676 |         for i, module in enumerate(self.stages):
677 |             x = module(x)
678 |             encoder_hidden_states.append(x)
679 | 
680 |         return encoder_hidden_states
681 | 
682 | 
683 | class EfficientViTPreTrainedModel(SuryaPreTrainedModel):
684 |     """
685 |     An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
686 |     models.
687 |     """
688 | 
689 |     config_class = EfficientViTConfig
690 |     base_model_prefix = "efficientvit"
691 |     main_input_name = "pixel_values"
692 | 
693 |     def _init_weights(self, module):
694 |         """Initialize the weights"""
695 |         if isinstance(module, (nn.Linear, nn.Conv2d)):
696 |             # Slightly different from the TF version which uses truncated_normal for initialization
697 |             # cf https://github.com/pytorch/pytorch/pull/5617
698 |             module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
699 |             if module.bias is not None:
700 |                 module.bias.data.zero_()
701 |         elif isinstance(module, nn.Embedding):
702 |             module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
703 |             if module.padding_idx is not None:
704 |                 module.weight.data[module.padding_idx].zero_()
705 |         elif isinstance(module, nn.LayerNorm):
706 |             module.bias.data.zero_()
707 |             module.weight.data.fill_(1.0)
708 | 
709 | 
710 | class DecodeMLP(nn.Module):
711 |     def __init__(self, input_dim, output_dim):
712 |         super().__init__()
713 |         self.proj = nn.Linear(input_dim, output_dim)
714 | 
715 |     def forward(self, hidden_states: torch.Tensor):
716 |         # Input is B, C, H, W
717 |         hidden_states = hidden_states.flatten(2).transpose(1, 2)
718 |         # Output is B, HW, C
719 |         hidden_states = self.proj(hidden_states)
720 |         return hidden_states
721 | 
722 | 
723 | class DecodeHead(EfficientViTPreTrainedModel):
724 |     def __init__(self, config: EfficientViTConfig):
725 |         super().__init__(config)
726 | 
727 |         # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
728 |         mlps = []
729 |         for width in config.widths[1:]:
730 |             mlp = DecodeMLP(
731 |                 input_dim=width, output_dim=config.decoder_layer_hidden_size
732 |             )
733 |             mlps.append(mlp)
734 |         self.linear_c = nn.ModuleList(mlps)
735 | 
736 |         # the following 3 layers implement the ConvModule of the original implementation
737 |         self.linear_fuse = nn.Conv2d(
738 |             in_channels=config.decoder_layer_hidden_size * config.num_stages,
739 |             out_channels=config.decoder_hidden_size,
740 |             kernel_size=1,
741 |             bias=False,
742 |         )
743 |         self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size)
744 |         self.activation = nn.ReLU()
745 | 
746 |         self.dropout = nn.Dropout(config.classifier_dropout_prob)
747 |         self.classifier = nn.Conv2d(
748 |             config.decoder_hidden_size, config.num_labels, kernel_size=1
749 |         )
750 | 
751 |         self.config = config
752 | 
753 |     def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
754 |         batch_size = encoder_hidden_states[-1].shape[0]
755 | 
756 |         all_hidden_states = ()
757 |         for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
758 |             height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
759 |             encoder_hidden_state = mlp(encoder_hidden_state)  # Output is B, HW, C
760 |             # Permute to B, C, HW
761 |             encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
762 |             encoder_hidden_state = encoder_hidden_state.reshape(
763 |                 batch_size, -1, height, width
764 |             )
765 |             # upsample
766 |             encoder_hidden_state = nn.functional.interpolate(
767 |                 encoder_hidden_state,
768 |                 size=encoder_hidden_states[0].size()[2:],
769 |                 mode="bilinear",
770 |                 align_corners=False,
771 |             )
772 |             all_hidden_states += (encoder_hidden_state,)
773 | 
774 |         hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
775 |         hidden_states = self.batch_norm(hidden_states)
776 |         hidden_states = self.activation(hidden_states)
777 | 
778 |         # logits are of shape (batch_size, num_labels, height/4, width/4)
779 |         logits = self.classifier(hidden_states)
780 | 
781 |         return logits
782 | 
783 | 
784 | class EfficientViTForSemanticSegmentation(
785 |     S3DownloaderMixin, EfficientViTPreTrainedModel
786 | ):
787 |     def __init__(self, config, **kwargs):
788 |         super().__init__(config)
789 |         self.vit = EfficientVitLarge(config)
790 |         self.decode_head = DecodeHead(config)
791 | 
792 |         # Initialize weights and apply final processing
793 |         self.post_init()
794 | 
795 |     def forward(
796 |         self, pixel_values: torch.FloatTensor
797 |     ) -> Union[Tuple, SemanticSegmenterOutput]:
798 |         # Pixel values should be B,C,H,W
799 |         encoder_hidden_states = self.vit(
800 |             pixel_values,
801 |         )
802 | 
803 |         logits = self.decode_head(encoder_hidden_states)
804 | 
805 |         # Apply sigmoid to get 0-1 output
806 |         logits = torch.special.expit(logits)
807 | 
808 |         return SemanticSegmenterOutput(
809 |             loss=None, logits=logits, hidden_states=encoder_hidden_states
810 |         )
811 | 
812 | 
813 | class EfficientViTForSemanticLayoutSegmentation(EfficientViTPreTrainedModel):
814 |     def __init__(self, config, **kwargs):
815 |         super().__init__(config, **kwargs)
816 |         self.vit = EfficientVitLarge(config)
817 |         self.decode_head = DecodeHead(config)
818 | 
819 |         # Initialize weights and apply final processing
820 |         self.post_init()
821 | 
822 |     def forward(
823 |         self, pixel_values: torch.FloatTensor
824 |     ) -> Union[Tuple, SemanticSegmenterOutput]:
825 |         # Pixel values should be B,C,H,W
826 |         encoder_hidden_states = self.vit(
827 |             pixel_values,
828 |         )
829 | 
830 |         logits = self.decode_head(encoder_hidden_states)
831 | 
832 |         # Apply sigmoid to get 0-1 output
833 |         logits = torch.special.expit(logits)
834 | 
835 |         return SemanticSegmenterOutput(
836 |             loss=None, logits=logits, hidden_states=encoder_hidden_states
837 |         )
838 | 
```

--------------------------------------------------------------------------------
/surya/common/surya/processor/tokenizer.py:
--------------------------------------------------------------------------------

```python
  1 | import html
  2 | import re
  3 | from typing import List, Union, Dict, Optional, Tuple, Iterable
  4 | import numpy as np
  5 | import torch
  6 | from tokenizers import AddedToken
  7 | import json
  8 | import os
  9 | from transformers import PreTrainedTokenizer, Qwen2Tokenizer as Qwen2OriginalTokenizer
 10 | 
 11 | 
 12 | from surya.common.s3 import S3DownloaderMixin
 13 | from surya.common.surya.schema import TASK_NAMES, TaskNames
 14 | from surya.logging import get_logger
 15 | from surya.settings import settings
 16 | 
 17 | logger = get_logger()
 18 | 
 19 | 
 20 | def create_token_regex(tokens):
 21 |     escaped_tokens = [re.escape(token) for token in tokens]
 22 |     escaped_tokens.sort(key=len, reverse=True)
 23 |     pattern = r"^(" + "|".join(escaped_tokens) + r")"
 24 |     regex = re.compile(pattern)
 25 |     return regex
 26 | 
 27 | 
 28 | class InnerOCRTokenizer:
 29 |     def __init__(
 30 |         self,
 31 |         special_tokens: Dict[str, list] | None = None,
 32 |         qwen_tokenizer: Qwen2OriginalTokenizer | None = None,
 33 |         **kwargs,
 34 |     ):
 35 |         self.qwen_tokenizer = qwen_tokenizer
 36 |         self.qwen_token_offset = len(qwen_tokenizer)
 37 | 
 38 |         all_special_tokens = special_tokens.get("all", [])
 39 |         self.SPECIAL_TOKEN_MAPPING = {}
 40 | 
 41 |         idx = 0
 42 |         for tag in all_special_tokens:
 43 |             if tag in self.SPECIAL_TOKEN_MAPPING:
 44 |                 continue
 45 |             self.SPECIAL_TOKEN_MAPPING[tag] = (
 46 |                 idx + self.qwen_token_offset
 47 |             )  # Assign token ID
 48 |             idx += 1
 49 | 
 50 |         self.REVERSE_SPECIAL_TOKEN_MAPPING = {
 51 |             v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items()
 52 |         }
 53 |         self.SPECIAL_TOKEN_OFFSET = idx
 54 |         self.FORMAT_TAG_PATTERN = create_token_regex(special_tokens["formatting"])
 55 |         self.MATH_TAG_PATTERN = create_token_regex(special_tokens["math_external"])
 56 |         self.LAYOUT_TAG_PATTERN = create_token_regex(special_tokens["layout"])
 57 |         self.TABLE_STRUCTURE_TAG_PATTERN = create_token_regex(
 58 |             special_tokens["table_structure"]
 59 |         )
 60 |         self.SYSTEM_TAG_PATTERN = create_token_regex(special_tokens.get("system", []))
 61 |         if not special_tokens.get("system", []):
 62 |             logger.warning("Warning: No system tokens found in special_tokens")
 63 | 
 64 |         self.MATH_TAG_START = "<math"
 65 |         self.MATH_END_TAG = "</math>"
 66 | 
 67 |         super().__init__(**kwargs)
 68 | 
 69 |     @property
 70 |     def vocab_size(self):
 71 |         return (
 72 |             65536 + self.SPECIAL_TOKEN_OFFSET
 73 |         )  # The highest codepoint is 65535, but we add 1 to account for the 0-indexing
 74 | 
 75 |     def _tokenize(self, text: str) -> List[int]:
 76 |         tokens = []
 77 |         in_math = False
 78 |         text = html.unescape(text)  # Unescape html entities like &lt; in equations
 79 |         while text:
 80 |             # Look for EOS, PAD, etc. tokens
 81 |             match = self.SYSTEM_TAG_PATTERN.search(text)
 82 |             if match:
 83 |                 tag = match.group(1)
 84 |                 tokens.append(
 85 |                     self.SPECIAL_TOKEN_MAPPING[tag]
 86 |                 )  # These are already offset
 87 |                 text = text[match.end() :]
 88 |                 continue
 89 | 
 90 |             # Look for layout tokens
 91 |             match = self.LAYOUT_TAG_PATTERN.search(text)
 92 |             if match:
 93 |                 tag = match.group(1)
 94 |                 tokens.append(
 95 |                     self.SPECIAL_TOKEN_MAPPING[tag]
 96 |                 )  # Layout tokens are already offset
 97 |                 text = text[match.end() :]
 98 |                 continue
 99 | 
100 |             match = self.TABLE_STRUCTURE_TAG_PATTERN.search(text)
101 |             if match:
102 |                 tag = match.group(1)
103 |                 tokens.append(self.SPECIAL_TOKEN_MAPPING[tag])
104 |                 text = text[match.end() :]
105 |                 continue
106 | 
107 |             # Check for math tags
108 |             match = self.MATH_TAG_PATTERN.search(text)
109 |             if match:
110 |                 # We found a tag
111 |                 tag = match.group(1)
112 |                 if tag.startswith(self.MATH_TAG_START):
113 |                     in_math = True
114 |                 elif tag == self.MATH_END_TAG:
115 |                     in_math = False
116 |                 tokens.append(
117 |                     self.SPECIAL_TOKEN_MAPPING[tag]  # Special tokens are already offset
118 |                 )  # Use special token ID
119 |                 text = text[match.end() :]
120 |                 continue
121 | 
122 |             # Tokenize math content with qwen2 tokenizer
123 |             if in_math:
124 |                 # If we're in a math block, check to see if we have a special math tag in the text
125 |                 math_end_position = text.find(self.MATH_END_TAG)
126 |                 math_str = text[:math_end_position]  # Gets the math content
127 |                 tokens += self.qwen_tokenizer(math_str)["input_ids"]
128 |                 text = text[math_end_position:]
129 |                 continue
130 | 
131 |             # Check for formatting tags
132 |             match = self.FORMAT_TAG_PATTERN.search(text)
133 |             if match:
134 |                 # We found a tag
135 |                 tag = match.group(1)
136 |                 tokens.append(
137 |                     self.SPECIAL_TOKEN_MAPPING[tag]  # Special tokens are already offset
138 |                 )  # Use special token ID
139 |                 text = text[match.end() :]
140 |                 continue
141 | 
142 |             # General case, utf-16 tokenization
143 |             utf_16_tokens = self.text_to_utf16_numbers(text[0])
144 |             tokens += [
145 |                 t + self.SPECIAL_TOKEN_OFFSET + self.qwen_token_offset
146 |                 for t in utf_16_tokens
147 |             ]
148 |             text = text[1:]
149 | 
150 |         return tokens
151 | 
152 |     def text_to_utf16_numbers(self, text: str):
153 |         """Converts text to UTF-16 encoded numbers."""
154 |         utf16_bytes = text.encode(
155 |             "utf-16le"
156 |         )  # Little-endian to simplify byte order handling
157 |         numbers = []
158 | 
159 |         for i in range(0, len(utf16_bytes), 2):
160 |             # Combine two adjacent bytes into a single number
161 |             number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8)
162 |             numbers.append(number)
163 | 
164 |         return numbers
165 | 
166 |     def utf16_numbers_to_text(self, numbers):
167 |         """Converts UTF-16 numbers back to text."""
168 |         byte_array = bytearray()
169 |         for number in numbers:
170 |             byte_array.append(number & 0xFF)  # Lower byte
171 |             byte_array.append((number >> 8) & 0xFF)  # Upper byte
172 | 
173 |         try:
174 |             text = byte_array.decode("utf-16le", errors="ignore")
175 |         except Exception as e:
176 |             logger.warning(f"Error decoding utf16: {e}")
177 |             text = ""
178 | 
179 |         return text
180 | 
181 |     def __call__(
182 |         self, texts: Union[str, List[str]], **kwargs
183 |     ) -> Dict[str, List[List[int]]]:
184 |         """Tokenizes text and returns input IDs."""
185 |         tokenized = []
186 | 
187 |         if isinstance(texts, str):
188 |             texts = [texts]
189 | 
190 |         for text in texts:
191 |             tokens = self._tokenize(text)
192 |             tokenized.append(tokens)
193 | 
194 |         return {"input_ids": tokenized}
195 | 
196 |     def decode(self, token_ids, **kwargs):
197 |         """Decodes token IDs back to text."""
198 |         if isinstance(token_ids, (np.ndarray, torch.Tensor)):
199 |             token_ids = token_ids.tolist()
200 | 
201 |         decoded_text = ""
202 |         token_buffer = []
203 |         decode_qwen = [False]
204 | 
205 |         def decode_buffer():
206 |             nonlocal decoded_text, token_buffer, decode_qwen
207 |             if token_buffer:
208 |                 if decode_qwen[0]:
209 |                     decoded_text += self.qwen_tokenizer.decode(token_buffer)
210 |                 else:
211 |                     token_buffer = [
212 |                         t - self.SPECIAL_TOKEN_OFFSET - self.qwen_token_offset
213 |                         for t in token_buffer
214 |                     ]
215 |                     decoded_text += self.utf16_numbers_to_text(token_buffer)
216 | 
217 |             token_buffer = []
218 |             decode_qwen[0] = False
219 | 
220 |         for t in token_ids:
221 |             if t < self.qwen_token_offset:
222 |                 # This is for math tags
223 |                 if token_buffer and token_buffer[-1] >= self.qwen_token_offset:
224 |                     decode_buffer()
225 |                 token_buffer.append(t)
226 |                 decode_qwen[0] = True
227 |             elif t >= self.SPECIAL_TOKEN_OFFSET + self.qwen_token_offset:
228 |                 if token_buffer and token_buffer[-1] < self.qwen_token_offset:
229 |                     decode_buffer()
230 |                 token_buffer.append(t)  # We shift this down later on
231 |                 decode_qwen[0] = False
232 |             elif t in self.REVERSE_SPECIAL_TOKEN_MAPPING:
233 |                 decode_buffer()
234 |                 decoded_text += self.REVERSE_SPECIAL_TOKEN_MAPPING[t]
235 |                 decode_qwen[0] = False
236 |             else:
237 |                 raise ValueError(
238 |                     f'Unexpected token value while decoding, got "{t}" in token_ids {token_ids}'
239 |                 )
240 | 
241 |         # Detokenize remaining tokens
242 |         decode_buffer()
243 | 
244 |         return decoded_text
245 | 
246 | 
247 | class Qwen2Tokenizer(S3DownloaderMixin, Qwen2OriginalTokenizer):
248 |     pass
249 | 
250 | class GreedyMathUTF16Tokenizer(S3DownloaderMixin, PreTrainedTokenizer):
251 |     """
252 |     HuggingFace slow tokenizer implementing:
253 |       - UTF-16 code units as the base [0..65535]
254 |       - Math tokens as greedy-longest-match ids after UTF-16
255 |       - Literal special tokens after math tokens
256 |     Absolute ID layout:
257 |       [0 .. 65535]                      : UTF-16 units
258 |       [65536 .. 65536+M-1]              : math tokens
259 |       [65536+M .. 65536+M+S-1]          : special tokens
260 |     """
261 | 
262 |     vocab_files_names = {
263 |         "vocab_file": "vocab_math.json",  # {"\\frac": 0, "\\alpha": 1, ...} raw contiguous ids 0..M-1
264 |         "specials_file": "specials.json",  # [flat list for legacy]
265 |         "specials_dict_file": "specials_dict.json",  # category dict (preferred)
266 |     }
267 |     model_input_names = ["input_ids", "attention_mask"]
268 |     is_fast = False
269 | 
270 |     # ---------- helpers ----------
271 |     @staticmethod
272 |     def _to_utf16_units(s: str) -> List[int]:
273 |         b = s.encode("utf-16le")
274 |         return [int.from_bytes(b[i : i + 2], "little") for i in range(0, len(b), 2)]
275 | 
276 |     @staticmethod
277 |     def _from_utf16_units(units: List[int]) -> str:
278 |         b = bytearray()
279 |         for u in units:
280 |             b += int(u).to_bytes(2, "little")
281 |         return b.decode("utf-16le", errors="ignore")
282 | 
283 |     class _TrieNode:
284 |         __slots__ = ("child", "id", "leaf")
285 | 
286 |         def __init__(self):
287 |             self.child: Dict[str, "GreedyMathUTF16Tokenizer._TrieNode"] = {}
288 |             self.id: Optional[int] = None
289 |             self.leaf: bool = False
290 | 
291 |     @classmethod
292 |     def _build_trie(
293 |         cls, token_to_id: Dict[str, int]
294 |     ) -> "GreedyMathUTF16Tokenizer._TrieNode":
295 |         root = cls._TrieNode()
296 |         for tok, tid in token_to_id.items():
297 |             node = root
298 |             for ch in tok:
299 |                 node = node.child.setdefault(ch, cls._TrieNode())
300 |             node.leaf = True
301 |             node.id = tid
302 |         return root
303 | 
304 |     @classmethod
305 |     def _encode_math_greedy(
306 |         cls,
307 |         s: str,
308 |         trie: "GreedyMathUTF16Tokenizer._TrieNode",
309 |         math_base: int,
310 |         debug: bool = False,
311 |     ) -> List[int]:
312 |         i, n = 0, len(s)
313 |         out: List[int] = []
314 |         while i < n:
315 |             node = trie
316 |             j = i
317 |             last_id = None
318 |             last_j = i
319 |             while j < n and (ch := s[j]) in node.child:
320 |                 node = node.child[ch]
321 |                 j += 1
322 |                 if node.leaf:
323 |                     last_id, last_j = node.id, j
324 |             if last_id is not None:
325 |                 if debug:
326 |                     print(f"[MATH] matched {s[i:last_j]!r} -> {last_id}")
327 |                 out.append(math_base + last_id)
328 |                 i = last_j
329 |             else:
330 |                 units = cls._to_utf16_units(s[i])
331 |                 if debug:
332 |                     print(f"[MATH] fallback {s[i]!r} -> utf16 {units}")
333 |                 out.extend(units)
334 |                 i += 1
335 |         return out
336 | 
337 |     # ---------- init ----------
338 |     def __init__(
339 |         self,
340 |         vocab_file: Optional[str] = None,
341 |         specials_file: Optional[str] = None,
342 |         specials_dict_file: Optional[str] = None,
343 |         *,
344 |         # You can also pass programmatically instead of files:
345 |         math_vocab: Optional[Dict[str, int]] = None,
346 |         special_tokens: Optional[List[str]] = None,
347 |         special_tokens_dict: Optional[Dict[str, List[str]]] = None,
348 |         debug: bool = False,
349 |         # Standard HF special token kwargs:
350 |         bos_token: Optional[str] = None,
351 |         eos_token: Optional[str] = None,
352 |         pad_token: Optional[str] = None,
353 |         unk_token: Optional[str] = None,
354 |         **kwargs,
355 |     ):
356 |         # Load math vocab
357 |         if vocab_file and os.path.isfile(vocab_file):
358 |             with open(vocab_file, "r", encoding="utf-8") as f:
359 |                 mv = json.load(f)
360 |         else:
361 |             mv = math_vocab or {}
362 | 
363 |         # Make math ids contiguous if needed
364 |         if mv:
365 |             max_id = max(mv.values())
366 |             if set(mv.values()) != set(range(max_id + 1)):
367 |                 items = sorted(mv.items(), key=lambda kv: kv[1])
368 |                 mv = {tok: i for i, (tok, _) in enumerate(items)}
369 | 
370 |         # Load special tokens (prefer category dict; fallback to flat list or defaults)
371 |         sp_dict = None
372 |         if specials_dict_file and os.path.isfile(specials_dict_file):
373 |             with open(specials_dict_file, "r", encoding="utf-8") as f:
374 |                 sp_dict = json.load(f)
375 |         elif special_tokens_dict is not None:
376 |             sp_dict = dict(special_tokens_dict)
377 | 
378 |         if sp_dict is None:
379 |             # Legacy path: flat list from file or provided/default list
380 |             if specials_file and os.path.isfile(specials_file):
381 |                 with open(specials_file, "r", encoding="utf-8") as f:
382 |                     sp_list_flat = json.load(f)
383 |             else:
384 |                 sp_list_flat = special_tokens or SPECIAL_TOKENS
385 |             sp_dict = {"all": list(sp_list_flat)}
386 | 
387 |         # Ensure "all" exists and is unique/preserved in order.
388 |         if "all" not in sp_dict or not isinstance(sp_dict["all"], list):
389 |             order = [
390 |                 "system",
391 |                 "formatting",
392 |                 "math_external",
393 |                 "script",
394 |                 "layout",
395 |                 "reasoning",
396 |                 "table_structure",
397 |                 "reserved",
398 |             ]
399 |             seen = set()
400 |             all_tokens: List[str] = []
401 |             for k in order:
402 |                 if k in sp_dict and isinstance(sp_dict[k], list):
403 |                     for t in sp_dict[k]:
404 |                         if t not in seen:
405 |                             all_tokens.append(t)
406 |                             seen.add(t)
407 |             sp_dict["all"] = all_tokens
408 | 
409 |         # Keep a copy of categories (if present) for downstream processor logic.
410 |         self.special_tokens = sp_dict
411 |         sp_list = list(sp_dict.get("all", []))
412 |         # Regex list should favor longest-first to avoid partial matches.
413 |         specials_for_regex = sorted(sp_list, key=len, reverse=True)
414 | 
415 |         self.debug = debug
416 |         self.UTF16_SPACE = 65536
417 |         self.math_token_to_rawid = dict(mv)  # 0..M-1
418 |         self.math_vocab_size = len(self.math_token_to_rawid)
419 |         self.MATH_BASE = self.UTF16_SPACE
420 |         self.SPECIAL_BASE = self.UTF16_SPACE + self.math_vocab_size
421 | 
422 |         # Maps
423 |         self.math_absid_to_token = {
424 |             self.MATH_BASE + rid: tok for tok, rid in self.math_token_to_rawid.items()
425 |         }
426 |         self.special_tokens_list = sp_list  # ID assignment order
427 |         self.special_to_absid = {
428 |             tok: self.SPECIAL_BASE + i for i, tok in enumerate(self.special_tokens_list)
429 |         }
430 |         self.absid_to_special = {v: k for k, v in self.special_to_absid.items()}
431 | 
432 |         # Public attributes for legacy/processor:
433 |         # All specials mapping (token -> absolute id)
434 |         self.SPECIAL_TOKEN_MAPPING: Dict[str, int] = dict(self.special_to_absid)
435 |         # Subset used heavily by processor for quick access
436 |         self.reverse_special_token_mapping = {
437 |             v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items()
438 |         }
439 |         self.LAYOUT_LABEL2ID = {
440 |             k: v
441 |             for k, v in self.SPECIAL_TOKEN_MAPPING.items()
442 |             if k in self.special_tokens["layout"]
443 |         }
444 |         self.TABLE_STRUCTURE_LABEL2ID = {
445 |             k: v
446 |             for k, v in self.SPECIAL_TOKEN_MAPPING.items()
447 |             if k in self.special_tokens["table_structure"]
448 |         }
449 |         if not self.special_tokens.get("system", []):
450 |             print("Warning: No system tokens found in special_tokens")
451 | 
452 |         self.MATH_TAG_START = "<math"
453 |         self.MATH_END_TAG = "</math>"
454 | 
455 |         sys_list = self.special_tokens.get("system", [])
456 |         self.system_tokens: Dict[str, int] = {
457 |             t: self.special_to_absid[t] for t in sys_list if t in self.special_to_absid
458 |         }
459 | 
460 |         # Regex for literal specials
461 |         self.specials_pattern = (
462 |             re.compile(r"(" + "|".join(re.escape(k) for k in specials_for_regex) + r")")
463 |             if specials_for_regex
464 |             else None
465 |         )
466 | 
467 |         # Trie for math greedy match
468 |         self.trie = self._build_trie(self.math_token_to_rawid)
469 | 
470 |         # Tell HF about special tokens (metadata)
471 |         kwargs.setdefault("bos_token", bos_token)
472 |         kwargs.setdefault("eos_token", eos_token or "</S>")
473 |         kwargs.setdefault("pad_token", pad_token or "<PAD>")
474 |         kwargs.setdefault("unk_token", unk_token)
475 | 
476 |         super().__init__(
477 |             vocab_file=vocab_file,
478 |             specials_file=specials_file,
479 |             specials_dict_file=specials_dict_file,
480 |             **kwargs,
481 |         )
482 | 
483 |     # ---------- required HF surface ----------
484 |     @property
485 |     def vocab_size(self) -> int:
486 |         return self.UTF16_SPACE + self.math_vocab_size + len(self.special_tokens_list)
487 | 
488 |     def get_vocab(self) -> Dict[str, int]:
489 |         # Compact vocab: just math+specials with ABSOLUTE ids.
490 |         v = {tok: self.MATH_BASE + rid for tok, rid in self.math_token_to_rawid.items()}
491 |         v.update(self.special_to_absid)
492 |         return v
493 | 
494 |     def __len__(self) -> int:
495 |         return self.vocab_size
496 | 
497 |     # Core encode/decode on ABSOLUTE ids
498 |     def _encode_core(self, text: str) -> List[int]:
499 |         text = html.unescape(text)
500 |         ids: List[int] = []
501 |         in_math = False
502 |         chunks = self.specials_pattern.split(text) if self.specials_pattern else [text]
503 |         for chunk in chunks:
504 |             if chunk in self.special_to_absid:
505 |                 ids.append(self.special_to_absid[chunk])
506 |                 if chunk.startswith("<math"):
507 |                     in_math = True
508 |                 elif chunk.startswith("</math>"):
509 |                     in_math = False
510 |                 if self.debug:
511 |                     print(f"[TAG] {chunk!r} -> {self.special_to_absid[chunk]}")
512 |                 continue
513 | 
514 |             if in_math:
515 |                 ids.extend(
516 |                     self._encode_math_greedy(
517 |                         chunk, self.trie, self.MATH_BASE, debug=self.debug
518 |                     )
519 |                 )
520 |             else:
521 |                 units = self._to_utf16_units(chunk)
522 |                 if self.debug and units:
523 |                     print(
524 |                         f"[TEXT] utf16 {chunk[:32]!r} -> {units[:8]}{'...' if len(units) > 8 else ''}"
525 |                     )
526 |                 ids.extend(units)
527 |         return ids
528 | 
529 |     def _decode_core(self, ids: Iterable[int]) -> str:
530 |         out: List[str] = []
531 |         buf: List[int] = []
532 | 
533 |         def flush():
534 |             if buf:
535 |                 out.append(self._from_utf16_units(buf))
536 |                 buf.clear()
537 | 
538 |         for tid in ids:
539 |             if tid >= self.MATH_BASE and tid < self.SPECIAL_BASE:
540 |                 flush()
541 |                 out.append(self.math_absid_to_token.get(tid, ""))
542 |             elif tid >= self.SPECIAL_BASE:
543 |                 flush()
544 |                 out.append(self.absid_to_special.get(tid, ""))
545 |             else:
546 |                 buf.append(int(tid))
547 |         flush()
548 |         return "".join(out)
549 | 
550 |     # ---- Tokenizer interface ----
551 |     def _tokenize(self, text: str, **kwargs) -> List[str]:
552 |         ids = self._encode_core(text)
553 |         toks: List[str] = []
554 |         for i in ids:
555 |             if i < self.MATH_BASE:
556 |                 toks.append(f"<U+{i:04X}>")
557 |             elif i < self.SPECIAL_BASE:
558 |                 toks.append(self.math_absid_to_token.get(i, "<UNK_MATH>"))
559 |             else:
560 |                 toks.append(self.absid_to_special.get(i, "<UNK_SPECIAL>"))
561 |         return toks
562 | 
563 |     def _convert_token_to_id(self, token: str) -> int:
564 |         if token.startswith("<U+") and token.endswith(">"):
565 |             try:
566 |                 return int(token[3:-1], 16)  # UTF-16 unit
567 |             except Exception:
568 |                 return self.unk_token_id if self.unk_token_id is not None else 0
569 |         # math or specials
570 |         if token in self.math_token_to_rawid:
571 |             return self.MATH_BASE + self.math_token_to_rawid[token]
572 |         if token in self.special_to_absid:
573 |             return self.special_to_absid[token]
574 |         # rare path: single-char token -> its UTF-16 unit
575 |         if len(token) == 1:
576 |             u = self._to_utf16_units(token)
577 |             if len(u) == 1:
578 |                 return u[0]
579 |         return self.unk_token_id if self.unk_token_id is not None else 0
580 | 
581 |     def _convert_id_to_token(self, index: int) -> str:
582 |         if index < self.MATH_BASE:
583 |             return f"<U+{index:04X}>"
584 |         if index < self.SPECIAL_BASE:
585 |             return self.math_absid_to_token.get(index, "<UNK_MATH>")
586 |         return self.absid_to_special.get(index, "<UNK_SPECIAL>")
587 | 
588 |     def convert_tokens_to_string(self, tokens: List[str]) -> str:
589 |         ids = [self._convert_token_to_id(t) for t in tokens]
590 |         return self._decode_core(ids)
591 | 
592 |     def decode(self, token_ids, skip_special_tokens: bool = False, **kwargs) -> str:
593 |         # Accept int, list, tuple, numpy, torch
594 |         if hasattr(token_ids, "tolist"):
595 |             token_ids = token_ids.tolist()
596 |         elif isinstance(token_ids, int):
597 |             token_ids = [token_ids]
598 |         else:
599 |             token_ids = list(token_ids)
600 |         token_ids = [int(i) for i in token_ids]  # normalize early
601 | 
602 |         if skip_special_tokens:
603 |             token_ids = [i for i in token_ids if i < self.SPECIAL_BASE]
604 |         return self._decode_core(token_ids)
605 | 
606 |     # HF plumbing
607 |     def build_inputs_with_special_tokens(
608 |         self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
609 |     ) -> List[int]:
610 |         out = (
611 |             list(token_ids_0)
612 |             if token_ids_1 is None
613 |             else list(token_ids_0) + list(token_ids_1)
614 |         )
615 |         # if self.eos_token_id is not None and (not out or out[-1] != self.eos_token_id):
616 |         #     out.append(self.eos_token_id)
617 |         return out
618 | 
619 |     def get_special_tokens_mask(
620 |         self,
621 |         token_ids_0: List[int],
622 |         token_ids_1: Optional[List[int]] = None,
623 |         already_has_special_tokens: bool = False,
624 |     ) -> List[int]:
625 |         def mask(seq: List[int]) -> List[int]:
626 |             return [1 if i >= self.SPECIAL_BASE else 0 for i in seq]
627 | 
628 |         return (
629 |             mask(token_ids_0)
630 |             if token_ids_1 is None
631 |             else mask(token_ids_0) + mask(token_ids_1)
632 |         )
633 | 
634 |     def create_token_type_ids_from_sequences(
635 |         self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
636 |     ) -> List[int]:
637 |         return [0] * (
638 |             len(token_ids_0)
639 |             if token_ids_1 is None
640 |             else len(token_ids_0) + len(token_ids_1)
641 |         )
642 | 
643 |     # Save/load raw assets
644 |     def save_vocabulary(
645 |         self, save_directory: str, filename_prefix: Optional[str] = None
646 |     ) -> Tuple[str, str]:
647 |         os.makedirs(save_directory, exist_ok=True)
648 |         pre = (filename_prefix + "-") if filename_prefix else ""
649 |         vocab_path = os.path.join(
650 |             save_directory, pre + self.vocab_files_names["vocab_file"]
651 |         )
652 |         specials_path = os.path.join(
653 |             save_directory, pre + self.vocab_files_names["specials_file"]
654 |         )
655 |         specials_dict_path = os.path.join(
656 |             save_directory, pre + self.vocab_files_names["specials_dict_file"]
657 |         )
658 |         with open(vocab_path, "w", encoding="utf-8") as f:
659 |             json.dump(self.math_token_to_rawid, f, ensure_ascii=False, indent=2)
660 |         # Save both the flat list ("all") and the category dict (preferred)
661 |         with open(specials_path, "w", encoding="utf-8") as f:
662 |             json.dump(self.special_tokens_list, f, ensure_ascii=False, indent=2)
663 |         with open(specials_dict_path, "w", encoding="utf-8") as f:
664 |             json.dump(self.special_tokens, f, ensure_ascii=False, indent=2)
665 |         return (vocab_path, specials_path)
666 | 
667 | 
668 | class SuryaOCRTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
669 |     def __init__(
670 |         self,
671 |         special_tokens: Dict[str, list] | None = None,
672 |         model_checkpoint: str = settings.FOUNDATION_MODEL_CHECKPOINT,
673 |         **kwargs,
674 |     ):
675 |         if special_tokens is None:
676 |             special_tokens = dict()
677 | 
678 |         self.special_tokens = special_tokens
679 | 
680 |         self.ocr_tokenizer = GreedyMathUTF16Tokenizer.from_pretrained(
681 |             model_checkpoint,
682 |         )
683 |         self.system_tokens = {
684 |             v: self.ocr_tokenizer(v)["input_ids"][0]
685 |             for v in special_tokens.get("system", [])
686 |         }
687 |         self.SPECIAL_TOKEN_MAPPING = self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING
688 | 
689 |         super().__init__(**kwargs)
690 | 
691 |     def get_vocab(self) -> Dict[str, int]:
692 |         return self.ocr_tokenizer.get_vocab()
693 | 
694 |     def _add_tokens(
695 |         self,
696 |         new_tokens: Union[List[str], List[AddedToken]],
697 |         special_tokens: bool = False,
698 |     ) -> int:
699 |         return self.ocr_tokenizer._add_tokens(
700 |             new_tokens, special_tokens=special_tokens
701 |         )
702 | 
703 |     @property
704 |     def vocab_size(self):
705 |         return self.ocr_tokenizer.vocab_size
706 | 
707 |     def _tokenize(self, text: str, **kwargs):
708 |         # task = kwargs.get("task", TaskNames.ocr_with_boxes)
709 |         # assert task in TASK_NAMES, f"Invalid task: {task}"
710 | 
711 |         tokens = self.ocr_tokenizer(text)["input_ids"]
712 | 
713 |         return tokens
714 | 
715 |     def __call__(
716 |         self,
717 |         texts: Union[str, List[str]],
718 |         tasks: Union[str, List[str]] = None,
719 |         **kwargs,
720 |     ) -> Dict[str, List[List[int]]]:
721 |         """Tokenizes text and returns input IDs."""
722 |         tokenized = []
723 | 
724 |         if isinstance(texts, str):
725 |             texts = [texts]
726 |             assert isinstance(tasks, str), "Tasks must be a string if texts is a string"
727 |             tasks = [tasks]
728 | 
729 |         if isinstance(texts, list):
730 |             assert isinstance(tasks, list), "Tasks must be a list if texts is a list"
731 | 
732 |         for text, task in zip(texts, tasks):
733 |             tokens = self._tokenize(text, task=task)
734 |             tokenized.append(tokens)
735 | 
736 |         return {"input_ids": tokenized}
737 | 
738 |     def decode(self, token_ids, **kwargs):
739 |         if isinstance(token_ids, (np.ndarray, torch.Tensor)):
740 |             token_ids = token_ids.tolist()
741 | 
742 |         decoded_text = self.ocr_tokenizer.decode(token_ids, skip_special_tokens=False)
743 |         # replace all <SCRIPT-...> tokens with empty strings
744 |         decoded_text = re.sub(r"<SCRIPT-.*?>", "", decoded_text)
745 |         # replace </S> with empty string
746 |         decoded_text = re.sub(r"</S>", "", decoded_text)
747 |         return decoded_text
748 | 
```

--------------------------------------------------------------------------------
/surya/common/surya/__init__.py:
--------------------------------------------------------------------------------

```python
  1 | import warnings
  2 | from typing import Optional, Tuple, TypedDict
  3 | from dataclasses import dataclass
  4 | 
  5 | import torch
  6 | from torch import nn
  7 | import torch.nn.functional as F
  8 | from transformers.modeling_outputs import CausalLMOutputWithPast
  9 | from transformers.cache_utils import Cache
 10 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 11 | 
 12 | from surya.common.pretrained import SuryaPreTrainedModel
 13 | from surya.common.s3 import S3DownloaderMixin
 14 | from surya.common.surya.config import SuryaModelConfig
 15 | from surya.common.surya.decoder import SuryaDecoderModel
 16 | from surya.common.surya.embedder import SimpleTokenEmbedder
 17 | from surya.common.surya.encoder import SuryaEncoderModel
 18 | from surya.common.util import pad_to_batch_size, pad_to_batch_size_repeat
 19 | from surya.common.xla import get_nearest_pad
 20 | from surya.settings import settings
 21 | 
 22 | from surya.logging import get_logger
 23 | 
 24 | logger = get_logger()
 25 | 
 26 | 
 27 | @dataclass
 28 | class SuryaModelOutput(CausalLMOutputWithPast):
 29 |     bbox_logits: torch.FloatTensor = None
 30 |     lm_logits: torch.FloatTensor = None
 31 | 
 32 | 
 33 | class FlashAttentionKwargs(TypedDict, total=False):
 34 |     """
 35 |     Keyword arguments for Flash Attention with Compile.
 36 | 
 37 |     Attributes:
 38 |         cu_seq_lens_q (`torch.LongTensor`, *optional*)
 39 |             Gets cumlative sequence length for query state.
 40 |         cu_seq_lens_k (`torch.LongTensor`, *optional*)
 41 |             Gets cumlative sequence length for key state.
 42 |         max_length_q (`int`, *optional*):
 43 |             Maximum sequence length for query state.
 44 |         max_length_k (`int`, *optional*):
 45 |             Maximum sequence length for key state.
 46 |     """
 47 | 
 48 |     cu_seq_lens_q: Optional[torch.LongTensor]
 49 |     cu_seq_lens_k: Optional[torch.LongTensor]
 50 |     max_length_q: Optional[int]
 51 |     max_length_k: Optional[int]
 52 | 
 53 | 
 54 | class KwargsForCausalLM(FlashAttentionKwargs): ...
 55 | 
 56 | 
 57 | class DistanceProjection(nn.Module):
 58 |     def __init__(self, in_features: int, out_features: int):
 59 |         super().__init__()
 60 |         self.fc1 = nn.Linear(in_features, out_features)
 61 |         self.act = nn.SiLU()
 62 |         self.fc2 = nn.Linear(out_features, out_features)
 63 | 
 64 |     def forward(self, x: torch.Tensor) -> torch.Tensor:
 65 |         x = self.fc1(x)
 66 |         x = self.act(x)
 67 |         x = self.fc2(x)
 68 |         return x
 69 | 
 70 |     def init_weights(self):
 71 |         nn.init.xavier_uniform_(self.fc1.weight)
 72 |         nn.init.xavier_uniform_(self.fc2.weight)
 73 |         nn.init.zeros_(self.fc1.bias)
 74 |         nn.init.zeros_(self.fc2.bias)
 75 | 
 76 | 
 77 | class BboxHead(nn.Module):
 78 |     def __init__(self, in_features: int, out_features: int):
 79 |         super().__init__()
 80 |         self.proj_layers = nn.ModuleList(
 81 |             [nn.Linear(in_features, in_features) for _ in range(6)]
 82 |         )
 83 |         self.act = nn.SiLU()
 84 |         self.out_proj = nn.Linear(in_features, out_features)
 85 | 
 86 |     def forward(self, x: torch.Tensor) -> torch.Tensor:
 87 |         for layer in self.proj_layers:
 88 |             x = layer(x)
 89 |             x = self.act(x)
 90 | 
 91 |         x = self.out_proj(x)
 92 |         return x
 93 | 
 94 | 
 95 | class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel):
 96 |     config_class = SuryaModelConfig
 97 |     supports_gradient_checkpointing = True
 98 |     _skip_keys_device_placement = ["past_key_values"]
 99 |     _supports_flash_attn_2 = True
100 |     _supports_sdpa = True
101 |     _supports_flex_attn = True
102 |     _supports_cache_class = True
103 |     _supports_quantized_cache = True
104 |     _supports_static_cache = True
105 |     _supports_attention_backend = True
106 |     main_input_name = "input_ids"
107 |     _tied_weights_keys = ["lm_head.weight"]
108 | 
109 |     def __init__(
110 |         self,
111 |         config: SuryaModelConfig,
112 |         embedder: SimpleTokenEmbedder = None,
113 |         vision_encoder: SuryaEncoderModel = None,
114 |         decoder: SuryaDecoderModel = None,
115 |         **kwargs,
116 |     ):
117 |         super().__init__(config, **kwargs)
118 | 
119 |         if vision_encoder is None:
120 |             vision_encoder = SuryaEncoderModel(config.vision_encoder)
121 | 
122 |         if decoder is None:
123 |             decoder = SuryaDecoderModel(config.decoder)
124 | 
125 |         if embedder is None:
126 |             embedder = SimpleTokenEmbedder(config)
127 | 
128 |         self.vision_encoder = vision_encoder
129 |         self.decoder = decoder
130 |         self.embedder = embedder
131 | 
132 |         # Simple encoding for image patches
133 |         self.img_w_embed = nn.Embedding(
134 |             self.config.image_embed_encoding_size,
135 |             self.config.hidden_size,
136 |         )
137 | 
138 |         self.img_h_embed = nn.Embedding(
139 |             self.config.image_embed_encoding_size,
140 |             self.config.hidden_size,
141 |         )
142 | 
143 |         # Tying configs
144 |         self.vision_encoder.config = self.config.vision_encoder
145 |         self.decoder.config = self.config.decoder
146 | 
147 |         self.bbox_head = BboxHead(config.hidden_size, 6)
148 |         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
149 | 
150 |         if (
151 |             self.config.multi_output_distance is not None
152 |             and self.config.multi_output_distance > 0
153 |         ):
154 |             self.multi_output_projections = nn.ModuleList(
155 |                 [
156 |                     DistanceProjection(
157 |                         in_features=config.hidden_size, out_features=config.hidden_size
158 |                     )
159 |                     for _ in range(self.config.multi_output_distance)
160 |                 ]
161 |             )
162 | 
163 |     def tie_weights(self):
164 |         self._tie_weights()
165 | 
166 |     def _tie_weights(self):
167 |         # Tie weights of lm head and token embedder
168 |         self._tie_or_clone_weights(self.lm_head, self.embedder.token_embed)
169 | 
170 |     def get_output_embeddings(self) -> nn.Module:
171 |         return self.lm_head
172 | 
173 |     def get_input_embeddings(self) -> nn.Module:
174 |         return self.embedder.token_embed
175 | 
176 |     def set_output_embeddings(self, new_embeddings: nn.Module):
177 |         self.lm_head = new_embeddings
178 | 
179 |     def set_input_embeddings(self, new_embeddings: nn.Module):
180 |         self.embedder.token_embed = new_embeddings
181 | 
182 |     def maybe_static_pad_image_inputs(
183 |         self,
184 |         chunk_pixels: torch.Tensor,
185 |         chunk_grid_thw: torch.Tensor,
186 |         actual_chunk_len: int,
187 |         encoder_chunk_size: int,
188 |     ) -> Tuple[torch.Tensor, torch.Tensor]:
189 |         valid_embed_len = actual_chunk_len // (
190 |             self.vision_encoder.spatial_merge_size**2
191 |         )
192 |         if settings.FOUNDATION_STATIC_CACHE and actual_chunk_len < encoder_chunk_size:
193 |             padding_len = encoder_chunk_size - actual_chunk_len
194 |             chunk_pixels = F.pad(
195 |                 chunk_pixels,
196 |                 (0, 0, 0, padding_len),
197 |                 mode="constant",
198 |                 value=0.0,
199 |             )
200 | 
201 |             padding_grid = torch.tensor(
202 |                 [[1, 2, padding_len // 2]],
203 |                 device=chunk_grid_thw.device,
204 |                 dtype=chunk_grid_thw.dtype,
205 |             )
206 |             chunk_grid_thw = torch.cat([chunk_grid_thw, padding_grid], dim=0)
207 | 
208 |         return chunk_pixels, chunk_grid_thw, valid_embed_len
209 | 
210 |     def get_image_embeddings(
211 |         self,
212 |         pixel_values: torch.Tensor,
213 |         grid_thw: torch.Tensor,
214 |         encoder_chunk_size: int,
215 |         valid_batch_size: torch.Tensor | None = None,
216 |         max_batch_size: int | None = None,
217 |     ):
218 |         # embed all images with the vision encoder after they have already been tiled and flattened into a single batch
219 |         chunks = [0]
220 |         grid_chunks = [0]
221 |         curr_chunk_len = 0
222 |         curr_seq_len = 0
223 |         for i in range(len(grid_thw)):
224 |             curr_chunk_len += (grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]).item()
225 |             if curr_chunk_len > encoder_chunk_size:
226 |                 chunks.append(curr_chunk_len + curr_seq_len)
227 |                 curr_seq_len += curr_chunk_len
228 |                 curr_chunk_len = 0
229 |                 grid_chunks.append(i + 1)
230 | 
231 |         if curr_chunk_len > 0:
232 |             chunks.append(pixel_values.shape[0])
233 |             grid_chunks.append(len(grid_thw))
234 | 
235 |         assert curr_chunk_len + curr_seq_len == pixel_values.shape[0], (
236 |             f"Mismatch in encoder chunking, {curr_chunk_len} + {curr_seq_len} != {pixel_values.shape[0]}"
237 |         )
238 | 
239 |         logger.debug(
240 |             f"Chunking encoder sequence into {len(chunks) - 1} chunks of size {encoder_chunk_size} with lengths {chunks} and grids {grid_chunks}"
241 |         )
242 |         embeddings = []
243 |         for i in range(len(chunks) - 1):
244 |             start = chunks[i]
245 |             end = chunks[i + 1]
246 |             grid_start = grid_chunks[i]
247 |             grid_end = grid_chunks[i + 1]
248 | 
249 |             chunk_pixels = pixel_values[start:end]
250 |             chunk_grid_thw = grid_thw[grid_start:grid_end]
251 |             actual_chunk_len = end - start
252 |             chunk_pixels, chunk_grid_thw, valid_embed_len = (
253 |                 self.maybe_static_pad_image_inputs(
254 |                     chunk_pixels, chunk_grid_thw, actual_chunk_len, encoder_chunk_size
255 |                 )
256 |             )
257 | 
258 |             chunk_embeddings = self.vision_encoder.embed_images(
259 |                 image_batch=chunk_pixels.unsqueeze(0).to(device=self.device),
260 |                 grid_thw=chunk_grid_thw.unsqueeze(0).to(device=self.device),
261 |             )
262 |             embeddings.append(chunk_embeddings[:valid_embed_len].squeeze(0))
263 | 
264 |         if len(embeddings) == 0:
265 |             raise ValueError(
266 |                 "No image embeddings were generated. Check the input images and grid sizes."
267 |             )
268 |         elif len(embeddings) == 1:
269 |             embeddings = embeddings[0]
270 |         else:
271 |             embeddings = torch.cat(embeddings, dim=0)
272 | 
273 |         encoding_2d = self.get_2d_learned_embeddings(
274 |             grid_thw,
275 |             device=embeddings.device,
276 |             bbox_size=self.config.image_embed_encoding_multiplier,
277 |         )
278 |         assert embeddings.shape[0] == encoding_2d.shape[0], (
279 |             f"Mismatch in image embedding seq len: {embeddings.shape} vs {encoding_2d.shape}"
280 |         )
281 |         assert embeddings.shape[1] == encoding_2d.shape[1], (
282 |             f"Mismatch in image embedding token counts: {embeddings.shape} vs {encoding_2d.shape}"
283 |         )
284 | 
285 |         embeddings = embeddings + encoding_2d
286 | 
287 |         return embeddings
288 | 
289 |     def embed_ids_boxes_images(
290 |         self,
291 |         input_ids,
292 |         image_embeddings,
293 |         encoder_chunk_size: int,
294 |         valid_batch_size: torch.Tensor | None = None,
295 |         input_boxes: torch.Tensor | None = None,
296 |         embed_boxes: torch.Tensor | None = None,
297 |     ):
298 |         """
299 |         Insert embedded image tiles into the corresponding positions into the full input sequence
300 | 
301 |         Positions to insert new tokens are indicated by the special image token index
302 |         """
303 |         # This is batched in the inner call
304 |         inputs_embeds = self.embedder.embed(
305 |             input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes
306 |         )
307 | 
308 |         if image_embeddings is not None:
309 |             special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
310 |             special_image_mask = special_image_mask.expand_as(inputs_embeds)
311 |             if inputs_embeds[special_image_mask].numel() != image_embeddings.numel():
312 |                 n_image_tokens = torch.sum((input_ids == self.config.image_token_id))
313 |                 n_image_features = image_embeddings.shape[0] * image_embeddings.shape[1]
314 |                 warnings.warn(
315 |                     f"Image features and image tokens do not match: tokens {n_image_tokens}, features {n_image_features}. This may lead to unexpected results"
316 |                 )
317 |             image_features = image_embeddings.to(inputs_embeds.dtype)
318 |             inputs_embeds = inputs_embeds.masked_scatter(
319 |                 special_image_mask, image_features
320 |             )
321 |         else:
322 |             assert (input_ids == self.config.image_token_id).sum() == 0, (
323 |                 "Image tokens were present in the input but no input images were provided"
324 |             )
325 | 
326 |         return inputs_embeds
327 | 
328 |     def get_2d_learned_embeddings(
329 |         self,
330 |         grid_thw,
331 |         device: str | torch.device = "cpu",
332 |         bbox_size: int = 256,
333 |     ):
334 |         all_embeddings = []
335 |         for grid_t, grid_h, grid_w in grid_thw:
336 |             llm_grid_h, llm_grid_w = (
337 |                 grid_h // self.config.merge_size,
338 |                 grid_w // self.config.merge_size,
339 |             )
340 | 
341 |             # Scale to 0-1024
342 |             llm_grid_h = (
343 |                 torch.arange(llm_grid_h, device=device)
344 |                 / max(1, (llm_grid_h - 1))
345 |                 * bbox_size
346 |             )
347 |             llm_grid_w = (
348 |                 torch.arange(llm_grid_w, device=device)
349 |                 / max(1, (llm_grid_w - 1))
350 |                 * bbox_size
351 |             )
352 | 
353 |             llm_grid_w_idx = llm_grid_w.to(torch.long)
354 |             llm_grid_h_idx = llm_grid_h.to(torch.long)
355 | 
356 |             llm_grid_w = self.img_w_embed(llm_grid_w_idx)
357 |             llm_grid_h = self.img_h_embed(llm_grid_h_idx)
358 | 
359 |             full_grid = llm_grid_h[:, None] + llm_grid_w[None, :]
360 | 
361 |             flattened = full_grid.flatten(
362 |                 0, 1
363 |             )  # Flatten first dimension, so they are seq_len x embed_dim
364 |             all_embeddings.append(flattened)
365 |         return torch.concat(
366 |             all_embeddings, dim=0
367 |         )  # Shape is num_image_tokens x embed_dim
368 | 
369 |     def get_logits(self, hidden_states):
370 |         assert hidden_states.shape[1] == 1, (
371 |             "Multi output predictions only applied on the last token"
372 |         )
373 | 
374 |         all_lm_logits = []
375 |         all_bbox_logits = []
376 | 
377 |         current_hidden = hidden_states
378 | 
379 |         # Loop includes initial prediction (i=0) plus multi_out_distance additional predictions
380 |         for i in range(self.config.multi_output_distance + 1):
381 |             if i > 0:
382 |                 current_hidden = self.multi_output_projections[i - 1](current_hidden)
383 | 
384 |             lm_logits = self.lm_head(current_hidden)
385 |             bbox_logits = F.sigmoid(self.bbox_head(current_hidden))
386 | 
387 |             all_lm_logits.append(lm_logits)
388 |             all_bbox_logits.append(bbox_logits)
389 | 
390 |         # Concatenate along sequence dimension (dim=1)
391 |         final_lm_logits = torch.cat(all_lm_logits, dim=1)
392 |         final_bbox_logits = torch.cat(all_bbox_logits, dim=1)
393 | 
394 |         return final_lm_logits, final_bbox_logits
395 | 
396 |     def forward(
397 |         self,
398 |         input_ids=None,
399 |         image_embeddings=None,
400 |         labels=None,
401 |         image_tiles=None,
402 |         grid_thw=None,
403 |         inputs_embeds=None,
404 |         attention_mask=None,
405 |         position_ids=None,
406 |         cache_position=None,
407 |         past_key_values=None,
408 |         output_hidden_states=False,
409 |         output_attentions=False,
410 |         use_cache=False,
411 |         encoder_chunk_size=32768,
412 |         cache_idxs=None,
413 |         num_valid_tokens=None,
414 |         prefill=True,
415 |         text_lengths=None,
416 |         valid_batch_size: torch.Tensor = None,
417 |         input_boxes=None,
418 |         embed_boxes=None,
419 |         logits_to_keep=None,
420 |         **kwargs: KwargsForCausalLM,
421 |     ):
422 |         if any([
423 |             input_ids is None,
424 |             position_ids is None,
425 |             cache_position is None,
426 |             (
427 |                 prefill
428 |                 and not (
429 |                     (image_tiles is not None and grid_thw is not None)
430 |                     or image_embeddings is not None
431 |                 )
432 |             ),
433 |         ]):
434 |             raise ValueError(
435 |                 "`input_ids`, `position_ids`, and `cache_position` **must** be specified. "
436 |                 "For prefill, you must provide either (`image_tiles` and `grid_thw`) or `image_embeddings`."
437 |             )
438 | 
439 | 
440 |         inputs_embeds = self.embed_ids_boxes_images(
441 |             input_ids, image_embeddings, encoder_chunk_size, valid_batch_size, input_boxes, embed_boxes
442 |         )
443 | 
444 |         # Handling flash attention kwargs outside the decoder to speed up + avoid graph breaks inside the decoder
445 |         # Skipped during decoding since not required
446 |         if self.decoder.config._attn_implementation == "flash_attention_2" and prefill:
447 |             # Needed for CPU -> GPU
448 |             from surya.common.surya.flash_attn_utils import _get_unpad_data
449 |             batch_size, query_length, _ = inputs_embeds.shape
450 |             indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
451 |                 attention_mask
452 |             )
453 |             kwargs["batch_size"] = batch_size
454 |             kwargs["query_length"] = query_length
455 |             kwargs["indices_k"] = indices_k
456 |             kwargs["cu_seqlens_k"] = cu_seqlens_k
457 |             kwargs["max_seqlen_in_batch_k"] = max_seqlen_in_batch_k
458 | 
459 |         causal_mask = self._update_causal_mask(
460 |             attention_mask,
461 |             inputs_embeds,
462 |             cache_position,
463 |             past_key_values,
464 |             output_attentions,
465 |         )
466 | 
467 |         attention_mask = causal_mask
468 |         outputs = self.decoder(
469 |             inputs_embeds=inputs_embeds,
470 |             attention_mask=attention_mask,
471 |             position_ids=position_ids,
472 |             cache_position=cache_position,
473 |             past_key_values=past_key_values,
474 |             return_dict=True,
475 |             use_cache=use_cache,
476 |             cache_idxs=cache_idxs,
477 |             num_valid_tokens=num_valid_tokens,
478 |             prefill=prefill,
479 |             text_lengths=text_lengths,
480 |             **kwargs,
481 |         )
482 | 
483 |         hidden_states = outputs.last_hidden_state
484 |         if logits_to_keep is not None:
485 |             hidden_states = hidden_states[:, -logits_to_keep:, :]
486 |         hidden_states = hidden_states.contiguous()
487 | 
488 |         loss = None
489 |         if labels is not None:
490 |             # Training, return full logits
491 |             lm_logits = self.lm_head(hidden_states)
492 |             bbox_logits = None
493 |             vocab_size = lm_logits.shape[-1]
494 |             labels = torch.roll(labels, shifts=-1, dims=-1)
495 |             loss = F.cross_entropy(
496 |                 lm_logits.view(-1, vocab_size), labels.view(-1), reduction="mean"
497 |             )
498 |         else:
499 |             lm_logits, bbox_logits = self.get_logits(hidden_states)
500 | 
501 |         return SuryaModelOutput(
502 |             loss=loss,
503 |             bbox_logits=bbox_logits,
504 |             lm_logits=lm_logits,
505 |             hidden_states=outputs.hidden_states if output_hidden_states else None,
506 |             attentions=outputs.attentions if output_attentions else None,
507 |             past_key_values=outputs.past_key_values,
508 |         )
509 | 
510 |     def _update_causal_mask(
511 |         self,
512 |         attention_mask: torch.Tensor,
513 |         input_tensor: torch.Tensor,
514 |         cache_position: torch.Tensor,
515 |         past_key_values: Cache,
516 |         output_attentions: bool,
517 |     ):
518 |         if self.decoder.config._attn_implementation == "flash_attention_2":
519 |             return attention_mask
520 | 
521 |         # We always pass in a 2D attention mask from the processor - In both static and dynamic cache cases
522 |         dtype, device = input_tensor.dtype, input_tensor.device
523 |         min_dtype = torch.finfo(dtype).min
524 |         sequence_length = input_tensor.shape[1]
525 |         target_length = (
526 |             attention_mask.shape[-1]
527 |             if isinstance(attention_mask, torch.Tensor)
528 |             else past_key_values.max_cache_len
529 |         )
530 | 
531 |         # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
532 |         causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
533 |             attention_mask,
534 |             sequence_length=sequence_length,
535 |             target_length=target_length,
536 |             dtype=dtype,
537 |             device=device,
538 |             cache_position=cache_position,
539 |             batch_size=input_tensor.shape[0],
540 |             config=self.config,
541 |             past_key_values=past_key_values,
542 |         )
543 | 
544 |         if (
545 |             self.config._attn_implementation == "sdpa"
546 |             and attention_mask is not None
547 |             and attention_mask.device.type in ["cuda", "xpu"]
548 |             and not output_attentions
549 |         ):
550 |             # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
551 |             # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
552 |             # Details: https://github.com/pytorch/pytorch/issues/110213
553 |             causal_mask = AttentionMaskConverter._unmask_unattended(
554 |                 causal_mask, min_dtype
555 |             )
556 | 
557 |         return causal_mask
558 | 
559 |     @staticmethod
560 |     def _prepare_4d_causal_attention_mask_with_cache_position(
561 |         attention_mask: torch.Tensor,
562 |         sequence_length: int,
563 |         target_length: int,
564 |         dtype: torch.dtype,
565 |         device: torch.device,
566 |         cache_position: torch.Tensor,
567 |         batch_size: int,
568 |         config: SuryaModelConfig,
569 |         past_key_values: Cache,
570 |     ):
571 |         """
572 |         Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
573 |         `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
574 | 
575 |         Args:
576 |             attention_mask (`torch.Tensor`):
577 |                 A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
578 |             sequence_length (`int`):
579 |                 The sequence length being processed.
580 |             target_length (`int`):
581 |                 The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
582 |             dtype (`torch.dtype`):
583 |                 The dtype to use for the 4D attention mask.
584 |             device (`torch.device`):
585 |                 The device to plcae the 4D attention mask on.
586 |             cache_position (`torch.Tensor`):
587 |                 Indices depicting the position of the input sequence tokens in the sequence. Shape `(batch_size, sequence_length)`.
588 |             batch_size (`torch.Tensor`):
589 |                 Batch size.
590 |             config (`Qwen2Config`):
591 |                 The model's configuration class
592 |             past_key_values (`Cache`):
593 |                 The cache class that is being used currently to generate
594 |         """
595 |         if attention_mask is not None and attention_mask.dim() == 4:
596 |             # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
597 |             causal_mask = attention_mask
598 |         else:
599 |             min_dtype = torch.finfo(dtype).min
600 |             causal_mask = torch.full(
601 |                 (sequence_length, target_length),
602 |                 fill_value=min_dtype,
603 |                 dtype=dtype,
604 |                 device=device,
605 |             )
606 |             # Batch-aware diagonal attend mask
607 |             diagonal_attend_mask = torch.arange(target_length, device=device).unsqueeze(
608 |                 0
609 |             ) > cache_position.unsqueeze(-1)
610 |             causal_mask = (
611 |                 causal_mask.unsqueeze(0) * diagonal_attend_mask
612 |             )  # (batch_size, seq_len, target_len)
613 |             causal_mask = causal_mask[
614 |                 :, None, :, :
615 |             ]  # (batch_size, 1, seq_len, target_len)
616 |             if attention_mask is not None:
617 |                 causal_mask = (
618 |                     causal_mask.clone()
619 |                 )  # copy to contiguous memory for in-place edit
620 |                 if attention_mask.shape[-1] > target_length:
621 |                     attention_mask = attention_mask[:, :target_length]
622 |                 mask_length = attention_mask.shape[-1]
623 |                 padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
624 |                     :, None, None, :
625 |                 ].to(causal_mask.device)
626 |                 padding_mask = padding_mask == 0
627 |                 causal_mask[:, :, :, :mask_length] = causal_mask[
628 |                     :, :, :, :mask_length
629 |                 ].masked_fill(padding_mask, min_dtype)
630 |         return causal_mask
631 | 
632 | class SuryaXLAModel(SuryaModel):
633 |     def get_image_embeddings(
634 |         self,
635 |         pixel_values: torch.Tensor,
636 |         grid_thw: torch.Tensor,
637 |         encoder_chunk_size: int,
638 |         valid_batch_size: torch.Tensor | None = None,
639 |         max_batch_size: int | None = None,
640 |     ):
641 |         # embed all images with the vision encoder after they have already been tiled and flattened into a single batch
642 |         unpadded_max_grid_size = (
643 |             (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).max().item()
644 |         )
645 |         max_grid_size = get_nearest_pad(
646 |             unpadded_max_grid_size,
647 |         )  # If we need zero padding, we still need to allocate a bit of room for the extra grid_thw
648 | 
649 |         # Always need 2 items in each row batch
650 |         if max_grid_size == unpadded_max_grid_size:
651 |             max_grid_size += 16
652 | 
653 |         full_image_grid = torch.zeros(
654 |             (valid_batch_size, max_grid_size, pixel_values.shape[-1]),
655 |             dtype=pixel_values.dtype,
656 |         )
657 | 
658 |         # Roll out into a full grid
659 |         seq_len = 0
660 |         row_grids = []
661 |         for i in range(valid_batch_size):
662 |             curr_sample_len = grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]
663 |             full_image_grid[i, -curr_sample_len:] = pixel_values[
664 |                 seq_len : seq_len + curr_sample_len
665 |             ]
666 |             padded_len = max_grid_size - curr_sample_len
667 |             if padded_len > 0:
668 |                 row_grid = torch.tensor(
669 |                     [
670 |                         [1, 4, padded_len // 4],
671 |                         grid_thw[i].tolist(),
672 |                     ],
673 |                     dtype=torch.long,
674 |                 )
675 |             else:
676 |                 row_grid = torch.tensor(
677 |                     [
678 |                         grid_thw[i].tolist(),
679 |                     ],
680 |                     dtype=torch.long,
681 |                 )
682 | 
683 |             row_grids.append(row_grid)
684 |             seq_len += curr_sample_len
685 | 
686 |         # bsz, 2, 3
687 |         row_grids = torch.stack(row_grids, dim=0)
688 | 
689 |         if settings.FOUNDATION_STATIC_CACHE:
690 |             # Pad to max batch size, repeat the final row
691 |             row_grids = pad_to_batch_size_repeat(
692 |                 row_grids,
693 |                 batch_size=max_batch_size,
694 |             )
695 |             full_image_grid = pad_to_batch_size(
696 |                 full_image_grid,
697 |                 batch_size=max_batch_size,
698 |             )
699 | 
700 |         full_image_grid = full_image_grid.to(self.device)
701 | 
702 |         embeddings = self.vision_encoder.embed_images(
703 |             image_batch=full_image_grid, grid_thw=row_grids.to(self.device)
704 |         )
705 | 
706 |         encoding_2d = self.get_2d_learned_embeddings(
707 |             row_grids,
708 |             bbox_size=self.config.image_embed_encoding_multiplier,
709 |         )
710 |         embeddings += encoding_2d
711 | 
712 |         return embeddings
713 | 
714 |     def embed_ids_boxes_images(
715 |         self,
716 |         input_ids,
717 |         image_embeddings,
718 |         encoder_chunk_size: int,
719 |         valid_batch_size: torch.Tensor | None = None,
720 |         input_boxes: torch.Tensor | None = None,
721 |         embed_boxes: torch.Tensor | None = None,
722 |     ):
723 |         """
724 |         Insert embedded image tiles into the corresponding positions into the full input sequence
725 | 
726 |         Positions to insert new tokens are indicated by the special image token index
727 |         """
728 |         # This is batched in the inner call
729 |         inputs_embeds = self.embedder.embed(
730 |             input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes
731 |         )
732 | 
733 |         if image_embeddings is not None:
734 |             image_token_id_tensor = torch.tensor(
735 |                 self.config.image_token_id,
736 |                 device=inputs_embeds.device,
737 |                 dtype=torch.long,
738 |             )
739 |             mask = input_ids == image_token_id_tensor
740 |             last_image_token_pos = (
741 |                 mask.size(1)
742 |                 - 1
743 |                 - mask.flip(dims=[1]).long().argmax(dim=1, keepdim=True)
744 |             )
745 |             # Calculate start position to replace N positions ending at (and including) the last image token
746 |             start_positions = last_image_token_pos - image_embeddings[0].shape[0]
747 |             batch_size, insert_len = image_embeddings.shape[:2]
748 | 
749 |             # Create position indices for each insertion
750 |             pos_indices = torch.arange(
751 |                 insert_len, device=inputs_embeds.device
752 |             ).unsqueeze(0)
753 |             insert_positions = start_positions + pos_indices
754 | 
755 |             idx = insert_positions.unsqueeze(-1).expand(
756 |                 -1, -1, inputs_embeds.size(-1)
757 |             )  # [B,N,D]
758 |             inputs_embeds = inputs_embeds.scatter(1, idx, image_embeddings)
759 | 
760 |         inputs_embeds = inputs_embeds * (
761 |             input_ids != self.config.pad_token_id
762 |         ).unsqueeze(-1).to(inputs_embeds.dtype)
763 |         return inputs_embeds
764 | 
765 |     def get_2d_learned_embeddings(
766 |         self,
767 |         grid_thw,
768 |         bbox_size: int = 256,
769 |     ):
770 |         dev = grid_thw.device
771 |         all_row_coords = []
772 |         all_col_coords = []
773 |         for row_grid in grid_thw:
774 |             merge = self.config.merge_size
775 | 
776 |             # per-sample grid sizes after merge
777 |             H = (row_grid[:, 1] // merge).long()  # (B,)
778 |             W = (row_grid[:, 2] // merge).long()  # (B,)
779 | 
780 |             row_coords = torch.cat(
781 |                 [
782 |                     torch.linspace(0, bbox_size, steps=int(h), device=dev)
783 |                     .round()
784 |                     .repeat_interleave(w)  # repeat each row value w times
785 |                     for h, w in zip(H.tolist(), W.tolist())
786 |                 ]
787 |             )  # (full_grid_size,)
788 | 
789 |             col_coords = torch.cat(
790 |                 [
791 |                     torch.linspace(0, bbox_size, steps=int(w), device=dev)
792 |                     .round()
793 |                     .repeat(int(h))  # tile the column vector h times
794 |                     for h, w in zip(H.tolist(), W.tolist())
795 |                 ]
796 |             )  # (full_grid_size,)
797 |             all_row_coords.append(row_coords)
798 |             all_col_coords.append(col_coords)
799 |         row_coords = torch.stack(all_row_coords, dim=0).to(self.device)
800 |         col_coords = torch.stack(all_col_coords, dim=0).to(self.device)
801 | 
802 |         emb = self.img_h_embed(row_coords.long()) + self.img_w_embed(col_coords.long())
803 |         return emb
804 | 
```

--------------------------------------------------------------------------------
/surya/common/surya/encoder/__init__.py:
--------------------------------------------------------------------------------

```python
  1 | import math
  2 | from typing import Optional, Tuple
  3 | 
  4 | import torch
  5 | import torch.nn as nn
  6 | import torch.nn.functional as F
  7 | from transformers.activations import ACT2FN
  8 | 
  9 | from surya.common.pretrained import SuryaPreTrainedModel
 10 | from surya.common.surya.encoder.config import SuryaEncoderConfig
 11 | from surya.common.xla import get_nearest_pad
 12 | from surya.logging import get_logger
 13 | from surya.settings import settings
 14 | 
 15 | if settings.FOUNDATION_XLA:
 16 |     import torch_xla.experimental.custom_kernel
 17 | 
 18 | from surya.logging import get_logger
 19 | logger = get_logger()
 20 | 
 21 | 
 22 | class Qwen2_5_VLMLP(nn.Module):
 23 |     def __init__(self, config, bias: bool = False):
 24 |         super().__init__()
 25 |         self.hidden_size = config.hidden_size
 26 |         self.intermediate_size = config.intermediate_size
 27 |         self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
 28 |         self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
 29 |         self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
 30 |         self.act_fn = ACT2FN[config.hidden_act]
 31 | 
 32 |     def forward(self, hidden_state):
 33 |         return self.down_proj(
 34 |             self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)
 35 |         )
 36 | 
 37 | 
 38 | class Qwen2_5_VisionPatchEmbed(nn.Module):
 39 |     def __init__(
 40 |         self,
 41 |         patch_size: int = 14,
 42 |         temporal_patch_size: int = 2,
 43 |         in_channels: int = 3,
 44 |         embed_dim: int = 1152,
 45 |     ) -> None:
 46 |         super().__init__()
 47 |         self.patch_size = patch_size
 48 |         self.temporal_patch_size = temporal_patch_size
 49 |         self.in_channels = in_channels
 50 |         self.embed_dim = embed_dim
 51 | 
 52 |         kernel_size = [temporal_patch_size, patch_size, patch_size]
 53 |         self.proj = nn.Conv3d(
 54 |             in_channels,
 55 |             embed_dim,
 56 |             kernel_size=kernel_size,
 57 |             stride=kernel_size,
 58 |             bias=False,
 59 |         )
 60 | 
 61 |     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 62 |         target_dtype = self.proj.weight.dtype
 63 |         bsz = hidden_states.shape[0]
 64 |         hidden_states = hidden_states.view(
 65 |             -1,
 66 |             self.in_channels,
 67 |             self.temporal_patch_size,
 68 |             self.patch_size,
 69 |             self.patch_size,
 70 |         )
 71 |         hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(
 72 |             bsz, -1, self.embed_dim
 73 |         )
 74 |         return hidden_states
 75 | 
 76 | 
 77 | class Qwen2_5_VisionRotaryEmbedding(nn.Module):
 78 |     def __init__(self, dim: int, theta: float = 10000.0) -> None:
 79 |         super().__init__()
 80 |         self.inv_freq = 1.0 / (
 81 |             theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)
 82 |         )
 83 | 
 84 |     def forward(self, seqlen: int) -> torch.Tensor:
 85 |         seq = torch.arange(seqlen, device="cpu", dtype=self.inv_freq.dtype)
 86 |         freqs = torch.outer(seq, self.inv_freq)
 87 |         return freqs
 88 | 
 89 | 
 90 | class Qwen2RMSNorm(nn.Module):
 91 |     def __init__(self, hidden_size, eps=1e-6):
 92 |         """
 93 |         Qwen2RMSNorm is equivalent to T5LayerNorm
 94 |         """
 95 |         super().__init__()
 96 |         self.weight = nn.Parameter(torch.ones(hidden_size))
 97 |         self.variance_epsilon = eps
 98 | 
 99 |     def forward(self, hidden_states):
100 |         input_dtype = hidden_states.dtype
101 |         hidden_states = hidden_states.to(torch.float32)
102 |         variance = hidden_states.pow(2).mean(-1, keepdim=True)
103 |         hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
104 |         return self.weight * hidden_states.to(input_dtype)
105 | 
106 |     def extra_repr(self):
107 |         return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
108 | 
109 | 
110 | class Qwen2_5_VLPatchMerger(nn.Module):
111 |     def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
112 |         super().__init__()
113 |         self.hidden_size = context_dim * (spatial_merge_size**2)
114 |         self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
115 |         self.mlp = nn.Sequential(
116 |             nn.Linear(self.hidden_size, self.hidden_size),
117 |             nn.GELU(),
118 |             nn.Linear(self.hidden_size, dim),
119 |         )
120 | 
121 |     def forward(self, x: torch.Tensor) -> torch.Tensor:
122 |         bsz = x.shape[0]
123 |         x = self.mlp(self.ln_q(x).view(bsz, -1, self.hidden_size))
124 |         return x
125 | 
126 | 
127 | def apply_rotary_pos_emb_flashatt(
128 |     q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
129 | ) -> Tuple[torch.Tensor, torch.Tensor]:
130 |     from flash_attn.layers.rotary import apply_rotary_emb
131 | 
132 |     cos = cos.chunk(2, dim=-1)[0].contiguous()
133 |     sin = sin.chunk(2, dim=-1)[0].contiguous()
134 |     q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
135 |     k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
136 |     return q_embed, k_embed
137 | 
138 | 
139 | class Qwen2_5_VLVisionXLASdpaAttention(nn.Module):
140 |     def __init__(self, dim: int, num_heads: int = 16) -> None:
141 |         super().__init__()
142 |         self.num_heads = num_heads
143 |         self.qkv = nn.Linear(dim, dim * 3, bias=True)
144 |         self.proj = nn.Linear(dim, dim)
145 |         self.head_dim = dim // num_heads
146 | 
147 |     def forward(
148 |         self,
149 |         hidden_states: torch.Tensor,
150 |         cu_seqlens: torch.Tensor,
151 |         rotary_pos_emb: Optional[torch.Tensor] = None,
152 |         position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
153 |     ) -> torch.Tensor:
154 |         bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1]
155 |         q, k, v = (
156 |             self.qkv(hidden_states)
157 |             .reshape(bsz, seq_length, 3, self.num_heads, -1)
158 |             .permute(0, 2, 1, 3, 4)
159 |             .unbind(1)
160 |         )
161 |         if position_embeddings is None:
162 |             logger.warning_once(
163 |                 "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
164 |                 "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
165 |                 "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
166 |                 "removed and `position_embeddings` will be mandatory."
167 |             )
168 |             emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
169 |             cos = emb.cos()
170 |             sin = emb.sin()
171 |         else:
172 |             cos, sin = position_embeddings
173 |         q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
174 | 
175 |         attention_mask = torch.zeros([bsz, 1, seq_length, seq_length], dtype=torch.bool)
176 |         cu_seqlens_cpu = cu_seqlens.cpu()
177 |         for j in range(bsz):
178 |             batch_seqlens = cu_seqlens_cpu[j]
179 |             for i in range(1, len(batch_seqlens)):
180 |                 attention_mask[
181 |                     j,
182 |                     ...,
183 |                     batch_seqlens[i - 1] : batch_seqlens[i],
184 |                     batch_seqlens[i - 1] : batch_seqlens[i],
185 |                 ] = True
186 | 
187 |         attention_mask = attention_mask.to(q.device)
188 | 
189 |         q = q.transpose(1, 2)
190 |         k = k.transpose(1, 2)
191 |         v = v.transpose(1, 2)
192 | 
193 |         attn_output = F.scaled_dot_product_attention(
194 |             q,
195 |             k,
196 |             v,
197 |             attention_mask,
198 |             dropout_p=0.0,
199 |         )
200 |         attn_output = attn_output.transpose(1, 2)
201 |         attn_output = attn_output.reshape(bsz, seq_length, -1)
202 |         attn_output = self.proj(attn_output)
203 |         return attn_output
204 | 
205 | 
206 | class Qwen2_5_VLVisionXLAFlashAttention2(nn.Module):
207 |     def __init__(self, dim: int, num_heads: int = 16) -> None:
208 |         super().__init__()
209 |         self.num_heads = num_heads
210 |         self.qkv = nn.Linear(dim, dim * 3, bias=True)
211 |         self.proj = nn.Linear(dim, dim)
212 |         self.head_dim = dim // num_heads
213 | 
214 |     def forward(
215 |         self,
216 |         hidden_states: torch.Tensor,
217 |         cu_seqlens: torch.Tensor,
218 |         rotary_pos_emb: Optional[torch.Tensor] = None,
219 |         position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
220 |     ) -> torch.Tensor:
221 |         # Note, this is faster than SDPA, but pretty memory inefficient
222 |         # It also has significant accuracy issues
223 | 
224 |         bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1]
225 | 
226 |         # Single reshape to target layout - avoid multiple operations
227 |         q, k, v = (
228 |             self.qkv(hidden_states)
229 |             .reshape(bsz, seq_length, 3, self.num_heads, -1)
230 |             .permute(0, 2, 1, 3, 4)
231 |             .unbind(1)
232 |         )
233 | 
234 |         # Apply rotary embeddings if provided
235 |         if position_embeddings is not None:
236 |             cos, sin = position_embeddings
237 |             q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
238 | 
239 |         # Single reshape to flash attention format [batch, num_heads, seq_len, head_dim]
240 |         q = q.transpose(1, 2)  # [bsz, num_heads, seq_len, head_dim]
241 |         k = k.transpose(1, 2)
242 |         v = v.transpose(1, 2)
243 | 
244 |         total_seqlen = q.shape[2]
245 |         # from cu_seqlens to segment ids for each position in dim 0
246 |         additive_bias = torch.zeros((bsz, 1, total_seqlen, total_seqlen), dtype=q.dtype)
247 |         min_val = torch.finfo(q.dtype).min
248 | 
249 |         for i in range(bsz):
250 |             padding_end = cu_seqlens[i][1].item()
251 |             additive_bias[i, :, :, :padding_end] = min_val
252 | 
253 |         additive_bias = additive_bias.to(hidden_states.device)
254 | 
255 |         attn_scale = 1 / math.sqrt(self.head_dim)
256 |         attn_output = torch_xla.experimental.custom_kernel.flash_attention(
257 |             q, k, v, sm_scale=attn_scale, ab=additive_bias
258 |         )
259 |         attn_output = (
260 |             attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_length, -1)
261 |         )
262 |         attn_output = self.proj(attn_output)
263 |         return attn_output
264 | 
265 | 
266 | class Qwen2_5_VLVisionFlashAttention2(nn.Module):
267 |     def __init__(self, dim: int, num_heads: int = 16) -> None:
268 |         super().__init__()
269 |         self.num_heads = num_heads
270 |         self.qkv = nn.Linear(dim, dim * 3, bias=True)
271 |         self.proj = nn.Linear(dim, dim)
272 | 
273 |     def forward(
274 |         self,
275 |         hidden_states: torch.Tensor,
276 |         cu_seqlens: torch.Tensor,
277 |         rotary_pos_emb: Optional[torch.Tensor] = None,
278 |         position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
279 |     ) -> torch.Tensor:
280 |         from flash_attn import flash_attn_varlen_func
281 | 
282 |         bsz = hidden_states.shape[0]
283 |         seq_length = hidden_states.shape[1]
284 |         q, k, v = (
285 |             self.qkv(hidden_states)
286 |             .reshape(bsz, seq_length, 3, self.num_heads, -1)
287 |             .permute(0, 2, 1, 3, 4)
288 |             .unbind(1)
289 |         )
290 |         if position_embeddings is None:
291 |             logger.warning_once(
292 |                 "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
293 |                 "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
294 |                 "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
295 |                 "removed and `position_embeddings` will be mandatory."
296 |             )
297 |             emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
298 |             cos = emb.cos()
299 |             sin = emb.sin()
300 |         else:
301 |             cos, sin = position_embeddings
302 | 
303 |         q, k = apply_rotary_pos_emb_flashatt(q, k, cos.squeeze(0), sin.squeeze(0))
304 | 
305 |         q = q.squeeze(0)
306 |         k = k.squeeze(0)
307 |         v = v.squeeze(0)
308 |         cu_seqlens = cu_seqlens.squeeze(0)
309 | 
310 |         max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
311 |         attn_output = flash_attn_varlen_func(
312 |             q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
313 |         ).reshape(bsz, seq_length, -1)
314 |         attn_output = self.proj(attn_output)
315 |         return attn_output
316 | 
317 | 
318 | def rotate_half(x):
319 |     """Rotates half the hidden dims of the input."""
320 |     x1 = x[..., : x.shape[-1] // 2]
321 |     x2 = x[..., x.shape[-1] // 2 :]
322 |     return torch.cat((-x2, x1), dim=-1)
323 | 
324 | 
325 | def apply_rotary_pos_emb_vision(
326 |     q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
327 | ) -> Tuple[torch.Tensor, torch.Tensor]:
328 |     orig_q_dtype = q.dtype
329 |     orig_k_dtype = k.dtype
330 |     q, k = q.float(), k.float()
331 |     cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
332 |     q_embed = (q * cos) + (rotate_half(q) * sin)
333 |     k_embed = (k * cos) + (rotate_half(k) * sin)
334 |     q_embed = q_embed.to(orig_q_dtype)
335 |     k_embed = k_embed.to(orig_k_dtype)
336 |     return q_embed, k_embed
337 | 
338 | 
339 | class Qwen2_5_VLVisionAttention(nn.Module):
340 |     def __init__(self, dim: int, num_heads: int = 16) -> None:
341 |         super().__init__()
342 |         self.num_heads = num_heads
343 |         self.head_dim = dim // num_heads
344 |         self.qkv = nn.Linear(dim, dim * 3, bias=True)
345 |         self.proj = nn.Linear(dim, dim)
346 | 
347 |     def forward(
348 |         self,
349 |         hidden_states: torch.Tensor,
350 |         cu_seqlens: torch.Tensor,
351 |         rotary_pos_emb: Optional[torch.Tensor] = None,
352 |         position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
353 |     ) -> torch.Tensor:
354 |         bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1]
355 |         q, k, v = (
356 |             self.qkv(hidden_states)
357 |             .reshape(bsz, seq_length, 3, self.num_heads, -1)
358 |             .permute(0, 2, 1, 3, 4)
359 |             .unbind(1)
360 |         )
361 |         if position_embeddings is None:
362 |             logger.warning_once(
363 |                 "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
364 |                 "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
365 |                 "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
366 |                 "removed and `position_embeddings` will be mandatory."
367 |             )
368 |             emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
369 |             cos = emb.cos()
370 |             sin = emb.sin()
371 |         else:
372 |             cos, sin = position_embeddings
373 | 
374 |         q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
375 | 
376 |         attention_mask = torch.full(
377 |             [bsz, 1, seq_length, seq_length],
378 |             torch.finfo(q.dtype).min,
379 |             device=q.device,
380 |             dtype=q.dtype,
381 |         )
382 |         for j in range(bsz):
383 |             batch_seqlens = cu_seqlens[j]
384 |             for i in range(1, len(batch_seqlens)):
385 |                 attention_mask[
386 |                     j,
387 |                     ...,
388 |                     batch_seqlens[i - 1] : batch_seqlens[i],
389 |                     batch_seqlens[i - 1] : batch_seqlens[i],
390 |                 ] = 0
391 | 
392 |         q = q.transpose(1, 2)
393 |         k = k.transpose(1, 2)
394 |         v = v.transpose(1, 2)
395 | 
396 |         attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
397 |         attn_weights = attn_weights + attention_mask
398 |         attn_weights = nn.functional.softmax(
399 |             attn_weights, dim=-1, dtype=torch.float32
400 |         ).to(q.dtype)
401 |         attn_output = torch.matmul(attn_weights, v)
402 |         attn_output = attn_output.transpose(1, 2)
403 |         attn_output = attn_output.reshape(bsz, seq_length, -1)
404 |         attn_output = self.proj(attn_output)
405 |         return attn_output
406 | 
407 | 
408 | class Qwen2_5_VLVisionSdpaAttention(nn.Module):
409 |     def __init__(self, dim: int, num_heads: int = 16) -> None:
410 |         super().__init__()
411 |         self.num_heads = num_heads
412 |         self.qkv = nn.Linear(dim, dim * 3, bias=True)
413 |         self.proj = nn.Linear(dim, dim)
414 | 
415 |     def unpack_qkv_with_mask(self, q, k, v, cu_seqlens):
416 |         """
417 |         Unpacks q, k, v sequences into batch-major form and constructs an additive attention mask.
418 | 
419 |         Args:
420 |             q, k, v: Tensors of shape (total_seq_len, num_heads, head_dim)
421 |             cu_seqlens: Tensor of shape (batch_size + 1,) with cumulative sequence lengths
422 | 
423 |         Returns:
424 |             batched_q: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)
425 |             batched_k: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)
426 |             batched_v: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)
427 |             attention_mask: Tensor of shape (batch_size, 1, max_seq_len, max_seq_len)
428 |                             with 0 for valid tokens and -inf for padding (for additive attention)
429 |         """
430 |         device = q.device
431 |         dtype = q.dtype
432 | 
433 |         batch_size = cu_seqlens.shape[0] - 1
434 |         num_heads = q.shape[1]
435 |         head_dim = q.shape[2]
436 | 
437 |         seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]  # Keep as tensor
438 |         max_seq_len = seq_lengths.max().item()  # Use .max() on tensor
439 | 
440 |         if settings.FOUNDATION_STATIC_CACHE:
441 |             # Pad max_seq_len to the nearest multiple for compilation
442 |             max_seq_len = get_nearest_pad(max_seq_len, pad_multiple=16)
443 | 
444 |             # Pad batch_size to the nearest multiple for compilation
445 |             batch_size = get_nearest_pad(batch_size, pad_multiple=2)
446 | 
447 |             # Ensure seq_lengths is a tensor of the correct size
448 |             seq_lengths = F.pad(
449 |                 seq_lengths, (0, batch_size - seq_lengths.size(0)), "constant", 0
450 |             )
451 | 
452 |         # some day, you may look at this, and think: "what if I used repeat_interlave or some other fancy torch instead"?
453 |         # don't do this - it's a path to madness.  For some reason, this loop is optimal
454 | 
455 |         batch_indices = []
456 |         position_indices = []
457 | 
458 |         for i, seq_len in enumerate(
459 |             seq_lengths.tolist()
460 |         ):  # Convert to list only for iteration
461 |             batch_indices.extend([i] * seq_len)
462 |             position_indices.extend(list(range(seq_len)))
463 | 
464 |         batch_indices = torch.tensor(batch_indices, device=device)
465 |         position_indices = torch.tensor(position_indices, device=device)
466 | 
467 |         batched_q = torch.zeros(
468 |             (batch_size, max_seq_len, num_heads, head_dim), device=device, dtype=dtype
469 |         )
470 |         batched_k = torch.zeros_like(batched_q)
471 |         batched_v = torch.zeros_like(batched_q)
472 | 
473 |         # Create additive attention mask
474 |         attention_mask = torch.full(
475 |             (batch_size, max_seq_len, max_seq_len),
476 |             fill_value=float("-inf"),
477 |             device=device,
478 |             dtype=dtype,
479 |         )
480 | 
481 |         # Create mask for valid positions
482 |         seq_range = torch.arange(max_seq_len, device=device)
483 |         valid_mask = seq_range.unsqueeze(0) < seq_lengths.unsqueeze(
484 |             1
485 |         )  # (batch_size, max_seq_len)
486 |         valid_2d = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(
487 |             1
488 |         )  # (batch_size, max_seq_len, max_seq_len)
489 | 
490 |         # Simply use boolean indexing to set valid positions to 0
491 |         attention_mask[valid_2d] = 0
492 | 
493 |         attention_mask = attention_mask.unsqueeze(
494 |             1
495 |         )  # (batch_size, 1, max_seq_len, max_seq_len)
496 | 
497 |         batched_q[batch_indices, position_indices] = q
498 |         batched_k[batch_indices, position_indices] = k
499 |         batched_v[batch_indices, position_indices] = v
500 | 
501 |         return (
502 |             batched_q,
503 |             batched_k,
504 |             batched_v,
505 |             attention_mask,
506 |             batch_indices,
507 |             position_indices,
508 |         )
509 | 
510 |     def forward(
511 |         self,
512 |         hidden_states: torch.Tensor,
513 |         cu_seqlens: torch.Tensor,
514 |         rotary_pos_emb: Optional[torch.Tensor] = None,
515 |         position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
516 |     ) -> torch.Tensor:
517 |         hidden_states = hidden_states.squeeze(0)
518 |         cu_seqlens = cu_seqlens.squeeze(0)
519 | 
520 |         seq_length = hidden_states.shape[0]
521 |         q, k, v = (
522 |             self.qkv(hidden_states)
523 |             .reshape(seq_length, 3, self.num_heads, -1)
524 |             .permute(1, 0, 2, 3)
525 |             .unbind(0)
526 |         )
527 |         if position_embeddings is None:
528 |             logger.warning_once(
529 |                 "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
530 |                 "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
531 |                 "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
532 |                 "removed and `position_embeddings` will be mandatory."
533 |             )
534 |             emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
535 |             cos = emb.cos()
536 |             sin = emb.sin()
537 |         else:
538 |             cos, sin = position_embeddings
539 |         q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
540 |         q = q.squeeze(0)
541 |         k = k.squeeze(0)
542 | 
543 |         q, k, v, attention_mask, batch_indices, position_indices = (
544 |             self.unpack_qkv_with_mask(q, k, v, cu_seqlens)
545 |         )
546 |         batch_size, max_seqlen = q.shape[:2]
547 |         q = q.transpose(1, 2)
548 |         k = k.transpose(1, 2)
549 |         v = v.transpose(1, 2)
550 | 
551 |         attn_output = F.scaled_dot_product_attention(
552 |             q,
553 |             k,
554 |             v,
555 |             attention_mask,
556 |             dropout_p=0.0,
557 |         )
558 |         attn_output = attn_output.permute(0, 2, 1, 3).reshape(
559 |             batch_size, max_seqlen, -1
560 |         )  # Bring back to (batch_size, max_seqlen, hidden_dim)
561 |         attn_output = attn_output[batch_indices, position_indices]
562 |         attn_output = self.proj(attn_output)
563 | 
564 |         return attn_output.unsqueeze(0)
565 | 
566 | 
567 | QWEN2_5_VL_VISION_ATTENTION_CLASSES = {
568 |     "eager": Qwen2_5_VLVisionAttention,
569 |     "flash_attention_2": Qwen2_5_VLVisionXLAFlashAttention2
570 |     if settings.FOUNDATION_XLA
571 |     else Qwen2_5_VLVisionFlashAttention2,
572 |     "sdpa": Qwen2_5_VLVisionXLASdpaAttention
573 |     if settings.FOUNDATION_XLA
574 |     else Qwen2_5_VLVisionSdpaAttention,
575 | }
576 | 
577 | 
578 | class Qwen2_5_VLVisionBlock(nn.Module):
579 |     def __init__(self, config, attn_implementation: str = "sdpa") -> None:
580 |         super().__init__()
581 |         self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
582 |         self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
583 |         self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](
584 |             config.hidden_size, num_heads=config.num_heads
585 |         )
586 |         self.mlp = Qwen2_5_VLMLP(config, bias=True)
587 | 
588 |     def forward(
589 |         self,
590 |         hidden_states: torch.Tensor,
591 |         cu_seqlens: torch.Tensor,
592 |         rotary_pos_emb: Optional[torch.Tensor] = None,
593 |         position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
594 |     ) -> torch.Tensor:
595 |         hidden_states = hidden_states + self.attn(
596 |             self.norm1(hidden_states),
597 |             cu_seqlens=cu_seqlens,
598 |             rotary_pos_emb=rotary_pos_emb,
599 |             position_embeddings=position_embeddings,
600 |         )
601 |         hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
602 |         return hidden_states
603 | 
604 | 
605 | Qwen2_5_VL_START_DOCSTRING = r"""
606 |     This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
607 |     library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
608 |     etc.)
609 | 
610 |     This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
611 |     Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
612 |     and behavior.
613 | 
614 |     Parameters:
615 |         config ([`Qwen2_5_VLConfig`]):
616 |             Model configuration class with all the parameters of the model. Initializing with a config file does not
617 |             load the weights associated with the model, only the configuration. Check out the
618 |             [`~PreTrainedModel.from_pretrained`] method to load the model weights.
619 | """
620 | 
621 | 
622 | class Qwen2_5_VLPreTrainedModel(SuryaPreTrainedModel):
623 |     config_class = SuryaEncoderConfig
624 |     base_model_prefix = "model"
625 |     supports_gradient_checkpointing = True
626 |     _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
627 |     _skip_keys_device_placement = "past_key_values"
628 |     _supports_flash_attn_2 = True
629 |     _supports_sdpa = True
630 |     _supports_cache_class = True
631 |     _supports_static_cache = False  # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
632 | 
633 |     def _init_weights(self, module):
634 |         std = self.config.initializer_range
635 |         if isinstance(module, (nn.Linear, nn.Conv3d)):
636 |             module.weight.data.normal_(mean=0.0, std=std)
637 |             if module.bias is not None:
638 |                 module.bias.data.zero_()
639 |         elif isinstance(module, nn.Embedding):
640 |             module.weight.data.normal_(mean=0.0, std=std)
641 |             if module.padding_idx is not None:
642 |                 module.weight.data[module.padding_idx].zero_()
643 | 
644 | 
645 | class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
646 |     config_class = SuryaEncoderConfig
647 |     _no_split_modules = ["Qwen2_5_VLVisionBlock"]
648 | 
649 |     def __init__(self, config, *inputs, **kwargs) -> None:
650 |         super().__init__(config, *inputs, **kwargs)
651 |         self.spatial_merge_size = config.spatial_merge_size
652 |         self.patch_size = config.patch_size
653 |         self.fullatt_block_indexes = config.fullatt_block_indexes
654 |         self.window_size = config.window_size
655 |         self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
656 | 
657 |         self.patch_embed = Qwen2_5_VisionPatchEmbed(
658 |             patch_size=config.patch_size,
659 |             temporal_patch_size=config.temporal_patch_size,
660 |             in_channels=config.in_channels,
661 |             embed_dim=config.hidden_size,
662 |         )
663 | 
664 |         head_dim = config.hidden_size // config.num_heads
665 |         self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
666 | 
667 |         self.blocks = nn.ModuleList(
668 |             [
669 |                 Qwen2_5_VLVisionBlock(config, config._attn_implementation)
670 |                 for _ in range(config.depth)
671 |             ]
672 |         )
673 |         self.merger = Qwen2_5_VLPatchMerger(
674 |             dim=config.out_hidden_size,
675 |             context_dim=config.hidden_size,
676 |             spatial_merge_size=config.spatial_merge_size,
677 |         )
678 |         self.gradient_checkpointing = False
679 | 
680 |     def rot_pos_emb(self, grid_thw):
681 |         rotary_pos_emb = []
682 |         grid_thw_list = grid_thw.cpu().tolist()
683 |         for batch_item in grid_thw_list:
684 |             row_pos_ids = []
685 |             heights = [h for _, h, _ in batch_item]
686 |             widths = [w for _, _, w in batch_item]
687 |             max_grid_size = max(heights + widths)
688 |             for t, h, w in batch_item:
689 |                 hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
690 |                 hpos_ids = hpos_ids.reshape(
691 |                     h // self.spatial_merge_size,
692 |                     self.spatial_merge_size,
693 |                     w // self.spatial_merge_size,
694 |                     self.spatial_merge_size,
695 |                 )
696 |                 hpos_ids = hpos_ids.permute(0, 2, 1, 3)
697 |                 hpos_ids = hpos_ids.flatten()
698 | 
699 |                 wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
700 |                 wpos_ids = wpos_ids.reshape(
701 |                     h // self.spatial_merge_size,
702 |                     self.spatial_merge_size,
703 |                     w // self.spatial_merge_size,
704 |                     self.spatial_merge_size,
705 |                 )
706 |                 wpos_ids = wpos_ids.permute(0, 2, 1, 3)
707 |                 wpos_ids = wpos_ids.flatten()
708 |                 # shape: token_count, 2
709 |                 row_pos_ids.append(
710 |                     torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
711 |                 )
712 |             # shape: token_count, 2
713 |             pos_ids = torch.cat(row_pos_ids, dim=0)
714 |             rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
715 |             rotary_pos_emb_row = rotary_pos_emb_full[pos_ids].flatten(1)
716 |             rotary_pos_emb.append(rotary_pos_emb_row)
717 |         rotary_pos_emb = torch.stack(rotary_pos_emb, dim=0)
718 |         return rotary_pos_emb
719 | 
720 |     def forward(
721 |         self,
722 |         hidden_states: torch.Tensor,
723 |         grid_thw: torch.Tensor,
724 |     ) -> torch.Tensor:
725 |         """
726 |         Args:
727 |             hidden_states (`torch.Tensor` of shape `(bsz, seq_len, hidden_size)`):
728 |                 The final hidden states of the model.
729 |             grid_thw (`torch.Tensor` of shape `(bsz, num_images_or_videos, 3)`):
730 |                 The temporal, height and width of feature shape of each image in LLM.
731 | 
732 |         Returns:
733 |             `torch.Tensor`: hidden_states.
734 |         """
735 |         bsz, seq_len, _ = hidden_states.size()
736 |         hidden_states = self.patch_embed(hidden_states)  # (bsz, seq_len, hidden_dim)
737 |         rotary_pos_emb = self.rot_pos_emb(grid_thw)
738 | 
739 |         # hidden_states = hidden_states.reshape(bsz, seq_len, -1)
740 |         # rotary_pos_emb = rotary_pos_emb.reshape(bsz, seq_len, -1)
741 |         emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1).to(
742 |             hidden_states.device
743 |         )
744 |         position_embeddings = (emb.cos(), emb.sin())
745 | 
746 |         cu_seqlens = (grid_thw[:, :, 1] * grid_thw[:, :, 2]).cumsum(
747 |             dim=1,
748 |             # Select dtype based on the following factors:
749 |             #  - FA2 requires that cu_seqlens_q must have dtype int32
750 |             #  - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
751 |             # See https://github.com/huggingface/transformers/pull/34852 for more information
752 |             dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
753 |         )
754 |         cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
755 |         for layer_num, blk in enumerate(self.blocks):
756 |             if self.gradient_checkpointing and self.training:
757 |                 hidden_states = self._gradient_checkpointing_func(
758 |                     blk.__call__,
759 |                     hidden_states,
760 |                     cu_seqlens,
761 |                     None,
762 |                     position_embeddings,
763 |                 )
764 |             else:
765 |                 hidden_states = blk(
766 |                     hidden_states,
767 |                     cu_seqlens=cu_seqlens,
768 |                     position_embeddings=position_embeddings,
769 |                 )
770 | 
771 |         hidden_states = self.merger(hidden_states)
772 |         return hidden_states
773 | 
774 | 
775 | class SuryaEncoderModel(Qwen2_5_VisionTransformerPretrainedModel):
776 |     @property
777 |     def image_size(self) -> int:
778 |         config: SuryaEncoderConfig = self.config
779 |         if isinstance(config.image_size, tuple) and len(config.image_size) == 2:
780 |             return config.image_size
781 |         elif isinstance(config.image_size, int):
782 |             return (config.image_size, config.image_size)
783 | 
784 |         raise ValueError(
785 |             f"The `image_size` for SwinConfig should be a tuple of (int, int) or a single int but found {type(config.image_size)}"
786 |         )
787 | 
788 |     @property
789 |     def hidden_size(self) -> int:
790 |         config: SuryaEncoderConfig = self.config
791 |         return config.hidden_size
792 | 
793 |     def embed_images(
794 |         self,
795 |         image_batch: torch.Tensor,
796 |         grid_thw: torch.Tensor,
797 |     ) -> torch.Tensor:
798 |         return super().forward(
799 |             hidden_states=image_batch,
800 |             grid_thw=grid_thw,
801 |         )
802 | 
```

--------------------------------------------------------------------------------
/surya/common/adetr/decoder.py:
--------------------------------------------------------------------------------

```python
  1 | from typing import Dict, Optional, Tuple, Union
  2 | 
  3 | import torch
  4 | import torch.utils.checkpoint
  5 | from torch import nn
  6 | from transformers import PretrainedConfig
  7 | 
  8 | from transformers.activations import ACT2FN
  9 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 10 | from transformers.modeling_outputs import BaseModelOutputWithNoAttention
 11 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
 12 | 
 13 | from surya.common.pretrained import SuryaPreTrainedModel
 14 | from surya.common.xla import mark_step
 15 | 
 16 | _MAX_SQRT_GRADIENT = 1000.0
 17 | 
 18 | 
 19 | class WrappedEmbedding(nn.Embedding):
 20 |     def forward(self, input_ids, *args, **kwargs):
 21 |         return super().forward(input_ids)
 22 | 
 23 | 
 24 | class SuryaADETRDecoderRMSNorm(nn.Module):
 25 |     def __init__(self, dim: int, eps: float = 1e-6):
 26 |         super().__init__()
 27 |         self.eps = eps
 28 |         self.weight = nn.Parameter(torch.zeros(dim))
 29 | 
 30 |     def _norm(self, x):
 31 |         variance = x.pow(2).mean(-1, keepdim=True)
 32 | 
 33 |         # Add clipping to prevent division by zero
 34 |         variance = torch.clamp(variance, min=self.eps)
 35 |         return x * torch.rsqrt(variance)
 36 | 
 37 |     def forward(self, x):
 38 |         output = self._norm(x.float())
 39 |         # Llama does x.to(float16) * w whilst SuryaADETRDecoder is (x * w).to(float16)
 40 |         # See https://github.com/huggingface/transformers/pull/29402
 41 |         output = output * (1.0 + self.weight.float())
 42 |         # Clamp to float16 range
 43 |         f16_info = torch.finfo(x.dtype)
 44 |         output = output.clamp(min=f16_info.min, max=f16_info.max)
 45 |         output = torch.where(
 46 |             torch.isnan(output), torch.tensor(0.0, device=output.device), output
 47 |         )
 48 |         return output.type_as(x)
 49 | 
 50 |     def extra_repr(self):
 51 |         return f"{tuple(self.weight.shape)}, eps={self.eps}"
 52 | 
 53 | 
 54 | ALL_LAYERNORM_LAYERS.append(SuryaADETRDecoderRMSNorm)
 55 | 
 56 | 
 57 | class SuryaADETRDecoderRotaryEmbedding(nn.Module):
 58 |     def __init__(self, dim, base=10000, device=None):
 59 |         super().__init__()
 60 |         self.dim = dim
 61 |         self.base = base
 62 |         inv_freq = 1.0 / (
 63 |             self.base
 64 |             ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
 65 |         )
 66 |         self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
 67 | 
 68 |     @torch.no_grad()
 69 |     # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaADETRDecoder
 70 |     def forward(self, x, position_ids, seq_len=None):
 71 |         # x: [bs, num_attention_heads, seq_len, head_size]
 72 |         self.inv_freq.to(x.device)
 73 |         inv_freq_expanded = (
 74 |             self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
 75 |         )
 76 |         position_ids_expanded = position_ids[:, None, :].float()
 77 | 
 78 |         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
 79 |             1, 2
 80 |         )
 81 |         emb = torch.cat((freqs, freqs), dim=-1)
 82 |         cos = emb.cos()
 83 |         sin = emb.sin()
 84 |         return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 85 | 
 86 | 
 87 | # Copied from transformers.models.llama.modeling_llama.rotate_half
 88 | def rotate_half(x):
 89 |     """Rotates half the hidden dims of the input."""
 90 |     x1 = x[..., : x.shape[-1] // 2]
 91 |     x2 = x[..., x.shape[-1] // 2 :]
 92 |     return torch.cat((-x2, x1), dim=-1)
 93 | 
 94 | 
 95 | # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
 96 | def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
 97 |     """Applies Rotary Position Embedding to the query and key tensors.
 98 | 
 99 |     Args:
100 |         q (`torch.Tensor`): The query tensor.
101 |         k (`torch.Tensor`): The key tensor.
102 |         cos (`torch.Tensor`): The cosine part of the rotary embedding.
103 |         sin (`torch.Tensor`): The sine part of the rotary embedding.
104 |         unsqueeze_dim (`int`, *optional*, defaults to 1):
105 |             The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
106 |             sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
107 |             that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
108 |             k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
109 |             cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
110 |             the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
111 |     Returns:
112 |         `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
113 |     """
114 |     cos = cos.unsqueeze(unsqueeze_dim)
115 |     sin = sin.unsqueeze(unsqueeze_dim)
116 |     q_embed = (q * cos) + (rotate_half(q) * sin)
117 |     k_embed = (k * cos) + (rotate_half(k) * sin)
118 |     return q_embed, k_embed
119 | 
120 | 
121 | # Copied from transformers.models.llama.modeling_llama.repeat_kv
122 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
123 |     """
124 |     This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
125 |     num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
126 |     """
127 |     batch, num_key_value_heads, slen, head_dim = hidden_states.shape
128 |     if n_rep == 1:
129 |         return hidden_states
130 |     hidden_states = hidden_states[:, :, None, :, :].expand(
131 |         batch, num_key_value_heads, n_rep, slen, head_dim
132 |     )
133 |     return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
134 | 
135 | 
136 | class SuryaADETRDecoderSdpaCrossAttention(nn.Module):
137 |     """Multi-headed attention from 'Attention Is All You Need' paper
138 |     Modified for GQA
139 |     """
140 | 
141 |     def __init__(self, config: PretrainedConfig):
142 |         super().__init__()
143 |         self.config = config
144 |         self.attention_dropout = config.attention_dropout
145 |         self.hidden_size = config.hidden_size
146 |         self.num_attention_heads = config.num_attention_heads
147 |         self.head_dim = config.head_dim
148 |         self.num_key_value_heads = config.num_key_value_heads
149 |         self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
150 | 
151 |         self.q_proj = nn.Linear(
152 |             self.hidden_size,
153 |             self.num_attention_heads * self.head_dim,
154 |             bias=config.attention_bias,
155 |         )
156 |         self.k_proj = nn.Linear(
157 |             self.config.encoder_hidden_size,
158 |             self.num_key_value_heads * self.head_dim,
159 |             bias=config.attention_bias,
160 |         )
161 |         self.v_proj = nn.Linear(
162 |             self.config.encoder_hidden_size,
163 |             self.num_key_value_heads * self.head_dim,
164 |             bias=config.attention_bias,
165 |         )
166 |         self.o_proj = nn.Linear(
167 |             self.num_attention_heads * self.head_dim, self.hidden_size, bias=True
168 |         )
169 |         self.rotary_emb = SuryaADETRDecoderRotaryEmbedding(
170 |             self.head_dim,
171 |             base=config.rope_theta,
172 |         )
173 | 
174 |     def forward(
175 |         self,
176 |         hidden_states: torch.Tensor,
177 |         encoder_hidden_states: torch.Tensor,
178 |         attention_mask: Optional[torch.Tensor] = None,
179 |         encoder_attention_mask: Optional[torch.Tensor] = None,
180 |         use_cache: bool = False,
181 |     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
182 |         # Encoder attention mask currently ignored
183 | 
184 |         bsz, q_len, _ = hidden_states.size()
185 |         _, v_len, _ = encoder_hidden_states.size()
186 | 
187 |         query_states = self.q_proj(hidden_states)
188 |         query_states = query_states.view(
189 |             bsz, q_len, self.num_attention_heads, self.head_dim
190 |         ).transpose(1, 2)
191 | 
192 |         if self.key_states is None:
193 |             key_states = self.k_proj(encoder_hidden_states)
194 |             value_states = self.v_proj(encoder_hidden_states)
195 |             key_states = key_states.view(
196 |                 bsz, v_len, self.num_key_value_heads, self.head_dim
197 |             ).transpose(1, 2)
198 |             value_states = value_states.view(
199 |                 bsz, v_len, self.num_key_value_heads, self.head_dim
200 |             ).transpose(1, 2)
201 |             if use_cache:
202 |                 self._update_cache(key_states, value_states)
203 |         else:
204 |             key_states = self.key_states
205 |             value_states = self.value_states
206 | 
207 |         key_states = repeat_kv(key_states, self.num_key_value_groups)
208 |         value_states = repeat_kv(value_states, self.num_key_value_groups)
209 | 
210 |         attn_output = torch.nn.functional.scaled_dot_product_attention(
211 |             query_states,
212 |             key_states,
213 |             value_states,
214 |             attn_mask=None,
215 |             dropout_p=self.attention_dropout if self.training else 0.0,
216 |             scale=self.head_dim**-0.5,
217 |         )
218 | 
219 |         attn_output = attn_output.transpose(1, 2).contiguous()
220 |         attn_output = attn_output.view(bsz, q_len, self.hidden_size)
221 |         attn_output = self.o_proj(attn_output)
222 |         return attn_output
223 | 
224 |     def _clear_cache(self):
225 |         if self.value_states is not None:
226 |             del self.value_states
227 |         if self.key_states is not None:
228 |             del self.key_states
229 | 
230 |     def _setup_cache(self, batch_size, device, dtype=None):
231 |         # Setup initial caches
232 |         self.value_states = None
233 |         self.key_states = None
234 | 
235 |     @torch.no_grad()
236 |     def _update_cache(self, key_states, value_states, **cache_kwargs):
237 |         self.value_states = value_states
238 |         self.key_states = key_states
239 | 
240 | 
241 | class SuryaADETRDecoderSdpaAttention(nn.Module):
242 |     """Multi-headed attention from 'Attention Is All You Need' paper"""
243 | 
244 |     def __init__(self, config: PretrainedConfig, static_cache=False, max_boxes=None):
245 |         super().__init__()
246 |         self.config = config
247 |         self.attention_dropout = config.attention_dropout
248 |         self.hidden_size = config.hidden_size
249 |         self.num_attention_heads = config.num_attention_heads
250 |         self.head_dim = config.head_dim
251 |         self.num_key_value_heads = config.num_key_value_heads
252 |         self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
253 | 
254 |         self.q_proj = nn.Linear(
255 |             self.hidden_size,
256 |             self.num_attention_heads * self.head_dim,
257 |             bias=config.attention_bias,
258 |         )
259 |         self.k_proj = nn.Linear(
260 |             self.hidden_size,
261 |             self.num_key_value_heads * self.head_dim,
262 |             bias=config.attention_bias,
263 |         )
264 |         self.v_proj = nn.Linear(
265 |             self.hidden_size,
266 |             self.num_key_value_heads * self.head_dim,
267 |             bias=config.attention_bias,
268 |         )
269 |         self.o_proj = nn.Linear(
270 |             self.num_attention_heads * self.head_dim, self.hidden_size, bias=True
271 |         )
272 |         self.rotary_emb = SuryaADETRDecoderRotaryEmbedding(
273 |             self.head_dim,
274 |             base=config.rope_theta,
275 |         )
276 | 
277 |         self.static_cache = static_cache
278 |         self.max_boxes = max_boxes
279 | 
280 |     def forward(
281 |         self,
282 |         hidden_states: torch.Tensor,
283 |         position_ids: Optional[torch.LongTensor] = None,
284 |         attention_mask: Optional[torch.Tensor] = None,
285 |         cache_position: Optional[torch.LongTensor] = None,
286 |         use_cache: bool = False,
287 |         window_attn: bool = False,
288 |     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
289 |         bsz, q_len, _ = hidden_states.size()
290 | 
291 |         query_states = self.q_proj(hidden_states)
292 |         key_states = self.k_proj(hidden_states)
293 |         value_states = self.v_proj(hidden_states)
294 | 
295 |         # Final is bsz, num_attention_heads, seq_len, head_dim
296 |         query_states = query_states.view(
297 |             bsz, q_len, self.num_attention_heads, self.head_dim
298 |         ).transpose(1, 2)
299 |         key_states = key_states.view(
300 |             bsz, q_len, self.num_key_value_heads, self.head_dim
301 |         ).transpose(1, 2)
302 |         value_states = value_states.view(
303 |             bsz, q_len, self.num_key_value_heads, self.head_dim
304 |         ).transpose(1, 2)
305 | 
306 |         cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
307 |         query_states, key_states = apply_rotary_pos_emb(
308 |             query_states, key_states, cos, sin
309 |         )
310 | 
311 |         if use_cache and hasattr(self, "key_states"):
312 |             cache_kwargs = {
313 |                 "cache_position": cache_position,
314 |                 "window_attn": window_attn,
315 |             }
316 |             key_states, value_states = self._update_cache(
317 |                 key_states, value_states, **cache_kwargs
318 |             )
319 | 
320 |         key_states = repeat_kv(key_states, self.num_key_value_groups)
321 |         value_states = repeat_kv(value_states, self.num_key_value_groups)
322 | 
323 |         causal_mask = attention_mask
324 |         if attention_mask is not None:
325 |             # Mask is batch, head, seq_len, kv_len
326 |             causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
327 |             if cache_position is not None and self.static_cache:
328 |                 current_pos = cache_position[-1]
329 |                 causal_mask[:, :, :, current_pos + 1 :] = torch.finfo(
330 |                     causal_mask.dtype
331 |                 ).min
332 | 
333 |         attn_output = torch.nn.functional.scaled_dot_product_attention(
334 |             query_states,
335 |             key_states,
336 |             value_states,
337 |             attn_mask=causal_mask,
338 |             dropout_p=self.attention_dropout if self.training else 0.0,
339 |             scale=self.head_dim**-0.5,
340 |         )
341 | 
342 |         attn_output = attn_output.transpose(1, 2).contiguous()
343 |         attn_output = attn_output.view(bsz, q_len, self.hidden_size)
344 |         attn_output = self.o_proj(attn_output)
345 |         return attn_output
346 | 
347 |     def _setup_cache(self, batch_size, device, dtype=None):
348 |         if dtype is None and self.config.torch_dtype is not None:
349 |             dtype = self.config.torch_dtype
350 |         dtype = dtype if dtype is not None else torch.float32
351 | 
352 |         # Setup initial caches
353 |         self.value_states = None
354 |         self.key_states = None
355 | 
356 |         if self.static_cache:
357 |             cache_shape = (
358 |                 batch_size,
359 |                 self.num_key_value_heads,
360 |                 self.max_boxes,
361 |                 self.head_dim,
362 |             )
363 |             self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
364 |             self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
365 | 
366 |     def _clear_cache(self):
367 |         if self.value_states is not None:
368 |             del self.value_states
369 |         if self.key_states is not None:
370 |             del self.key_states
371 | 
372 |     def _update_static_cache(self, key_states, value_states, **cache_kwargs):
373 |         cache_position = cache_kwargs.get("cache_position")
374 |         k_out, v_out = (
375 |             self.key_states.to(key_states.device),
376 |             self.value_states.to(value_states.device),
377 |         )
378 | 
379 |         k_out[:, :, cache_position] = key_states.to(k_out.dtype)
380 |         v_out[:, :, cache_position] = value_states.to(v_out.dtype)
381 | 
382 |         self.key_states, self.value_states = k_out, v_out
383 |         return k_out, v_out
384 | 
385 |     def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs):
386 |         k_out = key_states
387 |         if self.key_states is not None:
388 |             k_out = torch.cat([self.key_states, key_states], dim=2)
389 | 
390 |         v_out = value_states
391 |         if self.value_states is not None:
392 |             v_out = torch.cat([self.value_states, value_states], dim=2)
393 | 
394 |         self.key_states, self.value_states = k_out, v_out
395 |         return k_out, v_out
396 | 
397 |     @torch.no_grad()
398 |     def _update_cache(self, key_states, value_states, **cache_kwargs):
399 |         if self.static_cache:
400 |             return self._update_static_cache(key_states, value_states, **cache_kwargs)
401 | 
402 |         return self._update_dynamic_cache(key_states, value_states, **cache_kwargs)
403 | 
404 | 
405 | class SuryaADETRDecoderMlp(nn.Module):
406 |     def __init__(self, config):
407 |         super().__init__()
408 |         self.config = config
409 |         self.hidden_size = config.hidden_size
410 |         self.intermediate_size = config.intermediate_size
411 |         self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
412 |         self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
413 |         self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
414 |         if config.hidden_activation is None:
415 |             config.hidden_activation = "gelu_pytorch_tanh"
416 |         hidden_activation = config.hidden_activation
417 |         self.act_fn = ACT2FN[hidden_activation]
418 | 
419 |     def forward(self, x):
420 |         return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
421 | 
422 | 
423 | class SuryaADETRDecoderLayer(nn.Module):
424 |     def __init__(self, config, layer_idx, static_cache=False, max_boxes=None):
425 |         super().__init__()
426 |         self.cross_pre_norm = SuryaADETRDecoderRMSNorm(
427 |             config.hidden_size, eps=config.rms_norm_eps
428 |         )
429 |         self.temporal_pre_norm = SuryaADETRDecoderRMSNorm(
430 |             config.hidden_size, eps=config.rms_norm_eps
431 |         )
432 | 
433 |         self.temporal_block = None
434 |         if layer_idx in config.self_attn_layers:
435 |             self.temporal_block = SuryaADETRDecoderSdpaAttention(
436 |                 config, static_cache=static_cache, max_boxes=max_boxes
437 |             )
438 | 
439 |         self.cross_attn_block = None
440 |         if layer_idx in config.cross_attn_layers:
441 |             self.cross_attn_block = SuryaADETRDecoderSdpaCrossAttention(config)
442 | 
443 |         self.window_attn = layer_idx not in config.global_attn_layers
444 |         self.channel_pre_norm = SuryaADETRDecoderRMSNorm(
445 |             config.hidden_size, eps=config.rms_norm_eps
446 |         )
447 |         self.mlp_block = SuryaADETRDecoderMlp(config)
448 | 
449 |         self.double_residual_flow = getattr(config, "double_residual_flow", False)
450 | 
451 |     def forward(
452 |         self,
453 |         activations: torch.Tensor,
454 |         position_ids: torch.Tensor,
455 |         attention_mask: torch.Tensor,
456 |         encoder_hidden_states: torch.Tensor = None,
457 |         encoder_attention_mask: torch.Tensor = None,
458 |         cache_position: torch.Tensor = None,
459 |         use_cache: bool = None,
460 |     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
461 |         if self.double_residual_flow:
462 |             return self.double_res_forward(
463 |                 activations,
464 |                 position_ids,
465 |                 attention_mask,
466 |                 encoder_hidden_states,
467 |                 encoder_attention_mask,
468 |                 cache_position,
469 |                 use_cache,
470 |             )
471 | 
472 |         hidden_states = activations
473 |         if self.cross_attn_block is not None:
474 |             # Do cross-attention on encoder outputs
475 |             cross_attn_inputs = self.cross_pre_norm(hidden_states)
476 |             cross_attn_path = self.cross_attn_block(
477 |                 cross_attn_inputs,
478 |                 encoder_hidden_states,
479 |                 attention_mask,
480 |                 encoder_attention_mask,
481 |                 use_cache=use_cache,
482 |             )
483 |             hidden_states = cross_attn_path + hidden_states
484 | 
485 |         if self.temporal_block is not None:
486 |             temporal_inputs = self.temporal_pre_norm(
487 |                 hidden_states
488 |             )  # RMSNorm introduces slight slight differences
489 |             temporal_path = self.temporal_block(
490 |                 temporal_inputs,
491 |                 position_ids,
492 |                 attention_mask,
493 |                 cache_position=cache_position,
494 |                 use_cache=use_cache,
495 |                 window_attn=self.window_attn,
496 |             )
497 | 
498 |             hidden_states = temporal_path + hidden_states
499 | 
500 |         block_input = hidden_states
501 |         hidden_states = self.channel_pre_norm(block_input)
502 |         hidden_states = self.mlp_block(hidden_states)
503 |         hidden_states = hidden_states + block_input
504 | 
505 |         return hidden_states
506 | 
507 |     def double_res_forward(
508 |         self,
509 |         activations: torch.Tensor,
510 |         position_ids: torch.Tensor,
511 |         attention_mask: torch.Tensor,
512 |         encoder_hidden_states: torch.Tensor = None,
513 |         encoder_attention_mask: torch.Tensor = None,
514 |         cache_position: torch.Tensor = None,
515 |         use_cache: bool = None,
516 |     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
517 |         raw_activations = activations
518 | 
519 |         if self.cross_attn_block is not None:
520 |             # Do cross-attention on encoder outputs
521 |             cross_attn_inputs = self.cross_pre_norm(activations)
522 |             cross_attn_path = self.cross_attn_block(
523 |                 cross_attn_inputs,
524 |                 encoder_hidden_states,
525 |                 attention_mask,
526 |                 encoder_attention_mask,
527 |                 use_cache=use_cache,
528 |             )
529 |             cross_attn_output = cross_attn_path + raw_activations
530 |         else:
531 |             cross_attn_output = raw_activations
532 | 
533 |         if self.temporal_block is not None:
534 |             inputs_normalized = self.temporal_pre_norm(
535 |                 cross_attn_output
536 |             )  # RMSNorm introduces slight slight differences
537 |             hidden_states = self.temporal_block(
538 |                 inputs_normalized,
539 |                 position_ids,
540 |                 attention_mask,
541 |                 cache_position=cache_position,
542 |                 use_cache=use_cache,
543 |                 window_attn=self.window_attn,
544 |             )
545 | 
546 |             residual = hidden_states + raw_activations
547 |         else:
548 |             residual = cross_attn_output
549 | 
550 |         hidden_states = self.channel_pre_norm(residual)
551 |         hidden_states = self.mlp_block(hidden_states)
552 | 
553 |         hidden_states = hidden_states + residual
554 |         return hidden_states
555 | 
556 | 
557 | class SuryaADETRDecoderPreTrainedModel(SuryaPreTrainedModel):
558 |     config_class = PretrainedConfig
559 |     base_model_prefix = "model"
560 |     supports_gradient_checkpointing = True
561 |     _no_split_modules = ["SuryaADETRDecoderLayer"]
562 |     _skip_keys_device_placement = ["cache"]
563 |     _supports_flash_attn_2 = False
564 |     _supports_sdpa = False  # we can't compare with eager for now
565 |     _supports_cache_class = True
566 |     _supports_quantized_cache = True
567 | 
568 |     def _init_weights(self, module):
569 |         if isinstance(module, SuryaADETRDecoderSdpaAttention):
570 |             torch.nn.init.normal_(
571 |                 module.q_proj.weight, mean=0.0, std=self.config.init_std
572 |             )
573 |             torch.nn.init.normal_(
574 |                 module.k_proj.weight, mean=0.0, std=self.config.init_std
575 |             )
576 |             torch.nn.init.normal_(
577 |                 module.v_proj.weight, mean=0.0, std=self.config.init_std
578 |             )
579 | 
580 |             torch.nn.init.normal_(
581 |                 module.o_proj.weight, mean=0.0, std=self.config.init_std
582 |             )
583 |         elif isinstance(module, nn.Linear):
584 |             torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
585 |             if getattr(module, "bias", None) is not None:
586 |                 torch.nn.init.zeros_(module.bias)
587 |         elif isinstance(module, nn.Embedding):
588 |             module.weight.data.normal_(mean=0.0, std=self.config.init_std)
589 |             if module.padding_idx is not None:
590 |                 module.weight.data[module.padding_idx].zero_()
591 | 
592 |     def _setup_cache(self, config, batch, device, dtype):
593 |         layers = getattr(self, "model", self).layers
594 |         for layer in layers:
595 |             if layer.temporal_block:
596 |                 layer.temporal_block._setup_cache(batch, device, dtype)
597 |             if layer.cross_attn_block:
598 |                 layer.cross_attn_block._setup_cache(batch, device, dtype)
599 | 
600 |     def _clear_cache(self):
601 |         layers = getattr(self, "model", self).layers
602 |         for layer in layers:
603 |             if layer.temporal_block:
604 |                 layer.temporal_block._clear_cache()
605 |             if layer.cross_attn_block:
606 |                 layer.cross_attn_block._clear_cache()
607 | 
608 |     def reset_cache(self, batch, device, dtype):
609 |         pass
610 | 
611 |     def _tie_weights(self):
612 |         pass
613 | 
614 |     def tie_weights(self):
615 |         pass
616 | 
617 | 
618 | class SuryaADETRDecoderModel(SuryaADETRDecoderPreTrainedModel):
619 |     """
620 |     Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaADETRDecoderDecoderLayer`]
621 | 
622 |     Args:
623 |         config: PretrainedConfig
624 |     """
625 | 
626 |     def __init__(
627 |         self,
628 |         config: PretrainedConfig,
629 |         embedder: nn.Module = None,
630 |         max_boxes: int = None,
631 |         static_cache: bool = False,
632 |     ):
633 |         super().__init__(config)
634 |         self.padding_idx = config.pad_token_id
635 |         self.vocab_size = config.vocab_size
636 |         self.causal = config.causal
637 | 
638 |         self.embed_tokens = embedder
639 |         self.max_boxes = max_boxes
640 |         self.static_cache = static_cache
641 | 
642 |         self.layers = nn.ModuleList(
643 |             [
644 |                 SuryaADETRDecoderLayer(
645 |                     config, layer_idx, static_cache=static_cache, max_boxes=max_boxes
646 |                 )
647 |                 for layer_idx in range(config.num_hidden_layers)
648 |             ]
649 |         )
650 |         self.final_norm = SuryaADETRDecoderRMSNorm(
651 |             config.hidden_size, eps=config.rms_norm_eps
652 |         )
653 |         self.gradient_checkpointing = False
654 | 
655 |         self.register_buffer(
656 |             "normalizer",
657 |             torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32),
658 |             persistent=False,
659 |         )
660 |         # Initialize weights and apply final processing
661 |         self.post_init()
662 | 
663 |     # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings
664 |     def get_input_embeddings(self):
665 |         return self.embed_tokens
666 | 
667 |     # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings
668 |     def set_input_embeddings(self, value):
669 |         self.embed_tokens = value
670 | 
671 |     def forward(
672 |         self,
673 |         input_ids: torch.LongTensor = None,
674 |         input_boxes_counts: torch.LongTensor = None,
675 |         inputs_embeds: Optional[torch.FloatTensor] = None,
676 |         position_ids: Optional[torch.LongTensor] = None,
677 |         attention_mask: Optional[torch.Tensor] = None,
678 |         encoder_hidden_states: Optional[torch.FloatTensor] = None,
679 |         encoder_attention_mask: Optional[torch.FloatTensor] = None,
680 |         cache_position: Optional[torch.LongTensor] = None,
681 |         use_cache: Optional[bool] = None,
682 |         output_hidden_states: Optional[bool] = None,
683 |         return_dict: Optional[bool] = None,
684 |         prefill: bool = False,
685 |     ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
686 |         use_cache = use_cache if use_cache is not None else self.config.use_cache
687 |         return_dict = (
688 |             return_dict if return_dict is not None else self.config.use_return_dict
689 |         )
690 | 
691 |         if self.gradient_checkpointing and self.training and use_cache:
692 |             use_cache = False
693 | 
694 |         inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts)
695 |         hidden_states = inputs_embeds
696 | 
697 |         if use_cache and prefill:
698 |             self._setup_cache(
699 |                 self.config,
700 |                 hidden_states.shape[0],
701 |                 hidden_states.device,
702 |                 hidden_states.dtype,
703 |             )
704 | 
705 |         if cache_position is None:
706 |             cache_position = torch.arange(
707 |                 hidden_states.shape[1], device=hidden_states.device
708 |             )
709 |         if position_ids is None:
710 |             position_ids = cache_position.unsqueeze(0)
711 | 
712 |         causal_mask = self._update_causal_mask(
713 |             attention_mask, inputs_embeds, cache_position
714 |         )
715 | 
716 |         all_hidden_states = () if output_hidden_states else None
717 |         for i, residual_block in enumerate(self.layers):
718 |             if output_hidden_states:
719 |                 all_hidden_states += (hidden_states,)
720 |             if self.gradient_checkpointing and self.training:
721 |                 hidden_states = self._gradient_checkpointing_func(
722 |                     residual_block.__call__,
723 |                     hidden_states,
724 |                     position_ids,
725 |                     causal_mask,
726 |                     encoder_hidden_states,
727 |                     encoder_attention_mask,
728 |                     cache_position,
729 |                     use_cache,
730 |                 )
731 |             else:
732 |                 hidden_states = residual_block(
733 |                     hidden_states,
734 |                     position_ids,
735 |                     causal_mask,
736 |                     encoder_hidden_states,
737 |                     encoder_attention_mask,
738 |                     cache_position,
739 |                     use_cache,
740 |                 )
741 | 
742 |         hidden_states = self.final_norm(hidden_states)
743 | 
744 |         # add hidden states from the last decoder layer
745 |         if output_hidden_states:
746 |             all_hidden_states += (hidden_states,)
747 | 
748 |         if not return_dict:
749 |             return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
750 | 
751 |         return BaseModelOutputWithNoAttention(
752 |             last_hidden_state=hidden_states,
753 |             hidden_states=all_hidden_states,
754 |         )
755 | 
756 |     # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
757 |     # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
758 |     # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
759 |     # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
760 |     # Ignore copy
761 |     def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
762 |         if not self.causal:
763 |             return None
764 | 
765 |         dtype, device = input_tensor.dtype, input_tensor.device
766 |         min_dtype = torch.finfo(dtype).min
767 |         sequence_length = input_tensor.shape[1]
768 |         target_length = max(self.max_boxes, sequence_length)
769 | 
770 |         diagonal = torch.full(
771 |             (sequence_length, target_length),
772 |             fill_value=min_dtype,
773 |             dtype=dtype,
774 |             device=device,
775 |         )
776 |         causal_mask = diagonal
777 |         if sequence_length != 1:
778 |             # Select the upper triangular part of the matrix, but unmask current token (the diagonal)
779 |             # triu will be the min_dtype, everything else is 0 (attended to)
780 |             causal_mask = torch.triu(diagonal, diagonal=1)
781 | 
782 |         causal_mask *= torch.arange(
783 |             target_length, device=device
784 |         ) > cache_position.reshape(-1, 1)
785 |         causal_mask = causal_mask[None, None, :, :].expand(
786 |             input_tensor.shape[0], 1, -1, -1
787 |         )
788 |         if attention_mask is not None:
789 |             causal_mask = (
790 |                 causal_mask.clone()
791 |             )  # copy to contiguous memory for in-place edit
792 |             if attention_mask.dim() == 2:
793 |                 # Mask positions in the causal mask that are masked in the attention mask
794 |                 mask_length = attention_mask.shape[-1]
795 |                 padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
796 |                     :, None, None, :
797 |                 ].eq(0.0)
798 |                 causal_mask[..., :mask_length] = causal_mask[
799 |                     ..., :mask_length
800 |                 ].masked_fill(padding_mask, min_dtype)
801 | 
802 |         if attention_mask is not None and attention_mask.device.type == "cuda":
803 |             # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
804 |             # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
805 |             # Details: https://github.com/pytorch/pytorch/issues/110213
806 |             causal_mask = AttentionMaskConverter._unmask_unattended(
807 |                 causal_mask, min_dtype
808 |             )
809 | 
810 |         return causal_mask
811 | 
```
Page 4/5FirstPrevNextLast