This is page 5 of 5. Use http://codebase.md/datalab-to/surya?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── ISSUE_TEMPLATE │ │ ├── breaking-bug-report.md │ │ ├── feature_request.md │ │ └── output-bug-report.md │ └── workflows │ ├── benchmarks.yml │ ├── ci.yml │ ├── cla.yml │ ├── publish.yml │ └── scripts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── benchmark │ ├── detection.py │ ├── layout.py │ ├── ordering.py │ ├── recognition.py │ ├── table_recognition.py │ ├── texify.py │ └── utils │ ├── __init__.py │ ├── bbox.py │ ├── metrics.py │ ├── scoring.py │ ├── tatr.py │ ├── tesseract.py │ ├── textract.py │ └── verify_benchmark_scores.py ├── CITATION.cff ├── CLA.md ├── detect_layout.py ├── detect_text.py ├── LICENSE ├── ocr_app.py ├── ocr_latex.py ├── ocr_text.py ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── README.md ├── signatures │ └── version1 │ └── cla.json ├── static │ ├── fonts │ │ └── .gitignore │ └── images │ ├── arabic_layout.jpg │ ├── arabic_reading.jpg │ ├── arabic_text.jpg │ ├── arabic.jpg │ ├── benchmark_chart_small.png │ ├── benchmark_chart.png │ ├── benchmark_layout_chart.png │ ├── benchmark_rec_chart.png │ ├── benchmark_tablerec_acc.png │ ├── benchmark_tablerec_speed.png │ ├── chi_hind_layout.jpg │ ├── chi_hind_orig.jpg │ ├── chi_hind_reading.jpg │ ├── chi_hind_text.jpg │ ├── chi_hind.jpg │ ├── chinese_layout.jpg │ ├── chinese_reading.jpg │ ├── chinese_text.jpg │ ├── chinese.jpg │ ├── excerpt_layout.png │ ├── excerpt_reading.jpg │ ├── excerpt_text.png │ ├── excerpt.png │ ├── funsd_layout.jpg │ ├── funsd_reading.jpg │ ├── funsd_text.jpg │ ├── funsd.png │ ├── gcloud_full_langs.png │ ├── gcloud_rec_bench.png │ ├── hindi_layout.jpg │ ├── hindi_reading.jpg │ ├── hindi_text.jpg │ ├── hindi.jpg │ ├── japanese_layout.jpg │ ├── japanese_reading.jpg │ ├── japanese_tablerec.png │ ├── japanese_text.jpg │ ├── japanese.jpg │ ├── latex_ocr.png │ ├── nyt_layout.jpg │ ├── nyt_order.jpg │ ├── nyt_text.jpg │ ├── nyt.jpg │ ├── paper_layout.jpg │ ├── paper_reading.jpg │ ├── paper_tablerec.png │ ├── paper_text.jpg │ ├── paper.jpg │ ├── pres_layout.jpg │ ├── pres_reading.jpg │ ├── pres_tablerec.png │ ├── pres_text.jpg │ ├── pres.png │ ├── rec_acc_table.png │ ├── scanned_layout.jpg │ ├── scanned_reading.jpg │ ├── scanned_tablerec.png │ ├── scanned_tablerec2.png │ ├── scanned_text.jpg │ ├── scanned.png │ ├── surya_rec_perf.png │ ├── table_rec.png │ ├── textbook_layout.jpg │ ├── textbook_order.jpg │ ├── textbook_text.jpg │ └── textbook.jpg ├── surya │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── adetr │ │ │ └── decoder.py │ │ ├── donut │ │ │ ├── encoder.py │ │ │ └── processor.py │ │ ├── load.py │ │ ├── polygon.py │ │ ├── predictor.py │ │ ├── pretrained.py │ │ ├── s3.py │ │ ├── surya │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── decoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── embedder │ │ │ │ └── __init__.py │ │ │ ├── encoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── flash_attn_utils.py │ │ │ ├── processor │ │ │ │ ├── __init__.py │ │ │ │ ├── schema.py │ │ │ │ └── tokenizer.py │ │ │ └── schema.py │ │ ├── util.py │ │ └── xla.py │ ├── debug │ │ ├── draw.py │ │ ├── fonts.py │ │ ├── katex.js │ │ ├── render_html.py │ │ └── text.py │ ├── detection │ │ ├── __init__.py │ │ ├── heatmap.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoderdecoder.py │ │ ├── parallel.py │ │ ├── processor.py │ │ ├── schema.py │ │ └── util.py │ ├── foundation │ │ ├── __init__.py │ │ ├── cache │ │ │ ├── __init__.py │ │ │ ├── dynamic_ops.py │ │ │ └── static_ops.py │ │ ├── loader.py │ │ └── util.py │ ├── input │ │ ├── load.py │ │ └── processing.py │ ├── layout │ │ ├── __init__.py │ │ ├── label.py │ │ └── schema.py │ ├── logging.py │ ├── models.py │ ├── ocr_error │ │ ├── __init__.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoder.py │ │ ├── schema.py │ │ └── tokenizer.py │ ├── recognition │ │ ├── __init__.py │ │ ├── languages.py │ │ ├── postprocessing.py │ │ ├── schema.py │ │ └── util.py │ ├── scripts │ │ ├── __init__.py │ │ ├── config.py │ │ ├── detect_layout.py │ │ ├── detect_text.py │ │ ├── finetune_ocr.py │ │ ├── hf_to_s3.py │ │ ├── ocr_latex.py │ │ ├── ocr_text.py │ │ ├── run_streamlit_app.py │ │ ├── run_texify_app.py │ │ ├── streamlit_app.py │ │ ├── table_recognition.py │ │ └── texify_app.py │ ├── settings.py │ └── table_rec │ ├── __init__.py │ ├── loader.py │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── encoderdecoder.py │ ├── processor.py │ ├── schema.py │ └── shaper.py ├── table_recognition.py ├── tests │ ├── assets │ │ └── test_latex.png │ ├── conftest.py │ ├── test_detection.py │ ├── test_foundation.py │ ├── test_latex_ocr.py │ ├── test_layout.py │ ├── test_ocr_errors.py │ ├── test_recognition.py │ └── test_table_rec.py └── texify_app.py ``` # Files -------------------------------------------------------------------------------- /surya/foundation/__init__.py: -------------------------------------------------------------------------------- ```python 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional, Tuple 5 | from collections import deque 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import math 11 | from PIL import Image 12 | from tqdm import tqdm 13 | import torch.nn.functional as F 14 | 15 | from surya.common.surya import SuryaModelOutput 16 | from surya.common.xla import mark_step 17 | from surya.common.predictor import BasePredictor 18 | 19 | from surya.foundation.loader import FoundationModelLoader 20 | from surya.foundation.util import ( 21 | detect_repeat_token, 22 | ) 23 | from surya.common.surya.schema import TaskNames 24 | from surya.foundation.cache.dynamic_ops import DynamicOpsCache 25 | from surya.foundation.cache.static_ops import StaticOpsCache 26 | 27 | from surya.settings import settings 28 | from surya.logging import get_logger, configure_logging 29 | 30 | configure_logging() 31 | logger = get_logger() 32 | 33 | 34 | @dataclass 35 | class ContinuousBatchInput: 36 | input_ids: torch.Tensor 37 | input_boxes: torch.Tensor 38 | position_ids: torch.Tensor 39 | # input_ids and position_ids may be padded, num_valid_tokens tracks the 'real' counts 40 | num_valid_tokens: torch.Tensor 41 | # count the number of predicted tokens for each batch element so far 42 | num_predicted_tokens: torch.Tensor 43 | needs_bbox_embedding: torch.Tensor 44 | 45 | 46 | @dataclass 47 | class ContinuousBatchOutput: 48 | input_ids: torch.Tensor 49 | preds: torch.Tensor 50 | bbox_preds: torch.Tensor 51 | scores: torch.Tensor 52 | token_probs: torch.Tensor 53 | 54 | 55 | @dataclass 56 | class FoundationPrompt: 57 | id: int 58 | task_name: TaskNames 59 | image: np.ndarray 60 | text: str 61 | math_mode: bool 62 | 63 | 64 | class FoundationPredictor(BasePredictor): 65 | model_loader_cls = FoundationModelLoader 66 | batch_size = ( 67 | settings.RECOGNITION_BATCH_SIZE 68 | ) # Default to the recognition batch size 69 | torch_dtype = None # No default, loader picks the dtype based on device properties - bf16/fp16 70 | default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 64} 71 | encoder_chunk_size: int = 4096 # Default chunk size 72 | encoder_chunk_sizes = {"cpu": 4096, "mps": 4096, "cuda": 32768, "xla": 32768} 73 | extra_token_count = { 74 | "xla": 128 75 | } # We have to pad the XLA cache since we don't use sliding window 76 | min_prefill_ratio: int = 1 if settings.FOUNDATION_XLA else 0.2 77 | min_trim_length: int = 50 78 | tasks = { 79 | TaskNames.ocr_with_boxes: { 80 | "needs_bboxes": True, 81 | "img_size": (1024, 512), 82 | "max_tokens": 224, 83 | }, 84 | TaskNames.ocr_without_boxes: { 85 | "needs_bboxes": False, 86 | "img_size": (1024, 512), 87 | "max_tokens": 224, 88 | }, 89 | TaskNames.block_without_boxes: { 90 | "needs_bboxes": False, 91 | "img_size": (1024, 512), 92 | "max_tokens": 768, 93 | }, 94 | TaskNames.layout: { 95 | "needs_bboxes": False, 96 | "img_size": (1024, 1024), 97 | "max_tokens": 200, 98 | }, 99 | TaskNames.table_structure: { 100 | "needs_bboxes": False, 101 | "img_size": (1024, 512), 102 | "max_tokens": 600, 103 | }, 104 | } 105 | 106 | def __init__( 107 | self, 108 | checkpoint=None, 109 | device=settings.TORCH_DEVICE_MODEL, 110 | dtype=None, 111 | attention_implementation: Optional[str] = None, 112 | ): 113 | super().__init__(checkpoint, device, dtype, attention_implementation) 114 | self.prompt_queue = deque() 115 | self.batch_prompt_mapping = None 116 | self.kv_cache = None 117 | 118 | self.beacon_token_interval = self.model.config.beacon_token_interval 119 | 120 | # Setup various tokens on-device 121 | self.device_pad_token = torch.tensor( 122 | self.processor.pad_token_id, device=self.model.device, dtype=torch.long 123 | ) 124 | self.device_beacon_token = torch.tensor( 125 | self.processor.beacon_token_id, device=self.model.device, dtype=torch.long 126 | ) 127 | self.special_token_ids = torch.tensor( 128 | [self.model.config.image_token_id] + self.model.config.register_token_ids, 129 | device=self.model.device, 130 | ) 131 | 132 | self.pad_to_multiple = ( 133 | settings.FOUNDATION_PAD_TO_NEAREST 134 | if settings.FOUNDATION_STATIC_CACHE 135 | else None 136 | ) 137 | 138 | def to(self, device_dtype: torch.device | str | None = None): 139 | super().to(device_dtype) 140 | self.special_token_ids = self.special_token_ids.to(device_dtype) 141 | 142 | def get_encoder_chunk_size(self) -> int: 143 | if settings.FOUNDATION_CHUNK_SIZE is not None: 144 | return settings.FOUNDATION_CHUNK_SIZE 145 | 146 | chunk_size = self.encoder_chunk_size 147 | if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes: 148 | if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes: 149 | chunk_size = self.encoder_chunk_sizes[settings.TORCH_DEVICE_MODEL] 150 | return chunk_size 151 | 152 | def setup_cache(self, batch_size: int, max_cache_len: int, max_sliding_window: int): 153 | kv_cache_cls = StaticOpsCache if settings.FOUNDATION_XLA else DynamicOpsCache 154 | self.kv_cache = kv_cache_cls( 155 | self.model.config, 156 | batch_size, 157 | max_cache_len, 158 | text_sliding_window=max_sliding_window, 159 | device=self.model.device, 160 | dtype=self.model.dtype, 161 | ) 162 | self.prompt_queue.clear() 163 | self.batch_prompt_mapping = {i: None for i in range(batch_size)} 164 | 165 | @property 166 | def num_empty_slots(self): 167 | return sum(v is None for v in self.batch_prompt_mapping.values()) 168 | 169 | @property 170 | def num_active_slots(self): 171 | return len(self.batch_prompt_mapping) - self.num_empty_slots 172 | 173 | def prepare_input( 174 | self, 175 | task_names: List[str], 176 | images: List[Image.Image], 177 | input_text: List[str | None], 178 | math_modes: List[bool], 179 | ): 180 | batch = [] 181 | for image, text, task_name, math_mode in zip( 182 | images, input_text, task_names, math_modes 183 | ): 184 | image_size = self.tasks[task_name]["img_size"] 185 | 186 | try: 187 | image = self.processor.scale_to_fit( 188 | image, image_size 189 | ) # Only resizes if out of bounds (max/min) 190 | except cv2.error: 191 | # The image is empty if it can't be resized, so just make a blank image 192 | image = np.zeros((image_size[1], image_size[0], 3), dtype=np.float32) 193 | 194 | # Task input is the same for all tasks for now 195 | text = text or "" 196 | 197 | # Remove input text that exceeds max generation tokens (likely invalid) 198 | if len(text) > self.tasks[task_name]["max_tokens"]: 199 | text = "" 200 | inputs = [ 201 | {"type": "image", "image": image, "rotated": False}, 202 | {"type": "text", "text": text.strip(), "math": math_mode}, 203 | ] 204 | batch.append({"task": task_name, "inputs": inputs}) 205 | 206 | return batch 207 | 208 | def process_outputs( 209 | self, outputs: SuryaModelOutput, max_lookahead_tokens: Optional[int] = None 210 | ) -> ContinuousBatchOutput: 211 | # Predictions are multi-token 212 | lm_logits = outputs["lm_logits"].float() # shape: [batch_size, seq_len, V] 213 | bbox_logits = outputs["bbox_logits"].float() # shape: [batch_size, seq_len, 6] 214 | 215 | if ( 216 | max_lookahead_tokens is not None 217 | and lm_logits.shape[1] > max_lookahead_tokens + 1 218 | ): 219 | lm_logits = lm_logits[:, : max_lookahead_tokens + 1, :] 220 | bbox_logits = bbox_logits[:, : max_lookahead_tokens + 1, :] 221 | 222 | # Get predictions 223 | preds = torch.argmax(lm_logits, dim=-1) 224 | input_ids = preds.to(torch.long) 225 | 226 | # Confidence scores for all tokens 227 | token_probs = F.softmax(lm_logits, dim=-1) 228 | scores = torch.max(token_probs, dim=-1).values # shape: [B, T] 229 | 230 | # Update input boxes 231 | box_preds = bbox_logits * self.model.config.bbox_size 232 | box_preds = box_preds.to(torch.long) 233 | 234 | return ContinuousBatchOutput( 235 | input_ids=input_ids, 236 | preds=preds, 237 | bbox_preds=box_preds, 238 | scores=scores, 239 | token_probs=token_probs, 240 | ) 241 | 242 | # Always left pad with beacons, don't worry about attention masking 243 | def maybe_insert_beacon_tokens( 244 | self, 245 | input_ids: torch.Tensor, 246 | input_boxes: torch.Tensor, 247 | num_predicted_tokens: torch.Tensor, 248 | num_new_tokens: Optional[torch.Tensor] = None, 249 | ) -> Tuple[torch.Tensor, torch.Tensor]: 250 | batch_size, seq_len = ( 251 | input_ids.shape 252 | ) # seq_len can be >1 - In case of multi-token predictions 253 | 254 | # num_predicted tokens **does not include** the current new input_ids, this number is updated **after beacon tokens are inserted** 255 | token_positions = num_predicted_tokens + torch.arange( 256 | 1, seq_len + 1, device=input_ids.device 257 | ).unsqueeze(0) 258 | beacon_positions = token_positions % self.beacon_token_interval == 0 259 | 260 | # If no beacons needed, return original input 261 | needs_beacon = beacon_positions.any(dim=1) # shape: [batch_size] 262 | if not needs_beacon.any(): 263 | if num_new_tokens is None: 264 | num_new_tokens = ( 265 | torch.ones(batch_size, dtype=torch.long, device=input_ids.device) 266 | * seq_len 267 | ) 268 | return input_ids, input_boxes, num_new_tokens.squeeze(1) 269 | 270 | beacon_insert_pos = torch.zeros( 271 | batch_size, dtype=torch.long, device=input_ids.device 272 | ) 273 | for i in range(batch_size): 274 | if needs_beacon[i]: 275 | # Find first position that needs beacon 276 | beacon_insert_pos[i] = torch.where(beacon_positions[i])[0] 277 | 278 | # Padded input ids. 279 | new_input_ids = torch.full( 280 | (batch_size, seq_len + 1), 281 | self.device_pad_token, 282 | dtype=input_ids.dtype, 283 | device=input_ids.device, 284 | ) 285 | new_input_boxes = torch.full( 286 | (batch_size, seq_len + 1, 6), 287 | -100, 288 | dtype=input_boxes.dtype, 289 | device=input_boxes.device, 290 | ) 291 | # Fill in tokens for each sequence 292 | for i in range(batch_size): 293 | if needs_beacon[i]: 294 | insert_pos = beacon_insert_pos[i] 295 | new_input_ids[i, insert_pos] = self.device_beacon_token 296 | new_input_boxes[i, insert_pos, :] = -100 297 | if insert_pos > 0: 298 | new_input_ids[i, :insert_pos] = input_ids[i, :insert_pos] 299 | new_input_boxes[i, :insert_pos] = input_boxes[i, :insert_pos] 300 | new_input_ids[i, insert_pos + 1 :] = input_ids[i, insert_pos:] 301 | new_input_boxes[i, insert_pos + 1 :] = input_boxes[i, insert_pos:] 302 | else: 303 | new_input_ids[i, 1:] = input_ids[i, :] 304 | new_input_boxes[i, 1:] = input_boxes[i, :] 305 | 306 | # Calculate valid token counts for both padded and non padded sequences 307 | valid_token_counts = torch.where( 308 | needs_beacon, 309 | torch.tensor(seq_len + 1, device=input_ids.device), 310 | torch.tensor(seq_len, device=input_ids.device), 311 | ) 312 | 313 | return new_input_ids, new_input_boxes, valid_token_counts 314 | 315 | def decode( 316 | self, 317 | current_inputs: Optional[ContinuousBatchInput] = None, 318 | max_lookahead_tokens: Optional[int] = None, 319 | ): 320 | # Note - If we want to use the outputs from the non-last token, we 321 | # need to set the cache position manually to ensure causality. The default 322 | # behavior only works for the last token currently 323 | input_ids = current_inputs.input_ids 324 | input_boxes = current_inputs.input_boxes 325 | embed_boxes = current_inputs.needs_bbox_embedding 326 | 327 | position_ids = current_inputs.position_ids 328 | num_predicted_tokens = current_inputs.num_predicted_tokens 329 | num_valid_tokens = current_inputs.num_valid_tokens 330 | batch_size = input_ids.shape[0] 331 | 332 | # Pre-shift the attention mask based on the cache update 333 | self.kv_cache.decode_attention_mask_update( 334 | num_valid_tokens=num_valid_tokens, cache_idxs=list(range(batch_size)) 335 | ) 336 | 337 | cache_position = self.get_cache_position( 338 | input_ids.shape[1], self.kv_cache.attention_mask, prefill=False 339 | ) 340 | with settings.INFERENCE_MODE(): 341 | outputs = self.model( 342 | input_ids=input_ids, 343 | attention_mask=self.kv_cache.attention_mask, 344 | position_ids=position_ids, 345 | cache_position=cache_position, 346 | use_cache=True, 347 | past_key_values=self.kv_cache, 348 | prefill=False, 349 | num_valid_tokens=num_valid_tokens, 350 | input_boxes=input_boxes, 351 | embed_boxes=embed_boxes, 352 | logits_to_keep=1, 353 | ) 354 | 355 | processed_output: ContinuousBatchOutput = self.process_outputs( 356 | outputs, max_lookahead_tokens=max_lookahead_tokens 357 | ) 358 | 359 | input_ids = processed_output.input_ids 360 | input_boxes = processed_output.bbox_preds 361 | 362 | # Update this **before** inserting beacon tokens 363 | tau = settings.FOUNDATION_MULTI_TOKEN_MIN_CONFIDENCE 364 | if max_lookahead_tokens is not None: 365 | num_new_tokens = torch.clamp( 366 | ( 367 | processed_output.scores.ge(tau) 368 | .to(torch.long) 369 | .cumprod(dim=1) 370 | .sum(dim=1, keepdim=True) 371 | ), 372 | min=1, 373 | ) 374 | else: 375 | num_new_tokens = input_ids.shape[1] 376 | 377 | num_predicted_tokens += num_new_tokens 378 | input_ids, input_boxes, num_valid_tokens = self.maybe_insert_beacon_tokens( 379 | input_ids, input_boxes, num_predicted_tokens, num_new_tokens 380 | ) 381 | position_ids = position_ids[:, -1:] + torch.arange( 382 | 1, input_ids.shape[1] + 1, device=input_ids.device 383 | ) 384 | # Some of the input sequences may now have left padding tokens, so we want to account for that 385 | # offset is a per-batch offset of the position_ids 386 | offset = (input_ids.shape[1] - num_valid_tokens).unsqueeze(1) 387 | position_ids -= offset 388 | 389 | new_input = ContinuousBatchInput( 390 | input_ids=input_ids, 391 | input_boxes=input_boxes, 392 | position_ids=position_ids, 393 | num_valid_tokens=num_valid_tokens, 394 | num_predicted_tokens=num_predicted_tokens, 395 | needs_bbox_embedding=current_inputs.needs_bbox_embedding, 396 | ) 397 | 398 | return new_input, processed_output 399 | 400 | def pad_and_shift_input_ids_position_ids( 401 | self, 402 | input_ids: torch.Tensor, 403 | bbox_preds: torch.Tensor, 404 | position_ids: torch.Tensor, 405 | new_seq_len: int, 406 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 407 | """ 408 | Pads new_input_ids to match the new seq len with **left padding** 409 | and creates updated position_ids 410 | 411 | Returns: 412 | padded_input_ids (torch.Tensor): [batch_size, current_seq_len] 413 | updated_position_ids (torch.Tensor): [batch_size, current_seq_len] 414 | """ 415 | # No padding 416 | if new_seq_len == input_ids.shape[1]: 417 | return ( 418 | input_ids, 419 | bbox_preds, 420 | position_ids[:, -1:] + torch.arange(1, new_seq_len + 1, device=self.model.device), 421 | ) 422 | 423 | pad_len = new_seq_len - input_ids.shape[1] 424 | padded_input_ids = torch.nn.functional.pad( 425 | input_ids, (pad_len, 0), value=self.device_pad_token 426 | ) 427 | 428 | padded_bbox_preds = torch.nn.functional.pad( 429 | bbox_preds, (0, 0, pad_len, 0), value=-100 430 | ) 431 | 432 | # Since we have **left padding**, offset the new position_ids by the amount of padding 433 | # This ensures that the **true tokens** get the correct position_ids 434 | # The position_ids assigned to pad tokens do not matter. They are not cached, and not used for outputs 435 | updated_position_ids = position_ids[:, -1:] + torch.arange( 436 | 1, new_seq_len + 1, device=self.model.device 437 | ) 438 | updated_position_ids -= pad_len 439 | 440 | return padded_input_ids, padded_bbox_preds, updated_position_ids 441 | 442 | def get_cache_position( 443 | self, 444 | seq_len: int, 445 | attention_mask: torch.Tensor, 446 | prefill: bool, 447 | ): 448 | batch_size, target_len = attention_mask.shape 449 | base_cache_position = ( 450 | torch.arange(seq_len, device=attention_mask.device) 451 | .unsqueeze(0) 452 | .expand(batch_size, -1) 453 | ) 454 | if prefill: 455 | return base_cache_position 456 | 457 | # This is a (batch_size) tensor, we can add the seq lens here 458 | cache_seqlens = ( 459 | attention_mask 460 | * torch.arange(attention_mask.size(1), device=attention_mask.device) 461 | ).argmax(dim=1).to(torch.int32) + 1 462 | # Needs to be unsqueezed so broadcasting works 463 | return cache_seqlens.unsqueeze(1) + base_cache_position 464 | 465 | def prefill( 466 | self, 467 | current_inputs: Optional[ContinuousBatchInput] = None, 468 | max_lookahead_tokens: Optional[int] = None, 469 | ): 470 | logger.debug(f"Prefilling {self.num_empty_slots} slots") 471 | 472 | prompts: List[FoundationPrompt] = [ 473 | self.prompt_queue.popleft() 474 | for _ in range(min(self.num_empty_slots, len(self.prompt_queue))) 475 | ] 476 | non_active_idxs = [k for k, v in self.batch_prompt_mapping.items() if v is None] 477 | idxs_to_merge = non_active_idxs[: len(prompts)] 478 | 479 | for i, prompt in zip(idxs_to_merge, prompts): 480 | self.batch_prompt_mapping[i] = prompt.id 481 | 482 | needs_bbox_embedding = torch.tensor( 483 | [ 484 | p.task_name in [TaskNames.layout, TaskNames.table_structure] 485 | for p in prompts 486 | ], 487 | dtype=torch.bool, 488 | ) 489 | 490 | batch_input = self.prepare_input( 491 | task_names=[p.task_name for p in prompts], 492 | images=[p.image for p in prompts], 493 | input_text=[p.text for p in prompts], 494 | math_modes=[ 495 | p.math_mode for p in prompts 496 | ], # Pass math mode to the processor 497 | ) 498 | processed_inputs = self.processor( 499 | batch_input, 500 | padding_side="left", 501 | device=self.model.device, 502 | pad_to_multiple=self.pad_to_multiple, 503 | ) 504 | 505 | input_ids = processed_inputs["input_ids"].to(dtype=torch.long) 506 | attention_mask = processed_inputs["attention_mask"].to(dtype=torch.long) 507 | position_ids = processed_inputs["position_ids"].to(dtype=torch.long) 508 | valid_batch_size = len(idxs_to_merge) 509 | 510 | # Keep these off device until later 511 | image_tiles = processed_inputs["image_tiles"].to(dtype=self.model.dtype) 512 | grid_thw = processed_inputs["grid_thw"].to(dtype=torch.long) 513 | 514 | if settings.FOUNDATION_STATIC_CACHE: 515 | input_ids = self.pad_to_batch_size( 516 | input_ids, batch_size=self.kv_cache.max_batch_size 517 | ) 518 | attention_mask = self.pad_to_batch_size( 519 | attention_mask, batch_size=self.kv_cache.max_batch_size 520 | ) 521 | position_ids = self.pad_to_batch_size( 522 | position_ids, batch_size=self.kv_cache.max_batch_size 523 | ) 524 | needs_bbox_embedding = self.pad_to_batch_size( 525 | needs_bbox_embedding, batch_size=self.kv_cache.max_batch_size 526 | ) 527 | 528 | # Move to device after padding 529 | input_ids = input_ids.to(device=self.model.device) 530 | attention_mask = attention_mask.to(device=self.model.device) 531 | position_ids = position_ids.to(device=self.model.device) 532 | needs_bbox_embedding = needs_bbox_embedding.to(device=self.model.device) 533 | 534 | # Find text lengths of each 535 | # Oddly, this is optimal on GPU - causes a 30% slowdown if "optimized" 536 | # Be very careful with the type and device of this - can cause 537 | # a big slowdown if put on device 538 | is_special = ( 539 | (input_ids.unsqueeze(-1) == self.special_token_ids).any(-1).cpu() 540 | ) # (batch, seq_len) 541 | text_lengths = [] 542 | for i in range(input_ids.shape[0]): 543 | special_positions = is_special[i].nonzero(as_tuple=True)[0] 544 | if len(special_positions) > 0: 545 | # Assuming special tokens are contiguous at the start 546 | prefix_len = special_positions[-1].item() + 1 547 | else: 548 | prefix_len = 0 549 | text_lengths.append(input_ids.shape[1] - prefix_len) 550 | text_lengths = torch.tensor(text_lengths, dtype=torch.long) 551 | 552 | cache_position = self.get_cache_position( 553 | input_ids.shape[1], attention_mask, prefill=True 554 | ) 555 | with settings.INFERENCE_MODE(): 556 | image_embeddings = self.model.get_image_embeddings( 557 | pixel_values=image_tiles, 558 | grid_thw=grid_thw, 559 | encoder_chunk_size=self.get_encoder_chunk_size(), 560 | valid_batch_size=valid_batch_size, 561 | max_batch_size=self.kv_cache.max_batch_size, 562 | ) 563 | mark_step() 564 | 565 | outputs = self.model( 566 | input_ids=input_ids, 567 | image_embeddings=image_embeddings, 568 | attention_mask=attention_mask, 569 | position_ids=position_ids, 570 | cache_position=cache_position, 571 | inputs_embeds=None, 572 | past_key_values=self.kv_cache, 573 | use_cache=True, 574 | encoder_chunk_size=self.get_encoder_chunk_size(), 575 | cache_idxs=idxs_to_merge, 576 | prefill=True, 577 | num_valid_tokens=None, # Not required during prefill 578 | text_lengths=text_lengths, 579 | valid_batch_size=valid_batch_size, 580 | logits_to_keep=1, 581 | ) 582 | 583 | # Process outputs 584 | processed_outputs = self.process_outputs( 585 | outputs, max_lookahead_tokens=max_lookahead_tokens 586 | ) 587 | # Multi-token prediction 588 | predicted_tokens = processed_outputs.input_ids.shape[1] 589 | num_valid_tokens = ( 590 | torch.ones((input_ids.shape[0]), device=self.model.device, dtype=torch.long) 591 | * predicted_tokens 592 | ) 593 | num_predicted_tokens = ( 594 | torch.ones( 595 | (input_ids.shape[0], 1), device=self.model.device, dtype=torch.long 596 | ) 597 | * predicted_tokens 598 | ) 599 | 600 | self.kv_cache.prefill_attention_mask_update( 601 | attention_mask, idxs_to_merge, valid_batch_size, text_lengths 602 | ) 603 | self.kv_cache.update_text_counts(idxs_to_merge, valid_batch_size, text_lengths) 604 | 605 | full_batch = len(idxs_to_merge) == self.kv_cache.max_batch_size 606 | 607 | # If full batch, then we can ignore current_inputs 608 | if current_inputs is None or full_batch: 609 | new_seq_len = processed_outputs.input_ids.shape[1] 610 | # No padding tokens - So we can safely set position_ids this way 611 | position_ids = position_ids[:, -1:] + torch.arange( 612 | 1, new_seq_len + 1, device=position_ids.device 613 | ) 614 | new_input = ContinuousBatchInput( 615 | input_ids=processed_outputs.input_ids, 616 | input_boxes=processed_outputs.bbox_preds, 617 | position_ids=position_ids, 618 | num_valid_tokens=num_valid_tokens, 619 | num_predicted_tokens=num_predicted_tokens, 620 | needs_bbox_embedding=needs_bbox_embedding, 621 | ) 622 | 623 | return ( 624 | new_input, 625 | processed_outputs, 626 | range(processed_outputs.input_ids.shape[0]), 627 | ) 628 | 629 | # Merging inputs for next steps 630 | current_input_ids = current_inputs.input_ids 631 | current_position_ids = current_inputs.position_ids 632 | current_input_boxes = current_inputs.input_boxes 633 | 634 | current_needs_bbox_embedding = current_inputs.needs_bbox_embedding 635 | 636 | assert current_input_ids.shape[1] == current_position_ids.shape[1] 637 | input_ids, bbox_preds, position_ids = self.pad_and_shift_input_ids_position_ids( 638 | processed_outputs.input_ids, 639 | processed_outputs.bbox_preds, 640 | position_ids, 641 | new_seq_len=current_input_ids.shape[1], 642 | ) 643 | 644 | current_input_ids[idxs_to_merge] = input_ids[:valid_batch_size] 645 | current_input_boxes[idxs_to_merge] = bbox_preds[:valid_batch_size] 646 | current_position_ids[idxs_to_merge] = position_ids[:valid_batch_size] 647 | 648 | current_num_valid_tokens = current_inputs.num_valid_tokens 649 | current_num_valid_tokens[idxs_to_merge] = num_valid_tokens[:valid_batch_size] 650 | 651 | current_num_predicted_tokens = current_inputs.num_predicted_tokens 652 | current_num_predicted_tokens[idxs_to_merge] = num_predicted_tokens[ 653 | :valid_batch_size 654 | ] 655 | current_needs_bbox_embedding[idxs_to_merge] = needs_bbox_embedding[ 656 | :valid_batch_size 657 | ] 658 | 659 | new_input = ContinuousBatchInput( 660 | input_ids=current_input_ids, 661 | input_boxes=current_input_boxes, 662 | position_ids=current_position_ids, 663 | num_valid_tokens=current_num_valid_tokens, 664 | num_predicted_tokens=current_num_predicted_tokens, 665 | needs_bbox_embedding=current_needs_bbox_embedding, 666 | ) 667 | 668 | return new_input, processed_outputs, idxs_to_merge 669 | 670 | def get_max_image_token_count( 671 | self, images: list[np.ndarray], tasks: List[TaskNames] 672 | ) -> int: 673 | def compute_scaled_size( 674 | H: int, W: int, max_size: Tuple[int, int] 675 | ) -> Tuple[int, int]: 676 | max_W, max_H = max_size 677 | min_W, min_H = (168, 168) 678 | 679 | current_pixels = H * W 680 | max_pixels = max_H * max_W 681 | min_pixels = min_H * min_W 682 | current_pixels = max(1, current_pixels) # Avoid zero division 683 | 684 | if current_pixels > max_pixels: 685 | scale = (max_pixels / current_pixels) ** 0.5 686 | return math.floor(H * scale), math.floor(W * scale) 687 | elif current_pixels < min_pixels: 688 | scale = (min_pixels / current_pixels) ** 0.5 689 | return math.ceil(H * scale), math.ceil(W * scale) 690 | return H, W 691 | 692 | def get_tile_count(H: int, W: int, factor: int) -> int: 693 | H_bar = math.ceil(H / factor) * factor 694 | W_bar = math.ceil(W / factor) * factor 695 | grid_h = H_bar / self.processor.patch_size 696 | grid_w = W_bar // self.processor.patch_size 697 | return grid_h * grid_w 698 | 699 | max_tokens = 0 700 | factor = self.processor.patch_size * self.processor.merge_size 701 | 702 | for image, task in zip(images, tasks): 703 | H, W = image.shape[:2] 704 | max_size = self.tasks[task]["img_size"] 705 | scaled_H, scaled_W = compute_scaled_size(H, W, max_size) 706 | token_count = get_tile_count(scaled_H, scaled_W, factor) / ( 707 | self.processor.merge_size**2 708 | ) 709 | max_tokens = max(max_tokens, token_count) 710 | 711 | # Extra 10 to account for EOS/BOS/Rotation token etc. 712 | return 10 + self.processor.num_register_tokens + int(max_tokens) 713 | 714 | def prediction_loop( 715 | self, 716 | images: List[np.ndarray], 717 | input_texts: List[str], 718 | task_names: List[TaskNames], 719 | batch_size: int | None = None, 720 | max_tokens: int | None = None, 721 | max_sliding_window: int | None = None, 722 | math_mode: bool = True, 723 | drop_repeated_tokens: bool = True, 724 | max_lookahead_tokens: Optional[int] = None, 725 | top_k: int = 0, 726 | tqdm_desc: str = "Recognizing Text" 727 | ) -> tuple: 728 | allowed_tasks = self.tasks.keys() 729 | assert all([task_name in allowed_tasks for task_name in task_names]), ( 730 | f"One or more tasks in {task_names} is not supported. Supported tasks are {allowed_tasks}" 731 | ) 732 | 733 | predicted_tokens = [[] for _ in range(len(images))] 734 | scores = [[] for _ in range(len(images))] 735 | topk_probs = [[] for _ in range(len(images))] 736 | 737 | if batch_size is None: 738 | batch_size = self.get_batch_size() 739 | 740 | batch_size = min(len(images), batch_size) 741 | current_inputs = None 742 | 743 | max_image_tokens = self.get_max_image_token_count(images, task_names) 744 | if max_sliding_window is None: 745 | max_sliding_window = self.model.config.sliding_window 746 | self.setup_cache( 747 | batch_size, 748 | max_cache_len=max_image_tokens + max_sliding_window + self.extra_token_count.get(settings.TORCH_DEVICE_MODEL, 0), 749 | max_sliding_window=max_sliding_window, 750 | ) 751 | 752 | batch_max_tokens = {} 753 | for idx, (img, txt, task) in enumerate(zip(images, input_texts, task_names)): 754 | self.prompt_queue.append( 755 | FoundationPrompt( 756 | id=idx, task_name=task, text=txt, image=img, math_mode=math_mode 757 | ) 758 | ) 759 | batch_max_tokens[idx] = ( 760 | max_tokens 761 | or settings.FOUNDATION_MAX_TOKENS 762 | or self.tasks[task]["max_tokens"] 763 | ) 764 | 765 | overall_max_tokens = max(batch_max_tokens.values()) 766 | 767 | pbar = tqdm( 768 | total=len(self.prompt_queue), 769 | desc=tqdm_desc, 770 | disable=self.disable_tqdm, 771 | ) 772 | 773 | batch_bboxes = torch.zeros(len(images), overall_max_tokens, 6) 774 | batch_pos = [0] * len(images) 775 | 776 | while self.prompt_queue or self.num_active_slots > 0: 777 | if ( 778 | self.num_empty_slots / batch_size 779 | ) >= self.min_prefill_ratio and self.prompt_queue: 780 | updated_inputs, outputs, merge_idxs = self.prefill( 781 | current_inputs, max_lookahead_tokens=0 782 | ) 783 | 784 | predicted_tokens_cpu = outputs.preds.cpu() 785 | scores_cpu = outputs.scores.cpu() 786 | bbox_preds_cpu = outputs.bbox_preds.cpu() 787 | 788 | if top_k > 0: 789 | batch_top_k_probs, batch_top_k_indices = torch.topk( 790 | outputs.token_probs, k=top_k, dim=-1 791 | ) 792 | batch_top_k_probs_cpu = batch_top_k_probs.cpu() 793 | batch_top_k_indices_cpu = batch_top_k_indices.cpu() 794 | 795 | for temp_idx, b_idx in enumerate(merge_idxs): 796 | if self.batch_prompt_mapping[b_idx] is not None: 797 | p_idx = self.batch_prompt_mapping[b_idx] 798 | seq_len = predicted_tokens_cpu.shape[1] 799 | for t_idx in range(seq_len): 800 | token = predicted_tokens_cpu[temp_idx, t_idx].item() 801 | predicted_tokens[p_idx].append(token) 802 | batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[ 803 | temp_idx, t_idx 804 | ] 805 | batch_pos[p_idx] += 1 806 | scores[p_idx].append(scores_cpu[temp_idx, t_idx].item()) 807 | 808 | if top_k > 0: 809 | top_k_scores = { 810 | batch_top_k_indices_cpu[temp_idx, t_idx][ 811 | k 812 | ].item(): batch_top_k_probs_cpu[temp_idx, t_idx][ 813 | k 814 | ].item() 815 | for k in range(top_k) 816 | } 817 | topk_probs[p_idx].append(top_k_scores) 818 | 819 | if token in [ 820 | self.processor.eos_token_id, 821 | self.processor.no_output_token, 822 | ]: 823 | self.batch_prompt_mapping[b_idx] = None 824 | pbar.update(1) 825 | break 826 | else: 827 | updated_inputs, outputs = self.decode( 828 | current_inputs, max_lookahead_tokens=max_lookahead_tokens 829 | ) 830 | mark_step() 831 | 832 | predicted_tokens_cpu = outputs.preds.cpu() 833 | scores_cpu = outputs.scores.cpu() 834 | bbox_preds_cpu = outputs.bbox_preds.cpu() 835 | 836 | if top_k > 0: 837 | batch_top_k_probs, batch_top_k_indices = torch.topk( 838 | outputs.token_probs, k=top_k, dim=-1 839 | ) 840 | batch_top_k_probs_cpu = batch_top_k_probs.cpu() 841 | batch_top_k_indices_cpu = batch_top_k_indices.cpu() 842 | 843 | for b_idx, p_idx in self.batch_prompt_mapping.items(): 844 | if p_idx is not None: 845 | seq_len = predicted_tokens_cpu.shape[1] 846 | num_tokens = updated_inputs.num_valid_tokens[b_idx].item() 847 | should_stop = False 848 | 849 | for t_idx in range(seq_len): 850 | # don't use multitoken prediction for lower confidence tokens 851 | if t_idx > 0 and num_tokens < seq_len: 852 | # roll so tokens are right aligned 853 | updated_inputs.input_ids[b_idx] = ( 854 | updated_inputs.input_ids[b_idx].roll( 855 | shifts=seq_len - num_tokens, dims=0 856 | ) 857 | ) 858 | # don't need to roll position_ids because that's handled in `decode` (and when we do beacon tokens) 859 | break 860 | 861 | token = predicted_tokens_cpu[b_idx, t_idx].item() 862 | predicted_tokens[p_idx].append(token) 863 | batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[ 864 | b_idx, t_idx 865 | ] 866 | batch_pos[p_idx] += 1 867 | scores[p_idx].append(scores_cpu[b_idx, t_idx].item()) 868 | 869 | if top_k > 0: 870 | top_k_scores = { 871 | batch_top_k_indices_cpu[temp_idx, t_idx][ 872 | k 873 | ].item(): batch_top_k_probs_cpu[temp_idx, t_idx][ 874 | k 875 | ].item() 876 | for k in range(top_k) 877 | } 878 | topk_probs[p_idx].append(top_k_scores) 879 | 880 | repeats = len(predicted_tokens[p_idx]) >= batch_max_tokens[ 881 | p_idx 882 | ] or ( 883 | drop_repeated_tokens 884 | and detect_repeat_token(predicted_tokens[p_idx]) 885 | and task_names[p_idx] 886 | in [ 887 | TaskNames.ocr_with_boxes, 888 | TaskNames.ocr_without_boxes, 889 | ] 890 | ) 891 | if ( 892 | token 893 | in [ 894 | self.processor.eos_token_id, 895 | self.processor.pad_token_id, 896 | ] 897 | or repeats 898 | ): 899 | should_stop = True 900 | break 901 | 902 | if should_stop: 903 | self.batch_prompt_mapping[b_idx] = None 904 | pbar.update(1) 905 | 906 | # Update inputs and mark XLA step 907 | current_inputs = updated_inputs 908 | 909 | pbar.close() 910 | 911 | del self.kv_cache 912 | self.kv_cache = None 913 | torch.cuda.empty_cache() 914 | 915 | return predicted_tokens, batch_bboxes, scores, topk_probs 916 | ``` -------------------------------------------------------------------------------- /surya/common/donut/encoder.py: -------------------------------------------------------------------------------- ```python 1 | import collections.abc 2 | import math 3 | from dataclasses import dataclass 4 | from typing import Optional, Tuple, Union 5 | 6 | import torch 7 | import torch.utils.checkpoint 8 | from torch import nn 9 | 10 | from transformers.activations import ACT2FN 11 | from transformers.pytorch_utils import ( 12 | find_pruneable_heads_and_indices, 13 | meshgrid, 14 | prune_linear_layer, 15 | ) 16 | from transformers.utils import ModelOutput 17 | from transformers import DonutSwinConfig 18 | 19 | from surya.common.pretrained import SuryaPreTrainedModel 20 | from surya.common.xla import mark_step 21 | 22 | _EXPECTED_OUTPUT_SHAPE = [1, 49, 1024] 23 | 24 | 25 | @dataclass 26 | # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin 27 | class DonutSwinEncoderOutput(ModelOutput): 28 | last_hidden_state: torch.FloatTensor = None 29 | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None 30 | attentions: Optional[Tuple[torch.FloatTensor, ...]] = None 31 | reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None 32 | 33 | 34 | @dataclass 35 | class DonutSwinModelOutput(ModelOutput): 36 | last_hidden_state: torch.FloatTensor = None 37 | 38 | 39 | # Copied from transformers.models.swin.modeling_swin.window_partition 40 | def window_partition(input_feature, window_size): 41 | """ 42 | Partitions the given input into windows. 43 | """ 44 | batch_size, height, width, num_channels = input_feature.shape 45 | input_feature = input_feature.view( 46 | batch_size, 47 | height // window_size, 48 | window_size, 49 | width // window_size, 50 | window_size, 51 | num_channels, 52 | ) 53 | windows = ( 54 | input_feature.permute(0, 1, 3, 2, 4, 5) 55 | .contiguous() 56 | .view(-1, window_size, window_size, num_channels) 57 | ) 58 | return windows 59 | 60 | 61 | # Copied from transformers.models.swin.modeling_swin.window_reverse 62 | def window_reverse(windows, window_size, height, width): 63 | """ 64 | Merges windows to produce higher resolution features. 65 | """ 66 | num_channels = windows.shape[-1] 67 | windows = windows.view( 68 | -1, 69 | height // window_size, 70 | width // window_size, 71 | window_size, 72 | window_size, 73 | num_channels, 74 | ) 75 | windows = ( 76 | windows.permute(0, 1, 3, 2, 4, 5) 77 | .contiguous() 78 | .view(-1, height, width, num_channels) 79 | ) 80 | return windows 81 | 82 | 83 | # Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin 84 | class DonutSwinEmbeddings(nn.Module): 85 | """ 86 | Construct the patch and position embeddings. Optionally, also the mask token. 87 | """ 88 | 89 | def __init__(self, config, use_mask_token=False): 90 | super().__init__() 91 | 92 | self.patch_embeddings = DonutSwinPatchEmbeddings(config) 93 | num_patches = self.patch_embeddings.num_patches 94 | self.patch_grid = self.patch_embeddings.grid_size 95 | self.mask_token = ( 96 | nn.Parameter(torch.zeros(1, 1, config.embed_dim)) 97 | if use_mask_token 98 | else None 99 | ) 100 | 101 | self.position_embeddings = None 102 | self.row_embeddings = None 103 | self.column_embeddings = None 104 | if config.use_absolute_embeddings: 105 | self.position_embeddings = nn.Parameter( 106 | torch.zeros(1, num_patches + 1, config.embed_dim) 107 | ) 108 | 109 | if hasattr(config, "use_2d_embeddings") and config.use_2d_embeddings: 110 | self.row_embeddings = nn.Parameter( 111 | torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim) 112 | ) 113 | self.column_embeddings = nn.Parameter( 114 | torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim) 115 | ) 116 | 117 | self.norm = nn.LayerNorm(config.embed_dim) 118 | 119 | def interpolate_pos_encoding( 120 | self, embeddings: torch.Tensor, height: int, width: int 121 | ) -> torch.Tensor: 122 | """ 123 | This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher 124 | resolution images. 125 | 126 | Source: 127 | https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 128 | """ 129 | 130 | num_patches = embeddings.shape[1] - 1 131 | num_positions = self.position_embeddings.shape[1] - 1 132 | if num_patches == num_positions and height == width: 133 | return self.position_embeddings 134 | class_pos_embed = self.position_embeddings[:, 0] 135 | patch_pos_embed = self.position_embeddings[:, 1:] 136 | dim = embeddings.shape[-1] 137 | h0 = height // self.config.patch_size 138 | w0 = width // self.config.patch_size 139 | # we add a small number to avoid floating point error in the interpolation 140 | # see discussion at https://github.com/facebookresearch/dino/issues/8 141 | h0, w0 = h0 + 0.1, w0 + 0.1 142 | patch_pos_embed = patch_pos_embed.reshape( 143 | 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim 144 | ) 145 | patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) 146 | patch_pos_embed = nn.functional.interpolate( 147 | patch_pos_embed, 148 | scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), 149 | mode="bicubic", 150 | align_corners=False, 151 | ) 152 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 153 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 154 | 155 | def forward( 156 | self, 157 | pixel_values: Optional[torch.FloatTensor], 158 | bool_masked_pos: Optional[torch.BoolTensor] = None, 159 | interpolate_pos_encoding: bool = False, 160 | ) -> Tuple[torch.Tensor]: 161 | _, num_channels, height, width = pixel_values.shape 162 | embeddings, output_dimensions = self.patch_embeddings(pixel_values) 163 | embeddings = self.norm(embeddings) 164 | batch_size, seq_len, _ = embeddings.size() 165 | 166 | if bool_masked_pos is not None: 167 | mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) 168 | # replace the masked visual tokens by mask_tokens 169 | mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) 170 | embeddings = embeddings * (1.0 - mask) + mask_tokens * mask 171 | 172 | if self.position_embeddings is not None: 173 | if interpolate_pos_encoding: 174 | embeddings = embeddings + self.interpolate_pos_encoding( 175 | embeddings, height, width 176 | ) 177 | else: 178 | embeddings = embeddings + self.position_embeddings[:, :seq_len] 179 | 180 | if self.row_embeddings is not None and self.column_embeddings is not None: 181 | # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ... 182 | row_embeddings = self.row_embeddings[ 183 | :, : output_dimensions[0], : 184 | ].repeat_interleave(output_dimensions[1], dim=1) 185 | column_embeddings = self.column_embeddings[ 186 | :, : output_dimensions[1], : 187 | ].repeat(1, output_dimensions[0], 1) 188 | 189 | embeddings = embeddings + row_embeddings + column_embeddings 190 | 191 | return embeddings, output_dimensions 192 | 193 | 194 | # Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin 195 | class DonutSwinPatchEmbeddings(nn.Module): 196 | """ 197 | This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial 198 | `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a 199 | Transformer. 200 | """ 201 | 202 | def __init__(self, config): 203 | super().__init__() 204 | image_size, patch_size = config.image_size, config.patch_size 205 | num_channels, hidden_size = config.num_channels, config.embed_dim 206 | image_size = ( 207 | image_size 208 | if isinstance(image_size, collections.abc.Iterable) 209 | else (image_size, image_size) 210 | ) 211 | patch_size = ( 212 | patch_size 213 | if isinstance(patch_size, collections.abc.Iterable) 214 | else (patch_size, patch_size) 215 | ) 216 | num_patches = (image_size[1] // patch_size[1]) * ( 217 | image_size[0] // patch_size[0] 218 | ) 219 | self.image_size = image_size 220 | self.patch_size = patch_size 221 | self.num_channels = num_channels 222 | self.num_patches = num_patches 223 | self.grid_size = ( 224 | image_size[0] // patch_size[0], 225 | image_size[1] // patch_size[1], 226 | ) 227 | 228 | self.projection = nn.Conv2d( 229 | num_channels, hidden_size, kernel_size=patch_size, stride=patch_size 230 | ) 231 | 232 | def maybe_pad(self, pixel_values, height, width): 233 | if width % self.patch_size[1] != 0: 234 | pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) 235 | pixel_values = nn.functional.pad(pixel_values, pad_values) 236 | if height % self.patch_size[0] != 0: 237 | pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) 238 | pixel_values = nn.functional.pad(pixel_values, pad_values) 239 | return pixel_values 240 | 241 | def forward( 242 | self, pixel_values: Optional[torch.FloatTensor] 243 | ) -> Tuple[torch.Tensor, Tuple[int]]: 244 | _, num_channels, height, width = pixel_values.shape 245 | # pad the input to be divisible by self.patch_size, if needed 246 | pixel_values = self.maybe_pad(pixel_values, height, width) 247 | embeddings = self.projection(pixel_values) 248 | _, _, height, width = embeddings.shape 249 | output_dimensions = (height, width) 250 | embeddings = embeddings.flatten(2).transpose(1, 2) 251 | 252 | return embeddings, output_dimensions 253 | 254 | 255 | # Copied from transformers.models.swin.modeling_swin.SwinPatchMerging 256 | class DonutSwinPatchMerging(nn.Module): 257 | """ 258 | Patch Merging Layer. 259 | 260 | Args: 261 | input_resolution (`Tuple[int]`): 262 | Resolution of input feature. 263 | dim (`int`): 264 | Number of input channels. 265 | norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): 266 | Normalization layer class. 267 | """ 268 | 269 | def __init__( 270 | self, 271 | input_resolution: Tuple[int], 272 | dim: int, 273 | norm_layer: nn.Module = nn.LayerNorm, 274 | ) -> None: 275 | super().__init__() 276 | self.input_resolution = input_resolution 277 | self.dim = dim 278 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 279 | self.norm = norm_layer(4 * dim) 280 | 281 | def maybe_pad(self, input_feature, height, width): 282 | should_pad = (height % 2 == 1) or (width % 2 == 1) 283 | if should_pad: 284 | pad_values = (0, 0, 0, width % 2, 0, height % 2) 285 | input_feature = nn.functional.pad(input_feature, pad_values) 286 | 287 | return input_feature 288 | 289 | def forward( 290 | self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int] 291 | ) -> torch.Tensor: 292 | height, width = input_dimensions 293 | # `dim` is height * width 294 | batch_size, dim, num_channels = input_feature.shape 295 | 296 | input_feature = input_feature.view(batch_size, height, width, num_channels) 297 | # pad input to be disible by width and height, if needed 298 | input_feature = self.maybe_pad(input_feature, height, width) 299 | # [batch_size, height/2, width/2, num_channels] 300 | input_feature_0 = input_feature[:, 0::2, 0::2, :] 301 | # [batch_size, height/2, width/2, num_channels] 302 | input_feature_1 = input_feature[:, 1::2, 0::2, :] 303 | # [batch_size, height/2, width/2, num_channels] 304 | input_feature_2 = input_feature[:, 0::2, 1::2, :] 305 | # [batch_size, height/2, width/2, num_channels] 306 | input_feature_3 = input_feature[:, 1::2, 1::2, :] 307 | # batch_size height/2 width/2 4*num_channels 308 | input_feature = torch.cat( 309 | [input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1 310 | ) 311 | input_feature = input_feature.view( 312 | batch_size, -1, 4 * num_channels 313 | ) # batch_size height/2*width/2 4*C 314 | 315 | input_feature = self.norm(input_feature) 316 | input_feature = self.reduction(input_feature) 317 | 318 | return input_feature 319 | 320 | 321 | # Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin 322 | class DonutSwinSelfAttention(nn.Module): 323 | def __init__(self, config, dim, num_heads, num_kv_heads, window_size): 324 | super().__init__() 325 | if dim % num_heads != 0: 326 | raise ValueError( 327 | f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" 328 | ) 329 | 330 | self.num_attention_heads = num_heads 331 | self.num_kv_heads = num_kv_heads 332 | self.kv_repeats = self.num_attention_heads // self.num_kv_heads 333 | self.attention_head_size = int(dim / num_heads) 334 | self.all_head_size = self.num_attention_heads * self.attention_head_size 335 | self.kv_head_size = self.num_kv_heads * self.attention_head_size 336 | self.window_size = ( 337 | window_size 338 | if isinstance(window_size, collections.abc.Iterable) 339 | else (window_size, window_size) 340 | ) 341 | 342 | self.relative_position_bias_table = nn.Parameter( 343 | torch.zeros( 344 | (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads 345 | ) 346 | ) 347 | 348 | # get pair-wise relative position index for each token inside the window 349 | coords_h = torch.arange(self.window_size[0]) 350 | coords_w = torch.arange(self.window_size[1]) 351 | coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) 352 | coords_flatten = torch.flatten(coords, 1) 353 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 354 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 355 | relative_coords[:, :, 0] += self.window_size[0] - 1 356 | relative_coords[:, :, 1] += self.window_size[1] - 1 357 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 358 | relative_position_index = relative_coords.sum(-1) 359 | self.register_buffer("relative_position_index", relative_position_index) 360 | 361 | self.query = nn.Linear( 362 | self.all_head_size, self.all_head_size, bias=config.qkv_bias 363 | ) 364 | self.key = nn.Linear( 365 | self.all_head_size, self.kv_head_size, bias=config.qkv_bias 366 | ) 367 | self.value = nn.Linear( 368 | self.all_head_size, self.kv_head_size, bias=config.qkv_bias 369 | ) 370 | 371 | def transpose_for_scores(self, x): 372 | new_x_shape = x.size()[:-1] + ( 373 | self.num_attention_heads, 374 | self.attention_head_size, 375 | ) 376 | x = x.view(new_x_shape) 377 | return x.permute(0, 2, 1, 3) 378 | 379 | def transpose_kv_for_scores(self, x, repeats): 380 | new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size) 381 | x = x.view(new_x_shape) 382 | x = x.repeat( 383 | 1, 1, repeats, 1 384 | ) # repeat the values for each key-value head to match query dim 385 | return x.permute(0, 2, 1, 3).contiguous() 386 | 387 | def forward( 388 | self, 389 | hidden_states: torch.Tensor, 390 | attention_mask: Optional[torch.FloatTensor] = None, 391 | head_mask: Optional[torch.FloatTensor] = None, 392 | output_attentions: Optional[bool] = False, 393 | ) -> Tuple[torch.Tensor]: 394 | batch_size, dim, num_channels = hidden_states.shape 395 | mixed_query_layer = self.query(hidden_states) 396 | 397 | # Final is (batch_size, num_attention_heads, seq_len, attention_head_size) 398 | key_layer = self.transpose_kv_for_scores( 399 | self.key(hidden_states), self.kv_repeats 400 | ) 401 | value_layer = self.transpose_kv_for_scores( 402 | self.value(hidden_states), self.kv_repeats 403 | ) 404 | query_layer = self.transpose_for_scores(mixed_query_layer) 405 | 406 | relative_position_bias = self.relative_position_bias_table[ 407 | self.relative_position_index.view(-1) 408 | ] 409 | relative_position_bias = relative_position_bias.view( 410 | self.window_size[0] * self.window_size[1], 411 | self.window_size[0] * self.window_size[1], 412 | -1, 413 | ) 414 | relative_position_bias = ( 415 | relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) 416 | ) 417 | relative_position_bias = relative_position_bias.repeat(batch_size, 1, 1, 1) 418 | 419 | if attention_mask is None: 420 | attention_mask = relative_position_bias 421 | else: 422 | mask_shape = attention_mask.shape[0] 423 | repeat_count = batch_size // mask_shape 424 | attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1) 425 | attention_mask = attention_mask + relative_position_bias 426 | 427 | attn_output = torch.nn.functional.scaled_dot_product_attention( 428 | query_layer, 429 | key_layer, 430 | value_layer, 431 | attn_mask=attention_mask, 432 | dropout_p=0.0, 433 | scale=self.attention_head_size**-0.5, 434 | ) 435 | 436 | attn_output = attn_output.transpose(1, 2).contiguous() 437 | attn_output = attn_output.view(batch_size, dim, num_channels) 438 | 439 | outputs = (attn_output,) 440 | return outputs 441 | 442 | 443 | # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput 444 | class DonutSwinSelfOutput(nn.Module): 445 | def __init__(self, config, dim): 446 | super().__init__() 447 | self.dense = nn.Linear(dim, dim) 448 | 449 | def forward( 450 | self, hidden_states: torch.Tensor, input_tensor: torch.Tensor 451 | ) -> torch.Tensor: 452 | return self.dense(hidden_states) 453 | 454 | 455 | # Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin 456 | class DonutSwinAttention(nn.Module): 457 | def __init__(self, config, dim, num_heads, num_kv_heads, window_size): 458 | super().__init__() 459 | self.self = DonutSwinSelfAttention( 460 | config, dim, num_heads, num_kv_heads, window_size 461 | ) 462 | self.output = DonutSwinSelfOutput(config, dim) 463 | self.pruned_heads = set() 464 | 465 | def prune_heads(self, heads): 466 | if len(heads) == 0: 467 | return 468 | heads, index = find_pruneable_heads_and_indices( 469 | heads, 470 | self.self.num_attention_heads, 471 | self.self.attention_head_size, 472 | self.pruned_heads, 473 | ) 474 | 475 | # Prune linear layers 476 | self.self.query = prune_linear_layer(self.self.query, index) 477 | self.self.key = prune_linear_layer(self.self.key, index) 478 | self.self.value = prune_linear_layer(self.self.value, index) 479 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 480 | 481 | # Update hyper params and store pruned heads 482 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 483 | self.self.all_head_size = ( 484 | self.self.attention_head_size * self.self.num_attention_heads 485 | ) 486 | self.pruned_heads = self.pruned_heads.union(heads) 487 | 488 | def forward( 489 | self, 490 | hidden_states: torch.Tensor, 491 | attention_mask: Optional[torch.FloatTensor] = None, 492 | head_mask: Optional[torch.FloatTensor] = None, 493 | output_attentions: Optional[bool] = False, 494 | ) -> Tuple[torch.Tensor]: 495 | self_outputs = self.self( 496 | hidden_states, attention_mask, head_mask, output_attentions 497 | ) 498 | attention_output = self.output(self_outputs[0], hidden_states) 499 | outputs = (attention_output,) + self_outputs[ 500 | 1: 501 | ] # add attentions if we output them 502 | return outputs 503 | 504 | 505 | # Copied from transformers.models.swin.modeling_swin.SwinIntermediate 506 | class DonutSwinIntermediate(nn.Module): 507 | def __init__(self, config, dim): 508 | super().__init__() 509 | self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) 510 | if isinstance(config.hidden_act, str): 511 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 512 | else: 513 | self.intermediate_act_fn = config.hidden_act 514 | 515 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 516 | hidden_states = self.dense(hidden_states) 517 | hidden_states = self.intermediate_act_fn(hidden_states) 518 | return hidden_states 519 | 520 | 521 | # Copied from transformers.models.swin.modeling_swin.SwinOutput 522 | class DonutSwinOutput(nn.Module): 523 | def __init__(self, config, dim): 524 | super().__init__() 525 | self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) 526 | 527 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 528 | return self.dense(hidden_states) 529 | 530 | 531 | # Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin 532 | class DonutSwinLayer(nn.Module): 533 | def __init__( 534 | self, config, dim, input_resolution, num_heads, num_kv_heads, shift_size=0 535 | ): 536 | super().__init__() 537 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 538 | self.shift_size = shift_size 539 | self.window_size = config.window_size 540 | self.input_resolution = input_resolution 541 | self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) 542 | self.attention = DonutSwinAttention( 543 | config, dim, num_heads, num_kv_heads, window_size=self.window_size 544 | ) 545 | self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) 546 | self.intermediate = DonutSwinIntermediate(config, dim) 547 | self.output = DonutSwinOutput(config, dim) 548 | 549 | def set_shift_and_window_size(self, input_resolution): 550 | if min(input_resolution) <= self.window_size: 551 | # if window size is larger than input resolution, we don't partition windows 552 | self.shift_size = int(0) 553 | self.window_size = ( 554 | torch.min(torch.tensor(input_resolution)) 555 | if torch.jit.is_tracing() 556 | else min(input_resolution) 557 | ) 558 | 559 | def get_attn_mask(self, height, width, dtype, device): 560 | if self.shift_size > 0: 561 | # calculate attention mask for SW-MSA 562 | img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) 563 | height_slices = ( 564 | slice(0, -self.window_size), 565 | slice(-self.window_size, -self.shift_size), 566 | slice(-self.shift_size, None), 567 | ) 568 | width_slices = ( 569 | slice(0, -self.window_size), 570 | slice(-self.window_size, -self.shift_size), 571 | slice(-self.shift_size, None), 572 | ) 573 | count = 0 574 | for height_slice in height_slices: 575 | for width_slice in width_slices: 576 | img_mask[:, height_slice, width_slice, :] = count 577 | count += 1 578 | 579 | mask_windows = window_partition(img_mask, self.window_size) 580 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 581 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 582 | attn_mask = attn_mask.masked_fill( 583 | attn_mask != 0, float(-100.0) 584 | ).masked_fill(attn_mask == 0, float(0.0)) 585 | else: 586 | attn_mask = None 587 | return attn_mask 588 | 589 | def maybe_pad(self, hidden_states, height, width): 590 | pad_right = (self.window_size - width % self.window_size) % self.window_size 591 | pad_bottom = (self.window_size - height % self.window_size) % self.window_size 592 | pad_values = (0, 0, 0, pad_right, 0, pad_bottom) 593 | hidden_states = nn.functional.pad(hidden_states, pad_values) 594 | return hidden_states, pad_values 595 | 596 | def forward( 597 | self, 598 | hidden_states: torch.Tensor, 599 | input_dimensions: Tuple[int, int], 600 | head_mask: Optional[torch.FloatTensor] = None, 601 | output_attentions: Optional[bool] = False, 602 | always_partition: Optional[bool] = False, 603 | ) -> Tuple[torch.Tensor, torch.Tensor]: 604 | if not always_partition: 605 | self.set_shift_and_window_size(input_dimensions) 606 | else: 607 | pass 608 | height, width = input_dimensions 609 | batch_size, _, channels = hidden_states.size() 610 | shortcut = hidden_states 611 | 612 | hidden_states = self.layernorm_before(hidden_states) 613 | 614 | hidden_states = hidden_states.view(batch_size, height, width, channels) 615 | 616 | # pad hidden_states to multiples of window size 617 | hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) 618 | 619 | _, height_pad, width_pad, _ = hidden_states.shape 620 | # cyclic shift 621 | if self.shift_size > 0: 622 | shifted_hidden_states = torch.roll( 623 | hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) 624 | ) 625 | else: 626 | shifted_hidden_states = hidden_states 627 | 628 | # partition windows 629 | hidden_states_windows = window_partition( 630 | shifted_hidden_states, self.window_size 631 | ) 632 | hidden_states_windows = hidden_states_windows.view( 633 | -1, self.window_size * self.window_size, channels 634 | ) 635 | attn_mask = self.get_attn_mask( 636 | height_pad, 637 | width_pad, 638 | dtype=hidden_states.dtype, 639 | device=hidden_states_windows.device, 640 | ) 641 | 642 | attention_outputs = self.attention( 643 | hidden_states_windows, 644 | attn_mask, 645 | head_mask, 646 | output_attentions=output_attentions, 647 | ) 648 | 649 | attention_output = attention_outputs[0] 650 | 651 | attention_windows = attention_output.view( 652 | -1, self.window_size, self.window_size, channels 653 | ) 654 | shifted_windows = window_reverse( 655 | attention_windows, self.window_size, height_pad, width_pad 656 | ) 657 | 658 | # reverse cyclic shift 659 | if self.shift_size > 0: 660 | attention_windows = torch.roll( 661 | shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2) 662 | ) 663 | else: 664 | attention_windows = shifted_windows 665 | 666 | was_padded = pad_values[3] > 0 or pad_values[5] > 0 667 | if was_padded: 668 | attention_windows = attention_windows[:, :height, :width, :].contiguous() 669 | 670 | attention_windows = attention_windows.view(batch_size, height * width, channels) 671 | 672 | hidden_states = shortcut + attention_windows 673 | 674 | layer_output = self.layernorm_after(hidden_states) 675 | layer_output = self.intermediate(layer_output) 676 | layer_output = hidden_states + self.output(layer_output) 677 | 678 | layer_outputs = ( 679 | (layer_output, attention_outputs[1]) 680 | if output_attentions 681 | else (layer_output,) 682 | ) 683 | return layer_outputs 684 | 685 | 686 | # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin 687 | class DonutSwinStage(nn.Module): 688 | def __init__( 689 | self, 690 | config, 691 | layer_num, 692 | dim, 693 | input_resolution, 694 | depth, 695 | num_heads, 696 | num_kv_heads, 697 | downsample, 698 | ): 699 | super().__init__() 700 | self.config = config 701 | self.dim = dim 702 | self.blocks = nn.ModuleList( 703 | [ 704 | DonutSwinLayer( 705 | config=config, 706 | dim=dim, 707 | input_resolution=input_resolution, 708 | num_heads=num_heads, 709 | num_kv_heads=num_kv_heads, 710 | shift_size=0 if (i % 2 == 0) else config.window_size // 2, 711 | ) 712 | for i in range(depth) 713 | ] 714 | ) 715 | 716 | # patch merging layer 717 | if downsample is not None: 718 | self.downsample = downsample( 719 | input_resolution, dim=dim, norm_layer=nn.LayerNorm 720 | ) 721 | else: 722 | self.downsample = None 723 | 724 | self.pointing = False 725 | 726 | self.positional_encoding = None 727 | if config.use_positional_embeddings: 728 | self.positional_encoding = self.build_2d_sincos_position_embedding( 729 | input_resolution[1], 730 | input_resolution[0], 731 | embed_dim=dim, 732 | ) 733 | 734 | @staticmethod 735 | def build_2d_sincos_position_embedding( 736 | width, 737 | height, 738 | embed_dim=256, 739 | temperature=10000.0, 740 | device="cpu", 741 | dtype=torch.float32, 742 | ): 743 | grid_w = torch.arange(int(width), dtype=dtype, device=device) 744 | grid_h = torch.arange(int(height), dtype=dtype, device=device) 745 | grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") 746 | if embed_dim % 4 != 0: 747 | raise ValueError( 748 | "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" 749 | ) 750 | pos_dim = embed_dim // 4 751 | omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim 752 | omega = 1.0 / (temperature**omega) 753 | 754 | out_w = grid_w.flatten()[..., None] @ omega[None] 755 | out_h = grid_h.flatten()[..., None] @ omega[None] 756 | 757 | return torch.concat( 758 | [out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1 759 | )[None, :, :] 760 | 761 | def forward( 762 | self, 763 | hidden_states: torch.Tensor, 764 | input_dimensions: Tuple[int, int], 765 | head_mask: Optional[torch.FloatTensor] = None, 766 | output_attentions: Optional[bool] = False, 767 | always_partition: Optional[bool] = False, 768 | ) -> Tuple[torch.Tensor]: 769 | height, width = input_dimensions 770 | 771 | if self.positional_encoding is not None: 772 | hidden_states = hidden_states + self.positional_encoding.to( 773 | hidden_states.dtype 774 | ).to(hidden_states.device) 775 | 776 | for i, layer_module in enumerate(self.blocks): 777 | layer_head_mask = head_mask[i] if head_mask is not None else None 778 | 779 | layer_outputs = layer_module( 780 | hidden_states, 781 | input_dimensions, 782 | layer_head_mask, 783 | output_attentions, 784 | always_partition, 785 | ) 786 | 787 | hidden_states = layer_outputs[0] 788 | 789 | hidden_states_before_downsampling = hidden_states 790 | if self.downsample is not None: 791 | height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 792 | output_dimensions = (height, width, height_downsampled, width_downsampled) 793 | hidden_states = self.downsample( 794 | hidden_states_before_downsampling, input_dimensions 795 | ) 796 | else: 797 | output_dimensions = (height, width, height, width) 798 | 799 | stage_outputs = ( 800 | hidden_states, 801 | hidden_states_before_downsampling, 802 | output_dimensions, 803 | ) 804 | 805 | if output_attentions: 806 | stage_outputs += layer_outputs[1:] 807 | return stage_outputs 808 | 809 | 810 | # Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin 811 | class DonutSwinEncoder(nn.Module): 812 | def __init__(self, config, grid_size): 813 | super().__init__() 814 | self.num_layers = len(config.depths) 815 | self.config = config 816 | self.layers = nn.ModuleList( 817 | [ 818 | DonutSwinStage( 819 | config=config, 820 | layer_num=i_layer, 821 | dim=int(config.embed_dim * 2**i_layer), 822 | input_resolution=( 823 | grid_size[0] // (2**i_layer), 824 | grid_size[1] // (2**i_layer), 825 | ), 826 | depth=config.depths[i_layer], 827 | num_heads=config.num_heads[i_layer], 828 | num_kv_heads=config.num_kv_heads[i_layer] 829 | if hasattr(config, "num_kv_heads") 830 | else config.num_heads[i_layer], 831 | downsample=DonutSwinPatchMerging 832 | if (i_layer < self.num_layers - 1) 833 | else None, 834 | ) 835 | for i_layer in range(self.num_layers) 836 | ] 837 | ) 838 | 839 | self.gradient_checkpointing = False 840 | 841 | def forward( 842 | self, 843 | hidden_states: torch.Tensor, 844 | input_dimensions: Tuple[int, int], 845 | head_mask: Optional[torch.FloatTensor] = None, 846 | output_attentions: Optional[bool] = False, 847 | output_hidden_states: Optional[bool] = False, 848 | output_hidden_states_before_downsampling: Optional[bool] = False, 849 | always_partition: Optional[bool] = False, 850 | return_dict: Optional[bool] = True, 851 | ) -> Union[Tuple, DonutSwinEncoderOutput]: 852 | all_hidden_states = () if output_hidden_states else None 853 | all_reshaped_hidden_states = () if output_hidden_states else None 854 | all_self_attentions = () if output_attentions else None 855 | 856 | if output_hidden_states: 857 | batch_size, _, hidden_size = hidden_states.shape 858 | # rearrange b (h w) c -> b c h w 859 | reshaped_hidden_state = hidden_states.view( 860 | batch_size, *input_dimensions, hidden_size 861 | ) 862 | reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) 863 | all_hidden_states += (hidden_states,) 864 | all_reshaped_hidden_states += (reshaped_hidden_state,) 865 | 866 | for i, layer_module in enumerate(self.layers): 867 | layer_head_mask = head_mask[i] if head_mask is not None else None 868 | 869 | if self.gradient_checkpointing and self.training: 870 | layer_outputs = self._gradient_checkpointing_func( 871 | layer_module.__call__, 872 | hidden_states, 873 | input_dimensions, 874 | layer_head_mask, 875 | output_attentions, 876 | always_partition, 877 | ) 878 | else: 879 | layer_outputs = layer_module( 880 | hidden_states, 881 | input_dimensions, 882 | layer_head_mask, 883 | output_attentions, 884 | always_partition, 885 | ) 886 | 887 | hidden_states = layer_outputs[0] 888 | hidden_states_before_downsampling = layer_outputs[1] 889 | output_dimensions = layer_outputs[2] 890 | input_dimensions = (output_dimensions[-2], output_dimensions[-1]) 891 | 892 | if output_hidden_states and output_hidden_states_before_downsampling: 893 | batch_size, _, hidden_size = hidden_states_before_downsampling.shape 894 | # rearrange b (h w) c -> b c h w 895 | # here we use the original (not downsampled) height and width 896 | reshaped_hidden_state = hidden_states_before_downsampling.view( 897 | batch_size, 898 | *(output_dimensions[0], output_dimensions[1]), 899 | hidden_size, 900 | ) 901 | reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) 902 | all_hidden_states += (hidden_states_before_downsampling,) 903 | all_reshaped_hidden_states += (reshaped_hidden_state,) 904 | elif output_hidden_states and not output_hidden_states_before_downsampling: 905 | batch_size, _, hidden_size = hidden_states.shape 906 | # rearrange b (h w) c -> b c h w 907 | reshaped_hidden_state = hidden_states.view( 908 | batch_size, *input_dimensions, hidden_size 909 | ) 910 | reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) 911 | all_hidden_states += (hidden_states,) 912 | all_reshaped_hidden_states += (reshaped_hidden_state,) 913 | 914 | if output_attentions: 915 | all_self_attentions += layer_outputs[3:] 916 | 917 | if not return_dict: 918 | return tuple( 919 | v 920 | for v in [hidden_states, all_hidden_states, all_self_attentions] 921 | if v is not None 922 | ) 923 | 924 | return DonutSwinEncoderOutput( 925 | last_hidden_state=hidden_states, 926 | hidden_states=all_hidden_states, 927 | attentions=all_self_attentions, 928 | reshaped_hidden_states=all_reshaped_hidden_states, 929 | ) 930 | 931 | 932 | # Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin 933 | class DonutSwinPreTrainedModel(SuryaPreTrainedModel): 934 | """ 935 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 936 | models. 937 | """ 938 | 939 | config_class = DonutSwinConfig 940 | base_model_prefix = "swin" 941 | main_input_name = "pixel_values" 942 | supports_gradient_checkpointing = True 943 | _no_split_modules = ["DonutSwinStage"] 944 | 945 | def _init_weights(self, module): 946 | """Initialize the weights""" 947 | if isinstance(module, (nn.Linear, nn.Conv2d)): 948 | # Slightly different from the TF version which uses truncated_normal for initialization 949 | # cf https://github.com/pytorch/pytorch/pull/5617 950 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 951 | if module.bias is not None: 952 | module.bias.data.zero_() 953 | elif isinstance(module, nn.LayerNorm): 954 | module.bias.data.zero_() 955 | module.weight.data.fill_(1.0) 956 | ``` -------------------------------------------------------------------------------- /surya/ocr_error/model/encoder.py: -------------------------------------------------------------------------------- ```python 1 | from __future__ import annotations 2 | 3 | import math 4 | from typing import Optional, Set, List, Tuple, Union, Dict 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F, MSELoss, CrossEntropyLoss, BCEWithLogitsLoss 10 | from transformers import apply_chunking_to_forward 11 | from transformers.activations import get_activation 12 | from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput 13 | from transformers.pytorch_utils import ( 14 | find_pruneable_heads_and_indices, 15 | prune_linear_layer, 16 | ) 17 | 18 | from transformers.utils import ( 19 | is_flash_attn_greater_or_equal_2_10, 20 | ) 21 | 22 | from surya.common.pretrained import SuryaPreTrainedModel 23 | 24 | from surya.common.s3 import S3DownloaderMixin 25 | from surya.ocr_error.model.config import DistilBertConfig 26 | 27 | 28 | def _get_unpad_data(attention_mask): 29 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 30 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 31 | max_seqlen_in_batch = seqlens_in_batch.max().item() 32 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 33 | return ( 34 | indices, 35 | cu_seqlens, 36 | max_seqlen_in_batch, 37 | ) 38 | 39 | 40 | def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): 41 | position_enc = np.array( 42 | [ 43 | [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] 44 | for pos in range(n_pos) 45 | ] 46 | ) 47 | out.requires_grad = False 48 | out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) 49 | out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) 50 | out.detach_() 51 | 52 | 53 | class Embeddings(nn.Module): 54 | def __init__(self, config: DistilBertConfig): 55 | super().__init__() 56 | self.word_embeddings = nn.Embedding( 57 | config.vocab_size, config.dim, padding_idx=config.pad_token_id 58 | ) 59 | self.position_embeddings = nn.Embedding( 60 | config.max_position_embeddings, config.dim 61 | ) 62 | 63 | self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) 64 | self.dropout = nn.Dropout(config.dropout) 65 | self.register_buffer( 66 | "position_ids", 67 | torch.arange(config.max_position_embeddings).expand((1, -1)), 68 | persistent=False, 69 | ) 70 | 71 | def forward( 72 | self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None 73 | ) -> torch.Tensor: 74 | """ 75 | Parameters: 76 | input_ids (torch.Tensor): 77 | torch.tensor(bs, max_seq_length) The token ids to embed. 78 | input_embeds (*optional*, torch.Tensor): 79 | The pre-computed word embeddings. Can only be passed if the input ids are `None`. 80 | 81 | 82 | Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type 83 | embeddings) 84 | """ 85 | if input_ids is not None: 86 | input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) 87 | 88 | seq_length = input_embeds.size(1) 89 | 90 | # Setting the position-ids to the registered buffer in constructor, it helps 91 | # when tracing the model without passing position-ids, solves 92 | # isues similar to issue #5664 93 | if hasattr(self, "position_ids"): 94 | position_ids = self.position_ids[:, :seq_length] 95 | else: 96 | position_ids = torch.arange( 97 | seq_length, dtype=torch.long, device=input_ids.device 98 | ) # (max_seq_length) 99 | position_ids = position_ids.unsqueeze(0).expand_as( 100 | input_ids 101 | ) # (bs, max_seq_length) 102 | 103 | position_embeddings = self.position_embeddings( 104 | position_ids 105 | ) # (bs, max_seq_length, dim) 106 | 107 | embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim) 108 | embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) 109 | embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) 110 | return embeddings 111 | 112 | 113 | class MultiHeadSelfAttention(nn.Module): 114 | def __init__(self, config: DistilBertConfig): 115 | super().__init__() 116 | self.config = config 117 | 118 | self.n_heads = config.n_heads 119 | self.dim = config.dim 120 | self.dropout = nn.Dropout(p=config.attention_dropout) 121 | self.is_causal = False 122 | 123 | # Have an even number of multi heads that divide the dimensions 124 | if self.dim % self.n_heads != 0: 125 | # Raise value errors for even multi-head attention nodes 126 | raise ValueError( 127 | f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly" 128 | ) 129 | 130 | self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim) 131 | self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim) 132 | self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim) 133 | self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) 134 | 135 | self.pruned_heads: Set[int] = set() 136 | self.attention_head_size = self.dim // self.n_heads 137 | 138 | def prune_heads(self, heads: List[int]): 139 | if len(heads) == 0: 140 | return 141 | heads, index = find_pruneable_heads_and_indices( 142 | heads, self.n_heads, self.attention_head_size, self.pruned_heads 143 | ) 144 | # Prune linear layers 145 | self.q_lin = prune_linear_layer(self.q_lin, index) 146 | self.k_lin = prune_linear_layer(self.k_lin, index) 147 | self.v_lin = prune_linear_layer(self.v_lin, index) 148 | self.out_lin = prune_linear_layer(self.out_lin, index, dim=1) 149 | # Update hyper params 150 | self.n_heads = self.n_heads - len(heads) 151 | self.dim = self.attention_head_size * self.n_heads 152 | self.pruned_heads = self.pruned_heads.union(heads) 153 | 154 | def forward( 155 | self, 156 | query: torch.Tensor, 157 | key: torch.Tensor, 158 | value: torch.Tensor, 159 | mask: torch.Tensor, 160 | head_mask: Optional[torch.Tensor] = None, 161 | output_attentions: bool = False, 162 | ) -> Tuple[torch.Tensor, ...]: 163 | """ 164 | Parameters: 165 | query: torch.tensor(bs, seq_length, dim) 166 | key: torch.tensor(bs, seq_length, dim) 167 | value: torch.tensor(bs, seq_length, dim) 168 | mask: torch.tensor(bs, seq_length) 169 | 170 | Returns: 171 | weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, 172 | seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` 173 | """ 174 | bs, q_length, dim = query.size() 175 | k_length = key.size(1) 176 | # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' 177 | # assert key.size() == value.size() 178 | 179 | dim_per_head = self.dim // self.n_heads 180 | 181 | mask_reshp = (bs, 1, 1, k_length) 182 | 183 | def shape(x: torch.Tensor) -> torch.Tensor: 184 | """separate heads""" 185 | return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) 186 | 187 | def unshape(x: torch.Tensor) -> torch.Tensor: 188 | """group heads""" 189 | return ( 190 | x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) 191 | ) 192 | 193 | q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) 194 | k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) 195 | v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) 196 | 197 | q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) 198 | scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) 199 | mask = ( 200 | (mask == 0).view(mask_reshp).expand_as(scores) 201 | ) # (bs, n_heads, q_length, k_length) 202 | scores = scores.masked_fill( 203 | mask, torch.tensor(torch.finfo(scores.dtype).min) 204 | ) # (bs, n_heads, q_length, k_length) 205 | 206 | weights = nn.functional.softmax( 207 | scores, dim=-1 208 | ) # (bs, n_heads, q_length, k_length) 209 | weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) 210 | 211 | # Mask heads if we want to 212 | if head_mask is not None: 213 | weights = weights * head_mask 214 | 215 | context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) 216 | context = unshape(context) # (bs, q_length, dim) 217 | context = self.out_lin(context) # (bs, q_length, dim) 218 | 219 | if output_attentions: 220 | return (context, weights) 221 | else: 222 | return (context,) 223 | 224 | 225 | class DistilBertFlashAttention2(MultiHeadSelfAttention): 226 | """ 227 | DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module 228 | stays untouched. The only required change would be on the forward pass where it needs to correctly call the public 229 | API of flash attention and deal with padding tokens in case the input contains any of them. 230 | """ 231 | 232 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 233 | def __init__(self, *args, **kwargs): 234 | super().__init__(*args, **kwargs) 235 | 236 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 237 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 238 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 239 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 240 | 241 | def forward( 242 | self, 243 | query: torch.Tensor, 244 | key: torch.Tensor, 245 | value: torch.Tensor, 246 | mask: torch.Tensor, 247 | head_mask: Optional[torch.Tensor] = None, 248 | output_attentions: bool = False, 249 | ) -> Tuple[torch.Tensor, ...]: 250 | """ 251 | Parameters: 252 | query: torch.tensor(bs, seq_length, dim) 253 | key: torch.tensor(bs, seq_length, dim) 254 | value: torch.tensor(bs, seq_length, dim) 255 | mask: torch.tensor(bs, seq_length) 256 | 257 | Returns: 258 | weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, 259 | seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` 260 | """ 261 | batch_size, q_length, dim = query.size() 262 | 263 | dim_per_head = self.dim // self.n_heads 264 | 265 | def reshape(x: torch.Tensor) -> torch.Tensor: 266 | """separate heads""" 267 | return x.view(batch_size, -1, self.n_heads, dim_per_head) 268 | 269 | # Flash attention requires the input to have the shape 270 | # batch_size x seq_length x head_dim x hidden_dim 271 | query_states = reshape(self.q_lin(query)) 272 | key_states = reshape(self.k_lin(key)) 273 | value_states = reshape(self.v_lin(value)) 274 | 275 | attn_dropout = self.config.attention_dropout if self.training else 0.0 276 | 277 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 278 | # therefore the input hidden states gets silently casted in float32. Hence, we need 279 | # cast them back in the correct dtype just to be sure everything works as expected. 280 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 281 | # in fp32. (LlamaRMSNorm handles it correctly) 282 | 283 | if query_states.dtype == torch.float32: 284 | if torch.is_autocast_enabled(): 285 | target_dtype = torch.get_autocast_gpu_dtype() 286 | # Handle the case where the model is quantized 287 | elif hasattr(self.config, "_pre_quantization_dtype"): 288 | target_dtype = self.config._pre_quantization_dtype 289 | else: 290 | target_dtype = self.q_lin.weight.dtype 291 | 292 | query_states = query_states.to(target_dtype) 293 | key_states = key_states.to(target_dtype) 294 | value_states = value_states.to(target_dtype) 295 | 296 | attn_weights = self._flash_attention_forward( 297 | query_states, key_states, value_states, mask, q_length, dropout=attn_dropout 298 | ) 299 | 300 | attn_weights_reshaped = attn_weights.reshape( 301 | batch_size, q_length, self.n_heads * dim_per_head 302 | ) 303 | attn_output = self.out_lin(attn_weights_reshaped) 304 | 305 | if output_attentions: 306 | return (attn_output, attn_weights) 307 | else: 308 | return (attn_output,) 309 | 310 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False 311 | def _flash_attention_forward( 312 | self, 313 | query_states, 314 | key_states, 315 | value_states, 316 | attention_mask, 317 | query_length, 318 | dropout=0.0, 319 | softmax_scale=None, 320 | ): 321 | """ 322 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 323 | first unpad the input, then computes the attention scores and pad the final attention scores. 324 | 325 | Args: 326 | query_states (`torch.Tensor`): 327 | Input query states to be passed to Flash Attention API 328 | key_states (`torch.Tensor`): 329 | Input key states to be passed to Flash Attention API 330 | value_states (`torch.Tensor`): 331 | Input value states to be passed to Flash Attention API 332 | attention_mask (`torch.Tensor`): 333 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 334 | position of padding tokens and 1 for the position of non-padding tokens. 335 | dropout (`float`): 336 | Attention dropout 337 | softmax_scale (`float`, *optional*): 338 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 339 | """ 340 | from flash_attn import flash_attn_func, flash_attn_varlen_func 341 | from flash_attn.bert_padding import pad_input 342 | 343 | if not self._flash_attn_uses_top_left_mask: 344 | causal = self.is_causal 345 | else: 346 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 347 | causal = self.is_causal and query_length != 1 348 | 349 | # Contains at least one padding token in the sequence 350 | if attention_mask is not None: 351 | batch_size = query_states.shape[0] 352 | ( 353 | query_states, 354 | key_states, 355 | value_states, 356 | indices_q, 357 | cu_seq_lens, 358 | max_seq_lens, 359 | ) = self._upad_input( 360 | query_states, key_states, value_states, attention_mask, query_length 361 | ) 362 | 363 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 364 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 365 | 366 | attn_output_unpad = flash_attn_varlen_func( 367 | query_states, 368 | key_states, 369 | value_states, 370 | cu_seqlens_q=cu_seqlens_q, 371 | cu_seqlens_k=cu_seqlens_k, 372 | max_seqlen_q=max_seqlen_in_batch_q, 373 | max_seqlen_k=max_seqlen_in_batch_k, 374 | dropout_p=dropout, 375 | softmax_scale=softmax_scale, 376 | causal=causal, 377 | ) 378 | 379 | attn_output = pad_input( 380 | attn_output_unpad, indices_q, batch_size, query_length 381 | ) 382 | else: 383 | attn_output = flash_attn_func( 384 | query_states, 385 | key_states, 386 | value_states, 387 | dropout, 388 | softmax_scale=softmax_scale, 389 | causal=causal, 390 | ) 391 | 392 | return attn_output 393 | 394 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads 395 | def _upad_input( 396 | self, query_layer, key_layer, value_layer, attention_mask, query_length 397 | ): 398 | from flash_attn.bert_padding import index_first_axis, unpad_input 399 | 400 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 401 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 402 | 403 | key_layer = index_first_axis( 404 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), 405 | indices_k, 406 | ) 407 | value_layer = index_first_axis( 408 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), 409 | indices_k, 410 | ) 411 | if query_length == kv_seq_len: 412 | query_layer = index_first_axis( 413 | query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), 414 | indices_k, 415 | ) 416 | cu_seqlens_q = cu_seqlens_k 417 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 418 | indices_q = indices_k 419 | elif query_length == 1: 420 | max_seqlen_in_batch_q = 1 421 | cu_seqlens_q = torch.arange( 422 | batch_size + 1, dtype=torch.int32, device=query_layer.device 423 | ) # There is a memcpy here, that is very bad. 424 | indices_q = cu_seqlens_q[:-1] 425 | query_layer = query_layer.squeeze(1) 426 | else: 427 | # The -q_len: slice assumes left padding. 428 | attention_mask = attention_mask[:, -query_length:] 429 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( 430 | query_layer, attention_mask 431 | ) 432 | 433 | return ( 434 | query_layer, 435 | key_layer, 436 | value_layer, 437 | indices_q, 438 | (cu_seqlens_q, cu_seqlens_k), 439 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 440 | ) 441 | 442 | 443 | class FFN(nn.Module): 444 | def __init__(self, config: DistilBertConfig): 445 | super().__init__() 446 | self.dropout = nn.Dropout(p=config.dropout) 447 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 448 | self.seq_len_dim = 1 449 | self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim) 450 | self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim) 451 | self.activation = get_activation(config.activation) 452 | 453 | def forward(self, input: torch.Tensor) -> torch.Tensor: 454 | return apply_chunking_to_forward( 455 | self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input 456 | ) 457 | 458 | def ff_chunk(self, input: torch.Tensor) -> torch.Tensor: 459 | x = self.lin1(input) 460 | x = self.activation(x) 461 | x = self.lin2(x) 462 | x = self.dropout(x) 463 | return x 464 | 465 | 466 | DISTILBERT_ATTENTION_CLASSES = { 467 | "eager": MultiHeadSelfAttention, 468 | "flash_attention_2": DistilBertFlashAttention2, 469 | } 470 | 471 | 472 | class TransformerBlock(nn.Module): 473 | def __init__(self, config: DistilBertConfig): 474 | super().__init__() 475 | 476 | # Have an even number of Configure multi-heads 477 | if config.dim % config.n_heads != 0: 478 | raise ValueError( 479 | f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly" 480 | ) 481 | 482 | self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation]( 483 | config 484 | ) 485 | self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) 486 | 487 | self.ffn = FFN(config) 488 | self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) 489 | 490 | def forward( 491 | self, 492 | x: torch.Tensor, 493 | attn_mask: Optional[torch.Tensor] = None, 494 | head_mask: Optional[torch.Tensor] = None, 495 | output_attentions: bool = False, 496 | ) -> Tuple[torch.Tensor, ...]: 497 | """ 498 | Parameters: 499 | x: torch.tensor(bs, seq_length, dim) 500 | attn_mask: torch.tensor(bs, seq_length) 501 | 502 | Returns: 503 | sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output: 504 | torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization. 505 | """ 506 | # Self-Attention 507 | sa_output = self.attention( 508 | query=x, 509 | key=x, 510 | value=x, 511 | mask=attn_mask, 512 | head_mask=head_mask, 513 | output_attentions=output_attentions, 514 | ) 515 | if output_attentions: 516 | sa_output, sa_weights = ( 517 | sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) 518 | ) 519 | else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples 520 | sa_output = sa_output[0] 521 | 522 | sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) 523 | 524 | # Feed Forward Network 525 | ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) 526 | ffn_output: torch.Tensor = self.output_layer_norm( 527 | ffn_output + sa_output 528 | ) # (bs, seq_length, dim) 529 | 530 | output = (ffn_output,) 531 | if output_attentions: 532 | output = (sa_weights,) + output 533 | return output 534 | 535 | 536 | class Transformer(nn.Module): 537 | def __init__(self, config: DistilBertConfig): 538 | super().__init__() 539 | self.n_layers = config.n_layers 540 | self.layer = nn.ModuleList( 541 | [TransformerBlock(config) for _ in range(config.n_layers)] 542 | ) 543 | self.gradient_checkpointing = False 544 | 545 | def forward( 546 | self, 547 | x: torch.Tensor, 548 | attn_mask: Optional[torch.Tensor] = None, 549 | head_mask: Optional[torch.Tensor] = None, 550 | output_attentions: bool = False, 551 | output_hidden_states: bool = False, 552 | return_dict: Optional[bool] = None, 553 | ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore 554 | """ 555 | Parameters: 556 | x: torch.tensor(bs, seq_length, dim) Input sequence embedded. 557 | attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence. 558 | 559 | Returns: 560 | hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top) 561 | layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)] 562 | Tuple of length n_layers with the hidden states from each layer. 563 | Optional: only if output_hidden_states=True 564 | all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] 565 | Tuple of length n_layers with the attention weights from each layer 566 | Optional: only if output_attentions=True 567 | """ 568 | all_hidden_states = () if output_hidden_states else None 569 | all_attentions = () if output_attentions else None 570 | 571 | hidden_state = x 572 | for i, layer_module in enumerate(self.layer): 573 | if output_hidden_states: 574 | all_hidden_states = all_hidden_states + (hidden_state,) 575 | 576 | if self.gradient_checkpointing and self.training: 577 | layer_outputs = self._gradient_checkpointing_func( 578 | layer_module.__call__, 579 | hidden_state, 580 | attn_mask, 581 | head_mask[i], 582 | output_attentions, 583 | ) 584 | else: 585 | layer_outputs = layer_module( 586 | hidden_state, 587 | attn_mask, 588 | head_mask[i], 589 | output_attentions, 590 | ) 591 | 592 | hidden_state = layer_outputs[-1] 593 | 594 | if output_attentions: 595 | if len(layer_outputs) != 2: 596 | raise ValueError( 597 | f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}" 598 | ) 599 | 600 | attentions = layer_outputs[0] 601 | all_attentions = all_attentions + (attentions,) 602 | else: 603 | if len(layer_outputs) != 1: 604 | raise ValueError( 605 | f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}" 606 | ) 607 | 608 | # Add last layer 609 | if output_hidden_states: 610 | all_hidden_states = all_hidden_states + (hidden_state,) 611 | 612 | if not return_dict: 613 | return tuple( 614 | v 615 | for v in [hidden_state, all_hidden_states, all_attentions] 616 | if v is not None 617 | ) 618 | return BaseModelOutput( 619 | last_hidden_state=hidden_state, 620 | hidden_states=all_hidden_states, 621 | attentions=all_attentions, 622 | ) 623 | 624 | 625 | class DistilBertPreTrainedModel(SuryaPreTrainedModel): 626 | """ 627 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 628 | models. 629 | """ 630 | 631 | config_class = DistilBertConfig 632 | load_tf_weights = None 633 | base_model_prefix = "distilbert" 634 | supports_gradient_checkpointing = True 635 | _supports_flash_attn_2 = True 636 | 637 | def _init_weights(self, module: nn.Module): 638 | """Initialize the weights.""" 639 | if isinstance(module, nn.Linear): 640 | # Slightly different from the TF version which uses truncated_normal for initialization 641 | # cf https://github.com/pytorch/pytorch/pull/5617 642 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 643 | if module.bias is not None: 644 | module.bias.data.zero_() 645 | elif isinstance(module, nn.Embedding): 646 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 647 | if module.padding_idx is not None: 648 | module.weight.data[module.padding_idx].zero_() 649 | elif isinstance(module, nn.LayerNorm): 650 | module.bias.data.zero_() 651 | module.weight.data.fill_(1.0) 652 | elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds: 653 | create_sinusoidal_embeddings( 654 | self.config.max_position_embeddings, 655 | self.config.dim, 656 | module.position_embeddings.weight, 657 | ) 658 | 659 | 660 | class DistilBertModel(DistilBertPreTrainedModel): 661 | def __init__(self, config: DistilBertConfig): 662 | super().__init__(config) 663 | 664 | self.embeddings = Embeddings(config) # Embeddings 665 | self.transformer = Transformer(config) # Encoder 666 | self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" 667 | 668 | # Initialize weights and apply final processing 669 | self.post_init() 670 | 671 | def get_position_embeddings(self) -> nn.Embedding: 672 | """ 673 | Returns the position embeddings 674 | """ 675 | return self.embeddings.position_embeddings 676 | 677 | def resize_position_embeddings(self, new_num_position_embeddings: int): 678 | """ 679 | Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. 680 | 681 | Arguments: 682 | new_num_position_embeddings (`int`): 683 | The number of new position embedding matrix. If position embeddings are learned, increasing the size 684 | will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the 685 | end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the 686 | size will add correct vectors at the end following the position encoding algorithm, whereas reducing 687 | the size will remove vectors from the end. 688 | """ 689 | num_position_embeds_diff = ( 690 | new_num_position_embeddings - self.config.max_position_embeddings 691 | ) 692 | 693 | # no resizing needs to be done if the length stays the same 694 | if num_position_embeds_diff == 0: 695 | return 696 | 697 | self.config.max_position_embeddings = new_num_position_embeddings 698 | 699 | old_position_embeddings_weight = ( 700 | self.embeddings.position_embeddings.weight.clone() 701 | ) 702 | 703 | self.embeddings.position_embeddings = nn.Embedding( 704 | self.config.max_position_embeddings, self.config.dim 705 | ) 706 | 707 | if self.config.sinusoidal_pos_embds: 708 | create_sinusoidal_embeddings( 709 | n_pos=self.config.max_position_embeddings, 710 | dim=self.config.dim, 711 | out=self.position_embeddings.weight, 712 | ) 713 | else: 714 | with torch.no_grad(): 715 | if num_position_embeds_diff > 0: 716 | self.embeddings.position_embeddings.weight[ 717 | :-num_position_embeds_diff 718 | ] = nn.Parameter(old_position_embeddings_weight) 719 | else: 720 | self.embeddings.position_embeddings.weight = nn.Parameter( 721 | old_position_embeddings_weight[:num_position_embeds_diff] 722 | ) 723 | # move position_embeddings to correct device 724 | self.embeddings.position_embeddings.to(self.device) 725 | 726 | def get_input_embeddings(self) -> nn.Embedding: 727 | return self.embeddings.word_embeddings 728 | 729 | def set_input_embeddings(self, new_embeddings: nn.Embedding): 730 | self.embeddings.word_embeddings = new_embeddings 731 | 732 | def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]): 733 | """ 734 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 735 | class PreTrainedModel 736 | """ 737 | for layer, heads in heads_to_prune.items(): 738 | self.transformer.layer[layer].attention.prune_heads(heads) 739 | 740 | def forward( 741 | self, 742 | input_ids: Optional[torch.Tensor] = None, 743 | attention_mask: Optional[torch.Tensor] = None, 744 | head_mask: Optional[torch.Tensor] = None, 745 | inputs_embeds: Optional[torch.Tensor] = None, 746 | output_attentions: Optional[bool] = None, 747 | output_hidden_states: Optional[bool] = None, 748 | return_dict: Optional[bool] = None, 749 | ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: 750 | output_attentions = ( 751 | output_attentions 752 | if output_attentions is not None 753 | else self.config.output_attentions 754 | ) 755 | output_hidden_states = ( 756 | output_hidden_states 757 | if output_hidden_states is not None 758 | else self.config.output_hidden_states 759 | ) 760 | return_dict = ( 761 | return_dict if return_dict is not None else self.config.use_return_dict 762 | ) 763 | 764 | if input_ids is not None and inputs_embeds is not None: 765 | raise ValueError( 766 | "You cannot specify both input_ids and inputs_embeds at the same time" 767 | ) 768 | elif input_ids is not None: 769 | self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) 770 | input_shape = input_ids.size() 771 | elif inputs_embeds is not None: 772 | input_shape = inputs_embeds.size()[:-1] 773 | else: 774 | raise ValueError("You have to specify either input_ids or inputs_embeds") 775 | 776 | device = input_ids.device if input_ids is not None else inputs_embeds.device 777 | 778 | # Prepare head mask if needed 779 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 780 | 781 | embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) 782 | 783 | if self._use_flash_attention_2: 784 | attention_mask = ( 785 | attention_mask 786 | if (attention_mask is not None and 0 in attention_mask) 787 | else None 788 | ) 789 | else: 790 | if attention_mask is None: 791 | attention_mask = torch.ones( 792 | input_shape, device=device 793 | ) # (bs, seq_length) 794 | 795 | return self.transformer( 796 | x=embeddings, 797 | attn_mask=attention_mask, 798 | head_mask=head_mask, 799 | output_attentions=output_attentions, 800 | output_hidden_states=output_hidden_states, 801 | return_dict=return_dict, 802 | ) 803 | 804 | 805 | class DistilBertForSequenceClassification(S3DownloaderMixin, DistilBertPreTrainedModel): 806 | def __init__(self, config: DistilBertConfig, **kwargs): 807 | super().__init__(config, **kwargs) 808 | self.num_labels = config.num_labels 809 | self.config = config 810 | 811 | self.distilbert = DistilBertModel(config) 812 | self.pre_classifier = nn.Linear(config.dim, config.dim) 813 | self.classifier = nn.Linear(config.dim, config.num_labels) 814 | self.dropout = nn.Dropout(config.seq_classif_dropout) 815 | 816 | # Initialize weights and apply final processing 817 | self.post_init() 818 | 819 | def get_position_embeddings(self) -> nn.Embedding: 820 | """ 821 | Returns the position embeddings 822 | """ 823 | return self.distilbert.get_position_embeddings() 824 | 825 | def resize_position_embeddings(self, new_num_position_embeddings: int): 826 | """ 827 | Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. 828 | 829 | Arguments: 830 | new_num_position_embeddings (`int`): 831 | The number of new position embedding matrix. If position embeddings are learned, increasing the size 832 | will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the 833 | end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the 834 | size will add correct vectors at the end following the position encoding algorithm, whereas reducing 835 | the size will remove vectors from the end. 836 | """ 837 | self.distilbert.resize_position_embeddings(new_num_position_embeddings) 838 | 839 | def forward( 840 | self, 841 | input_ids: Optional[torch.Tensor] = None, 842 | attention_mask: Optional[torch.Tensor] = None, 843 | head_mask: Optional[torch.Tensor] = None, 844 | inputs_embeds: Optional[torch.Tensor] = None, 845 | labels: Optional[torch.LongTensor] = None, 846 | output_attentions: Optional[bool] = None, 847 | output_hidden_states: Optional[bool] = None, 848 | return_dict: Optional[bool] = None, 849 | ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]: 850 | r""" 851 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 852 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 853 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 854 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 855 | """ 856 | return_dict = ( 857 | return_dict if return_dict is not None else self.config.use_return_dict 858 | ) 859 | 860 | distilbert_output = self.distilbert( 861 | input_ids=input_ids, 862 | attention_mask=attention_mask, 863 | head_mask=head_mask, 864 | inputs_embeds=inputs_embeds, 865 | output_attentions=output_attentions, 866 | output_hidden_states=output_hidden_states, 867 | return_dict=return_dict, 868 | ) 869 | hidden_state = distilbert_output[0] # (bs, seq_len, dim) 870 | pooled_output = hidden_state[:, 0] # (bs, dim) 871 | pooled_output = self.pre_classifier(pooled_output) # (bs, dim) 872 | pooled_output = nn.ReLU()(pooled_output) # (bs, dim) 873 | pooled_output = self.dropout(pooled_output) # (bs, dim) 874 | logits = self.classifier(pooled_output) # (bs, num_labels) 875 | 876 | loss = None 877 | if labels is not None: 878 | if self.config.problem_type is None: 879 | if self.num_labels == 1: 880 | self.config.problem_type = "regression" 881 | elif self.num_labels > 1 and ( 882 | labels.dtype == torch.long or labels.dtype == torch.int 883 | ): 884 | self.config.problem_type = "single_label_classification" 885 | else: 886 | self.config.problem_type = "multi_label_classification" 887 | 888 | if self.config.problem_type == "regression": 889 | loss_fct = MSELoss() 890 | if self.num_labels == 1: 891 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 892 | else: 893 | loss = loss_fct(logits, labels) 894 | elif self.config.problem_type == "single_label_classification": 895 | loss_fct = CrossEntropyLoss() 896 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 897 | elif self.config.problem_type == "multi_label_classification": 898 | loss_fct = BCEWithLogitsLoss() 899 | loss = loss_fct(logits, labels) 900 | 901 | if not return_dict: 902 | output = (logits,) + distilbert_output[1:] 903 | return ((loss,) + output) if loss is not None else output 904 | 905 | return SequenceClassifierOutput( 906 | loss=loss, 907 | logits=logits, 908 | hidden_states=distilbert_output.hidden_states, 909 | attentions=distilbert_output.attentions, 910 | ) 911 | ```