#
tokens: 31626/50000 3/133 files (page 5/5)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 5 of 5. Use http://codebase.md/datalab-to/surya?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── ISSUE_TEMPLATE
│   │   ├── breaking-bug-report.md
│   │   ├── feature_request.md
│   │   └── output-bug-report.md
│   └── workflows
│       ├── benchmarks.yml
│       ├── ci.yml
│       ├── cla.yml
│       ├── publish.yml
│       └── scripts.yml
├── .gitignore
├── .pre-commit-config.yaml
├── benchmark
│   ├── detection.py
│   ├── layout.py
│   ├── ordering.py
│   ├── recognition.py
│   ├── table_recognition.py
│   ├── texify.py
│   └── utils
│       ├── __init__.py
│       ├── bbox.py
│       ├── metrics.py
│       ├── scoring.py
│       ├── tatr.py
│       ├── tesseract.py
│       ├── textract.py
│       └── verify_benchmark_scores.py
├── CITATION.cff
├── CLA.md
├── detect_layout.py
├── detect_text.py
├── LICENSE
├── ocr_app.py
├── ocr_latex.py
├── ocr_text.py
├── poetry.lock
├── pyproject.toml
├── pytest.ini
├── README.md
├── signatures
│   └── version1
│       └── cla.json
├── static
│   ├── fonts
│   │   └── .gitignore
│   └── images
│       ├── arabic_layout.jpg
│       ├── arabic_reading.jpg
│       ├── arabic_text.jpg
│       ├── arabic.jpg
│       ├── benchmark_chart_small.png
│       ├── benchmark_chart.png
│       ├── benchmark_layout_chart.png
│       ├── benchmark_rec_chart.png
│       ├── benchmark_tablerec_acc.png
│       ├── benchmark_tablerec_speed.png
│       ├── chi_hind_layout.jpg
│       ├── chi_hind_orig.jpg
│       ├── chi_hind_reading.jpg
│       ├── chi_hind_text.jpg
│       ├── chi_hind.jpg
│       ├── chinese_layout.jpg
│       ├── chinese_reading.jpg
│       ├── chinese_text.jpg
│       ├── chinese.jpg
│       ├── excerpt_layout.png
│       ├── excerpt_reading.jpg
│       ├── excerpt_text.png
│       ├── excerpt.png
│       ├── funsd_layout.jpg
│       ├── funsd_reading.jpg
│       ├── funsd_text.jpg
│       ├── funsd.png
│       ├── gcloud_full_langs.png
│       ├── gcloud_rec_bench.png
│       ├── hindi_layout.jpg
│       ├── hindi_reading.jpg
│       ├── hindi_text.jpg
│       ├── hindi.jpg
│       ├── japanese_layout.jpg
│       ├── japanese_reading.jpg
│       ├── japanese_tablerec.png
│       ├── japanese_text.jpg
│       ├── japanese.jpg
│       ├── latex_ocr.png
│       ├── nyt_layout.jpg
│       ├── nyt_order.jpg
│       ├── nyt_text.jpg
│       ├── nyt.jpg
│       ├── paper_layout.jpg
│       ├── paper_reading.jpg
│       ├── paper_tablerec.png
│       ├── paper_text.jpg
│       ├── paper.jpg
│       ├── pres_layout.jpg
│       ├── pres_reading.jpg
│       ├── pres_tablerec.png
│       ├── pres_text.jpg
│       ├── pres.png
│       ├── rec_acc_table.png
│       ├── scanned_layout.jpg
│       ├── scanned_reading.jpg
│       ├── scanned_tablerec.png
│       ├── scanned_tablerec2.png
│       ├── scanned_text.jpg
│       ├── scanned.png
│       ├── surya_rec_perf.png
│       ├── table_rec.png
│       ├── textbook_layout.jpg
│       ├── textbook_order.jpg
│       ├── textbook_text.jpg
│       └── textbook.jpg
├── surya
│   ├── __init__.py
│   ├── common
│   │   ├── __init__.py
│   │   ├── adetr
│   │   │   └── decoder.py
│   │   ├── donut
│   │   │   ├── encoder.py
│   │   │   └── processor.py
│   │   ├── load.py
│   │   ├── polygon.py
│   │   ├── predictor.py
│   │   ├── pretrained.py
│   │   ├── s3.py
│   │   ├── surya
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   ├── decoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── embedder
│   │   │   │   └── __init__.py
│   │   │   ├── encoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── flash_attn_utils.py
│   │   │   ├── processor
│   │   │   │   ├── __init__.py
│   │   │   │   ├── schema.py
│   │   │   │   └── tokenizer.py
│   │   │   └── schema.py
│   │   ├── util.py
│   │   └── xla.py
│   ├── debug
│   │   ├── draw.py
│   │   ├── fonts.py
│   │   ├── katex.js
│   │   ├── render_html.py
│   │   └── text.py
│   ├── detection
│   │   ├── __init__.py
│   │   ├── heatmap.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoderdecoder.py
│   │   ├── parallel.py
│   │   ├── processor.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── foundation
│   │   ├── __init__.py
│   │   ├── cache
│   │   │   ├── __init__.py
│   │   │   ├── dynamic_ops.py
│   │   │   └── static_ops.py
│   │   ├── loader.py
│   │   └── util.py
│   ├── input
│   │   ├── load.py
│   │   └── processing.py
│   ├── layout
│   │   ├── __init__.py
│   │   ├── label.py
│   │   └── schema.py
│   ├── logging.py
│   ├── models.py
│   ├── ocr_error
│   │   ├── __init__.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoder.py
│   │   ├── schema.py
│   │   └── tokenizer.py
│   ├── recognition
│   │   ├── __init__.py
│   │   ├── languages.py
│   │   ├── postprocessing.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── scripts
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── detect_layout.py
│   │   ├── detect_text.py
│   │   ├── finetune_ocr.py
│   │   ├── hf_to_s3.py
│   │   ├── ocr_latex.py
│   │   ├── ocr_text.py
│   │   ├── run_streamlit_app.py
│   │   ├── run_texify_app.py
│   │   ├── streamlit_app.py
│   │   ├── table_recognition.py
│   │   └── texify_app.py
│   ├── settings.py
│   └── table_rec
│       ├── __init__.py
│       ├── loader.py
│       ├── model
│       │   ├── __init__.py
│       │   ├── config.py
│       │   ├── decoder.py
│       │   ├── encoder.py
│       │   └── encoderdecoder.py
│       ├── processor.py
│       ├── schema.py
│       └── shaper.py
├── table_recognition.py
├── tests
│   ├── assets
│   │   └── test_latex.png
│   ├── conftest.py
│   ├── test_detection.py
│   ├── test_foundation.py
│   ├── test_latex_ocr.py
│   ├── test_layout.py
│   ├── test_ocr_errors.py
│   ├── test_recognition.py
│   └── test_table_rec.py
└── texify_app.py
```

# Files

--------------------------------------------------------------------------------
/surya/foundation/__init__.py:
--------------------------------------------------------------------------------

```python
  1 | from __future__ import annotations
  2 | 
  3 | from dataclasses import dataclass
  4 | from typing import List, Optional, Tuple
  5 | from collections import deque
  6 | 
  7 | import cv2
  8 | import numpy as np
  9 | import torch
 10 | import math
 11 | from PIL import Image
 12 | from tqdm import tqdm
 13 | import torch.nn.functional as F
 14 | 
 15 | from surya.common.surya import SuryaModelOutput
 16 | from surya.common.xla import mark_step
 17 | from surya.common.predictor import BasePredictor
 18 | 
 19 | from surya.foundation.loader import FoundationModelLoader
 20 | from surya.foundation.util import (
 21 |     detect_repeat_token,
 22 | )
 23 | from surya.common.surya.schema import TaskNames
 24 | from surya.foundation.cache.dynamic_ops import DynamicOpsCache
 25 | from surya.foundation.cache.static_ops import StaticOpsCache
 26 | 
 27 | from surya.settings import settings
 28 | from surya.logging import get_logger, configure_logging
 29 | 
 30 | configure_logging()
 31 | logger = get_logger()
 32 | 
 33 | 
 34 | @dataclass
 35 | class ContinuousBatchInput:
 36 |     input_ids: torch.Tensor
 37 |     input_boxes: torch.Tensor
 38 |     position_ids: torch.Tensor
 39 |     # input_ids and position_ids may be padded, num_valid_tokens tracks the 'real' counts
 40 |     num_valid_tokens: torch.Tensor
 41 |     # count the number of predicted tokens for each batch element so far
 42 |     num_predicted_tokens: torch.Tensor
 43 |     needs_bbox_embedding: torch.Tensor
 44 | 
 45 | 
 46 | @dataclass
 47 | class ContinuousBatchOutput:
 48 |     input_ids: torch.Tensor
 49 |     preds: torch.Tensor
 50 |     bbox_preds: torch.Tensor
 51 |     scores: torch.Tensor
 52 |     token_probs: torch.Tensor
 53 | 
 54 | 
 55 | @dataclass
 56 | class FoundationPrompt:
 57 |     id: int
 58 |     task_name: TaskNames
 59 |     image: np.ndarray
 60 |     text: str
 61 |     math_mode: bool
 62 | 
 63 | 
 64 | class FoundationPredictor(BasePredictor):
 65 |     model_loader_cls = FoundationModelLoader
 66 |     batch_size = (
 67 |         settings.RECOGNITION_BATCH_SIZE
 68 |     )  # Default to the recognition batch size
 69 |     torch_dtype = None  # No default, loader picks the dtype based on device properties - bf16/fp16
 70 |     default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 64}
 71 |     encoder_chunk_size: int = 4096  # Default chunk size
 72 |     encoder_chunk_sizes = {"cpu": 4096, "mps": 4096, "cuda": 32768, "xla": 32768}
 73 |     extra_token_count = {
 74 |         "xla": 128
 75 |     }  # We have to pad the XLA cache since we don't use sliding window
 76 |     min_prefill_ratio: int = 1 if settings.FOUNDATION_XLA else 0.2
 77 |     min_trim_length: int = 50
 78 |     tasks = {
 79 |         TaskNames.ocr_with_boxes: {
 80 |             "needs_bboxes": True,
 81 |             "img_size": (1024, 512),
 82 |             "max_tokens": 224,
 83 |         },
 84 |         TaskNames.ocr_without_boxes: {
 85 |             "needs_bboxes": False,
 86 |             "img_size": (1024, 512),
 87 |             "max_tokens": 224,
 88 |         },
 89 |         TaskNames.block_without_boxes: {
 90 |             "needs_bboxes": False,
 91 |             "img_size": (1024, 512),
 92 |             "max_tokens": 768,
 93 |         },
 94 |         TaskNames.layout: {
 95 |             "needs_bboxes": False,
 96 |             "img_size": (1024, 1024),
 97 |             "max_tokens": 200,
 98 |         },
 99 |         TaskNames.table_structure: {
100 |             "needs_bboxes": False,
101 |             "img_size": (1024, 512),
102 |             "max_tokens": 600,
103 |         },
104 |     }
105 | 
106 |     def __init__(
107 |         self,
108 |         checkpoint=None,
109 |         device=settings.TORCH_DEVICE_MODEL,
110 |         dtype=None,
111 |         attention_implementation: Optional[str] = None,
112 |     ):
113 |         super().__init__(checkpoint, device, dtype, attention_implementation)
114 |         self.prompt_queue = deque()
115 |         self.batch_prompt_mapping = None
116 |         self.kv_cache = None
117 | 
118 |         self.beacon_token_interval = self.model.config.beacon_token_interval
119 | 
120 |         # Setup various tokens on-device
121 |         self.device_pad_token = torch.tensor(
122 |             self.processor.pad_token_id, device=self.model.device, dtype=torch.long
123 |         )
124 |         self.device_beacon_token = torch.tensor(
125 |             self.processor.beacon_token_id, device=self.model.device, dtype=torch.long
126 |         )
127 |         self.special_token_ids = torch.tensor(
128 |             [self.model.config.image_token_id] + self.model.config.register_token_ids,
129 |             device=self.model.device,
130 |         )
131 | 
132 |         self.pad_to_multiple = (
133 |             settings.FOUNDATION_PAD_TO_NEAREST
134 |             if settings.FOUNDATION_STATIC_CACHE
135 |             else None
136 |         )
137 | 
138 |     def to(self, device_dtype: torch.device | str | None = None):
139 |         super().to(device_dtype)
140 |         self.special_token_ids = self.special_token_ids.to(device_dtype)
141 | 
142 |     def get_encoder_chunk_size(self) -> int:
143 |         if settings.FOUNDATION_CHUNK_SIZE is not None:
144 |             return settings.FOUNDATION_CHUNK_SIZE
145 | 
146 |         chunk_size = self.encoder_chunk_size
147 |         if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes:
148 |             if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes:
149 |                 chunk_size = self.encoder_chunk_sizes[settings.TORCH_DEVICE_MODEL]
150 |         return chunk_size
151 | 
152 |     def setup_cache(self, batch_size: int, max_cache_len: int, max_sliding_window: int):
153 |         kv_cache_cls = StaticOpsCache if settings.FOUNDATION_XLA else DynamicOpsCache
154 |         self.kv_cache = kv_cache_cls(
155 |             self.model.config,
156 |             batch_size,
157 |             max_cache_len,
158 |             text_sliding_window=max_sliding_window,
159 |             device=self.model.device,
160 |             dtype=self.model.dtype,
161 |         )
162 |         self.prompt_queue.clear()
163 |         self.batch_prompt_mapping = {i: None for i in range(batch_size)}
164 | 
165 |     @property
166 |     def num_empty_slots(self):
167 |         return sum(v is None for v in self.batch_prompt_mapping.values())
168 | 
169 |     @property
170 |     def num_active_slots(self):
171 |         return len(self.batch_prompt_mapping) - self.num_empty_slots
172 | 
173 |     def prepare_input(
174 |         self,
175 |         task_names: List[str],
176 |         images: List[Image.Image],
177 |         input_text: List[str | None],
178 |         math_modes: List[bool],
179 |     ):
180 |         batch = []
181 |         for image, text, task_name, math_mode in zip(
182 |             images, input_text, task_names, math_modes
183 |         ):
184 |             image_size = self.tasks[task_name]["img_size"]
185 | 
186 |             try:
187 |                 image = self.processor.scale_to_fit(
188 |                     image, image_size
189 |                 )  # Only resizes if out of bounds (max/min)
190 |             except cv2.error:
191 |                 # The image is empty if it can't be resized, so just make a blank image
192 |                 image = np.zeros((image_size[1], image_size[0], 3), dtype=np.float32)
193 | 
194 |             # Task input is the same for all tasks for now
195 |             text = text or ""
196 | 
197 |             # Remove input text that exceeds max generation tokens (likely invalid)
198 |             if len(text) > self.tasks[task_name]["max_tokens"]:
199 |                 text = ""
200 |             inputs = [
201 |                 {"type": "image", "image": image, "rotated": False},
202 |                 {"type": "text", "text": text.strip(), "math": math_mode},
203 |             ]
204 |             batch.append({"task": task_name, "inputs": inputs})
205 | 
206 |         return batch
207 | 
208 |     def process_outputs(
209 |         self, outputs: SuryaModelOutput, max_lookahead_tokens: Optional[int] = None
210 |     ) -> ContinuousBatchOutput:
211 |         # Predictions are multi-token
212 |         lm_logits = outputs["lm_logits"].float()  # shape: [batch_size, seq_len, V]
213 |         bbox_logits = outputs["bbox_logits"].float()  # shape: [batch_size, seq_len, 6]
214 | 
215 |         if (
216 |             max_lookahead_tokens is not None
217 |             and lm_logits.shape[1] > max_lookahead_tokens + 1
218 |         ):
219 |             lm_logits = lm_logits[:, : max_lookahead_tokens + 1, :]
220 |             bbox_logits = bbox_logits[:, : max_lookahead_tokens + 1, :]
221 | 
222 |         # Get predictions
223 |         preds = torch.argmax(lm_logits, dim=-1)
224 |         input_ids = preds.to(torch.long)
225 | 
226 |         # Confidence scores for all tokens
227 |         token_probs = F.softmax(lm_logits, dim=-1)
228 |         scores = torch.max(token_probs, dim=-1).values  # shape: [B, T]
229 | 
230 |         # Update input boxes
231 |         box_preds = bbox_logits * self.model.config.bbox_size
232 |         box_preds = box_preds.to(torch.long)
233 | 
234 |         return ContinuousBatchOutput(
235 |             input_ids=input_ids,
236 |             preds=preds,
237 |             bbox_preds=box_preds,
238 |             scores=scores,
239 |             token_probs=token_probs,
240 |         )
241 | 
242 |     # Always left pad with beacons, don't worry about attention masking
243 |     def maybe_insert_beacon_tokens(
244 |         self,
245 |         input_ids: torch.Tensor,
246 |         input_boxes: torch.Tensor,
247 |         num_predicted_tokens: torch.Tensor,
248 |         num_new_tokens: Optional[torch.Tensor] = None,
249 |     ) -> Tuple[torch.Tensor, torch.Tensor]:
250 |         batch_size, seq_len = (
251 |             input_ids.shape
252 |         )  # seq_len can be >1 - In case of multi-token predictions
253 | 
254 |         # num_predicted tokens **does not include** the current new input_ids, this number is updated **after beacon tokens are inserted**
255 |         token_positions = num_predicted_tokens + torch.arange(
256 |             1, seq_len + 1, device=input_ids.device
257 |         ).unsqueeze(0)
258 |         beacon_positions = token_positions % self.beacon_token_interval == 0
259 | 
260 |         # If no beacons needed, return original input
261 |         needs_beacon = beacon_positions.any(dim=1)  # shape: [batch_size]
262 |         if not needs_beacon.any():
263 |             if num_new_tokens is None:
264 |                 num_new_tokens = (
265 |                     torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
266 |                     * seq_len
267 |                 )
268 |             return input_ids, input_boxes, num_new_tokens.squeeze(1)
269 | 
270 |         beacon_insert_pos = torch.zeros(
271 |             batch_size, dtype=torch.long, device=input_ids.device
272 |         )
273 |         for i in range(batch_size):
274 |             if needs_beacon[i]:
275 |                 # Find first position that needs beacon
276 |                 beacon_insert_pos[i] = torch.where(beacon_positions[i])[0]
277 | 
278 |         # Padded input ids.
279 |         new_input_ids = torch.full(
280 |             (batch_size, seq_len + 1),
281 |             self.device_pad_token,
282 |             dtype=input_ids.dtype,
283 |             device=input_ids.device,
284 |         )
285 |         new_input_boxes = torch.full(
286 |             (batch_size, seq_len + 1, 6),
287 |             -100,
288 |             dtype=input_boxes.dtype,
289 |             device=input_boxes.device,
290 |         )
291 |         # Fill in tokens for each sequence
292 |         for i in range(batch_size):
293 |             if needs_beacon[i]:
294 |                 insert_pos = beacon_insert_pos[i]
295 |                 new_input_ids[i, insert_pos] = self.device_beacon_token
296 |                 new_input_boxes[i, insert_pos, :] = -100
297 |                 if insert_pos > 0:
298 |                     new_input_ids[i, :insert_pos] = input_ids[i, :insert_pos]
299 |                     new_input_boxes[i, :insert_pos] = input_boxes[i, :insert_pos]
300 |                 new_input_ids[i, insert_pos + 1 :] = input_ids[i, insert_pos:]
301 |                 new_input_boxes[i, insert_pos + 1 :] = input_boxes[i, insert_pos:]
302 |             else:
303 |                 new_input_ids[i, 1:] = input_ids[i, :]
304 |                 new_input_boxes[i, 1:] = input_boxes[i, :]
305 | 
306 |         # Calculate valid token counts for both padded and non padded sequences
307 |         valid_token_counts = torch.where(
308 |             needs_beacon,
309 |             torch.tensor(seq_len + 1, device=input_ids.device),
310 |             torch.tensor(seq_len, device=input_ids.device),
311 |         )
312 | 
313 |         return new_input_ids, new_input_boxes, valid_token_counts
314 | 
315 |     def decode(
316 |         self,
317 |         current_inputs: Optional[ContinuousBatchInput] = None,
318 |         max_lookahead_tokens: Optional[int] = None,
319 |     ):
320 |         # Note - If we want to use the outputs from the non-last token, we
321 |         # need to set the cache position manually to ensure causality. The default
322 |         # behavior only works for the last token currently
323 |         input_ids = current_inputs.input_ids
324 |         input_boxes = current_inputs.input_boxes
325 |         embed_boxes = current_inputs.needs_bbox_embedding
326 | 
327 |         position_ids = current_inputs.position_ids
328 |         num_predicted_tokens = current_inputs.num_predicted_tokens
329 |         num_valid_tokens = current_inputs.num_valid_tokens
330 |         batch_size = input_ids.shape[0]
331 | 
332 |         # Pre-shift the attention mask based on the cache update
333 |         self.kv_cache.decode_attention_mask_update(
334 |             num_valid_tokens=num_valid_tokens, cache_idxs=list(range(batch_size))
335 |         )
336 | 
337 |         cache_position = self.get_cache_position(
338 |             input_ids.shape[1], self.kv_cache.attention_mask, prefill=False
339 |         )
340 |         with settings.INFERENCE_MODE():
341 |             outputs = self.model(
342 |                 input_ids=input_ids,
343 |                 attention_mask=self.kv_cache.attention_mask,
344 |                 position_ids=position_ids,
345 |                 cache_position=cache_position,
346 |                 use_cache=True,
347 |                 past_key_values=self.kv_cache,
348 |                 prefill=False,
349 |                 num_valid_tokens=num_valid_tokens,
350 |                 input_boxes=input_boxes,
351 |                 embed_boxes=embed_boxes,
352 |                 logits_to_keep=1,
353 |             )
354 | 
355 |         processed_output: ContinuousBatchOutput = self.process_outputs(
356 |             outputs, max_lookahead_tokens=max_lookahead_tokens
357 |         )
358 | 
359 |         input_ids = processed_output.input_ids
360 |         input_boxes = processed_output.bbox_preds
361 | 
362 |         # Update this **before** inserting beacon tokens
363 |         tau = settings.FOUNDATION_MULTI_TOKEN_MIN_CONFIDENCE
364 |         if max_lookahead_tokens is not None:
365 |             num_new_tokens = torch.clamp(
366 |                 (
367 |                     processed_output.scores.ge(tau)
368 |                     .to(torch.long)
369 |                     .cumprod(dim=1)
370 |                     .sum(dim=1, keepdim=True)
371 |                 ),
372 |                 min=1,
373 |             )
374 |         else:
375 |             num_new_tokens = input_ids.shape[1]
376 | 
377 |         num_predicted_tokens += num_new_tokens
378 |         input_ids, input_boxes, num_valid_tokens = self.maybe_insert_beacon_tokens(
379 |              input_ids, input_boxes, num_predicted_tokens, num_new_tokens
380 |         )
381 |         position_ids = position_ids[:, -1:] + torch.arange(
382 |             1, input_ids.shape[1] + 1, device=input_ids.device
383 |         )
384 |         # Some of the input sequences may now have left padding tokens, so we want to account for that
385 |         # offset is a per-batch offset of the position_ids
386 |         offset = (input_ids.shape[1] - num_valid_tokens).unsqueeze(1)
387 |         position_ids -= offset
388 | 
389 |         new_input = ContinuousBatchInput(
390 |             input_ids=input_ids,
391 |             input_boxes=input_boxes,
392 |             position_ids=position_ids,
393 |             num_valid_tokens=num_valid_tokens,
394 |             num_predicted_tokens=num_predicted_tokens,
395 |             needs_bbox_embedding=current_inputs.needs_bbox_embedding,
396 |         )
397 | 
398 |         return new_input, processed_output
399 | 
400 |     def pad_and_shift_input_ids_position_ids(
401 |         self,
402 |         input_ids: torch.Tensor,
403 |         bbox_preds: torch.Tensor,
404 |         position_ids: torch.Tensor,
405 |         new_seq_len: int,
406 |     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
407 |         """
408 |         Pads new_input_ids to match the new seq len with **left padding**
409 |         and creates updated position_ids
410 | 
411 |         Returns:
412 |             padded_input_ids (torch.Tensor): [batch_size, current_seq_len]
413 |             updated_position_ids (torch.Tensor): [batch_size, current_seq_len]
414 |         """
415 |         # No padding
416 |         if new_seq_len == input_ids.shape[1]:
417 |             return (
418 |                 input_ids,
419 |                 bbox_preds,
420 |                 position_ids[:, -1:] + torch.arange(1, new_seq_len + 1, device=self.model.device),
421 |             )
422 | 
423 |         pad_len = new_seq_len - input_ids.shape[1]
424 |         padded_input_ids = torch.nn.functional.pad(
425 |             input_ids, (pad_len, 0), value=self.device_pad_token
426 |         )
427 | 
428 |         padded_bbox_preds = torch.nn.functional.pad(
429 |             bbox_preds, (0, 0, pad_len, 0), value=-100
430 |         )
431 | 
432 |         # Since we have **left padding**, offset the new position_ids by the amount of padding
433 |         # This ensures that the **true tokens** get the correct position_ids
434 |         # The position_ids assigned to pad tokens do not matter. They are not cached, and not used for outputs
435 |         updated_position_ids = position_ids[:, -1:] + torch.arange(
436 |             1, new_seq_len + 1, device=self.model.device
437 |         )
438 |         updated_position_ids -= pad_len
439 | 
440 |         return padded_input_ids, padded_bbox_preds, updated_position_ids
441 | 
442 |     def get_cache_position(
443 |         self,
444 |         seq_len: int,
445 |         attention_mask: torch.Tensor,
446 |         prefill: bool,
447 |     ):
448 |         batch_size, target_len = attention_mask.shape
449 |         base_cache_position = (
450 |             torch.arange(seq_len, device=attention_mask.device)
451 |             .unsqueeze(0)
452 |             .expand(batch_size, -1)
453 |         )
454 |         if prefill:
455 |             return base_cache_position
456 | 
457 |         # This is a (batch_size) tensor, we can add the seq lens here
458 |         cache_seqlens = (
459 |             attention_mask
460 |             * torch.arange(attention_mask.size(1), device=attention_mask.device)
461 |         ).argmax(dim=1).to(torch.int32) + 1
462 |         # Needs to be unsqueezed so broadcasting works
463 |         return cache_seqlens.unsqueeze(1) + base_cache_position
464 | 
465 |     def prefill(
466 |         self,
467 |         current_inputs: Optional[ContinuousBatchInput] = None,
468 |         max_lookahead_tokens: Optional[int] = None,
469 |     ):
470 |         logger.debug(f"Prefilling {self.num_empty_slots} slots")
471 | 
472 |         prompts: List[FoundationPrompt] = [
473 |             self.prompt_queue.popleft()
474 |             for _ in range(min(self.num_empty_slots, len(self.prompt_queue)))
475 |         ]
476 |         non_active_idxs = [k for k, v in self.batch_prompt_mapping.items() if v is None]
477 |         idxs_to_merge = non_active_idxs[: len(prompts)]
478 | 
479 |         for i, prompt in zip(idxs_to_merge, prompts):
480 |             self.batch_prompt_mapping[i] = prompt.id
481 | 
482 |         needs_bbox_embedding = torch.tensor(
483 |             [
484 |                 p.task_name in [TaskNames.layout, TaskNames.table_structure]
485 |                 for p in prompts
486 |             ],
487 |             dtype=torch.bool,
488 |         )
489 | 
490 |         batch_input = self.prepare_input(
491 |             task_names=[p.task_name for p in prompts],
492 |             images=[p.image for p in prompts],
493 |             input_text=[p.text for p in prompts],
494 |             math_modes=[
495 |                 p.math_mode for p in prompts
496 |             ],  # Pass math mode to the processor
497 |         )
498 |         processed_inputs = self.processor(
499 |             batch_input,
500 |             padding_side="left",
501 |             device=self.model.device,
502 |             pad_to_multiple=self.pad_to_multiple,
503 |         )
504 | 
505 |         input_ids = processed_inputs["input_ids"].to(dtype=torch.long)
506 |         attention_mask = processed_inputs["attention_mask"].to(dtype=torch.long)
507 |         position_ids = processed_inputs["position_ids"].to(dtype=torch.long)
508 |         valid_batch_size = len(idxs_to_merge)
509 | 
510 |         # Keep these off device until later
511 |         image_tiles = processed_inputs["image_tiles"].to(dtype=self.model.dtype)
512 |         grid_thw = processed_inputs["grid_thw"].to(dtype=torch.long)
513 | 
514 |         if settings.FOUNDATION_STATIC_CACHE:
515 |             input_ids = self.pad_to_batch_size(
516 |                 input_ids, batch_size=self.kv_cache.max_batch_size
517 |             )
518 |             attention_mask = self.pad_to_batch_size(
519 |                 attention_mask, batch_size=self.kv_cache.max_batch_size
520 |             )
521 |             position_ids = self.pad_to_batch_size(
522 |                 position_ids, batch_size=self.kv_cache.max_batch_size
523 |             )
524 |             needs_bbox_embedding = self.pad_to_batch_size(
525 |                 needs_bbox_embedding, batch_size=self.kv_cache.max_batch_size
526 |             )
527 | 
528 |         # Move to device after padding
529 |         input_ids = input_ids.to(device=self.model.device)
530 |         attention_mask = attention_mask.to(device=self.model.device)
531 |         position_ids = position_ids.to(device=self.model.device)
532 |         needs_bbox_embedding = needs_bbox_embedding.to(device=self.model.device)
533 | 
534 |         # Find text lengths of each
535 |         # Oddly, this is optimal on GPU - causes a 30% slowdown if "optimized"
536 |         # Be very careful with the type and device of this - can cause
537 |         # a big slowdown if put on device
538 |         is_special = (
539 |             (input_ids.unsqueeze(-1) == self.special_token_ids).any(-1).cpu()
540 |         )  # (batch, seq_len)
541 |         text_lengths = []
542 |         for i in range(input_ids.shape[0]):
543 |             special_positions = is_special[i].nonzero(as_tuple=True)[0]
544 |             if len(special_positions) > 0:
545 |                 # Assuming special tokens are contiguous at the start
546 |                 prefix_len = special_positions[-1].item() + 1
547 |             else:
548 |                 prefix_len = 0
549 |             text_lengths.append(input_ids.shape[1] - prefix_len)
550 |         text_lengths = torch.tensor(text_lengths, dtype=torch.long)
551 | 
552 |         cache_position = self.get_cache_position(
553 |             input_ids.shape[1], attention_mask, prefill=True
554 |         )
555 |         with settings.INFERENCE_MODE():
556 |             image_embeddings = self.model.get_image_embeddings(
557 |                 pixel_values=image_tiles,
558 |                 grid_thw=grid_thw,
559 |                 encoder_chunk_size=self.get_encoder_chunk_size(),
560 |                 valid_batch_size=valid_batch_size,
561 |                 max_batch_size=self.kv_cache.max_batch_size,
562 |             )
563 |             mark_step()
564 | 
565 |             outputs = self.model(
566 |                 input_ids=input_ids,
567 |                 image_embeddings=image_embeddings,
568 |                 attention_mask=attention_mask,
569 |                 position_ids=position_ids,
570 |                 cache_position=cache_position,
571 |                 inputs_embeds=None,
572 |                 past_key_values=self.kv_cache,
573 |                 use_cache=True,
574 |                 encoder_chunk_size=self.get_encoder_chunk_size(),
575 |                 cache_idxs=idxs_to_merge,
576 |                 prefill=True,
577 |                 num_valid_tokens=None,  # Not required during prefill
578 |                 text_lengths=text_lengths,
579 |                 valid_batch_size=valid_batch_size,
580 |                 logits_to_keep=1,
581 |             )
582 | 
583 |         # Process outputs
584 |         processed_outputs = self.process_outputs(
585 |             outputs, max_lookahead_tokens=max_lookahead_tokens
586 |         )
587 |         # Multi-token prediction
588 |         predicted_tokens = processed_outputs.input_ids.shape[1]
589 |         num_valid_tokens = (
590 |             torch.ones((input_ids.shape[0]), device=self.model.device, dtype=torch.long)
591 |             * predicted_tokens
592 |         )
593 |         num_predicted_tokens = (
594 |             torch.ones(
595 |                 (input_ids.shape[0], 1), device=self.model.device, dtype=torch.long
596 |             )
597 |             * predicted_tokens
598 |         )
599 | 
600 |         self.kv_cache.prefill_attention_mask_update(
601 |             attention_mask, idxs_to_merge, valid_batch_size, text_lengths
602 |         )
603 |         self.kv_cache.update_text_counts(idxs_to_merge, valid_batch_size, text_lengths)
604 | 
605 |         full_batch = len(idxs_to_merge) == self.kv_cache.max_batch_size
606 | 
607 |         # If full batch, then we can ignore current_inputs
608 |         if current_inputs is None or full_batch:
609 |             new_seq_len = processed_outputs.input_ids.shape[1]
610 |             # No padding tokens - So we can safely set position_ids this way
611 |             position_ids = position_ids[:, -1:] + torch.arange(
612 |                 1, new_seq_len + 1, device=position_ids.device
613 |             )
614 |             new_input = ContinuousBatchInput(
615 |                 input_ids=processed_outputs.input_ids,
616 |                 input_boxes=processed_outputs.bbox_preds,
617 |                 position_ids=position_ids,
618 |                 num_valid_tokens=num_valid_tokens,
619 |                 num_predicted_tokens=num_predicted_tokens,
620 |                 needs_bbox_embedding=needs_bbox_embedding,
621 |             )
622 | 
623 |             return (
624 |                 new_input,
625 |                 processed_outputs,
626 |                 range(processed_outputs.input_ids.shape[0]),
627 |             )
628 | 
629 |         # Merging inputs for next steps
630 |         current_input_ids = current_inputs.input_ids
631 |         current_position_ids = current_inputs.position_ids
632 |         current_input_boxes = current_inputs.input_boxes
633 | 
634 |         current_needs_bbox_embedding = current_inputs.needs_bbox_embedding
635 | 
636 |         assert current_input_ids.shape[1] == current_position_ids.shape[1]
637 |         input_ids, bbox_preds, position_ids = self.pad_and_shift_input_ids_position_ids(
638 |             processed_outputs.input_ids,
639 |             processed_outputs.bbox_preds,
640 |             position_ids,
641 |             new_seq_len=current_input_ids.shape[1],
642 |         )
643 | 
644 |         current_input_ids[idxs_to_merge] = input_ids[:valid_batch_size]
645 |         current_input_boxes[idxs_to_merge] = bbox_preds[:valid_batch_size]
646 |         current_position_ids[idxs_to_merge] = position_ids[:valid_batch_size]
647 | 
648 |         current_num_valid_tokens = current_inputs.num_valid_tokens
649 |         current_num_valid_tokens[idxs_to_merge] = num_valid_tokens[:valid_batch_size]
650 | 
651 |         current_num_predicted_tokens = current_inputs.num_predicted_tokens
652 |         current_num_predicted_tokens[idxs_to_merge] = num_predicted_tokens[
653 |             :valid_batch_size
654 |         ]
655 |         current_needs_bbox_embedding[idxs_to_merge] = needs_bbox_embedding[
656 |             :valid_batch_size
657 |         ]
658 | 
659 |         new_input = ContinuousBatchInput(
660 |             input_ids=current_input_ids,
661 |             input_boxes=current_input_boxes,
662 |             position_ids=current_position_ids,
663 |             num_valid_tokens=current_num_valid_tokens,
664 |             num_predicted_tokens=current_num_predicted_tokens,
665 |             needs_bbox_embedding=current_needs_bbox_embedding,
666 |         )
667 | 
668 |         return new_input, processed_outputs, idxs_to_merge
669 | 
670 |     def get_max_image_token_count(
671 |         self, images: list[np.ndarray], tasks: List[TaskNames]
672 |     ) -> int:
673 |         def compute_scaled_size(
674 |             H: int, W: int, max_size: Tuple[int, int]
675 |         ) -> Tuple[int, int]:
676 |             max_W, max_H = max_size
677 |             min_W, min_H = (168, 168)
678 | 
679 |             current_pixels = H * W
680 |             max_pixels = max_H * max_W
681 |             min_pixels = min_H * min_W
682 |             current_pixels = max(1, current_pixels)  # Avoid zero division
683 | 
684 |             if current_pixels > max_pixels:
685 |                 scale = (max_pixels / current_pixels) ** 0.5
686 |                 return math.floor(H * scale), math.floor(W * scale)
687 |             elif current_pixels < min_pixels:
688 |                 scale = (min_pixels / current_pixels) ** 0.5
689 |                 return math.ceil(H * scale), math.ceil(W * scale)
690 |             return H, W
691 | 
692 |         def get_tile_count(H: int, W: int, factor: int) -> int:
693 |             H_bar = math.ceil(H / factor) * factor
694 |             W_bar = math.ceil(W / factor) * factor
695 |             grid_h = H_bar / self.processor.patch_size
696 |             grid_w = W_bar // self.processor.patch_size
697 |             return grid_h * grid_w
698 | 
699 |         max_tokens = 0
700 |         factor = self.processor.patch_size * self.processor.merge_size
701 | 
702 |         for image, task in zip(images, tasks):
703 |             H, W = image.shape[:2]
704 |             max_size = self.tasks[task]["img_size"]
705 |             scaled_H, scaled_W = compute_scaled_size(H, W, max_size)
706 |             token_count = get_tile_count(scaled_H, scaled_W, factor) / (
707 |                 self.processor.merge_size**2
708 |             )
709 |             max_tokens = max(max_tokens, token_count)
710 | 
711 |         # Extra 10 to account for EOS/BOS/Rotation token etc.
712 |         return 10 + self.processor.num_register_tokens + int(max_tokens)
713 | 
714 |     def prediction_loop(
715 |         self,
716 |         images: List[np.ndarray],
717 |         input_texts: List[str],
718 |         task_names: List[TaskNames],
719 |         batch_size: int | None = None,
720 |         max_tokens: int | None = None,
721 |         max_sliding_window: int | None = None,
722 |         math_mode: bool = True,
723 |         drop_repeated_tokens: bool = True,
724 |         max_lookahead_tokens: Optional[int] = None,
725 |         top_k: int = 0,
726 |         tqdm_desc: str = "Recognizing Text"
727 |     ) -> tuple:
728 |         allowed_tasks = self.tasks.keys()
729 |         assert all([task_name in allowed_tasks for task_name in task_names]), (
730 |             f"One or more tasks in {task_names} is not supported. Supported tasks are {allowed_tasks}"
731 |         )
732 | 
733 |         predicted_tokens = [[] for _ in range(len(images))]
734 |         scores = [[] for _ in range(len(images))]
735 |         topk_probs = [[] for _ in range(len(images))]
736 | 
737 |         if batch_size is None:
738 |             batch_size = self.get_batch_size()
739 | 
740 |         batch_size = min(len(images), batch_size)
741 |         current_inputs = None
742 | 
743 |         max_image_tokens = self.get_max_image_token_count(images, task_names)
744 |         if max_sliding_window is None:
745 |             max_sliding_window = self.model.config.sliding_window
746 |         self.setup_cache(
747 |             batch_size,
748 |             max_cache_len=max_image_tokens + max_sliding_window + self.extra_token_count.get(settings.TORCH_DEVICE_MODEL, 0),
749 |             max_sliding_window=max_sliding_window,
750 |         )
751 | 
752 |         batch_max_tokens = {}
753 |         for idx, (img, txt, task) in enumerate(zip(images, input_texts, task_names)):
754 |             self.prompt_queue.append(
755 |                 FoundationPrompt(
756 |                     id=idx, task_name=task, text=txt, image=img, math_mode=math_mode
757 |                 )
758 |             )
759 |             batch_max_tokens[idx] = (
760 |                 max_tokens
761 |                 or settings.FOUNDATION_MAX_TOKENS
762 |                 or self.tasks[task]["max_tokens"]
763 |             )
764 | 
765 |         overall_max_tokens = max(batch_max_tokens.values())
766 | 
767 |         pbar = tqdm(
768 |             total=len(self.prompt_queue),
769 |             desc=tqdm_desc,
770 |             disable=self.disable_tqdm,
771 |         )
772 | 
773 |         batch_bboxes = torch.zeros(len(images), overall_max_tokens, 6)
774 |         batch_pos = [0] * len(images)
775 | 
776 |         while self.prompt_queue or self.num_active_slots > 0:
777 |             if (
778 |                 self.num_empty_slots / batch_size
779 |             ) >= self.min_prefill_ratio and self.prompt_queue:
780 |                 updated_inputs, outputs, merge_idxs = self.prefill(
781 |                     current_inputs, max_lookahead_tokens=0
782 |                 )
783 | 
784 |                 predicted_tokens_cpu = outputs.preds.cpu()
785 |                 scores_cpu = outputs.scores.cpu()
786 |                 bbox_preds_cpu = outputs.bbox_preds.cpu()
787 | 
788 |                 if top_k > 0:
789 |                     batch_top_k_probs, batch_top_k_indices = torch.topk(
790 |                         outputs.token_probs, k=top_k, dim=-1
791 |                     )
792 |                     batch_top_k_probs_cpu = batch_top_k_probs.cpu()
793 |                     batch_top_k_indices_cpu = batch_top_k_indices.cpu()
794 | 
795 |                 for temp_idx, b_idx in enumerate(merge_idxs):
796 |                     if self.batch_prompt_mapping[b_idx] is not None:
797 |                         p_idx = self.batch_prompt_mapping[b_idx]
798 |                         seq_len = predicted_tokens_cpu.shape[1]
799 |                         for t_idx in range(seq_len):
800 |                             token = predicted_tokens_cpu[temp_idx, t_idx].item()
801 |                             predicted_tokens[p_idx].append(token)
802 |                             batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[
803 |                                 temp_idx, t_idx
804 |                             ]
805 |                             batch_pos[p_idx] += 1
806 |                             scores[p_idx].append(scores_cpu[temp_idx, t_idx].item())
807 | 
808 |                             if top_k > 0:
809 |                                 top_k_scores = {
810 |                                     batch_top_k_indices_cpu[temp_idx, t_idx][
811 |                                         k
812 |                                     ].item(): batch_top_k_probs_cpu[temp_idx, t_idx][
813 |                                         k
814 |                                     ].item()
815 |                                     for k in range(top_k)
816 |                                 }
817 |                                 topk_probs[p_idx].append(top_k_scores)
818 | 
819 |                             if token in [
820 |                                 self.processor.eos_token_id,
821 |                                 self.processor.no_output_token,
822 |                             ]:
823 |                                 self.batch_prompt_mapping[b_idx] = None
824 |                                 pbar.update(1)
825 |                                 break
826 |             else:
827 |                 updated_inputs, outputs = self.decode(
828 |                     current_inputs, max_lookahead_tokens=max_lookahead_tokens
829 |                 )
830 |                 mark_step()
831 | 
832 |                 predicted_tokens_cpu = outputs.preds.cpu()
833 |                 scores_cpu = outputs.scores.cpu()
834 |                 bbox_preds_cpu = outputs.bbox_preds.cpu()
835 | 
836 |                 if top_k > 0:
837 |                     batch_top_k_probs, batch_top_k_indices = torch.topk(
838 |                         outputs.token_probs, k=top_k, dim=-1
839 |                     )
840 |                     batch_top_k_probs_cpu = batch_top_k_probs.cpu()
841 |                     batch_top_k_indices_cpu = batch_top_k_indices.cpu()
842 | 
843 |                 for b_idx, p_idx in self.batch_prompt_mapping.items():
844 |                     if p_idx is not None:
845 |                         seq_len = predicted_tokens_cpu.shape[1]
846 |                         num_tokens = updated_inputs.num_valid_tokens[b_idx].item()
847 |                         should_stop = False
848 | 
849 |                         for t_idx in range(seq_len):
850 |                             # don't use multitoken prediction for lower confidence tokens
851 |                             if t_idx > 0 and num_tokens < seq_len:
852 |                                 # roll so tokens are right aligned
853 |                                 updated_inputs.input_ids[b_idx] = (
854 |                                     updated_inputs.input_ids[b_idx].roll(
855 |                                         shifts=seq_len - num_tokens, dims=0
856 |                                     )
857 |                                 )
858 |                                 # don't need to roll position_ids because that's handled in `decode` (and when we do beacon tokens)
859 |                                 break
860 | 
861 |                             token = predicted_tokens_cpu[b_idx, t_idx].item()
862 |                             predicted_tokens[p_idx].append(token)
863 |                             batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[
864 |                                 b_idx, t_idx
865 |                             ]
866 |                             batch_pos[p_idx] += 1
867 |                             scores[p_idx].append(scores_cpu[b_idx, t_idx].item())
868 | 
869 |                             if top_k > 0:
870 |                                 top_k_scores = {
871 |                                     batch_top_k_indices_cpu[temp_idx, t_idx][
872 |                                         k
873 |                                     ].item(): batch_top_k_probs_cpu[temp_idx, t_idx][
874 |                                         k
875 |                                     ].item()
876 |                                     for k in range(top_k)
877 |                                 }
878 |                                 topk_probs[p_idx].append(top_k_scores)
879 | 
880 |                             repeats = len(predicted_tokens[p_idx]) >= batch_max_tokens[
881 |                                 p_idx
882 |                             ] or (
883 |                                 drop_repeated_tokens
884 |                                 and detect_repeat_token(predicted_tokens[p_idx])
885 |                                 and task_names[p_idx]
886 |                                 in [
887 |                                     TaskNames.ocr_with_boxes,
888 |                                     TaskNames.ocr_without_boxes,
889 |                                 ]
890 |                             )
891 |                             if (
892 |                                 token
893 |                                 in [
894 |                                     self.processor.eos_token_id,
895 |                                     self.processor.pad_token_id,
896 |                                 ]
897 |                                 or repeats
898 |                             ):
899 |                                 should_stop = True
900 |                                 break
901 | 
902 |                         if should_stop:
903 |                             self.batch_prompt_mapping[b_idx] = None
904 |                             pbar.update(1)
905 | 
906 |             # Update inputs and mark XLA step
907 |             current_inputs = updated_inputs
908 | 
909 |         pbar.close()
910 | 
911 |         del self.kv_cache
912 |         self.kv_cache = None
913 |         torch.cuda.empty_cache()
914 | 
915 |         return predicted_tokens, batch_bboxes, scores, topk_probs
916 | 
```

--------------------------------------------------------------------------------
/surya/common/donut/encoder.py:
--------------------------------------------------------------------------------

```python
  1 | import collections.abc
  2 | import math
  3 | from dataclasses import dataclass
  4 | from typing import Optional, Tuple, Union
  5 | 
  6 | import torch
  7 | import torch.utils.checkpoint
  8 | from torch import nn
  9 | 
 10 | from transformers.activations import ACT2FN
 11 | from transformers.pytorch_utils import (
 12 |     find_pruneable_heads_and_indices,
 13 |     meshgrid,
 14 |     prune_linear_layer,
 15 | )
 16 | from transformers.utils import ModelOutput
 17 | from transformers import DonutSwinConfig
 18 | 
 19 | from surya.common.pretrained import SuryaPreTrainedModel
 20 | from surya.common.xla import mark_step
 21 | 
 22 | _EXPECTED_OUTPUT_SHAPE = [1, 49, 1024]
 23 | 
 24 | 
 25 | @dataclass
 26 | # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
 27 | class DonutSwinEncoderOutput(ModelOutput):
 28 |     last_hidden_state: torch.FloatTensor = None
 29 |     hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
 30 |     attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
 31 |     reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
 32 | 
 33 | 
 34 | @dataclass
 35 | class DonutSwinModelOutput(ModelOutput):
 36 |     last_hidden_state: torch.FloatTensor = None
 37 | 
 38 | 
 39 | # Copied from transformers.models.swin.modeling_swin.window_partition
 40 | def window_partition(input_feature, window_size):
 41 |     """
 42 |     Partitions the given input into windows.
 43 |     """
 44 |     batch_size, height, width, num_channels = input_feature.shape
 45 |     input_feature = input_feature.view(
 46 |         batch_size,
 47 |         height // window_size,
 48 |         window_size,
 49 |         width // window_size,
 50 |         window_size,
 51 |         num_channels,
 52 |     )
 53 |     windows = (
 54 |         input_feature.permute(0, 1, 3, 2, 4, 5)
 55 |         .contiguous()
 56 |         .view(-1, window_size, window_size, num_channels)
 57 |     )
 58 |     return windows
 59 | 
 60 | 
 61 | # Copied from transformers.models.swin.modeling_swin.window_reverse
 62 | def window_reverse(windows, window_size, height, width):
 63 |     """
 64 |     Merges windows to produce higher resolution features.
 65 |     """
 66 |     num_channels = windows.shape[-1]
 67 |     windows = windows.view(
 68 |         -1,
 69 |         height // window_size,
 70 |         width // window_size,
 71 |         window_size,
 72 |         window_size,
 73 |         num_channels,
 74 |     )
 75 |     windows = (
 76 |         windows.permute(0, 1, 3, 2, 4, 5)
 77 |         .contiguous()
 78 |         .view(-1, height, width, num_channels)
 79 |     )
 80 |     return windows
 81 | 
 82 | 
 83 | # Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin
 84 | class DonutSwinEmbeddings(nn.Module):
 85 |     """
 86 |     Construct the patch and position embeddings. Optionally, also the mask token.
 87 |     """
 88 | 
 89 |     def __init__(self, config, use_mask_token=False):
 90 |         super().__init__()
 91 | 
 92 |         self.patch_embeddings = DonutSwinPatchEmbeddings(config)
 93 |         num_patches = self.patch_embeddings.num_patches
 94 |         self.patch_grid = self.patch_embeddings.grid_size
 95 |         self.mask_token = (
 96 |             nn.Parameter(torch.zeros(1, 1, config.embed_dim))
 97 |             if use_mask_token
 98 |             else None
 99 |         )
100 | 
101 |         self.position_embeddings = None
102 |         self.row_embeddings = None
103 |         self.column_embeddings = None
104 |         if config.use_absolute_embeddings:
105 |             self.position_embeddings = nn.Parameter(
106 |                 torch.zeros(1, num_patches + 1, config.embed_dim)
107 |             )
108 | 
109 |         if hasattr(config, "use_2d_embeddings") and config.use_2d_embeddings:
110 |             self.row_embeddings = nn.Parameter(
111 |                 torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim)
112 |             )
113 |             self.column_embeddings = nn.Parameter(
114 |                 torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim)
115 |             )
116 | 
117 |         self.norm = nn.LayerNorm(config.embed_dim)
118 | 
119 |     def interpolate_pos_encoding(
120 |         self, embeddings: torch.Tensor, height: int, width: int
121 |     ) -> torch.Tensor:
122 |         """
123 |         This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
124 |         resolution images.
125 | 
126 |         Source:
127 |         https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
128 |         """
129 | 
130 |         num_patches = embeddings.shape[1] - 1
131 |         num_positions = self.position_embeddings.shape[1] - 1
132 |         if num_patches == num_positions and height == width:
133 |             return self.position_embeddings
134 |         class_pos_embed = self.position_embeddings[:, 0]
135 |         patch_pos_embed = self.position_embeddings[:, 1:]
136 |         dim = embeddings.shape[-1]
137 |         h0 = height // self.config.patch_size
138 |         w0 = width // self.config.patch_size
139 |         # we add a small number to avoid floating point error in the interpolation
140 |         # see discussion at https://github.com/facebookresearch/dino/issues/8
141 |         h0, w0 = h0 + 0.1, w0 + 0.1
142 |         patch_pos_embed = patch_pos_embed.reshape(
143 |             1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
144 |         )
145 |         patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
146 |         patch_pos_embed = nn.functional.interpolate(
147 |             patch_pos_embed,
148 |             scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
149 |             mode="bicubic",
150 |             align_corners=False,
151 |         )
152 |         patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
153 |         return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
154 | 
155 |     def forward(
156 |         self,
157 |         pixel_values: Optional[torch.FloatTensor],
158 |         bool_masked_pos: Optional[torch.BoolTensor] = None,
159 |         interpolate_pos_encoding: bool = False,
160 |     ) -> Tuple[torch.Tensor]:
161 |         _, num_channels, height, width = pixel_values.shape
162 |         embeddings, output_dimensions = self.patch_embeddings(pixel_values)
163 |         embeddings = self.norm(embeddings)
164 |         batch_size, seq_len, _ = embeddings.size()
165 | 
166 |         if bool_masked_pos is not None:
167 |             mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
168 |             # replace the masked visual tokens by mask_tokens
169 |             mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
170 |             embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
171 | 
172 |         if self.position_embeddings is not None:
173 |             if interpolate_pos_encoding:
174 |                 embeddings = embeddings + self.interpolate_pos_encoding(
175 |                     embeddings, height, width
176 |                 )
177 |             else:
178 |                 embeddings = embeddings + self.position_embeddings[:, :seq_len]
179 | 
180 |         if self.row_embeddings is not None and self.column_embeddings is not None:
181 |             # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ...
182 |             row_embeddings = self.row_embeddings[
183 |                 :, : output_dimensions[0], :
184 |             ].repeat_interleave(output_dimensions[1], dim=1)
185 |             column_embeddings = self.column_embeddings[
186 |                 :, : output_dimensions[1], :
187 |             ].repeat(1, output_dimensions[0], 1)
188 | 
189 |             embeddings = embeddings + row_embeddings + column_embeddings
190 | 
191 |         return embeddings, output_dimensions
192 | 
193 | 
194 | # Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin
195 | class DonutSwinPatchEmbeddings(nn.Module):
196 |     """
197 |     This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
198 |     `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
199 |     Transformer.
200 |     """
201 | 
202 |     def __init__(self, config):
203 |         super().__init__()
204 |         image_size, patch_size = config.image_size, config.patch_size
205 |         num_channels, hidden_size = config.num_channels, config.embed_dim
206 |         image_size = (
207 |             image_size
208 |             if isinstance(image_size, collections.abc.Iterable)
209 |             else (image_size, image_size)
210 |         )
211 |         patch_size = (
212 |             patch_size
213 |             if isinstance(patch_size, collections.abc.Iterable)
214 |             else (patch_size, patch_size)
215 |         )
216 |         num_patches = (image_size[1] // patch_size[1]) * (
217 |             image_size[0] // patch_size[0]
218 |         )
219 |         self.image_size = image_size
220 |         self.patch_size = patch_size
221 |         self.num_channels = num_channels
222 |         self.num_patches = num_patches
223 |         self.grid_size = (
224 |             image_size[0] // patch_size[0],
225 |             image_size[1] // patch_size[1],
226 |         )
227 | 
228 |         self.projection = nn.Conv2d(
229 |             num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
230 |         )
231 | 
232 |     def maybe_pad(self, pixel_values, height, width):
233 |         if width % self.patch_size[1] != 0:
234 |             pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
235 |             pixel_values = nn.functional.pad(pixel_values, pad_values)
236 |         if height % self.patch_size[0] != 0:
237 |             pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
238 |             pixel_values = nn.functional.pad(pixel_values, pad_values)
239 |         return pixel_values
240 | 
241 |     def forward(
242 |         self, pixel_values: Optional[torch.FloatTensor]
243 |     ) -> Tuple[torch.Tensor, Tuple[int]]:
244 |         _, num_channels, height, width = pixel_values.shape
245 |         # pad the input to be divisible by self.patch_size, if needed
246 |         pixel_values = self.maybe_pad(pixel_values, height, width)
247 |         embeddings = self.projection(pixel_values)
248 |         _, _, height, width = embeddings.shape
249 |         output_dimensions = (height, width)
250 |         embeddings = embeddings.flatten(2).transpose(1, 2)
251 | 
252 |         return embeddings, output_dimensions
253 | 
254 | 
255 | # Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
256 | class DonutSwinPatchMerging(nn.Module):
257 |     """
258 |     Patch Merging Layer.
259 | 
260 |     Args:
261 |         input_resolution (`Tuple[int]`):
262 |             Resolution of input feature.
263 |         dim (`int`):
264 |             Number of input channels.
265 |         norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
266 |             Normalization layer class.
267 |     """
268 | 
269 |     def __init__(
270 |         self,
271 |         input_resolution: Tuple[int],
272 |         dim: int,
273 |         norm_layer: nn.Module = nn.LayerNorm,
274 |     ) -> None:
275 |         super().__init__()
276 |         self.input_resolution = input_resolution
277 |         self.dim = dim
278 |         self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
279 |         self.norm = norm_layer(4 * dim)
280 | 
281 |     def maybe_pad(self, input_feature, height, width):
282 |         should_pad = (height % 2 == 1) or (width % 2 == 1)
283 |         if should_pad:
284 |             pad_values = (0, 0, 0, width % 2, 0, height % 2)
285 |             input_feature = nn.functional.pad(input_feature, pad_values)
286 | 
287 |         return input_feature
288 | 
289 |     def forward(
290 |         self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]
291 |     ) -> torch.Tensor:
292 |         height, width = input_dimensions
293 |         # `dim` is height * width
294 |         batch_size, dim, num_channels = input_feature.shape
295 | 
296 |         input_feature = input_feature.view(batch_size, height, width, num_channels)
297 |         # pad input to be disible by width and height, if needed
298 |         input_feature = self.maybe_pad(input_feature, height, width)
299 |         # [batch_size, height/2, width/2, num_channels]
300 |         input_feature_0 = input_feature[:, 0::2, 0::2, :]
301 |         # [batch_size, height/2, width/2, num_channels]
302 |         input_feature_1 = input_feature[:, 1::2, 0::2, :]
303 |         # [batch_size, height/2, width/2, num_channels]
304 |         input_feature_2 = input_feature[:, 0::2, 1::2, :]
305 |         # [batch_size, height/2, width/2, num_channels]
306 |         input_feature_3 = input_feature[:, 1::2, 1::2, :]
307 |         # batch_size height/2 width/2 4*num_channels
308 |         input_feature = torch.cat(
309 |             [input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1
310 |         )
311 |         input_feature = input_feature.view(
312 |             batch_size, -1, 4 * num_channels
313 |         )  # batch_size height/2*width/2 4*C
314 | 
315 |         input_feature = self.norm(input_feature)
316 |         input_feature = self.reduction(input_feature)
317 | 
318 |         return input_feature
319 | 
320 | 
321 | # Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin
322 | class DonutSwinSelfAttention(nn.Module):
323 |     def __init__(self, config, dim, num_heads, num_kv_heads, window_size):
324 |         super().__init__()
325 |         if dim % num_heads != 0:
326 |             raise ValueError(
327 |                 f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
328 |             )
329 | 
330 |         self.num_attention_heads = num_heads
331 |         self.num_kv_heads = num_kv_heads
332 |         self.kv_repeats = self.num_attention_heads // self.num_kv_heads
333 |         self.attention_head_size = int(dim / num_heads)
334 |         self.all_head_size = self.num_attention_heads * self.attention_head_size
335 |         self.kv_head_size = self.num_kv_heads * self.attention_head_size
336 |         self.window_size = (
337 |             window_size
338 |             if isinstance(window_size, collections.abc.Iterable)
339 |             else (window_size, window_size)
340 |         )
341 | 
342 |         self.relative_position_bias_table = nn.Parameter(
343 |             torch.zeros(
344 |                 (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads
345 |             )
346 |         )
347 | 
348 |         # get pair-wise relative position index for each token inside the window
349 |         coords_h = torch.arange(self.window_size[0])
350 |         coords_w = torch.arange(self.window_size[1])
351 |         coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
352 |         coords_flatten = torch.flatten(coords, 1)
353 |         relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
354 |         relative_coords = relative_coords.permute(1, 2, 0).contiguous()
355 |         relative_coords[:, :, 0] += self.window_size[0] - 1
356 |         relative_coords[:, :, 1] += self.window_size[1] - 1
357 |         relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
358 |         relative_position_index = relative_coords.sum(-1)
359 |         self.register_buffer("relative_position_index", relative_position_index)
360 | 
361 |         self.query = nn.Linear(
362 |             self.all_head_size, self.all_head_size, bias=config.qkv_bias
363 |         )
364 |         self.key = nn.Linear(
365 |             self.all_head_size, self.kv_head_size, bias=config.qkv_bias
366 |         )
367 |         self.value = nn.Linear(
368 |             self.all_head_size, self.kv_head_size, bias=config.qkv_bias
369 |         )
370 | 
371 |     def transpose_for_scores(self, x):
372 |         new_x_shape = x.size()[:-1] + (
373 |             self.num_attention_heads,
374 |             self.attention_head_size,
375 |         )
376 |         x = x.view(new_x_shape)
377 |         return x.permute(0, 2, 1, 3)
378 | 
379 |     def transpose_kv_for_scores(self, x, repeats):
380 |         new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size)
381 |         x = x.view(new_x_shape)
382 |         x = x.repeat(
383 |             1, 1, repeats, 1
384 |         )  # repeat the values for each key-value head to match query dim
385 |         return x.permute(0, 2, 1, 3).contiguous()
386 | 
387 |     def forward(
388 |         self,
389 |         hidden_states: torch.Tensor,
390 |         attention_mask: Optional[torch.FloatTensor] = None,
391 |         head_mask: Optional[torch.FloatTensor] = None,
392 |         output_attentions: Optional[bool] = False,
393 |     ) -> Tuple[torch.Tensor]:
394 |         batch_size, dim, num_channels = hidden_states.shape
395 |         mixed_query_layer = self.query(hidden_states)
396 | 
397 |         # Final is (batch_size, num_attention_heads, seq_len, attention_head_size)
398 |         key_layer = self.transpose_kv_for_scores(
399 |             self.key(hidden_states), self.kv_repeats
400 |         )
401 |         value_layer = self.transpose_kv_for_scores(
402 |             self.value(hidden_states), self.kv_repeats
403 |         )
404 |         query_layer = self.transpose_for_scores(mixed_query_layer)
405 | 
406 |         relative_position_bias = self.relative_position_bias_table[
407 |             self.relative_position_index.view(-1)
408 |         ]
409 |         relative_position_bias = relative_position_bias.view(
410 |             self.window_size[0] * self.window_size[1],
411 |             self.window_size[0] * self.window_size[1],
412 |             -1,
413 |         )
414 |         relative_position_bias = (
415 |             relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
416 |         )
417 |         relative_position_bias = relative_position_bias.repeat(batch_size, 1, 1, 1)
418 | 
419 |         if attention_mask is None:
420 |             attention_mask = relative_position_bias
421 |         else:
422 |             mask_shape = attention_mask.shape[0]
423 |             repeat_count = batch_size // mask_shape
424 |             attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1)
425 |             attention_mask = attention_mask + relative_position_bias
426 | 
427 |         attn_output = torch.nn.functional.scaled_dot_product_attention(
428 |             query_layer,
429 |             key_layer,
430 |             value_layer,
431 |             attn_mask=attention_mask,
432 |             dropout_p=0.0,
433 |             scale=self.attention_head_size**-0.5,
434 |         )
435 | 
436 |         attn_output = attn_output.transpose(1, 2).contiguous()
437 |         attn_output = attn_output.view(batch_size, dim, num_channels)
438 | 
439 |         outputs = (attn_output,)
440 |         return outputs
441 | 
442 | 
443 | # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
444 | class DonutSwinSelfOutput(nn.Module):
445 |     def __init__(self, config, dim):
446 |         super().__init__()
447 |         self.dense = nn.Linear(dim, dim)
448 | 
449 |     def forward(
450 |         self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
451 |     ) -> torch.Tensor:
452 |         return self.dense(hidden_states)
453 | 
454 | 
455 | # Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
456 | class DonutSwinAttention(nn.Module):
457 |     def __init__(self, config, dim, num_heads, num_kv_heads, window_size):
458 |         super().__init__()
459 |         self.self = DonutSwinSelfAttention(
460 |             config, dim, num_heads, num_kv_heads, window_size
461 |         )
462 |         self.output = DonutSwinSelfOutput(config, dim)
463 |         self.pruned_heads = set()
464 | 
465 |     def prune_heads(self, heads):
466 |         if len(heads) == 0:
467 |             return
468 |         heads, index = find_pruneable_heads_and_indices(
469 |             heads,
470 |             self.self.num_attention_heads,
471 |             self.self.attention_head_size,
472 |             self.pruned_heads,
473 |         )
474 | 
475 |         # Prune linear layers
476 |         self.self.query = prune_linear_layer(self.self.query, index)
477 |         self.self.key = prune_linear_layer(self.self.key, index)
478 |         self.self.value = prune_linear_layer(self.self.value, index)
479 |         self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
480 | 
481 |         # Update hyper params and store pruned heads
482 |         self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
483 |         self.self.all_head_size = (
484 |             self.self.attention_head_size * self.self.num_attention_heads
485 |         )
486 |         self.pruned_heads = self.pruned_heads.union(heads)
487 | 
488 |     def forward(
489 |         self,
490 |         hidden_states: torch.Tensor,
491 |         attention_mask: Optional[torch.FloatTensor] = None,
492 |         head_mask: Optional[torch.FloatTensor] = None,
493 |         output_attentions: Optional[bool] = False,
494 |     ) -> Tuple[torch.Tensor]:
495 |         self_outputs = self.self(
496 |             hidden_states, attention_mask, head_mask, output_attentions
497 |         )
498 |         attention_output = self.output(self_outputs[0], hidden_states)
499 |         outputs = (attention_output,) + self_outputs[
500 |             1:
501 |         ]  # add attentions if we output them
502 |         return outputs
503 | 
504 | 
505 | # Copied from transformers.models.swin.modeling_swin.SwinIntermediate
506 | class DonutSwinIntermediate(nn.Module):
507 |     def __init__(self, config, dim):
508 |         super().__init__()
509 |         self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
510 |         if isinstance(config.hidden_act, str):
511 |             self.intermediate_act_fn = ACT2FN[config.hidden_act]
512 |         else:
513 |             self.intermediate_act_fn = config.hidden_act
514 | 
515 |     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
516 |         hidden_states = self.dense(hidden_states)
517 |         hidden_states = self.intermediate_act_fn(hidden_states)
518 |         return hidden_states
519 | 
520 | 
521 | # Copied from transformers.models.swin.modeling_swin.SwinOutput
522 | class DonutSwinOutput(nn.Module):
523 |     def __init__(self, config, dim):
524 |         super().__init__()
525 |         self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
526 | 
527 |     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
528 |         return self.dense(hidden_states)
529 | 
530 | 
531 | # Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
532 | class DonutSwinLayer(nn.Module):
533 |     def __init__(
534 |         self, config, dim, input_resolution, num_heads, num_kv_heads, shift_size=0
535 |     ):
536 |         super().__init__()
537 |         self.chunk_size_feed_forward = config.chunk_size_feed_forward
538 |         self.shift_size = shift_size
539 |         self.window_size = config.window_size
540 |         self.input_resolution = input_resolution
541 |         self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
542 |         self.attention = DonutSwinAttention(
543 |             config, dim, num_heads, num_kv_heads, window_size=self.window_size
544 |         )
545 |         self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
546 |         self.intermediate = DonutSwinIntermediate(config, dim)
547 |         self.output = DonutSwinOutput(config, dim)
548 | 
549 |     def set_shift_and_window_size(self, input_resolution):
550 |         if min(input_resolution) <= self.window_size:
551 |             # if window size is larger than input resolution, we don't partition windows
552 |             self.shift_size = int(0)
553 |             self.window_size = (
554 |                 torch.min(torch.tensor(input_resolution))
555 |                 if torch.jit.is_tracing()
556 |                 else min(input_resolution)
557 |             )
558 | 
559 |     def get_attn_mask(self, height, width, dtype, device):
560 |         if self.shift_size > 0:
561 |             # calculate attention mask for SW-MSA
562 |             img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
563 |             height_slices = (
564 |                 slice(0, -self.window_size),
565 |                 slice(-self.window_size, -self.shift_size),
566 |                 slice(-self.shift_size, None),
567 |             )
568 |             width_slices = (
569 |                 slice(0, -self.window_size),
570 |                 slice(-self.window_size, -self.shift_size),
571 |                 slice(-self.shift_size, None),
572 |             )
573 |             count = 0
574 |             for height_slice in height_slices:
575 |                 for width_slice in width_slices:
576 |                     img_mask[:, height_slice, width_slice, :] = count
577 |                     count += 1
578 | 
579 |             mask_windows = window_partition(img_mask, self.window_size)
580 |             mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
581 |             attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
582 |             attn_mask = attn_mask.masked_fill(
583 |                 attn_mask != 0, float(-100.0)
584 |             ).masked_fill(attn_mask == 0, float(0.0))
585 |         else:
586 |             attn_mask = None
587 |         return attn_mask
588 | 
589 |     def maybe_pad(self, hidden_states, height, width):
590 |         pad_right = (self.window_size - width % self.window_size) % self.window_size
591 |         pad_bottom = (self.window_size - height % self.window_size) % self.window_size
592 |         pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
593 |         hidden_states = nn.functional.pad(hidden_states, pad_values)
594 |         return hidden_states, pad_values
595 | 
596 |     def forward(
597 |         self,
598 |         hidden_states: torch.Tensor,
599 |         input_dimensions: Tuple[int, int],
600 |         head_mask: Optional[torch.FloatTensor] = None,
601 |         output_attentions: Optional[bool] = False,
602 |         always_partition: Optional[bool] = False,
603 |     ) -> Tuple[torch.Tensor, torch.Tensor]:
604 |         if not always_partition:
605 |             self.set_shift_and_window_size(input_dimensions)
606 |         else:
607 |             pass
608 |         height, width = input_dimensions
609 |         batch_size, _, channels = hidden_states.size()
610 |         shortcut = hidden_states
611 | 
612 |         hidden_states = self.layernorm_before(hidden_states)
613 | 
614 |         hidden_states = hidden_states.view(batch_size, height, width, channels)
615 | 
616 |         # pad hidden_states to multiples of window size
617 |         hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
618 | 
619 |         _, height_pad, width_pad, _ = hidden_states.shape
620 |         # cyclic shift
621 |         if self.shift_size > 0:
622 |             shifted_hidden_states = torch.roll(
623 |                 hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
624 |             )
625 |         else:
626 |             shifted_hidden_states = hidden_states
627 | 
628 |         # partition windows
629 |         hidden_states_windows = window_partition(
630 |             shifted_hidden_states, self.window_size
631 |         )
632 |         hidden_states_windows = hidden_states_windows.view(
633 |             -1, self.window_size * self.window_size, channels
634 |         )
635 |         attn_mask = self.get_attn_mask(
636 |             height_pad,
637 |             width_pad,
638 |             dtype=hidden_states.dtype,
639 |             device=hidden_states_windows.device,
640 |         )
641 | 
642 |         attention_outputs = self.attention(
643 |             hidden_states_windows,
644 |             attn_mask,
645 |             head_mask,
646 |             output_attentions=output_attentions,
647 |         )
648 | 
649 |         attention_output = attention_outputs[0]
650 | 
651 |         attention_windows = attention_output.view(
652 |             -1, self.window_size, self.window_size, channels
653 |         )
654 |         shifted_windows = window_reverse(
655 |             attention_windows, self.window_size, height_pad, width_pad
656 |         )
657 | 
658 |         # reverse cyclic shift
659 |         if self.shift_size > 0:
660 |             attention_windows = torch.roll(
661 |                 shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
662 |             )
663 |         else:
664 |             attention_windows = shifted_windows
665 | 
666 |         was_padded = pad_values[3] > 0 or pad_values[5] > 0
667 |         if was_padded:
668 |             attention_windows = attention_windows[:, :height, :width, :].contiguous()
669 | 
670 |         attention_windows = attention_windows.view(batch_size, height * width, channels)
671 | 
672 |         hidden_states = shortcut + attention_windows
673 | 
674 |         layer_output = self.layernorm_after(hidden_states)
675 |         layer_output = self.intermediate(layer_output)
676 |         layer_output = hidden_states + self.output(layer_output)
677 | 
678 |         layer_outputs = (
679 |             (layer_output, attention_outputs[1])
680 |             if output_attentions
681 |             else (layer_output,)
682 |         )
683 |         return layer_outputs
684 | 
685 | 
686 | # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
687 | class DonutSwinStage(nn.Module):
688 |     def __init__(
689 |         self,
690 |         config,
691 |         layer_num,
692 |         dim,
693 |         input_resolution,
694 |         depth,
695 |         num_heads,
696 |         num_kv_heads,
697 |         downsample,
698 |     ):
699 |         super().__init__()
700 |         self.config = config
701 |         self.dim = dim
702 |         self.blocks = nn.ModuleList(
703 |             [
704 |                 DonutSwinLayer(
705 |                     config=config,
706 |                     dim=dim,
707 |                     input_resolution=input_resolution,
708 |                     num_heads=num_heads,
709 |                     num_kv_heads=num_kv_heads,
710 |                     shift_size=0 if (i % 2 == 0) else config.window_size // 2,
711 |                 )
712 |                 for i in range(depth)
713 |             ]
714 |         )
715 | 
716 |         # patch merging layer
717 |         if downsample is not None:
718 |             self.downsample = downsample(
719 |                 input_resolution, dim=dim, norm_layer=nn.LayerNorm
720 |             )
721 |         else:
722 |             self.downsample = None
723 | 
724 |         self.pointing = False
725 | 
726 |         self.positional_encoding = None
727 |         if config.use_positional_embeddings:
728 |             self.positional_encoding = self.build_2d_sincos_position_embedding(
729 |                 input_resolution[1],
730 |                 input_resolution[0],
731 |                 embed_dim=dim,
732 |             )
733 | 
734 |     @staticmethod
735 |     def build_2d_sincos_position_embedding(
736 |         width,
737 |         height,
738 |         embed_dim=256,
739 |         temperature=10000.0,
740 |         device="cpu",
741 |         dtype=torch.float32,
742 |     ):
743 |         grid_w = torch.arange(int(width), dtype=dtype, device=device)
744 |         grid_h = torch.arange(int(height), dtype=dtype, device=device)
745 |         grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
746 |         if embed_dim % 4 != 0:
747 |             raise ValueError(
748 |                 "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
749 |             )
750 |         pos_dim = embed_dim // 4
751 |         omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
752 |         omega = 1.0 / (temperature**omega)
753 | 
754 |         out_w = grid_w.flatten()[..., None] @ omega[None]
755 |         out_h = grid_h.flatten()[..., None] @ omega[None]
756 | 
757 |         return torch.concat(
758 |             [out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1
759 |         )[None, :, :]
760 | 
761 |     def forward(
762 |         self,
763 |         hidden_states: torch.Tensor,
764 |         input_dimensions: Tuple[int, int],
765 |         head_mask: Optional[torch.FloatTensor] = None,
766 |         output_attentions: Optional[bool] = False,
767 |         always_partition: Optional[bool] = False,
768 |     ) -> Tuple[torch.Tensor]:
769 |         height, width = input_dimensions
770 | 
771 |         if self.positional_encoding is not None:
772 |             hidden_states = hidden_states + self.positional_encoding.to(
773 |                 hidden_states.dtype
774 |             ).to(hidden_states.device)
775 | 
776 |         for i, layer_module in enumerate(self.blocks):
777 |             layer_head_mask = head_mask[i] if head_mask is not None else None
778 | 
779 |             layer_outputs = layer_module(
780 |                 hidden_states,
781 |                 input_dimensions,
782 |                 layer_head_mask,
783 |                 output_attentions,
784 |                 always_partition,
785 |             )
786 | 
787 |             hidden_states = layer_outputs[0]
788 | 
789 |         hidden_states_before_downsampling = hidden_states
790 |         if self.downsample is not None:
791 |             height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
792 |             output_dimensions = (height, width, height_downsampled, width_downsampled)
793 |             hidden_states = self.downsample(
794 |                 hidden_states_before_downsampling, input_dimensions
795 |             )
796 |         else:
797 |             output_dimensions = (height, width, height, width)
798 | 
799 |         stage_outputs = (
800 |             hidden_states,
801 |             hidden_states_before_downsampling,
802 |             output_dimensions,
803 |         )
804 | 
805 |         if output_attentions:
806 |             stage_outputs += layer_outputs[1:]
807 |         return stage_outputs
808 | 
809 | 
810 | # Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin
811 | class DonutSwinEncoder(nn.Module):
812 |     def __init__(self, config, grid_size):
813 |         super().__init__()
814 |         self.num_layers = len(config.depths)
815 |         self.config = config
816 |         self.layers = nn.ModuleList(
817 |             [
818 |                 DonutSwinStage(
819 |                     config=config,
820 |                     layer_num=i_layer,
821 |                     dim=int(config.embed_dim * 2**i_layer),
822 |                     input_resolution=(
823 |                         grid_size[0] // (2**i_layer),
824 |                         grid_size[1] // (2**i_layer),
825 |                     ),
826 |                     depth=config.depths[i_layer],
827 |                     num_heads=config.num_heads[i_layer],
828 |                     num_kv_heads=config.num_kv_heads[i_layer]
829 |                     if hasattr(config, "num_kv_heads")
830 |                     else config.num_heads[i_layer],
831 |                     downsample=DonutSwinPatchMerging
832 |                     if (i_layer < self.num_layers - 1)
833 |                     else None,
834 |                 )
835 |                 for i_layer in range(self.num_layers)
836 |             ]
837 |         )
838 | 
839 |         self.gradient_checkpointing = False
840 | 
841 |     def forward(
842 |         self,
843 |         hidden_states: torch.Tensor,
844 |         input_dimensions: Tuple[int, int],
845 |         head_mask: Optional[torch.FloatTensor] = None,
846 |         output_attentions: Optional[bool] = False,
847 |         output_hidden_states: Optional[bool] = False,
848 |         output_hidden_states_before_downsampling: Optional[bool] = False,
849 |         always_partition: Optional[bool] = False,
850 |         return_dict: Optional[bool] = True,
851 |     ) -> Union[Tuple, DonutSwinEncoderOutput]:
852 |         all_hidden_states = () if output_hidden_states else None
853 |         all_reshaped_hidden_states = () if output_hidden_states else None
854 |         all_self_attentions = () if output_attentions else None
855 | 
856 |         if output_hidden_states:
857 |             batch_size, _, hidden_size = hidden_states.shape
858 |             # rearrange b (h w) c -> b c h w
859 |             reshaped_hidden_state = hidden_states.view(
860 |                 batch_size, *input_dimensions, hidden_size
861 |             )
862 |             reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
863 |             all_hidden_states += (hidden_states,)
864 |             all_reshaped_hidden_states += (reshaped_hidden_state,)
865 | 
866 |         for i, layer_module in enumerate(self.layers):
867 |             layer_head_mask = head_mask[i] if head_mask is not None else None
868 | 
869 |             if self.gradient_checkpointing and self.training:
870 |                 layer_outputs = self._gradient_checkpointing_func(
871 |                     layer_module.__call__,
872 |                     hidden_states,
873 |                     input_dimensions,
874 |                     layer_head_mask,
875 |                     output_attentions,
876 |                     always_partition,
877 |                 )
878 |             else:
879 |                 layer_outputs = layer_module(
880 |                     hidden_states,
881 |                     input_dimensions,
882 |                     layer_head_mask,
883 |                     output_attentions,
884 |                     always_partition,
885 |                 )
886 | 
887 |             hidden_states = layer_outputs[0]
888 |             hidden_states_before_downsampling = layer_outputs[1]
889 |             output_dimensions = layer_outputs[2]
890 |             input_dimensions = (output_dimensions[-2], output_dimensions[-1])
891 | 
892 |             if output_hidden_states and output_hidden_states_before_downsampling:
893 |                 batch_size, _, hidden_size = hidden_states_before_downsampling.shape
894 |                 # rearrange b (h w) c -> b c h w
895 |                 # here we use the original (not downsampled) height and width
896 |                 reshaped_hidden_state = hidden_states_before_downsampling.view(
897 |                     batch_size,
898 |                     *(output_dimensions[0], output_dimensions[1]),
899 |                     hidden_size,
900 |                 )
901 |                 reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
902 |                 all_hidden_states += (hidden_states_before_downsampling,)
903 |                 all_reshaped_hidden_states += (reshaped_hidden_state,)
904 |             elif output_hidden_states and not output_hidden_states_before_downsampling:
905 |                 batch_size, _, hidden_size = hidden_states.shape
906 |                 # rearrange b (h w) c -> b c h w
907 |                 reshaped_hidden_state = hidden_states.view(
908 |                     batch_size, *input_dimensions, hidden_size
909 |                 )
910 |                 reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
911 |                 all_hidden_states += (hidden_states,)
912 |                 all_reshaped_hidden_states += (reshaped_hidden_state,)
913 | 
914 |             if output_attentions:
915 |                 all_self_attentions += layer_outputs[3:]
916 | 
917 |         if not return_dict:
918 |             return tuple(
919 |                 v
920 |                 for v in [hidden_states, all_hidden_states, all_self_attentions]
921 |                 if v is not None
922 |             )
923 | 
924 |         return DonutSwinEncoderOutput(
925 |             last_hidden_state=hidden_states,
926 |             hidden_states=all_hidden_states,
927 |             attentions=all_self_attentions,
928 |             reshaped_hidden_states=all_reshaped_hidden_states,
929 |         )
930 | 
931 | 
932 | # Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin
933 | class DonutSwinPreTrainedModel(SuryaPreTrainedModel):
934 |     """
935 |     An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
936 |     models.
937 |     """
938 | 
939 |     config_class = DonutSwinConfig
940 |     base_model_prefix = "swin"
941 |     main_input_name = "pixel_values"
942 |     supports_gradient_checkpointing = True
943 |     _no_split_modules = ["DonutSwinStage"]
944 | 
945 |     def _init_weights(self, module):
946 |         """Initialize the weights"""
947 |         if isinstance(module, (nn.Linear, nn.Conv2d)):
948 |             # Slightly different from the TF version which uses truncated_normal for initialization
949 |             # cf https://github.com/pytorch/pytorch/pull/5617
950 |             module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
951 |             if module.bias is not None:
952 |                 module.bias.data.zero_()
953 |         elif isinstance(module, nn.LayerNorm):
954 |             module.bias.data.zero_()
955 |             module.weight.data.fill_(1.0)
956 | 
```

--------------------------------------------------------------------------------
/surya/ocr_error/model/encoder.py:
--------------------------------------------------------------------------------

```python
  1 | from __future__ import annotations
  2 | 
  3 | import math
  4 | from typing import Optional, Set, List, Tuple, Union, Dict
  5 | 
  6 | import numpy as np
  7 | import torch
  8 | from torch import nn
  9 | from torch.nn import functional as F, MSELoss, CrossEntropyLoss, BCEWithLogitsLoss
 10 | from transformers import apply_chunking_to_forward
 11 | from transformers.activations import get_activation
 12 | from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput
 13 | from transformers.pytorch_utils import (
 14 |     find_pruneable_heads_and_indices,
 15 |     prune_linear_layer,
 16 | )
 17 | 
 18 | from transformers.utils import (
 19 |     is_flash_attn_greater_or_equal_2_10,
 20 | )
 21 | 
 22 | from surya.common.pretrained import SuryaPreTrainedModel
 23 | 
 24 | from surya.common.s3 import S3DownloaderMixin
 25 | from surya.ocr_error.model.config import DistilBertConfig
 26 | 
 27 | 
 28 | def _get_unpad_data(attention_mask):
 29 |     seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 30 |     indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
 31 |     max_seqlen_in_batch = seqlens_in_batch.max().item()
 32 |     cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
 33 |     return (
 34 |         indices,
 35 |         cu_seqlens,
 36 |         max_seqlen_in_batch,
 37 |     )
 38 | 
 39 | 
 40 | def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
 41 |     position_enc = np.array(
 42 |         [
 43 |             [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
 44 |             for pos in range(n_pos)
 45 |         ]
 46 |     )
 47 |     out.requires_grad = False
 48 |     out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
 49 |     out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
 50 |     out.detach_()
 51 | 
 52 | 
 53 | class Embeddings(nn.Module):
 54 |     def __init__(self, config: DistilBertConfig):
 55 |         super().__init__()
 56 |         self.word_embeddings = nn.Embedding(
 57 |             config.vocab_size, config.dim, padding_idx=config.pad_token_id
 58 |         )
 59 |         self.position_embeddings = nn.Embedding(
 60 |             config.max_position_embeddings, config.dim
 61 |         )
 62 | 
 63 |         self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
 64 |         self.dropout = nn.Dropout(config.dropout)
 65 |         self.register_buffer(
 66 |             "position_ids",
 67 |             torch.arange(config.max_position_embeddings).expand((1, -1)),
 68 |             persistent=False,
 69 |         )
 70 | 
 71 |     def forward(
 72 |         self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None
 73 |     ) -> torch.Tensor:
 74 |         """
 75 |         Parameters:
 76 |             input_ids (torch.Tensor):
 77 |                 torch.tensor(bs, max_seq_length) The token ids to embed.
 78 |             input_embeds (*optional*, torch.Tensor):
 79 |                 The pre-computed word embeddings. Can only be passed if the input ids are `None`.
 80 | 
 81 | 
 82 |         Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
 83 |         embeddings)
 84 |         """
 85 |         if input_ids is not None:
 86 |             input_embeds = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)
 87 | 
 88 |         seq_length = input_embeds.size(1)
 89 | 
 90 |         # Setting the position-ids to the registered buffer in constructor, it helps
 91 |         # when tracing the model without passing position-ids, solves
 92 |         # isues similar to issue #5664
 93 |         if hasattr(self, "position_ids"):
 94 |             position_ids = self.position_ids[:, :seq_length]
 95 |         else:
 96 |             position_ids = torch.arange(
 97 |                 seq_length, dtype=torch.long, device=input_ids.device
 98 |             )  # (max_seq_length)
 99 |             position_ids = position_ids.unsqueeze(0).expand_as(
100 |                 input_ids
101 |             )  # (bs, max_seq_length)
102 | 
103 |         position_embeddings = self.position_embeddings(
104 |             position_ids
105 |         )  # (bs, max_seq_length, dim)
106 | 
107 |         embeddings = input_embeds + position_embeddings  # (bs, max_seq_length, dim)
108 |         embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
109 |         embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)
110 |         return embeddings
111 | 
112 | 
113 | class MultiHeadSelfAttention(nn.Module):
114 |     def __init__(self, config: DistilBertConfig):
115 |         super().__init__()
116 |         self.config = config
117 | 
118 |         self.n_heads = config.n_heads
119 |         self.dim = config.dim
120 |         self.dropout = nn.Dropout(p=config.attention_dropout)
121 |         self.is_causal = False
122 | 
123 |         # Have an even number of multi heads that divide the dimensions
124 |         if self.dim % self.n_heads != 0:
125 |             # Raise value errors for even multi-head attention nodes
126 |             raise ValueError(
127 |                 f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly"
128 |             )
129 | 
130 |         self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
131 |         self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
132 |         self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
133 |         self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
134 | 
135 |         self.pruned_heads: Set[int] = set()
136 |         self.attention_head_size = self.dim // self.n_heads
137 | 
138 |     def prune_heads(self, heads: List[int]):
139 |         if len(heads) == 0:
140 |             return
141 |         heads, index = find_pruneable_heads_and_indices(
142 |             heads, self.n_heads, self.attention_head_size, self.pruned_heads
143 |         )
144 |         # Prune linear layers
145 |         self.q_lin = prune_linear_layer(self.q_lin, index)
146 |         self.k_lin = prune_linear_layer(self.k_lin, index)
147 |         self.v_lin = prune_linear_layer(self.v_lin, index)
148 |         self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
149 |         # Update hyper params
150 |         self.n_heads = self.n_heads - len(heads)
151 |         self.dim = self.attention_head_size * self.n_heads
152 |         self.pruned_heads = self.pruned_heads.union(heads)
153 | 
154 |     def forward(
155 |         self,
156 |         query: torch.Tensor,
157 |         key: torch.Tensor,
158 |         value: torch.Tensor,
159 |         mask: torch.Tensor,
160 |         head_mask: Optional[torch.Tensor] = None,
161 |         output_attentions: bool = False,
162 |     ) -> Tuple[torch.Tensor, ...]:
163 |         """
164 |         Parameters:
165 |             query: torch.tensor(bs, seq_length, dim)
166 |             key: torch.tensor(bs, seq_length, dim)
167 |             value: torch.tensor(bs, seq_length, dim)
168 |             mask: torch.tensor(bs, seq_length)
169 | 
170 |         Returns:
171 |             weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
172 |             seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
173 |         """
174 |         bs, q_length, dim = query.size()
175 |         k_length = key.size(1)
176 |         # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
177 |         # assert key.size() == value.size()
178 | 
179 |         dim_per_head = self.dim // self.n_heads
180 | 
181 |         mask_reshp = (bs, 1, 1, k_length)
182 | 
183 |         def shape(x: torch.Tensor) -> torch.Tensor:
184 |             """separate heads"""
185 |             return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
186 | 
187 |         def unshape(x: torch.Tensor) -> torch.Tensor:
188 |             """group heads"""
189 |             return (
190 |                 x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
191 |             )
192 | 
193 |         q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
194 |         k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
195 |         v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)
196 | 
197 |         q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
198 |         scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
199 |         mask = (
200 |             (mask == 0).view(mask_reshp).expand_as(scores)
201 |         )  # (bs, n_heads, q_length, k_length)
202 |         scores = scores.masked_fill(
203 |             mask, torch.tensor(torch.finfo(scores.dtype).min)
204 |         )  # (bs, n_heads, q_length, k_length)
205 | 
206 |         weights = nn.functional.softmax(
207 |             scores, dim=-1
208 |         )  # (bs, n_heads, q_length, k_length)
209 |         weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)
210 | 
211 |         # Mask heads if we want to
212 |         if head_mask is not None:
213 |             weights = weights * head_mask
214 | 
215 |         context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
216 |         context = unshape(context)  # (bs, q_length, dim)
217 |         context = self.out_lin(context)  # (bs, q_length, dim)
218 | 
219 |         if output_attentions:
220 |             return (context, weights)
221 |         else:
222 |             return (context,)
223 | 
224 | 
225 | class DistilBertFlashAttention2(MultiHeadSelfAttention):
226 |     """
227 |     DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module
228 |     stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
229 |     API of flash attention and deal with padding tokens in case the input contains any of them.
230 |     """
231 | 
232 |     # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
233 |     def __init__(self, *args, **kwargs):
234 |         super().__init__(*args, **kwargs)
235 | 
236 |         # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
237 |         # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
238 |         # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
239 |         self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
240 | 
241 |     def forward(
242 |         self,
243 |         query: torch.Tensor,
244 |         key: torch.Tensor,
245 |         value: torch.Tensor,
246 |         mask: torch.Tensor,
247 |         head_mask: Optional[torch.Tensor] = None,
248 |         output_attentions: bool = False,
249 |     ) -> Tuple[torch.Tensor, ...]:
250 |         """
251 |         Parameters:
252 |             query: torch.tensor(bs, seq_length, dim)
253 |             key: torch.tensor(bs, seq_length, dim)
254 |             value: torch.tensor(bs, seq_length, dim)
255 |             mask: torch.tensor(bs, seq_length)
256 | 
257 |         Returns:
258 |             weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
259 |             seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
260 |         """
261 |         batch_size, q_length, dim = query.size()
262 | 
263 |         dim_per_head = self.dim // self.n_heads
264 | 
265 |         def reshape(x: torch.Tensor) -> torch.Tensor:
266 |             """separate heads"""
267 |             return x.view(batch_size, -1, self.n_heads, dim_per_head)
268 | 
269 |         # Flash attention requires the input to have the shape
270 |         # batch_size x seq_length x head_dim x hidden_dim
271 |         query_states = reshape(self.q_lin(query))
272 |         key_states = reshape(self.k_lin(key))
273 |         value_states = reshape(self.v_lin(value))
274 | 
275 |         attn_dropout = self.config.attention_dropout if self.training else 0.0
276 | 
277 |         # In PEFT, usually we cast the layer norms in float32 for training stability reasons
278 |         # therefore the input hidden states gets silently casted in float32. Hence, we need
279 |         # cast them back in the correct dtype just to be sure everything works as expected.
280 |         # This might slowdown training & inference so it is recommended to not cast the LayerNorms
281 |         # in fp32. (LlamaRMSNorm handles it correctly)
282 | 
283 |         if query_states.dtype == torch.float32:
284 |             if torch.is_autocast_enabled():
285 |                 target_dtype = torch.get_autocast_gpu_dtype()
286 |             # Handle the case where the model is quantized
287 |             elif hasattr(self.config, "_pre_quantization_dtype"):
288 |                 target_dtype = self.config._pre_quantization_dtype
289 |             else:
290 |                 target_dtype = self.q_lin.weight.dtype
291 | 
292 |             query_states = query_states.to(target_dtype)
293 |             key_states = key_states.to(target_dtype)
294 |             value_states = value_states.to(target_dtype)
295 | 
296 |         attn_weights = self._flash_attention_forward(
297 |             query_states, key_states, value_states, mask, q_length, dropout=attn_dropout
298 |         )
299 | 
300 |         attn_weights_reshaped = attn_weights.reshape(
301 |             batch_size, q_length, self.n_heads * dim_per_head
302 |         )
303 |         attn_output = self.out_lin(attn_weights_reshaped)
304 | 
305 |         if output_attentions:
306 |             return (attn_output, attn_weights)
307 |         else:
308 |             return (attn_output,)
309 | 
310 |     # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False
311 |     def _flash_attention_forward(
312 |         self,
313 |         query_states,
314 |         key_states,
315 |         value_states,
316 |         attention_mask,
317 |         query_length,
318 |         dropout=0.0,
319 |         softmax_scale=None,
320 |     ):
321 |         """
322 |         Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
323 |         first unpad the input, then computes the attention scores and pad the final attention scores.
324 | 
325 |         Args:
326 |             query_states (`torch.Tensor`):
327 |                 Input query states to be passed to Flash Attention API
328 |             key_states (`torch.Tensor`):
329 |                 Input key states to be passed to Flash Attention API
330 |             value_states (`torch.Tensor`):
331 |                 Input value states to be passed to Flash Attention API
332 |             attention_mask (`torch.Tensor`):
333 |                 The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
334 |                 position of padding tokens and 1 for the position of non-padding tokens.
335 |             dropout (`float`):
336 |                 Attention dropout
337 |             softmax_scale (`float`, *optional*):
338 |                 The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
339 |         """
340 |         from flash_attn import flash_attn_func, flash_attn_varlen_func
341 |         from flash_attn.bert_padding import pad_input
342 | 
343 |         if not self._flash_attn_uses_top_left_mask:
344 |             causal = self.is_causal
345 |         else:
346 |             # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
347 |             causal = self.is_causal and query_length != 1
348 | 
349 |         # Contains at least one padding token in the sequence
350 |         if attention_mask is not None:
351 |             batch_size = query_states.shape[0]
352 |             (
353 |                 query_states,
354 |                 key_states,
355 |                 value_states,
356 |                 indices_q,
357 |                 cu_seq_lens,
358 |                 max_seq_lens,
359 |             ) = self._upad_input(
360 |                 query_states, key_states, value_states, attention_mask, query_length
361 |             )
362 | 
363 |             cu_seqlens_q, cu_seqlens_k = cu_seq_lens
364 |             max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
365 | 
366 |             attn_output_unpad = flash_attn_varlen_func(
367 |                 query_states,
368 |                 key_states,
369 |                 value_states,
370 |                 cu_seqlens_q=cu_seqlens_q,
371 |                 cu_seqlens_k=cu_seqlens_k,
372 |                 max_seqlen_q=max_seqlen_in_batch_q,
373 |                 max_seqlen_k=max_seqlen_in_batch_k,
374 |                 dropout_p=dropout,
375 |                 softmax_scale=softmax_scale,
376 |                 causal=causal,
377 |             )
378 | 
379 |             attn_output = pad_input(
380 |                 attn_output_unpad, indices_q, batch_size, query_length
381 |             )
382 |         else:
383 |             attn_output = flash_attn_func(
384 |                 query_states,
385 |                 key_states,
386 |                 value_states,
387 |                 dropout,
388 |                 softmax_scale=softmax_scale,
389 |                 causal=causal,
390 |             )
391 | 
392 |         return attn_output
393 | 
394 |     # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads
395 |     def _upad_input(
396 |         self, query_layer, key_layer, value_layer, attention_mask, query_length
397 |     ):
398 |         from flash_attn.bert_padding import index_first_axis, unpad_input
399 | 
400 |         indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
401 |         batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
402 | 
403 |         key_layer = index_first_axis(
404 |             key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
405 |             indices_k,
406 |         )
407 |         value_layer = index_first_axis(
408 |             value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
409 |             indices_k,
410 |         )
411 |         if query_length == kv_seq_len:
412 |             query_layer = index_first_axis(
413 |                 query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim),
414 |                 indices_k,
415 |             )
416 |             cu_seqlens_q = cu_seqlens_k
417 |             max_seqlen_in_batch_q = max_seqlen_in_batch_k
418 |             indices_q = indices_k
419 |         elif query_length == 1:
420 |             max_seqlen_in_batch_q = 1
421 |             cu_seqlens_q = torch.arange(
422 |                 batch_size + 1, dtype=torch.int32, device=query_layer.device
423 |             )  # There is a memcpy here, that is very bad.
424 |             indices_q = cu_seqlens_q[:-1]
425 |             query_layer = query_layer.squeeze(1)
426 |         else:
427 |             # The -q_len: slice assumes left padding.
428 |             attention_mask = attention_mask[:, -query_length:]
429 |             query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
430 |                 query_layer, attention_mask
431 |             )
432 | 
433 |         return (
434 |             query_layer,
435 |             key_layer,
436 |             value_layer,
437 |             indices_q,
438 |             (cu_seqlens_q, cu_seqlens_k),
439 |             (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
440 |         )
441 | 
442 | 
443 | class FFN(nn.Module):
444 |     def __init__(self, config: DistilBertConfig):
445 |         super().__init__()
446 |         self.dropout = nn.Dropout(p=config.dropout)
447 |         self.chunk_size_feed_forward = config.chunk_size_feed_forward
448 |         self.seq_len_dim = 1
449 |         self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
450 |         self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
451 |         self.activation = get_activation(config.activation)
452 | 
453 |     def forward(self, input: torch.Tensor) -> torch.Tensor:
454 |         return apply_chunking_to_forward(
455 |             self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input
456 |         )
457 | 
458 |     def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
459 |         x = self.lin1(input)
460 |         x = self.activation(x)
461 |         x = self.lin2(x)
462 |         x = self.dropout(x)
463 |         return x
464 | 
465 | 
466 | DISTILBERT_ATTENTION_CLASSES = {
467 |     "eager": MultiHeadSelfAttention,
468 |     "flash_attention_2": DistilBertFlashAttention2,
469 | }
470 | 
471 | 
472 | class TransformerBlock(nn.Module):
473 |     def __init__(self, config: DistilBertConfig):
474 |         super().__init__()
475 | 
476 |         # Have an even number of Configure multi-heads
477 |         if config.dim % config.n_heads != 0:
478 |             raise ValueError(
479 |                 f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly"
480 |             )
481 | 
482 |         self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](
483 |             config
484 |         )
485 |         self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
486 | 
487 |         self.ffn = FFN(config)
488 |         self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
489 | 
490 |     def forward(
491 |         self,
492 |         x: torch.Tensor,
493 |         attn_mask: Optional[torch.Tensor] = None,
494 |         head_mask: Optional[torch.Tensor] = None,
495 |         output_attentions: bool = False,
496 |     ) -> Tuple[torch.Tensor, ...]:
497 |         """
498 |         Parameters:
499 |             x: torch.tensor(bs, seq_length, dim)
500 |             attn_mask: torch.tensor(bs, seq_length)
501 | 
502 |         Returns:
503 |             sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
504 |             torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
505 |         """
506 |         # Self-Attention
507 |         sa_output = self.attention(
508 |             query=x,
509 |             key=x,
510 |             value=x,
511 |             mask=attn_mask,
512 |             head_mask=head_mask,
513 |             output_attentions=output_attentions,
514 |         )
515 |         if output_attentions:
516 |             sa_output, sa_weights = (
517 |                 sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
518 |             )
519 |         else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
520 |             sa_output = sa_output[0]
521 | 
522 |         sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)
523 | 
524 |         # Feed Forward Network
525 |         ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)
526 |         ffn_output: torch.Tensor = self.output_layer_norm(
527 |             ffn_output + sa_output
528 |         )  # (bs, seq_length, dim)
529 | 
530 |         output = (ffn_output,)
531 |         if output_attentions:
532 |             output = (sa_weights,) + output
533 |         return output
534 | 
535 | 
536 | class Transformer(nn.Module):
537 |     def __init__(self, config: DistilBertConfig):
538 |         super().__init__()
539 |         self.n_layers = config.n_layers
540 |         self.layer = nn.ModuleList(
541 |             [TransformerBlock(config) for _ in range(config.n_layers)]
542 |         )
543 |         self.gradient_checkpointing = False
544 | 
545 |     def forward(
546 |         self,
547 |         x: torch.Tensor,
548 |         attn_mask: Optional[torch.Tensor] = None,
549 |         head_mask: Optional[torch.Tensor] = None,
550 |         output_attentions: bool = False,
551 |         output_hidden_states: bool = False,
552 |         return_dict: Optional[bool] = None,
553 |     ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:  # docstyle-ignore
554 |         """
555 |         Parameters:
556 |             x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
557 |             attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.
558 | 
559 |         Returns:
560 |             hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)
561 |             layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
562 |                 Tuple of length n_layers with the hidden states from each layer.
563 |                 Optional: only if output_hidden_states=True
564 |             all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
565 |                 Tuple of length n_layers with the attention weights from each layer
566 |                 Optional: only if output_attentions=True
567 |         """
568 |         all_hidden_states = () if output_hidden_states else None
569 |         all_attentions = () if output_attentions else None
570 | 
571 |         hidden_state = x
572 |         for i, layer_module in enumerate(self.layer):
573 |             if output_hidden_states:
574 |                 all_hidden_states = all_hidden_states + (hidden_state,)
575 | 
576 |             if self.gradient_checkpointing and self.training:
577 |                 layer_outputs = self._gradient_checkpointing_func(
578 |                     layer_module.__call__,
579 |                     hidden_state,
580 |                     attn_mask,
581 |                     head_mask[i],
582 |                     output_attentions,
583 |                 )
584 |             else:
585 |                 layer_outputs = layer_module(
586 |                     hidden_state,
587 |                     attn_mask,
588 |                     head_mask[i],
589 |                     output_attentions,
590 |                 )
591 | 
592 |             hidden_state = layer_outputs[-1]
593 | 
594 |             if output_attentions:
595 |                 if len(layer_outputs) != 2:
596 |                     raise ValueError(
597 |                         f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}"
598 |                     )
599 | 
600 |                 attentions = layer_outputs[0]
601 |                 all_attentions = all_attentions + (attentions,)
602 |             else:
603 |                 if len(layer_outputs) != 1:
604 |                     raise ValueError(
605 |                         f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}"
606 |                     )
607 | 
608 |         # Add last layer
609 |         if output_hidden_states:
610 |             all_hidden_states = all_hidden_states + (hidden_state,)
611 | 
612 |         if not return_dict:
613 |             return tuple(
614 |                 v
615 |                 for v in [hidden_state, all_hidden_states, all_attentions]
616 |                 if v is not None
617 |             )
618 |         return BaseModelOutput(
619 |             last_hidden_state=hidden_state,
620 |             hidden_states=all_hidden_states,
621 |             attentions=all_attentions,
622 |         )
623 | 
624 | 
625 | class DistilBertPreTrainedModel(SuryaPreTrainedModel):
626 |     """
627 |     An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
628 |     models.
629 |     """
630 | 
631 |     config_class = DistilBertConfig
632 |     load_tf_weights = None
633 |     base_model_prefix = "distilbert"
634 |     supports_gradient_checkpointing = True
635 |     _supports_flash_attn_2 = True
636 | 
637 |     def _init_weights(self, module: nn.Module):
638 |         """Initialize the weights."""
639 |         if isinstance(module, nn.Linear):
640 |             # Slightly different from the TF version which uses truncated_normal for initialization
641 |             # cf https://github.com/pytorch/pytorch/pull/5617
642 |             module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
643 |             if module.bias is not None:
644 |                 module.bias.data.zero_()
645 |         elif isinstance(module, nn.Embedding):
646 |             module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
647 |             if module.padding_idx is not None:
648 |                 module.weight.data[module.padding_idx].zero_()
649 |         elif isinstance(module, nn.LayerNorm):
650 |             module.bias.data.zero_()
651 |             module.weight.data.fill_(1.0)
652 |         elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds:
653 |             create_sinusoidal_embeddings(
654 |                 self.config.max_position_embeddings,
655 |                 self.config.dim,
656 |                 module.position_embeddings.weight,
657 |             )
658 | 
659 | 
660 | class DistilBertModel(DistilBertPreTrainedModel):
661 |     def __init__(self, config: DistilBertConfig):
662 |         super().__init__(config)
663 | 
664 |         self.embeddings = Embeddings(config)  # Embeddings
665 |         self.transformer = Transformer(config)  # Encoder
666 |         self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
667 | 
668 |         # Initialize weights and apply final processing
669 |         self.post_init()
670 | 
671 |     def get_position_embeddings(self) -> nn.Embedding:
672 |         """
673 |         Returns the position embeddings
674 |         """
675 |         return self.embeddings.position_embeddings
676 | 
677 |     def resize_position_embeddings(self, new_num_position_embeddings: int):
678 |         """
679 |         Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
680 | 
681 |         Arguments:
682 |             new_num_position_embeddings (`int`):
683 |                 The number of new position embedding matrix. If position embeddings are learned, increasing the size
684 |                 will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
685 |                 end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
686 |                 size will add correct vectors at the end following the position encoding algorithm, whereas reducing
687 |                 the size will remove vectors from the end.
688 |         """
689 |         num_position_embeds_diff = (
690 |             new_num_position_embeddings - self.config.max_position_embeddings
691 |         )
692 | 
693 |         # no resizing needs to be done if the length stays the same
694 |         if num_position_embeds_diff == 0:
695 |             return
696 | 
697 |         self.config.max_position_embeddings = new_num_position_embeddings
698 | 
699 |         old_position_embeddings_weight = (
700 |             self.embeddings.position_embeddings.weight.clone()
701 |         )
702 | 
703 |         self.embeddings.position_embeddings = nn.Embedding(
704 |             self.config.max_position_embeddings, self.config.dim
705 |         )
706 | 
707 |         if self.config.sinusoidal_pos_embds:
708 |             create_sinusoidal_embeddings(
709 |                 n_pos=self.config.max_position_embeddings,
710 |                 dim=self.config.dim,
711 |                 out=self.position_embeddings.weight,
712 |             )
713 |         else:
714 |             with torch.no_grad():
715 |                 if num_position_embeds_diff > 0:
716 |                     self.embeddings.position_embeddings.weight[
717 |                         :-num_position_embeds_diff
718 |                     ] = nn.Parameter(old_position_embeddings_weight)
719 |                 else:
720 |                     self.embeddings.position_embeddings.weight = nn.Parameter(
721 |                         old_position_embeddings_weight[:num_position_embeds_diff]
722 |                     )
723 |         # move position_embeddings to correct device
724 |         self.embeddings.position_embeddings.to(self.device)
725 | 
726 |     def get_input_embeddings(self) -> nn.Embedding:
727 |         return self.embeddings.word_embeddings
728 | 
729 |     def set_input_embeddings(self, new_embeddings: nn.Embedding):
730 |         self.embeddings.word_embeddings = new_embeddings
731 | 
732 |     def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]):
733 |         """
734 |         Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
735 |         class PreTrainedModel
736 |         """
737 |         for layer, heads in heads_to_prune.items():
738 |             self.transformer.layer[layer].attention.prune_heads(heads)
739 | 
740 |     def forward(
741 |         self,
742 |         input_ids: Optional[torch.Tensor] = None,
743 |         attention_mask: Optional[torch.Tensor] = None,
744 |         head_mask: Optional[torch.Tensor] = None,
745 |         inputs_embeds: Optional[torch.Tensor] = None,
746 |         output_attentions: Optional[bool] = None,
747 |         output_hidden_states: Optional[bool] = None,
748 |         return_dict: Optional[bool] = None,
749 |     ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
750 |         output_attentions = (
751 |             output_attentions
752 |             if output_attentions is not None
753 |             else self.config.output_attentions
754 |         )
755 |         output_hidden_states = (
756 |             output_hidden_states
757 |             if output_hidden_states is not None
758 |             else self.config.output_hidden_states
759 |         )
760 |         return_dict = (
761 |             return_dict if return_dict is not None else self.config.use_return_dict
762 |         )
763 | 
764 |         if input_ids is not None and inputs_embeds is not None:
765 |             raise ValueError(
766 |                 "You cannot specify both input_ids and inputs_embeds at the same time"
767 |             )
768 |         elif input_ids is not None:
769 |             self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
770 |             input_shape = input_ids.size()
771 |         elif inputs_embeds is not None:
772 |             input_shape = inputs_embeds.size()[:-1]
773 |         else:
774 |             raise ValueError("You have to specify either input_ids or inputs_embeds")
775 | 
776 |         device = input_ids.device if input_ids is not None else inputs_embeds.device
777 | 
778 |         # Prepare head mask if needed
779 |         head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
780 | 
781 |         embeddings = self.embeddings(input_ids, inputs_embeds)  # (bs, seq_length, dim)
782 | 
783 |         if self._use_flash_attention_2:
784 |             attention_mask = (
785 |                 attention_mask
786 |                 if (attention_mask is not None and 0 in attention_mask)
787 |                 else None
788 |             )
789 |         else:
790 |             if attention_mask is None:
791 |                 attention_mask = torch.ones(
792 |                     input_shape, device=device
793 |                 )  # (bs, seq_length)
794 | 
795 |         return self.transformer(
796 |             x=embeddings,
797 |             attn_mask=attention_mask,
798 |             head_mask=head_mask,
799 |             output_attentions=output_attentions,
800 |             output_hidden_states=output_hidden_states,
801 |             return_dict=return_dict,
802 |         )
803 | 
804 | 
805 | class DistilBertForSequenceClassification(S3DownloaderMixin, DistilBertPreTrainedModel):
806 |     def __init__(self, config: DistilBertConfig, **kwargs):
807 |         super().__init__(config, **kwargs)
808 |         self.num_labels = config.num_labels
809 |         self.config = config
810 | 
811 |         self.distilbert = DistilBertModel(config)
812 |         self.pre_classifier = nn.Linear(config.dim, config.dim)
813 |         self.classifier = nn.Linear(config.dim, config.num_labels)
814 |         self.dropout = nn.Dropout(config.seq_classif_dropout)
815 | 
816 |         # Initialize weights and apply final processing
817 |         self.post_init()
818 | 
819 |     def get_position_embeddings(self) -> nn.Embedding:
820 |         """
821 |         Returns the position embeddings
822 |         """
823 |         return self.distilbert.get_position_embeddings()
824 | 
825 |     def resize_position_embeddings(self, new_num_position_embeddings: int):
826 |         """
827 |         Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
828 | 
829 |         Arguments:
830 |             new_num_position_embeddings (`int`):
831 |                 The number of new position embedding matrix. If position embeddings are learned, increasing the size
832 |                 will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
833 |                 end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
834 |                 size will add correct vectors at the end following the position encoding algorithm, whereas reducing
835 |                 the size will remove vectors from the end.
836 |         """
837 |         self.distilbert.resize_position_embeddings(new_num_position_embeddings)
838 | 
839 |     def forward(
840 |         self,
841 |         input_ids: Optional[torch.Tensor] = None,
842 |         attention_mask: Optional[torch.Tensor] = None,
843 |         head_mask: Optional[torch.Tensor] = None,
844 |         inputs_embeds: Optional[torch.Tensor] = None,
845 |         labels: Optional[torch.LongTensor] = None,
846 |         output_attentions: Optional[bool] = None,
847 |         output_hidden_states: Optional[bool] = None,
848 |         return_dict: Optional[bool] = None,
849 |     ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
850 |         r"""
851 |         labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
852 |             Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
853 |             config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
854 |             `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
855 |         """
856 |         return_dict = (
857 |             return_dict if return_dict is not None else self.config.use_return_dict
858 |         )
859 | 
860 |         distilbert_output = self.distilbert(
861 |             input_ids=input_ids,
862 |             attention_mask=attention_mask,
863 |             head_mask=head_mask,
864 |             inputs_embeds=inputs_embeds,
865 |             output_attentions=output_attentions,
866 |             output_hidden_states=output_hidden_states,
867 |             return_dict=return_dict,
868 |         )
869 |         hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
870 |         pooled_output = hidden_state[:, 0]  # (bs, dim)
871 |         pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
872 |         pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
873 |         pooled_output = self.dropout(pooled_output)  # (bs, dim)
874 |         logits = self.classifier(pooled_output)  # (bs, num_labels)
875 | 
876 |         loss = None
877 |         if labels is not None:
878 |             if self.config.problem_type is None:
879 |                 if self.num_labels == 1:
880 |                     self.config.problem_type = "regression"
881 |                 elif self.num_labels > 1 and (
882 |                     labels.dtype == torch.long or labels.dtype == torch.int
883 |                 ):
884 |                     self.config.problem_type = "single_label_classification"
885 |                 else:
886 |                     self.config.problem_type = "multi_label_classification"
887 | 
888 |             if self.config.problem_type == "regression":
889 |                 loss_fct = MSELoss()
890 |                 if self.num_labels == 1:
891 |                     loss = loss_fct(logits.squeeze(), labels.squeeze())
892 |                 else:
893 |                     loss = loss_fct(logits, labels)
894 |             elif self.config.problem_type == "single_label_classification":
895 |                 loss_fct = CrossEntropyLoss()
896 |                 loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
897 |             elif self.config.problem_type == "multi_label_classification":
898 |                 loss_fct = BCEWithLogitsLoss()
899 |                 loss = loss_fct(logits, labels)
900 | 
901 |         if not return_dict:
902 |             output = (logits,) + distilbert_output[1:]
903 |             return ((loss,) + output) if loss is not None else output
904 | 
905 |         return SequenceClassifierOutput(
906 |             loss=loss,
907 |             logits=logits,
908 |             hidden_states=distilbert_output.hidden_states,
909 |             attentions=distilbert_output.attentions,
910 |         )
911 | 
```
Page 5/5FirstPrevNextLast