This is page 4 of 4. Use http://codebase.md/datalab-to/surya?page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── ISSUE_TEMPLATE │ │ ├── breaking-bug-report.md │ │ ├── feature_request.md │ │ └── output-bug-report.md │ └── workflows │ ├── benchmarks.yml │ ├── ci.yml │ ├── cla.yml │ ├── publish.yml │ └── scripts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── benchmark │ ├── detection.py │ ├── layout.py │ ├── ordering.py │ ├── recognition.py │ ├── table_recognition.py │ ├── texify.py │ └── utils │ ├── __init__.py │ ├── bbox.py │ ├── metrics.py │ ├── scoring.py │ ├── tatr.py │ ├── tesseract.py │ ├── textract.py │ └── verify_benchmark_scores.py ├── CITATION.cff ├── CLA.md ├── detect_layout.py ├── detect_text.py ├── LICENSE ├── ocr_app.py ├── ocr_latex.py ├── ocr_text.py ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── README.md ├── signatures │ └── version1 │ └── cla.json ├── static │ ├── fonts │ │ └── .gitignore │ └── images │ ├── arabic_layout.jpg │ ├── arabic_reading.jpg │ ├── arabic_text.jpg │ ├── arabic.jpg │ ├── benchmark_chart_small.png │ ├── benchmark_chart.png │ ├── benchmark_layout_chart.png │ ├── benchmark_rec_chart.png │ ├── benchmark_tablerec_acc.png │ ├── benchmark_tablerec_speed.png │ ├── chi_hind_layout.jpg │ ├── chi_hind_orig.jpg │ ├── chi_hind_reading.jpg │ ├── chi_hind_text.jpg │ ├── chi_hind.jpg │ ├── chinese_layout.jpg │ ├── chinese_reading.jpg │ ├── chinese_text.jpg │ ├── chinese.jpg │ ├── excerpt_layout.png │ ├── excerpt_reading.jpg │ ├── excerpt_text.png │ ├── excerpt.png │ ├── funsd_layout.jpg │ ├── funsd_reading.jpg │ ├── funsd_text.jpg │ ├── funsd.png │ ├── gcloud_full_langs.png │ ├── gcloud_rec_bench.png │ ├── hindi_layout.jpg │ ├── hindi_reading.jpg │ ├── hindi_text.jpg │ ├── hindi.jpg │ ├── japanese_layout.jpg │ ├── japanese_reading.jpg │ ├── japanese_tablerec.png │ ├── japanese_text.jpg │ ├── japanese.jpg │ ├── latex_ocr.png │ ├── nyt_layout.jpg │ ├── nyt_order.jpg │ ├── nyt_text.jpg │ ├── nyt.jpg │ ├── paper_layout.jpg │ ├── paper_reading.jpg │ ├── paper_tablerec.png │ ├── paper_text.jpg │ ├── paper.jpg │ ├── pres_layout.jpg │ ├── pres_reading.jpg │ ├── pres_tablerec.png │ ├── pres_text.jpg │ ├── pres.png │ ├── rec_acc_table.png │ ├── scanned_layout.jpg │ ├── scanned_reading.jpg │ ├── scanned_tablerec.png │ ├── scanned_tablerec2.png │ ├── scanned_text.jpg │ ├── scanned.png │ ├── surya_rec_perf.png │ ├── table_rec.png │ ├── textbook_layout.jpg │ ├── textbook_order.jpg │ ├── textbook_text.jpg │ └── textbook.jpg ├── surya │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── adetr │ │ │ └── decoder.py │ │ ├── donut │ │ │ ├── encoder.py │ │ │ └── processor.py │ │ ├── load.py │ │ ├── polygon.py │ │ ├── predictor.py │ │ ├── pretrained.py │ │ ├── s3.py │ │ ├── surya │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── decoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── embedder │ │ │ │ └── __init__.py │ │ │ ├── encoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── flash_attn_utils.py │ │ │ ├── processor │ │ │ │ ├── __init__.py │ │ │ │ ├── schema.py │ │ │ │ └── tokenizer.py │ │ │ └── schema.py │ │ ├── util.py │ │ └── xla.py │ ├── debug │ │ ├── draw.py │ │ ├── fonts.py │ │ ├── katex.js │ │ ├── render_html.py │ │ └── text.py │ ├── detection │ │ ├── __init__.py │ │ ├── heatmap.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoderdecoder.py │ │ ├── parallel.py │ │ ├── processor.py │ │ ├── schema.py │ │ └── util.py │ ├── foundation │ │ ├── __init__.py │ │ ├── cache │ │ │ ├── __init__.py │ │ │ ├── dynamic_ops.py │ │ │ └── static_ops.py │ │ ├── loader.py │ │ └── util.py │ ├── input │ │ ├── load.py │ │ └── processing.py │ ├── layout │ │ ├── __init__.py │ │ ├── label.py │ │ └── schema.py │ ├── logging.py │ ├── models.py │ ├── ocr_error │ │ ├── __init__.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoder.py │ │ ├── schema.py │ │ └── tokenizer.py │ ├── recognition │ │ ├── __init__.py │ │ ├── languages.py │ │ ├── postprocessing.py │ │ ├── schema.py │ │ └── util.py │ ├── scripts │ │ ├── __init__.py │ │ ├── config.py │ │ ├── detect_layout.py │ │ ├── detect_text.py │ │ ├── finetune_ocr.py │ │ ├── hf_to_s3.py │ │ ├── ocr_latex.py │ │ ├── ocr_text.py │ │ ├── run_streamlit_app.py │ │ ├── run_texify_app.py │ │ ├── streamlit_app.py │ │ ├── table_recognition.py │ │ └── texify_app.py │ ├── settings.py │ └── table_rec │ ├── __init__.py │ ├── loader.py │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── encoderdecoder.py │ ├── processor.py │ ├── schema.py │ └── shaper.py ├── table_recognition.py ├── tests │ ├── assets │ │ └── test_latex.png │ ├── conftest.py │ ├── test_detection.py │ ├── test_foundation.py │ ├── test_latex_ocr.py │ ├── test_layout.py │ ├── test_ocr_errors.py │ ├── test_recognition.py │ └── test_table_rec.py └── texify_app.py ``` # Files -------------------------------------------------------------------------------- /surya/common/adetr/decoder.py: -------------------------------------------------------------------------------- ```python from typing import Dict, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers import PretrainedConfig from transformers.activations import ACT2FN from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import BaseModelOutputWithNoAttention from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from surya.common.pretrained import SuryaPreTrainedModel from surya.common.xla import mark_step _MAX_SQRT_GRADIENT = 1000.0 class WrappedEmbedding(nn.Embedding): def forward(self, input_ids, *args, **kwargs): return super().forward(input_ids) class SuryaADETRDecoderRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.zeros(dim)) def _norm(self, x): variance = x.pow(2).mean(-1, keepdim=True) # Add clipping to prevent division by zero variance = torch.clamp(variance, min=self.eps) return x * torch.rsqrt(variance) def forward(self, x): output = self._norm(x.float()) # Llama does x.to(float16) * w whilst SuryaADETRDecoder is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 output = output * (1.0 + self.weight.float()) # Clamp to float16 range f16_info = torch.finfo(x.dtype) output = output.clamp(min=f16_info.min, max=f16_info.max) output = torch.where( torch.isnan(output), torch.tensor(0.0, device=output.device), output ) return output.type_as(x) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" ALL_LAYERNORM_LAYERS.append(SuryaADETRDecoderRMSNorm) class SuryaADETRDecoderRotaryEmbedding(nn.Module): def __init__(self, dim, base=10000, device=None): super().__init__() self.dim = dim self.base = base inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim) ) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaADETRDecoder def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] self.inv_freq.to(x.device) inv_freq_expanded = ( self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ) position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( 1, 2 ) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class SuryaADETRDecoderSdpaCrossAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper Modified for GQA """ def __init__(self, config: PretrainedConfig): super().__init__() self.config = config self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads self.q_proj = nn.Linear( self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias, ) self.k_proj = nn.Linear( self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.v_proj = nn.Linear( self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.o_proj = nn.Linear( self.num_attention_heads * self.head_dim, self.hidden_size, bias=True ) self.rotary_emb = SuryaADETRDecoderRotaryEmbedding( self.head_dim, base=config.rope_theta, ) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # Encoder attention mask currently ignored bsz, q_len, _ = hidden_states.size() _, v_len, _ = encoder_hidden_states.size() query_states = self.q_proj(hidden_states) query_states = query_states.view( bsz, q_len, self.num_attention_heads, self.head_dim ).transpose(1, 2) if self.key_states is None: key_states = self.k_proj(encoder_hidden_states) value_states = self.v_proj(encoder_hidden_states) key_states = key_states.view( bsz, v_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, v_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) if use_cache: self._update_cache(key_states, value_states) else: key_states = self.key_states value_states = self.value_states key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=None, dropout_p=self.attention_dropout if self.training else 0.0, scale=self.head_dim**-0.5, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output def _clear_cache(self): if self.value_states is not None: del self.value_states if self.key_states is not None: del self.key_states def _setup_cache(self, batch_size, device, dtype=None): # Setup initial caches self.value_states = None self.key_states = None @torch.no_grad() def _update_cache(self, key_states, value_states, **cache_kwargs): self.value_states = value_states self.key_states = key_states class SuryaADETRDecoderSdpaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: PretrainedConfig, static_cache=False, max_boxes=None): super().__init__() self.config = config self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads self.q_proj = nn.Linear( self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias, ) self.k_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.v_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.o_proj = nn.Linear( self.num_attention_heads * self.head_dim, self.hidden_size, bias=True ) self.rotary_emb = SuryaADETRDecoderRotaryEmbedding( self.head_dim, base=config.rope_theta, ) self.static_cache = static_cache self.max_boxes = max_boxes def forward( self, hidden_states: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: bool = False, window_attn: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # Final is bsz, num_attention_heads, seq_len, head_dim query_states = query_states.view( bsz, q_len, self.num_attention_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if use_cache and hasattr(self, "key_states"): cache_kwargs = { "cache_position": cache_position, "window_attn": window_attn, } key_states, value_states = self._update_cache( key_states, value_states, **cache_kwargs ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: # Mask is batch, head, seq_len, kv_len causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] if cache_position is not None and self.static_cache: current_pos = cache_position[-1] causal_mask[:, :, :, current_pos + 1 :] = torch.finfo( causal_mask.dtype ).min attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, scale=self.head_dim**-0.5, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output def _setup_cache(self, batch_size, device, dtype=None): if dtype is None and self.config.torch_dtype is not None: dtype = self.config.torch_dtype dtype = dtype if dtype is not None else torch.float32 # Setup initial caches self.value_states = None self.key_states = None if self.static_cache: cache_shape = ( batch_size, self.num_key_value_heads, self.max_boxes, self.head_dim, ) self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device) self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device) def _clear_cache(self): if self.value_states is not None: del self.value_states if self.key_states is not None: del self.key_states def _update_static_cache(self, key_states, value_states, **cache_kwargs): cache_position = cache_kwargs.get("cache_position") k_out, v_out = ( self.key_states.to(key_states.device), self.value_states.to(value_states.device), ) k_out[:, :, cache_position] = key_states.to(k_out.dtype) v_out[:, :, cache_position] = value_states.to(v_out.dtype) self.key_states, self.value_states = k_out, v_out return k_out, v_out def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs): k_out = key_states if self.key_states is not None: k_out = torch.cat([self.key_states, key_states], dim=2) v_out = value_states if self.value_states is not None: v_out = torch.cat([self.value_states, value_states], dim=2) self.key_states, self.value_states = k_out, v_out return k_out, v_out @torch.no_grad() def _update_cache(self, key_states, value_states, **cache_kwargs): if self.static_cache: return self._update_static_cache(key_states, value_states, **cache_kwargs) return self._update_dynamic_cache(key_states, value_states, **cache_kwargs) class SuryaADETRDecoderMlp(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) if config.hidden_activation is None: config.hidden_activation = "gelu_pytorch_tanh" hidden_activation = config.hidden_activation self.act_fn = ACT2FN[hidden_activation] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class SuryaADETRDecoderLayer(nn.Module): def __init__(self, config, layer_idx, static_cache=False, max_boxes=None): super().__init__() self.cross_pre_norm = SuryaADETRDecoderRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.temporal_pre_norm = SuryaADETRDecoderRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.temporal_block = None if layer_idx in config.self_attn_layers: self.temporal_block = SuryaADETRDecoderSdpaAttention( config, static_cache=static_cache, max_boxes=max_boxes ) self.cross_attn_block = None if layer_idx in config.cross_attn_layers: self.cross_attn_block = SuryaADETRDecoderSdpaCrossAttention(config) self.window_attn = layer_idx not in config.global_attn_layers self.channel_pre_norm = SuryaADETRDecoderRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.mlp_block = SuryaADETRDecoderMlp(config) self.double_residual_flow = getattr(config, "double_residual_flow", False) def forward( self, activations: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor, encoder_hidden_states: torch.Tensor = None, encoder_attention_mask: torch.Tensor = None, cache_position: torch.Tensor = None, use_cache: bool = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: if self.double_residual_flow: return self.double_res_forward( activations, position_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache, ) hidden_states = activations if self.cross_attn_block is not None: # Do cross-attention on encoder outputs cross_attn_inputs = self.cross_pre_norm(hidden_states) cross_attn_path = self.cross_attn_block( cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache, ) hidden_states = cross_attn_path + hidden_states if self.temporal_block is not None: temporal_inputs = self.temporal_pre_norm( hidden_states ) # RMSNorm introduces slight slight differences temporal_path = self.temporal_block( temporal_inputs, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn, ) hidden_states = temporal_path + hidden_states block_input = hidden_states hidden_states = self.channel_pre_norm(block_input) hidden_states = self.mlp_block(hidden_states) hidden_states = hidden_states + block_input return hidden_states def double_res_forward( self, activations: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor, encoder_hidden_states: torch.Tensor = None, encoder_attention_mask: torch.Tensor = None, cache_position: torch.Tensor = None, use_cache: bool = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: raw_activations = activations if self.cross_attn_block is not None: # Do cross-attention on encoder outputs cross_attn_inputs = self.cross_pre_norm(activations) cross_attn_path = self.cross_attn_block( cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache, ) cross_attn_output = cross_attn_path + raw_activations else: cross_attn_output = raw_activations if self.temporal_block is not None: inputs_normalized = self.temporal_pre_norm( cross_attn_output ) # RMSNorm introduces slight slight differences hidden_states = self.temporal_block( inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn, ) residual = hidden_states + raw_activations else: residual = cross_attn_output hidden_states = self.channel_pre_norm(residual) hidden_states = self.mlp_block(hidden_states) hidden_states = hidden_states + residual return hidden_states class SuryaADETRDecoderPreTrainedModel(SuryaPreTrainedModel): config_class = PretrainedConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["SuryaADETRDecoderLayer"] _skip_keys_device_placement = ["cache"] _supports_flash_attn_2 = False _supports_sdpa = False # we can't compare with eager for now _supports_cache_class = True _supports_quantized_cache = True def _init_weights(self, module): if isinstance(module, SuryaADETRDecoderSdpaAttention): torch.nn.init.normal_( module.q_proj.weight, mean=0.0, std=self.config.init_std ) torch.nn.init.normal_( module.k_proj.weight, mean=0.0, std=self.config.init_std ) torch.nn.init.normal_( module.v_proj.weight, mean=0.0, std=self.config.init_std ) torch.nn.init.normal_( module.o_proj.weight, mean=0.0, std=self.config.init_std ) elif isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) if getattr(module, "bias", None) is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.init_std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _setup_cache(self, config, batch, device, dtype): layers = getattr(self, "model", self).layers for layer in layers: if layer.temporal_block: layer.temporal_block._setup_cache(batch, device, dtype) if layer.cross_attn_block: layer.cross_attn_block._setup_cache(batch, device, dtype) def _clear_cache(self): layers = getattr(self, "model", self).layers for layer in layers: if layer.temporal_block: layer.temporal_block._clear_cache() if layer.cross_attn_block: layer.cross_attn_block._clear_cache() def reset_cache(self, batch, device, dtype): pass def _tie_weights(self): pass def tie_weights(self): pass class SuryaADETRDecoderModel(SuryaADETRDecoderPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaADETRDecoderDecoderLayer`] Args: config: PretrainedConfig """ def __init__( self, config: PretrainedConfig, embedder: nn.Module = None, max_boxes: int = None, static_cache: bool = False, ): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.causal = config.causal self.embed_tokens = embedder self.max_boxes = max_boxes self.static_cache = static_cache self.layers = nn.ModuleList( [ SuryaADETRDecoderLayer( config, layer_idx, static_cache=static_cache, max_boxes=max_boxes ) for layer_idx in range(config.num_hidden_layers) ] ) self.final_norm = SuryaADETRDecoderRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.gradient_checkpointing = False self.register_buffer( "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False, ) # Initialize weights and apply final processing self.post_init() # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings def get_input_embeddings(self): return self.embed_tokens # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, input_boxes_counts: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, prefill: bool = False, ) -> Union[Tuple, BaseModelOutputWithNoAttention]: use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if self.gradient_checkpointing and self.training and use_cache: use_cache = False inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts) hidden_states = inputs_embeds if use_cache and prefill: self._setup_cache( self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype, ) if cache_position is None: cache_position = torch.arange( hidden_states.shape[1], device=hidden_states.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position ) all_hidden_states = () if output_hidden_states else None for i, residual_block in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache, ) else: hidden_states = residual_block( hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache, ) hidden_states = self.final_norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) return BaseModelOutputWithNoAttention( last_hidden_state=hidden_states, hidden_states=all_hidden_states, ) # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 # Ignore copy def _update_causal_mask(self, attention_mask, input_tensor, cache_position): if not self.causal: return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] target_length = max(self.max_boxes, sequence_length) diagonal = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device, ) causal_mask = diagonal if sequence_length != 1: # Select the upper triangular part of the matrix, but unmask current token (the diagonal) # triu will be the min_dtype, everything else is 0 (attended to) causal_mask = torch.triu(diagonal, diagonal=1) causal_mask *= torch.arange( target_length, device=device ) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand( input_tensor.shape[0], 1, -1, -1 ) if attention_mask is not None: causal_mask = ( causal_mask.clone() ) # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: # Mask positions in the causal mask that are masked in the attention mask mask_length = attention_mask.shape[-1] padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[ :, None, None, : ].eq(0.0) causal_mask[..., :mask_length] = causal_mask[ ..., :mask_length ].masked_fill(padding_mask, min_dtype) if attention_mask is not None and attention_mask.device.type == "cuda": # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended( causal_mask, min_dtype ) return causal_mask ``` -------------------------------------------------------------------------------- /surya/foundation/__init__.py: -------------------------------------------------------------------------------- ```python from __future__ import annotations from dataclasses import dataclass from typing import List, Optional, Tuple from collections import deque import cv2 import numpy as np import torch import math from PIL import Image from tqdm import tqdm import torch.nn.functional as F from surya.common.surya import SuryaModelOutput from surya.common.xla import mark_step from surya.common.predictor import BasePredictor from surya.foundation.loader import FoundationModelLoader from surya.foundation.util import ( detect_repeat_token, ) from surya.common.surya.schema import TaskNames from surya.foundation.cache.dynamic_ops import DynamicOpsCache from surya.foundation.cache.static_ops import StaticOpsCache from surya.settings import settings from surya.logging import get_logger, configure_logging configure_logging() logger = get_logger() @dataclass class ContinuousBatchInput: input_ids: torch.Tensor input_boxes: torch.Tensor position_ids: torch.Tensor # input_ids and position_ids may be padded, num_valid_tokens tracks the 'real' counts num_valid_tokens: torch.Tensor # count the number of predicted tokens for each batch element so far num_predicted_tokens: torch.Tensor needs_bbox_embedding: torch.Tensor @dataclass class ContinuousBatchOutput: input_ids: torch.Tensor preds: torch.Tensor bbox_preds: torch.Tensor scores: torch.Tensor token_probs: torch.Tensor @dataclass class FoundationPrompt: id: int task_name: TaskNames image: np.ndarray text: str math_mode: bool class FoundationPredictor(BasePredictor): model_loader_cls = FoundationModelLoader batch_size = ( settings.RECOGNITION_BATCH_SIZE ) # Default to the recognition batch size torch_dtype = None # No default, loader picks the dtype based on device properties - bf16/fp16 default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 64} encoder_chunk_size: int = 4096 # Default chunk size encoder_chunk_sizes = {"cpu": 4096, "mps": 4096, "cuda": 32768, "xla": 32768} extra_token_count = { "xla": 128 } # We have to pad the XLA cache since we don't use sliding window min_prefill_ratio: int = 1 if settings.FOUNDATION_XLA else 0.2 min_trim_length: int = 50 tasks = { TaskNames.ocr_with_boxes: { "needs_bboxes": True, "img_size": (1024, 512), "max_tokens": 224, }, TaskNames.ocr_without_boxes: { "needs_bboxes": False, "img_size": (1024, 512), "max_tokens": 224, }, TaskNames.block_without_boxes: { "needs_bboxes": False, "img_size": (1024, 512), "max_tokens": 768, }, TaskNames.layout: { "needs_bboxes": False, "img_size": (1024, 1024), "max_tokens": 200, }, TaskNames.table_structure: { "needs_bboxes": False, "img_size": (1024, 512), "max_tokens": 600, }, } def __init__( self, checkpoint=None, device=settings.TORCH_DEVICE_MODEL, dtype=None, attention_implementation: Optional[str] = None, ): super().__init__(checkpoint, device, dtype, attention_implementation) self.prompt_queue = deque() self.batch_prompt_mapping = None self.kv_cache = None self.beacon_token_interval = self.model.config.beacon_token_interval # Setup various tokens on-device self.device_pad_token = torch.tensor( self.processor.pad_token_id, device=self.model.device, dtype=torch.long ) self.device_beacon_token = torch.tensor( self.processor.beacon_token_id, device=self.model.device, dtype=torch.long ) self.special_token_ids = torch.tensor( [self.model.config.image_token_id] + self.model.config.register_token_ids, device=self.model.device, ) self.pad_to_multiple = ( settings.FOUNDATION_PAD_TO_NEAREST if settings.FOUNDATION_STATIC_CACHE else None ) def to(self, device_dtype: torch.device | str | None = None): super().to(device_dtype) self.special_token_ids = self.special_token_ids.to(device_dtype) def get_encoder_chunk_size(self) -> int: if settings.FOUNDATION_CHUNK_SIZE is not None: return settings.FOUNDATION_CHUNK_SIZE chunk_size = self.encoder_chunk_size if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes: if settings.TORCH_DEVICE_MODEL in self.encoder_chunk_sizes: chunk_size = self.encoder_chunk_sizes[settings.TORCH_DEVICE_MODEL] return chunk_size def setup_cache(self, batch_size: int, max_cache_len: int, max_sliding_window: int): kv_cache_cls = StaticOpsCache if settings.FOUNDATION_XLA else DynamicOpsCache self.kv_cache = kv_cache_cls( self.model.config, batch_size, max_cache_len, text_sliding_window=max_sliding_window, device=self.model.device, dtype=self.model.dtype, ) self.prompt_queue.clear() self.batch_prompt_mapping = {i: None for i in range(batch_size)} @property def num_empty_slots(self): return sum(v is None for v in self.batch_prompt_mapping.values()) @property def num_active_slots(self): return len(self.batch_prompt_mapping) - self.num_empty_slots def prepare_input( self, task_names: List[str], images: List[Image.Image], input_text: List[str | None], math_modes: List[bool], ): batch = [] for image, text, task_name, math_mode in zip( images, input_text, task_names, math_modes ): image_size = self.tasks[task_name]["img_size"] try: image = self.processor.scale_to_fit( image, image_size ) # Only resizes if out of bounds (max/min) except cv2.error: # The image is empty if it can't be resized, so just make a blank image image = np.zeros((image_size[1], image_size[0], 3), dtype=np.float32) # Task input is the same for all tasks for now text = text or "" # Remove input text that exceeds max generation tokens (likely invalid) if len(text) > self.tasks[task_name]["max_tokens"]: text = "" inputs = [ {"type": "image", "image": image, "rotated": False}, {"type": "text", "text": text.strip(), "math": math_mode}, ] batch.append({"task": task_name, "inputs": inputs}) return batch def process_outputs( self, outputs: SuryaModelOutput, max_lookahead_tokens: Optional[int] = None ) -> ContinuousBatchOutput: # Predictions are multi-token lm_logits = outputs["lm_logits"].float() # shape: [batch_size, seq_len, V] bbox_logits = outputs["bbox_logits"].float() # shape: [batch_size, seq_len, 6] if ( max_lookahead_tokens is not None and lm_logits.shape[1] > max_lookahead_tokens + 1 ): lm_logits = lm_logits[:, : max_lookahead_tokens + 1, :] bbox_logits = bbox_logits[:, : max_lookahead_tokens + 1, :] # Get predictions preds = torch.argmax(lm_logits, dim=-1) input_ids = preds.to(torch.long) # Confidence scores for all tokens token_probs = F.softmax(lm_logits, dim=-1) scores = torch.max(token_probs, dim=-1).values # shape: [B, T] # Update input boxes box_preds = bbox_logits * self.model.config.bbox_size box_preds = box_preds.to(torch.long) return ContinuousBatchOutput( input_ids=input_ids, preds=preds, bbox_preds=box_preds, scores=scores, token_probs=token_probs, ) # Always left pad with beacons, don't worry about attention masking def maybe_insert_beacon_tokens( self, input_ids: torch.Tensor, input_boxes: torch.Tensor, num_predicted_tokens: torch.Tensor, num_new_tokens: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len = ( input_ids.shape ) # seq_len can be >1 - In case of multi-token predictions # num_predicted tokens **does not include** the current new input_ids, this number is updated **after beacon tokens are inserted** token_positions = num_predicted_tokens + torch.arange( 1, seq_len + 1, device=input_ids.device ).unsqueeze(0) beacon_positions = token_positions % self.beacon_token_interval == 0 # If no beacons needed, return original input needs_beacon = beacon_positions.any(dim=1) # shape: [batch_size] if not needs_beacon.any(): if num_new_tokens is None: num_new_tokens = ( torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * seq_len ) return input_ids, input_boxes, num_new_tokens.squeeze(1) beacon_insert_pos = torch.zeros( batch_size, dtype=torch.long, device=input_ids.device ) for i in range(batch_size): if needs_beacon[i]: # Find first position that needs beacon beacon_insert_pos[i] = torch.where(beacon_positions[i])[0] # Padded input ids. new_input_ids = torch.full( (batch_size, seq_len + 1), self.device_pad_token, dtype=input_ids.dtype, device=input_ids.device, ) new_input_boxes = torch.full( (batch_size, seq_len + 1, 6), -100, dtype=input_boxes.dtype, device=input_boxes.device, ) # Fill in tokens for each sequence for i in range(batch_size): if needs_beacon[i]: insert_pos = beacon_insert_pos[i] new_input_ids[i, insert_pos] = self.device_beacon_token new_input_boxes[i, insert_pos, :] = -100 if insert_pos > 0: new_input_ids[i, :insert_pos] = input_ids[i, :insert_pos] new_input_boxes[i, :insert_pos] = input_boxes[i, :insert_pos] new_input_ids[i, insert_pos + 1 :] = input_ids[i, insert_pos:] new_input_boxes[i, insert_pos + 1 :] = input_boxes[i, insert_pos:] else: new_input_ids[i, 1:] = input_ids[i, :] new_input_boxes[i, 1:] = input_boxes[i, :] # Calculate valid token counts for both padded and non padded sequences valid_token_counts = torch.where( needs_beacon, torch.tensor(seq_len + 1, device=input_ids.device), torch.tensor(seq_len, device=input_ids.device), ) return new_input_ids, new_input_boxes, valid_token_counts def decode( self, current_inputs: Optional[ContinuousBatchInput] = None, max_lookahead_tokens: Optional[int] = None, ): # Note - If we want to use the outputs from the non-last token, we # need to set the cache position manually to ensure causality. The default # behavior only works for the last token currently input_ids = current_inputs.input_ids input_boxes = current_inputs.input_boxes embed_boxes = current_inputs.needs_bbox_embedding position_ids = current_inputs.position_ids num_predicted_tokens = current_inputs.num_predicted_tokens num_valid_tokens = current_inputs.num_valid_tokens batch_size = input_ids.shape[0] # Pre-shift the attention mask based on the cache update self.kv_cache.decode_attention_mask_update( num_valid_tokens=num_valid_tokens, cache_idxs=list(range(batch_size)) ) cache_position = self.get_cache_position( input_ids.shape[1], self.kv_cache.attention_mask, prefill=False ) with settings.INFERENCE_MODE(): outputs = self.model( input_ids=input_ids, attention_mask=self.kv_cache.attention_mask, position_ids=position_ids, cache_position=cache_position, use_cache=True, past_key_values=self.kv_cache, prefill=False, num_valid_tokens=num_valid_tokens, input_boxes=input_boxes, embed_boxes=embed_boxes, logits_to_keep=1, ) processed_output: ContinuousBatchOutput = self.process_outputs( outputs, max_lookahead_tokens=max_lookahead_tokens ) input_ids = processed_output.input_ids input_boxes = processed_output.bbox_preds # Update this **before** inserting beacon tokens tau = settings.FOUNDATION_MULTI_TOKEN_MIN_CONFIDENCE if max_lookahead_tokens is not None: num_new_tokens = torch.clamp( ( processed_output.scores.ge(tau) .to(torch.long) .cumprod(dim=1) .sum(dim=1, keepdim=True) ), min=1, ) else: num_new_tokens = input_ids.shape[1] num_predicted_tokens += num_new_tokens input_ids, input_boxes, num_valid_tokens = self.maybe_insert_beacon_tokens( input_ids, input_boxes, num_predicted_tokens, num_new_tokens ) position_ids = position_ids[:, -1:] + torch.arange( 1, input_ids.shape[1] + 1, device=input_ids.device ) # Some of the input sequences may now have left padding tokens, so we want to account for that # offset is a per-batch offset of the position_ids offset = (input_ids.shape[1] - num_valid_tokens).unsqueeze(1) position_ids -= offset new_input = ContinuousBatchInput( input_ids=input_ids, input_boxes=input_boxes, position_ids=position_ids, num_valid_tokens=num_valid_tokens, num_predicted_tokens=num_predicted_tokens, needs_bbox_embedding=current_inputs.needs_bbox_embedding, ) return new_input, processed_output def pad_and_shift_input_ids_position_ids( self, input_ids: torch.Tensor, bbox_preds: torch.Tensor, position_ids: torch.Tensor, new_seq_len: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Pads new_input_ids to match the new seq len with **left padding** and creates updated position_ids Returns: padded_input_ids (torch.Tensor): [batch_size, current_seq_len] updated_position_ids (torch.Tensor): [batch_size, current_seq_len] """ # No padding if new_seq_len == input_ids.shape[1]: return ( input_ids, bbox_preds, position_ids[:, -1:] + torch.arange(1, new_seq_len + 1, device=self.model.device), ) pad_len = new_seq_len - input_ids.shape[1] padded_input_ids = torch.nn.functional.pad( input_ids, (pad_len, 0), value=self.device_pad_token ) padded_bbox_preds = torch.nn.functional.pad( bbox_preds, (0, 0, pad_len, 0), value=-100 ) # Since we have **left padding**, offset the new position_ids by the amount of padding # This ensures that the **true tokens** get the correct position_ids # The position_ids assigned to pad tokens do not matter. They are not cached, and not used for outputs updated_position_ids = position_ids[:, -1:] + torch.arange( 1, new_seq_len + 1, device=self.model.device ) updated_position_ids -= pad_len return padded_input_ids, padded_bbox_preds, updated_position_ids def get_cache_position( self, seq_len: int, attention_mask: torch.Tensor, prefill: bool, ): batch_size, target_len = attention_mask.shape base_cache_position = ( torch.arange(seq_len, device=attention_mask.device) .unsqueeze(0) .expand(batch_size, -1) ) if prefill: return base_cache_position # This is a (batch_size) tensor, we can add the seq lens here cache_seqlens = ( attention_mask * torch.arange(attention_mask.size(1), device=attention_mask.device) ).argmax(dim=1).to(torch.int32) + 1 # Needs to be unsqueezed so broadcasting works return cache_seqlens.unsqueeze(1) + base_cache_position def prefill( self, current_inputs: Optional[ContinuousBatchInput] = None, max_lookahead_tokens: Optional[int] = None, ): logger.debug(f"Prefilling {self.num_empty_slots} slots") prompts: List[FoundationPrompt] = [ self.prompt_queue.popleft() for _ in range(min(self.num_empty_slots, len(self.prompt_queue))) ] non_active_idxs = [k for k, v in self.batch_prompt_mapping.items() if v is None] idxs_to_merge = non_active_idxs[: len(prompts)] for i, prompt in zip(idxs_to_merge, prompts): self.batch_prompt_mapping[i] = prompt.id needs_bbox_embedding = torch.tensor( [ p.task_name in [TaskNames.layout, TaskNames.table_structure] for p in prompts ], dtype=torch.bool, ) batch_input = self.prepare_input( task_names=[p.task_name for p in prompts], images=[p.image for p in prompts], input_text=[p.text for p in prompts], math_modes=[ p.math_mode for p in prompts ], # Pass math mode to the processor ) processed_inputs = self.processor( batch_input, padding_side="left", device=self.model.device, pad_to_multiple=self.pad_to_multiple, ) input_ids = processed_inputs["input_ids"].to(dtype=torch.long) attention_mask = processed_inputs["attention_mask"].to(dtype=torch.long) position_ids = processed_inputs["position_ids"].to(dtype=torch.long) valid_batch_size = len(idxs_to_merge) # Keep these off device until later image_tiles = processed_inputs["image_tiles"].to(dtype=self.model.dtype) grid_thw = processed_inputs["grid_thw"].to(dtype=torch.long) if settings.FOUNDATION_STATIC_CACHE: input_ids = self.pad_to_batch_size( input_ids, batch_size=self.kv_cache.max_batch_size ) attention_mask = self.pad_to_batch_size( attention_mask, batch_size=self.kv_cache.max_batch_size ) position_ids = self.pad_to_batch_size( position_ids, batch_size=self.kv_cache.max_batch_size ) needs_bbox_embedding = self.pad_to_batch_size( needs_bbox_embedding, batch_size=self.kv_cache.max_batch_size ) # Move to device after padding input_ids = input_ids.to(device=self.model.device) attention_mask = attention_mask.to(device=self.model.device) position_ids = position_ids.to(device=self.model.device) needs_bbox_embedding = needs_bbox_embedding.to(device=self.model.device) # Find text lengths of each # Oddly, this is optimal on GPU - causes a 30% slowdown if "optimized" # Be very careful with the type and device of this - can cause # a big slowdown if put on device is_special = ( (input_ids.unsqueeze(-1) == self.special_token_ids).any(-1).cpu() ) # (batch, seq_len) text_lengths = [] for i in range(input_ids.shape[0]): special_positions = is_special[i].nonzero(as_tuple=True)[0] if len(special_positions) > 0: # Assuming special tokens are contiguous at the start prefix_len = special_positions[-1].item() + 1 else: prefix_len = 0 text_lengths.append(input_ids.shape[1] - prefix_len) text_lengths = torch.tensor(text_lengths, dtype=torch.long) cache_position = self.get_cache_position( input_ids.shape[1], attention_mask, prefill=True ) with settings.INFERENCE_MODE(): image_embeddings = self.model.get_image_embeddings( pixel_values=image_tiles, grid_thw=grid_thw, encoder_chunk_size=self.get_encoder_chunk_size(), valid_batch_size=valid_batch_size, max_batch_size=self.kv_cache.max_batch_size, ) mark_step() outputs = self.model( input_ids=input_ids, image_embeddings=image_embeddings, attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position, inputs_embeds=None, past_key_values=self.kv_cache, use_cache=True, encoder_chunk_size=self.get_encoder_chunk_size(), cache_idxs=idxs_to_merge, prefill=True, num_valid_tokens=None, # Not required during prefill text_lengths=text_lengths, valid_batch_size=valid_batch_size, logits_to_keep=1, ) # Process outputs processed_outputs = self.process_outputs( outputs, max_lookahead_tokens=max_lookahead_tokens ) # Multi-token prediction predicted_tokens = processed_outputs.input_ids.shape[1] num_valid_tokens = ( torch.ones((input_ids.shape[0]), device=self.model.device, dtype=torch.long) * predicted_tokens ) num_predicted_tokens = ( torch.ones( (input_ids.shape[0], 1), device=self.model.device, dtype=torch.long ) * predicted_tokens ) self.kv_cache.prefill_attention_mask_update( attention_mask, idxs_to_merge, valid_batch_size, text_lengths ) self.kv_cache.update_text_counts(idxs_to_merge, valid_batch_size, text_lengths) full_batch = len(idxs_to_merge) == self.kv_cache.max_batch_size # If full batch, then we can ignore current_inputs if current_inputs is None or full_batch: new_seq_len = processed_outputs.input_ids.shape[1] # No padding tokens - So we can safely set position_ids this way position_ids = position_ids[:, -1:] + torch.arange( 1, new_seq_len + 1, device=position_ids.device ) new_input = ContinuousBatchInput( input_ids=processed_outputs.input_ids, input_boxes=processed_outputs.bbox_preds, position_ids=position_ids, num_valid_tokens=num_valid_tokens, num_predicted_tokens=num_predicted_tokens, needs_bbox_embedding=needs_bbox_embedding, ) return ( new_input, processed_outputs, range(processed_outputs.input_ids.shape[0]), ) # Merging inputs for next steps current_input_ids = current_inputs.input_ids current_position_ids = current_inputs.position_ids current_input_boxes = current_inputs.input_boxes current_needs_bbox_embedding = current_inputs.needs_bbox_embedding assert current_input_ids.shape[1] == current_position_ids.shape[1] input_ids, bbox_preds, position_ids = self.pad_and_shift_input_ids_position_ids( processed_outputs.input_ids, processed_outputs.bbox_preds, position_ids, new_seq_len=current_input_ids.shape[1], ) current_input_ids[idxs_to_merge] = input_ids[:valid_batch_size] current_input_boxes[idxs_to_merge] = bbox_preds[:valid_batch_size] current_position_ids[idxs_to_merge] = position_ids[:valid_batch_size] current_num_valid_tokens = current_inputs.num_valid_tokens current_num_valid_tokens[idxs_to_merge] = num_valid_tokens[:valid_batch_size] current_num_predicted_tokens = current_inputs.num_predicted_tokens current_num_predicted_tokens[idxs_to_merge] = num_predicted_tokens[ :valid_batch_size ] current_needs_bbox_embedding[idxs_to_merge] = needs_bbox_embedding[ :valid_batch_size ] new_input = ContinuousBatchInput( input_ids=current_input_ids, input_boxes=current_input_boxes, position_ids=current_position_ids, num_valid_tokens=current_num_valid_tokens, num_predicted_tokens=current_num_predicted_tokens, needs_bbox_embedding=current_needs_bbox_embedding, ) return new_input, processed_outputs, idxs_to_merge def get_max_image_token_count( self, images: list[np.ndarray], tasks: List[TaskNames] ) -> int: def compute_scaled_size( H: int, W: int, max_size: Tuple[int, int] ) -> Tuple[int, int]: max_W, max_H = max_size min_W, min_H = (168, 168) current_pixels = H * W max_pixels = max_H * max_W min_pixels = min_H * min_W current_pixels = max(1, current_pixels) # Avoid zero division if current_pixels > max_pixels: scale = (max_pixels / current_pixels) ** 0.5 return math.floor(H * scale), math.floor(W * scale) elif current_pixels < min_pixels: scale = (min_pixels / current_pixels) ** 0.5 return math.ceil(H * scale), math.ceil(W * scale) return H, W def get_tile_count(H: int, W: int, factor: int) -> int: H_bar = math.ceil(H / factor) * factor W_bar = math.ceil(W / factor) * factor grid_h = H_bar / self.processor.patch_size grid_w = W_bar // self.processor.patch_size return grid_h * grid_w max_tokens = 0 factor = self.processor.patch_size * self.processor.merge_size for image, task in zip(images, tasks): H, W = image.shape[:2] max_size = self.tasks[task]["img_size"] scaled_H, scaled_W = compute_scaled_size(H, W, max_size) token_count = get_tile_count(scaled_H, scaled_W, factor) / ( self.processor.merge_size**2 ) max_tokens = max(max_tokens, token_count) # Extra 10 to account for EOS/BOS/Rotation token etc. return 10 + self.processor.num_register_tokens + int(max_tokens) def prediction_loop( self, images: List[np.ndarray], input_texts: List[str], task_names: List[TaskNames], batch_size: int | None = None, max_tokens: int | None = None, max_sliding_window: int | None = None, math_mode: bool = True, drop_repeated_tokens: bool = True, max_lookahead_tokens: Optional[int] = None, top_k: int = 0, tqdm_desc: str = "Recognizing Text" ) -> tuple: allowed_tasks = self.tasks.keys() assert all([task_name in allowed_tasks for task_name in task_names]), ( f"One or more tasks in {task_names} is not supported. Supported tasks are {allowed_tasks}" ) predicted_tokens = [[] for _ in range(len(images))] scores = [[] for _ in range(len(images))] topk_probs = [[] for _ in range(len(images))] if batch_size is None: batch_size = self.get_batch_size() batch_size = min(len(images), batch_size) current_inputs = None max_image_tokens = self.get_max_image_token_count(images, task_names) if max_sliding_window is None: max_sliding_window = self.model.config.sliding_window self.setup_cache( batch_size, max_cache_len=max_image_tokens + max_sliding_window + self.extra_token_count.get(settings.TORCH_DEVICE_MODEL, 0), max_sliding_window=max_sliding_window, ) batch_max_tokens = {} for idx, (img, txt, task) in enumerate(zip(images, input_texts, task_names)): self.prompt_queue.append( FoundationPrompt( id=idx, task_name=task, text=txt, image=img, math_mode=math_mode ) ) batch_max_tokens[idx] = ( max_tokens or settings.FOUNDATION_MAX_TOKENS or self.tasks[task]["max_tokens"] ) overall_max_tokens = max(batch_max_tokens.values()) pbar = tqdm( total=len(self.prompt_queue), desc=tqdm_desc, disable=self.disable_tqdm, ) batch_bboxes = torch.zeros(len(images), overall_max_tokens, 6) batch_pos = [0] * len(images) while self.prompt_queue or self.num_active_slots > 0: if ( self.num_empty_slots / batch_size ) >= self.min_prefill_ratio and self.prompt_queue: updated_inputs, outputs, merge_idxs = self.prefill( current_inputs, max_lookahead_tokens=0 ) predicted_tokens_cpu = outputs.preds.cpu() scores_cpu = outputs.scores.cpu() bbox_preds_cpu = outputs.bbox_preds.cpu() if top_k > 0: batch_top_k_probs, batch_top_k_indices = torch.topk( outputs.token_probs, k=top_k, dim=-1 ) batch_top_k_probs_cpu = batch_top_k_probs.cpu() batch_top_k_indices_cpu = batch_top_k_indices.cpu() for temp_idx, b_idx in enumerate(merge_idxs): if self.batch_prompt_mapping[b_idx] is not None: p_idx = self.batch_prompt_mapping[b_idx] seq_len = predicted_tokens_cpu.shape[1] for t_idx in range(seq_len): token = predicted_tokens_cpu[temp_idx, t_idx].item() predicted_tokens[p_idx].append(token) batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[ temp_idx, t_idx ] batch_pos[p_idx] += 1 scores[p_idx].append(scores_cpu[temp_idx, t_idx].item()) if top_k > 0: top_k_scores = { batch_top_k_indices_cpu[temp_idx, t_idx][ k ].item(): batch_top_k_probs_cpu[temp_idx, t_idx][ k ].item() for k in range(top_k) } topk_probs[p_idx].append(top_k_scores) if token in [ self.processor.eos_token_id, self.processor.no_output_token, ]: self.batch_prompt_mapping[b_idx] = None pbar.update(1) break else: updated_inputs, outputs = self.decode( current_inputs, max_lookahead_tokens=max_lookahead_tokens ) mark_step() predicted_tokens_cpu = outputs.preds.cpu() scores_cpu = outputs.scores.cpu() bbox_preds_cpu = outputs.bbox_preds.cpu() if top_k > 0: batch_top_k_probs, batch_top_k_indices = torch.topk( outputs.token_probs, k=top_k, dim=-1 ) batch_top_k_probs_cpu = batch_top_k_probs.cpu() batch_top_k_indices_cpu = batch_top_k_indices.cpu() for b_idx, p_idx in self.batch_prompt_mapping.items(): if p_idx is not None: seq_len = predicted_tokens_cpu.shape[1] num_tokens = updated_inputs.num_valid_tokens[b_idx].item() should_stop = False for t_idx in range(seq_len): # don't use multitoken prediction for lower confidence tokens if t_idx > 0 and num_tokens < seq_len: # roll so tokens are right aligned updated_inputs.input_ids[b_idx] = ( updated_inputs.input_ids[b_idx].roll( shifts=seq_len - num_tokens, dims=0 ) ) # don't need to roll position_ids because that's handled in `decode` (and when we do beacon tokens) break token = predicted_tokens_cpu[b_idx, t_idx].item() predicted_tokens[p_idx].append(token) batch_bboxes[p_idx, batch_pos[p_idx]] = bbox_preds_cpu[ b_idx, t_idx ] batch_pos[p_idx] += 1 scores[p_idx].append(scores_cpu[b_idx, t_idx].item()) if top_k > 0: top_k_scores = { batch_top_k_indices_cpu[temp_idx, t_idx][ k ].item(): batch_top_k_probs_cpu[temp_idx, t_idx][ k ].item() for k in range(top_k) } topk_probs[p_idx].append(top_k_scores) repeats = len(predicted_tokens[p_idx]) >= batch_max_tokens[ p_idx ] or ( drop_repeated_tokens and detect_repeat_token(predicted_tokens[p_idx]) and task_names[p_idx] in [ TaskNames.ocr_with_boxes, TaskNames.ocr_without_boxes, ] ) if ( token in [ self.processor.eos_token_id, self.processor.pad_token_id, ] or repeats ): should_stop = True break if should_stop: self.batch_prompt_mapping[b_idx] = None pbar.update(1) # Update inputs and mark XLA step current_inputs = updated_inputs pbar.close() del self.kv_cache self.kv_cache = None torch.cuda.empty_cache() return predicted_tokens, batch_bboxes, scores, topk_probs ``` -------------------------------------------------------------------------------- /surya/common/donut/encoder.py: -------------------------------------------------------------------------------- ```python import collections.abc import math from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers.activations import ACT2FN from transformers.pytorch_utils import ( find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, ) from transformers.utils import ModelOutput from transformers import DonutSwinConfig from surya.common.pretrained import SuryaPreTrainedModel from surya.common.xla import mark_step _EXPECTED_OUTPUT_SHAPE = [1, 49, 1024] @dataclass # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin class DonutSwinEncoderOutput(ModelOutput): last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None @dataclass class DonutSwinModelOutput(ModelOutput): last_hidden_state: torch.FloatTensor = None # Copied from transformers.models.swin.modeling_swin.window_partition def window_partition(input_feature, window_size): """ Partitions the given input into windows. """ batch_size, height, width, num_channels = input_feature.shape input_feature = input_feature.view( batch_size, height // window_size, window_size, width // window_size, window_size, num_channels, ) windows = ( input_feature.permute(0, 1, 3, 2, 4, 5) .contiguous() .view(-1, window_size, window_size, num_channels) ) return windows # Copied from transformers.models.swin.modeling_swin.window_reverse def window_reverse(windows, window_size, height, width): """ Merges windows to produce higher resolution features. """ num_channels = windows.shape[-1] windows = windows.view( -1, height // window_size, width // window_size, window_size, window_size, num_channels, ) windows = ( windows.permute(0, 1, 3, 2, 4, 5) .contiguous() .view(-1, height, width, num_channels) ) return windows # Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin class DonutSwinEmbeddings(nn.Module): """ Construct the patch and position embeddings. Optionally, also the mask token. """ def __init__(self, config, use_mask_token=False): super().__init__() self.patch_embeddings = DonutSwinPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.patch_grid = self.patch_embeddings.grid_size self.mask_token = ( nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None ) self.position_embeddings = None self.row_embeddings = None self.column_embeddings = None if config.use_absolute_embeddings: self.position_embeddings = nn.Parameter( torch.zeros(1, num_patches + 1, config.embed_dim) ) if hasattr(config, "use_2d_embeddings") and config.use_2d_embeddings: self.row_embeddings = nn.Parameter( torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim) ) self.column_embeddings = nn.Parameter( torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim) ) self.norm = nn.LayerNorm(config.embed_dim) def interpolate_pos_encoding( self, embeddings: torch.Tensor, height: int, width: int ) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. Source: https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 if num_patches == num_positions and height == width: return self.position_embeddings class_pos_embed = self.position_embeddings[:, 0] patch_pos_embed = self.position_embeddings[:, 1:] dim = embeddings.shape[-1] h0 = height // self.config.patch_size w0 = width // self.config.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 h0, w0 = h0 + 0.1, w0 + 0.1 patch_pos_embed = patch_pos_embed.reshape( 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim ) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def forward( self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None, interpolate_pos_encoding: bool = False, ) -> Tuple[torch.Tensor]: _, num_channels, height, width = pixel_values.shape embeddings, output_dimensions = self.patch_embeddings(pixel_values) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() if bool_masked_pos is not None: mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) # replace the masked visual tokens by mask_tokens mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) embeddings = embeddings * (1.0 - mask) + mask_tokens * mask if self.position_embeddings is not None: if interpolate_pos_encoding: embeddings = embeddings + self.interpolate_pos_encoding( embeddings, height, width ) else: embeddings = embeddings + self.position_embeddings[:, :seq_len] if self.row_embeddings is not None and self.column_embeddings is not None: # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ... row_embeddings = self.row_embeddings[ :, : output_dimensions[0], : ].repeat_interleave(output_dimensions[1], dim=1) column_embeddings = self.column_embeddings[ :, : output_dimensions[1], : ].repeat(1, output_dimensions[0], 1) embeddings = embeddings + row_embeddings + column_embeddings return embeddings, output_dimensions # Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin class DonutSwinPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.embed_dim image_size = ( image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) ) patch_size = ( patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) ) num_patches = (image_size[1] // patch_size[1]) * ( image_size[0] // patch_size[0] ) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.grid_size = ( image_size[0] // patch_size[0], image_size[1] // patch_size[1], ) self.projection = nn.Conv2d( num_channels, hidden_size, kernel_size=patch_size, stride=patch_size ) def maybe_pad(self, pixel_values, height, width): if width % self.patch_size[1] != 0: pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) pixel_values = nn.functional.pad(pixel_values, pad_values) if height % self.patch_size[0] != 0: pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) pixel_values = nn.functional.pad(pixel_values, pad_values) return pixel_values def forward( self, pixel_values: Optional[torch.FloatTensor] ) -> Tuple[torch.Tensor, Tuple[int]]: _, num_channels, height, width = pixel_values.shape # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) embeddings = self.projection(pixel_values) _, _, height, width = embeddings.shape output_dimensions = (height, width) embeddings = embeddings.flatten(2).transpose(1, 2) return embeddings, output_dimensions # Copied from transformers.models.swin.modeling_swin.SwinPatchMerging class DonutSwinPatchMerging(nn.Module): """ Patch Merging Layer. Args: input_resolution (`Tuple[int]`): Resolution of input feature. dim (`int`): Number of input channels. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): Normalization layer class. """ def __init__( self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm, ) -> None: super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def maybe_pad(self, input_feature, height, width): should_pad = (height % 2 == 1) or (width % 2 == 1) if should_pad: pad_values = (0, 0, 0, width % 2, 0, height % 2) input_feature = nn.functional.pad(input_feature, pad_values) return input_feature def forward( self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int] ) -> torch.Tensor: height, width = input_dimensions # `dim` is height * width batch_size, dim, num_channels = input_feature.shape input_feature = input_feature.view(batch_size, height, width, num_channels) # pad input to be disible by width and height, if needed input_feature = self.maybe_pad(input_feature, height, width) # [batch_size, height/2, width/2, num_channels] input_feature_0 = input_feature[:, 0::2, 0::2, :] # [batch_size, height/2, width/2, num_channels] input_feature_1 = input_feature[:, 1::2, 0::2, :] # [batch_size, height/2, width/2, num_channels] input_feature_2 = input_feature[:, 0::2, 1::2, :] # [batch_size, height/2, width/2, num_channels] input_feature_3 = input_feature[:, 1::2, 1::2, :] # batch_size height/2 width/2 4*num_channels input_feature = torch.cat( [input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1 ) input_feature = input_feature.view( batch_size, -1, 4 * num_channels ) # batch_size height/2*width/2 4*C input_feature = self.norm(input_feature) input_feature = self.reduction(input_feature) return input_feature # Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin class DonutSwinSelfAttention(nn.Module): def __init__(self, config, dim, num_heads, num_kv_heads, window_size): super().__init__() if dim % num_heads != 0: raise ValueError( f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" ) self.num_attention_heads = num_heads self.num_kv_heads = num_kv_heads self.kv_repeats = self.num_attention_heads // self.num_kv_heads self.attention_head_size = int(dim / num_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.kv_head_size = self.num_kv_heads * self.attention_head_size self.window_size = ( window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) ) self.relative_position_bias_table = nn.Parameter( torch.zeros( (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads ) ) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += self.window_size[0] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index) self.query = nn.Linear( self.all_head_size, self.all_head_size, bias=config.qkv_bias ) self.key = nn.Linear( self.all_head_size, self.kv_head_size, bias=config.qkv_bias ) self.value = nn.Linear( self.all_head_size, self.kv_head_size, bias=config.qkv_bias ) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def transpose_kv_for_scores(self, x, repeats): new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size) x = x.view(new_x_shape) x = x.repeat( 1, 1, repeats, 1 ) # repeat the values for each key-value head to match query dim return x.permute(0, 2, 1, 3).contiguous() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape mixed_query_layer = self.query(hidden_states) # Final is (batch_size, num_attention_heads, seq_len, attention_head_size) key_layer = self.transpose_kv_for_scores( self.key(hidden_states), self.kv_repeats ) value_layer = self.transpose_kv_for_scores( self.value(hidden_states), self.kv_repeats ) query_layer = self.transpose_for_scores(mixed_query_layer) relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ] relative_position_bias = relative_position_bias.view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1, ) relative_position_bias = ( relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) ) relative_position_bias = relative_position_bias.repeat(batch_size, 1, 1, 1) if attention_mask is None: attention_mask = relative_position_bias else: mask_shape = attention_mask.shape[0] repeat_count = batch_size // mask_shape attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1) attention_mask = attention_mask + relative_position_bias attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask=attention_mask, dropout_p=0.0, scale=self.attention_head_size**-0.5, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, dim, num_channels) outputs = (attn_output,) return outputs # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput class DonutSwinSelfOutput(nn.Module): def __init__(self, config, dim): super().__init__() self.dense = nn.Linear(dim, dim) def forward( self, hidden_states: torch.Tensor, input_tensor: torch.Tensor ) -> torch.Tensor: return self.dense(hidden_states) # Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin class DonutSwinAttention(nn.Module): def __init__(self, config, dim, num_heads, num_kv_heads, window_size): super().__init__() self.self = DonutSwinSelfAttention( config, dim, num_heads, num_kv_heads, window_size ) self.output = DonutSwinSelfOutput(config, dim) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads, ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = ( self.self.attention_head_size * self.self.num_attention_heads ) self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, attention_mask, head_mask, output_attentions ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[ 1: ] # add attentions if we output them return outputs # Copied from transformers.models.swin.modeling_swin.SwinIntermediate class DonutSwinIntermediate(nn.Module): def __init__(self, config, dim): super().__init__() self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states # Copied from transformers.models.swin.modeling_swin.SwinOutput class DonutSwinOutput(nn.Module): def __init__(self, config, dim): super().__init__() self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.dense(hidden_states) # Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin class DonutSwinLayer(nn.Module): def __init__( self, config, dim, input_resolution, num_heads, num_kv_heads, shift_size=0 ): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.shift_size = shift_size self.window_size = config.window_size self.input_resolution = input_resolution self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = DonutSwinAttention( config, dim, num_heads, num_kv_heads, window_size=self.window_size ) self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.intermediate = DonutSwinIntermediate(config, dim) self.output = DonutSwinOutput(config, dim) def set_shift_and_window_size(self, input_resolution): if min(input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = int(0) self.window_size = ( torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) ) def get_attn_mask(self, height, width, dtype, device): if self.shift_size > 0: # calculate attention mask for SW-MSA img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) height_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) width_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None), ) count = 0 for height_slice in height_slices: for width_slice in width_slices: img_mask[:, height_slice, width_slice, :] = count count += 1 mask_windows = window_partition(img_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill( attn_mask != 0, float(-100.0) ).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None return attn_mask def maybe_pad(self, hidden_states, height, width): pad_right = (self.window_size - width % self.window_size) % self.window_size pad_bottom = (self.window_size - height % self.window_size) % self.window_size pad_values = (0, 0, 0, pad_right, 0, pad_bottom) hidden_states = nn.functional.pad(hidden_states, pad_values) return hidden_states, pad_values def forward( self, hidden_states: torch.Tensor, input_dimensions: Tuple[int, int], head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, always_partition: Optional[bool] = False, ) -> Tuple[torch.Tensor, torch.Tensor]: if not always_partition: self.set_shift_and_window_size(input_dimensions) else: pass height, width = input_dimensions batch_size, _, channels = hidden_states.size() shortcut = hidden_states hidden_states = self.layernorm_before(hidden_states) hidden_states = hidden_states.view(batch_size, height, width, channels) # pad hidden_states to multiples of window size hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) _, height_pad, width_pad, _ = hidden_states.shape # cyclic shift if self.shift_size > 0: shifted_hidden_states = torch.roll( hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) ) else: shifted_hidden_states = hidden_states # partition windows hidden_states_windows = window_partition( shifted_hidden_states, self.window_size ) hidden_states_windows = hidden_states_windows.view( -1, self.window_size * self.window_size, channels ) attn_mask = self.get_attn_mask( height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device, ) attention_outputs = self.attention( hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions, ) attention_output = attention_outputs[0] attention_windows = attention_output.view( -1, self.window_size, self.window_size, channels ) shifted_windows = window_reverse( attention_windows, self.window_size, height_pad, width_pad ) # reverse cyclic shift if self.shift_size > 0: attention_windows = torch.roll( shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2) ) else: attention_windows = shifted_windows was_padded = pad_values[3] > 0 or pad_values[5] > 0 if was_padded: attention_windows = attention_windows[:, :height, :width, :].contiguous() attention_windows = attention_windows.view(batch_size, height * width, channels) hidden_states = shortcut + attention_windows layer_output = self.layernorm_after(hidden_states) layer_output = self.intermediate(layer_output) layer_output = hidden_states + self.output(layer_output) layer_outputs = ( (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) ) return layer_outputs # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin class DonutSwinStage(nn.Module): def __init__( self, config, layer_num, dim, input_resolution, depth, num_heads, num_kv_heads, downsample, ): super().__init__() self.config = config self.dim = dim self.blocks = nn.ModuleList( [ DonutSwinLayer( config=config, dim=dim, input_resolution=input_resolution, num_heads=num_heads, num_kv_heads=num_kv_heads, shift_size=0 if (i % 2 == 0) else config.window_size // 2, ) for i in range(depth) ] ) # patch merging layer if downsample is not None: self.downsample = downsample( input_resolution, dim=dim, norm_layer=nn.LayerNorm ) else: self.downsample = None self.pointing = False self.positional_encoding = None if config.use_positional_embeddings: self.positional_encoding = self.build_2d_sincos_position_embedding( input_resolution[1], input_resolution[0], embed_dim=dim, ) @staticmethod def build_2d_sincos_position_embedding( width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32, ): grid_w = torch.arange(int(width), dtype=dtype, device=device) grid_h = torch.arange(int(height), dtype=dtype, device=device) grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") if embed_dim % 4 != 0: raise ValueError( "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" ) pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim omega = 1.0 / (temperature**omega) out_w = grid_w.flatten()[..., None] @ omega[None] out_h = grid_h.flatten()[..., None] @ omega[None] return torch.concat( [out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1 )[None, :, :] def forward( self, hidden_states: torch.Tensor, input_dimensions: Tuple[int, int], head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, always_partition: Optional[bool] = False, ) -> Tuple[torch.Tensor]: height, width = input_dimensions if self.positional_encoding is not None: hidden_states = hidden_states + self.positional_encoding.to( hidden_states.dtype ).to(hidden_states.device) for i, layer_module in enumerate(self.blocks): layer_head_mask = head_mask[i] if head_mask is not None else None layer_outputs = layer_module( hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition, ) hidden_states = layer_outputs[0] hidden_states_before_downsampling = hidden_states if self.downsample is not None: height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 output_dimensions = (height, width, height_downsampled, width_downsampled) hidden_states = self.downsample( hidden_states_before_downsampling, input_dimensions ) else: output_dimensions = (height, width, height, width) stage_outputs = ( hidden_states, hidden_states_before_downsampling, output_dimensions, ) if output_attentions: stage_outputs += layer_outputs[1:] return stage_outputs # Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin class DonutSwinEncoder(nn.Module): def __init__(self, config, grid_size): super().__init__() self.num_layers = len(config.depths) self.config = config self.layers = nn.ModuleList( [ DonutSwinStage( config=config, layer_num=i_layer, dim=int(config.embed_dim * 2**i_layer), input_resolution=( grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer), ), depth=config.depths[i_layer], num_heads=config.num_heads[i_layer], num_kv_heads=config.num_kv_heads[i_layer] if hasattr(config, "num_kv_heads") else config.num_heads[i_layer], downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None, ) for i_layer in range(self.num_layers) ] ) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, input_dimensions: Tuple[int, int], head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, output_hidden_states_before_downsampling: Optional[bool] = False, always_partition: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple, DonutSwinEncoderOutput]: all_hidden_states = () if output_hidden_states else None all_reshaped_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if output_hidden_states: batch_size, _, hidden_size = hidden_states.shape # rearrange b (h w) c -> b c h w reshaped_hidden_state = hidden_states.view( batch_size, *input_dimensions, hidden_size ) reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) all_hidden_states += (hidden_states,) all_reshaped_hidden_states += (reshaped_hidden_state,) for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition, ) else: layer_outputs = layer_module( hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition, ) hidden_states = layer_outputs[0] hidden_states_before_downsampling = layer_outputs[1] output_dimensions = layer_outputs[2] input_dimensions = (output_dimensions[-2], output_dimensions[-1]) if output_hidden_states and output_hidden_states_before_downsampling: batch_size, _, hidden_size = hidden_states_before_downsampling.shape # rearrange b (h w) c -> b c h w # here we use the original (not downsampled) height and width reshaped_hidden_state = hidden_states_before_downsampling.view( batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size, ) reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) all_hidden_states += (hidden_states_before_downsampling,) all_reshaped_hidden_states += (reshaped_hidden_state,) elif output_hidden_states and not output_hidden_states_before_downsampling: batch_size, _, hidden_size = hidden_states.shape # rearrange b (h w) c -> b c h w reshaped_hidden_state = hidden_states.view( batch_size, *input_dimensions, hidden_size ) reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) all_hidden_states += (hidden_states,) all_reshaped_hidden_states += (reshaped_hidden_state,) if output_attentions: all_self_attentions += layer_outputs[3:] if not return_dict: return tuple( v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None ) return DonutSwinEncoderOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, reshaped_hidden_states=all_reshaped_hidden_states, ) # Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin class DonutSwinPreTrainedModel(SuryaPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = DonutSwinConfig base_model_prefix = "swin" main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["DonutSwinStage"] def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) ``` -------------------------------------------------------------------------------- /surya/ocr_error/model/encoder.py: -------------------------------------------------------------------------------- ```python from __future__ import annotations import math from typing import Optional, Set, List, Tuple, Union, Dict import numpy as np import torch from torch import nn from torch.nn import functional as F, MSELoss, CrossEntropyLoss, BCEWithLogitsLoss from transformers import apply_chunking_to_forward from transformers.activations import get_activation from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput from transformers.pytorch_utils import ( find_pruneable_heads_and_indices, prune_linear_layer, ) from transformers.utils import ( is_flash_attn_greater_or_equal_2_10, ) from surya.common.pretrained import SuryaPreTrainedModel from surya.common.s3 import S3DownloaderMixin from surya.ocr_error.model.config import DistilBertConfig def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): position_enc = np.array( [ [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos) ] ) out.requires_grad = False out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) out.detach_() class Embeddings(nn.Module): def __init__(self, config: DistilBertConfig): super().__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.dim, padding_idx=config.pad_token_id ) self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.dim ) self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) self.dropout = nn.Dropout(config.dropout) self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False, ) def forward( self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Parameters: input_ids (torch.Tensor): torch.tensor(bs, max_seq_length) The token ids to embed. input_embeds (*optional*, torch.Tensor): The pre-computed word embeddings. Can only be passed if the input ids are `None`. Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type embeddings) """ if input_ids is not None: input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) seq_length = input_embeds.size(1) # Setting the position-ids to the registered buffer in constructor, it helps # when tracing the model without passing position-ids, solves # isues similar to issue #5664 if hasattr(self, "position_ids"): position_ids = self.position_ids[:, :seq_length] else: position_ids = torch.arange( seq_length, dtype=torch.long, device=input_ids.device ) # (max_seq_length) position_ids = position_ids.unsqueeze(0).expand_as( input_ids ) # (bs, max_seq_length) position_embeddings = self.position_embeddings( position_ids ) # (bs, max_seq_length, dim) embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim) embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) return embeddings class MultiHeadSelfAttention(nn.Module): def __init__(self, config: DistilBertConfig): super().__init__() self.config = config self.n_heads = config.n_heads self.dim = config.dim self.dropout = nn.Dropout(p=config.attention_dropout) self.is_causal = False # Have an even number of multi heads that divide the dimensions if self.dim % self.n_heads != 0: # Raise value errors for even multi-head attention nodes raise ValueError( f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly" ) self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) self.pruned_heads: Set[int] = set() self.attention_head_size = self.dim // self.n_heads def prune_heads(self, heads: List[int]): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.n_heads, self.attention_head_size, self.pruned_heads ) # Prune linear layers self.q_lin = prune_linear_layer(self.q_lin, index) self.k_lin = prune_linear_layer(self.k_lin, index) self.v_lin = prune_linear_layer(self.v_lin, index) self.out_lin = prune_linear_layer(self.out_lin, index, dim=1) # Update hyper params self.n_heads = self.n_heads - len(heads) self.dim = self.attention_head_size * self.n_heads self.pruned_heads = self.pruned_heads.union(heads) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, ...]: """ Parameters: query: torch.tensor(bs, seq_length, dim) key: torch.tensor(bs, seq_length, dim) value: torch.tensor(bs, seq_length, dim) mask: torch.tensor(bs, seq_length) Returns: weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` """ bs, q_length, dim = query.size() k_length = key.size(1) # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' # assert key.size() == value.size() dim_per_head = self.dim // self.n_heads mask_reshp = (bs, 1, 1, k_length) def shape(x: torch.Tensor) -> torch.Tensor: """separate heads""" return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) def unshape(x: torch.Tensor) -> torch.Tensor: """group heads""" return ( x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) ) q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) mask = ( (mask == 0).view(mask_reshp).expand_as(scores) ) # (bs, n_heads, q_length, k_length) scores = scores.masked_fill( mask, torch.tensor(torch.finfo(scores.dtype).min) ) # (bs, n_heads, q_length, k_length) weights = nn.functional.softmax( scores, dim=-1 ) # (bs, n_heads, q_length, k_length) weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) # Mask heads if we want to if head_mask is not None: weights = weights * head_mask context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) context = unshape(context) # (bs, q_length, dim) context = self.out_lin(context) # (bs, q_length, dim) if output_attentions: return (context, weights) else: return (context,) class DistilBertFlashAttention2(MultiHeadSelfAttention): """ DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, ...]: """ Parameters: query: torch.tensor(bs, seq_length, dim) key: torch.tensor(bs, seq_length, dim) value: torch.tensor(bs, seq_length, dim) mask: torch.tensor(bs, seq_length) Returns: weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` """ batch_size, q_length, dim = query.size() dim_per_head = self.dim // self.n_heads def reshape(x: torch.Tensor) -> torch.Tensor: """separate heads""" return x.view(batch_size, -1, self.n_heads, dim_per_head) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim query_states = reshape(self.q_lin(query)) key_states = reshape(self.k_lin(key)) value_states = reshape(self.v_lin(value)) attn_dropout = self.config.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) if query_states.dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_lin.weight.dtype query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_weights = self._flash_attention_forward( query_states, key_states, value_states, mask, q_length, dropout=attn_dropout ) attn_weights_reshaped = attn_weights.reshape( batch_size, q_length, self.n_heads * dim_per_head ) attn_output = self.out_lin(attn_weights_reshaped) if output_attentions: return (attn_output, attn_weights) else: return (attn_output,) # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`float`): Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import pad_input if not self._flash_attn_uses_top_left_mask: causal = self.is_causal else: # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. causal = self.is_causal and query_length != 1 # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] ( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens, ) = self._upad_input( query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, ) attn_output = pad_input( attn_output_unpad, indices_q, batch_size, query_length ) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, ) return attn_output # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads def _upad_input( self, query_layer, key_layer, value_layer, attention_mask, query_length ): from flash_attn.bert_padding import index_first_axis, unpad_input indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), indices_k, ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( query_layer, attention_mask ) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) class FFN(nn.Module): def __init__(self, config: DistilBertConfig): super().__init__() self.dropout = nn.Dropout(p=config.dropout) self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim) self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim) self.activation = get_activation(config.activation) def forward(self, input: torch.Tensor) -> torch.Tensor: return apply_chunking_to_forward( self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input ) def ff_chunk(self, input: torch.Tensor) -> torch.Tensor: x = self.lin1(input) x = self.activation(x) x = self.lin2(x) x = self.dropout(x) return x DISTILBERT_ATTENTION_CLASSES = { "eager": MultiHeadSelfAttention, "flash_attention_2": DistilBertFlashAttention2, } class TransformerBlock(nn.Module): def __init__(self, config: DistilBertConfig): super().__init__() # Have an even number of Configure multi-heads if config.dim % config.n_heads != 0: raise ValueError( f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly" ) self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation]( config ) self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) self.ffn = FFN(config) self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, ...]: """ Parameters: x: torch.tensor(bs, seq_length, dim) attn_mask: torch.tensor(bs, seq_length) Returns: sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output: torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization. """ # Self-Attention sa_output = self.attention( query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask, output_attentions=output_attentions, ) if output_attentions: sa_output, sa_weights = ( sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) ) else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples sa_output = sa_output[0] sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) # Feed Forward Network ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) ffn_output: torch.Tensor = self.output_layer_norm( ffn_output + sa_output ) # (bs, seq_length, dim) output = (ffn_output,) if output_attentions: output = (sa_weights,) + output return output class Transformer(nn.Module): def __init__(self, config: DistilBertConfig): super().__init__() self.n_layers = config.n_layers self.layer = nn.ModuleList( [TransformerBlock(config) for _ in range(config.n_layers)] ) self.gradient_checkpointing = False def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: Optional[bool] = None, ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore """ Parameters: x: torch.tensor(bs, seq_length, dim) Input sequence embedded. attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence. Returns: hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top) layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)] Tuple of length n_layers with the hidden states from each layer. Optional: only if output_hidden_states=True all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] Tuple of length n_layers with the attention weights from each layer Optional: only if output_attentions=True """ all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_state = x for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_state, attn_mask, head_mask[i], output_attentions, ) else: layer_outputs = layer_module( hidden_state, attn_mask, head_mask[i], output_attentions, ) hidden_state = layer_outputs[-1] if output_attentions: if len(layer_outputs) != 2: raise ValueError( f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}" ) attentions = layer_outputs[0] all_attentions = all_attentions + (attentions,) else: if len(layer_outputs) != 1: raise ValueError( f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}" ) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state,) if not return_dict: return tuple( v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None ) return BaseModelOutput( last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions, ) class DistilBertPreTrainedModel(SuryaPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = DistilBertConfig load_tf_weights = None base_model_prefix = "distilbert" supports_gradient_checkpointing = True _supports_flash_attn_2 = True def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds: create_sinusoidal_embeddings( self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight, ) class DistilBertModel(DistilBertPreTrainedModel): def __init__(self, config: DistilBertConfig): super().__init__(config) self.embeddings = Embeddings(config) # Embeddings self.transformer = Transformer(config) # Encoder self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" # Initialize weights and apply final processing self.post_init() def get_position_embeddings(self) -> nn.Embedding: """ Returns the position embeddings """ return self.embeddings.position_embeddings def resize_position_embeddings(self, new_num_position_embeddings: int): """ Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. Arguments: new_num_position_embeddings (`int`): The number of new position embedding matrix. If position embeddings are learned, increasing the size will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will add correct vectors at the end following the position encoding algorithm, whereas reducing the size will remove vectors from the end. """ num_position_embeds_diff = ( new_num_position_embeddings - self.config.max_position_embeddings ) # no resizing needs to be done if the length stays the same if num_position_embeds_diff == 0: return self.config.max_position_embeddings = new_num_position_embeddings old_position_embeddings_weight = ( self.embeddings.position_embeddings.weight.clone() ) self.embeddings.position_embeddings = nn.Embedding( self.config.max_position_embeddings, self.config.dim ) if self.config.sinusoidal_pos_embds: create_sinusoidal_embeddings( n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight, ) else: with torch.no_grad(): if num_position_embeds_diff > 0: self.embeddings.position_embeddings.weight[ :-num_position_embeds_diff ] = nn.Parameter(old_position_embeddings_weight) else: self.embeddings.position_embeddings.weight = nn.Parameter( old_position_embeddings_weight[:num_position_embeds_diff] ) # move position_embeddings to correct device self.embeddings.position_embeddings.to(self.device) def get_input_embeddings(self) -> nn.Embedding: return self.embeddings.word_embeddings def set_input_embeddings(self, new_embeddings: nn.Embedding): self.embeddings.word_embeddings = new_embeddings def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.transformer.layer[layer].attention.prune_heads(heads) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) if self._use_flash_attention_2: attention_mask = ( attention_mask if (attention_mask is not None and 0 in attention_mask) else None ) else: if attention_mask is None: attention_mask = torch.ones( input_shape, device=device ) # (bs, seq_length) return self.transformer( x=embeddings, attn_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) class DistilBertForSequenceClassification(S3DownloaderMixin, DistilBertPreTrainedModel): def __init__(self, config: DistilBertConfig, **kwargs): super().__init__(config, **kwargs) self.num_labels = config.num_labels self.config = config self.distilbert = DistilBertModel(config) self.pre_classifier = nn.Linear(config.dim, config.dim) self.classifier = nn.Linear(config.dim, config.num_labels) self.dropout = nn.Dropout(config.seq_classif_dropout) # Initialize weights and apply final processing self.post_init() def get_position_embeddings(self) -> nn.Embedding: """ Returns the position embeddings """ return self.distilbert.get_position_embeddings() def resize_position_embeddings(self, new_num_position_embeddings: int): """ Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. Arguments: new_num_position_embeddings (`int`): The number of new position embedding matrix. If position embeddings are learned, increasing the size will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will add correct vectors at the end following the position encoding algorithm, whereas reducing the size will remove vectors from the end. """ self.distilbert.resize_position_embeddings(new_num_position_embeddings) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) distilbert_output = self.distilbert( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_state = distilbert_output[0] # (bs, seq_len, dim) pooled_output = hidden_state[:, 0] # (bs, dim) pooled_output = self.pre_classifier(pooled_output) # (bs, dim) pooled_output = nn.ReLU()(pooled_output) # (bs, dim) pooled_output = self.dropout(pooled_output) # (bs, dim) logits = self.classifier(pooled_output) # (bs, num_labels) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + distilbert_output[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=distilbert_output.hidden_states, attentions=distilbert_output.attentions, ) ```