#
tokens: 32143/50000 4/133 files (page 4/4)
lines: off (toggle) GitHub
raw markdown copy
This is page 4 of 4. Use http://codebase.md/datalab-to/surya?lines=false&page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── ISSUE_TEMPLATE
│   │   ├── breaking-bug-report.md
│   │   ├── feature_request.md
│   │   └── output-bug-report.md
│   └── workflows
│       ├── benchmarks.yml
│       ├── ci.yml
│       ├── cla.yml
│       ├── publish.yml
│       └── scripts.yml
├── .gitignore
├── .pre-commit-config.yaml
├── benchmark
│   ├── detection.py
│   ├── layout.py
│   ├── ordering.py
│   ├── recognition.py
│   ├── table_recognition.py
│   ├── texify.py
│   └── utils
│       ├── __init__.py
│       ├── bbox.py
│       ├── metrics.py
│       ├── scoring.py
│       ├── tatr.py
│       ├── tesseract.py
│       ├── textract.py
│       └── verify_benchmark_scores.py
├── CITATION.cff
├── CLA.md
├── detect_layout.py
├── detect_text.py
├── LICENSE
├── ocr_app.py
├── ocr_latex.py
├── ocr_text.py
├── poetry.lock
├── pyproject.toml
├── pytest.ini
├── README.md
├── signatures
│   └── version1
│       └── cla.json
├── static
│   ├── fonts
│   │   └── .gitignore
│   └── images
│       ├── arabic_layout.jpg
│       ├── arabic_reading.jpg
│       ├── arabic_text.jpg
│       ├── arabic.jpg
│       ├── benchmark_chart_small.png
│       ├── benchmark_chart.png
│       ├── benchmark_layout_chart.png
│       ├── benchmark_rec_chart.png
│       ├── benchmark_tablerec_acc.png
│       ├── benchmark_tablerec_speed.png
│       ├── chi_hind_layout.jpg
│       ├── chi_hind_orig.jpg
│       ├── chi_hind_reading.jpg
│       ├── chi_hind_text.jpg
│       ├── chi_hind.jpg
│       ├── chinese_layout.jpg
│       ├── chinese_reading.jpg
│       ├── chinese_text.jpg
│       ├── chinese.jpg
│       ├── excerpt_layout.png
│       ├── excerpt_reading.jpg
│       ├── excerpt_text.png
│       ├── excerpt.png
│       ├── funsd_layout.jpg
│       ├── funsd_reading.jpg
│       ├── funsd_text.jpg
│       ├── funsd.png
│       ├── gcloud_full_langs.png
│       ├── gcloud_rec_bench.png
│       ├── hindi_layout.jpg
│       ├── hindi_reading.jpg
│       ├── hindi_text.jpg
│       ├── hindi.jpg
│       ├── japanese_layout.jpg
│       ├── japanese_reading.jpg
│       ├── japanese_tablerec.png
│       ├── japanese_text.jpg
│       ├── japanese.jpg
│       ├── latex_ocr.png
│       ├── nyt_layout.jpg
│       ├── nyt_order.jpg
│       ├── nyt_text.jpg
│       ├── nyt.jpg
│       ├── paper_layout.jpg
│       ├── paper_reading.jpg
│       ├── paper_tablerec.png
│       ├── paper_text.jpg
│       ├── paper.jpg
│       ├── pres_layout.jpg
│       ├── pres_reading.jpg
│       ├── pres_tablerec.png
│       ├── pres_text.jpg
│       ├── pres.png
│       ├── rec_acc_table.png
│       ├── scanned_layout.jpg
│       ├── scanned_reading.jpg
│       ├── scanned_tablerec.png
│       ├── scanned_tablerec2.png
│       ├── scanned_text.jpg
│       ├── scanned.png
│       ├── surya_rec_perf.png
│       ├── table_rec.png
│       ├── textbook_layout.jpg
│       ├── textbook_order.jpg
│       ├── textbook_text.jpg
│       └── textbook.jpg
├── surya
│   ├── __init__.py
│   ├── common
│   │   ├── __init__.py
│   │   ├── adetr
│   │   │   └── decoder.py
│   │   ├── donut
│   │   │   ├── encoder.py
│   │   │   └── processor.py
│   │   ├── load.py
│   │   ├── polygon.py
│   │   ├── predictor.py
│   │   ├── pretrained.py
│   │   ├── s3.py
│   │   ├── surya
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   ├── decoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── embedder
│   │   │   │   └── __init__.py
│   │   │   ├── encoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── flash_attn_utils.py
│   │   │   ├── processor
│   │   │   │   ├── __init__.py
│   │   │   │   ├── schema.py
│   │   │   │   └── tokenizer.py
│   │   │   └── schema.py
│   │   ├── util.py
│   │   └── xla.py
│   ├── debug
│   │   ├── draw.py
│   │   ├── fonts.py
│   │   ├── katex.js
│   │   ├── render_html.py
│   │   └── text.py
│   ├── detection
│   │   ├── __init__.py
│   │   ├── heatmap.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoderdecoder.py
│   │   ├── parallel.py
│   │   ├── processor.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── foundation
│   │   ├── __init__.py
│   │   ├── cache
│   │   │   ├── __init__.py
│   │   │   ├── dynamic_ops.py
│   │   │   └── static_ops.py
│   │   ├── loader.py
│   │   └── util.py
│   ├── input
│   │   ├── load.py
│   │   └── processing.py
│   ├── layout
│   │   ├── __init__.py
│   │   ├── label.py
│   │   └── schema.py
│   ├── logging.py
│   ├── models.py
│   ├── ocr_error
│   │   ├── __init__.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoder.py
│   │   ├── schema.py
│   │   └── tokenizer.py
│   ├── recognition
│   │   ├── __init__.py
│   │   ├── languages.py
│   │   ├── postprocessing.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── scripts
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── detect_layout.py
│   │   ├── detect_text.py
│   │   ├── finetune_ocr.py
│   │   ├── hf_to_s3.py
│   │   ├── ocr_latex.py
│   │   ├── ocr_text.py
│   │   ├── run_streamlit_app.py
│   │   ├── run_texify_app.py
│   │   ├── streamlit_app.py
│   │   ├── table_recognition.py
│   │   └── texify_app.py
│   ├── settings.py
│   └── table_rec
│       ├── __init__.py
│       ├── loader.py
│       ├── model
│       │   ├── __init__.py
│       │   ├── config.py
│       │   ├── decoder.py
│       │   ├── encoder.py
│       │   └── encoderdecoder.py
│       ├── processor.py
│       ├── schema.py
│       └── shaper.py
├── table_recognition.py
├── tests
│   ├── assets
│   │   └── test_latex.png
│   ├── conftest.py
│   ├── test_detection.py
│   ├── test_foundation.py
│   ├── test_latex_ocr.py
│   ├── test_layout.py
│   ├── test_ocr_errors.py
│   ├── test_recognition.py
│   └── test_table_rec.py
└── texify_app.py
```

# Files

--------------------------------------------------------------------------------
/surya/common/adetr/decoder.py:
--------------------------------------------------------------------------------

```python
from typing import Dict, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from transformers import PretrainedConfig

from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS

from surya.common.pretrained import SuryaPreTrainedModel
from surya.common.xla import mark_step

_MAX_SQRT_GRADIENT = 1000.0


class WrappedEmbedding(nn.Embedding):
    def forward(self, input_ids, *args, **kwargs):
        return super().forward(input_ids)


class SuryaADETRDecoderRMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        variance = x.pow(2).mean(-1, keepdim=True)

        # Add clipping to prevent division by zero
        variance = torch.clamp(variance, min=self.eps)
        return x * torch.rsqrt(variance)

    def forward(self, x):
        output = self._norm(x.float())
        # Llama does x.to(float16) * w whilst SuryaADETRDecoder is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
        output = output * (1.0 + self.weight.float())
        # Clamp to float16 range
        f16_info = torch.finfo(x.dtype)
        output = output.clamp(min=f16_info.min, max=f16_info.max)
        output = torch.where(
            torch.isnan(output), torch.tensor(0.0, device=output.device), output
        )
        return output.type_as(x)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.eps}"


ALL_LAYERNORM_LAYERS.append(SuryaADETRDecoderRMSNorm)


class SuryaADETRDecoderRotaryEmbedding(nn.Module):
    def __init__(self, dim, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (
            self.base
            ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
        )
        self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)

    @torch.no_grad()
    # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaADETRDecoder
    def forward(self, x, position_ids, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        self.inv_freq.to(x.device)
        inv_freq_expanded = (
            self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        )
        position_ids_expanded = position_ids[:, None, :].float()

        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
            1, 2
        )
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos()
        sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class SuryaADETRDecoderSdpaCrossAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper
    Modified for GQA
    """

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads

        self.q_proj = nn.Linear(
            self.hidden_size,
            self.num_attention_heads * self.head_dim,
            bias=config.attention_bias,
        )
        self.k_proj = nn.Linear(
            self.config.encoder_hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=config.attention_bias,
        )
        self.v_proj = nn.Linear(
            self.config.encoder_hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=config.attention_bias,
        )
        self.o_proj = nn.Linear(
            self.num_attention_heads * self.head_dim, self.hidden_size, bias=True
        )
        self.rotary_emb = SuryaADETRDecoderRotaryEmbedding(
            self.head_dim,
            base=config.rope_theta,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # Encoder attention mask currently ignored

        bsz, q_len, _ = hidden_states.size()
        _, v_len, _ = encoder_hidden_states.size()

        query_states = self.q_proj(hidden_states)
        query_states = query_states.view(
            bsz, q_len, self.num_attention_heads, self.head_dim
        ).transpose(1, 2)

        if self.key_states is None:
            key_states = self.k_proj(encoder_hidden_states)
            value_states = self.v_proj(encoder_hidden_states)
            key_states = key_states.view(
                bsz, v_len, self.num_key_value_heads, self.head_dim
            ).transpose(1, 2)
            value_states = value_states.view(
                bsz, v_len, self.num_key_value_heads, self.head_dim
            ).transpose(1, 2)
            if use_cache:
                self._update_cache(key_states, value_states)
        else:
            key_states = self.key_states
            value_states = self.value_states

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=None,
            dropout_p=self.attention_dropout if self.training else 0.0,
            scale=self.head_dim**-0.5,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        return attn_output

    def _clear_cache(self):
        if self.value_states is not None:
            del self.value_states
        if self.key_states is not None:
            del self.key_states

    def _setup_cache(self, batch_size, device, dtype=None):
        # Setup initial caches
        self.value_states = None
        self.key_states = None

    @torch.no_grad()
    def _update_cache(self, key_states, value_states, **cache_kwargs):
        self.value_states = value_states
        self.key_states = key_states


class SuryaADETRDecoderSdpaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: PretrainedConfig, static_cache=False, max_boxes=None):
        super().__init__()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads

        self.q_proj = nn.Linear(
            self.hidden_size,
            self.num_attention_heads * self.head_dim,
            bias=config.attention_bias,
        )
        self.k_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=config.attention_bias,
        )
        self.v_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=config.attention_bias,
        )
        self.o_proj = nn.Linear(
            self.num_attention_heads * self.head_dim, self.hidden_size, bias=True
        )
        self.rotary_emb = SuryaADETRDecoderRotaryEmbedding(
            self.head_dim,
            base=config.rope_theta,
        )

        self.static_cache = static_cache
        self.max_boxes = max_boxes

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        use_cache: bool = False,
        window_attn: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Final is bsz, num_attention_heads, seq_len, head_dim
        query_states = query_states.view(
            bsz, q_len, self.num_attention_heads, self.head_dim
        ).transpose(1, 2)
        key_states = key_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)
        value_states = value_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
        query_states, key_states = apply_rotary_pos_emb(
            query_states, key_states, cos, sin
        )

        if use_cache and hasattr(self, "key_states"):
            cache_kwargs = {
                "cache_position": cache_position,
                "window_attn": window_attn,
            }
            key_states, value_states = self._update_cache(
                key_states, value_states, **cache_kwargs
            )

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            # Mask is batch, head, seq_len, kv_len
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
            if cache_position is not None and self.static_cache:
                current_pos = cache_position[-1]
                causal_mask[:, :, :, current_pos + 1 :] = torch.finfo(
                    causal_mask.dtype
                ).min

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            scale=self.head_dim**-0.5,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        return attn_output

    def _setup_cache(self, batch_size, device, dtype=None):
        if dtype is None and self.config.torch_dtype is not None:
            dtype = self.config.torch_dtype
        dtype = dtype if dtype is not None else torch.float32

        # Setup initial caches
        self.value_states = None
        self.key_states = None

        if self.static_cache:
            cache_shape = (
                batch_size,
                self.num_key_value_heads,
                self.max_boxes,
                self.head_dim,
            )
            self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
            self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)

    def _clear_cache(self):
        if self.value_states is not None:
            del self.value_states
        if self.key_states is not None:
            del self.key_states

    def _update_static_cache(self, key_states, value_states, **cache_kwargs):
        cache_position = cache_kwargs.get("cache_position")
        k_out, v_out = (
            self.key_states.to(key_states.device),
            self.value_states.to(value_states.device),
        )

        k_out[:, :, cache_position] = key_states.to(k_out.dtype)
        v_out[:, :, cache_position] = value_states.to(v_out.dtype)

        self.key_states, self.value_states = k_out, v_out
        return k_out, v_out

    def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs):
        k_out = key_states
        if self.key_states is not None:
            k_out = torch.cat([self.key_states, key_states], dim=2)

        v_out = value_states
        if self.value_states is not None:
            v_out = torch.cat([self.value_states, value_states], dim=2)

        self.key_states, self.value_states = k_out, v_out
        return k_out, v_out

    @torch.no_grad()
    def _update_cache(self, key_states, value_states, **cache_kwargs):
        if self.static_cache:
            return self._update_static_cache(key_states, value_states, **cache_kwargs)

        return self._update_dynamic_cache(key_states, value_states, **cache_kwargs)


class SuryaADETRDecoderMlp(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        if config.hidden_activation is None:
            config.hidden_activation = "gelu_pytorch_tanh"
        hidden_activation = config.hidden_activation
        self.act_fn = ACT2FN[hidden_activation]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class SuryaADETRDecoderLayer(nn.Module):
    def __init__(self, config, layer_idx, static_cache=False, max_boxes=None):
        super().__init__()
        self.cross_pre_norm = SuryaADETRDecoderRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
        self.temporal_pre_norm = SuryaADETRDecoderRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

        self.temporal_block = None
        if layer_idx in config.self_attn_layers:
            self.temporal_block = SuryaADETRDecoderSdpaAttention(
                config, static_cache=static_cache, max_boxes=max_boxes
            )

        self.cross_attn_block = None
        if layer_idx in config.cross_attn_layers:
            self.cross_attn_block = SuryaADETRDecoderSdpaCrossAttention(config)

        self.window_attn = layer_idx not in config.global_attn_layers
        self.channel_pre_norm = SuryaADETRDecoderRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
        self.mlp_block = SuryaADETRDecoderMlp(config)

        self.double_residual_flow = getattr(config, "double_residual_flow", False)

    def forward(
        self,
        activations: torch.Tensor,
        position_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        encoder_hidden_states: torch.Tensor = None,
        encoder_attention_mask: torch.Tensor = None,
        cache_position: torch.Tensor = None,
        use_cache: bool = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        if self.double_residual_flow:
            return self.double_res_forward(
                activations,
                position_ids,
                attention_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cache_position,
                use_cache,
            )

        hidden_states = activations
        if self.cross_attn_block is not None:
            # Do cross-attention on encoder outputs
            cross_attn_inputs = self.cross_pre_norm(hidden_states)
            cross_attn_path = self.cross_attn_block(
                cross_attn_inputs,
                encoder_hidden_states,
                attention_mask,
                encoder_attention_mask,
                use_cache=use_cache,
            )
            hidden_states = cross_attn_path + hidden_states

        if self.temporal_block is not None:
            temporal_inputs = self.temporal_pre_norm(
                hidden_states
            )  # RMSNorm introduces slight slight differences
            temporal_path = self.temporal_block(
                temporal_inputs,
                position_ids,
                attention_mask,
                cache_position=cache_position,
                use_cache=use_cache,
                window_attn=self.window_attn,
            )

            hidden_states = temporal_path + hidden_states

        block_input = hidden_states
        hidden_states = self.channel_pre_norm(block_input)
        hidden_states = self.mlp_block(hidden_states)
        hidden_states = hidden_states + block_input

        return hidden_states

    def double_res_forward(
        self,
        activations: torch.Tensor,
        position_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        encoder_hidden_states: torch.Tensor = None,
        encoder_attention_mask: torch.Tensor = None,
        cache_position: torch.Tensor = None,
        use_cache: bool = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        raw_activations = activations

        if self.cross_attn_block is not None:
            # Do cross-attention on encoder outputs
            cross_attn_inputs = self.cross_pre_norm(activations)
            cross_attn_path = self.cross_attn_block(
                cross_attn_inputs,
                encoder_hidden_states,
                attention_mask,
                encoder_attention_mask,
                use_cache=use_cache,
            )
            cross_attn_output = cross_attn_path + raw_activations
        else:
            cross_attn_output = raw_activations

        if self.temporal_block is not None:
            inputs_normalized = self.temporal_pre_norm(
                cross_attn_output
            )  # RMSNorm introduces slight slight differences
            hidden_states = self.temporal_block(
                inputs_normalized,
                position_ids,
                attention_mask,
                cache_position=cache_position,
                use_cache=use_cache,
                window_attn=self.window_attn,
            )

            residual = hidden_states + raw_activations
        else:
            residual = cross_attn_output

        hidden_states = self.channel_pre_norm(residual)
        hidden_states = self.mlp_block(hidden_states)

        hidden_states = hidden_states + residual
        return hidden_states


class SuryaADETRDecoderPreTrainedModel(SuryaPreTrainedModel):
    config_class = PretrainedConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["SuryaADETRDecoderLayer"]
    _skip_keys_device_placement = ["cache"]
    _supports_flash_attn_2 = False
    _supports_sdpa = False  # we can't compare with eager for now
    _supports_cache_class = True
    _supports_quantized_cache = True

    def _init_weights(self, module):
        if isinstance(module, SuryaADETRDecoderSdpaAttention):
            torch.nn.init.normal_(
                module.q_proj.weight, mean=0.0, std=self.config.init_std
            )
            torch.nn.init.normal_(
                module.k_proj.weight, mean=0.0, std=self.config.init_std
            )
            torch.nn.init.normal_(
                module.v_proj.weight, mean=0.0, std=self.config.init_std
            )

            torch.nn.init.normal_(
                module.o_proj.weight, mean=0.0, std=self.config.init_std
            )
        elif isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
            if getattr(module, "bias", None) is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    def _setup_cache(self, config, batch, device, dtype):
        layers = getattr(self, "model", self).layers
        for layer in layers:
            if layer.temporal_block:
                layer.temporal_block._setup_cache(batch, device, dtype)
            if layer.cross_attn_block:
                layer.cross_attn_block._setup_cache(batch, device, dtype)

    def _clear_cache(self):
        layers = getattr(self, "model", self).layers
        for layer in layers:
            if layer.temporal_block:
                layer.temporal_block._clear_cache()
            if layer.cross_attn_block:
                layer.cross_attn_block._clear_cache()

    def reset_cache(self, batch, device, dtype):
        pass

    def _tie_weights(self):
        pass

    def tie_weights(self):
        pass


class SuryaADETRDecoderModel(SuryaADETRDecoderPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaADETRDecoderDecoderLayer`]

    Args:
        config: PretrainedConfig
    """

    def __init__(
        self,
        config: PretrainedConfig,
        embedder: nn.Module = None,
        max_boxes: int = None,
        static_cache: bool = False,
    ):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.causal = config.causal

        self.embed_tokens = embedder
        self.max_boxes = max_boxes
        self.static_cache = static_cache

        self.layers = nn.ModuleList(
            [
                SuryaADETRDecoderLayer(
                    config, layer_idx, static_cache=static_cache, max_boxes=max_boxes
                )
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self.final_norm = SuryaADETRDecoderRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
        self.gradient_checkpointing = False

        self.register_buffer(
            "normalizer",
            torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32),
            persistent=False,
        )
        # Initialize weights and apply final processing
        self.post_init()

    # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings
    def get_input_embeddings(self):
        return self.embed_tokens

    # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings
    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        input_boxes_counts: torch.LongTensor = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        prefill: bool = False,
    ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        if self.gradient_checkpointing and self.training and use_cache:
            use_cache = False

        inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts)
        hidden_states = inputs_embeds

        if use_cache and prefill:
            self._setup_cache(
                self.config,
                hidden_states.shape[0],
                hidden_states.device,
                hidden_states.dtype,
            )

        if cache_position is None:
            cache_position = torch.arange(
                hidden_states.shape[1], device=hidden_states.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position
        )

        all_hidden_states = () if output_hidden_states else None
        for i, residual_block in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            if self.gradient_checkpointing and self.training:
                hidden_states = self._gradient_checkpointing_func(
                    residual_block.__call__,
                    hidden_states,
                    position_ids,
                    causal_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    cache_position,
                    use_cache,
                )
            else:
                hidden_states = residual_block(
                    hidden_states,
                    position_ids,
                    causal_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    cache_position,
                    use_cache,
                )

        hidden_states = self.final_norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)

        return BaseModelOutputWithNoAttention(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
        )

    # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
    # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
    # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
    # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
    # Ignore copy
    def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
        if not self.causal:
            return None

        dtype, device = input_tensor.dtype, input_tensor.device
        min_dtype = torch.finfo(dtype).min
        sequence_length = input_tensor.shape[1]
        target_length = max(self.max_boxes, sequence_length)

        diagonal = torch.full(
            (sequence_length, target_length),
            fill_value=min_dtype,
            dtype=dtype,
            device=device,
        )
        causal_mask = diagonal
        if sequence_length != 1:
            # Select the upper triangular part of the matrix, but unmask current token (the diagonal)
            # triu will be the min_dtype, everything else is 0 (attended to)
            causal_mask = torch.triu(diagonal, diagonal=1)

        causal_mask *= torch.arange(
            target_length, device=device
        ) > cache_position.reshape(-1, 1)
        causal_mask = causal_mask[None, None, :, :].expand(
            input_tensor.shape[0], 1, -1, -1
        )
        if attention_mask is not None:
            causal_mask = (
                causal_mask.clone()
            )  # copy to contiguous memory for in-place edit
            if attention_mask.dim() == 2:
                # Mask positions in the causal mask that are masked in the attention mask
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
                    :, None, None, :
                ].eq(0.0)
                causal_mask[..., :mask_length] = causal_mask[
                    ..., :mask_length
                ].masked_fill(padding_mask, min_dtype)

        if attention_mask is not None and attention_mask.device.type == "cuda":
            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
            # Details: https://github.com/pytorch/pytorch/issues/110213
            causal_mask = AttentionMaskConverter._unmask_unattended(
                causal_mask, min_dtype
            )

        return causal_mask

```

--------------------------------------------------------------------------------
/surya/foundation/__init__.py:
--------------------------------------------------------------------------------

```python
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Tuple
from collections import deque

import cv2
import numpy as np
import torch
import math
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F

from surya.common.surya import SuryaModelOutput
from surya.common.xla import mark_step
from surya.common.predictor import BasePredictor

from surya.foundation.loader import FoundationModelLoader
from surya.foundation.util import (
    detect_repeat_token,
)
from surya.common.surya.schema import TaskNames
from surya.foundation.cache.dynamic_ops import DynamicOpsCache
from surya.foundation.cache.static_ops import StaticOpsCache

from surya.settings import settings
from surya.logging import get_logger, configure_logging

configure_logging()
logger = get_logger()


@dataclass
class ContinuousBatchInput:
    input_ids: torch.Tensor
    input_boxes: torch.Tensor
    position_ids: torch.Tensor
    # input_ids and position_ids may be padded, num_valid_tokens tracks the 'real' counts
    num_valid_tokens: torch.Tensor
    # count the number of predicted tokens for each batch element so far
    num_predicted_tokens: torch.Tensor
    needs_bbox_embedding: torch.Tensor


@dataclass
class ContinuousBatchOutput:
    input_ids: torch.Tensor
    preds: torch.Tensor
    bbox_preds: torch.Tensor
    scores: torch.Tensor
    token_probs: torch.Tensor


@dataclass
class FoundationPrompt:
    id: int
    task_name: TaskNames
    image: np.ndarray
    text: str
    math_mode: bool


class FoundationPredictor(BasePredictor):
    model_loader_cls = FoundationModelLoader
    batch_size = (
        settings.RECOGNITION_BATCH_SIZE
    )  # Default to the recognition batch size
    torch_dtype = None  # No default, loader picks the dtype based on device properties - bf16/fp16
    default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 64}
    encoder_chunk_size: int = 4096  # Default chunk size
    encoder_chunk_sizes = {"cpu": 4096, "mps": 4096, "cuda": 32768, "xla": 32768}
    extra_token_count = {
        "xla": 128
    }  # We have to pad the XLA cache since we don't use sliding window
    min_prefill_ratio: int = 1 if settings.FOUNDATION_XLA else 0.2
    min_trim_length: int = 50
    tasks = {
        TaskNames.ocr_with_boxes: {
            "needs_bboxes": True,
            "img_size": (1024, 512),
            "max_tokens": 224,
        },
        TaskNames.ocr_without_boxes: {
            "needs_bboxes": False,
            "img_size": (1024, 512),
            "max_tokens": 224,
        },
        TaskNames.block_without_boxes: {
            "needs_bboxes": False,
            "img_size": (1024, 512),
            "max_tokens": 768,
        },
        TaskNames.layout: {
            "needs_bboxes": False,
            "img_size": (1024, 1024),
            "max_tokens": 200,
        },
        TaskNames.table_structure: {
            "needs_bboxes": False,
            "img_size": (1024, 512),
            "max_tokens": 600,
        },
    }

    def __init__(
        self,
        checkpoint=None,
        device=settings.TORCH_DEVICE_MODEL,
        dtype=None,
        attention_implementation: Optional[str] = None,
    ):
        super().__init__(checkpoint, device, dtype, attention_implementation)
        self.prompt_queue = deque()
        self.batch_prompt_mapping = None
        self.kv_cache = None

        self.beacon_token_interval = self.model.config.beacon_token_interval

        # Setup various tokens on-device
        self.device_pad_token = torch.tensor(
            self.processor.pad_token_id, device=self.model.device, dtype=torch.long
        )
        self.device_beacon_token = torch.tensor(
            self.processor.beacon_token_id, device=self.model.device, dtype=torch.long
        )
        self.special_token_ids = torch.tensor(
            [self.model.config.image_token_id] + self.model.config.register_token_ids,
            device=self.model.device,
        )

        self.pad_to_multiple = (
            settings.FOUNDATION_PAD_TO_NEAREST
            if settings.FOUNDATION_STATIC_CACHE
            else None
        )

    def to(self, device_dtype: torch.device | str | None = None):
        super().to(device_dtype)
        self.special_token_ids = self.special_token_ids.to(device_dtype)

    def get_encoder_chunk_size(self) -> int:
        if settings.FOUNDATION_CHUNK_SIZE is not None:
            return settings.FOUNDATION_CHUNK_SIZE

        chunk_size = self.encoder_chunk_size
        if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes:
            if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes:
                chunk_size = self.encoder_chunk_sizes[settings.TORCH_DEVICE_MODEL]
        return chunk_size

    def setup_cache(self, batch_size: int, max_cache_len: int, max_sliding_window: int):
        kv_cache_cls = StaticOpsCache if settings.FOUNDATION_XLA else DynamicOpsCache
        self.kv_cache = kv_cache_cls(
            self.model.config,
            batch_size,
            max_cache_len,
            text_sliding_window=max_sliding_window,
            device=self.model.device,
            dtype=self.model.dtype,
        )
        self.prompt_queue.clear()
        self.batch_prompt_mapping = {i: None for i in range(batch_size)}

    @property
    def num_empty_slots(self):
        return sum(v is None for v in self.batch_prompt_mapping.values())

    @property
    def num_active_slots(self):
        return len(self.batch_prompt_mapping) - self.num_empty_slots

    def prepare_input(
        self,
        task_names: List[str],
        images: List[Image.Image],
        input_text: List[str | None],
        math_modes: List[bool],
    ):
        batch = []
        for image, text, task_name, math_mode in zip(
            images, input_text, task_names, math_modes
        ):
            image_size = self.tasks[task_name]["img_size"]

            try:
                image = self.processor.scale_to_fit(
                    image, image_size
                )  # Only resizes if out of bounds (max/min)
            except cv2.error:
                # The image is empty if it can't be resized, so just make a blank image
                image = np.zeros((image_size[1], image_size[0], 3), dtype=np.float32)

            # Task input is the same for all tasks for now
            text = text or ""

            # Remove input text that exceeds max generation tokens (likely invalid)
            if len(text) > self.tasks[task_name]["max_tokens"]:
                text = ""
            inputs = [
                {"type": "image", "image": image, "rotated": False},
                {"type": "text", "text": text.strip(), "math": math_mode},
            ]
            batch.append({"task": task_name, "inputs": inputs})

        return batch

    def process_outputs(
        self, outputs: SuryaModelOutput, max_lookahead_tokens: Optional[int] = None
    ) -> ContinuousBatchOutput:
        # Predictions are multi-token
        lm_logits = outputs["lm_logits"].float()  # shape: [batch_size, seq_len, V]
        bbox_logits = outputs["bbox_logits"].float()  # shape: [batch_size, seq_len, 6]

        if (
            max_lookahead_tokens is not None
            and lm_logits.shape[1] > max_lookahead_tokens + 1
        ):
            lm_logits = lm_logits[:, : max_lookahead_tokens + 1, :]
            bbox_logits = bbox_logits[:, : max_lookahead_tokens + 1, :]

        # Get predictions
        preds = torch.argmax(lm_logits, dim=-1)
        input_ids = preds.to(torch.long)

        # Confidence scores for all tokens
        token_probs = F.softmax(lm_logits, dim=-1)
        scores = torch.max(token_probs, dim=-1).values  # shape: [B, T]

        # Update input boxes
        box_preds = bbox_logits * self.model.config.bbox_size
        box_preds = box_preds.to(torch.long)

        return ContinuousBatchOutput(
            input_ids=input_ids,
            preds=preds,
            bbox_preds=box_preds,
            scores=scores,
            token_probs=token_probs,
        )

    # Always left pad with beacons, don't worry about attention masking
    def maybe_insert_beacon_tokens(
        self,
        input_ids: torch.Tensor,
        input_boxes: torch.Tensor,
        num_predicted_tokens: torch.Tensor,
        num_new_tokens: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, seq_len = (
            input_ids.shape
        )  # seq_len can be >1 - In case of multi-token predictions

        # num_predicted tokens **does not include** the current new input_ids, this number is updated **after beacon tokens are inserted**
        token_positions = num_predicted_tokens + torch.arange(
            1, seq_len + 1, device=input_ids.device
        ).unsqueeze(0)
        beacon_positions = token_positions % self.beacon_token_interval == 0

        # If no beacons needed, return original input
        needs_beacon = beacon_positions.any(dim=1)  # shape: [batch_size]
        if not needs_beacon.any():
            if num_new_tokens is None:
                num_new_tokens = (
                    torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
                    * seq_len
                )
            return input_ids, input_boxes, num_new_tokens.squeeze(1)

        beacon_insert_pos = torch.zeros(
            batch_size, dtype=torch.long, device=input_ids.device
        )
        for i in range(batch_size):
            if needs_beacon[i]:
                # Find first position that needs beacon
                beacon_insert_pos[i] = torch.where(beacon_positions[i])[0]

        # Padded input ids.
        new_input_ids = torch.full(
            (batch_size, seq_len + 1),
            self.device_pad_token,
            dtype=input_ids.dtype,
            device=input_ids.device,
        )
        new_input_boxes = torch.full(
            (batch_size, seq_len + 1, 6),
            -100,
            dtype=input_boxes.dtype,
            device=input_boxes.device,
        )
        # Fill in tokens for each sequence
        for i in range(batch_size):
            if needs_beacon[i]:
                insert_pos = beacon_insert_pos[i]
                new_input_ids[i, insert_pos] = self.device_beacon_token
                new_input_boxes[i, insert_pos, :] = -100
                if insert_pos > 0:
                    new_input_ids[i, :insert_pos] = input_ids[i, :insert_pos]
                    new_input_boxes[i, :insert_pos] = input_boxes[i, :insert_pos]
                new_input_ids[i, insert_pos + 1 :] = input_ids[i, insert_pos:]
                new_input_boxes[i, insert_pos + 1 :] = input_boxes[i, insert_pos:]
            else:
                new_input_ids[i, 1:] = input_ids[i, :]
                new_input_boxes[i, 1:] = input_boxes[i, :]

        # Calculate valid token counts for both padded and non padded sequences
        valid_token_counts = torch.where(
            needs_beacon,
            torch.tensor(seq_len + 1, device=input_ids.device),
            torch.tensor(seq_len, device=input_ids.device),
        )

        return new_input_ids, new_input_boxes, valid_token_counts

    def decode(
        self,
        current_inputs: Optional[ContinuousBatchInput] = None,
        max_lookahead_tokens: Optional[int] = None,
    ):
        # Note - If we want to use the outputs from the non-last token, we
        # need to set the cache position manually to ensure causality. The default
        # behavior only works for the last token currently
        input_ids = current_inputs.input_ids
        input_boxes = current_inputs.input_boxes
        embed_boxes = current_inputs.needs_bbox_embedding

        position_ids = current_inputs.position_ids
        num_predicted_tokens = current_inputs.num_predicted_tokens
        num_valid_tokens = current_inputs.num_valid_tokens
        batch_size = input_ids.shape[0]

        # Pre-shift the attention mask based on the cache update
        self.kv_cache.decode_attention_mask_update(
            num_valid_tokens=num_valid_tokens, cache_idxs=list(range(batch_size))
        )

        cache_position = self.get_cache_position(
            input_ids.shape[1], self.kv_cache.attention_mask, prefill=False
        )
        with settings.INFERENCE_MODE():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=self.kv_cache.attention_mask,
                position_ids=position_ids,
                cache_position=cache_position,
                use_cache=True,
                past_key_values=self.kv_cache,
                prefill=False,
                num_valid_tokens=num_valid_tokens,
                input_boxes=input_boxes,
                embed_boxes=embed_boxes,
                logits_to_keep=1,
            )

        processed_output: ContinuousBatchOutput = self.process_outputs(
            outputs, max_lookahead_tokens=max_lookahead_tokens
        )

        input_ids = processed_output.input_ids
        input_boxes = processed_output.bbox_preds

        # Update this **before** inserting beacon tokens
        tau = settings.FOUNDATION_MULTI_TOKEN_MIN_CONFIDENCE
        if max_lookahead_tokens is not None:
            num_new_tokens = torch.clamp(
                (
                    processed_output.scores.ge(tau)
                    .to(torch.long)
                    .cumprod(dim=1)
                    .sum(dim=1, keepdim=True)
                ),
                min=1,
            )
        else:
            num_new_tokens = input_ids.shape[1]

        num_predicted_tokens += num_new_tokens
        input_ids, input_boxes, num_valid_tokens = self.maybe_insert_beacon_tokens(
             input_ids, input_boxes, num_predicted_tokens, num_new_tokens
        )
        position_ids = position_ids[:, -1:] + torch.arange(
            1, input_ids.shape[1] + 1, device=input_ids.device
        )
        # Some of the input sequences may now have left padding tokens, so we want to account for that
        # offset is a per-batch offset of the position_ids
        offset = (input_ids.shape[1] - num_valid_tokens).unsqueeze(1)
        position_ids -= offset

        new_input = ContinuousBatchInput(
            input_ids=input_ids,
            input_boxes=input_boxes,
            position_ids=position_ids,
            num_valid_tokens=num_valid_tokens,
            num_predicted_tokens=num_predicted_tokens,
            needs_bbox_embedding=current_inputs.needs_bbox_embedding,
        )

        return new_input, processed_output

    def pad_and_shift_input_ids_position_ids(
        self,
        input_ids: torch.Tensor,
        bbox_preds: torch.Tensor,
        position_ids: torch.Tensor,
        new_seq_len: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Pads new_input_ids to match the new seq len with **left padding**
        and creates updated position_ids

        Returns:
            padded_input_ids (torch.Tensor): [batch_size, current_seq_len]
            updated_position_ids (torch.Tensor): [batch_size, current_seq_len]
        """
        # No padding
        if new_seq_len == input_ids.shape[1]:
            return (
                input_ids,
                bbox_preds,
                position_ids[:, -1:] + torch.arange(1, new_seq_len + 1, device=self.model.device),
            )

        pad_len = new_seq_len - input_ids.shape[1]
        padded_input_ids = torch.nn.functional.pad(
            input_ids, (pad_len, 0), value=self.device_pad_token
        )

        padded_bbox_preds = torch.nn.functional.pad(
            bbox_preds, (0, 0, pad_len, 0), value=-100
        )

        # Since we have **left padding**, offset the new position_ids by the amount of padding
        # This ensures that the **true tokens** get the correct position_ids
        # The position_ids assigned to pad tokens do not matter. They are not cached, and not used for outputs
        updated_position_ids = position_ids[:, -1:] + torch.arange(
            1, new_seq_len + 1, device=self.model.device
        )
        updated_position_ids -= pad_len

        return padded_input_ids, padded_bbox_preds, updated_position_ids

    def get_cache_position(
        self,
        seq_len: int,
        attention_mask: torch.Tensor,
        prefill: bool,
    ):
        batch_size, target_len = attention_mask.shape
        base_cache_position = (
            torch.arange(seq_len, device=attention_mask.device)
            .unsqueeze(0)
            .expand(batch_size, -1)
        )
        if prefill:
            return base_cache_position

        # This is a (batch_size) tensor, we can add the seq lens here
        cache_seqlens = (
            attention_mask
            * torch.arange(attention_mask.size(1), device=attention_mask.device)
        ).argmax(dim=1).to(torch.int32) + 1
        # Needs to be unsqueezed so broadcasting works
        return cache_seqlens.unsqueeze(1) + base_cache_position

    def prefill(
        self,
        current_inputs: Optional[ContinuousBatchInput] = None,
        max_lookahead_tokens: Optional[int] = None,
    ):
        logger.debug(f"Prefilling {self.num_empty_slots} slots")

        prompts: List[FoundationPrompt] = [
            self.prompt_queue.popleft()
            for _ in range(min(self.num_empty_slots, len(self.prompt_queue)))
        ]
        non_active_idxs = [k for k, v in self.batch_prompt_mapping.items() if v is None]
        idxs_to_merge = non_active_idxs[: len(prompts)]

        for i, prompt in zip(idxs_to_merge, prompts):
            self.batch_prompt_mapping[i] = prompt.id

        needs_bbox_embedding = torch.tensor(
            [
                p.task_name in [TaskNames.layout, TaskNames.table_structure]
                for p in prompts
            ],
            dtype=torch.bool,
        )

        batch_input = self.prepare_input(
            task_names=[p.task_name for p in prompts],
            images=[p.image for p in prompts],
            input_text=[p.text for p in prompts],
            math_modes=[
                p.math_mode for p in prompts
            ],  # Pass math mode to the processor
        )
        processed_inputs = self.processor(
            batch_input,
            padding_side="left",
            device=self.model.device,
            pad_to_multiple=self.pad_to_multiple,
        )

        input_ids = processed_inputs["input_ids"].to(dtype=torch.long)
        attention_mask = processed_inputs["attention_mask"].to(dtype=torch.long)
        position_ids = processed_inputs["position_ids"].to(dtype=torch.long)
        valid_batch_size = len(idxs_to_merge)

        # Keep these off device until later
        image_tiles = processed_inputs["image_tiles"].to(dtype=self.model.dtype)
        grid_thw = processed_inputs["grid_thw"].to(dtype=torch.long)

        if settings.FOUNDATION_STATIC_CACHE:
            input_ids = self.pad_to_batch_size(
                input_ids, batch_size=self.kv_cache.max_batch_size
            )
            attention_mask = self.pad_to_batch_size(
                attention_mask, batch_size=self.kv_cache.max_batch_size
            )
            position_ids = self.pad_to_batch_size(
                position_ids, batch_size=self.kv_cache.max_batch_size
            )
            needs_bbox_embedding = self.pad_to_batch_size(
                needs_bbox_embedding, batch_size=self.kv_cache.max_batch_size
            )

        # Move to device after padding
        input_ids = input_ids.to(device=self.model.device)
        attention_mask = attention_mask.to(device=self.model.device)
        position_ids = position_ids.to(device=self.model.device)
        needs_bbox_embedding = needs_bbox_embedding.to(device=self.model.device)

        # Find text lengths of each
        # Oddly, this is optimal on GPU - causes a 30% slowdown if "optimized"
        # Be very careful with the type and device of this - can cause
        # a big slowdown if put on device
        is_special = (
            (input_ids.unsqueeze(-1) == self.special_token_ids).any(-1).cpu()
        )  # (batch, seq_len)
        text_lengths = []
        for i in range(input_ids.shape[0]):
            special_positions = is_special[i].nonzero(as_tuple=True)[0]
            if len(special_positions) > 0:
                # Assuming special tokens are contiguous at the start
                prefix_len = special_positions[-1].item() + 1
            else:
                prefix_len = 0
            text_lengths.append(input_ids.shape[1] - prefix_len)
        text_lengths = torch.tensor(text_lengths, dtype=torch.long)

        cache_position = self.get_cache_position(
            input_ids.shape[1], attention_mask, prefill=True
        )
        with settings.INFERENCE_MODE():
            image_embeddings = self.model.get_image_embeddings(
                pixel_values=image_tiles,
                grid_thw=grid_thw,
                encoder_chunk_size=self.get_encoder_chunk_size(),
                valid_batch_size=valid_batch_size,
                max_batch_size=self.kv_cache.max_batch_size,
            )
            mark_step()

            outputs = self.model(
                input_ids=input_ids,
                image_embeddings=image_embeddings,
                attention_mask=attention_mask,
                position_ids=position_ids,
                cache_position=cache_position,
                inputs_embeds=None,
                past_key_values=self.kv_cache,
                use_cache=True,
                encoder_chunk_size=self.get_encoder_chunk_size(),
                cache_idxs=idxs_to_merge,
                prefill=True,
                num_valid_tokens=None,  # Not required during prefill
                text_lengths=text_lengths,
                valid_batch_size=valid_batch_size,
                logits_to_keep=1,
            )

        # Process outputs
        processed_outputs = self.process_outputs(
            outputs, max_lookahead_tokens=max_lookahead_tokens
        )
        # Multi-token prediction
        predicted_tokens = processed_outputs.input_ids.shape[1]
        num_valid_tokens = (
            torch.ones((input_ids.shape[0]), device=self.model.device, dtype=torch.long)
            * predicted_tokens
        )
        num_predicted_tokens = (
            torch.ones(
                (input_ids.shape[0], 1), device=self.model.device, dtype=torch.long
            )
            * predicted_tokens
        )

        self.kv_cache.prefill_attention_mask_update(
            attention_mask, idxs_to_merge, valid_batch_size, text_lengths
        )
        self.kv_cache.update_text_counts(idxs_to_merge, valid_batch_size, text_lengths)

        full_batch = len(idxs_to_merge) == self.kv_cache.max_batch_size

        # If full batch, then we can ignore current_inputs
        if current_inputs is None or full_batch:
            new_seq_len = processed_outputs.input_ids.shape[1]
            # No padding tokens - So we can safely set position_ids this way
            position_ids = position_ids[:, -1:] + torch.arange(
                1, new_seq_len + 1, device=position_ids.device
            )
            new_input = ContinuousBatchInput(
                input_ids=processed_outputs.input_ids,
                input_boxes=processed_outputs.bbox_preds,
                position_ids=position_ids,
                num_valid_tokens=num_valid_tokens,
                num_predicted_tokens=num_predicted_tokens,
                needs_bbox_embedding=needs_bbox_embedding,
            )

            return (
                new_input,
                processed_outputs,
                range(processed_outputs.input_ids.shape[0]),
            )

        # Merging inputs for next steps
        current_input_ids = current_inputs.input_ids
        current_position_ids = current_inputs.position_ids
        current_input_boxes = current_inputs.input_boxes

        current_needs_bbox_embedding = current_inputs.needs_bbox_embedding

        assert current_input_ids.shape[1] == current_position_ids.shape[1]
        input_ids, bbox_preds, position_ids = self.pad_and_shift_input_ids_position_ids(
            processed_outputs.input_ids,
            processed_outputs.bbox_preds,
            position_ids,
            new_seq_len=current_input_ids.shape[1],
        )

        current_input_ids[idxs_to_merge] = input_ids[:valid_batch_size]
        current_input_boxes[idxs_to_merge] = bbox_preds[:valid_batch_size]
        current_position_ids[idxs_to_merge] = position_ids[:valid_batch_size]

        current_num_valid_tokens = current_inputs.num_valid_tokens
        current_num_valid_tokens[idxs_to_merge] = num_valid_tokens[:valid_batch_size]

        current_num_predicted_tokens = current_inputs.num_predicted_tokens
        current_num_predicted_tokens[idxs_to_merge] = num_predicted_tokens[
            :valid_batch_size
        ]
        current_needs_bbox_embedding[idxs_to_merge] = needs_bbox_embedding[
            :valid_batch_size
        ]

        new_input = ContinuousBatchInput(
            input_ids=current_input_ids,
            input_boxes=current_input_boxes,
            position_ids=current_position_ids,
            num_valid_tokens=current_num_valid_tokens,
            num_predicted_tokens=current_num_predicted_tokens,
            needs_bbox_embedding=current_needs_bbox_embedding,
        )

        return new_input, processed_outputs, idxs_to_merge

    def get_max_image_token_count(
        self, images: list[np.ndarray], tasks: List[TaskNames]
    ) -> int:
        def compute_scaled_size(
            H: int, W: int, max_size: Tuple[int, int]
        ) -> Tuple[int, int]:
            max_W, max_H = max_size
            min_W, min_H = (168, 168)

            current_pixels = H * W
            max_pixels = max_H * max_W
            min_pixels = min_H * min_W
            current_pixels = max(1, current_pixels)  # Avoid zero division

            if current_pixels > max_pixels:
                scale = (max_pixels / current_pixels) ** 0.5
                return math.floor(H * scale), math.floor(W * scale)
            elif current_pixels < min_pixels:
                scale = (min_pixels / current_pixels) ** 0.5
                return math.ceil(H * scale), math.ceil(W * scale)
            return H, W

        def get_tile_count(H: int, W: int, factor: int) -> int:
            H_bar = math.ceil(H / factor) * factor
            W_bar = math.ceil(W / factor) * factor
            grid_h = H_bar / self.processor.patch_size
            grid_w = W_bar // self.processor.patch_size
            return grid_h * grid_w

        max_tokens = 0
        factor = self.processor.patch_size * self.processor.merge_size

        for image, task in zip(images, tasks):
            H, W = image.shape[:2]
            max_size = self.tasks[task]["img_size"]
            scaled_H, scaled_W = compute_scaled_size(H, W, max_size)
            token_count = get_tile_count(scaled_H, scaled_W, factor) / (
                self.processor.merge_size**2
            )
            max_tokens = max(max_tokens, token_count)

        # Extra 10 to account for EOS/BOS/Rotation token etc.
        return 10 + self.processor.num_register_tokens + int(max_tokens)

    def prediction_loop(
        self,
        images: List[np.ndarray],
        input_texts: List[str],
        task_names: List[TaskNames],
        batch_size: int | None = None,
        max_tokens: int | None = None,
        max_sliding_window: int | None = None,
        math_mode: bool = True,
        drop_repeated_tokens: bool = True,
        max_lookahead_tokens: Optional[int] = None,
        top_k: int = 0,
        tqdm_desc: str = "Recognizing Text"
    ) -> tuple:
        allowed_tasks = self.tasks.keys()
        assert all([task_name in allowed_tasks for task_name in task_names]), (
            f"One or more tasks in {task_names} is not supported. Supported tasks are {allowed_tasks}"
        )

        predicted_tokens = [[] for _ in range(len(images))]
        scores = [[] for _ in range(len(images))]
        topk_probs = [[] for _ in range(len(images))]

        if batch_size is None:
            batch_size = self.get_batch_size()

        batch_size = min(len(images), batch_size)
        current_inputs = None

        max_image_tokens = self.get_max_image_token_count(images, task_names)
        if max_sliding_window is None:
            max_sliding_window = self.model.config.sliding_window
        self.setup_cache(
            batch_size,
            max_cache_len=max_image_tokens + max_sliding_window + self.extra_token_count.get(settings.TORCH_DEVICE_MODEL, 0),
            max_sliding_window=max_sliding_window,
        )

        batch_max_tokens = {}
        for idx, (img, txt, task) in enumerate(zip(images, input_texts, task_names)):
            self.prompt_queue.append(
                FoundationPrompt(
                    id=idx, task_name=task, text=txt, image=img, math_mode=math_mode
                )
            )
            batch_max_tokens[idx] = (
                max_tokens
                or settings.FOUNDATION_MAX_TOKENS
                or self.tasks[task]["max_tokens"]
            )

        overall_max_tokens = max(batch_max_tokens.values())

        pbar = tqdm(
            total=len(self.prompt_queue),
            desc=tqdm_desc,
            disable=self.disable_tqdm,
        )

        batch_bboxes = torch.zeros(len(images), overall_max_tokens, 6)
        batch_pos = [0] * len(images)

        while self.prompt_queue or self.num_active_slots > 0:
            if (
                self.num_empty_slots / batch_size
            ) >= self.min_prefill_ratio and self.prompt_queue:
                updated_inputs, outputs, merge_idxs = self.prefill(
                    current_inputs, max_lookahead_tokens=0
                )

                predicted_tokens_cpu = outputs.preds.cpu()
                scores_cpu = outputs.scores.cpu()
                bbox_preds_cpu = outputs.bbox_preds.cpu()

                if top_k > 0:
                    batch_top_k_probs, batch_top_k_indices = torch.topk(
                        outputs.token_probs, k=top_k, dim=-1
                    )
                    batch_top_k_probs_cpu = batch_top_k_probs.cpu()
                    batch_top_k_indices_cpu = batch_top_k_indices.cpu()

                for temp_idx, b_idx in enumerate(merge_idxs):
                    if self.batch_prompt_mapping[b_idx] is not None:
                        p_idx = self.batch_prompt_mapping[b_idx]
                        seq_len = predicted_tokens_cpu.shape[1]
                        for t_idx in range(seq_len):
                            token = predicted_tokens_cpu[temp_idx, t_idx].item()
                            predicted_tokens[p_idx].append(token)
                            batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[
                                temp_idx, t_idx
                            ]
                            batch_pos[p_idx] += 1
                            scores[p_idx].append(scores_cpu[temp_idx, t_idx].item())

                            if top_k > 0:
                                top_k_scores = {
                                    batch_top_k_indices_cpu[temp_idx, t_idx][
                                        k
                                    ].item(): batch_top_k_probs_cpu[temp_idx, t_idx][
                                        k
                                    ].item()
                                    for k in range(top_k)
                                }
                                topk_probs[p_idx].append(top_k_scores)

                            if token in [
                                self.processor.eos_token_id,
                                self.processor.no_output_token,
                            ]:
                                self.batch_prompt_mapping[b_idx] = None
                                pbar.update(1)
                                break
            else:
                updated_inputs, outputs = self.decode(
                    current_inputs, max_lookahead_tokens=max_lookahead_tokens
                )
                mark_step()

                predicted_tokens_cpu = outputs.preds.cpu()
                scores_cpu = outputs.scores.cpu()
                bbox_preds_cpu = outputs.bbox_preds.cpu()

                if top_k > 0:
                    batch_top_k_probs, batch_top_k_indices = torch.topk(
                        outputs.token_probs, k=top_k, dim=-1
                    )
                    batch_top_k_probs_cpu = batch_top_k_probs.cpu()
                    batch_top_k_indices_cpu = batch_top_k_indices.cpu()

                for b_idx, p_idx in self.batch_prompt_mapping.items():
                    if p_idx is not None:
                        seq_len = predicted_tokens_cpu.shape[1]
                        num_tokens = updated_inputs.num_valid_tokens[b_idx].item()
                        should_stop = False

                        for t_idx in range(seq_len):
                            # don't use multitoken prediction for lower confidence tokens
                            if t_idx > 0 and num_tokens < seq_len:
                                # roll so tokens are right aligned
                                updated_inputs.input_ids[b_idx] = (
                                    updated_inputs.input_ids[b_idx].roll(
                                        shifts=seq_len - num_tokens, dims=0
                                    )
                                )
                                # don't need to roll position_ids because that's handled in `decode` (and when we do beacon tokens)
                                break

                            token = predicted_tokens_cpu[b_idx, t_idx].item()
                            predicted_tokens[p_idx].append(token)
                            batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[
                                b_idx, t_idx
                            ]
                            batch_pos[p_idx] += 1
                            scores[p_idx].append(scores_cpu[b_idx, t_idx].item())

                            if top_k > 0:
                                top_k_scores = {
                                    batch_top_k_indices_cpu[temp_idx, t_idx][
                                        k
                                    ].item(): batch_top_k_probs_cpu[temp_idx, t_idx][
                                        k
                                    ].item()
                                    for k in range(top_k)
                                }
                                topk_probs[p_idx].append(top_k_scores)

                            repeats = len(predicted_tokens[p_idx]) >= batch_max_tokens[
                                p_idx
                            ] or (
                                drop_repeated_tokens
                                and detect_repeat_token(predicted_tokens[p_idx])
                                and task_names[p_idx]
                                in [
                                    TaskNames.ocr_with_boxes,
                                    TaskNames.ocr_without_boxes,
                                ]
                            )
                            if (
                                token
                                in [
                                    self.processor.eos_token_id,
                                    self.processor.pad_token_id,
                                ]
                                or repeats
                            ):
                                should_stop = True
                                break

                        if should_stop:
                            self.batch_prompt_mapping[b_idx] = None
                            pbar.update(1)

            # Update inputs and mark XLA step
            current_inputs = updated_inputs

        pbar.close()

        del self.kv_cache
        self.kv_cache = None
        torch.cuda.empty_cache()

        return predicted_tokens, batch_bboxes, scores, topk_probs

```

--------------------------------------------------------------------------------
/surya/common/donut/encoder.py:
--------------------------------------------------------------------------------

```python
import collections.abc
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn

from transformers.activations import ACT2FN
from transformers.pytorch_utils import (
    find_pruneable_heads_and_indices,
    meshgrid,
    prune_linear_layer,
)
from transformers.utils import ModelOutput
from transformers import DonutSwinConfig

from surya.common.pretrained import SuryaPreTrainedModel
from surya.common.xla import mark_step

_EXPECTED_OUTPUT_SHAPE = [1, 49, 1024]


@dataclass
# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
class DonutSwinEncoderOutput(ModelOutput):
    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None


@dataclass
class DonutSwinModelOutput(ModelOutput):
    last_hidden_state: torch.FloatTensor = None


# Copied from transformers.models.swin.modeling_swin.window_partition
def window_partition(input_feature, window_size):
    """
    Partitions the given input into windows.
    """
    batch_size, height, width, num_channels = input_feature.shape
    input_feature = input_feature.view(
        batch_size,
        height // window_size,
        window_size,
        width // window_size,
        window_size,
        num_channels,
    )
    windows = (
        input_feature.permute(0, 1, 3, 2, 4, 5)
        .contiguous()
        .view(-1, window_size, window_size, num_channels)
    )
    return windows


# Copied from transformers.models.swin.modeling_swin.window_reverse
def window_reverse(windows, window_size, height, width):
    """
    Merges windows to produce higher resolution features.
    """
    num_channels = windows.shape[-1]
    windows = windows.view(
        -1,
        height // window_size,
        width // window_size,
        window_size,
        window_size,
        num_channels,
    )
    windows = (
        windows.permute(0, 1, 3, 2, 4, 5)
        .contiguous()
        .view(-1, height, width, num_channels)
    )
    return windows


# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin
class DonutSwinEmbeddings(nn.Module):
    """
    Construct the patch and position embeddings. Optionally, also the mask token.
    """

    def __init__(self, config, use_mask_token=False):
        super().__init__()

        self.patch_embeddings = DonutSwinPatchEmbeddings(config)
        num_patches = self.patch_embeddings.num_patches
        self.patch_grid = self.patch_embeddings.grid_size
        self.mask_token = (
            nn.Parameter(torch.zeros(1, 1, config.embed_dim))
            if use_mask_token
            else None
        )

        self.position_embeddings = None
        self.row_embeddings = None
        self.column_embeddings = None
        if config.use_absolute_embeddings:
            self.position_embeddings = nn.Parameter(
                torch.zeros(1, num_patches + 1, config.embed_dim)
            )

        if hasattr(config, "use_2d_embeddings") and config.use_2d_embeddings:
            self.row_embeddings = nn.Parameter(
                torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim)
            )
            self.column_embeddings = nn.Parameter(
                torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim)
            )

        self.norm = nn.LayerNorm(config.embed_dim)

    def interpolate_pos_encoding(
        self, embeddings: torch.Tensor, height: int, width: int
    ) -> torch.Tensor:
        """
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """

        num_patches = embeddings.shape[1] - 1
        num_positions = self.position_embeddings.shape[1] - 1
        if num_patches == num_positions and height == width:
            return self.position_embeddings
        class_pos_embed = self.position_embeddings[:, 0]
        patch_pos_embed = self.position_embeddings[:, 1:]
        dim = embeddings.shape[-1]
        h0 = height // self.config.patch_size
        w0 = width // self.config.patch_size
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        h0, w0 = h0 + 0.1, w0 + 0.1
        patch_pos_embed = patch_pos_embed.reshape(
            1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
        )
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
            mode="bicubic",
            align_corners=False,
        )
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor],
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        interpolate_pos_encoding: bool = False,
    ) -> Tuple[torch.Tensor]:
        _, num_channels, height, width = pixel_values.shape
        embeddings, output_dimensions = self.patch_embeddings(pixel_values)
        embeddings = self.norm(embeddings)
        batch_size, seq_len, _ = embeddings.size()

        if bool_masked_pos is not None:
            mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
            # replace the masked visual tokens by mask_tokens
            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

        if self.position_embeddings is not None:
            if interpolate_pos_encoding:
                embeddings = embeddings + self.interpolate_pos_encoding(
                    embeddings, height, width
                )
            else:
                embeddings = embeddings + self.position_embeddings[:, :seq_len]

        if self.row_embeddings is not None and self.column_embeddings is not None:
            # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ...
            row_embeddings = self.row_embeddings[
                :, : output_dimensions[0], :
            ].repeat_interleave(output_dimensions[1], dim=1)
            column_embeddings = self.column_embeddings[
                :, : output_dimensions[1], :
            ].repeat(1, output_dimensions[0], 1)

            embeddings = embeddings + row_embeddings + column_embeddings

        return embeddings, output_dimensions


# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin
class DonutSwinPatchEmbeddings(nn.Module):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    def __init__(self, config):
        super().__init__()
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = config.num_channels, config.embed_dim
        image_size = (
            image_size
            if isinstance(image_size, collections.abc.Iterable)
            else (image_size, image_size)
        )
        patch_size = (
            patch_size
            if isinstance(patch_size, collections.abc.Iterable)
            else (patch_size, patch_size)
        )
        num_patches = (image_size[1] // patch_size[1]) * (
            image_size[0] // patch_size[0]
        )
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches
        self.grid_size = (
            image_size[0] // patch_size[0],
            image_size[1] // patch_size[1],
        )

        self.projection = nn.Conv2d(
            num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
        )

    def maybe_pad(self, pixel_values, height, width):
        if width % self.patch_size[1] != 0:
            pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
            pixel_values = nn.functional.pad(pixel_values, pad_values)
        if height % self.patch_size[0] != 0:
            pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
            pixel_values = nn.functional.pad(pixel_values, pad_values)
        return pixel_values

    def forward(
        self, pixel_values: Optional[torch.FloatTensor]
    ) -> Tuple[torch.Tensor, Tuple[int]]:
        _, num_channels, height, width = pixel_values.shape
        # pad the input to be divisible by self.patch_size, if needed
        pixel_values = self.maybe_pad(pixel_values, height, width)
        embeddings = self.projection(pixel_values)
        _, _, height, width = embeddings.shape
        output_dimensions = (height, width)
        embeddings = embeddings.flatten(2).transpose(1, 2)

        return embeddings, output_dimensions


# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
class DonutSwinPatchMerging(nn.Module):
    """
    Patch Merging Layer.

    Args:
        input_resolution (`Tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    """

    def __init__(
        self,
        input_resolution: Tuple[int],
        dim: int,
        norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def maybe_pad(self, input_feature, height, width):
        should_pad = (height % 2 == 1) or (width % 2 == 1)
        if should_pad:
            pad_values = (0, 0, 0, width % 2, 0, height % 2)
            input_feature = nn.functional.pad(input_feature, pad_values)

        return input_feature

    def forward(
        self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]
    ) -> torch.Tensor:
        height, width = input_dimensions
        # `dim` is height * width
        batch_size, dim, num_channels = input_feature.shape

        input_feature = input_feature.view(batch_size, height, width, num_channels)
        # pad input to be disible by width and height, if needed
        input_feature = self.maybe_pad(input_feature, height, width)
        # [batch_size, height/2, width/2, num_channels]
        input_feature_0 = input_feature[:, 0::2, 0::2, :]
        # [batch_size, height/2, width/2, num_channels]
        input_feature_1 = input_feature[:, 1::2, 0::2, :]
        # [batch_size, height/2, width/2, num_channels]
        input_feature_2 = input_feature[:, 0::2, 1::2, :]
        # [batch_size, height/2, width/2, num_channels]
        input_feature_3 = input_feature[:, 1::2, 1::2, :]
        # batch_size height/2 width/2 4*num_channels
        input_feature = torch.cat(
            [input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1
        )
        input_feature = input_feature.view(
            batch_size, -1, 4 * num_channels
        )  # batch_size height/2*width/2 4*C

        input_feature = self.norm(input_feature)
        input_feature = self.reduction(input_feature)

        return input_feature


# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin
class DonutSwinSelfAttention(nn.Module):
    def __init__(self, config, dim, num_heads, num_kv_heads, window_size):
        super().__init__()
        if dim % num_heads != 0:
            raise ValueError(
                f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
            )

        self.num_attention_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.kv_repeats = self.num_attention_heads // self.num_kv_heads
        self.attention_head_size = int(dim / num_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.kv_head_size = self.num_kv_heads * self.attention_head_size
        self.window_size = (
            window_size
            if isinstance(window_size, collections.abc.Iterable)
            else (window_size, window_size)
        )

        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(
                (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads
            )
        )

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)

        self.query = nn.Linear(
            self.all_head_size, self.all_head_size, bias=config.qkv_bias
        )
        self.key = nn.Linear(
            self.all_head_size, self.kv_head_size, bias=config.qkv_bias
        )
        self.value = nn.Linear(
            self.all_head_size, self.kv_head_size, bias=config.qkv_bias
        )

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def transpose_kv_for_scores(self, x, repeats):
        new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        x = x.repeat(
            1, 1, repeats, 1
        )  # repeat the values for each key-value head to match query dim
        return x.permute(0, 2, 1, 3).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        batch_size, dim, num_channels = hidden_states.shape
        mixed_query_layer = self.query(hidden_states)

        # Final is (batch_size, num_attention_heads, seq_len, attention_head_size)
        key_layer = self.transpose_kv_for_scores(
            self.key(hidden_states), self.kv_repeats
        )
        value_layer = self.transpose_kv_for_scores(
            self.value(hidden_states), self.kv_repeats
        )
        query_layer = self.transpose_for_scores(mixed_query_layer)

        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)
        ]
        relative_position_bias = relative_position_bias.view(
            self.window_size[0] * self.window_size[1],
            self.window_size[0] * self.window_size[1],
            -1,
        )
        relative_position_bias = (
            relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
        )
        relative_position_bias = relative_position_bias.repeat(batch_size, 1, 1, 1)

        if attention_mask is None:
            attention_mask = relative_position_bias
        else:
            mask_shape = attention_mask.shape[0]
            repeat_count = batch_size // mask_shape
            attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1)
            attention_mask = attention_mask + relative_position_bias

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_layer,
            key_layer,
            value_layer,
            attn_mask=attention_mask,
            dropout_p=0.0,
            scale=self.attention_head_size**-0.5,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, dim, num_channels)

        outputs = (attn_output,)
        return outputs


# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
class DonutSwinSelfOutput(nn.Module):
    def __init__(self, config, dim):
        super().__init__()
        self.dense = nn.Linear(dim, dim)

    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
        return self.dense(hidden_states)


# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
class DonutSwinAttention(nn.Module):
    def __init__(self, config, dim, num_heads, num_kv_heads, window_size):
        super().__init__()
        self.self = DonutSwinSelfAttention(
            config, dim, num_heads, num_kv_heads, window_size
        )
        self.output = DonutSwinSelfOutput(config, dim)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads,
            self.self.num_attention_heads,
            self.self.attention_head_size,
            self.pruned_heads,
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = (
            self.self.attention_head_size * self.self.num_attention_heads
        )
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        self_outputs = self.self(
            hidden_states, attention_mask, head_mask, output_attentions
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[
            1:
        ]  # add attentions if we output them
        return outputs


# Copied from transformers.models.swin.modeling_swin.SwinIntermediate
class DonutSwinIntermediate(nn.Module):
    def __init__(self, config, dim):
        super().__init__()
        self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


# Copied from transformers.models.swin.modeling_swin.SwinOutput
class DonutSwinOutput(nn.Module):
    def __init__(self, config, dim):
        super().__init__()
        self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.dense(hidden_states)


# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
class DonutSwinLayer(nn.Module):
    def __init__(
        self, config, dim, input_resolution, num_heads, num_kv_heads, shift_size=0
    ):
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.shift_size = shift_size
        self.window_size = config.window_size
        self.input_resolution = input_resolution
        self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
        self.attention = DonutSwinAttention(
            config, dim, num_heads, num_kv_heads, window_size=self.window_size
        )
        self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
        self.intermediate = DonutSwinIntermediate(config, dim)
        self.output = DonutSwinOutput(config, dim)

    def set_shift_and_window_size(self, input_resolution):
        if min(input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = int(0)
            self.window_size = (
                torch.min(torch.tensor(input_resolution))
                if torch.jit.is_tracing()
                else min(input_resolution)
            )

    def get_attn_mask(self, height, width, dtype, device):
        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
            height_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            width_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            count = 0
            for height_slice in height_slices:
                for width_slice in width_slices:
                    img_mask[:, height_slice, width_slice, :] = count
                    count += 1

            mask_windows = window_partition(img_mask, self.window_size)
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(
                attn_mask != 0, float(-100.0)
            ).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None
        return attn_mask

    def maybe_pad(self, hidden_states, height, width):
        pad_right = (self.window_size - width % self.window_size) % self.window_size
        pad_bottom = (self.window_size - height % self.window_size) % self.window_size
        pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
        hidden_states = nn.functional.pad(hidden_states, pad_values)
        return hidden_states, pad_values

    def forward(
        self,
        hidden_states: torch.Tensor,
        input_dimensions: Tuple[int, int],
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        always_partition: Optional[bool] = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if not always_partition:
            self.set_shift_and_window_size(input_dimensions)
        else:
            pass
        height, width = input_dimensions
        batch_size, _, channels = hidden_states.size()
        shortcut = hidden_states

        hidden_states = self.layernorm_before(hidden_states)

        hidden_states = hidden_states.view(batch_size, height, width, channels)

        # pad hidden_states to multiples of window size
        hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)

        _, height_pad, width_pad, _ = hidden_states.shape
        # cyclic shift
        if self.shift_size > 0:
            shifted_hidden_states = torch.roll(
                hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
            )
        else:
            shifted_hidden_states = hidden_states

        # partition windows
        hidden_states_windows = window_partition(
            shifted_hidden_states, self.window_size
        )
        hidden_states_windows = hidden_states_windows.view(
            -1, self.window_size * self.window_size, channels
        )
        attn_mask = self.get_attn_mask(
            height_pad,
            width_pad,
            dtype=hidden_states.dtype,
            device=hidden_states_windows.device,
        )

        attention_outputs = self.attention(
            hidden_states_windows,
            attn_mask,
            head_mask,
            output_attentions=output_attentions,
        )

        attention_output = attention_outputs[0]

        attention_windows = attention_output.view(
            -1, self.window_size, self.window_size, channels
        )
        shifted_windows = window_reverse(
            attention_windows, self.window_size, height_pad, width_pad
        )

        # reverse cyclic shift
        if self.shift_size > 0:
            attention_windows = torch.roll(
                shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
            )
        else:
            attention_windows = shifted_windows

        was_padded = pad_values[3] > 0 or pad_values[5] > 0
        if was_padded:
            attention_windows = attention_windows[:, :height, :width, :].contiguous()

        attention_windows = attention_windows.view(batch_size, height * width, channels)

        hidden_states = shortcut + attention_windows

        layer_output = self.layernorm_after(hidden_states)
        layer_output = self.intermediate(layer_output)
        layer_output = hidden_states + self.output(layer_output)

        layer_outputs = (
            (layer_output, attention_outputs[1])
            if output_attentions
            else (layer_output,)
        )
        return layer_outputs


# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
class DonutSwinStage(nn.Module):
    def __init__(
        self,
        config,
        layer_num,
        dim,
        input_resolution,
        depth,
        num_heads,
        num_kv_heads,
        downsample,
    ):
        super().__init__()
        self.config = config
        self.dim = dim
        self.blocks = nn.ModuleList(
            [
                DonutSwinLayer(
                    config=config,
                    dim=dim,
                    input_resolution=input_resolution,
                    num_heads=num_heads,
                    num_kv_heads=num_kv_heads,
                    shift_size=0 if (i % 2 == 0) else config.window_size // 2,
                )
                for i in range(depth)
            ]
        )

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(
                input_resolution, dim=dim, norm_layer=nn.LayerNorm
            )
        else:
            self.downsample = None

        self.pointing = False

        self.positional_encoding = None
        if config.use_positional_embeddings:
            self.positional_encoding = self.build_2d_sincos_position_embedding(
                input_resolution[1],
                input_resolution[0],
                embed_dim=dim,
            )

    @staticmethod
    def build_2d_sincos_position_embedding(
        width,
        height,
        embed_dim=256,
        temperature=10000.0,
        device="cpu",
        dtype=torch.float32,
    ):
        grid_w = torch.arange(int(width), dtype=dtype, device=device)
        grid_h = torch.arange(int(height), dtype=dtype, device=device)
        grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
        if embed_dim % 4 != 0:
            raise ValueError(
                "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
            )
        pos_dim = embed_dim // 4
        omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
        omega = 1.0 / (temperature**omega)

        out_w = grid_w.flatten()[..., None] @ omega[None]
        out_h = grid_h.flatten()[..., None] @ omega[None]

        return torch.concat(
            [out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1
        )[None, :, :]

    def forward(
        self,
        hidden_states: torch.Tensor,
        input_dimensions: Tuple[int, int],
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        always_partition: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        height, width = input_dimensions

        if self.positional_encoding is not None:
            hidden_states = hidden_states + self.positional_encoding.to(
                hidden_states.dtype
            ).to(hidden_states.device)

        for i, layer_module in enumerate(self.blocks):
            layer_head_mask = head_mask[i] if head_mask is not None else None

            layer_outputs = layer_module(
                hidden_states,
                input_dimensions,
                layer_head_mask,
                output_attentions,
                always_partition,
            )

            hidden_states = layer_outputs[0]

        hidden_states_before_downsampling = hidden_states
        if self.downsample is not None:
            height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
            output_dimensions = (height, width, height_downsampled, width_downsampled)
            hidden_states = self.downsample(
                hidden_states_before_downsampling, input_dimensions
            )
        else:
            output_dimensions = (height, width, height, width)

        stage_outputs = (
            hidden_states,
            hidden_states_before_downsampling,
            output_dimensions,
        )

        if output_attentions:
            stage_outputs += layer_outputs[1:]
        return stage_outputs


# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin
class DonutSwinEncoder(nn.Module):
    def __init__(self, config, grid_size):
        super().__init__()
        self.num_layers = len(config.depths)
        self.config = config
        self.layers = nn.ModuleList(
            [
                DonutSwinStage(
                    config=config,
                    layer_num=i_layer,
                    dim=int(config.embed_dim * 2**i_layer),
                    input_resolution=(
                        grid_size[0] // (2**i_layer),
                        grid_size[1] // (2**i_layer),
                    ),
                    depth=config.depths[i_layer],
                    num_heads=config.num_heads[i_layer],
                    num_kv_heads=config.num_kv_heads[i_layer]
                    if hasattr(config, "num_kv_heads")
                    else config.num_heads[i_layer],
                    downsample=DonutSwinPatchMerging
                    if (i_layer < self.num_layers - 1)
                    else None,
                )
                for i_layer in range(self.num_layers)
            ]
        )

        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        input_dimensions: Tuple[int, int],
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        output_hidden_states_before_downsampling: Optional[bool] = False,
        always_partition: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple, DonutSwinEncoderOutput]:
        all_hidden_states = () if output_hidden_states else None
        all_reshaped_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        if output_hidden_states:
            batch_size, _, hidden_size = hidden_states.shape
            # rearrange b (h w) c -> b c h w
            reshaped_hidden_state = hidden_states.view(
                batch_size, *input_dimensions, hidden_size
            )
            reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
            all_hidden_states += (hidden_states,)
            all_reshaped_hidden_states += (reshaped_hidden_state,)

        for i, layer_module in enumerate(self.layers):
            layer_head_mask = head_mask[i] if head_mask is not None else None

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    input_dimensions,
                    layer_head_mask,
                    output_attentions,
                    always_partition,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    input_dimensions,
                    layer_head_mask,
                    output_attentions,
                    always_partition,
                )

            hidden_states = layer_outputs[0]
            hidden_states_before_downsampling = layer_outputs[1]
            output_dimensions = layer_outputs[2]
            input_dimensions = (output_dimensions[-2], output_dimensions[-1])

            if output_hidden_states and output_hidden_states_before_downsampling:
                batch_size, _, hidden_size = hidden_states_before_downsampling.shape
                # rearrange b (h w) c -> b c h w
                # here we use the original (not downsampled) height and width
                reshaped_hidden_state = hidden_states_before_downsampling.view(
                    batch_size,
                    *(output_dimensions[0], output_dimensions[1]),
                    hidden_size,
                )
                reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
                all_hidden_states += (hidden_states_before_downsampling,)
                all_reshaped_hidden_states += (reshaped_hidden_state,)
            elif output_hidden_states and not output_hidden_states_before_downsampling:
                batch_size, _, hidden_size = hidden_states.shape
                # rearrange b (h w) c -> b c h w
                reshaped_hidden_state = hidden_states.view(
                    batch_size, *input_dimensions, hidden_size
                )
                reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
                all_hidden_states += (hidden_states,)
                all_reshaped_hidden_states += (reshaped_hidden_state,)

            if output_attentions:
                all_self_attentions += layer_outputs[3:]

        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, all_hidden_states, all_self_attentions]
                if v is not None
            )

        return DonutSwinEncoderOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            reshaped_hidden_states=all_reshaped_hidden_states,
        )


# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin
class DonutSwinPreTrainedModel(SuryaPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = DonutSwinConfig
    base_model_prefix = "swin"
    main_input_name = "pixel_values"
    supports_gradient_checkpointing = True
    _no_split_modules = ["DonutSwinStage"]

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

```

--------------------------------------------------------------------------------
/surya/ocr_error/model/encoder.py:
--------------------------------------------------------------------------------

```python
from __future__ import annotations

import math
from typing import Optional, Set, List, Tuple, Union, Dict

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F, MSELoss, CrossEntropyLoss, BCEWithLogitsLoss
from transformers import apply_chunking_to_forward
from transformers.activations import get_activation
from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from transformers.pytorch_utils import (
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)

from transformers.utils import (
    is_flash_attn_greater_or_equal_2_10,
)

from surya.common.pretrained import SuryaPreTrainedModel

from surya.common.s3 import S3DownloaderMixin
from surya.ocr_error.model.config import DistilBertConfig


def _get_unpad_data(attention_mask):
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )


def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
    position_enc = np.array(
        [
            [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
            for pos in range(n_pos)
        ]
    )
    out.requires_grad = False
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()


class Embeddings(nn.Module):
    def __init__(self, config: DistilBertConfig):
        super().__init__()
        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.dim, padding_idx=config.pad_token_id
        )
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.dim
        )

        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)),
            persistent=False,
        )

    def forward(
        self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Parameters:
            input_ids (torch.Tensor):
                torch.tensor(bs, max_seq_length) The token ids to embed.
            input_embeds (*optional*, torch.Tensor):
                The pre-computed word embeddings. Can only be passed if the input ids are `None`.


        Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
        embeddings)
        """
        if input_ids is not None:
            input_embeds = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)

        seq_length = input_embeds.size(1)

        # Setting the position-ids to the registered buffer in constructor, it helps
        # when tracing the model without passing position-ids, solves
        # isues similar to issue #5664
        if hasattr(self, "position_ids"):
            position_ids = self.position_ids[:, :seq_length]
        else:
            position_ids = torch.arange(
                seq_length, dtype=torch.long, device=input_ids.device
            )  # (max_seq_length)
            position_ids = position_ids.unsqueeze(0).expand_as(
                input_ids
            )  # (bs, max_seq_length)

        position_embeddings = self.position_embeddings(
            position_ids
        )  # (bs, max_seq_length, dim)

        embeddings = input_embeds + position_embeddings  # (bs, max_seq_length, dim)
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)
        return embeddings


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, config: DistilBertConfig):
        super().__init__()
        self.config = config

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)
        self.is_causal = False

        # Have an even number of multi heads that divide the dimensions
        if self.dim % self.n_heads != 0:
            # Raise value errors for even multi-head attention nodes
            raise ValueError(
                f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly"
            )

        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)

        self.pruned_heads: Set[int] = set()
        self.attention_head_size = self.dim // self.n_heads

    def prune_heads(self, heads: List[int]):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_heads, self.attention_head_size, self.pruned_heads
        )
        # Prune linear layers
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.dim = self.attention_head_size * self.n_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, ...]:
        """
        Parameters:
            query: torch.tensor(bs, seq_length, dim)
            key: torch.tensor(bs, seq_length, dim)
            value: torch.tensor(bs, seq_length, dim)
            mask: torch.tensor(bs, seq_length)

        Returns:
            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
        """
        bs, q_length, dim = query.size()
        k_length = key.size(1)
        # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
        # assert key.size() == value.size()

        dim_per_head = self.dim // self.n_heads

        mask_reshp = (bs, 1, 1, k_length)

        def shape(x: torch.Tensor) -> torch.Tensor:
            """separate heads"""
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x: torch.Tensor) -> torch.Tensor:
            """group heads"""
            return (
                x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
            )

        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)

        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
        mask = (
            (mask == 0).view(mask_reshp).expand_as(scores)
        )  # (bs, n_heads, q_length, k_length)
        scores = scores.masked_fill(
            mask, torch.tensor(torch.finfo(scores.dtype).min)
        )  # (bs, n_heads, q_length, k_length)

        weights = nn.functional.softmax(
            scores, dim=-1
        )  # (bs, n_heads, q_length, k_length)
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
        context = unshape(context)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)

        if output_attentions:
            return (context, weights)
        else:
            return (context,)


class DistilBertFlashAttention2(MultiHeadSelfAttention):
    """
    DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module
    stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
    API of flash attention and deal with padding tokens in case the input contains any of them.
    """

    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, ...]:
        """
        Parameters:
            query: torch.tensor(bs, seq_length, dim)
            key: torch.tensor(bs, seq_length, dim)
            value: torch.tensor(bs, seq_length, dim)
            mask: torch.tensor(bs, seq_length)

        Returns:
            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
        """
        batch_size, q_length, dim = query.size()

        dim_per_head = self.dim // self.n_heads

        def reshape(x: torch.Tensor) -> torch.Tensor:
            """separate heads"""
            return x.view(batch_size, -1, self.n_heads, dim_per_head)

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dim x hidden_dim
        query_states = reshape(self.q_lin(query))
        key_states = reshape(self.k_lin(key))
        value_states = reshape(self.v_lin(value))

        attn_dropout = self.config.attention_dropout if self.training else 0.0

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in the correct dtype just to be sure everything works as expected.
        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
        # in fp32. (LlamaRMSNorm handles it correctly)

        if query_states.dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_lin.weight.dtype

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        attn_weights = self._flash_attention_forward(
            query_states, key_states, value_states, mask, q_length, dropout=attn_dropout
        )

        attn_weights_reshaped = attn_weights.reshape(
            batch_size, q_length, self.n_heads * dim_per_head
        )
        attn_output = self.out_lin(attn_weights_reshaped)

        if output_attentions:
            return (attn_output, attn_weights)
        else:
            return (attn_output,)

    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False
    def _flash_attention_forward(
        self,
        query_states,
        key_states,
        value_states,
        attention_mask,
        query_length,
        dropout=0.0,
        softmax_scale=None,
    ):
        """
        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
        first unpad the input, then computes the attention scores and pad the final attention scores.

        Args:
            query_states (`torch.Tensor`):
                Input query states to be passed to Flash Attention API
            key_states (`torch.Tensor`):
                Input key states to be passed to Flash Attention API
            value_states (`torch.Tensor`):
                Input value states to be passed to Flash Attention API
            attention_mask (`torch.Tensor`):
                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
                position of padding tokens and 1 for the position of non-padding tokens.
            dropout (`float`):
                Attention dropout
            softmax_scale (`float`, *optional*):
                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
        """
        from flash_attn import flash_attn_func, flash_attn_varlen_func
        from flash_attn.bert_padding import pad_input

        if not self._flash_attn_uses_top_left_mask:
            causal = self.is_causal
        else:
            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
            causal = self.is_causal and query_length != 1

        # Contains at least one padding token in the sequence
        if attention_mask is not None:
            batch_size = query_states.shape[0]
            (
                query_states,
                key_states,
                value_states,
                indices_q,
                cu_seq_lens,
                max_seq_lens,
            ) = self._upad_input(
                query_states, key_states, value_states, attention_mask, query_length
            )

            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

            attn_output_unpad = flash_attn_varlen_func(
                query_states,
                key_states,
                value_states,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_in_batch_q,
                max_seqlen_k=max_seqlen_in_batch_k,
                dropout_p=dropout,
                softmax_scale=softmax_scale,
                causal=causal,
            )

            attn_output = pad_input(
                attn_output_unpad, indices_q, batch_size, query_length
            )
        else:
            attn_output = flash_attn_func(
                query_states,
                key_states,
                value_states,
                dropout,
                softmax_scale=softmax_scale,
                causal=causal,
            )

        return attn_output

    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads
    def _upad_input(
        self, query_layer, key_layer, value_layer, attention_mask, query_length
    ):
        from flash_attn.bert_padding import index_first_axis, unpad_input

        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

        key_layer = index_first_axis(
            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
            indices_k,
        )
        value_layer = index_first_axis(
            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
            indices_k,
        )
        if query_length == kv_seq_len:
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim),
                indices_k,
            )
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        elif query_length == 1:
            max_seqlen_in_batch_q = 1
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # There is a memcpy here, that is very bad.
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1)
        else:
            # The -q_len: slice assumes left padding.
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
                query_layer, attention_mask
            )

        return (
            query_layer,
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )


class FFN(nn.Module):
    def __init__(self, config: DistilBertConfig):
        super().__init__()
        self.dropout = nn.Dropout(p=config.dropout)
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
        self.activation = get_activation(config.activation)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return apply_chunking_to_forward(
            self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input
        )

    def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
        x = self.lin1(input)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x


DISTILBERT_ATTENTION_CLASSES = {
    "eager": MultiHeadSelfAttention,
    "flash_attention_2": DistilBertFlashAttention2,
}


class TransformerBlock(nn.Module):
    def __init__(self, config: DistilBertConfig):
        super().__init__()

        # Have an even number of Configure multi-heads
        if config.dim % config.n_heads != 0:
            raise ValueError(
                f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly"
            )

        self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](
            config
        )
        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

        self.ffn = FFN(config)
        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, ...]:
        """
        Parameters:
            x: torch.tensor(bs, seq_length, dim)
            attn_mask: torch.tensor(bs, seq_length)

        Returns:
            sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
            torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
        """
        # Self-Attention
        sa_output = self.attention(
            query=x,
            key=x,
            value=x,
            mask=attn_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )
        if output_attentions:
            sa_output, sa_weights = (
                sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
            )
        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
            sa_output = sa_output[0]

        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)

        # Feed Forward Network
        ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)
        ffn_output: torch.Tensor = self.output_layer_norm(
            ffn_output + sa_output
        )  # (bs, seq_length, dim)

        output = (ffn_output,)
        if output_attentions:
            output = (sa_weights,) + output
        return output


class Transformer(nn.Module):
    def __init__(self, config: DistilBertConfig):
        super().__init__()
        self.n_layers = config.n_layers
        self.layer = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config.n_layers)]
        )
        self.gradient_checkpointing = False

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: Optional[bool] = None,
    ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:  # docstyle-ignore
        """
        Parameters:
            x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
            attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.

        Returns:
            hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)
            layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
                Tuple of length n_layers with the hidden states from each layer.
                Optional: only if output_hidden_states=True
            all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
                Tuple of length n_layers with the attention weights from each layer
                Optional: only if output_attentions=True
        """
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        hidden_state = x
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_state,
                    attn_mask,
                    head_mask[i],
                    output_attentions,
                )
            else:
                layer_outputs = layer_module(
                    hidden_state,
                    attn_mask,
                    head_mask[i],
                    output_attentions,
                )

            hidden_state = layer_outputs[-1]

            if output_attentions:
                if len(layer_outputs) != 2:
                    raise ValueError(
                        f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}"
                    )

                attentions = layer_outputs[0]
                all_attentions = all_attentions + (attentions,)
            else:
                if len(layer_outputs) != 1:
                    raise ValueError(
                        f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}"
                    )

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_state,)

        if not return_dict:
            return tuple(
                v
                for v in [hidden_state, all_hidden_states, all_attentions]
                if v is not None
            )
        return BaseModelOutput(
            last_hidden_state=hidden_state,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
        )


class DistilBertPreTrainedModel(SuryaPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = DistilBertConfig
    load_tf_weights = None
    base_model_prefix = "distilbert"
    supports_gradient_checkpointing = True
    _supports_flash_attn_2 = True

    def _init_weights(self, module: nn.Module):
        """Initialize the weights."""
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                self.config.max_position_embeddings,
                self.config.dim,
                module.position_embeddings.weight,
            )


class DistilBertModel(DistilBertPreTrainedModel):
    def __init__(self, config: DistilBertConfig):
        super().__init__(config)

        self.embeddings = Embeddings(config)  # Embeddings
        self.transformer = Transformer(config)  # Encoder
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"

        # Initialize weights and apply final processing
        self.post_init()

    def get_position_embeddings(self) -> nn.Embedding:
        """
        Returns the position embeddings
        """
        return self.embeddings.position_embeddings

    def resize_position_embeddings(self, new_num_position_embeddings: int):
        """
        Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.

        Arguments:
            new_num_position_embeddings (`int`):
                The number of new position embedding matrix. If position embeddings are learned, increasing the size
                will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
                end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
                size will add correct vectors at the end following the position encoding algorithm, whereas reducing
                the size will remove vectors from the end.
        """
        num_position_embeds_diff = (
            new_num_position_embeddings - self.config.max_position_embeddings
        )

        # no resizing needs to be done if the length stays the same
        if num_position_embeds_diff == 0:
            return

        self.config.max_position_embeddings = new_num_position_embeddings

        old_position_embeddings_weight = (
            self.embeddings.position_embeddings.weight.clone()
        )

        self.embeddings.position_embeddings = nn.Embedding(
            self.config.max_position_embeddings, self.config.dim
        )

        if self.config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                n_pos=self.config.max_position_embeddings,
                dim=self.config.dim,
                out=self.position_embeddings.weight,
            )
        else:
            with torch.no_grad():
                if num_position_embeds_diff > 0:
                    self.embeddings.position_embeddings.weight[
                        :-num_position_embeds_diff
                    ] = nn.Parameter(old_position_embeddings_weight)
                else:
                    self.embeddings.position_embeddings.weight = nn.Parameter(
                        old_position_embeddings_weight[:num_position_embeds_diff]
                    )
        # move position_embeddings to correct device
        self.embeddings.position_embeddings.to(self.device)

    def get_input_embeddings(self) -> nn.Embedding:
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, new_embeddings: nn.Embedding):
        self.embeddings.word_embeddings = new_embeddings

    def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.transformer.layer[layer].attention.prune_heads(heads)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # Prepare head mask if needed
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embeddings = self.embeddings(input_ids, inputs_embeds)  # (bs, seq_length, dim)

        if self._use_flash_attention_2:
            attention_mask = (
                attention_mask
                if (attention_mask is not None and 0 in attention_mask)
                else None
            )
        else:
            if attention_mask is None:
                attention_mask = torch.ones(
                    input_shape, device=device
                )  # (bs, seq_length)

        return self.transformer(
            x=embeddings,
            attn_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


class DistilBertForSequenceClassification(S3DownloaderMixin, DistilBertPreTrainedModel):
    def __init__(self, config: DistilBertConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.num_labels = config.num_labels
        self.config = config

        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier = nn.Linear(config.dim, config.num_labels)
        self.dropout = nn.Dropout(config.seq_classif_dropout)

        # Initialize weights and apply final processing
        self.post_init()

    def get_position_embeddings(self) -> nn.Embedding:
        """
        Returns the position embeddings
        """
        return self.distilbert.get_position_embeddings()

    def resize_position_embeddings(self, new_num_position_embeddings: int):
        """
        Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.

        Arguments:
            new_num_position_embeddings (`int`):
                The number of new position embedding matrix. If position embeddings are learned, increasing the size
                will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
                end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
                size will add correct vectors at the end following the position encoding algorithm, whereas reducing
                the size will remove vectors from the end.
        """
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        distilbert_output = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
        pooled_output = self.dropout(pooled_output)  # (bs, dim)
        logits = self.classifier(pooled_output)  # (bs, num_labels)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (
                    labels.dtype == torch.long or labels.dtype == torch.int
                ):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + distilbert_output[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
        )

```
Page 4/4FirstPrevNextLast