This is page 4 of 5. Use http://codebase.md/datalab-to/surya?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── ISSUE_TEMPLATE │ │ ├── breaking-bug-report.md │ │ ├── feature_request.md │ │ └── output-bug-report.md │ └── workflows │ ├── benchmarks.yml │ ├── ci.yml │ ├── cla.yml │ ├── publish.yml │ └── scripts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── benchmark │ ├── detection.py │ ├── layout.py │ ├── ordering.py │ ├── recognition.py │ ├── table_recognition.py │ ├── texify.py │ └── utils │ ├── __init__.py │ ├── bbox.py │ ├── metrics.py │ ├── scoring.py │ ├── tatr.py │ ├── tesseract.py │ ├── textract.py │ └── verify_benchmark_scores.py ├── CITATION.cff ├── CLA.md ├── detect_layout.py ├── detect_text.py ├── LICENSE ├── ocr_app.py ├── ocr_latex.py ├── ocr_text.py ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── README.md ├── signatures │ └── version1 │ └── cla.json ├── static │ ├── fonts │ │ └── .gitignore │ └── images │ ├── arabic_layout.jpg │ ├── arabic_reading.jpg │ ├── arabic_text.jpg │ ├── arabic.jpg │ ├── benchmark_chart_small.png │ ├── benchmark_chart.png │ ├── benchmark_layout_chart.png │ ├── benchmark_rec_chart.png │ ├── benchmark_tablerec_acc.png │ ├── benchmark_tablerec_speed.png │ ├── chi_hind_layout.jpg │ ├── chi_hind_orig.jpg │ ├── chi_hind_reading.jpg │ ├── chi_hind_text.jpg │ ├── chi_hind.jpg │ ├── chinese_layout.jpg │ ├── chinese_reading.jpg │ ├── chinese_text.jpg │ ├── chinese.jpg │ ├── excerpt_layout.png │ ├── excerpt_reading.jpg │ ├── excerpt_text.png │ ├── excerpt.png │ ├── funsd_layout.jpg │ ├── funsd_reading.jpg │ ├── funsd_text.jpg │ ├── funsd.png │ ├── gcloud_full_langs.png │ ├── gcloud_rec_bench.png │ ├── hindi_layout.jpg │ ├── hindi_reading.jpg │ ├── hindi_text.jpg │ ├── hindi.jpg │ ├── japanese_layout.jpg │ ├── japanese_reading.jpg │ ├── japanese_tablerec.png │ ├── japanese_text.jpg │ ├── japanese.jpg │ ├── latex_ocr.png │ ├── nyt_layout.jpg │ ├── nyt_order.jpg │ ├── nyt_text.jpg │ ├── nyt.jpg │ ├── paper_layout.jpg │ ├── paper_reading.jpg │ ├── paper_tablerec.png │ ├── paper_text.jpg │ ├── paper.jpg │ ├── pres_layout.jpg │ ├── pres_reading.jpg │ ├── pres_tablerec.png │ ├── pres_text.jpg │ ├── pres.png │ ├── rec_acc_table.png │ ├── scanned_layout.jpg │ ├── scanned_reading.jpg │ ├── scanned_tablerec.png │ ├── scanned_tablerec2.png │ ├── scanned_text.jpg │ ├── scanned.png │ ├── surya_rec_perf.png │ ├── table_rec.png │ ├── textbook_layout.jpg │ ├── textbook_order.jpg │ ├── textbook_text.jpg │ └── textbook.jpg ├── surya │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── adetr │ │ │ └── decoder.py │ │ ├── donut │ │ │ ├── encoder.py │ │ │ └── processor.py │ │ ├── load.py │ │ ├── polygon.py │ │ ├── predictor.py │ │ ├── pretrained.py │ │ ├── s3.py │ │ ├── surya │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── decoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── embedder │ │ │ │ └── __init__.py │ │ │ ├── encoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── flash_attn_utils.py │ │ │ ├── processor │ │ │ │ ├── __init__.py │ │ │ │ ├── schema.py │ │ │ │ └── tokenizer.py │ │ │ └── schema.py │ │ ├── util.py │ │ └── xla.py │ ├── debug │ │ ├── draw.py │ │ ├── fonts.py │ │ ├── katex.js │ │ ├── render_html.py │ │ └── text.py │ ├── detection │ │ ├── __init__.py │ │ ├── heatmap.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoderdecoder.py │ │ ├── parallel.py │ │ ├── processor.py │ │ ├── schema.py │ │ └── util.py │ ├── foundation │ │ ├── __init__.py │ │ ├── cache │ │ │ ├── __init__.py │ │ │ ├── dynamic_ops.py │ │ │ └── static_ops.py │ │ ├── loader.py │ │ └── util.py │ ├── input │ │ ├── load.py │ │ └── processing.py │ ├── layout │ │ ├── __init__.py │ │ ├── label.py │ │ └── schema.py │ ├── logging.py │ ├── models.py │ ├── ocr_error │ │ ├── __init__.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoder.py │ │ ├── schema.py │ │ └── tokenizer.py │ ├── recognition │ │ ├── __init__.py │ │ ├── languages.py │ │ ├── postprocessing.py │ │ ├── schema.py │ │ └── util.py │ ├── scripts │ │ ├── __init__.py │ │ ├── config.py │ │ ├── detect_layout.py │ │ ├── detect_text.py │ │ ├── finetune_ocr.py │ │ ├── hf_to_s3.py │ │ ├── ocr_latex.py │ │ ├── ocr_text.py │ │ ├── run_streamlit_app.py │ │ ├── run_texify_app.py │ │ ├── streamlit_app.py │ │ ├── table_recognition.py │ │ └── texify_app.py │ ├── settings.py │ └── table_rec │ ├── __init__.py │ ├── loader.py │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── encoderdecoder.py │ ├── processor.py │ ├── schema.py │ └── shaper.py ├── table_recognition.py ├── tests │ ├── assets │ │ └── test_latex.png │ ├── conftest.py │ ├── test_detection.py │ ├── test_foundation.py │ ├── test_latex_ocr.py │ ├── test_layout.py │ ├── test_ocr_errors.py │ ├── test_recognition.py │ └── test_table_rec.py └── texify_app.py ``` # Files -------------------------------------------------------------------------------- /surya/ocr_error/tokenizer.py: -------------------------------------------------------------------------------- ```python 1 | import collections 2 | import os 3 | import json 4 | import unicodedata 5 | from typing import List, Optional, Tuple 6 | 7 | from tokenizers import normalizers 8 | 9 | from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace 10 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 11 | 12 | from surya.common.s3 import S3DownloaderMixin 13 | 14 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 15 | 16 | # Copied from transformers.models.bert.tokenization_bert.load_vocab 17 | def load_vocab(vocab_file): 18 | """Loads a vocabulary file into a dictionary.""" 19 | vocab = collections.OrderedDict() 20 | with open(vocab_file, "r", encoding="utf-8") as reader: 21 | tokens = reader.readlines() 22 | for index, token in enumerate(tokens): 23 | token = token.rstrip("\n") 24 | vocab[token] = index 25 | return vocab 26 | 27 | 28 | # Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize 29 | def whitespace_tokenize(text): 30 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 31 | text = text.strip() 32 | if not text: 33 | return [] 34 | tokens = text.split() 35 | return tokens 36 | 37 | 38 | class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer): 39 | r""" 40 | Construct a DistilBERT tokenizer. Based on WordPiece. 41 | 42 | This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to 43 | this superclass for more information regarding those methods. 44 | 45 | Args: 46 | vocab_file (`str`): 47 | File containing the vocabulary. 48 | do_lower_case (`bool`, *optional*, defaults to `True`): 49 | Whether or not to lowercase the input when tokenizing. 50 | do_basic_tokenize (`bool`, *optional*, defaults to `True`): 51 | Whether or not to do basic tokenization before WordPiece. 52 | never_split (`Iterable`, *optional*): 53 | Collection of tokens which will never be split during tokenization. Only has an effect when 54 | `do_basic_tokenize=True` 55 | unk_token (`str`, *optional*, defaults to `"[UNK]"`): 56 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 57 | token instead. 58 | sep_token (`str`, *optional*, defaults to `"[SEP]"`): 59 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 60 | sequence classification or for a text and a question for question answering. It is also used as the last 61 | token of a sequence built with special tokens. 62 | pad_token (`str`, *optional*, defaults to `"[PAD]"`): 63 | The token used for padding, for example when batching sequences of different lengths. 64 | cls_token (`str`, *optional*, defaults to `"[CLS]"`): 65 | The classifier token which is used when doing sequence classification (classification of the whole sequence 66 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 67 | mask_token (`str`, *optional*, defaults to `"[MASK]"`): 68 | The token used for masking values. This is the token used when training this model with masked language 69 | modeling. This is the token which the model will try to predict. 70 | tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): 71 | Whether or not to tokenize Chinese characters. 72 | 73 | This should likely be deactivated for Japanese (see this 74 | [issue](https://github.com/huggingface/transformers/issues/328)). 75 | strip_accents (`bool`, *optional*): 76 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 77 | value for `lowercase` (as in the original BERT). 78 | """ 79 | 80 | vocab_files_names = VOCAB_FILES_NAMES 81 | model_input_names = ["input_ids", "attention_mask"] 82 | 83 | def __init__( 84 | self, 85 | vocab_file, 86 | do_lower_case=True, 87 | do_basic_tokenize=True, 88 | never_split=None, 89 | unk_token="[UNK]", 90 | sep_token="[SEP]", 91 | pad_token="[PAD]", 92 | cls_token="[CLS]", 93 | mask_token="[MASK]", 94 | tokenize_chinese_chars=True, 95 | strip_accents=None, 96 | **kwargs, 97 | ): 98 | if not os.path.isfile(vocab_file): 99 | raise ValueError( 100 | f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" 101 | " model use `tokenizer = DistilBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" 102 | ) 103 | self.vocab = load_vocab(vocab_file) 104 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) 105 | self.do_basic_tokenize = do_basic_tokenize 106 | if do_basic_tokenize: 107 | self.basic_tokenizer = BasicTokenizer( 108 | do_lower_case=do_lower_case, 109 | never_split=never_split, 110 | tokenize_chinese_chars=tokenize_chinese_chars, 111 | strip_accents=strip_accents, 112 | ) 113 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) 114 | 115 | super().__init__( 116 | do_lower_case=do_lower_case, 117 | do_basic_tokenize=do_basic_tokenize, 118 | never_split=never_split, 119 | unk_token=unk_token, 120 | sep_token=sep_token, 121 | pad_token=pad_token, 122 | cls_token=cls_token, 123 | mask_token=mask_token, 124 | tokenize_chinese_chars=tokenize_chinese_chars, 125 | strip_accents=strip_accents, 126 | **kwargs, 127 | ) 128 | 129 | @property 130 | # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.do_lower_case 131 | def do_lower_case(self): 132 | return self.basic_tokenizer.do_lower_case 133 | 134 | @property 135 | # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size 136 | def vocab_size(self): 137 | return len(self.vocab) 138 | 139 | # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab 140 | def get_vocab(self): 141 | return dict(self.vocab, **self.added_tokens_encoder) 142 | 143 | # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize 144 | def _tokenize(self, text, split_special_tokens=False): 145 | split_tokens = [] 146 | if self.do_basic_tokenize: 147 | for token in self.basic_tokenizer.tokenize( 148 | text, never_split=self.all_special_tokens if not split_special_tokens else None 149 | ): 150 | # If the token is part of the never_split set 151 | if token in self.basic_tokenizer.never_split: 152 | split_tokens.append(token) 153 | else: 154 | split_tokens += self.wordpiece_tokenizer.tokenize(token) 155 | else: 156 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 157 | return split_tokens 158 | 159 | # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id 160 | def _convert_token_to_id(self, token): 161 | """Converts a token (str) in an id using the vocab.""" 162 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 163 | 164 | # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token 165 | def _convert_id_to_token(self, index): 166 | """Converts an index (integer) in a token (str) using the vocab.""" 167 | return self.ids_to_tokens.get(index, self.unk_token) 168 | 169 | # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string 170 | def convert_tokens_to_string(self, tokens): 171 | """Converts a sequence of tokens (string) in a single string.""" 172 | out_string = " ".join(tokens).replace(" ##", "").strip() 173 | return out_string 174 | 175 | # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens 176 | def build_inputs_with_special_tokens( 177 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 178 | ) -> List[int]: 179 | """ 180 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 181 | adding special tokens. A BERT sequence has the following format: 182 | 183 | - single sequence: `[CLS] X [SEP]` 184 | - pair of sequences: `[CLS] A [SEP] B [SEP]` 185 | 186 | Args: 187 | token_ids_0 (`List[int]`): 188 | List of IDs to which the special tokens will be added. 189 | token_ids_1 (`List[int]`, *optional*): 190 | Optional second list of IDs for sequence pairs. 191 | 192 | Returns: 193 | `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. 194 | """ 195 | if token_ids_1 is None: 196 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 197 | cls = [self.cls_token_id] 198 | sep = [self.sep_token_id] 199 | return cls + token_ids_0 + sep + token_ids_1 + sep 200 | 201 | # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask 202 | def get_special_tokens_mask( 203 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 204 | ) -> List[int]: 205 | """ 206 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 207 | special tokens using the tokenizer `prepare_for_model` method. 208 | 209 | Args: 210 | token_ids_0 (`List[int]`): 211 | List of IDs. 212 | token_ids_1 (`List[int]`, *optional*): 213 | Optional second list of IDs for sequence pairs. 214 | already_has_special_tokens (`bool`, *optional*, defaults to `False`): 215 | Whether or not the token list is already formatted with special tokens for the model. 216 | 217 | Returns: 218 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 219 | """ 220 | 221 | if already_has_special_tokens: 222 | return super().get_special_tokens_mask( 223 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True 224 | ) 225 | 226 | if token_ids_1 is not None: 227 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 228 | return [1] + ([0] * len(token_ids_0)) + [1] 229 | 230 | # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences 231 | def create_token_type_ids_from_sequences( 232 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 233 | ) -> List[int]: 234 | """ 235 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence 236 | pair mask has the following format: 237 | 238 | ``` 239 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 240 | | first sequence | second sequence | 241 | ``` 242 | 243 | If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). 244 | 245 | Args: 246 | token_ids_0 (`List[int]`): 247 | List of IDs. 248 | token_ids_1 (`List[int]`, *optional*): 249 | Optional second list of IDs for sequence pairs. 250 | 251 | Returns: 252 | `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). 253 | """ 254 | sep = [self.sep_token_id] 255 | cls = [self.cls_token_id] 256 | if token_ids_1 is None: 257 | return len(cls + token_ids_0 + sep) * [0] 258 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 259 | 260 | # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary 261 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 262 | index = 0 263 | if os.path.isdir(save_directory): 264 | vocab_file = os.path.join( 265 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] 266 | ) 267 | else: 268 | vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory 269 | with open(vocab_file, "w", encoding="utf-8") as writer: 270 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 271 | if index != token_index: 272 | # logger.warning( 273 | # f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." 274 | # " Please check that the vocabulary is not corrupted!" 275 | # ) 276 | index = token_index 277 | writer.write(token + "\n") 278 | index += 1 279 | return (vocab_file,) 280 | 281 | 282 | # Copied from transformers.models.bert.tokenization_bert.BasicTokenizer 283 | class BasicTokenizer(object): 284 | """ 285 | Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). 286 | 287 | Args: 288 | do_lower_case (`bool`, *optional*, defaults to `True`): 289 | Whether or not to lowercase the input when tokenizing. 290 | never_split (`Iterable`, *optional*): 291 | Collection of tokens which will never be split during tokenization. Only has an effect when 292 | `do_basic_tokenize=True` 293 | tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): 294 | Whether or not to tokenize Chinese characters. 295 | 296 | This should likely be deactivated for Japanese (see this 297 | [issue](https://github.com/huggingface/transformers/issues/328)). 298 | strip_accents (`bool`, *optional*): 299 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 300 | value for `lowercase` (as in the original BERT). 301 | do_split_on_punc (`bool`, *optional*, defaults to `True`): 302 | In some instances we want to skip the basic punctuation splitting so that later tokenization can capture 303 | the full context of the words, such as contractions. 304 | """ 305 | 306 | def __init__( 307 | self, 308 | do_lower_case=True, 309 | never_split=None, 310 | tokenize_chinese_chars=True, 311 | strip_accents=None, 312 | do_split_on_punc=True, 313 | ): 314 | if never_split is None: 315 | never_split = [] 316 | self.do_lower_case = do_lower_case 317 | self.never_split = set(never_split) 318 | self.tokenize_chinese_chars = tokenize_chinese_chars 319 | self.strip_accents = strip_accents 320 | self.do_split_on_punc = do_split_on_punc 321 | 322 | def tokenize(self, text, never_split=None): 323 | """ 324 | Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. 325 | 326 | Args: 327 | never_split (`List[str]`, *optional*) 328 | Kept for backward compatibility purposes. Now implemented directly at the base class level (see 329 | [`PreTrainedTokenizer.tokenize`]) List of token not to split. 330 | """ 331 | # union() returns a new set by concatenating the two sets. 332 | never_split = self.never_split.union(set(never_split)) if never_split else self.never_split 333 | text = self._clean_text(text) 334 | 335 | # This was added on November 1st, 2018 for the multilingual and Chinese 336 | # models. This is also applied to the English models now, but it doesn't 337 | # matter since the English models were not trained on any Chinese data 338 | # and generally don't have any Chinese data in them (there are Chinese 339 | # characters in the vocabulary because Wikipedia does have some Chinese 340 | # words in the English Wikipedia.). 341 | if self.tokenize_chinese_chars: 342 | text = self._tokenize_chinese_chars(text) 343 | # prevents treating the same character with different unicode codepoints as different characters 344 | unicode_normalized_text = unicodedata.normalize("NFC", text) 345 | orig_tokens = whitespace_tokenize(unicode_normalized_text) 346 | split_tokens = [] 347 | for token in orig_tokens: 348 | if token not in never_split: 349 | if self.do_lower_case: 350 | token = token.lower() 351 | if self.strip_accents is not False: 352 | token = self._run_strip_accents(token) 353 | elif self.strip_accents: 354 | token = self._run_strip_accents(token) 355 | split_tokens.extend(self._run_split_on_punc(token, never_split)) 356 | 357 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 358 | return output_tokens 359 | 360 | def _run_strip_accents(self, text): 361 | """Strips accents from a piece of text.""" 362 | text = unicodedata.normalize("NFD", text) 363 | output = [] 364 | for char in text: 365 | cat = unicodedata.category(char) 366 | if cat == "Mn": 367 | continue 368 | output.append(char) 369 | return "".join(output) 370 | 371 | def _run_split_on_punc(self, text, never_split=None): 372 | """Splits punctuation on a piece of text.""" 373 | if not self.do_split_on_punc or (never_split is not None and text in never_split): 374 | return [text] 375 | chars = list(text) 376 | i = 0 377 | start_new_word = True 378 | output = [] 379 | while i < len(chars): 380 | char = chars[i] 381 | if _is_punctuation(char): 382 | output.append([char]) 383 | start_new_word = True 384 | else: 385 | if start_new_word: 386 | output.append([]) 387 | start_new_word = False 388 | output[-1].append(char) 389 | i += 1 390 | 391 | return ["".join(x) for x in output] 392 | 393 | def _tokenize_chinese_chars(self, text): 394 | """Adds whitespace around any CJK character.""" 395 | output = [] 396 | for char in text: 397 | cp = ord(char) 398 | if self._is_chinese_char(cp): 399 | output.append(" ") 400 | output.append(char) 401 | output.append(" ") 402 | else: 403 | output.append(char) 404 | return "".join(output) 405 | 406 | def _is_chinese_char(self, cp): 407 | """Checks whether CP is the codepoint of a CJK character.""" 408 | # This defines a "chinese character" as anything in the CJK Unicode block: 409 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 410 | # 411 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 412 | # despite its name. The modern Korean Hangul alphabet is a different block, 413 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 414 | # space-separated words, so they are not treated specially and handled 415 | # like the all of the other languages. 416 | if ( 417 | (cp >= 0x4E00 and cp <= 0x9FFF) 418 | or (cp >= 0x3400 and cp <= 0x4DBF) # 419 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 420 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 421 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 422 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 423 | or (cp >= 0xF900 and cp <= 0xFAFF) 424 | or (cp >= 0x2F800 and cp <= 0x2FA1F) # 425 | ): # 426 | return True 427 | 428 | return False 429 | 430 | def _clean_text(self, text): 431 | """Performs invalid character removal and whitespace cleanup on text.""" 432 | output = [] 433 | for char in text: 434 | cp = ord(char) 435 | if cp == 0 or cp == 0xFFFD or _is_control(char): 436 | continue 437 | if _is_whitespace(char): 438 | output.append(" ") 439 | else: 440 | output.append(char) 441 | return "".join(output) 442 | 443 | 444 | # Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer 445 | class WordpieceTokenizer(object): 446 | """Runs WordPiece tokenization.""" 447 | 448 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 449 | self.vocab = vocab 450 | self.unk_token = unk_token 451 | self.max_input_chars_per_word = max_input_chars_per_word 452 | 453 | def tokenize(self, text): 454 | """ 455 | Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform 456 | tokenization using the given vocabulary. 457 | 458 | For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. 459 | 460 | Args: 461 | text: A single token or whitespace separated tokens. This should have 462 | already been passed through *BasicTokenizer*. 463 | 464 | Returns: 465 | A list of wordpiece tokens. 466 | """ 467 | 468 | output_tokens = [] 469 | for token in whitespace_tokenize(text): 470 | chars = list(token) 471 | if len(chars) > self.max_input_chars_per_word: 472 | output_tokens.append(self.unk_token) 473 | continue 474 | 475 | is_bad = False 476 | start = 0 477 | sub_tokens = [] 478 | while start < len(chars): 479 | end = len(chars) 480 | cur_substr = None 481 | while start < end: 482 | substr = "".join(chars[start:end]) 483 | if start > 0: 484 | substr = "##" + substr 485 | if substr in self.vocab: 486 | cur_substr = substr 487 | break 488 | end -= 1 489 | if cur_substr is None: 490 | is_bad = True 491 | break 492 | sub_tokens.append(cur_substr) 493 | start = end 494 | 495 | if is_bad: 496 | output_tokens.append(self.unk_token) 497 | else: 498 | output_tokens.extend(sub_tokens) 499 | return output_tokens ``` -------------------------------------------------------------------------------- /surya/detection/model/encoderdecoder.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | This is an implementation of efficientvit, with some modifications (decode head, etc). 3 | 4 | Original paper at https://arxiv.org/abs/2205.14756 5 | 6 | Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py 7 | Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit 8 | """ 9 | 10 | from __future__ import annotations 11 | 12 | from typing import Optional, Union, Tuple, List, Any 13 | from functools import partial 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | from transformers.modeling_outputs import SemanticSegmenterOutput 20 | 21 | from surya.common.pretrained import SuryaPreTrainedModel 22 | from surya.common.s3 import S3DownloaderMixin 23 | from surya.detection.model.config import EfficientViTConfig 24 | 25 | 26 | def val2list(x: Union[List, Tuple, Any], repeat_time=1): 27 | if isinstance(x, (list, tuple)): 28 | return list(x) 29 | return [x for _ in range(repeat_time)] 30 | 31 | 32 | def val2tuple(x: Union[List, Tuple, Any], min_len: int = 1, idx_repeat: int = -1): 33 | # repeat elements if necessary 34 | x = val2list(x) 35 | if len(x) > 0: 36 | x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] 37 | 38 | return tuple(x) 39 | 40 | 41 | def get_same_padding( 42 | kernel_size: Union[int, Tuple[int, ...]], 43 | ) -> Union[int, Tuple[int, ...]]: 44 | if isinstance(kernel_size, tuple): 45 | return tuple([get_same_padding(ks) for ks in kernel_size]) 46 | else: 47 | assert kernel_size % 2 > 0, "kernel size should be odd number" 48 | return kernel_size // 2 49 | 50 | 51 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int: 52 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 53 | return padding 54 | 55 | 56 | class ConvNormAct(nn.Module): 57 | def __init__( 58 | self, 59 | in_channels: int, 60 | out_channels: int, 61 | kernel_size=3, 62 | stride=1, 63 | dilation=1, 64 | groups=1, 65 | bias=False, 66 | dropout=0.0, 67 | norm_layer=nn.BatchNorm2d, 68 | act_layer=nn.ReLU, 69 | ): 70 | super(ConvNormAct, self).__init__() 71 | self.dropout = nn.Dropout(dropout, inplace=False) 72 | padding = get_padding(kernel_size, stride, dilation) 73 | self.conv = nn.Conv2d( 74 | in_channels, 75 | out_channels, 76 | kernel_size=kernel_size, 77 | stride=stride, 78 | dilation=dilation, 79 | groups=groups, 80 | bias=bias, 81 | padding=padding, 82 | ) 83 | self.norm = ( 84 | norm_layer(num_features=out_channels) if norm_layer else nn.Identity() 85 | ) 86 | self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity() 87 | 88 | def forward(self, x): 89 | x = self.conv(x) 90 | x = self.norm(x) 91 | x = self.act(x) 92 | return x 93 | 94 | 95 | class DSConv(nn.Module): 96 | def __init__( 97 | self, 98 | in_channels: int, 99 | out_channels: int, 100 | kernel_size=3, 101 | stride=1, 102 | use_bias=False, 103 | norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), 104 | act_layer=(nn.ReLU6, None), 105 | ): 106 | super(DSConv, self).__init__() 107 | use_bias = val2tuple(use_bias, 2) 108 | norm_layer = val2tuple(norm_layer, 2) 109 | act_layer = val2tuple(act_layer, 2) 110 | 111 | self.depth_conv = ConvNormAct( 112 | in_channels, 113 | in_channels, 114 | kernel_size, 115 | stride, 116 | groups=in_channels, 117 | norm_layer=norm_layer[0], 118 | act_layer=act_layer[0], 119 | bias=use_bias[0], 120 | ) 121 | self.point_conv = ConvNormAct( 122 | in_channels, 123 | out_channels, 124 | 1, 125 | norm_layer=norm_layer[1], 126 | act_layer=act_layer[1], 127 | bias=use_bias[1], 128 | ) 129 | 130 | def forward(self, x): 131 | x = self.depth_conv(x) 132 | x = self.point_conv(x) 133 | return x 134 | 135 | 136 | class ConvBlock(nn.Module): 137 | def __init__( 138 | self, 139 | in_channels: int, 140 | out_channels: int, 141 | kernel_size=3, 142 | stride=1, 143 | mid_channels=None, 144 | expand_ratio=1, 145 | use_bias=False, 146 | norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), 147 | act_layer=(nn.ReLU6, None), 148 | ): 149 | super(ConvBlock, self).__init__() 150 | use_bias = val2tuple(use_bias, 2) 151 | norm_layer = val2tuple(norm_layer, 2) 152 | act_layer = val2tuple(act_layer, 2) 153 | mid_channels = mid_channels or round(in_channels * expand_ratio) 154 | 155 | self.conv1 = ConvNormAct( 156 | in_channels, 157 | mid_channels, 158 | kernel_size, 159 | stride, 160 | norm_layer=norm_layer[0], 161 | act_layer=act_layer[0], 162 | bias=use_bias[0], 163 | ) 164 | self.conv2 = ConvNormAct( 165 | mid_channels, 166 | out_channels, 167 | kernel_size, 168 | 1, 169 | norm_layer=norm_layer[1], 170 | act_layer=act_layer[1], 171 | bias=use_bias[1], 172 | ) 173 | 174 | def forward(self, x): 175 | x = self.conv1(x) 176 | x = self.conv2(x) 177 | return x 178 | 179 | 180 | class MBConv(nn.Module): 181 | def __init__( 182 | self, 183 | in_channels: int, 184 | out_channels: int, 185 | kernel_size=3, 186 | stride=1, 187 | mid_channels=None, 188 | expand_ratio=6, 189 | use_bias=False, 190 | norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d), 191 | act_layer=(nn.ReLU6, nn.ReLU6, None), 192 | ): 193 | super(MBConv, self).__init__() 194 | use_bias = val2tuple(use_bias, 3) 195 | norm_layer = val2tuple(norm_layer, 3) 196 | act_layer = val2tuple(act_layer, 3) 197 | mid_channels = mid_channels or round(in_channels * expand_ratio) 198 | 199 | self.inverted_conv = ConvNormAct( 200 | in_channels, 201 | mid_channels, 202 | 1, 203 | stride=1, 204 | norm_layer=norm_layer[0], 205 | act_layer=act_layer[0], 206 | bias=use_bias[0], 207 | ) 208 | self.depth_conv = ConvNormAct( 209 | mid_channels, 210 | mid_channels, 211 | kernel_size, 212 | stride=stride, 213 | groups=mid_channels, 214 | norm_layer=norm_layer[1], 215 | act_layer=act_layer[1], 216 | bias=use_bias[1], 217 | ) 218 | self.point_conv = ConvNormAct( 219 | mid_channels, 220 | out_channels, 221 | 1, 222 | norm_layer=norm_layer[2], 223 | act_layer=act_layer[2], 224 | bias=use_bias[2], 225 | ) 226 | 227 | def forward(self, x): 228 | x = self.inverted_conv(x) 229 | x = self.depth_conv(x) 230 | x = self.point_conv(x) 231 | return x 232 | 233 | 234 | class FusedMBConv(nn.Module): 235 | def __init__( 236 | self, 237 | in_channels: int, 238 | out_channels: int, 239 | kernel_size=3, 240 | stride=1, 241 | mid_channels=None, 242 | expand_ratio=6, 243 | groups=1, 244 | use_bias=False, 245 | norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), 246 | act_layer=(nn.ReLU6, None), 247 | ): 248 | super(FusedMBConv, self).__init__() 249 | use_bias = val2tuple(use_bias, 2) 250 | norm_layer = val2tuple(norm_layer, 2) 251 | act_layer = val2tuple(act_layer, 2) 252 | mid_channels = mid_channels or round(in_channels * expand_ratio) 253 | 254 | self.spatial_conv = ConvNormAct( 255 | in_channels, 256 | mid_channels, 257 | kernel_size, 258 | stride=stride, 259 | groups=groups, 260 | norm_layer=norm_layer[0], 261 | act_layer=act_layer[0], 262 | bias=use_bias[0], 263 | ) 264 | self.point_conv = ConvNormAct( 265 | mid_channels, 266 | out_channels, 267 | 1, 268 | norm_layer=norm_layer[1], 269 | act_layer=act_layer[1], 270 | bias=use_bias[1], 271 | ) 272 | 273 | def forward(self, x): 274 | x = self.spatial_conv(x) 275 | x = self.point_conv(x) 276 | return x 277 | 278 | 279 | class LiteMLA(nn.Module): 280 | """Lightweight multi-scale linear attention""" 281 | 282 | def __init__( 283 | self, 284 | in_channels: int, 285 | out_channels: int, 286 | heads: Union[int, None] = None, 287 | heads_ratio: float = 1.0, 288 | dim=8, 289 | use_bias=False, 290 | norm_layer=(None, nn.BatchNorm2d), 291 | act_layer=(None, None), 292 | kernel_func=nn.ReLU, 293 | scales=(5,), 294 | eps=1e-5, 295 | ): 296 | super(LiteMLA, self).__init__() 297 | self.eps = eps 298 | heads = heads or int(in_channels // dim * heads_ratio) 299 | total_dim = heads * dim 300 | use_bias = val2tuple(use_bias, 2) 301 | norm_layer = val2tuple(norm_layer, 2) 302 | act_layer = val2tuple(act_layer, 2) 303 | 304 | self.dim = dim 305 | self.qkv = ConvNormAct( 306 | in_channels, 307 | 3 * total_dim, 308 | 1, 309 | bias=use_bias[0], 310 | norm_layer=norm_layer[0], 311 | act_layer=act_layer[0], 312 | ) 313 | self.aggreg = nn.ModuleList( 314 | [ 315 | nn.Sequential( 316 | nn.Conv2d( 317 | 3 * total_dim, 318 | 3 * total_dim, 319 | scale, 320 | padding=get_same_padding(scale), 321 | groups=3 * total_dim, 322 | bias=use_bias[0], 323 | ), 324 | nn.Conv2d( 325 | 3 * total_dim, 326 | 3 * total_dim, 327 | 1, 328 | groups=3 * heads, 329 | bias=use_bias[0], 330 | ), 331 | ) 332 | for scale in scales 333 | ] 334 | ) 335 | self.kernel_func = kernel_func(inplace=False) 336 | 337 | self.proj = ConvNormAct( 338 | total_dim * (1 + len(scales)), 339 | out_channels, 340 | 1, 341 | bias=use_bias[1], 342 | norm_layer=norm_layer[1], 343 | act_layer=act_layer[1], 344 | ) 345 | 346 | def _attn(self, q, k, v): 347 | dtype = v.dtype 348 | q, k, v = q.float(), k.float(), v.float() 349 | kv = k.transpose(-1, -2) @ v 350 | out = q @ kv 351 | out = out[..., :-1] / (out[..., -1:] + self.eps) 352 | return out.to(dtype) 353 | 354 | def forward(self, x): 355 | # Shape is B, C, H, W 356 | B, _, H, W = x.shape 357 | 358 | # generate multi-scale q, k, v 359 | qkv = self.qkv(x) 360 | multi_scale_qkv = [qkv] 361 | for op in self.aggreg: 362 | multi_scale_qkv.append(op(qkv)) 363 | multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) 364 | multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose( 365 | -1, -2 366 | ) 367 | # Shape for each is B, C, HW, head_dim 368 | q, k, v = multi_scale_qkv.chunk(3, dim=-1) 369 | 370 | # lightweight global attention 371 | q = self.kernel_func(q) 372 | k = self.kernel_func(k) 373 | v = F.pad(v, (0, 1), mode="constant", value=1.0) 374 | 375 | out = self._attn(q, k, v) 376 | 377 | # final projection 378 | out = out.transpose(-1, -2).reshape(B, -1, H, W) 379 | out = self.proj(out) 380 | return out 381 | 382 | 383 | class EfficientVitBlock(nn.Module): 384 | def __init__( 385 | self, 386 | in_channels, 387 | heads_ratio=1.0, 388 | head_dim=32, 389 | expand_ratio=4, 390 | norm_layer=nn.BatchNorm2d, 391 | act_layer=nn.Hardswish, 392 | ): 393 | super(EfficientVitBlock, self).__init__() 394 | self.context_module = ResidualBlock( 395 | LiteMLA( 396 | in_channels=in_channels, 397 | out_channels=in_channels, 398 | heads_ratio=heads_ratio, 399 | dim=head_dim, 400 | norm_layer=(None, norm_layer), 401 | ), 402 | nn.Identity(), 403 | ) 404 | self.local_module = ResidualBlock( 405 | MBConv( 406 | in_channels=in_channels, 407 | out_channels=in_channels, 408 | expand_ratio=expand_ratio, 409 | use_bias=(True, True, False), 410 | norm_layer=(None, None, norm_layer), 411 | act_layer=(act_layer, act_layer, None), 412 | ), 413 | nn.Identity(), 414 | ) 415 | 416 | def forward(self, x): 417 | x = self.context_module(x) 418 | x = self.local_module(x) 419 | return x 420 | 421 | 422 | class ResidualBlock(nn.Module): 423 | def __init__( 424 | self, 425 | main: Optional[nn.Module], 426 | shortcut: Optional[nn.Module] = None, 427 | pre_norm: Optional[nn.Module] = None, 428 | ): 429 | super(ResidualBlock, self).__init__() 430 | self.pre_norm = pre_norm if pre_norm is not None else nn.Identity() 431 | self.main = main 432 | self.shortcut = shortcut 433 | 434 | def forward(self, x): 435 | res = self.main(self.pre_norm(x)) 436 | if self.shortcut is not None: 437 | res = res + self.shortcut(x) 438 | return res 439 | 440 | 441 | def build_local_block( 442 | in_channels: int, 443 | out_channels: int, 444 | stride: int, 445 | kernel_size: int, 446 | expand_ratio: float, 447 | norm_layer: str, 448 | act_layer: str, 449 | fewer_norm: bool = False, 450 | block_type: str = "default", 451 | ): 452 | assert block_type in ["default", "large", "fused"] 453 | if expand_ratio == 1: 454 | if block_type == "default": 455 | block = DSConv( 456 | in_channels=in_channels, 457 | out_channels=out_channels, 458 | stride=stride, 459 | kernel_size=kernel_size, 460 | use_bias=(True, False) if fewer_norm else False, 461 | norm_layer=(None, norm_layer) if fewer_norm else norm_layer, 462 | act_layer=(act_layer, None), 463 | ) 464 | else: 465 | block = ConvBlock( 466 | in_channels=in_channels, 467 | out_channels=out_channels, 468 | stride=stride, 469 | kernel_size=kernel_size, 470 | use_bias=(True, False) if fewer_norm else False, 471 | norm_layer=(None, norm_layer) if fewer_norm else norm_layer, 472 | act_layer=(act_layer, None), 473 | ) 474 | else: 475 | if block_type == "default": 476 | block = MBConv( 477 | in_channels=in_channels, 478 | out_channels=out_channels, 479 | stride=stride, 480 | kernel_size=kernel_size, 481 | expand_ratio=expand_ratio, 482 | use_bias=(True, True, False) if fewer_norm else False, 483 | norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer, 484 | act_layer=(act_layer, act_layer, None), 485 | ) 486 | else: 487 | block = FusedMBConv( 488 | in_channels=in_channels, 489 | out_channels=out_channels, 490 | stride=stride, 491 | kernel_size=kernel_size, 492 | expand_ratio=expand_ratio, 493 | use_bias=(True, False) if fewer_norm else False, 494 | norm_layer=(None, norm_layer) if fewer_norm else norm_layer, 495 | act_layer=(act_layer, None), 496 | ) 497 | return block 498 | 499 | 500 | class Stem(nn.Sequential): 501 | def __init__( 502 | self, 503 | in_chs, 504 | out_chs, 505 | depth, 506 | stride, 507 | norm_layer, 508 | act_layer, 509 | block_type="default", 510 | ): 511 | super().__init__() 512 | self.stride = stride 513 | 514 | self.add_module( 515 | "in_conv", 516 | ConvNormAct( 517 | in_chs, 518 | out_chs, 519 | kernel_size=stride + 1, 520 | stride=stride, 521 | norm_layer=norm_layer, 522 | act_layer=act_layer, 523 | ), 524 | ) 525 | stem_block = 0 526 | for _ in range(depth): 527 | self.add_module( 528 | f"res{stem_block}", 529 | ResidualBlock( 530 | build_local_block( 531 | in_channels=out_chs, 532 | out_channels=out_chs, 533 | stride=1, 534 | kernel_size=3, 535 | expand_ratio=1, 536 | norm_layer=norm_layer, 537 | act_layer=act_layer, 538 | block_type=block_type, 539 | ), 540 | nn.Identity(), 541 | ), 542 | ) 543 | stem_block += 1 544 | 545 | 546 | class EfficientVitLargeStage(nn.Module): 547 | def __init__( 548 | self, 549 | in_chs, 550 | out_chs, 551 | depth, 552 | stride, 553 | norm_layer, 554 | act_layer, 555 | head_dim, 556 | vit_stage=False, 557 | fewer_norm=False, 558 | ): 559 | super(EfficientVitLargeStage, self).__init__() 560 | blocks = [ 561 | ResidualBlock( 562 | build_local_block( 563 | in_channels=in_chs, 564 | out_channels=out_chs, 565 | stride=stride, 566 | kernel_size=stride + 1, 567 | expand_ratio=24 if vit_stage else 16, 568 | norm_layer=norm_layer, 569 | act_layer=act_layer, 570 | fewer_norm=vit_stage or fewer_norm, 571 | block_type="default" if fewer_norm else "fused", 572 | ), 573 | None, 574 | ) 575 | ] 576 | in_chs = out_chs 577 | 578 | if vit_stage: 579 | # for stage 4 580 | for _ in range(depth): 581 | blocks.append( 582 | EfficientVitBlock( 583 | in_channels=in_chs, 584 | head_dim=head_dim, 585 | expand_ratio=6, 586 | norm_layer=norm_layer, 587 | act_layer=act_layer, 588 | ) 589 | ) 590 | else: 591 | # for stage 1, 2, 3 592 | for i in range(depth): 593 | blocks.append( 594 | ResidualBlock( 595 | build_local_block( 596 | in_channels=in_chs, 597 | out_channels=out_chs, 598 | stride=1, 599 | kernel_size=3, 600 | expand_ratio=4, 601 | norm_layer=norm_layer, 602 | act_layer=act_layer, 603 | fewer_norm=fewer_norm, 604 | block_type="default" if fewer_norm else "fused", 605 | ), 606 | nn.Identity(), 607 | ) 608 | ) 609 | 610 | self.blocks = nn.Sequential(*blocks) 611 | 612 | def forward(self, x): 613 | return self.blocks(x) 614 | 615 | 616 | class EfficientVitLarge(nn.Module): 617 | def __init__( 618 | self, 619 | config: EfficientViTConfig, 620 | norm_layer=nn.BatchNorm2d, 621 | act_layer=nn.Hardswish, 622 | ): 623 | super(EfficientVitLarge, self).__init__() 624 | self.grad_checkpointing = False 625 | self.num_classes = config.num_classes 626 | self.norm_eps = config.layer_norm_eps 627 | norm_layer = partial(norm_layer, eps=self.norm_eps) 628 | 629 | # input stem 630 | self.stem = Stem( 631 | config.num_channels, 632 | config.widths[0], 633 | config.depths[0], 634 | config.strides[0], 635 | norm_layer, 636 | act_layer, 637 | block_type="large", 638 | ) 639 | stride = config.strides[0] 640 | 641 | # stages 642 | self.feature_info = [] 643 | self.stages = nn.Sequential() 644 | in_channels = config.widths[0] 645 | for i, (w, d, s) in enumerate( 646 | zip(config.widths[1:], config.depths[1:], config.strides[1:]) 647 | ): 648 | self.stages.append( 649 | EfficientVitLargeStage( 650 | in_channels, 651 | w, 652 | depth=d, 653 | stride=s, 654 | norm_layer=norm_layer, 655 | act_layer=act_layer, 656 | head_dim=config.head_dim, 657 | vit_stage=i >= 3, 658 | fewer_norm=i >= 2, 659 | ) 660 | ) 661 | stride *= s 662 | in_channels = w 663 | self.feature_info += [ 664 | dict(num_chs=in_channels, reduction=stride, module=f"stages.{i}") 665 | ] 666 | 667 | self.num_features = in_channels 668 | 669 | @torch.jit.ignore 670 | def set_grad_checkpointing(self, enable=True): 671 | self.grad_checkpointing = enable 672 | 673 | def forward(self, x): 674 | x = self.stem(x) 675 | encoder_hidden_states = [] 676 | for i, module in enumerate(self.stages): 677 | x = module(x) 678 | encoder_hidden_states.append(x) 679 | 680 | return encoder_hidden_states 681 | 682 | 683 | class EfficientViTPreTrainedModel(SuryaPreTrainedModel): 684 | """ 685 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 686 | models. 687 | """ 688 | 689 | config_class = EfficientViTConfig 690 | base_model_prefix = "efficientvit" 691 | main_input_name = "pixel_values" 692 | 693 | def _init_weights(self, module): 694 | """Initialize the weights""" 695 | if isinstance(module, (nn.Linear, nn.Conv2d)): 696 | # Slightly different from the TF version which uses truncated_normal for initialization 697 | # cf https://github.com/pytorch/pytorch/pull/5617 698 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 699 | if module.bias is not None: 700 | module.bias.data.zero_() 701 | elif isinstance(module, nn.Embedding): 702 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 703 | if module.padding_idx is not None: 704 | module.weight.data[module.padding_idx].zero_() 705 | elif isinstance(module, nn.LayerNorm): 706 | module.bias.data.zero_() 707 | module.weight.data.fill_(1.0) 708 | 709 | 710 | class DecodeMLP(nn.Module): 711 | def __init__(self, input_dim, output_dim): 712 | super().__init__() 713 | self.proj = nn.Linear(input_dim, output_dim) 714 | 715 | def forward(self, hidden_states: torch.Tensor): 716 | # Input is B, C, H, W 717 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 718 | # Output is B, HW, C 719 | hidden_states = self.proj(hidden_states) 720 | return hidden_states 721 | 722 | 723 | class DecodeHead(EfficientViTPreTrainedModel): 724 | def __init__(self, config: EfficientViTConfig): 725 | super().__init__(config) 726 | 727 | # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size 728 | mlps = [] 729 | for width in config.widths[1:]: 730 | mlp = DecodeMLP( 731 | input_dim=width, output_dim=config.decoder_layer_hidden_size 732 | ) 733 | mlps.append(mlp) 734 | self.linear_c = nn.ModuleList(mlps) 735 | 736 | # the following 3 layers implement the ConvModule of the original implementation 737 | self.linear_fuse = nn.Conv2d( 738 | in_channels=config.decoder_layer_hidden_size * config.num_stages, 739 | out_channels=config.decoder_hidden_size, 740 | kernel_size=1, 741 | bias=False, 742 | ) 743 | self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size) 744 | self.activation = nn.ReLU() 745 | 746 | self.dropout = nn.Dropout(config.classifier_dropout_prob) 747 | self.classifier = nn.Conv2d( 748 | config.decoder_hidden_size, config.num_labels, kernel_size=1 749 | ) 750 | 751 | self.config = config 752 | 753 | def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor: 754 | batch_size = encoder_hidden_states[-1].shape[0] 755 | 756 | all_hidden_states = () 757 | for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c): 758 | height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] 759 | encoder_hidden_state = mlp(encoder_hidden_state) # Output is B, HW, C 760 | # Permute to B, C, HW 761 | encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) 762 | encoder_hidden_state = encoder_hidden_state.reshape( 763 | batch_size, -1, height, width 764 | ) 765 | # upsample 766 | encoder_hidden_state = nn.functional.interpolate( 767 | encoder_hidden_state, 768 | size=encoder_hidden_states[0].size()[2:], 769 | mode="bilinear", 770 | align_corners=False, 771 | ) 772 | all_hidden_states += (encoder_hidden_state,) 773 | 774 | hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) 775 | hidden_states = self.batch_norm(hidden_states) 776 | hidden_states = self.activation(hidden_states) 777 | 778 | # logits are of shape (batch_size, num_labels, height/4, width/4) 779 | logits = self.classifier(hidden_states) 780 | 781 | return logits 782 | 783 | 784 | class EfficientViTForSemanticSegmentation( 785 | S3DownloaderMixin, EfficientViTPreTrainedModel 786 | ): 787 | def __init__(self, config, **kwargs): 788 | super().__init__(config) 789 | self.vit = EfficientVitLarge(config) 790 | self.decode_head = DecodeHead(config) 791 | 792 | # Initialize weights and apply final processing 793 | self.post_init() 794 | 795 | def forward( 796 | self, pixel_values: torch.FloatTensor 797 | ) -> Union[Tuple, SemanticSegmenterOutput]: 798 | # Pixel values should be B,C,H,W 799 | encoder_hidden_states = self.vit( 800 | pixel_values, 801 | ) 802 | 803 | logits = self.decode_head(encoder_hidden_states) 804 | 805 | # Apply sigmoid to get 0-1 output 806 | logits = torch.special.expit(logits) 807 | 808 | return SemanticSegmenterOutput( 809 | loss=None, logits=logits, hidden_states=encoder_hidden_states 810 | ) 811 | 812 | 813 | class EfficientViTForSemanticLayoutSegmentation(EfficientViTPreTrainedModel): 814 | def __init__(self, config, **kwargs): 815 | super().__init__(config, **kwargs) 816 | self.vit = EfficientVitLarge(config) 817 | self.decode_head = DecodeHead(config) 818 | 819 | # Initialize weights and apply final processing 820 | self.post_init() 821 | 822 | def forward( 823 | self, pixel_values: torch.FloatTensor 824 | ) -> Union[Tuple, SemanticSegmenterOutput]: 825 | # Pixel values should be B,C,H,W 826 | encoder_hidden_states = self.vit( 827 | pixel_values, 828 | ) 829 | 830 | logits = self.decode_head(encoder_hidden_states) 831 | 832 | # Apply sigmoid to get 0-1 output 833 | logits = torch.special.expit(logits) 834 | 835 | return SemanticSegmenterOutput( 836 | loss=None, logits=logits, hidden_states=encoder_hidden_states 837 | ) 838 | ``` -------------------------------------------------------------------------------- /surya/common/surya/processor/tokenizer.py: -------------------------------------------------------------------------------- ```python 1 | import html 2 | import re 3 | from typing import List, Union, Dict, Optional, Tuple, Iterable 4 | import numpy as np 5 | import torch 6 | from tokenizers import AddedToken 7 | import json 8 | import os 9 | from transformers import PreTrainedTokenizer, Qwen2Tokenizer as Qwen2OriginalTokenizer 10 | 11 | 12 | from surya.common.s3 import S3DownloaderMixin 13 | from surya.common.surya.schema import TASK_NAMES, TaskNames 14 | from surya.logging import get_logger 15 | from surya.settings import settings 16 | 17 | logger = get_logger() 18 | 19 | 20 | def create_token_regex(tokens): 21 | escaped_tokens = [re.escape(token) for token in tokens] 22 | escaped_tokens.sort(key=len, reverse=True) 23 | pattern = r"^(" + "|".join(escaped_tokens) + r")" 24 | regex = re.compile(pattern) 25 | return regex 26 | 27 | 28 | class InnerOCRTokenizer: 29 | def __init__( 30 | self, 31 | special_tokens: Dict[str, list] | None = None, 32 | qwen_tokenizer: Qwen2OriginalTokenizer | None = None, 33 | **kwargs, 34 | ): 35 | self.qwen_tokenizer = qwen_tokenizer 36 | self.qwen_token_offset = len(qwen_tokenizer) 37 | 38 | all_special_tokens = special_tokens.get("all", []) 39 | self.SPECIAL_TOKEN_MAPPING = {} 40 | 41 | idx = 0 42 | for tag in all_special_tokens: 43 | if tag in self.SPECIAL_TOKEN_MAPPING: 44 | continue 45 | self.SPECIAL_TOKEN_MAPPING[tag] = ( 46 | idx + self.qwen_token_offset 47 | ) # Assign token ID 48 | idx += 1 49 | 50 | self.REVERSE_SPECIAL_TOKEN_MAPPING = { 51 | v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items() 52 | } 53 | self.SPECIAL_TOKEN_OFFSET = idx 54 | self.FORMAT_TAG_PATTERN = create_token_regex(special_tokens["formatting"]) 55 | self.MATH_TAG_PATTERN = create_token_regex(special_tokens["math_external"]) 56 | self.LAYOUT_TAG_PATTERN = create_token_regex(special_tokens["layout"]) 57 | self.TABLE_STRUCTURE_TAG_PATTERN = create_token_regex( 58 | special_tokens["table_structure"] 59 | ) 60 | self.SYSTEM_TAG_PATTERN = create_token_regex(special_tokens.get("system", [])) 61 | if not special_tokens.get("system", []): 62 | logger.warning("Warning: No system tokens found in special_tokens") 63 | 64 | self.MATH_TAG_START = "<math" 65 | self.MATH_END_TAG = "</math>" 66 | 67 | super().__init__(**kwargs) 68 | 69 | @property 70 | def vocab_size(self): 71 | return ( 72 | 65536 + self.SPECIAL_TOKEN_OFFSET 73 | ) # The highest codepoint is 65535, but we add 1 to account for the 0-indexing 74 | 75 | def _tokenize(self, text: str) -> List[int]: 76 | tokens = [] 77 | in_math = False 78 | text = html.unescape(text) # Unescape html entities like < in equations 79 | while text: 80 | # Look for EOS, PAD, etc. tokens 81 | match = self.SYSTEM_TAG_PATTERN.search(text) 82 | if match: 83 | tag = match.group(1) 84 | tokens.append( 85 | self.SPECIAL_TOKEN_MAPPING[tag] 86 | ) # These are already offset 87 | text = text[match.end() :] 88 | continue 89 | 90 | # Look for layout tokens 91 | match = self.LAYOUT_TAG_PATTERN.search(text) 92 | if match: 93 | tag = match.group(1) 94 | tokens.append( 95 | self.SPECIAL_TOKEN_MAPPING[tag] 96 | ) # Layout tokens are already offset 97 | text = text[match.end() :] 98 | continue 99 | 100 | match = self.TABLE_STRUCTURE_TAG_PATTERN.search(text) 101 | if match: 102 | tag = match.group(1) 103 | tokens.append(self.SPECIAL_TOKEN_MAPPING[tag]) 104 | text = text[match.end() :] 105 | continue 106 | 107 | # Check for math tags 108 | match = self.MATH_TAG_PATTERN.search(text) 109 | if match: 110 | # We found a tag 111 | tag = match.group(1) 112 | if tag.startswith(self.MATH_TAG_START): 113 | in_math = True 114 | elif tag == self.MATH_END_TAG: 115 | in_math = False 116 | tokens.append( 117 | self.SPECIAL_TOKEN_MAPPING[tag] # Special tokens are already offset 118 | ) # Use special token ID 119 | text = text[match.end() :] 120 | continue 121 | 122 | # Tokenize math content with qwen2 tokenizer 123 | if in_math: 124 | # If we're in a math block, check to see if we have a special math tag in the text 125 | math_end_position = text.find(self.MATH_END_TAG) 126 | math_str = text[:math_end_position] # Gets the math content 127 | tokens += self.qwen_tokenizer(math_str)["input_ids"] 128 | text = text[math_end_position:] 129 | continue 130 | 131 | # Check for formatting tags 132 | match = self.FORMAT_TAG_PATTERN.search(text) 133 | if match: 134 | # We found a tag 135 | tag = match.group(1) 136 | tokens.append( 137 | self.SPECIAL_TOKEN_MAPPING[tag] # Special tokens are already offset 138 | ) # Use special token ID 139 | text = text[match.end() :] 140 | continue 141 | 142 | # General case, utf-16 tokenization 143 | utf_16_tokens = self.text_to_utf16_numbers(text[0]) 144 | tokens += [ 145 | t + self.SPECIAL_TOKEN_OFFSET + self.qwen_token_offset 146 | for t in utf_16_tokens 147 | ] 148 | text = text[1:] 149 | 150 | return tokens 151 | 152 | def text_to_utf16_numbers(self, text: str): 153 | """Converts text to UTF-16 encoded numbers.""" 154 | utf16_bytes = text.encode( 155 | "utf-16le" 156 | ) # Little-endian to simplify byte order handling 157 | numbers = [] 158 | 159 | for i in range(0, len(utf16_bytes), 2): 160 | # Combine two adjacent bytes into a single number 161 | number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8) 162 | numbers.append(number) 163 | 164 | return numbers 165 | 166 | def utf16_numbers_to_text(self, numbers): 167 | """Converts UTF-16 numbers back to text.""" 168 | byte_array = bytearray() 169 | for number in numbers: 170 | byte_array.append(number & 0xFF) # Lower byte 171 | byte_array.append((number >> 8) & 0xFF) # Upper byte 172 | 173 | try: 174 | text = byte_array.decode("utf-16le", errors="ignore") 175 | except Exception as e: 176 | logger.warning(f"Error decoding utf16: {e}") 177 | text = "" 178 | 179 | return text 180 | 181 | def __call__( 182 | self, texts: Union[str, List[str]], **kwargs 183 | ) -> Dict[str, List[List[int]]]: 184 | """Tokenizes text and returns input IDs.""" 185 | tokenized = [] 186 | 187 | if isinstance(texts, str): 188 | texts = [texts] 189 | 190 | for text in texts: 191 | tokens = self._tokenize(text) 192 | tokenized.append(tokens) 193 | 194 | return {"input_ids": tokenized} 195 | 196 | def decode(self, token_ids, **kwargs): 197 | """Decodes token IDs back to text.""" 198 | if isinstance(token_ids, (np.ndarray, torch.Tensor)): 199 | token_ids = token_ids.tolist() 200 | 201 | decoded_text = "" 202 | token_buffer = [] 203 | decode_qwen = [False] 204 | 205 | def decode_buffer(): 206 | nonlocal decoded_text, token_buffer, decode_qwen 207 | if token_buffer: 208 | if decode_qwen[0]: 209 | decoded_text += self.qwen_tokenizer.decode(token_buffer) 210 | else: 211 | token_buffer = [ 212 | t - self.SPECIAL_TOKEN_OFFSET - self.qwen_token_offset 213 | for t in token_buffer 214 | ] 215 | decoded_text += self.utf16_numbers_to_text(token_buffer) 216 | 217 | token_buffer = [] 218 | decode_qwen[0] = False 219 | 220 | for t in token_ids: 221 | if t < self.qwen_token_offset: 222 | # This is for math tags 223 | if token_buffer and token_buffer[-1] >= self.qwen_token_offset: 224 | decode_buffer() 225 | token_buffer.append(t) 226 | decode_qwen[0] = True 227 | elif t >= self.SPECIAL_TOKEN_OFFSET + self.qwen_token_offset: 228 | if token_buffer and token_buffer[-1] < self.qwen_token_offset: 229 | decode_buffer() 230 | token_buffer.append(t) # We shift this down later on 231 | decode_qwen[0] = False 232 | elif t in self.REVERSE_SPECIAL_TOKEN_MAPPING: 233 | decode_buffer() 234 | decoded_text += self.REVERSE_SPECIAL_TOKEN_MAPPING[t] 235 | decode_qwen[0] = False 236 | else: 237 | raise ValueError( 238 | f'Unexpected token value while decoding, got "{t}" in token_ids {token_ids}' 239 | ) 240 | 241 | # Detokenize remaining tokens 242 | decode_buffer() 243 | 244 | return decoded_text 245 | 246 | 247 | class Qwen2Tokenizer(S3DownloaderMixin, Qwen2OriginalTokenizer): 248 | pass 249 | 250 | class GreedyMathUTF16Tokenizer(S3DownloaderMixin, PreTrainedTokenizer): 251 | """ 252 | HuggingFace slow tokenizer implementing: 253 | - UTF-16 code units as the base [0..65535] 254 | - Math tokens as greedy-longest-match ids after UTF-16 255 | - Literal special tokens after math tokens 256 | Absolute ID layout: 257 | [0 .. 65535] : UTF-16 units 258 | [65536 .. 65536+M-1] : math tokens 259 | [65536+M .. 65536+M+S-1] : special tokens 260 | """ 261 | 262 | vocab_files_names = { 263 | "vocab_file": "vocab_math.json", # {"\\frac": 0, "\\alpha": 1, ...} raw contiguous ids 0..M-1 264 | "specials_file": "specials.json", # [flat list for legacy] 265 | "specials_dict_file": "specials_dict.json", # category dict (preferred) 266 | } 267 | model_input_names = ["input_ids", "attention_mask"] 268 | is_fast = False 269 | 270 | # ---------- helpers ---------- 271 | @staticmethod 272 | def _to_utf16_units(s: str) -> List[int]: 273 | b = s.encode("utf-16le") 274 | return [int.from_bytes(b[i : i + 2], "little") for i in range(0, len(b), 2)] 275 | 276 | @staticmethod 277 | def _from_utf16_units(units: List[int]) -> str: 278 | b = bytearray() 279 | for u in units: 280 | b += int(u).to_bytes(2, "little") 281 | return b.decode("utf-16le", errors="ignore") 282 | 283 | class _TrieNode: 284 | __slots__ = ("child", "id", "leaf") 285 | 286 | def __init__(self): 287 | self.child: Dict[str, "GreedyMathUTF16Tokenizer._TrieNode"] = {} 288 | self.id: Optional[int] = None 289 | self.leaf: bool = False 290 | 291 | @classmethod 292 | def _build_trie( 293 | cls, token_to_id: Dict[str, int] 294 | ) -> "GreedyMathUTF16Tokenizer._TrieNode": 295 | root = cls._TrieNode() 296 | for tok, tid in token_to_id.items(): 297 | node = root 298 | for ch in tok: 299 | node = node.child.setdefault(ch, cls._TrieNode()) 300 | node.leaf = True 301 | node.id = tid 302 | return root 303 | 304 | @classmethod 305 | def _encode_math_greedy( 306 | cls, 307 | s: str, 308 | trie: "GreedyMathUTF16Tokenizer._TrieNode", 309 | math_base: int, 310 | debug: bool = False, 311 | ) -> List[int]: 312 | i, n = 0, len(s) 313 | out: List[int] = [] 314 | while i < n: 315 | node = trie 316 | j = i 317 | last_id = None 318 | last_j = i 319 | while j < n and (ch := s[j]) in node.child: 320 | node = node.child[ch] 321 | j += 1 322 | if node.leaf: 323 | last_id, last_j = node.id, j 324 | if last_id is not None: 325 | if debug: 326 | print(f"[MATH] matched {s[i:last_j]!r} -> {last_id}") 327 | out.append(math_base + last_id) 328 | i = last_j 329 | else: 330 | units = cls._to_utf16_units(s[i]) 331 | if debug: 332 | print(f"[MATH] fallback {s[i]!r} -> utf16 {units}") 333 | out.extend(units) 334 | i += 1 335 | return out 336 | 337 | # ---------- init ---------- 338 | def __init__( 339 | self, 340 | vocab_file: Optional[str] = None, 341 | specials_file: Optional[str] = None, 342 | specials_dict_file: Optional[str] = None, 343 | *, 344 | # You can also pass programmatically instead of files: 345 | math_vocab: Optional[Dict[str, int]] = None, 346 | special_tokens: Optional[List[str]] = None, 347 | special_tokens_dict: Optional[Dict[str, List[str]]] = None, 348 | debug: bool = False, 349 | # Standard HF special token kwargs: 350 | bos_token: Optional[str] = None, 351 | eos_token: Optional[str] = None, 352 | pad_token: Optional[str] = None, 353 | unk_token: Optional[str] = None, 354 | **kwargs, 355 | ): 356 | # Load math vocab 357 | if vocab_file and os.path.isfile(vocab_file): 358 | with open(vocab_file, "r", encoding="utf-8") as f: 359 | mv = json.load(f) 360 | else: 361 | mv = math_vocab or {} 362 | 363 | # Make math ids contiguous if needed 364 | if mv: 365 | max_id = max(mv.values()) 366 | if set(mv.values()) != set(range(max_id + 1)): 367 | items = sorted(mv.items(), key=lambda kv: kv[1]) 368 | mv = {tok: i for i, (tok, _) in enumerate(items)} 369 | 370 | # Load special tokens (prefer category dict; fallback to flat list or defaults) 371 | sp_dict = None 372 | if specials_dict_file and os.path.isfile(specials_dict_file): 373 | with open(specials_dict_file, "r", encoding="utf-8") as f: 374 | sp_dict = json.load(f) 375 | elif special_tokens_dict is not None: 376 | sp_dict = dict(special_tokens_dict) 377 | 378 | if sp_dict is None: 379 | # Legacy path: flat list from file or provided/default list 380 | if specials_file and os.path.isfile(specials_file): 381 | with open(specials_file, "r", encoding="utf-8") as f: 382 | sp_list_flat = json.load(f) 383 | else: 384 | sp_list_flat = special_tokens or SPECIAL_TOKENS 385 | sp_dict = {"all": list(sp_list_flat)} 386 | 387 | # Ensure "all" exists and is unique/preserved in order. 388 | if "all" not in sp_dict or not isinstance(sp_dict["all"], list): 389 | order = [ 390 | "system", 391 | "formatting", 392 | "math_external", 393 | "script", 394 | "layout", 395 | "reasoning", 396 | "table_structure", 397 | "reserved", 398 | ] 399 | seen = set() 400 | all_tokens: List[str] = [] 401 | for k in order: 402 | if k in sp_dict and isinstance(sp_dict[k], list): 403 | for t in sp_dict[k]: 404 | if t not in seen: 405 | all_tokens.append(t) 406 | seen.add(t) 407 | sp_dict["all"] = all_tokens 408 | 409 | # Keep a copy of categories (if present) for downstream processor logic. 410 | self.special_tokens = sp_dict 411 | sp_list = list(sp_dict.get("all", [])) 412 | # Regex list should favor longest-first to avoid partial matches. 413 | specials_for_regex = sorted(sp_list, key=len, reverse=True) 414 | 415 | self.debug = debug 416 | self.UTF16_SPACE = 65536 417 | self.math_token_to_rawid = dict(mv) # 0..M-1 418 | self.math_vocab_size = len(self.math_token_to_rawid) 419 | self.MATH_BASE = self.UTF16_SPACE 420 | self.SPECIAL_BASE = self.UTF16_SPACE + self.math_vocab_size 421 | 422 | # Maps 423 | self.math_absid_to_token = { 424 | self.MATH_BASE + rid: tok for tok, rid in self.math_token_to_rawid.items() 425 | } 426 | self.special_tokens_list = sp_list # ID assignment order 427 | self.special_to_absid = { 428 | tok: self.SPECIAL_BASE + i for i, tok in enumerate(self.special_tokens_list) 429 | } 430 | self.absid_to_special = {v: k for k, v in self.special_to_absid.items()} 431 | 432 | # Public attributes for legacy/processor: 433 | # All specials mapping (token -> absolute id) 434 | self.SPECIAL_TOKEN_MAPPING: Dict[str, int] = dict(self.special_to_absid) 435 | # Subset used heavily by processor for quick access 436 | self.reverse_special_token_mapping = { 437 | v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items() 438 | } 439 | self.LAYOUT_LABEL2ID = { 440 | k: v 441 | for k, v in self.SPECIAL_TOKEN_MAPPING.items() 442 | if k in self.special_tokens["layout"] 443 | } 444 | self.TABLE_STRUCTURE_LABEL2ID = { 445 | k: v 446 | for k, v in self.SPECIAL_TOKEN_MAPPING.items() 447 | if k in self.special_tokens["table_structure"] 448 | } 449 | if not self.special_tokens.get("system", []): 450 | print("Warning: No system tokens found in special_tokens") 451 | 452 | self.MATH_TAG_START = "<math" 453 | self.MATH_END_TAG = "</math>" 454 | 455 | sys_list = self.special_tokens.get("system", []) 456 | self.system_tokens: Dict[str, int] = { 457 | t: self.special_to_absid[t] for t in sys_list if t in self.special_to_absid 458 | } 459 | 460 | # Regex for literal specials 461 | self.specials_pattern = ( 462 | re.compile(r"(" + "|".join(re.escape(k) for k in specials_for_regex) + r")") 463 | if specials_for_regex 464 | else None 465 | ) 466 | 467 | # Trie for math greedy match 468 | self.trie = self._build_trie(self.math_token_to_rawid) 469 | 470 | # Tell HF about special tokens (metadata) 471 | kwargs.setdefault("bos_token", bos_token) 472 | kwargs.setdefault("eos_token", eos_token or "</S>") 473 | kwargs.setdefault("pad_token", pad_token or "<PAD>") 474 | kwargs.setdefault("unk_token", unk_token) 475 | 476 | super().__init__( 477 | vocab_file=vocab_file, 478 | specials_file=specials_file, 479 | specials_dict_file=specials_dict_file, 480 | **kwargs, 481 | ) 482 | 483 | # ---------- required HF surface ---------- 484 | @property 485 | def vocab_size(self) -> int: 486 | return self.UTF16_SPACE + self.math_vocab_size + len(self.special_tokens_list) 487 | 488 | def get_vocab(self) -> Dict[str, int]: 489 | # Compact vocab: just math+specials with ABSOLUTE ids. 490 | v = {tok: self.MATH_BASE + rid for tok, rid in self.math_token_to_rawid.items()} 491 | v.update(self.special_to_absid) 492 | return v 493 | 494 | def __len__(self) -> int: 495 | return self.vocab_size 496 | 497 | # Core encode/decode on ABSOLUTE ids 498 | def _encode_core(self, text: str) -> List[int]: 499 | text = html.unescape(text) 500 | ids: List[int] = [] 501 | in_math = False 502 | chunks = self.specials_pattern.split(text) if self.specials_pattern else [text] 503 | for chunk in chunks: 504 | if chunk in self.special_to_absid: 505 | ids.append(self.special_to_absid[chunk]) 506 | if chunk.startswith("<math"): 507 | in_math = True 508 | elif chunk.startswith("</math>"): 509 | in_math = False 510 | if self.debug: 511 | print(f"[TAG] {chunk!r} -> {self.special_to_absid[chunk]}") 512 | continue 513 | 514 | if in_math: 515 | ids.extend( 516 | self._encode_math_greedy( 517 | chunk, self.trie, self.MATH_BASE, debug=self.debug 518 | ) 519 | ) 520 | else: 521 | units = self._to_utf16_units(chunk) 522 | if self.debug and units: 523 | print( 524 | f"[TEXT] utf16 {chunk[:32]!r} -> {units[:8]}{'...' if len(units) > 8 else ''}" 525 | ) 526 | ids.extend(units) 527 | return ids 528 | 529 | def _decode_core(self, ids: Iterable[int]) -> str: 530 | out: List[str] = [] 531 | buf: List[int] = [] 532 | 533 | def flush(): 534 | if buf: 535 | out.append(self._from_utf16_units(buf)) 536 | buf.clear() 537 | 538 | for tid in ids: 539 | if tid >= self.MATH_BASE and tid < self.SPECIAL_BASE: 540 | flush() 541 | out.append(self.math_absid_to_token.get(tid, "")) 542 | elif tid >= self.SPECIAL_BASE: 543 | flush() 544 | out.append(self.absid_to_special.get(tid, "")) 545 | else: 546 | buf.append(int(tid)) 547 | flush() 548 | return "".join(out) 549 | 550 | # ---- Tokenizer interface ---- 551 | def _tokenize(self, text: str, **kwargs) -> List[str]: 552 | ids = self._encode_core(text) 553 | toks: List[str] = [] 554 | for i in ids: 555 | if i < self.MATH_BASE: 556 | toks.append(f"<U+{i:04X}>") 557 | elif i < self.SPECIAL_BASE: 558 | toks.append(self.math_absid_to_token.get(i, "<UNK_MATH>")) 559 | else: 560 | toks.append(self.absid_to_special.get(i, "<UNK_SPECIAL>")) 561 | return toks 562 | 563 | def _convert_token_to_id(self, token: str) -> int: 564 | if token.startswith("<U+") and token.endswith(">"): 565 | try: 566 | return int(token[3:-1], 16) # UTF-16 unit 567 | except Exception: 568 | return self.unk_token_id if self.unk_token_id is not None else 0 569 | # math or specials 570 | if token in self.math_token_to_rawid: 571 | return self.MATH_BASE + self.math_token_to_rawid[token] 572 | if token in self.special_to_absid: 573 | return self.special_to_absid[token] 574 | # rare path: single-char token -> its UTF-16 unit 575 | if len(token) == 1: 576 | u = self._to_utf16_units(token) 577 | if len(u) == 1: 578 | return u[0] 579 | return self.unk_token_id if self.unk_token_id is not None else 0 580 | 581 | def _convert_id_to_token(self, index: int) -> str: 582 | if index < self.MATH_BASE: 583 | return f"<U+{index:04X}>" 584 | if index < self.SPECIAL_BASE: 585 | return self.math_absid_to_token.get(index, "<UNK_MATH>") 586 | return self.absid_to_special.get(index, "<UNK_SPECIAL>") 587 | 588 | def convert_tokens_to_string(self, tokens: List[str]) -> str: 589 | ids = [self._convert_token_to_id(t) for t in tokens] 590 | return self._decode_core(ids) 591 | 592 | def decode(self, token_ids, skip_special_tokens: bool = False, **kwargs) -> str: 593 | # Accept int, list, tuple, numpy, torch 594 | if hasattr(token_ids, "tolist"): 595 | token_ids = token_ids.tolist() 596 | elif isinstance(token_ids, int): 597 | token_ids = [token_ids] 598 | else: 599 | token_ids = list(token_ids) 600 | token_ids = [int(i) for i in token_ids] # normalize early 601 | 602 | if skip_special_tokens: 603 | token_ids = [i for i in token_ids if i < self.SPECIAL_BASE] 604 | return self._decode_core(token_ids) 605 | 606 | # HF plumbing 607 | def build_inputs_with_special_tokens( 608 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 609 | ) -> List[int]: 610 | out = ( 611 | list(token_ids_0) 612 | if token_ids_1 is None 613 | else list(token_ids_0) + list(token_ids_1) 614 | ) 615 | # if self.eos_token_id is not None and (not out or out[-1] != self.eos_token_id): 616 | # out.append(self.eos_token_id) 617 | return out 618 | 619 | def get_special_tokens_mask( 620 | self, 621 | token_ids_0: List[int], 622 | token_ids_1: Optional[List[int]] = None, 623 | already_has_special_tokens: bool = False, 624 | ) -> List[int]: 625 | def mask(seq: List[int]) -> List[int]: 626 | return [1 if i >= self.SPECIAL_BASE else 0 for i in seq] 627 | 628 | return ( 629 | mask(token_ids_0) 630 | if token_ids_1 is None 631 | else mask(token_ids_0) + mask(token_ids_1) 632 | ) 633 | 634 | def create_token_type_ids_from_sequences( 635 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 636 | ) -> List[int]: 637 | return [0] * ( 638 | len(token_ids_0) 639 | if token_ids_1 is None 640 | else len(token_ids_0) + len(token_ids_1) 641 | ) 642 | 643 | # Save/load raw assets 644 | def save_vocabulary( 645 | self, save_directory: str, filename_prefix: Optional[str] = None 646 | ) -> Tuple[str, str]: 647 | os.makedirs(save_directory, exist_ok=True) 648 | pre = (filename_prefix + "-") if filename_prefix else "" 649 | vocab_path = os.path.join( 650 | save_directory, pre + self.vocab_files_names["vocab_file"] 651 | ) 652 | specials_path = os.path.join( 653 | save_directory, pre + self.vocab_files_names["specials_file"] 654 | ) 655 | specials_dict_path = os.path.join( 656 | save_directory, pre + self.vocab_files_names["specials_dict_file"] 657 | ) 658 | with open(vocab_path, "w", encoding="utf-8") as f: 659 | json.dump(self.math_token_to_rawid, f, ensure_ascii=False, indent=2) 660 | # Save both the flat list ("all") and the category dict (preferred) 661 | with open(specials_path, "w", encoding="utf-8") as f: 662 | json.dump(self.special_tokens_list, f, ensure_ascii=False, indent=2) 663 | with open(specials_dict_path, "w", encoding="utf-8") as f: 664 | json.dump(self.special_tokens, f, ensure_ascii=False, indent=2) 665 | return (vocab_path, specials_path) 666 | 667 | 668 | class SuryaOCRTokenizer(S3DownloaderMixin, PreTrainedTokenizer): 669 | def __init__( 670 | self, 671 | special_tokens: Dict[str, list] | None = None, 672 | model_checkpoint: str = settings.FOUNDATION_MODEL_CHECKPOINT, 673 | **kwargs, 674 | ): 675 | if special_tokens is None: 676 | special_tokens = dict() 677 | 678 | self.special_tokens = special_tokens 679 | 680 | self.ocr_tokenizer = GreedyMathUTF16Tokenizer.from_pretrained( 681 | model_checkpoint, 682 | ) 683 | self.system_tokens = { 684 | v: self.ocr_tokenizer(v)["input_ids"][0] 685 | for v in special_tokens.get("system", []) 686 | } 687 | self.SPECIAL_TOKEN_MAPPING = self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING 688 | 689 | super().__init__(**kwargs) 690 | 691 | def get_vocab(self) -> Dict[str, int]: 692 | return self.ocr_tokenizer.get_vocab() 693 | 694 | def _add_tokens( 695 | self, 696 | new_tokens: Union[List[str], List[AddedToken]], 697 | special_tokens: bool = False, 698 | ) -> int: 699 | return self.ocr_tokenizer._add_tokens( 700 | new_tokens, special_tokens=special_tokens 701 | ) 702 | 703 | @property 704 | def vocab_size(self): 705 | return self.ocr_tokenizer.vocab_size 706 | 707 | def _tokenize(self, text: str, **kwargs): 708 | # task = kwargs.get("task", TaskNames.ocr_with_boxes) 709 | # assert task in TASK_NAMES, f"Invalid task: {task}" 710 | 711 | tokens = self.ocr_tokenizer(text)["input_ids"] 712 | 713 | return tokens 714 | 715 | def __call__( 716 | self, 717 | texts: Union[str, List[str]], 718 | tasks: Union[str, List[str]] = None, 719 | **kwargs, 720 | ) -> Dict[str, List[List[int]]]: 721 | """Tokenizes text and returns input IDs.""" 722 | tokenized = [] 723 | 724 | if isinstance(texts, str): 725 | texts = [texts] 726 | assert isinstance(tasks, str), "Tasks must be a string if texts is a string" 727 | tasks = [tasks] 728 | 729 | if isinstance(texts, list): 730 | assert isinstance(tasks, list), "Tasks must be a list if texts is a list" 731 | 732 | for text, task in zip(texts, tasks): 733 | tokens = self._tokenize(text, task=task) 734 | tokenized.append(tokens) 735 | 736 | return {"input_ids": tokenized} 737 | 738 | def decode(self, token_ids, **kwargs): 739 | if isinstance(token_ids, (np.ndarray, torch.Tensor)): 740 | token_ids = token_ids.tolist() 741 | 742 | decoded_text = self.ocr_tokenizer.decode(token_ids, skip_special_tokens=False) 743 | # replace all <SCRIPT-...> tokens with empty strings 744 | decoded_text = re.sub(r"<SCRIPT-.*?>", "", decoded_text) 745 | # replace </S> with empty string 746 | decoded_text = re.sub(r"</S>", "", decoded_text) 747 | return decoded_text 748 | ``` -------------------------------------------------------------------------------- /surya/common/surya/__init__.py: -------------------------------------------------------------------------------- ```python 1 | import warnings 2 | from typing import Optional, Tuple, TypedDict 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from transformers.modeling_outputs import CausalLMOutputWithPast 9 | from transformers.cache_utils import Cache 10 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter 11 | 12 | from surya.common.pretrained import SuryaPreTrainedModel 13 | from surya.common.s3 import S3DownloaderMixin 14 | from surya.common.surya.config import SuryaModelConfig 15 | from surya.common.surya.decoder import SuryaDecoderModel 16 | from surya.common.surya.embedder import SimpleTokenEmbedder 17 | from surya.common.surya.encoder import SuryaEncoderModel 18 | from surya.common.util import pad_to_batch_size, pad_to_batch_size_repeat 19 | from surya.common.xla import get_nearest_pad 20 | from surya.settings import settings 21 | 22 | from surya.logging import get_logger 23 | 24 | logger = get_logger() 25 | 26 | 27 | @dataclass 28 | class SuryaModelOutput(CausalLMOutputWithPast): 29 | bbox_logits: torch.FloatTensor = None 30 | lm_logits: torch.FloatTensor = None 31 | 32 | 33 | class FlashAttentionKwargs(TypedDict, total=False): 34 | """ 35 | Keyword arguments for Flash Attention with Compile. 36 | 37 | Attributes: 38 | cu_seq_lens_q (`torch.LongTensor`, *optional*) 39 | Gets cumlative sequence length for query state. 40 | cu_seq_lens_k (`torch.LongTensor`, *optional*) 41 | Gets cumlative sequence length for key state. 42 | max_length_q (`int`, *optional*): 43 | Maximum sequence length for query state. 44 | max_length_k (`int`, *optional*): 45 | Maximum sequence length for key state. 46 | """ 47 | 48 | cu_seq_lens_q: Optional[torch.LongTensor] 49 | cu_seq_lens_k: Optional[torch.LongTensor] 50 | max_length_q: Optional[int] 51 | max_length_k: Optional[int] 52 | 53 | 54 | class KwargsForCausalLM(FlashAttentionKwargs): ... 55 | 56 | 57 | class DistanceProjection(nn.Module): 58 | def __init__(self, in_features: int, out_features: int): 59 | super().__init__() 60 | self.fc1 = nn.Linear(in_features, out_features) 61 | self.act = nn.SiLU() 62 | self.fc2 = nn.Linear(out_features, out_features) 63 | 64 | def forward(self, x: torch.Tensor) -> torch.Tensor: 65 | x = self.fc1(x) 66 | x = self.act(x) 67 | x = self.fc2(x) 68 | return x 69 | 70 | def init_weights(self): 71 | nn.init.xavier_uniform_(self.fc1.weight) 72 | nn.init.xavier_uniform_(self.fc2.weight) 73 | nn.init.zeros_(self.fc1.bias) 74 | nn.init.zeros_(self.fc2.bias) 75 | 76 | 77 | class BboxHead(nn.Module): 78 | def __init__(self, in_features: int, out_features: int): 79 | super().__init__() 80 | self.proj_layers = nn.ModuleList( 81 | [nn.Linear(in_features, in_features) for _ in range(6)] 82 | ) 83 | self.act = nn.SiLU() 84 | self.out_proj = nn.Linear(in_features, out_features) 85 | 86 | def forward(self, x: torch.Tensor) -> torch.Tensor: 87 | for layer in self.proj_layers: 88 | x = layer(x) 89 | x = self.act(x) 90 | 91 | x = self.out_proj(x) 92 | return x 93 | 94 | 95 | class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel): 96 | config_class = SuryaModelConfig 97 | supports_gradient_checkpointing = True 98 | _skip_keys_device_placement = ["past_key_values"] 99 | _supports_flash_attn_2 = True 100 | _supports_sdpa = True 101 | _supports_flex_attn = True 102 | _supports_cache_class = True 103 | _supports_quantized_cache = True 104 | _supports_static_cache = True 105 | _supports_attention_backend = True 106 | main_input_name = "input_ids" 107 | _tied_weights_keys = ["lm_head.weight"] 108 | 109 | def __init__( 110 | self, 111 | config: SuryaModelConfig, 112 | embedder: SimpleTokenEmbedder = None, 113 | vision_encoder: SuryaEncoderModel = None, 114 | decoder: SuryaDecoderModel = None, 115 | **kwargs, 116 | ): 117 | super().__init__(config, **kwargs) 118 | 119 | if vision_encoder is None: 120 | vision_encoder = SuryaEncoderModel(config.vision_encoder) 121 | 122 | if decoder is None: 123 | decoder = SuryaDecoderModel(config.decoder) 124 | 125 | if embedder is None: 126 | embedder = SimpleTokenEmbedder(config) 127 | 128 | self.vision_encoder = vision_encoder 129 | self.decoder = decoder 130 | self.embedder = embedder 131 | 132 | # Simple encoding for image patches 133 | self.img_w_embed = nn.Embedding( 134 | self.config.image_embed_encoding_size, 135 | self.config.hidden_size, 136 | ) 137 | 138 | self.img_h_embed = nn.Embedding( 139 | self.config.image_embed_encoding_size, 140 | self.config.hidden_size, 141 | ) 142 | 143 | # Tying configs 144 | self.vision_encoder.config = self.config.vision_encoder 145 | self.decoder.config = self.config.decoder 146 | 147 | self.bbox_head = BboxHead(config.hidden_size, 6) 148 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) 149 | 150 | if ( 151 | self.config.multi_output_distance is not None 152 | and self.config.multi_output_distance > 0 153 | ): 154 | self.multi_output_projections = nn.ModuleList( 155 | [ 156 | DistanceProjection( 157 | in_features=config.hidden_size, out_features=config.hidden_size 158 | ) 159 | for _ in range(self.config.multi_output_distance) 160 | ] 161 | ) 162 | 163 | def tie_weights(self): 164 | self._tie_weights() 165 | 166 | def _tie_weights(self): 167 | # Tie weights of lm head and token embedder 168 | self._tie_or_clone_weights(self.lm_head, self.embedder.token_embed) 169 | 170 | def get_output_embeddings(self) -> nn.Module: 171 | return self.lm_head 172 | 173 | def get_input_embeddings(self) -> nn.Module: 174 | return self.embedder.token_embed 175 | 176 | def set_output_embeddings(self, new_embeddings: nn.Module): 177 | self.lm_head = new_embeddings 178 | 179 | def set_input_embeddings(self, new_embeddings: nn.Module): 180 | self.embedder.token_embed = new_embeddings 181 | 182 | def maybe_static_pad_image_inputs( 183 | self, 184 | chunk_pixels: torch.Tensor, 185 | chunk_grid_thw: torch.Tensor, 186 | actual_chunk_len: int, 187 | encoder_chunk_size: int, 188 | ) -> Tuple[torch.Tensor, torch.Tensor]: 189 | valid_embed_len = actual_chunk_len // ( 190 | self.vision_encoder.spatial_merge_size**2 191 | ) 192 | if settings.FOUNDATION_STATIC_CACHE and actual_chunk_len < encoder_chunk_size: 193 | padding_len = encoder_chunk_size - actual_chunk_len 194 | chunk_pixels = F.pad( 195 | chunk_pixels, 196 | (0, 0, 0, padding_len), 197 | mode="constant", 198 | value=0.0, 199 | ) 200 | 201 | padding_grid = torch.tensor( 202 | [[1, 2, padding_len // 2]], 203 | device=chunk_grid_thw.device, 204 | dtype=chunk_grid_thw.dtype, 205 | ) 206 | chunk_grid_thw = torch.cat([chunk_grid_thw, padding_grid], dim=0) 207 | 208 | return chunk_pixels, chunk_grid_thw, valid_embed_len 209 | 210 | def get_image_embeddings( 211 | self, 212 | pixel_values: torch.Tensor, 213 | grid_thw: torch.Tensor, 214 | encoder_chunk_size: int, 215 | valid_batch_size: torch.Tensor | None = None, 216 | max_batch_size: int | None = None, 217 | ): 218 | # embed all images with the vision encoder after they have already been tiled and flattened into a single batch 219 | chunks = [0] 220 | grid_chunks = [0] 221 | curr_chunk_len = 0 222 | curr_seq_len = 0 223 | for i in range(len(grid_thw)): 224 | curr_chunk_len += (grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]).item() 225 | if curr_chunk_len > encoder_chunk_size: 226 | chunks.append(curr_chunk_len + curr_seq_len) 227 | curr_seq_len += curr_chunk_len 228 | curr_chunk_len = 0 229 | grid_chunks.append(i + 1) 230 | 231 | if curr_chunk_len > 0: 232 | chunks.append(pixel_values.shape[0]) 233 | grid_chunks.append(len(grid_thw)) 234 | 235 | assert curr_chunk_len + curr_seq_len == pixel_values.shape[0], ( 236 | f"Mismatch in encoder chunking, {curr_chunk_len} + {curr_seq_len} != {pixel_values.shape[0]}" 237 | ) 238 | 239 | logger.debug( 240 | f"Chunking encoder sequence into {len(chunks) - 1} chunks of size {encoder_chunk_size} with lengths {chunks} and grids {grid_chunks}" 241 | ) 242 | embeddings = [] 243 | for i in range(len(chunks) - 1): 244 | start = chunks[i] 245 | end = chunks[i + 1] 246 | grid_start = grid_chunks[i] 247 | grid_end = grid_chunks[i + 1] 248 | 249 | chunk_pixels = pixel_values[start:end] 250 | chunk_grid_thw = grid_thw[grid_start:grid_end] 251 | actual_chunk_len = end - start 252 | chunk_pixels, chunk_grid_thw, valid_embed_len = ( 253 | self.maybe_static_pad_image_inputs( 254 | chunk_pixels, chunk_grid_thw, actual_chunk_len, encoder_chunk_size 255 | ) 256 | ) 257 | 258 | chunk_embeddings = self.vision_encoder.embed_images( 259 | image_batch=chunk_pixels.unsqueeze(0).to(device=self.device), 260 | grid_thw=chunk_grid_thw.unsqueeze(0).to(device=self.device), 261 | ) 262 | embeddings.append(chunk_embeddings[:valid_embed_len].squeeze(0)) 263 | 264 | if len(embeddings) == 0: 265 | raise ValueError( 266 | "No image embeddings were generated. Check the input images and grid sizes." 267 | ) 268 | elif len(embeddings) == 1: 269 | embeddings = embeddings[0] 270 | else: 271 | embeddings = torch.cat(embeddings, dim=0) 272 | 273 | encoding_2d = self.get_2d_learned_embeddings( 274 | grid_thw, 275 | device=embeddings.device, 276 | bbox_size=self.config.image_embed_encoding_multiplier, 277 | ) 278 | assert embeddings.shape[0] == encoding_2d.shape[0], ( 279 | f"Mismatch in image embedding seq len: {embeddings.shape} vs {encoding_2d.shape}" 280 | ) 281 | assert embeddings.shape[1] == encoding_2d.shape[1], ( 282 | f"Mismatch in image embedding token counts: {embeddings.shape} vs {encoding_2d.shape}" 283 | ) 284 | 285 | embeddings = embeddings + encoding_2d 286 | 287 | return embeddings 288 | 289 | def embed_ids_boxes_images( 290 | self, 291 | input_ids, 292 | image_embeddings, 293 | encoder_chunk_size: int, 294 | valid_batch_size: torch.Tensor | None = None, 295 | input_boxes: torch.Tensor | None = None, 296 | embed_boxes: torch.Tensor | None = None, 297 | ): 298 | """ 299 | Insert embedded image tiles into the corresponding positions into the full input sequence 300 | 301 | Positions to insert new tokens are indicated by the special image token index 302 | """ 303 | # This is batched in the inner call 304 | inputs_embeds = self.embedder.embed( 305 | input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes 306 | ) 307 | 308 | if image_embeddings is not None: 309 | special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) 310 | special_image_mask = special_image_mask.expand_as(inputs_embeds) 311 | if inputs_embeds[special_image_mask].numel() != image_embeddings.numel(): 312 | n_image_tokens = torch.sum((input_ids == self.config.image_token_id)) 313 | n_image_features = image_embeddings.shape[0] * image_embeddings.shape[1] 314 | warnings.warn( 315 | f"Image features and image tokens do not match: tokens {n_image_tokens}, features {n_image_features}. This may lead to unexpected results" 316 | ) 317 | image_features = image_embeddings.to(inputs_embeds.dtype) 318 | inputs_embeds = inputs_embeds.masked_scatter( 319 | special_image_mask, image_features 320 | ) 321 | else: 322 | assert (input_ids == self.config.image_token_id).sum() == 0, ( 323 | "Image tokens were present in the input but no input images were provided" 324 | ) 325 | 326 | return inputs_embeds 327 | 328 | def get_2d_learned_embeddings( 329 | self, 330 | grid_thw, 331 | device: str | torch.device = "cpu", 332 | bbox_size: int = 256, 333 | ): 334 | all_embeddings = [] 335 | for grid_t, grid_h, grid_w in grid_thw: 336 | llm_grid_h, llm_grid_w = ( 337 | grid_h // self.config.merge_size, 338 | grid_w // self.config.merge_size, 339 | ) 340 | 341 | # Scale to 0-1024 342 | llm_grid_h = ( 343 | torch.arange(llm_grid_h, device=device) 344 | / max(1, (llm_grid_h - 1)) 345 | * bbox_size 346 | ) 347 | llm_grid_w = ( 348 | torch.arange(llm_grid_w, device=device) 349 | / max(1, (llm_grid_w - 1)) 350 | * bbox_size 351 | ) 352 | 353 | llm_grid_w_idx = llm_grid_w.to(torch.long) 354 | llm_grid_h_idx = llm_grid_h.to(torch.long) 355 | 356 | llm_grid_w = self.img_w_embed(llm_grid_w_idx) 357 | llm_grid_h = self.img_h_embed(llm_grid_h_idx) 358 | 359 | full_grid = llm_grid_h[:, None] + llm_grid_w[None, :] 360 | 361 | flattened = full_grid.flatten( 362 | 0, 1 363 | ) # Flatten first dimension, so they are seq_len x embed_dim 364 | all_embeddings.append(flattened) 365 | return torch.concat( 366 | all_embeddings, dim=0 367 | ) # Shape is num_image_tokens x embed_dim 368 | 369 | def get_logits(self, hidden_states): 370 | assert hidden_states.shape[1] == 1, ( 371 | "Multi output predictions only applied on the last token" 372 | ) 373 | 374 | all_lm_logits = [] 375 | all_bbox_logits = [] 376 | 377 | current_hidden = hidden_states 378 | 379 | # Loop includes initial prediction (i=0) plus multi_out_distance additional predictions 380 | for i in range(self.config.multi_output_distance + 1): 381 | if i > 0: 382 | current_hidden = self.multi_output_projections[i - 1](current_hidden) 383 | 384 | lm_logits = self.lm_head(current_hidden) 385 | bbox_logits = F.sigmoid(self.bbox_head(current_hidden)) 386 | 387 | all_lm_logits.append(lm_logits) 388 | all_bbox_logits.append(bbox_logits) 389 | 390 | # Concatenate along sequence dimension (dim=1) 391 | final_lm_logits = torch.cat(all_lm_logits, dim=1) 392 | final_bbox_logits = torch.cat(all_bbox_logits, dim=1) 393 | 394 | return final_lm_logits, final_bbox_logits 395 | 396 | def forward( 397 | self, 398 | input_ids=None, 399 | image_embeddings=None, 400 | labels=None, 401 | image_tiles=None, 402 | grid_thw=None, 403 | inputs_embeds=None, 404 | attention_mask=None, 405 | position_ids=None, 406 | cache_position=None, 407 | past_key_values=None, 408 | output_hidden_states=False, 409 | output_attentions=False, 410 | use_cache=False, 411 | encoder_chunk_size=32768, 412 | cache_idxs=None, 413 | num_valid_tokens=None, 414 | prefill=True, 415 | text_lengths=None, 416 | valid_batch_size: torch.Tensor = None, 417 | input_boxes=None, 418 | embed_boxes=None, 419 | logits_to_keep=None, 420 | **kwargs: KwargsForCausalLM, 421 | ): 422 | if any([ 423 | input_ids is None, 424 | position_ids is None, 425 | cache_position is None, 426 | ( 427 | prefill 428 | and not ( 429 | (image_tiles is not None and grid_thw is not None) 430 | or image_embeddings is not None 431 | ) 432 | ), 433 | ]): 434 | raise ValueError( 435 | "`input_ids`, `position_ids`, and `cache_position` **must** be specified. " 436 | "For prefill, you must provide either (`image_tiles` and `grid_thw`) or `image_embeddings`." 437 | ) 438 | 439 | 440 | inputs_embeds = self.embed_ids_boxes_images( 441 | input_ids, image_embeddings, encoder_chunk_size, valid_batch_size, input_boxes, embed_boxes 442 | ) 443 | 444 | # Handling flash attention kwargs outside the decoder to speed up + avoid graph breaks inside the decoder 445 | # Skipped during decoding since not required 446 | if self.decoder.config._attn_implementation == "flash_attention_2" and prefill: 447 | # Needed for CPU -> GPU 448 | from surya.common.surya.flash_attn_utils import _get_unpad_data 449 | batch_size, query_length, _ = inputs_embeds.shape 450 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( 451 | attention_mask 452 | ) 453 | kwargs["batch_size"] = batch_size 454 | kwargs["query_length"] = query_length 455 | kwargs["indices_k"] = indices_k 456 | kwargs["cu_seqlens_k"] = cu_seqlens_k 457 | kwargs["max_seqlen_in_batch_k"] = max_seqlen_in_batch_k 458 | 459 | causal_mask = self._update_causal_mask( 460 | attention_mask, 461 | inputs_embeds, 462 | cache_position, 463 | past_key_values, 464 | output_attentions, 465 | ) 466 | 467 | attention_mask = causal_mask 468 | outputs = self.decoder( 469 | inputs_embeds=inputs_embeds, 470 | attention_mask=attention_mask, 471 | position_ids=position_ids, 472 | cache_position=cache_position, 473 | past_key_values=past_key_values, 474 | return_dict=True, 475 | use_cache=use_cache, 476 | cache_idxs=cache_idxs, 477 | num_valid_tokens=num_valid_tokens, 478 | prefill=prefill, 479 | text_lengths=text_lengths, 480 | **kwargs, 481 | ) 482 | 483 | hidden_states = outputs.last_hidden_state 484 | if logits_to_keep is not None: 485 | hidden_states = hidden_states[:, -logits_to_keep:, :] 486 | hidden_states = hidden_states.contiguous() 487 | 488 | loss = None 489 | if labels is not None: 490 | # Training, return full logits 491 | lm_logits = self.lm_head(hidden_states) 492 | bbox_logits = None 493 | vocab_size = lm_logits.shape[-1] 494 | labels = torch.roll(labels, shifts=-1, dims=-1) 495 | loss = F.cross_entropy( 496 | lm_logits.view(-1, vocab_size), labels.view(-1), reduction="mean" 497 | ) 498 | else: 499 | lm_logits, bbox_logits = self.get_logits(hidden_states) 500 | 501 | return SuryaModelOutput( 502 | loss=loss, 503 | bbox_logits=bbox_logits, 504 | lm_logits=lm_logits, 505 | hidden_states=outputs.hidden_states if output_hidden_states else None, 506 | attentions=outputs.attentions if output_attentions else None, 507 | past_key_values=outputs.past_key_values, 508 | ) 509 | 510 | def _update_causal_mask( 511 | self, 512 | attention_mask: torch.Tensor, 513 | input_tensor: torch.Tensor, 514 | cache_position: torch.Tensor, 515 | past_key_values: Cache, 516 | output_attentions: bool, 517 | ): 518 | if self.decoder.config._attn_implementation == "flash_attention_2": 519 | return attention_mask 520 | 521 | # We always pass in a 2D attention mask from the processor - In both static and dynamic cache cases 522 | dtype, device = input_tensor.dtype, input_tensor.device 523 | min_dtype = torch.finfo(dtype).min 524 | sequence_length = input_tensor.shape[1] 525 | target_length = ( 526 | attention_mask.shape[-1] 527 | if isinstance(attention_mask, torch.Tensor) 528 | else past_key_values.max_cache_len 529 | ) 530 | 531 | # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). 532 | causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( 533 | attention_mask, 534 | sequence_length=sequence_length, 535 | target_length=target_length, 536 | dtype=dtype, 537 | device=device, 538 | cache_position=cache_position, 539 | batch_size=input_tensor.shape[0], 540 | config=self.config, 541 | past_key_values=past_key_values, 542 | ) 543 | 544 | if ( 545 | self.config._attn_implementation == "sdpa" 546 | and attention_mask is not None 547 | and attention_mask.device.type in ["cuda", "xpu"] 548 | and not output_attentions 549 | ): 550 | # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when 551 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. 552 | # Details: https://github.com/pytorch/pytorch/issues/110213 553 | causal_mask = AttentionMaskConverter._unmask_unattended( 554 | causal_mask, min_dtype 555 | ) 556 | 557 | return causal_mask 558 | 559 | @staticmethod 560 | def _prepare_4d_causal_attention_mask_with_cache_position( 561 | attention_mask: torch.Tensor, 562 | sequence_length: int, 563 | target_length: int, 564 | dtype: torch.dtype, 565 | device: torch.device, 566 | cache_position: torch.Tensor, 567 | batch_size: int, 568 | config: SuryaModelConfig, 569 | past_key_values: Cache, 570 | ): 571 | """ 572 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape 573 | `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. 574 | 575 | Args: 576 | attention_mask (`torch.Tensor`): 577 | A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. 578 | sequence_length (`int`): 579 | The sequence length being processed. 580 | target_length (`int`): 581 | The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. 582 | dtype (`torch.dtype`): 583 | The dtype to use for the 4D attention mask. 584 | device (`torch.device`): 585 | The device to plcae the 4D attention mask on. 586 | cache_position (`torch.Tensor`): 587 | Indices depicting the position of the input sequence tokens in the sequence. Shape `(batch_size, sequence_length)`. 588 | batch_size (`torch.Tensor`): 589 | Batch size. 590 | config (`Qwen2Config`): 591 | The model's configuration class 592 | past_key_values (`Cache`): 593 | The cache class that is being used currently to generate 594 | """ 595 | if attention_mask is not None and attention_mask.dim() == 4: 596 | # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. 597 | causal_mask = attention_mask 598 | else: 599 | min_dtype = torch.finfo(dtype).min 600 | causal_mask = torch.full( 601 | (sequence_length, target_length), 602 | fill_value=min_dtype, 603 | dtype=dtype, 604 | device=device, 605 | ) 606 | # Batch-aware diagonal attend mask 607 | diagonal_attend_mask = torch.arange(target_length, device=device).unsqueeze( 608 | 0 609 | ) > cache_position.unsqueeze(-1) 610 | causal_mask = ( 611 | causal_mask.unsqueeze(0) * diagonal_attend_mask 612 | ) # (batch_size, seq_len, target_len) 613 | causal_mask = causal_mask[ 614 | :, None, :, : 615 | ] # (batch_size, 1, seq_len, target_len) 616 | if attention_mask is not None: 617 | causal_mask = ( 618 | causal_mask.clone() 619 | ) # copy to contiguous memory for in-place edit 620 | if attention_mask.shape[-1] > target_length: 621 | attention_mask = attention_mask[:, :target_length] 622 | mask_length = attention_mask.shape[-1] 623 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ 624 | :, None, None, : 625 | ].to(causal_mask.device) 626 | padding_mask = padding_mask == 0 627 | causal_mask[:, :, :, :mask_length] = causal_mask[ 628 | :, :, :, :mask_length 629 | ].masked_fill(padding_mask, min_dtype) 630 | return causal_mask 631 | 632 | class SuryaXLAModel(SuryaModel): 633 | def get_image_embeddings( 634 | self, 635 | pixel_values: torch.Tensor, 636 | grid_thw: torch.Tensor, 637 | encoder_chunk_size: int, 638 | valid_batch_size: torch.Tensor | None = None, 639 | max_batch_size: int | None = None, 640 | ): 641 | # embed all images with the vision encoder after they have already been tiled and flattened into a single batch 642 | unpadded_max_grid_size = ( 643 | (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).max().item() 644 | ) 645 | max_grid_size = get_nearest_pad( 646 | unpadded_max_grid_size, 647 | ) # If we need zero padding, we still need to allocate a bit of room for the extra grid_thw 648 | 649 | # Always need 2 items in each row batch 650 | if max_grid_size == unpadded_max_grid_size: 651 | max_grid_size += 16 652 | 653 | full_image_grid = torch.zeros( 654 | (valid_batch_size, max_grid_size, pixel_values.shape[-1]), 655 | dtype=pixel_values.dtype, 656 | ) 657 | 658 | # Roll out into a full grid 659 | seq_len = 0 660 | row_grids = [] 661 | for i in range(valid_batch_size): 662 | curr_sample_len = grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2] 663 | full_image_grid[i, -curr_sample_len:] = pixel_values[ 664 | seq_len : seq_len + curr_sample_len 665 | ] 666 | padded_len = max_grid_size - curr_sample_len 667 | if padded_len > 0: 668 | row_grid = torch.tensor( 669 | [ 670 | [1, 4, padded_len // 4], 671 | grid_thw[i].tolist(), 672 | ], 673 | dtype=torch.long, 674 | ) 675 | else: 676 | row_grid = torch.tensor( 677 | [ 678 | grid_thw[i].tolist(), 679 | ], 680 | dtype=torch.long, 681 | ) 682 | 683 | row_grids.append(row_grid) 684 | seq_len += curr_sample_len 685 | 686 | # bsz, 2, 3 687 | row_grids = torch.stack(row_grids, dim=0) 688 | 689 | if settings.FOUNDATION_STATIC_CACHE: 690 | # Pad to max batch size, repeat the final row 691 | row_grids = pad_to_batch_size_repeat( 692 | row_grids, 693 | batch_size=max_batch_size, 694 | ) 695 | full_image_grid = pad_to_batch_size( 696 | full_image_grid, 697 | batch_size=max_batch_size, 698 | ) 699 | 700 | full_image_grid = full_image_grid.to(self.device) 701 | 702 | embeddings = self.vision_encoder.embed_images( 703 | image_batch=full_image_grid, grid_thw=row_grids.to(self.device) 704 | ) 705 | 706 | encoding_2d = self.get_2d_learned_embeddings( 707 | row_grids, 708 | bbox_size=self.config.image_embed_encoding_multiplier, 709 | ) 710 | embeddings += encoding_2d 711 | 712 | return embeddings 713 | 714 | def embed_ids_boxes_images( 715 | self, 716 | input_ids, 717 | image_embeddings, 718 | encoder_chunk_size: int, 719 | valid_batch_size: torch.Tensor | None = None, 720 | input_boxes: torch.Tensor | None = None, 721 | embed_boxes: torch.Tensor | None = None, 722 | ): 723 | """ 724 | Insert embedded image tiles into the corresponding positions into the full input sequence 725 | 726 | Positions to insert new tokens are indicated by the special image token index 727 | """ 728 | # This is batched in the inner call 729 | inputs_embeds = self.embedder.embed( 730 | input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes 731 | ) 732 | 733 | if image_embeddings is not None: 734 | image_token_id_tensor = torch.tensor( 735 | self.config.image_token_id, 736 | device=inputs_embeds.device, 737 | dtype=torch.long, 738 | ) 739 | mask = input_ids == image_token_id_tensor 740 | last_image_token_pos = ( 741 | mask.size(1) 742 | - 1 743 | - mask.flip(dims=[1]).long().argmax(dim=1, keepdim=True) 744 | ) 745 | # Calculate start position to replace N positions ending at (and including) the last image token 746 | start_positions = last_image_token_pos - image_embeddings[0].shape[0] 747 | batch_size, insert_len = image_embeddings.shape[:2] 748 | 749 | # Create position indices for each insertion 750 | pos_indices = torch.arange( 751 | insert_len, device=inputs_embeds.device 752 | ).unsqueeze(0) 753 | insert_positions = start_positions + pos_indices 754 | 755 | idx = insert_positions.unsqueeze(-1).expand( 756 | -1, -1, inputs_embeds.size(-1) 757 | ) # [B,N,D] 758 | inputs_embeds = inputs_embeds.scatter(1, idx, image_embeddings) 759 | 760 | inputs_embeds = inputs_embeds * ( 761 | input_ids != self.config.pad_token_id 762 | ).unsqueeze(-1).to(inputs_embeds.dtype) 763 | return inputs_embeds 764 | 765 | def get_2d_learned_embeddings( 766 | self, 767 | grid_thw, 768 | bbox_size: int = 256, 769 | ): 770 | dev = grid_thw.device 771 | all_row_coords = [] 772 | all_col_coords = [] 773 | for row_grid in grid_thw: 774 | merge = self.config.merge_size 775 | 776 | # per-sample grid sizes after merge 777 | H = (row_grid[:, 1] // merge).long() # (B,) 778 | W = (row_grid[:, 2] // merge).long() # (B,) 779 | 780 | row_coords = torch.cat( 781 | [ 782 | torch.linspace(0, bbox_size, steps=int(h), device=dev) 783 | .round() 784 | .repeat_interleave(w) # repeat each row value w times 785 | for h, w in zip(H.tolist(), W.tolist()) 786 | ] 787 | ) # (full_grid_size,) 788 | 789 | col_coords = torch.cat( 790 | [ 791 | torch.linspace(0, bbox_size, steps=int(w), device=dev) 792 | .round() 793 | .repeat(int(h)) # tile the column vector h times 794 | for h, w in zip(H.tolist(), W.tolist()) 795 | ] 796 | ) # (full_grid_size,) 797 | all_row_coords.append(row_coords) 798 | all_col_coords.append(col_coords) 799 | row_coords = torch.stack(all_row_coords, dim=0).to(self.device) 800 | col_coords = torch.stack(all_col_coords, dim=0).to(self.device) 801 | 802 | emb = self.img_h_embed(row_coords.long()) + self.img_w_embed(col_coords.long()) 803 | return emb 804 | ``` -------------------------------------------------------------------------------- /surya/common/surya/encoder/__init__.py: -------------------------------------------------------------------------------- ```python 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from transformers.activations import ACT2FN 8 | 9 | from surya.common.pretrained import SuryaPreTrainedModel 10 | from surya.common.surya.encoder.config import SuryaEncoderConfig 11 | from surya.common.xla import get_nearest_pad 12 | from surya.logging import get_logger 13 | from surya.settings import settings 14 | 15 | if settings.FOUNDATION_XLA: 16 | import torch_xla.experimental.custom_kernel 17 | 18 | from surya.logging import get_logger 19 | logger = get_logger() 20 | 21 | 22 | class Qwen2_5_VLMLP(nn.Module): 23 | def __init__(self, config, bias: bool = False): 24 | super().__init__() 25 | self.hidden_size = config.hidden_size 26 | self.intermediate_size = config.intermediate_size 27 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) 28 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) 29 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) 30 | self.act_fn = ACT2FN[config.hidden_act] 31 | 32 | def forward(self, hidden_state): 33 | return self.down_proj( 34 | self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) 35 | ) 36 | 37 | 38 | class Qwen2_5_VisionPatchEmbed(nn.Module): 39 | def __init__( 40 | self, 41 | patch_size: int = 14, 42 | temporal_patch_size: int = 2, 43 | in_channels: int = 3, 44 | embed_dim: int = 1152, 45 | ) -> None: 46 | super().__init__() 47 | self.patch_size = patch_size 48 | self.temporal_patch_size = temporal_patch_size 49 | self.in_channels = in_channels 50 | self.embed_dim = embed_dim 51 | 52 | kernel_size = [temporal_patch_size, patch_size, patch_size] 53 | self.proj = nn.Conv3d( 54 | in_channels, 55 | embed_dim, 56 | kernel_size=kernel_size, 57 | stride=kernel_size, 58 | bias=False, 59 | ) 60 | 61 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 62 | target_dtype = self.proj.weight.dtype 63 | bsz = hidden_states.shape[0] 64 | hidden_states = hidden_states.view( 65 | -1, 66 | self.in_channels, 67 | self.temporal_patch_size, 68 | self.patch_size, 69 | self.patch_size, 70 | ) 71 | hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( 72 | bsz, -1, self.embed_dim 73 | ) 74 | return hidden_states 75 | 76 | 77 | class Qwen2_5_VisionRotaryEmbedding(nn.Module): 78 | def __init__(self, dim: int, theta: float = 10000.0) -> None: 79 | super().__init__() 80 | self.inv_freq = 1.0 / ( 81 | theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim) 82 | ) 83 | 84 | def forward(self, seqlen: int) -> torch.Tensor: 85 | seq = torch.arange(seqlen, device="cpu", dtype=self.inv_freq.dtype) 86 | freqs = torch.outer(seq, self.inv_freq) 87 | return freqs 88 | 89 | 90 | class Qwen2RMSNorm(nn.Module): 91 | def __init__(self, hidden_size, eps=1e-6): 92 | """ 93 | Qwen2RMSNorm is equivalent to T5LayerNorm 94 | """ 95 | super().__init__() 96 | self.weight = nn.Parameter(torch.ones(hidden_size)) 97 | self.variance_epsilon = eps 98 | 99 | def forward(self, hidden_states): 100 | input_dtype = hidden_states.dtype 101 | hidden_states = hidden_states.to(torch.float32) 102 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 103 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 104 | return self.weight * hidden_states.to(input_dtype) 105 | 106 | def extra_repr(self): 107 | return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" 108 | 109 | 110 | class Qwen2_5_VLPatchMerger(nn.Module): 111 | def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: 112 | super().__init__() 113 | self.hidden_size = context_dim * (spatial_merge_size**2) 114 | self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) 115 | self.mlp = nn.Sequential( 116 | nn.Linear(self.hidden_size, self.hidden_size), 117 | nn.GELU(), 118 | nn.Linear(self.hidden_size, dim), 119 | ) 120 | 121 | def forward(self, x: torch.Tensor) -> torch.Tensor: 122 | bsz = x.shape[0] 123 | x = self.mlp(self.ln_q(x).view(bsz, -1, self.hidden_size)) 124 | return x 125 | 126 | 127 | def apply_rotary_pos_emb_flashatt( 128 | q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor 129 | ) -> Tuple[torch.Tensor, torch.Tensor]: 130 | from flash_attn.layers.rotary import apply_rotary_emb 131 | 132 | cos = cos.chunk(2, dim=-1)[0].contiguous() 133 | sin = sin.chunk(2, dim=-1)[0].contiguous() 134 | q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) 135 | k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) 136 | return q_embed, k_embed 137 | 138 | 139 | class Qwen2_5_VLVisionXLASdpaAttention(nn.Module): 140 | def __init__(self, dim: int, num_heads: int = 16) -> None: 141 | super().__init__() 142 | self.num_heads = num_heads 143 | self.qkv = nn.Linear(dim, dim * 3, bias=True) 144 | self.proj = nn.Linear(dim, dim) 145 | self.head_dim = dim // num_heads 146 | 147 | def forward( 148 | self, 149 | hidden_states: torch.Tensor, 150 | cu_seqlens: torch.Tensor, 151 | rotary_pos_emb: Optional[torch.Tensor] = None, 152 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 153 | ) -> torch.Tensor: 154 | bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1] 155 | q, k, v = ( 156 | self.qkv(hidden_states) 157 | .reshape(bsz, seq_length, 3, self.num_heads, -1) 158 | .permute(0, 2, 1, 3, 4) 159 | .unbind(1) 160 | ) 161 | if position_embeddings is None: 162 | logger.warning_once( 163 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 164 | "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " 165 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " 166 | "removed and `position_embeddings` will be mandatory." 167 | ) 168 | emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) 169 | cos = emb.cos() 170 | sin = emb.sin() 171 | else: 172 | cos, sin = position_embeddings 173 | q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) 174 | 175 | attention_mask = torch.zeros([bsz, 1, seq_length, seq_length], dtype=torch.bool) 176 | cu_seqlens_cpu = cu_seqlens.cpu() 177 | for j in range(bsz): 178 | batch_seqlens = cu_seqlens_cpu[j] 179 | for i in range(1, len(batch_seqlens)): 180 | attention_mask[ 181 | j, 182 | ..., 183 | batch_seqlens[i - 1] : batch_seqlens[i], 184 | batch_seqlens[i - 1] : batch_seqlens[i], 185 | ] = True 186 | 187 | attention_mask = attention_mask.to(q.device) 188 | 189 | q = q.transpose(1, 2) 190 | k = k.transpose(1, 2) 191 | v = v.transpose(1, 2) 192 | 193 | attn_output = F.scaled_dot_product_attention( 194 | q, 195 | k, 196 | v, 197 | attention_mask, 198 | dropout_p=0.0, 199 | ) 200 | attn_output = attn_output.transpose(1, 2) 201 | attn_output = attn_output.reshape(bsz, seq_length, -1) 202 | attn_output = self.proj(attn_output) 203 | return attn_output 204 | 205 | 206 | class Qwen2_5_VLVisionXLAFlashAttention2(nn.Module): 207 | def __init__(self, dim: int, num_heads: int = 16) -> None: 208 | super().__init__() 209 | self.num_heads = num_heads 210 | self.qkv = nn.Linear(dim, dim * 3, bias=True) 211 | self.proj = nn.Linear(dim, dim) 212 | self.head_dim = dim // num_heads 213 | 214 | def forward( 215 | self, 216 | hidden_states: torch.Tensor, 217 | cu_seqlens: torch.Tensor, 218 | rotary_pos_emb: Optional[torch.Tensor] = None, 219 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 220 | ) -> torch.Tensor: 221 | # Note, this is faster than SDPA, but pretty memory inefficient 222 | # It also has significant accuracy issues 223 | 224 | bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1] 225 | 226 | # Single reshape to target layout - avoid multiple operations 227 | q, k, v = ( 228 | self.qkv(hidden_states) 229 | .reshape(bsz, seq_length, 3, self.num_heads, -1) 230 | .permute(0, 2, 1, 3, 4) 231 | .unbind(1) 232 | ) 233 | 234 | # Apply rotary embeddings if provided 235 | if position_embeddings is not None: 236 | cos, sin = position_embeddings 237 | q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) 238 | 239 | # Single reshape to flash attention format [batch, num_heads, seq_len, head_dim] 240 | q = q.transpose(1, 2) # [bsz, num_heads, seq_len, head_dim] 241 | k = k.transpose(1, 2) 242 | v = v.transpose(1, 2) 243 | 244 | total_seqlen = q.shape[2] 245 | # from cu_seqlens to segment ids for each position in dim 0 246 | additive_bias = torch.zeros((bsz, 1, total_seqlen, total_seqlen), dtype=q.dtype) 247 | min_val = torch.finfo(q.dtype).min 248 | 249 | for i in range(bsz): 250 | padding_end = cu_seqlens[i][1].item() 251 | additive_bias[i, :, :, :padding_end] = min_val 252 | 253 | additive_bias = additive_bias.to(hidden_states.device) 254 | 255 | attn_scale = 1 / math.sqrt(self.head_dim) 256 | attn_output = torch_xla.experimental.custom_kernel.flash_attention( 257 | q, k, v, sm_scale=attn_scale, ab=additive_bias 258 | ) 259 | attn_output = ( 260 | attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_length, -1) 261 | ) 262 | attn_output = self.proj(attn_output) 263 | return attn_output 264 | 265 | 266 | class Qwen2_5_VLVisionFlashAttention2(nn.Module): 267 | def __init__(self, dim: int, num_heads: int = 16) -> None: 268 | super().__init__() 269 | self.num_heads = num_heads 270 | self.qkv = nn.Linear(dim, dim * 3, bias=True) 271 | self.proj = nn.Linear(dim, dim) 272 | 273 | def forward( 274 | self, 275 | hidden_states: torch.Tensor, 276 | cu_seqlens: torch.Tensor, 277 | rotary_pos_emb: Optional[torch.Tensor] = None, 278 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 279 | ) -> torch.Tensor: 280 | from flash_attn import flash_attn_varlen_func 281 | 282 | bsz = hidden_states.shape[0] 283 | seq_length = hidden_states.shape[1] 284 | q, k, v = ( 285 | self.qkv(hidden_states) 286 | .reshape(bsz, seq_length, 3, self.num_heads, -1) 287 | .permute(0, 2, 1, 3, 4) 288 | .unbind(1) 289 | ) 290 | if position_embeddings is None: 291 | logger.warning_once( 292 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 293 | "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " 294 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " 295 | "removed and `position_embeddings` will be mandatory." 296 | ) 297 | emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) 298 | cos = emb.cos() 299 | sin = emb.sin() 300 | else: 301 | cos, sin = position_embeddings 302 | 303 | q, k = apply_rotary_pos_emb_flashatt(q, k, cos.squeeze(0), sin.squeeze(0)) 304 | 305 | q = q.squeeze(0) 306 | k = k.squeeze(0) 307 | v = v.squeeze(0) 308 | cu_seqlens = cu_seqlens.squeeze(0) 309 | 310 | max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() 311 | attn_output = flash_attn_varlen_func( 312 | q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen 313 | ).reshape(bsz, seq_length, -1) 314 | attn_output = self.proj(attn_output) 315 | return attn_output 316 | 317 | 318 | def rotate_half(x): 319 | """Rotates half the hidden dims of the input.""" 320 | x1 = x[..., : x.shape[-1] // 2] 321 | x2 = x[..., x.shape[-1] // 2 :] 322 | return torch.cat((-x2, x1), dim=-1) 323 | 324 | 325 | def apply_rotary_pos_emb_vision( 326 | q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor 327 | ) -> Tuple[torch.Tensor, torch.Tensor]: 328 | orig_q_dtype = q.dtype 329 | orig_k_dtype = k.dtype 330 | q, k = q.float(), k.float() 331 | cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() 332 | q_embed = (q * cos) + (rotate_half(q) * sin) 333 | k_embed = (k * cos) + (rotate_half(k) * sin) 334 | q_embed = q_embed.to(orig_q_dtype) 335 | k_embed = k_embed.to(orig_k_dtype) 336 | return q_embed, k_embed 337 | 338 | 339 | class Qwen2_5_VLVisionAttention(nn.Module): 340 | def __init__(self, dim: int, num_heads: int = 16) -> None: 341 | super().__init__() 342 | self.num_heads = num_heads 343 | self.head_dim = dim // num_heads 344 | self.qkv = nn.Linear(dim, dim * 3, bias=True) 345 | self.proj = nn.Linear(dim, dim) 346 | 347 | def forward( 348 | self, 349 | hidden_states: torch.Tensor, 350 | cu_seqlens: torch.Tensor, 351 | rotary_pos_emb: Optional[torch.Tensor] = None, 352 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 353 | ) -> torch.Tensor: 354 | bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1] 355 | q, k, v = ( 356 | self.qkv(hidden_states) 357 | .reshape(bsz, seq_length, 3, self.num_heads, -1) 358 | .permute(0, 2, 1, 3, 4) 359 | .unbind(1) 360 | ) 361 | if position_embeddings is None: 362 | logger.warning_once( 363 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 364 | "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " 365 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " 366 | "removed and `position_embeddings` will be mandatory." 367 | ) 368 | emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) 369 | cos = emb.cos() 370 | sin = emb.sin() 371 | else: 372 | cos, sin = position_embeddings 373 | 374 | q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) 375 | 376 | attention_mask = torch.full( 377 | [bsz, 1, seq_length, seq_length], 378 | torch.finfo(q.dtype).min, 379 | device=q.device, 380 | dtype=q.dtype, 381 | ) 382 | for j in range(bsz): 383 | batch_seqlens = cu_seqlens[j] 384 | for i in range(1, len(batch_seqlens)): 385 | attention_mask[ 386 | j, 387 | ..., 388 | batch_seqlens[i - 1] : batch_seqlens[i], 389 | batch_seqlens[i - 1] : batch_seqlens[i], 390 | ] = 0 391 | 392 | q = q.transpose(1, 2) 393 | k = k.transpose(1, 2) 394 | v = v.transpose(1, 2) 395 | 396 | attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) 397 | attn_weights = attn_weights + attention_mask 398 | attn_weights = nn.functional.softmax( 399 | attn_weights, dim=-1, dtype=torch.float32 400 | ).to(q.dtype) 401 | attn_output = torch.matmul(attn_weights, v) 402 | attn_output = attn_output.transpose(1, 2) 403 | attn_output = attn_output.reshape(bsz, seq_length, -1) 404 | attn_output = self.proj(attn_output) 405 | return attn_output 406 | 407 | 408 | class Qwen2_5_VLVisionSdpaAttention(nn.Module): 409 | def __init__(self, dim: int, num_heads: int = 16) -> None: 410 | super().__init__() 411 | self.num_heads = num_heads 412 | self.qkv = nn.Linear(dim, dim * 3, bias=True) 413 | self.proj = nn.Linear(dim, dim) 414 | 415 | def unpack_qkv_with_mask(self, q, k, v, cu_seqlens): 416 | """ 417 | Unpacks q, k, v sequences into batch-major form and constructs an additive attention mask. 418 | 419 | Args: 420 | q, k, v: Tensors of shape (total_seq_len, num_heads, head_dim) 421 | cu_seqlens: Tensor of shape (batch_size + 1,) with cumulative sequence lengths 422 | 423 | Returns: 424 | batched_q: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim) 425 | batched_k: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim) 426 | batched_v: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim) 427 | attention_mask: Tensor of shape (batch_size, 1, max_seq_len, max_seq_len) 428 | with 0 for valid tokens and -inf for padding (for additive attention) 429 | """ 430 | device = q.device 431 | dtype = q.dtype 432 | 433 | batch_size = cu_seqlens.shape[0] - 1 434 | num_heads = q.shape[1] 435 | head_dim = q.shape[2] 436 | 437 | seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] # Keep as tensor 438 | max_seq_len = seq_lengths.max().item() # Use .max() on tensor 439 | 440 | if settings.FOUNDATION_STATIC_CACHE: 441 | # Pad max_seq_len to the nearest multiple for compilation 442 | max_seq_len = get_nearest_pad(max_seq_len, pad_multiple=16) 443 | 444 | # Pad batch_size to the nearest multiple for compilation 445 | batch_size = get_nearest_pad(batch_size, pad_multiple=2) 446 | 447 | # Ensure seq_lengths is a tensor of the correct size 448 | seq_lengths = F.pad( 449 | seq_lengths, (0, batch_size - seq_lengths.size(0)), "constant", 0 450 | ) 451 | 452 | # some day, you may look at this, and think: "what if I used repeat_interlave or some other fancy torch instead"? 453 | # don't do this - it's a path to madness. For some reason, this loop is optimal 454 | 455 | batch_indices = [] 456 | position_indices = [] 457 | 458 | for i, seq_len in enumerate( 459 | seq_lengths.tolist() 460 | ): # Convert to list only for iteration 461 | batch_indices.extend([i] * seq_len) 462 | position_indices.extend(list(range(seq_len))) 463 | 464 | batch_indices = torch.tensor(batch_indices, device=device) 465 | position_indices = torch.tensor(position_indices, device=device) 466 | 467 | batched_q = torch.zeros( 468 | (batch_size, max_seq_len, num_heads, head_dim), device=device, dtype=dtype 469 | ) 470 | batched_k = torch.zeros_like(batched_q) 471 | batched_v = torch.zeros_like(batched_q) 472 | 473 | # Create additive attention mask 474 | attention_mask = torch.full( 475 | (batch_size, max_seq_len, max_seq_len), 476 | fill_value=float("-inf"), 477 | device=device, 478 | dtype=dtype, 479 | ) 480 | 481 | # Create mask for valid positions 482 | seq_range = torch.arange(max_seq_len, device=device) 483 | valid_mask = seq_range.unsqueeze(0) < seq_lengths.unsqueeze( 484 | 1 485 | ) # (batch_size, max_seq_len) 486 | valid_2d = valid_mask.unsqueeze(2) & valid_mask.unsqueeze( 487 | 1 488 | ) # (batch_size, max_seq_len, max_seq_len) 489 | 490 | # Simply use boolean indexing to set valid positions to 0 491 | attention_mask[valid_2d] = 0 492 | 493 | attention_mask = attention_mask.unsqueeze( 494 | 1 495 | ) # (batch_size, 1, max_seq_len, max_seq_len) 496 | 497 | batched_q[batch_indices, position_indices] = q 498 | batched_k[batch_indices, position_indices] = k 499 | batched_v[batch_indices, position_indices] = v 500 | 501 | return ( 502 | batched_q, 503 | batched_k, 504 | batched_v, 505 | attention_mask, 506 | batch_indices, 507 | position_indices, 508 | ) 509 | 510 | def forward( 511 | self, 512 | hidden_states: torch.Tensor, 513 | cu_seqlens: torch.Tensor, 514 | rotary_pos_emb: Optional[torch.Tensor] = None, 515 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 516 | ) -> torch.Tensor: 517 | hidden_states = hidden_states.squeeze(0) 518 | cu_seqlens = cu_seqlens.squeeze(0) 519 | 520 | seq_length = hidden_states.shape[0] 521 | q, k, v = ( 522 | self.qkv(hidden_states) 523 | .reshape(seq_length, 3, self.num_heads, -1) 524 | .permute(1, 0, 2, 3) 525 | .unbind(0) 526 | ) 527 | if position_embeddings is None: 528 | logger.warning_once( 529 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 530 | "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " 531 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " 532 | "removed and `position_embeddings` will be mandatory." 533 | ) 534 | emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) 535 | cos = emb.cos() 536 | sin = emb.sin() 537 | else: 538 | cos, sin = position_embeddings 539 | q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) 540 | q = q.squeeze(0) 541 | k = k.squeeze(0) 542 | 543 | q, k, v, attention_mask, batch_indices, position_indices = ( 544 | self.unpack_qkv_with_mask(q, k, v, cu_seqlens) 545 | ) 546 | batch_size, max_seqlen = q.shape[:2] 547 | q = q.transpose(1, 2) 548 | k = k.transpose(1, 2) 549 | v = v.transpose(1, 2) 550 | 551 | attn_output = F.scaled_dot_product_attention( 552 | q, 553 | k, 554 | v, 555 | attention_mask, 556 | dropout_p=0.0, 557 | ) 558 | attn_output = attn_output.permute(0, 2, 1, 3).reshape( 559 | batch_size, max_seqlen, -1 560 | ) # Bring back to (batch_size, max_seqlen, hidden_dim) 561 | attn_output = attn_output[batch_indices, position_indices] 562 | attn_output = self.proj(attn_output) 563 | 564 | return attn_output.unsqueeze(0) 565 | 566 | 567 | QWEN2_5_VL_VISION_ATTENTION_CLASSES = { 568 | "eager": Qwen2_5_VLVisionAttention, 569 | "flash_attention_2": Qwen2_5_VLVisionXLAFlashAttention2 570 | if settings.FOUNDATION_XLA 571 | else Qwen2_5_VLVisionFlashAttention2, 572 | "sdpa": Qwen2_5_VLVisionXLASdpaAttention 573 | if settings.FOUNDATION_XLA 574 | else Qwen2_5_VLVisionSdpaAttention, 575 | } 576 | 577 | 578 | class Qwen2_5_VLVisionBlock(nn.Module): 579 | def __init__(self, config, attn_implementation: str = "sdpa") -> None: 580 | super().__init__() 581 | self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) 582 | self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) 583 | self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( 584 | config.hidden_size, num_heads=config.num_heads 585 | ) 586 | self.mlp = Qwen2_5_VLMLP(config, bias=True) 587 | 588 | def forward( 589 | self, 590 | hidden_states: torch.Tensor, 591 | cu_seqlens: torch.Tensor, 592 | rotary_pos_emb: Optional[torch.Tensor] = None, 593 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 594 | ) -> torch.Tensor: 595 | hidden_states = hidden_states + self.attn( 596 | self.norm1(hidden_states), 597 | cu_seqlens=cu_seqlens, 598 | rotary_pos_emb=rotary_pos_emb, 599 | position_embeddings=position_embeddings, 600 | ) 601 | hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) 602 | return hidden_states 603 | 604 | 605 | Qwen2_5_VL_START_DOCSTRING = r""" 606 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 607 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 608 | etc.) 609 | 610 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 611 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 612 | and behavior. 613 | 614 | Parameters: 615 | config ([`Qwen2_5_VLConfig`]): 616 | Model configuration class with all the parameters of the model. Initializing with a config file does not 617 | load the weights associated with the model, only the configuration. Check out the 618 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 619 | """ 620 | 621 | 622 | class Qwen2_5_VLPreTrainedModel(SuryaPreTrainedModel): 623 | config_class = SuryaEncoderConfig 624 | base_model_prefix = "model" 625 | supports_gradient_checkpointing = True 626 | _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] 627 | _skip_keys_device_placement = "past_key_values" 628 | _supports_flash_attn_2 = True 629 | _supports_sdpa = True 630 | _supports_cache_class = True 631 | _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` 632 | 633 | def _init_weights(self, module): 634 | std = self.config.initializer_range 635 | if isinstance(module, (nn.Linear, nn.Conv3d)): 636 | module.weight.data.normal_(mean=0.0, std=std) 637 | if module.bias is not None: 638 | module.bias.data.zero_() 639 | elif isinstance(module, nn.Embedding): 640 | module.weight.data.normal_(mean=0.0, std=std) 641 | if module.padding_idx is not None: 642 | module.weight.data[module.padding_idx].zero_() 643 | 644 | 645 | class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): 646 | config_class = SuryaEncoderConfig 647 | _no_split_modules = ["Qwen2_5_VLVisionBlock"] 648 | 649 | def __init__(self, config, *inputs, **kwargs) -> None: 650 | super().__init__(config, *inputs, **kwargs) 651 | self.spatial_merge_size = config.spatial_merge_size 652 | self.patch_size = config.patch_size 653 | self.fullatt_block_indexes = config.fullatt_block_indexes 654 | self.window_size = config.window_size 655 | self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size 656 | 657 | self.patch_embed = Qwen2_5_VisionPatchEmbed( 658 | patch_size=config.patch_size, 659 | temporal_patch_size=config.temporal_patch_size, 660 | in_channels=config.in_channels, 661 | embed_dim=config.hidden_size, 662 | ) 663 | 664 | head_dim = config.hidden_size // config.num_heads 665 | self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) 666 | 667 | self.blocks = nn.ModuleList( 668 | [ 669 | Qwen2_5_VLVisionBlock(config, config._attn_implementation) 670 | for _ in range(config.depth) 671 | ] 672 | ) 673 | self.merger = Qwen2_5_VLPatchMerger( 674 | dim=config.out_hidden_size, 675 | context_dim=config.hidden_size, 676 | spatial_merge_size=config.spatial_merge_size, 677 | ) 678 | self.gradient_checkpointing = False 679 | 680 | def rot_pos_emb(self, grid_thw): 681 | rotary_pos_emb = [] 682 | grid_thw_list = grid_thw.cpu().tolist() 683 | for batch_item in grid_thw_list: 684 | row_pos_ids = [] 685 | heights = [h for _, h, _ in batch_item] 686 | widths = [w for _, _, w in batch_item] 687 | max_grid_size = max(heights + widths) 688 | for t, h, w in batch_item: 689 | hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) 690 | hpos_ids = hpos_ids.reshape( 691 | h // self.spatial_merge_size, 692 | self.spatial_merge_size, 693 | w // self.spatial_merge_size, 694 | self.spatial_merge_size, 695 | ) 696 | hpos_ids = hpos_ids.permute(0, 2, 1, 3) 697 | hpos_ids = hpos_ids.flatten() 698 | 699 | wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) 700 | wpos_ids = wpos_ids.reshape( 701 | h // self.spatial_merge_size, 702 | self.spatial_merge_size, 703 | w // self.spatial_merge_size, 704 | self.spatial_merge_size, 705 | ) 706 | wpos_ids = wpos_ids.permute(0, 2, 1, 3) 707 | wpos_ids = wpos_ids.flatten() 708 | # shape: token_count, 2 709 | row_pos_ids.append( 710 | torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) 711 | ) 712 | # shape: token_count, 2 713 | pos_ids = torch.cat(row_pos_ids, dim=0) 714 | rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) 715 | rotary_pos_emb_row = rotary_pos_emb_full[pos_ids].flatten(1) 716 | rotary_pos_emb.append(rotary_pos_emb_row) 717 | rotary_pos_emb = torch.stack(rotary_pos_emb, dim=0) 718 | return rotary_pos_emb 719 | 720 | def forward( 721 | self, 722 | hidden_states: torch.Tensor, 723 | grid_thw: torch.Tensor, 724 | ) -> torch.Tensor: 725 | """ 726 | Args: 727 | hidden_states (`torch.Tensor` of shape `(bsz, seq_len, hidden_size)`): 728 | The final hidden states of the model. 729 | grid_thw (`torch.Tensor` of shape `(bsz, num_images_or_videos, 3)`): 730 | The temporal, height and width of feature shape of each image in LLM. 731 | 732 | Returns: 733 | `torch.Tensor`: hidden_states. 734 | """ 735 | bsz, seq_len, _ = hidden_states.size() 736 | hidden_states = self.patch_embed(hidden_states) # (bsz, seq_len, hidden_dim) 737 | rotary_pos_emb = self.rot_pos_emb(grid_thw) 738 | 739 | # hidden_states = hidden_states.reshape(bsz, seq_len, -1) 740 | # rotary_pos_emb = rotary_pos_emb.reshape(bsz, seq_len, -1) 741 | emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1).to( 742 | hidden_states.device 743 | ) 744 | position_embeddings = (emb.cos(), emb.sin()) 745 | 746 | cu_seqlens = (grid_thw[:, :, 1] * grid_thw[:, :, 2]).cumsum( 747 | dim=1, 748 | # Select dtype based on the following factors: 749 | # - FA2 requires that cu_seqlens_q must have dtype int32 750 | # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw 751 | # See https://github.com/huggingface/transformers/pull/34852 for more information 752 | dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, 753 | ) 754 | cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) 755 | for layer_num, blk in enumerate(self.blocks): 756 | if self.gradient_checkpointing and self.training: 757 | hidden_states = self._gradient_checkpointing_func( 758 | blk.__call__, 759 | hidden_states, 760 | cu_seqlens, 761 | None, 762 | position_embeddings, 763 | ) 764 | else: 765 | hidden_states = blk( 766 | hidden_states, 767 | cu_seqlens=cu_seqlens, 768 | position_embeddings=position_embeddings, 769 | ) 770 | 771 | hidden_states = self.merger(hidden_states) 772 | return hidden_states 773 | 774 | 775 | class SuryaEncoderModel(Qwen2_5_VisionTransformerPretrainedModel): 776 | @property 777 | def image_size(self) -> int: 778 | config: SuryaEncoderConfig = self.config 779 | if isinstance(config.image_size, tuple) and len(config.image_size) == 2: 780 | return config.image_size 781 | elif isinstance(config.image_size, int): 782 | return (config.image_size, config.image_size) 783 | 784 | raise ValueError( 785 | f"The `image_size` for SwinConfig should be a tuple of (int, int) or a single int but found {type(config.image_size)}" 786 | ) 787 | 788 | @property 789 | def hidden_size(self) -> int: 790 | config: SuryaEncoderConfig = self.config 791 | return config.hidden_size 792 | 793 | def embed_images( 794 | self, 795 | image_batch: torch.Tensor, 796 | grid_thw: torch.Tensor, 797 | ) -> torch.Tensor: 798 | return super().forward( 799 | hidden_states=image_batch, 800 | grid_thw=grid_thw, 801 | ) 802 | ``` -------------------------------------------------------------------------------- /surya/common/adetr/decoder.py: -------------------------------------------------------------------------------- ```python 1 | from typing import Dict, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.utils.checkpoint 5 | from torch import nn 6 | from transformers import PretrainedConfig 7 | 8 | from transformers.activations import ACT2FN 9 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter 10 | from transformers.modeling_outputs import BaseModelOutputWithNoAttention 11 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 12 | 13 | from surya.common.pretrained import SuryaPreTrainedModel 14 | from surya.common.xla import mark_step 15 | 16 | _MAX_SQRT_GRADIENT = 1000.0 17 | 18 | 19 | class WrappedEmbedding(nn.Embedding): 20 | def forward(self, input_ids, *args, **kwargs): 21 | return super().forward(input_ids) 22 | 23 | 24 | class SuryaADETRDecoderRMSNorm(nn.Module): 25 | def __init__(self, dim: int, eps: float = 1e-6): 26 | super().__init__() 27 | self.eps = eps 28 | self.weight = nn.Parameter(torch.zeros(dim)) 29 | 30 | def _norm(self, x): 31 | variance = x.pow(2).mean(-1, keepdim=True) 32 | 33 | # Add clipping to prevent division by zero 34 | variance = torch.clamp(variance, min=self.eps) 35 | return x * torch.rsqrt(variance) 36 | 37 | def forward(self, x): 38 | output = self._norm(x.float()) 39 | # Llama does x.to(float16) * w whilst SuryaADETRDecoder is (x * w).to(float16) 40 | # See https://github.com/huggingface/transformers/pull/29402 41 | output = output * (1.0 + self.weight.float()) 42 | # Clamp to float16 range 43 | f16_info = torch.finfo(x.dtype) 44 | output = output.clamp(min=f16_info.min, max=f16_info.max) 45 | output = torch.where( 46 | torch.isnan(output), torch.tensor(0.0, device=output.device), output 47 | ) 48 | return output.type_as(x) 49 | 50 | def extra_repr(self): 51 | return f"{tuple(self.weight.shape)}, eps={self.eps}" 52 | 53 | 54 | ALL_LAYERNORM_LAYERS.append(SuryaADETRDecoderRMSNorm) 55 | 56 | 57 | class SuryaADETRDecoderRotaryEmbedding(nn.Module): 58 | def __init__(self, dim, base=10000, device=None): 59 | super().__init__() 60 | self.dim = dim 61 | self.base = base 62 | inv_freq = 1.0 / ( 63 | self.base 64 | ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim) 65 | ) 66 | self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) 67 | 68 | @torch.no_grad() 69 | # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaADETRDecoder 70 | def forward(self, x, position_ids, seq_len=None): 71 | # x: [bs, num_attention_heads, seq_len, head_size] 72 | self.inv_freq.to(x.device) 73 | inv_freq_expanded = ( 74 | self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) 75 | ) 76 | position_ids_expanded = position_ids[:, None, :].float() 77 | 78 | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( 79 | 1, 2 80 | ) 81 | emb = torch.cat((freqs, freqs), dim=-1) 82 | cos = emb.cos() 83 | sin = emb.sin() 84 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 85 | 86 | 87 | # Copied from transformers.models.llama.modeling_llama.rotate_half 88 | def rotate_half(x): 89 | """Rotates half the hidden dims of the input.""" 90 | x1 = x[..., : x.shape[-1] // 2] 91 | x2 = x[..., x.shape[-1] // 2 :] 92 | return torch.cat((-x2, x1), dim=-1) 93 | 94 | 95 | # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb 96 | def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): 97 | """Applies Rotary Position Embedding to the query and key tensors. 98 | 99 | Args: 100 | q (`torch.Tensor`): The query tensor. 101 | k (`torch.Tensor`): The key tensor. 102 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 103 | sin (`torch.Tensor`): The sine part of the rotary embedding. 104 | unsqueeze_dim (`int`, *optional*, defaults to 1): 105 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 106 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 107 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 108 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 109 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 110 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 111 | Returns: 112 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 113 | """ 114 | cos = cos.unsqueeze(unsqueeze_dim) 115 | sin = sin.unsqueeze(unsqueeze_dim) 116 | q_embed = (q * cos) + (rotate_half(q) * sin) 117 | k_embed = (k * cos) + (rotate_half(k) * sin) 118 | return q_embed, k_embed 119 | 120 | 121 | # Copied from transformers.models.llama.modeling_llama.repeat_kv 122 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 123 | """ 124 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 125 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 126 | """ 127 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 128 | if n_rep == 1: 129 | return hidden_states 130 | hidden_states = hidden_states[:, :, None, :, :].expand( 131 | batch, num_key_value_heads, n_rep, slen, head_dim 132 | ) 133 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 134 | 135 | 136 | class SuryaADETRDecoderSdpaCrossAttention(nn.Module): 137 | """Multi-headed attention from 'Attention Is All You Need' paper 138 | Modified for GQA 139 | """ 140 | 141 | def __init__(self, config: PretrainedConfig): 142 | super().__init__() 143 | self.config = config 144 | self.attention_dropout = config.attention_dropout 145 | self.hidden_size = config.hidden_size 146 | self.num_attention_heads = config.num_attention_heads 147 | self.head_dim = config.head_dim 148 | self.num_key_value_heads = config.num_key_value_heads 149 | self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads 150 | 151 | self.q_proj = nn.Linear( 152 | self.hidden_size, 153 | self.num_attention_heads * self.head_dim, 154 | bias=config.attention_bias, 155 | ) 156 | self.k_proj = nn.Linear( 157 | self.config.encoder_hidden_size, 158 | self.num_key_value_heads * self.head_dim, 159 | bias=config.attention_bias, 160 | ) 161 | self.v_proj = nn.Linear( 162 | self.config.encoder_hidden_size, 163 | self.num_key_value_heads * self.head_dim, 164 | bias=config.attention_bias, 165 | ) 166 | self.o_proj = nn.Linear( 167 | self.num_attention_heads * self.head_dim, self.hidden_size, bias=True 168 | ) 169 | self.rotary_emb = SuryaADETRDecoderRotaryEmbedding( 170 | self.head_dim, 171 | base=config.rope_theta, 172 | ) 173 | 174 | def forward( 175 | self, 176 | hidden_states: torch.Tensor, 177 | encoder_hidden_states: torch.Tensor, 178 | attention_mask: Optional[torch.Tensor] = None, 179 | encoder_attention_mask: Optional[torch.Tensor] = None, 180 | use_cache: bool = False, 181 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 182 | # Encoder attention mask currently ignored 183 | 184 | bsz, q_len, _ = hidden_states.size() 185 | _, v_len, _ = encoder_hidden_states.size() 186 | 187 | query_states = self.q_proj(hidden_states) 188 | query_states = query_states.view( 189 | bsz, q_len, self.num_attention_heads, self.head_dim 190 | ).transpose(1, 2) 191 | 192 | if self.key_states is None: 193 | key_states = self.k_proj(encoder_hidden_states) 194 | value_states = self.v_proj(encoder_hidden_states) 195 | key_states = key_states.view( 196 | bsz, v_len, self.num_key_value_heads, self.head_dim 197 | ).transpose(1, 2) 198 | value_states = value_states.view( 199 | bsz, v_len, self.num_key_value_heads, self.head_dim 200 | ).transpose(1, 2) 201 | if use_cache: 202 | self._update_cache(key_states, value_states) 203 | else: 204 | key_states = self.key_states 205 | value_states = self.value_states 206 | 207 | key_states = repeat_kv(key_states, self.num_key_value_groups) 208 | value_states = repeat_kv(value_states, self.num_key_value_groups) 209 | 210 | attn_output = torch.nn.functional.scaled_dot_product_attention( 211 | query_states, 212 | key_states, 213 | value_states, 214 | attn_mask=None, 215 | dropout_p=self.attention_dropout if self.training else 0.0, 216 | scale=self.head_dim**-0.5, 217 | ) 218 | 219 | attn_output = attn_output.transpose(1, 2).contiguous() 220 | attn_output = attn_output.view(bsz, q_len, self.hidden_size) 221 | attn_output = self.o_proj(attn_output) 222 | return attn_output 223 | 224 | def _clear_cache(self): 225 | if self.value_states is not None: 226 | del self.value_states 227 | if self.key_states is not None: 228 | del self.key_states 229 | 230 | def _setup_cache(self, batch_size, device, dtype=None): 231 | # Setup initial caches 232 | self.value_states = None 233 | self.key_states = None 234 | 235 | @torch.no_grad() 236 | def _update_cache(self, key_states, value_states, **cache_kwargs): 237 | self.value_states = value_states 238 | self.key_states = key_states 239 | 240 | 241 | class SuryaADETRDecoderSdpaAttention(nn.Module): 242 | """Multi-headed attention from 'Attention Is All You Need' paper""" 243 | 244 | def __init__(self, config: PretrainedConfig, static_cache=False, max_boxes=None): 245 | super().__init__() 246 | self.config = config 247 | self.attention_dropout = config.attention_dropout 248 | self.hidden_size = config.hidden_size 249 | self.num_attention_heads = config.num_attention_heads 250 | self.head_dim = config.head_dim 251 | self.num_key_value_heads = config.num_key_value_heads 252 | self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads 253 | 254 | self.q_proj = nn.Linear( 255 | self.hidden_size, 256 | self.num_attention_heads * self.head_dim, 257 | bias=config.attention_bias, 258 | ) 259 | self.k_proj = nn.Linear( 260 | self.hidden_size, 261 | self.num_key_value_heads * self.head_dim, 262 | bias=config.attention_bias, 263 | ) 264 | self.v_proj = nn.Linear( 265 | self.hidden_size, 266 | self.num_key_value_heads * self.head_dim, 267 | bias=config.attention_bias, 268 | ) 269 | self.o_proj = nn.Linear( 270 | self.num_attention_heads * self.head_dim, self.hidden_size, bias=True 271 | ) 272 | self.rotary_emb = SuryaADETRDecoderRotaryEmbedding( 273 | self.head_dim, 274 | base=config.rope_theta, 275 | ) 276 | 277 | self.static_cache = static_cache 278 | self.max_boxes = max_boxes 279 | 280 | def forward( 281 | self, 282 | hidden_states: torch.Tensor, 283 | position_ids: Optional[torch.LongTensor] = None, 284 | attention_mask: Optional[torch.Tensor] = None, 285 | cache_position: Optional[torch.LongTensor] = None, 286 | use_cache: bool = False, 287 | window_attn: bool = False, 288 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 289 | bsz, q_len, _ = hidden_states.size() 290 | 291 | query_states = self.q_proj(hidden_states) 292 | key_states = self.k_proj(hidden_states) 293 | value_states = self.v_proj(hidden_states) 294 | 295 | # Final is bsz, num_attention_heads, seq_len, head_dim 296 | query_states = query_states.view( 297 | bsz, q_len, self.num_attention_heads, self.head_dim 298 | ).transpose(1, 2) 299 | key_states = key_states.view( 300 | bsz, q_len, self.num_key_value_heads, self.head_dim 301 | ).transpose(1, 2) 302 | value_states = value_states.view( 303 | bsz, q_len, self.num_key_value_heads, self.head_dim 304 | ).transpose(1, 2) 305 | 306 | cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) 307 | query_states, key_states = apply_rotary_pos_emb( 308 | query_states, key_states, cos, sin 309 | ) 310 | 311 | if use_cache and hasattr(self, "key_states"): 312 | cache_kwargs = { 313 | "cache_position": cache_position, 314 | "window_attn": window_attn, 315 | } 316 | key_states, value_states = self._update_cache( 317 | key_states, value_states, **cache_kwargs 318 | ) 319 | 320 | key_states = repeat_kv(key_states, self.num_key_value_groups) 321 | value_states = repeat_kv(value_states, self.num_key_value_groups) 322 | 323 | causal_mask = attention_mask 324 | if attention_mask is not None: 325 | # Mask is batch, head, seq_len, kv_len 326 | causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] 327 | if cache_position is not None and self.static_cache: 328 | current_pos = cache_position[-1] 329 | causal_mask[:, :, :, current_pos + 1 :] = torch.finfo( 330 | causal_mask.dtype 331 | ).min 332 | 333 | attn_output = torch.nn.functional.scaled_dot_product_attention( 334 | query_states, 335 | key_states, 336 | value_states, 337 | attn_mask=causal_mask, 338 | dropout_p=self.attention_dropout if self.training else 0.0, 339 | scale=self.head_dim**-0.5, 340 | ) 341 | 342 | attn_output = attn_output.transpose(1, 2).contiguous() 343 | attn_output = attn_output.view(bsz, q_len, self.hidden_size) 344 | attn_output = self.o_proj(attn_output) 345 | return attn_output 346 | 347 | def _setup_cache(self, batch_size, device, dtype=None): 348 | if dtype is None and self.config.torch_dtype is not None: 349 | dtype = self.config.torch_dtype 350 | dtype = dtype if dtype is not None else torch.float32 351 | 352 | # Setup initial caches 353 | self.value_states = None 354 | self.key_states = None 355 | 356 | if self.static_cache: 357 | cache_shape = ( 358 | batch_size, 359 | self.num_key_value_heads, 360 | self.max_boxes, 361 | self.head_dim, 362 | ) 363 | self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device) 364 | self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device) 365 | 366 | def _clear_cache(self): 367 | if self.value_states is not None: 368 | del self.value_states 369 | if self.key_states is not None: 370 | del self.key_states 371 | 372 | def _update_static_cache(self, key_states, value_states, **cache_kwargs): 373 | cache_position = cache_kwargs.get("cache_position") 374 | k_out, v_out = ( 375 | self.key_states.to(key_states.device), 376 | self.value_states.to(value_states.device), 377 | ) 378 | 379 | k_out[:, :, cache_position] = key_states.to(k_out.dtype) 380 | v_out[:, :, cache_position] = value_states.to(v_out.dtype) 381 | 382 | self.key_states, self.value_states = k_out, v_out 383 | return k_out, v_out 384 | 385 | def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs): 386 | k_out = key_states 387 | if self.key_states is not None: 388 | k_out = torch.cat([self.key_states, key_states], dim=2) 389 | 390 | v_out = value_states 391 | if self.value_states is not None: 392 | v_out = torch.cat([self.value_states, value_states], dim=2) 393 | 394 | self.key_states, self.value_states = k_out, v_out 395 | return k_out, v_out 396 | 397 | @torch.no_grad() 398 | def _update_cache(self, key_states, value_states, **cache_kwargs): 399 | if self.static_cache: 400 | return self._update_static_cache(key_states, value_states, **cache_kwargs) 401 | 402 | return self._update_dynamic_cache(key_states, value_states, **cache_kwargs) 403 | 404 | 405 | class SuryaADETRDecoderMlp(nn.Module): 406 | def __init__(self, config): 407 | super().__init__() 408 | self.config = config 409 | self.hidden_size = config.hidden_size 410 | self.intermediate_size = config.intermediate_size 411 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 412 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 413 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 414 | if config.hidden_activation is None: 415 | config.hidden_activation = "gelu_pytorch_tanh" 416 | hidden_activation = config.hidden_activation 417 | self.act_fn = ACT2FN[hidden_activation] 418 | 419 | def forward(self, x): 420 | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 421 | 422 | 423 | class SuryaADETRDecoderLayer(nn.Module): 424 | def __init__(self, config, layer_idx, static_cache=False, max_boxes=None): 425 | super().__init__() 426 | self.cross_pre_norm = SuryaADETRDecoderRMSNorm( 427 | config.hidden_size, eps=config.rms_norm_eps 428 | ) 429 | self.temporal_pre_norm = SuryaADETRDecoderRMSNorm( 430 | config.hidden_size, eps=config.rms_norm_eps 431 | ) 432 | 433 | self.temporal_block = None 434 | if layer_idx in config.self_attn_layers: 435 | self.temporal_block = SuryaADETRDecoderSdpaAttention( 436 | config, static_cache=static_cache, max_boxes=max_boxes 437 | ) 438 | 439 | self.cross_attn_block = None 440 | if layer_idx in config.cross_attn_layers: 441 | self.cross_attn_block = SuryaADETRDecoderSdpaCrossAttention(config) 442 | 443 | self.window_attn = layer_idx not in config.global_attn_layers 444 | self.channel_pre_norm = SuryaADETRDecoderRMSNorm( 445 | config.hidden_size, eps=config.rms_norm_eps 446 | ) 447 | self.mlp_block = SuryaADETRDecoderMlp(config) 448 | 449 | self.double_residual_flow = getattr(config, "double_residual_flow", False) 450 | 451 | def forward( 452 | self, 453 | activations: torch.Tensor, 454 | position_ids: torch.Tensor, 455 | attention_mask: torch.Tensor, 456 | encoder_hidden_states: torch.Tensor = None, 457 | encoder_attention_mask: torch.Tensor = None, 458 | cache_position: torch.Tensor = None, 459 | use_cache: bool = None, 460 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 461 | if self.double_residual_flow: 462 | return self.double_res_forward( 463 | activations, 464 | position_ids, 465 | attention_mask, 466 | encoder_hidden_states, 467 | encoder_attention_mask, 468 | cache_position, 469 | use_cache, 470 | ) 471 | 472 | hidden_states = activations 473 | if self.cross_attn_block is not None: 474 | # Do cross-attention on encoder outputs 475 | cross_attn_inputs = self.cross_pre_norm(hidden_states) 476 | cross_attn_path = self.cross_attn_block( 477 | cross_attn_inputs, 478 | encoder_hidden_states, 479 | attention_mask, 480 | encoder_attention_mask, 481 | use_cache=use_cache, 482 | ) 483 | hidden_states = cross_attn_path + hidden_states 484 | 485 | if self.temporal_block is not None: 486 | temporal_inputs = self.temporal_pre_norm( 487 | hidden_states 488 | ) # RMSNorm introduces slight slight differences 489 | temporal_path = self.temporal_block( 490 | temporal_inputs, 491 | position_ids, 492 | attention_mask, 493 | cache_position=cache_position, 494 | use_cache=use_cache, 495 | window_attn=self.window_attn, 496 | ) 497 | 498 | hidden_states = temporal_path + hidden_states 499 | 500 | block_input = hidden_states 501 | hidden_states = self.channel_pre_norm(block_input) 502 | hidden_states = self.mlp_block(hidden_states) 503 | hidden_states = hidden_states + block_input 504 | 505 | return hidden_states 506 | 507 | def double_res_forward( 508 | self, 509 | activations: torch.Tensor, 510 | position_ids: torch.Tensor, 511 | attention_mask: torch.Tensor, 512 | encoder_hidden_states: torch.Tensor = None, 513 | encoder_attention_mask: torch.Tensor = None, 514 | cache_position: torch.Tensor = None, 515 | use_cache: bool = None, 516 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 517 | raw_activations = activations 518 | 519 | if self.cross_attn_block is not None: 520 | # Do cross-attention on encoder outputs 521 | cross_attn_inputs = self.cross_pre_norm(activations) 522 | cross_attn_path = self.cross_attn_block( 523 | cross_attn_inputs, 524 | encoder_hidden_states, 525 | attention_mask, 526 | encoder_attention_mask, 527 | use_cache=use_cache, 528 | ) 529 | cross_attn_output = cross_attn_path + raw_activations 530 | else: 531 | cross_attn_output = raw_activations 532 | 533 | if self.temporal_block is not None: 534 | inputs_normalized = self.temporal_pre_norm( 535 | cross_attn_output 536 | ) # RMSNorm introduces slight slight differences 537 | hidden_states = self.temporal_block( 538 | inputs_normalized, 539 | position_ids, 540 | attention_mask, 541 | cache_position=cache_position, 542 | use_cache=use_cache, 543 | window_attn=self.window_attn, 544 | ) 545 | 546 | residual = hidden_states + raw_activations 547 | else: 548 | residual = cross_attn_output 549 | 550 | hidden_states = self.channel_pre_norm(residual) 551 | hidden_states = self.mlp_block(hidden_states) 552 | 553 | hidden_states = hidden_states + residual 554 | return hidden_states 555 | 556 | 557 | class SuryaADETRDecoderPreTrainedModel(SuryaPreTrainedModel): 558 | config_class = PretrainedConfig 559 | base_model_prefix = "model" 560 | supports_gradient_checkpointing = True 561 | _no_split_modules = ["SuryaADETRDecoderLayer"] 562 | _skip_keys_device_placement = ["cache"] 563 | _supports_flash_attn_2 = False 564 | _supports_sdpa = False # we can't compare with eager for now 565 | _supports_cache_class = True 566 | _supports_quantized_cache = True 567 | 568 | def _init_weights(self, module): 569 | if isinstance(module, SuryaADETRDecoderSdpaAttention): 570 | torch.nn.init.normal_( 571 | module.q_proj.weight, mean=0.0, std=self.config.init_std 572 | ) 573 | torch.nn.init.normal_( 574 | module.k_proj.weight, mean=0.0, std=self.config.init_std 575 | ) 576 | torch.nn.init.normal_( 577 | module.v_proj.weight, mean=0.0, std=self.config.init_std 578 | ) 579 | 580 | torch.nn.init.normal_( 581 | module.o_proj.weight, mean=0.0, std=self.config.init_std 582 | ) 583 | elif isinstance(module, nn.Linear): 584 | torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) 585 | if getattr(module, "bias", None) is not None: 586 | torch.nn.init.zeros_(module.bias) 587 | elif isinstance(module, nn.Embedding): 588 | module.weight.data.normal_(mean=0.0, std=self.config.init_std) 589 | if module.padding_idx is not None: 590 | module.weight.data[module.padding_idx].zero_() 591 | 592 | def _setup_cache(self, config, batch, device, dtype): 593 | layers = getattr(self, "model", self).layers 594 | for layer in layers: 595 | if layer.temporal_block: 596 | layer.temporal_block._setup_cache(batch, device, dtype) 597 | if layer.cross_attn_block: 598 | layer.cross_attn_block._setup_cache(batch, device, dtype) 599 | 600 | def _clear_cache(self): 601 | layers = getattr(self, "model", self).layers 602 | for layer in layers: 603 | if layer.temporal_block: 604 | layer.temporal_block._clear_cache() 605 | if layer.cross_attn_block: 606 | layer.cross_attn_block._clear_cache() 607 | 608 | def reset_cache(self, batch, device, dtype): 609 | pass 610 | 611 | def _tie_weights(self): 612 | pass 613 | 614 | def tie_weights(self): 615 | pass 616 | 617 | 618 | class SuryaADETRDecoderModel(SuryaADETRDecoderPreTrainedModel): 619 | """ 620 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaADETRDecoderDecoderLayer`] 621 | 622 | Args: 623 | config: PretrainedConfig 624 | """ 625 | 626 | def __init__( 627 | self, 628 | config: PretrainedConfig, 629 | embedder: nn.Module = None, 630 | max_boxes: int = None, 631 | static_cache: bool = False, 632 | ): 633 | super().__init__(config) 634 | self.padding_idx = config.pad_token_id 635 | self.vocab_size = config.vocab_size 636 | self.causal = config.causal 637 | 638 | self.embed_tokens = embedder 639 | self.max_boxes = max_boxes 640 | self.static_cache = static_cache 641 | 642 | self.layers = nn.ModuleList( 643 | [ 644 | SuryaADETRDecoderLayer( 645 | config, layer_idx, static_cache=static_cache, max_boxes=max_boxes 646 | ) 647 | for layer_idx in range(config.num_hidden_layers) 648 | ] 649 | ) 650 | self.final_norm = SuryaADETRDecoderRMSNorm( 651 | config.hidden_size, eps=config.rms_norm_eps 652 | ) 653 | self.gradient_checkpointing = False 654 | 655 | self.register_buffer( 656 | "normalizer", 657 | torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), 658 | persistent=False, 659 | ) 660 | # Initialize weights and apply final processing 661 | self.post_init() 662 | 663 | # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings 664 | def get_input_embeddings(self): 665 | return self.embed_tokens 666 | 667 | # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings 668 | def set_input_embeddings(self, value): 669 | self.embed_tokens = value 670 | 671 | def forward( 672 | self, 673 | input_ids: torch.LongTensor = None, 674 | input_boxes_counts: torch.LongTensor = None, 675 | inputs_embeds: Optional[torch.FloatTensor] = None, 676 | position_ids: Optional[torch.LongTensor] = None, 677 | attention_mask: Optional[torch.Tensor] = None, 678 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 679 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 680 | cache_position: Optional[torch.LongTensor] = None, 681 | use_cache: Optional[bool] = None, 682 | output_hidden_states: Optional[bool] = None, 683 | return_dict: Optional[bool] = None, 684 | prefill: bool = False, 685 | ) -> Union[Tuple, BaseModelOutputWithNoAttention]: 686 | use_cache = use_cache if use_cache is not None else self.config.use_cache 687 | return_dict = ( 688 | return_dict if return_dict is not None else self.config.use_return_dict 689 | ) 690 | 691 | if self.gradient_checkpointing and self.training and use_cache: 692 | use_cache = False 693 | 694 | inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts) 695 | hidden_states = inputs_embeds 696 | 697 | if use_cache and prefill: 698 | self._setup_cache( 699 | self.config, 700 | hidden_states.shape[0], 701 | hidden_states.device, 702 | hidden_states.dtype, 703 | ) 704 | 705 | if cache_position is None: 706 | cache_position = torch.arange( 707 | hidden_states.shape[1], device=hidden_states.device 708 | ) 709 | if position_ids is None: 710 | position_ids = cache_position.unsqueeze(0) 711 | 712 | causal_mask = self._update_causal_mask( 713 | attention_mask, inputs_embeds, cache_position 714 | ) 715 | 716 | all_hidden_states = () if output_hidden_states else None 717 | for i, residual_block in enumerate(self.layers): 718 | if output_hidden_states: 719 | all_hidden_states += (hidden_states,) 720 | if self.gradient_checkpointing and self.training: 721 | hidden_states = self._gradient_checkpointing_func( 722 | residual_block.__call__, 723 | hidden_states, 724 | position_ids, 725 | causal_mask, 726 | encoder_hidden_states, 727 | encoder_attention_mask, 728 | cache_position, 729 | use_cache, 730 | ) 731 | else: 732 | hidden_states = residual_block( 733 | hidden_states, 734 | position_ids, 735 | causal_mask, 736 | encoder_hidden_states, 737 | encoder_attention_mask, 738 | cache_position, 739 | use_cache, 740 | ) 741 | 742 | hidden_states = self.final_norm(hidden_states) 743 | 744 | # add hidden states from the last decoder layer 745 | if output_hidden_states: 746 | all_hidden_states += (hidden_states,) 747 | 748 | if not return_dict: 749 | return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) 750 | 751 | return BaseModelOutputWithNoAttention( 752 | last_hidden_state=hidden_states, 753 | hidden_states=all_hidden_states, 754 | ) 755 | 756 | # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static 757 | # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. 758 | # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using 759 | # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 760 | # Ignore copy 761 | def _update_causal_mask(self, attention_mask, input_tensor, cache_position): 762 | if not self.causal: 763 | return None 764 | 765 | dtype, device = input_tensor.dtype, input_tensor.device 766 | min_dtype = torch.finfo(dtype).min 767 | sequence_length = input_tensor.shape[1] 768 | target_length = max(self.max_boxes, sequence_length) 769 | 770 | diagonal = torch.full( 771 | (sequence_length, target_length), 772 | fill_value=min_dtype, 773 | dtype=dtype, 774 | device=device, 775 | ) 776 | causal_mask = diagonal 777 | if sequence_length != 1: 778 | # Select the upper triangular part of the matrix, but unmask current token (the diagonal) 779 | # triu will be the min_dtype, everything else is 0 (attended to) 780 | causal_mask = torch.triu(diagonal, diagonal=1) 781 | 782 | causal_mask *= torch.arange( 783 | target_length, device=device 784 | ) > cache_position.reshape(-1, 1) 785 | causal_mask = causal_mask[None, None, :, :].expand( 786 | input_tensor.shape[0], 1, -1, -1 787 | ) 788 | if attention_mask is not None: 789 | causal_mask = ( 790 | causal_mask.clone() 791 | ) # copy to contiguous memory for in-place edit 792 | if attention_mask.dim() == 2: 793 | # Mask positions in the causal mask that are masked in the attention mask 794 | mask_length = attention_mask.shape[-1] 795 | padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[ 796 | :, None, None, : 797 | ].eq(0.0) 798 | causal_mask[..., :mask_length] = causal_mask[ 799 | ..., :mask_length 800 | ].masked_fill(padding_mask, min_dtype) 801 | 802 | if attention_mask is not None and attention_mask.device.type == "cuda": 803 | # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when 804 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. 805 | # Details: https://github.com/pytorch/pytorch/issues/110213 806 | causal_mask = AttentionMaskConverter._unmask_unattended( 807 | causal_mask, min_dtype 808 | ) 809 | 810 | return causal_mask 811 | ```