#
tokens: 45016/50000 8/133 files (page 3/4)
lines: off (toggle) GitHub
raw markdown copy
This is page 3 of 4. Use http://codebase.md/datalab-to/surya?lines=false&page={x} to view the full context.

# Directory Structure

```
├── .github
│   ├── ISSUE_TEMPLATE
│   │   ├── breaking-bug-report.md
│   │   ├── feature_request.md
│   │   └── output-bug-report.md
│   └── workflows
│       ├── benchmarks.yml
│       ├── ci.yml
│       ├── cla.yml
│       ├── publish.yml
│       └── scripts.yml
├── .gitignore
├── .pre-commit-config.yaml
├── benchmark
│   ├── detection.py
│   ├── layout.py
│   ├── ordering.py
│   ├── recognition.py
│   ├── table_recognition.py
│   ├── texify.py
│   └── utils
│       ├── __init__.py
│       ├── bbox.py
│       ├── metrics.py
│       ├── scoring.py
│       ├── tatr.py
│       ├── tesseract.py
│       ├── textract.py
│       └── verify_benchmark_scores.py
├── CITATION.cff
├── CLA.md
├── detect_layout.py
├── detect_text.py
├── LICENSE
├── ocr_app.py
├── ocr_latex.py
├── ocr_text.py
├── poetry.lock
├── pyproject.toml
├── pytest.ini
├── README.md
├── signatures
│   └── version1
│       └── cla.json
├── static
│   ├── fonts
│   │   └── .gitignore
│   └── images
│       ├── arabic_layout.jpg
│       ├── arabic_reading.jpg
│       ├── arabic_text.jpg
│       ├── arabic.jpg
│       ├── benchmark_chart_small.png
│       ├── benchmark_chart.png
│       ├── benchmark_layout_chart.png
│       ├── benchmark_rec_chart.png
│       ├── benchmark_tablerec_acc.png
│       ├── benchmark_tablerec_speed.png
│       ├── chi_hind_layout.jpg
│       ├── chi_hind_orig.jpg
│       ├── chi_hind_reading.jpg
│       ├── chi_hind_text.jpg
│       ├── chi_hind.jpg
│       ├── chinese_layout.jpg
│       ├── chinese_reading.jpg
│       ├── chinese_text.jpg
│       ├── chinese.jpg
│       ├── excerpt_layout.png
│       ├── excerpt_reading.jpg
│       ├── excerpt_text.png
│       ├── excerpt.png
│       ├── funsd_layout.jpg
│       ├── funsd_reading.jpg
│       ├── funsd_text.jpg
│       ├── funsd.png
│       ├── gcloud_full_langs.png
│       ├── gcloud_rec_bench.png
│       ├── hindi_layout.jpg
│       ├── hindi_reading.jpg
│       ├── hindi_text.jpg
│       ├── hindi.jpg
│       ├── japanese_layout.jpg
│       ├── japanese_reading.jpg
│       ├── japanese_tablerec.png
│       ├── japanese_text.jpg
│       ├── japanese.jpg
│       ├── latex_ocr.png
│       ├── nyt_layout.jpg
│       ├── nyt_order.jpg
│       ├── nyt_text.jpg
│       ├── nyt.jpg
│       ├── paper_layout.jpg
│       ├── paper_reading.jpg
│       ├── paper_tablerec.png
│       ├── paper_text.jpg
│       ├── paper.jpg
│       ├── pres_layout.jpg
│       ├── pres_reading.jpg
│       ├── pres_tablerec.png
│       ├── pres_text.jpg
│       ├── pres.png
│       ├── rec_acc_table.png
│       ├── scanned_layout.jpg
│       ├── scanned_reading.jpg
│       ├── scanned_tablerec.png
│       ├── scanned_tablerec2.png
│       ├── scanned_text.jpg
│       ├── scanned.png
│       ├── surya_rec_perf.png
│       ├── table_rec.png
│       ├── textbook_layout.jpg
│       ├── textbook_order.jpg
│       ├── textbook_text.jpg
│       └── textbook.jpg
├── surya
│   ├── __init__.py
│   ├── common
│   │   ├── __init__.py
│   │   ├── adetr
│   │   │   └── decoder.py
│   │   ├── donut
│   │   │   ├── encoder.py
│   │   │   └── processor.py
│   │   ├── load.py
│   │   ├── polygon.py
│   │   ├── predictor.py
│   │   ├── pretrained.py
│   │   ├── s3.py
│   │   ├── surya
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   ├── decoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── embedder
│   │   │   │   └── __init__.py
│   │   │   ├── encoder
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── flash_attn_utils.py
│   │   │   ├── processor
│   │   │   │   ├── __init__.py
│   │   │   │   ├── schema.py
│   │   │   │   └── tokenizer.py
│   │   │   └── schema.py
│   │   ├── util.py
│   │   └── xla.py
│   ├── debug
│   │   ├── draw.py
│   │   ├── fonts.py
│   │   ├── katex.js
│   │   ├── render_html.py
│   │   └── text.py
│   ├── detection
│   │   ├── __init__.py
│   │   ├── heatmap.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoderdecoder.py
│   │   ├── parallel.py
│   │   ├── processor.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── foundation
│   │   ├── __init__.py
│   │   ├── cache
│   │   │   ├── __init__.py
│   │   │   ├── dynamic_ops.py
│   │   │   └── static_ops.py
│   │   ├── loader.py
│   │   └── util.py
│   ├── input
│   │   ├── load.py
│   │   └── processing.py
│   ├── layout
│   │   ├── __init__.py
│   │   ├── label.py
│   │   └── schema.py
│   ├── logging.py
│   ├── models.py
│   ├── ocr_error
│   │   ├── __init__.py
│   │   ├── loader.py
│   │   ├── model
│   │   │   ├── __init__.py
│   │   │   ├── config.py
│   │   │   └── encoder.py
│   │   ├── schema.py
│   │   └── tokenizer.py
│   ├── recognition
│   │   ├── __init__.py
│   │   ├── languages.py
│   │   ├── postprocessing.py
│   │   ├── schema.py
│   │   └── util.py
│   ├── scripts
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── detect_layout.py
│   │   ├── detect_text.py
│   │   ├── finetune_ocr.py
│   │   ├── hf_to_s3.py
│   │   ├── ocr_latex.py
│   │   ├── ocr_text.py
│   │   ├── run_streamlit_app.py
│   │   ├── run_texify_app.py
│   │   ├── streamlit_app.py
│   │   ├── table_recognition.py
│   │   └── texify_app.py
│   ├── settings.py
│   └── table_rec
│       ├── __init__.py
│       ├── loader.py
│       ├── model
│       │   ├── __init__.py
│       │   ├── config.py
│       │   ├── decoder.py
│       │   ├── encoder.py
│       │   └── encoderdecoder.py
│       ├── processor.py
│       ├── schema.py
│       └── shaper.py
├── table_recognition.py
├── tests
│   ├── assets
│   │   └── test_latex.png
│   ├── conftest.py
│   ├── test_detection.py
│   ├── test_foundation.py
│   ├── test_latex_ocr.py
│   ├── test_layout.py
│   ├── test_ocr_errors.py
│   ├── test_recognition.py
│   └── test_table_rec.py
└── texify_app.py
```

# Files

--------------------------------------------------------------------------------
/surya/common/surya/processor/__init__.py:
--------------------------------------------------------------------------------

```python
import math

import cv2
import numpy as np
import torch
from PIL import Image
from torch.nn.utils.rnn import pad_sequence

from typing import List, Optional, Tuple

from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer

from surya.common.s3 import S3DownloaderMixin
from surya.common.surya.processor.schema import (
    TextInput,
    ImageInput,
    ProcessorOutput,
)
from surya.common.surya.schema import TaskNames
from surya.logging import get_logger
from surya.settings import settings

logger = get_logger()

# Task agnostic tokens - Every task will use these in some form or another
EOS_TOKEN = "</S>"
EOI_TOKEN = "<EOI>"  # This is end of INPUT, not image. Images are always followed by a task specific BOS token, so that serves as a delimiter anyways.
IMAGE_TOKEN = "<IMAGE>"
PAD_TOKEN = "<PAD>"
NO_OUTPUT_TOKEN = "<NOP>"
IMAGE_ROTATED_TOKEN = "<ROT>"
REGISTER_TOKENS = ["<REG1>", "<REG2>", "<REG3>", "<REG4>"]
BEACON_TOKEN = "<BEACON>"
NOMATH_TOKEN = "<NO-MATH>"

# Task specific tokens
OCR_WITH_BOXES_BOS_TOKEN = "<OCR-WB>"
OCR_WITHOUT_BOXES_BOS_TOKEN = "<OCR-WOB>"
BLOCK_WITHOUT_BOXES_TOKEN = "<BLOCKS-WOB>"
LAYOUT_BOS_TOKEN = "<LAYOUT>"
TABLE_STRUCTURE_BOS_TOKEN = "<TABLE-STRUCTURE>"


class SuryaOCRProcessor(S3DownloaderMixin, ProcessorMixin):
    attributes = ["image_processor", "ocr_tokenizer"]
    image_processor_class = "BaseImageProcessor"
    ocr_tokenizer_class = "PreTrainedTokenizer"
    rescale_factor = 1 / 255.0
    image_mean = (0.485, 0.456, 0.406)
    image_std = (0.229, 0.224, 0.225)

    def __init__(
        self,
        ocr_tokenizer: PreTrainedTokenizer,
        blank_bbox_token_id: int,
        num_register_tokens: int,
        patch_size: int,
        merge_size: int,
        num_beacon_tokens: int,
        beacon_token_interval: int,
        model_device: str,
        **kwargs,
    ):
        self.ocr_tokenizer = ocr_tokenizer
        self.patch_size = patch_size
        self.merge_size = merge_size
        self.num_register_tokens = num_register_tokens
        self.num_beacon_tokens = num_beacon_tokens
        self.beacon_token_interval = beacon_token_interval

        self.tokenizer_vocab_size = 0
        for attr in self.attributes:
            if "tokenizer" in attr:
                self.tokenizer_vocab_size += getattr(self, attr).vocab_size

        self.offsets = {"ocr": 0}

        # Create special token mapping
        self.special_token_mapping = self.ocr_tokenizer.system_tokens

        self.register_token_ids = [
            self.special_token_mapping.get(r) for r in REGISTER_TOKENS
        ]
        self.beacon_token_id = self.special_token_mapping.get(BEACON_TOKEN)
        self.image_token_id = self.special_token_mapping.get(IMAGE_TOKEN)
        self.pad_token_id = self.special_token_mapping.get(PAD_TOKEN)
        self.eos_token_id = self.special_token_mapping.get(EOS_TOKEN)
        self.eoi_token_id = self.special_token_mapping.get(EOI_TOKEN)
        self.no_output_token = self.special_token_mapping.get(NO_OUTPUT_TOKEN)
        self.image_rotated_token = self.special_token_mapping.get(IMAGE_ROTATED_TOKEN)
        self.nomath_token = self.special_token_mapping.get(NOMATH_TOKEN)

        self.bos_token_id = {
            TaskNames.ocr_with_boxes: self.special_token_mapping.get(
                OCR_WITH_BOXES_BOS_TOKEN
            ),
            TaskNames.ocr_without_boxes: self.special_token_mapping.get(
                OCR_WITHOUT_BOXES_BOS_TOKEN
            ),
            TaskNames.block_without_boxes: self.special_token_mapping.get(
                BLOCK_WITHOUT_BOXES_TOKEN
            ),
            TaskNames.layout: self.special_token_mapping.get(LAYOUT_BOS_TOKEN),
            TaskNames.table_structure: self.special_token_mapping.get(
                TABLE_STRUCTURE_BOS_TOKEN
            ),
        }

        if self.image_token_id is None:
            logger.warning("Warning: Image token not found in special tokens")

        self.blank_bbox_token_id = blank_bbox_token_id
        self.bbox_pad_token_id = self.blank_bbox_token_id

        self.ignore_bbox_token_ids = [
            v
            for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items()
            if k not in self.ocr_tokenizer.special_tokens["math_external"]
        ]
        math_end_token = "</math>"
        self.math_start_token_ids = [
            v
            for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items()
            if k in self.ocr_tokenizer.special_tokens["math_external"]
            and k != math_end_token
        ]
        self.math_end_token_ids = [
            v
            for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items()
            if k == math_end_token
        ]

        if self.num_register_tokens > len(self.register_token_ids):
            raise ValueError(
                "The number of register tokens requested exceeds the number of register tokens defined in the special token mapping."
            )

        self.image_mean = np.array(self.image_mean, dtype=np.float32)
        self.image_std = np.array(self.image_std, dtype=np.float32)
        self.model_device = model_device

    @property
    def vocab_size(self):
        return self.tokenizer_vocab_size

    def image_processor(self, image: Image.Image) -> np.ndarray:
        # Convert to array
        image = np.asarray(image, dtype=np.float32)
        return image

    @staticmethod
    def scale_to_fit(
        img: np.ndarray,
        max_size: Tuple[int, int],
        min_size: Tuple[int, int] = (168, 168),
    ):
        # Get current dimensions
        height, width = img.shape[:2]

        # Check for empty or invalid image
        if width == 0 or height == 0:
            return img

        max_width, max_height = max_size
        min_width, min_height = min_size

        # Calculate pixel counts
        current_pixels = width * height
        max_pixels = max_width * max_height
        min_pixels = min_width * min_height

        if current_pixels > max_pixels:
            scale_factor = (max_pixels / current_pixels) ** 0.5

            new_width = math.floor(width * scale_factor)
            new_height = math.floor(height * scale_factor)
        elif current_pixels == 0:
            return img
        elif current_pixels < min_pixels:
            scale_factor = (min_pixels / current_pixels) ** 0.5

            new_width = math.ceil(width * scale_factor)
            new_height = math.ceil(height * scale_factor)
        else:
            return img

        return cv2.resize(
            img, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4
        )

    def _image_processor(self, image: np.ndarray):
        image = image.astype(np.float64) * self.rescale_factor
        image = (image.astype(np.float32) - self.image_mean) / self.image_std
        return image

    def _process_and_tile(
        self, image: np.ndarray
    ) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
        """
        Resizes the input image to the closest multiple of tile_size while preserving the aspect ratio
        and returns a tensor of image tiles.
        """
        extra_multipler = (
            4 if settings.FOUNDATION_XLA else 1
        )  # Needed to force same size grid_thws per row with padding

        factor = (
            self.patch_size * self.merge_size * extra_multipler
        )  # Make a multiple of window size

        height, width = image.shape[:2]

        h_bar = math.ceil(height / factor) * factor
        w_bar = math.ceil(width / factor) * factor
        if h_bar != height or w_bar != width:
            if height == 0 or width == 0:
                image = np.zeros((h_bar, w_bar, 3), dtype=np.uint8)
            else:
                image = cv2.resize(image, (w_bar, h_bar), interpolation=cv2.INTER_CUBIC)

        # Handle scaling and normalization
        image = self._image_processor(image)
        height, width = image.shape[:2]

        # Numpy array to torch tensor
        img_tensor = torch.from_numpy(image.transpose(2, 0, 1))
        patches = img_tensor.unsqueeze(0)

        channel = patches.shape[1]
        grid_t = patches.shape[0]
        grid_h, grid_w = height // self.patch_size, width // self.patch_size

        patches = patches.reshape(
            grid_t,
            1,
            channel,
            grid_h // self.merge_size,
            self.merge_size,
            self.patch_size,
            grid_w // self.merge_size,
            self.merge_size,
            self.patch_size,
        )
        patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
        flatten_patches = patches.reshape(
            grid_t * grid_h * grid_w, channel * 1 * self.patch_size * self.patch_size
        )

        return flatten_patches, (grid_t, grid_h, grid_w)

    # Handle image input dictionaries - Process image, tile accordingly, and setup the input ids and boxes correspondingly
    def _process_image_input(self, image_input: ImageInput) -> ProcessorOutput:
        rotated = image_input.get("rotated", False)
        image = image_input.get("image", None)

        assert image is not None, (
            "A PIL Image must be provided when the input type is `image`"
        )
        image_tiles, grid_thw = self._process_and_tile(image)

        num_tokens = image_tiles.shape[0] / self.merge_size**2
        assert num_tokens.is_integer(), (
            f"Expected number of tokens to be an integer, got {num_tokens}"
        )

        input_ids = [self.image_token_id] * int(num_tokens)
        input_ids += self.register_token_ids[: self.num_register_tokens]

        # Handle the image being rotated in the imdataset
        if rotated:
            input_ids = [self.image_rotated_token] + input_ids

        return ProcessorOutput(
            input_ids=input_ids,
            image_tiles=image_tiles,
            grid_thw=grid_thw,
        )

    def _process_text_input(self, text_input: TextInput, task: str) -> ProcessorOutput:
        input_text = text_input.get("text", None)
        math_mode = text_input.get("math", False)

        input_ids = self.ocr_tokenizer(input_text, tasks=task)["input_ids"][0]
        input_ids = [self.offsets["ocr"] + id for id in input_ids]

        # nomath token does not work for layout
        if not math_mode and task != "layout":
            input_ids.insert(0, self.nomath_token)

        return ProcessorOutput(
            input_ids=input_ids,
            image_tiles=None,
            grid_thw=None,
        )

    def _process_input(self, input_dict: dict, task: str):
        input_type = input_dict["type"]
        if input_type == "image":
            return self._process_image_input(input_dict)
        elif input_type == "text":
            return self._process_text_input(input_dict, task)

        raise NotImplementedError(f"Input of type `{input_type}` is not implemented")

    # Peprocessing for OCR task
    # The task is expected to have - image_dict, user_input_dict, output_dict
    # use_input_dict is allowed to have an empty input which is fine, but needs to be present
    def _process_ocr_with_boxes(
        self,
        mixed_input: List[dict],
        bos_token_id: int,
        task: str = TaskNames.ocr_with_boxes,
    ):
        processed_input_ids = []
        all_image_tiles = []
        all_grid_thw = []

        # 1. Process the image input
        for i, input_dict in enumerate(mixed_input):
            processor_output = self._process_input(input_dict, task)
            input_ids = processor_output["input_ids"]
            image_tiles = processor_output["image_tiles"]
            grid_thw = processor_output["grid_thw"]

            # Special handling of some delimiter tokens
            if i == 1:
                assert input_dict["type"] == "text", (
                    "Expected text input for model input."
                )
                # Case for input - Add task specific bos token + end_of_input token
                # We do not want the model to learn how to predict inputs. Hence IGNORE_INDEX for these
                input_ids = [bos_token_id] + input_ids + [self.eoi_token_id]
            if i == 2:
                assert input_dict["type"] == "text", (
                    "Expected text for final model input"
                )
                input_ids = input_ids + [self.eos_token_id]
            elif i > 2:
                raise ValueError(f"Too many inputs received. Expected is 2 for inference, 3 for training. Received: {len(mixed_input)}")

            # Some input types don't return any image tiles, accounting for that
            if image_tiles is not None:
                all_image_tiles.append(image_tiles)
                all_grid_thw.append(grid_thw)

            processed_input_ids.extend(input_ids)

        return (
            torch.tensor(processed_input_ids, dtype=torch.long),
            all_image_tiles,
            all_grid_thw,
        )

    def _process_layout(self, mixed_input: List[dict], bos_token_id: int):
        return self._process_ocr_with_boxes(
            mixed_input, bos_token_id=bos_token_id, task="layout"
        )

    def _process_table_structure(self, mixed_input: List[dict], bos_token_id: int):
        return self._process_ocr_with_boxes(
            mixed_input, bos_token_id=bos_token_id, task="table_structure"
        )

    def _process_ocr_without_boxes(
        self,
        mixed_input: List[dict],
        bos_token_id: int,
        task: str = "ocr_without_boxes",
    ):
        # Boxes are set to None, so this will work
        # TODO: improve this behavior
        return self._process_ocr_with_boxes(
            mixed_input, bos_token_id=bos_token_id, task=task
        )

    def _process_block_without_boxes(
        self,
        mixed_input: List[dict],
        bos_token_id: int,
        task: str = "block_without_boxes",
    ):
        return self._process_ocr_with_boxes(
            mixed_input, bos_token_id=bos_token_id, task=task
        )

    def align_long_axis(self, image: np.ndarray) -> Tuple[np.ndarray, bool]:
        height, width, _ = image.shape
        if height > width:  # Rotate vertical lines
            image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
            return image, True

        return image, False

    def __call__(
        self,
        mixed_batch: List[dict],
        padding_side: Optional[str] = "left",
        device: Optional[torch.device] = None,
        pad_to_multiple: Optional[int] = None,
    ):
        all_image_tiles = []
        all_input_ids = []
        all_grid_thw = []

        for b in mixed_batch:
            mixed_input = b["inputs"]
            task = b["task"]
            assert task in self.bos_token_id, f"Task {task} has no bos token defined."

            # Select the correct processing function based on the task type
            input_ids, image_tiles, grid_thw = getattr(self, f"_process_{task}")(
                mixed_input, self.bos_token_id[task]
            )

            all_input_ids.append(input_ids)
            all_image_tiles.extend(image_tiles)
            all_grid_thw.extend(grid_thw)

        batched_input_ids = pad_sequence(
            all_input_ids,
            batch_first=True,
            padding_side=padding_side,
            padding_value=self.pad_token_id,
        )

        if pad_to_multiple is not None:
            current_len = batched_input_ids.shape[1]
            # Calculate the next multiple of pad_to_multiple
            padded_len = (
                (current_len + pad_to_multiple - 1) // pad_to_multiple
            ) * pad_to_multiple

            if padded_len > current_len:
                pad_len = padded_len - current_len
                batched_input_ids = torch.nn.functional.pad(
                    batched_input_ids, (pad_len, 0), value=self.pad_token_id
                )

        attention_mask = batched_input_ids.ne(self.pad_token_id)

        # Generating position IDs that are independent of left and right padding;
        # This should ensure same results for either padding side. Exact position id for the pad tokens themselves don't matter since they are masked
        position_ids = attention_mask.cumsum(dim=-1) - 1
        position_ids[position_ids < 0] = (
            0  # For left padding, the position ids for padding will become -1 because of the shift; Setting to 0
        )
        position_ids = (
            attention_mask.to(torch.long) * position_ids
        )  # Ensure right pad ids get set to zero

        batched_image_tiles = torch.cat(all_image_tiles, dim=0)
        batched_grid_thw = torch.from_numpy(np.array(all_grid_thw))

        # Pin memory for CUDA
        if device == torch.device("cuda"):
            batched_image_tiles = batched_image_tiles.pin_memory()
            batched_grid_thw = batched_grid_thw.pin_memory()
            attention_mask = attention_mask.pin_memory()
            batched_input_ids = batched_input_ids.pin_memory()
            position_ids = position_ids.pin_memory()

        return BatchFeature(
            {
                "input_ids": batched_input_ids,
                "image_tiles": batched_image_tiles,
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "grid_thw": batched_grid_thw,
            }
        )

    # Decode model outputs; Strips special tokens
    def decode(self, tokens: List[int], task: str):
        filtered_tokens = [
            t
            for t in tokens
            if t not in self.special_token_mapping.values() and t != -100
        ]  # Skip special tokens and loss ignore index
        return self.ocr_tokenizer.decode(filtered_tokens, task=task)

```

--------------------------------------------------------------------------------
/surya/recognition/__init__.py:
--------------------------------------------------------------------------------

```python
from __future__ import annotations

import re
from typing import List

import numpy as np
import torch
from PIL import Image
import torch.nn.functional as F

from surya.common.polygon import PolygonBox
from surya.common.surya.processor import NOMATH_TOKEN
from surya.common.predictor import BasePredictor
from surya.detection import DetectionPredictor
from surya.foundation import FoundationPredictor

from surya.input.processing import (
    convert_if_not_rgb,
    slice_polys_from_image,
    slice_bboxes_from_image,
)
from surya.recognition.postprocessing import fix_unbalanced_tags
from surya.recognition.util import (
    sort_text_lines,
    clean_close_polygons,
    unwrap_math,
    clean_math_tags,
    filter_blacklist_tags,
    words_from_chars
)
from surya.foundation.util import detect_repeat_token, prediction_to_polygon_batch
from surya.recognition.schema import TextLine, OCRResult, TextChar
from surya.common.surya.schema import TaskNames
from surya.settings import settings
from surya.logging import get_logger, configure_logging

configure_logging()
logger = get_logger()

class RecognitionPredictor(BasePredictor):
    batch_size = settings.RECOGNITION_BATCH_SIZE
    default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 128}

    # Override base init - Do not load model
    def __init__(self, foundation_predictor: FoundationPredictor):
        self.foundation_predictor = foundation_predictor
        self.processor = self.foundation_predictor.processor
        self.bbox_size = self.foundation_predictor.model.config.bbox_size
        self.tasks = self.foundation_predictor.tasks

    # Special handling for disable tqdm to pass into foundation predictor
    # Make sure they are kept in sync
    @property
    def disable_tqdm(self) -> bool:
        return super().disable_tqdm

    @disable_tqdm.setter
    def disable_tqdm(self, value: bool) -> None:
        self._disable_tqdm = bool(value)
        self.foundation_predictor.disable_tqdm = bool(value)

    def detect_and_slice_bboxes(
        self,
        images: List[Image.Image],
        task_names: List[str],
        det_predictor: DetectionPredictor,
        detection_batch_size: int | None = None,
        highres_images: List[Image.Image] | None = None,
    ):
        det_predictions = det_predictor(images, batch_size=detection_batch_size)

        all_slices = []
        slice_map = []
        all_polygons = []
        all_task_names = []
        all_res_scales = []

        for idx, (det_pred, image, highres_image, task_name) in enumerate(
            zip(det_predictions, images, highres_images, task_names)
        ):
            polygons = [p.polygon for p in det_pred.bboxes]
            if highres_image:
                width_scaler = highres_image.size[0] / image.size[0]
                height_scaler = highres_image.size[1] / image.size[1]
                scaled_polygons = [
                    [
                        [int(p[0] * width_scaler), int(p[1] * height_scaler)]
                        for p in polygon
                    ]
                    for polygon in polygons
                ]
                highres_image = self.processor.image_processor(highres_image)
                slices = slice_polys_from_image(highres_image, scaled_polygons)
                res_scales = [(width_scaler, height_scaler) for _ in range(len(slices))]
            else:
                image = self.processor.image_processor(image)
                slices = slice_polys_from_image(image, polygons)
                res_scales = [(1, 1) for _ in range(len(slices))]

            slice_map.append(len(slices))
            all_slices.extend(slices)
            all_polygons.extend(polygons)
            all_task_names.extend([task_name] * len(slices))
            all_res_scales.extend(res_scales)

        assert (
            len(all_slices)
            == sum(slice_map)
            == len(all_polygons)
            == len(all_task_names)
            == len(all_res_scales)
        )

        return {
            "slices": all_slices,
            "slice_map": slice_map,
            "polygons": all_polygons,
            "task_names": all_task_names,
            "input_text": [None] * len(all_slices),
            "res_scales": all_res_scales,
        }

    def slice_bboxes(
        self,
        images: List[Image.Image],
        task_names: List[str],
        bboxes: List[List[List[int]]] | None = None,
        polygons: List[List[List[List[int]]]] | None = None,
        input_text: List[List[str | None]] | None = None,
    ) -> dict:
        assert bboxes is not None or polygons is not None
        slice_map = []
        all_slices = []
        all_polygons = []
        all_text = []
        all_task_names = []

        for idx, image in enumerate(images):
            image = self.processor.image_processor(image)
            if polygons is not None:
                polys = polygons[idx]
                slices = slice_polys_from_image(image, polys)
            else:
                slices = slice_bboxes_from_image(image, bboxes[idx])
                polys = [
                    [
                        [bbox[0], bbox[1]],
                        [bbox[2], bbox[1]],
                        [bbox[2], bbox[3]],
                        [bbox[0], bbox[3]],
                    ]
                    for bbox in bboxes[idx]
                ]
            slice_map.append(len(slices))
            all_slices.extend(slices)
            all_polygons.extend(polys)
            all_task_names.extend([task_names[idx]] * len(slices))

            if input_text is None:
                all_text.extend([None] * len(slices))
            else:
                all_text.extend(input_text[idx])

        assert (
            len(all_slices)
            == sum(slice_map)
            == len(all_polygons)
            == len(all_text)
            == len(all_task_names)
        ), (
            f"Mismatch in lengths: {len(all_slices)}, {sum(slice_map)}, {len(all_polygons)}, {len(all_text)}, {len(all_task_names)}"
        )

        return {
            "slices": all_slices,
            "slice_map": slice_map,
            "polygons": all_polygons,
            "input_text": all_text,
            "task_names": all_task_names,
            "res_scales": [(1, 1) for _ in range(len(all_slices))],
        }

    def get_bboxes_text(
        self,
        flat: dict,
        predicted_tokens: list,
        scores: list,
        predicted_polygons: list,
        drop_repeated_text: bool = False,
    ) -> list:
        char_predictions = []
        needs_boxes = [
            self.tasks[task_name]["needs_bboxes"] for task_name in flat["task_names"]
        ]

        for slice_idx, (
            slice_image,
            image_tokens,
            image_polygons,
            image_scores,
            needs_box,
        ) in enumerate(
            zip(
                flat["slices"],
                predicted_tokens,
                predicted_polygons,
                scores,
                needs_boxes,
            )
        ):
            blank_bbox = [[0, 0], [0, 1], [1, 1], [1, 0]]
            if self.processor.no_output_token in image_tokens:
                char_predictions.append(None)
                continue

            # If the image is very out of distribution, we can get nonsense repeats, and we may need to drop the text entirely
            if drop_repeated_text and detect_repeat_token(image_tokens):
                char_predictions.append(
                    [
                        TextChar(
                            text="",
                            polygon=blank_bbox,
                            confidence=0,
                            bbox_valid=False,
                        )
                    ]
                )
                continue

            image_polygons = image_polygons[: len(image_tokens)].cpu().numpy().tolist()

            detokenize_sequences = []
            detokenize_sequence = []
            past_char_qwen_token = False

            def _add_detokenize_sequence(
                special_token: bool,
                past_special_token: bool,
                force: bool = False,
            ):
                nonlocal detokenize_sequence, detokenize_sequences

                if (
                    special_token
                    or past_special_token
                    or force
                ) and detokenize_sequence:
                    chars = [dt[0] for dt in detokenize_sequence]
                    scores = [dt[1] for dt in detokenize_sequence]
                    bboxes = [dt[2] for dt in detokenize_sequence]

                    if past_special_token:
                        detokenize_sequences.append((chars, scores, None, "special"))
                    else:
                        detokenize_sequences.append((chars, scores, bboxes, "ocr"))

                    detokenize_sequence = []

            # Split up into sequences to detokenize separately
            past_special_token = False
            for bbox, char_id, score in zip(image_polygons, image_tokens, image_scores):
                if char_id in [
                    self.processor.eos_token_id,
                    self.processor.pad_token_id,
                ]:
                    break

                special_token = (
                    char_id >= self.processor.ocr_tokenizer.ocr_tokenizer.SPECIAL_BASE
                )
                _add_detokenize_sequence(
                    special_token, past_special_token
                )
                detokenize_sequence.append((char_id, score, bbox))
                past_special_token = special_token

            _add_detokenize_sequence(
                False, past_special_token, force=True
            )

            img_chars = []
            for sequence in detokenize_sequences:
                token_ids, seq_score, bboxes, token_type = sequence
                if token_type == "ocr":
                    text = self.processor.ocr_tokenizer.decode(
                        token_ids, task=TaskNames.ocr_with_boxes
                    )
                    bboxes = clean_close_polygons(
                        bboxes
                    )  # clean out bboxes that are close, like what happens with multiple utf-16 tokens per char
                    bbox_idx = 0
                    for text_idx, text_line in enumerate(text):
                        img_chars.append(
                            TextChar(
                                text=text_line,
                                polygon=bboxes[bbox_idx],
                                confidence=seq_score[bbox_idx],
                                bbox_valid=True,
                            )
                        )

                        # Ensure we don't exceed the bbox count
                        # Use the last bbox for the rest of the text
                        if bbox_idx < len(bboxes) - 1:
                            bbox_idx += 1
                elif token_type == "special":
                    text = self.processor.ocr_tokenizer.decode(
                        token_ids, task="ocr_without_boxes"
                    )
                    if text in [NOMATH_TOKEN] or re.match(r"<SCRIPT-\w+>", text):
                        continue

                    img_chars.append(
                        TextChar(
                            text=text,
                            polygon=blank_bbox,
                            confidence=seq_score[0],
                            bbox_valid=False,
                        )
                    )
                else:
                    text = self.processor.ocr_tokenizer.decode(
                        token_ids, task=TaskNames.block_without_boxes
                    )
                    img_chars.append(
                        TextChar(
                            text=text,
                            polygon=blank_bbox,
                            confidence=seq_score[0],
                            bbox_valid=False,
                        )
                    )

            char_predictions.append(img_chars)

        return char_predictions

    def __call__(
        self,
        images: List[Image.Image],
        task_names: List[str] | None = None,
        det_predictor: DetectionPredictor | None = None,
        detection_batch_size: int | None = None,
        recognition_batch_size: int | None = None,
        highres_images: List[Image.Image] | None = None,
        bboxes: List[List[List[int]]] | None = None,
        polygons: List[List[List[List[int]]]] | None = None,
        input_text: List[List[str | None]] | None = None,
        sort_lines: bool = False,
        math_mode: bool = True,
        return_words: bool = False,
        drop_repeated_text: bool = False,
        max_sliding_window: int | None = None,
        max_tokens: int | None = None,
        filter_tag_list: List[str] = None
    ) -> List[OCRResult]:
        if task_names is None:
            task_names = [TaskNames.ocr_with_boxes] * len(images)
        if recognition_batch_size is None:
            recognition_batch_size = self.get_batch_size()

        assert len(images) == len(task_names), (
            "You need to pass in one task name for each image"
        )

        images = convert_if_not_rgb(images)
        if highres_images is not None:
            assert len(images) == len(highres_images), (
                "You need to pass in one highres image for each image"
            )

        highres_images = (
            convert_if_not_rgb(highres_images)
            if highres_images is not None
            else [None] * len(images)
        )

        if bboxes is None and polygons is None:
            assert det_predictor is not None, (
                "You need to pass in a detection predictor if you don't provide bboxes or polygons"
            )

            # Detect then slice
            flat = self.detect_and_slice_bboxes(
                images,
                task_names,
                det_predictor,
                detection_batch_size=detection_batch_size,
                highres_images=highres_images,
            )
        else:
            if bboxes is not None:
                assert len(images) == len(bboxes), (
                    "You need to pass in one list of bboxes for each image"
                )
            if polygons is not None:
                assert len(images) == len(polygons), (
                    "You need to pass in one list of polygons for each image"
                )

            flat = self.slice_bboxes(
                images,
                bboxes=bboxes,
                polygons=polygons,
                input_text=input_text,
                task_names=task_names,
            )

        # No images passed, or no boxes passed, or no text detected in the images
        if len(flat["slices"]) == 0:
            return [
                OCRResult(
                    text_lines=[], image_bbox=[0, 0, im.size[0], im.size[1]]
                )
                for im in images
            ]

        # Sort by image sizes. Negative so that longer images come first, fits in with continuous batching better
        sorted_pairs = sorted(
            enumerate(flat["slices"]),
            key=lambda x: -(x[1].shape[0] * x[1].shape[1])  # height * width
        )
        indices, sorted_slices = zip(*sorted_pairs)

        # Reorder input_text and task_names based on the new order
        flat["slices"] = list(sorted_slices)
        flat["input_text"] = [flat["input_text"][i] for i in indices]
        flat["task_names"] = [flat["task_names"][i] for i in indices]

        # Make predictions
        predicted_tokens, batch_bboxes, scores, _ = self.foundation_predictor.prediction_loop(
            images=flat["slices"],
            input_texts=flat["input_text"],
            task_names=flat["task_names"],
            batch_size=recognition_batch_size,
            math_mode=math_mode,
            drop_repeated_tokens=True,
            max_lookahead_tokens=self.foundation_predictor.model.config.multi_output_distance,
            max_sliding_window=max_sliding_window,
            max_tokens=max_tokens,
            tqdm_desc="Recognizing Text"
        )

        # Get text and bboxes in structured form
        bbox_size = self.bbox_size
        image_sizes = [img.shape for img in flat["slices"]]
        predicted_polygons = prediction_to_polygon_batch(
            batch_bboxes, image_sizes, bbox_size, bbox_size // 2
        )
        char_predictions = self.get_bboxes_text(
            flat,
            predicted_tokens,
            scores,
            predicted_polygons,
            drop_repeated_text=drop_repeated_text,
        )

        char_predictions = sorted(zip(indices, char_predictions), key=lambda x: x[0])
        char_predictions = [pred for _, pred in char_predictions]

        predictions_by_image = []
        slice_start = 0
        for idx, image in enumerate(images):
            slice_end = slice_start + flat["slice_map"][idx]
            image_lines = char_predictions[slice_start:slice_end]
            polygons = flat["polygons"][slice_start:slice_end]
            res_scales = flat["res_scales"][slice_start:slice_end]
            slice_start = slice_end

            lines = []
            for text_line, polygon, res_scale in zip(image_lines, polygons, res_scales):
                # Special case when input text is good
                if not text_line:
                    lines.append(
                        TextLine(
                            text="",
                            polygon=polygon,
                            chars=[],
                            confidence=1,
                            original_text_good=True,
                        )
                    )
                else:
                    confidence = (
                        float(np.mean([char.confidence for char in text_line]))
                        if len(text_line) > 0
                        else 0
                    )
                    poly_box = PolygonBox(polygon=polygon)
                    for char in text_line:
                        char.rescale(
                            res_scale, (1, 1)
                        )  # Rescale from highres if needed
                        char.shift(
                            poly_box.bbox[0], poly_box.bbox[1]
                        )  # Ensure character boxes match line boxes (relative to page)
                        char.clamp(poly_box.bbox)

                    text_line = fix_unbalanced_tags(
                        text_line, self.processor.ocr_tokenizer.special_tokens
                    )
                    text_line = filter_blacklist_tags(text_line, filter_tag_list)
                    text = "".join([char.text for char in text_line])
                    text = unwrap_math(text)
                    text = clean_math_tags(text)
                    lines.append(
                        TextLine(
                            text=text,
                            polygon=polygon,
                            chars=text_line,
                            confidence=confidence,
                            words=words_from_chars(text_line, poly_box)
                            if return_words
                            else [],
                        )
                    )

            if sort_lines:
                lines = sort_text_lines(lines)
            predictions_by_image.append(
                OCRResult(
                    text_lines=lines, image_bbox=[0, 0, image.size[0], image.size[1]]
                )
            )

        return predictions_by_image

```

--------------------------------------------------------------------------------
/surya/common/surya/decoder/__init__.py:
--------------------------------------------------------------------------------

```python
from typing import Callable, List, Optional, Tuple, Union

import torch
from torch import nn

from transformers.activations import ACT2FN
from transformers.cache_utils import (
    Cache,
)
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.processing_utils import Unpack
from transformers.utils import (
    logging,
)

from surya.common.pretrained import SuryaPreTrainedModel
from surya.common.surya.decoder.config import SuryaDecoderConfig


logger = logging.get_logger(__name__)


class Qwen2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
        query.dtype
    )
    attn_weights = nn.functional.dropout(
        attn_weights, p=dropout, training=module.training
    )
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


class Qwen2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: SuryaDecoderConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(
            config, "head_dim", config.hidden_size // config.num_attention_heads
        )
        self.num_key_value_groups = (
            config.num_attention_heads // config.num_key_value_heads
        )
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = True
        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=True
        )
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True
        )
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True
        )
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        cache_idxs: Optional[List[int]] = None,
        num_valid_tokens: Optional[List[int]] = None,
        text_lengths: Optional[List[int]] = None,
        prefill: bool = False,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(
            query_states, key_states, cos, sin
        )

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            # cache_idxs, num_valid_tokens, and prefill add support for our new caching mechanism
            cache_kwargs = {
                "sin": sin,
                "cos": cos,
                "cache_position": cache_position,
                "cache_idxs": cache_idxs,
                "num_valid_tokens": num_valid_tokens,
                "prefill": prefill,
                "text_lengths": text_lengths,
            }
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.layer_idx, cache_kwargs
            )

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get(
                "output_attentions", False
            ):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            elif self.config._attn_implementation == "flash_attention_2":
                # Needed for CPU -> GPU
                from surya.common.surya.flash_attn_utils import (
                    flash_attn_decode,
                    flash_attn_prefill,
                )

                if prefill:
                    attention_interface = flash_attn_prefill
                else:
                    attention_interface = flash_attn_decode
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[
                    self.config._attn_implementation
                ]

        """
        IMPORTANT:
        We sometimes use a custom sliding window impl. during training

        We force this to None to ensure that the HF attention integrations do not
        perform any special handling - FA2 in particular will ignore the 4D mask, and use this instead
        to infer the final mask

        SDPA ignores this completely, and is fully dependent on the 4D mask - (https://github.com/huggingface/transformers/blob/b9faf2f93085e3cf2c65184a69d1d9e502f95786/src/transformers/integrations/sdpa_attention.py#L23)
        """
        sliding_window = None

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=sliding_window,  # main diff with Llama
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class Qwen2DecoderLayer(nn.Module):
    def __init__(self, config: SuryaDecoderConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
        self.mlp = Qwen2MLP(config)
        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen2RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        cache_idxs: Optional[List[int]] = None,
        num_valid_tokens: Optional[List[int]] = None,
        text_lengths: Optional[List[int]] = None,
        prefill: bool = False,
        position_embeddings: Optional[
            Tuple[torch.Tensor, torch.Tensor]
        ] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[
        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
    ]:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            cache_idxs=cache_idxs,
            num_valid_tokens=num_valid_tokens,
            text_lengths=text_lengths,
            prefill=prefill,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs


class Qwen2RotaryEmbedding(nn.Module):
    def __init__(self, config: SuryaDecoderConfig, device=None):
        super().__init__()
        # BC: "rope_type" was originally "type"
        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
            self.rope_type = config.rope_scaling.get(
                "rope_type", config.rope_scaling.get("type")
            )
        else:
            self.rope_type = "default"
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    def _dynamic_frequency_update(self, position_ids, device):
        """
        dynamic RoPE layers should recompute `inv_freq` in the following situations:
        1 - growing beyond the cached sequence length (allow scaling)
        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
        """
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(
                self.config, device, seq_len=seq_len
            )
            self.register_buffer(
                "inv_freq", inv_freq, persistent=False
            )  # TODO joao: may break with compilation
            self.max_seq_len_cached = seq_len

        if (
            seq_len < self.original_max_seq_len
            and self.max_seq_len_cached > self.original_max_seq_len
        ):  # reset
            # This .to() is needed if the model has been moved to a device after being initialized (because
            # the buffer is automatically moved, but not the original copy)
            self.original_inv_freq = self.original_inv_freq.to(device)
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    @torch.no_grad()
    def forward(self, x, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = (
            self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        )
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = (
            device_type
            if isinstance(device_type, str) and device_type != "mps"
            else "cpu"
        )
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (
                inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()
            ).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class Qwen2PreTrainedModel(SuryaPreTrainedModel):
    config_class = SuryaDecoderConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["Qwen2DecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_flex_attn = True
    _supports_cache_class = True
    _supports_quantized_cache = True
    _supports_static_cache = True
    _supports_attention_backend = True

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()


class SuryaDecoderModel(Qwen2PreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
    This variant has been modified to remove the embedding layer completely - It only supports inputs_embeds as an input

    Args:
        config: Qwen2Config
    """

    def __init__(self, config: SuryaDecoderConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.layers = nn.ModuleList(
            [
                Qwen2DecoderLayer(config, layer_idx)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = Qwen2RotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        cache_idxs: Optional[List[int]] = None,
        num_valid_tokens: Optional[List[int]] = None,
        text_lengths: Optional[List[int]] = None,
        prefill: bool = False,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        if inputs_embeds is None:
            raise ValueError("You must specify inputs_embeds")

        if cache_position is None:
            raise ValueError("You must specify cache_position")

        if position_ids is None:
            raise ValueError("You must specify position_ids")

        hidden_states = inputs_embeds
        causal_mask = (
            attention_mask  # We make the 4D mask in the combined model when needed
        )

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                cache_idxs=cache_idxs,
                num_valid_tokens=num_valid_tokens,
                prefill=prefill,
                text_lengths=text_lengths,
                **flash_attn_kwargs,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.norm(hidden_states)

        output = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
        )
        return output if return_dict else output.to_tuple()

```

--------------------------------------------------------------------------------
/surya/ocr_error/tokenizer.py:
--------------------------------------------------------------------------------

```python
import collections
import os
import json
import unicodedata
from typing import List, Optional, Tuple

from tokenizers import normalizers

from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

from surya.common.s3 import S3DownloaderMixin

VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}

# Copied from transformers.models.bert.tokenization_bert.load_vocab
def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    vocab = collections.OrderedDict()
    with open(vocab_file, "r", encoding="utf-8") as reader:
        tokens = reader.readlines()
    for index, token in enumerate(tokens):
        token = token.rstrip("\n")
        vocab[token] = index
    return vocab


# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
def whitespace_tokenize(text):
    """Runs basic whitespace cleaning and splitting on a piece of text."""
    text = text.strip()
    if not text:
        return []
    tokens = text.split()
    return tokens


class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
    r"""
    Construct a DistilBERT tokenizer. Based on WordPiece.

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.

    Args:
        vocab_file (`str`):
            File containing the vocabulary.
        do_lower_case (`bool`, *optional*, defaults to `True`):
            Whether or not to lowercase the input when tokenizing.
        do_basic_tokenize (`bool`, *optional*, defaults to `True`):
            Whether or not to do basic tokenization before WordPiece.
        never_split (`Iterable`, *optional*):
            Collection of tokens which will never be split during tokenization. Only has an effect when
            `do_basic_tokenize=True`
        unk_token (`str`, *optional*, defaults to `"[UNK]"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        sep_token (`str`, *optional*, defaults to `"[SEP]"`):
            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
            sequence classification or for a text and a question for question answering. It is also used as the last
            token of a sequence built with special tokens.
        pad_token (`str`, *optional*, defaults to `"[PAD]"`):
            The token used for padding, for example when batching sequences of different lengths.
        cls_token (`str`, *optional*, defaults to `"[CLS]"`):
            The classifier token which is used when doing sequence classification (classification of the whole sequence
            instead of per-token classification). It is the first token of the sequence when built with special tokens.
        mask_token (`str`, *optional*, defaults to `"[MASK]"`):
            The token used for masking values. This is the token used when training this model with masked language
            modeling. This is the token which the model will try to predict.
        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
            Whether or not to tokenize Chinese characters.

            This should likely be deactivated for Japanese (see this
            [issue](https://github.com/huggingface/transformers/issues/328)).
        strip_accents (`bool`, *optional*):
            Whether or not to strip all accents. If this option is not specified, then it will be determined by the
            value for `lowercase` (as in the original BERT).
    """

    vocab_files_names = VOCAB_FILES_NAMES
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        vocab_file,
        do_lower_case=True,
        do_basic_tokenize=True,
        never_split=None,
        unk_token="[UNK]",
        sep_token="[SEP]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        mask_token="[MASK]",
        tokenize_chinese_chars=True,
        strip_accents=None,
        **kwargs,
    ):
        if not os.path.isfile(vocab_file):
            raise ValueError(
                f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
                " model use `tokenizer = DistilBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
            )
        self.vocab = load_vocab(vocab_file)
        self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
        self.do_basic_tokenize = do_basic_tokenize
        if do_basic_tokenize:
            self.basic_tokenizer = BasicTokenizer(
                do_lower_case=do_lower_case,
                never_split=never_split,
                tokenize_chinese_chars=tokenize_chinese_chars,
                strip_accents=strip_accents,
            )
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))

        super().__init__(
            do_lower_case=do_lower_case,
            do_basic_tokenize=do_basic_tokenize,
            never_split=never_split,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            tokenize_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            **kwargs,
        )

    @property
    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.do_lower_case
    def do_lower_case(self):
        return self.basic_tokenizer.do_lower_case

    @property
    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size
    def vocab_size(self):
        return len(self.vocab)

    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab
    def get_vocab(self):
        return dict(self.vocab, **self.added_tokens_encoder)

    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
    def _tokenize(self, text, split_special_tokens=False):
        split_tokens = []
        if self.do_basic_tokenize:
            for token in self.basic_tokenizer.tokenize(
                text, never_split=self.all_special_tokens if not split_special_tokens else None
            ):
                # If the token is part of the never_split set
                if token in self.basic_tokenizer.never_split:
                    split_tokens.append(token)
                else:
                    split_tokens += self.wordpiece_tokenizer.tokenize(token)
        else:
            split_tokens = self.wordpiece_tokenizer.tokenize(text)
        return split_tokens

    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id
    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        return self.vocab.get(token, self.vocab.get(self.unk_token))

    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token
    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        return self.ids_to_tokens.get(index, self.unk_token)

    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string
    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        out_string = " ".join(tokens).replace(" ##", "").strip()
        return out_string

    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens
    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A BERT sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + token_ids_1 + sep

    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask
    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        if token_ids_1 is not None:
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1]

    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences
    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
        pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        index = 0
        if os.path.isdir(save_directory):
            vocab_file = os.path.join(
                save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
            )
        else:
            vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
        with open(vocab_file, "w", encoding="utf-8") as writer:
            for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
                if index != token_index:
                    # logger.warning(
                    #     f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
                    #     " Please check that the vocabulary is not corrupted!"
                    # )
                    index = token_index
                writer.write(token + "\n")
                index += 1
        return (vocab_file,)


# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
class BasicTokenizer(object):
    """
    Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).

    Args:
        do_lower_case (`bool`, *optional*, defaults to `True`):
            Whether or not to lowercase the input when tokenizing.
        never_split (`Iterable`, *optional*):
            Collection of tokens which will never be split during tokenization. Only has an effect when
            `do_basic_tokenize=True`
        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
            Whether or not to tokenize Chinese characters.

            This should likely be deactivated for Japanese (see this
            [issue](https://github.com/huggingface/transformers/issues/328)).
        strip_accents (`bool`, *optional*):
            Whether or not to strip all accents. If this option is not specified, then it will be determined by the
            value for `lowercase` (as in the original BERT).
        do_split_on_punc (`bool`, *optional*, defaults to `True`):
            In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
            the full context of the words, such as contractions.
    """

    def __init__(
        self,
        do_lower_case=True,
        never_split=None,
        tokenize_chinese_chars=True,
        strip_accents=None,
        do_split_on_punc=True,
    ):
        if never_split is None:
            never_split = []
        self.do_lower_case = do_lower_case
        self.never_split = set(never_split)
        self.tokenize_chinese_chars = tokenize_chinese_chars
        self.strip_accents = strip_accents
        self.do_split_on_punc = do_split_on_punc

    def tokenize(self, text, never_split=None):
        """
        Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.

        Args:
            never_split (`List[str]`, *optional*)
                Kept for backward compatibility purposes. Now implemented directly at the base class level (see
                [`PreTrainedTokenizer.tokenize`]) List of token not to split.
        """
        # union() returns a new set by concatenating the two sets.
        never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
        text = self._clean_text(text)

        # This was added on November 1st, 2018 for the multilingual and Chinese
        # models. This is also applied to the English models now, but it doesn't
        # matter since the English models were not trained on any Chinese data
        # and generally don't have any Chinese data in them (there are Chinese
        # characters in the vocabulary because Wikipedia does have some Chinese
        # words in the English Wikipedia.).
        if self.tokenize_chinese_chars:
            text = self._tokenize_chinese_chars(text)
        # prevents treating the same character with different unicode codepoints as different characters
        unicode_normalized_text = unicodedata.normalize("NFC", text)
        orig_tokens = whitespace_tokenize(unicode_normalized_text)
        split_tokens = []
        for token in orig_tokens:
            if token not in never_split:
                if self.do_lower_case:
                    token = token.lower()
                    if self.strip_accents is not False:
                        token = self._run_strip_accents(token)
                elif self.strip_accents:
                    token = self._run_strip_accents(token)
            split_tokens.extend(self._run_split_on_punc(token, never_split))

        output_tokens = whitespace_tokenize(" ".join(split_tokens))
        return output_tokens

    def _run_strip_accents(self, text):
        """Strips accents from a piece of text."""
        text = unicodedata.normalize("NFD", text)
        output = []
        for char in text:
            cat = unicodedata.category(char)
            if cat == "Mn":
                continue
            output.append(char)
        return "".join(output)

    def _run_split_on_punc(self, text, never_split=None):
        """Splits punctuation on a piece of text."""
        if not self.do_split_on_punc or (never_split is not None and text in never_split):
            return [text]
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        while i < len(chars):
            char = chars[i]
            if _is_punctuation(char):
                output.append([char])
                start_new_word = True
            else:
                if start_new_word:
                    output.append([])
                start_new_word = False
                output[-1].append(char)
            i += 1

        return ["".join(x) for x in output]

    def _tokenize_chinese_chars(self, text):
        """Adds whitespace around any CJK character."""
        output = []
        for char in text:
            cp = ord(char)
            if self._is_chinese_char(cp):
                output.append(" ")
                output.append(char)
                output.append(" ")
            else:
                output.append(char)
        return "".join(output)

    def _is_chinese_char(self, cp):
        """Checks whether CP is the codepoint of a CJK character."""
        # This defines a "chinese character" as anything in the CJK Unicode block:
        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
        #
        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
        # despite its name. The modern Korean Hangul alphabet is a different block,
        # as is Japanese Hiragana and Katakana. Those alphabets are used to write
        # space-separated words, so they are not treated specially and handled
        # like the all of the other languages.
        if (
            (cp >= 0x4E00 and cp <= 0x9FFF)
            or (cp >= 0x3400 and cp <= 0x4DBF)  #
            or (cp >= 0x20000 and cp <= 0x2A6DF)  #
            or (cp >= 0x2A700 and cp <= 0x2B73F)  #
            or (cp >= 0x2B740 and cp <= 0x2B81F)  #
            or (cp >= 0x2B820 and cp <= 0x2CEAF)  #
            or (cp >= 0xF900 and cp <= 0xFAFF)
            or (cp >= 0x2F800 and cp <= 0x2FA1F)  #
        ):  #
            return True

        return False

    def _clean_text(self, text):
        """Performs invalid character removal and whitespace cleanup on text."""
        output = []
        for char in text:
            cp = ord(char)
            if cp == 0 or cp == 0xFFFD or _is_control(char):
                continue
            if _is_whitespace(char):
                output.append(" ")
            else:
                output.append(char)
        return "".join(output)


# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
class WordpieceTokenizer(object):
    """Runs WordPiece tokenization."""

    def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
        self.vocab = vocab
        self.unk_token = unk_token
        self.max_input_chars_per_word = max_input_chars_per_word

    def tokenize(self, text):
        """
        Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
        tokenization using the given vocabulary.

        For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.

        Args:
            text: A single token or whitespace separated tokens. This should have
                already been passed through *BasicTokenizer*.

        Returns:
            A list of wordpiece tokens.
        """

        output_tokens = []
        for token in whitespace_tokenize(text):
            chars = list(token)
            if len(chars) > self.max_input_chars_per_word:
                output_tokens.append(self.unk_token)
                continue

            is_bad = False
            start = 0
            sub_tokens = []
            while start < len(chars):
                end = len(chars)
                cur_substr = None
                while start < end:
                    substr = "".join(chars[start:end])
                    if start > 0:
                        substr = "##" + substr
                    if substr in self.vocab:
                        cur_substr = substr
                        break
                    end -= 1
                if cur_substr is None:
                    is_bad = True
                    break
                sub_tokens.append(cur_substr)
                start = end

            if is_bad:
                output_tokens.append(self.unk_token)
            else:
                output_tokens.extend(sub_tokens)
        return output_tokens
```

--------------------------------------------------------------------------------
/surya/detection/model/encoderdecoder.py:
--------------------------------------------------------------------------------

```python
"""
This is an implementation of efficientvit, with some modifications (decode head, etc).

Original paper at https://arxiv.org/abs/2205.14756

Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py
Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit
"""

from __future__ import annotations

from typing import Optional, Union, Tuple, List, Any
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.modeling_outputs import SemanticSegmenterOutput

from surya.common.pretrained import SuryaPreTrainedModel
from surya.common.s3 import S3DownloaderMixin
from surya.detection.model.config import EfficientViTConfig


def val2list(x: Union[List, Tuple, Any], repeat_time=1):
    if isinstance(x, (list, tuple)):
        return list(x)
    return [x for _ in range(repeat_time)]


def val2tuple(x: Union[List, Tuple, Any], min_len: int = 1, idx_repeat: int = -1):
    # repeat elements if necessary
    x = val2list(x)
    if len(x) > 0:
        x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]

    return tuple(x)


def get_same_padding(
    kernel_size: Union[int, Tuple[int, ...]],
) -> Union[int, Tuple[int, ...]]:
    if isinstance(kernel_size, tuple):
        return tuple([get_same_padding(ks) for ks in kernel_size])
    else:
        assert kernel_size % 2 > 0, "kernel size should be odd number"
        return kernel_size // 2


def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int:
    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
    return padding


class ConvNormAct(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        dilation=1,
        groups=1,
        bias=False,
        dropout=0.0,
        norm_layer=nn.BatchNorm2d,
        act_layer=nn.ReLU,
    ):
        super(ConvNormAct, self).__init__()
        self.dropout = nn.Dropout(dropout, inplace=False)
        padding = get_padding(kernel_size, stride, dilation)
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding=padding,
        )
        self.norm = (
            norm_layer(num_features=out_channels) if norm_layer else nn.Identity()
        )
        self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


class DSConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        use_bias=False,
        norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
        act_layer=(nn.ReLU6, None),
    ):
        super(DSConv, self).__init__()
        use_bias = val2tuple(use_bias, 2)
        norm_layer = val2tuple(norm_layer, 2)
        act_layer = val2tuple(act_layer, 2)

        self.depth_conv = ConvNormAct(
            in_channels,
            in_channels,
            kernel_size,
            stride,
            groups=in_channels,
            norm_layer=norm_layer[0],
            act_layer=act_layer[0],
            bias=use_bias[0],
        )
        self.point_conv = ConvNormAct(
            in_channels,
            out_channels,
            1,
            norm_layer=norm_layer[1],
            act_layer=act_layer[1],
            bias=use_bias[1],
        )

    def forward(self, x):
        x = self.depth_conv(x)
        x = self.point_conv(x)
        return x


class ConvBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        mid_channels=None,
        expand_ratio=1,
        use_bias=False,
        norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
        act_layer=(nn.ReLU6, None),
    ):
        super(ConvBlock, self).__init__()
        use_bias = val2tuple(use_bias, 2)
        norm_layer = val2tuple(norm_layer, 2)
        act_layer = val2tuple(act_layer, 2)
        mid_channels = mid_channels or round(in_channels * expand_ratio)

        self.conv1 = ConvNormAct(
            in_channels,
            mid_channels,
            kernel_size,
            stride,
            norm_layer=norm_layer[0],
            act_layer=act_layer[0],
            bias=use_bias[0],
        )
        self.conv2 = ConvNormAct(
            mid_channels,
            out_channels,
            kernel_size,
            1,
            norm_layer=norm_layer[1],
            act_layer=act_layer[1],
            bias=use_bias[1],
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class MBConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        mid_channels=None,
        expand_ratio=6,
        use_bias=False,
        norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d),
        act_layer=(nn.ReLU6, nn.ReLU6, None),
    ):
        super(MBConv, self).__init__()
        use_bias = val2tuple(use_bias, 3)
        norm_layer = val2tuple(norm_layer, 3)
        act_layer = val2tuple(act_layer, 3)
        mid_channels = mid_channels or round(in_channels * expand_ratio)

        self.inverted_conv = ConvNormAct(
            in_channels,
            mid_channels,
            1,
            stride=1,
            norm_layer=norm_layer[0],
            act_layer=act_layer[0],
            bias=use_bias[0],
        )
        self.depth_conv = ConvNormAct(
            mid_channels,
            mid_channels,
            kernel_size,
            stride=stride,
            groups=mid_channels,
            norm_layer=norm_layer[1],
            act_layer=act_layer[1],
            bias=use_bias[1],
        )
        self.point_conv = ConvNormAct(
            mid_channels,
            out_channels,
            1,
            norm_layer=norm_layer[2],
            act_layer=act_layer[2],
            bias=use_bias[2],
        )

    def forward(self, x):
        x = self.inverted_conv(x)
        x = self.depth_conv(x)
        x = self.point_conv(x)
        return x


class FusedMBConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        mid_channels=None,
        expand_ratio=6,
        groups=1,
        use_bias=False,
        norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
        act_layer=(nn.ReLU6, None),
    ):
        super(FusedMBConv, self).__init__()
        use_bias = val2tuple(use_bias, 2)
        norm_layer = val2tuple(norm_layer, 2)
        act_layer = val2tuple(act_layer, 2)
        mid_channels = mid_channels or round(in_channels * expand_ratio)

        self.spatial_conv = ConvNormAct(
            in_channels,
            mid_channels,
            kernel_size,
            stride=stride,
            groups=groups,
            norm_layer=norm_layer[0],
            act_layer=act_layer[0],
            bias=use_bias[0],
        )
        self.point_conv = ConvNormAct(
            mid_channels,
            out_channels,
            1,
            norm_layer=norm_layer[1],
            act_layer=act_layer[1],
            bias=use_bias[1],
        )

    def forward(self, x):
        x = self.spatial_conv(x)
        x = self.point_conv(x)
        return x


class LiteMLA(nn.Module):
    """Lightweight multi-scale linear attention"""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        heads: Union[int, None] = None,
        heads_ratio: float = 1.0,
        dim=8,
        use_bias=False,
        norm_layer=(None, nn.BatchNorm2d),
        act_layer=(None, None),
        kernel_func=nn.ReLU,
        scales=(5,),
        eps=1e-5,
    ):
        super(LiteMLA, self).__init__()
        self.eps = eps
        heads = heads or int(in_channels // dim * heads_ratio)
        total_dim = heads * dim
        use_bias = val2tuple(use_bias, 2)
        norm_layer = val2tuple(norm_layer, 2)
        act_layer = val2tuple(act_layer, 2)

        self.dim = dim
        self.qkv = ConvNormAct(
            in_channels,
            3 * total_dim,
            1,
            bias=use_bias[0],
            norm_layer=norm_layer[0],
            act_layer=act_layer[0],
        )
        self.aggreg = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv2d(
                        3 * total_dim,
                        3 * total_dim,
                        scale,
                        padding=get_same_padding(scale),
                        groups=3 * total_dim,
                        bias=use_bias[0],
                    ),
                    nn.Conv2d(
                        3 * total_dim,
                        3 * total_dim,
                        1,
                        groups=3 * heads,
                        bias=use_bias[0],
                    ),
                )
                for scale in scales
            ]
        )
        self.kernel_func = kernel_func(inplace=False)

        self.proj = ConvNormAct(
            total_dim * (1 + len(scales)),
            out_channels,
            1,
            bias=use_bias[1],
            norm_layer=norm_layer[1],
            act_layer=act_layer[1],
        )

    def _attn(self, q, k, v):
        dtype = v.dtype
        q, k, v = q.float(), k.float(), v.float()
        kv = k.transpose(-1, -2) @ v
        out = q @ kv
        out = out[..., :-1] / (out[..., -1:] + self.eps)
        return out.to(dtype)

    def forward(self, x):
        # Shape is B, C, H, W
        B, _, H, W = x.shape

        # generate multi-scale q, k, v
        qkv = self.qkv(x)
        multi_scale_qkv = [qkv]
        for op in self.aggreg:
            multi_scale_qkv.append(op(qkv))
        multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
        multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(
            -1, -2
        )
        # Shape for each is B, C, HW, head_dim
        q, k, v = multi_scale_qkv.chunk(3, dim=-1)

        # lightweight global attention
        q = self.kernel_func(q)
        k = self.kernel_func(k)
        v = F.pad(v, (0, 1), mode="constant", value=1.0)

        out = self._attn(q, k, v)

        # final projection
        out = out.transpose(-1, -2).reshape(B, -1, H, W)
        out = self.proj(out)
        return out


class EfficientVitBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        heads_ratio=1.0,
        head_dim=32,
        expand_ratio=4,
        norm_layer=nn.BatchNorm2d,
        act_layer=nn.Hardswish,
    ):
        super(EfficientVitBlock, self).__init__()
        self.context_module = ResidualBlock(
            LiteMLA(
                in_channels=in_channels,
                out_channels=in_channels,
                heads_ratio=heads_ratio,
                dim=head_dim,
                norm_layer=(None, norm_layer),
            ),
            nn.Identity(),
        )
        self.local_module = ResidualBlock(
            MBConv(
                in_channels=in_channels,
                out_channels=in_channels,
                expand_ratio=expand_ratio,
                use_bias=(True, True, False),
                norm_layer=(None, None, norm_layer),
                act_layer=(act_layer, act_layer, None),
            ),
            nn.Identity(),
        )

    def forward(self, x):
        x = self.context_module(x)
        x = self.local_module(x)
        return x


class ResidualBlock(nn.Module):
    def __init__(
        self,
        main: Optional[nn.Module],
        shortcut: Optional[nn.Module] = None,
        pre_norm: Optional[nn.Module] = None,
    ):
        super(ResidualBlock, self).__init__()
        self.pre_norm = pre_norm if pre_norm is not None else nn.Identity()
        self.main = main
        self.shortcut = shortcut

    def forward(self, x):
        res = self.main(self.pre_norm(x))
        if self.shortcut is not None:
            res = res + self.shortcut(x)
        return res


def build_local_block(
    in_channels: int,
    out_channels: int,
    stride: int,
    kernel_size: int,
    expand_ratio: float,
    norm_layer: str,
    act_layer: str,
    fewer_norm: bool = False,
    block_type: str = "default",
):
    assert block_type in ["default", "large", "fused"]
    if expand_ratio == 1:
        if block_type == "default":
            block = DSConv(
                in_channels=in_channels,
                out_channels=out_channels,
                stride=stride,
                kernel_size=kernel_size,
                use_bias=(True, False) if fewer_norm else False,
                norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
                act_layer=(act_layer, None),
            )
        else:
            block = ConvBlock(
                in_channels=in_channels,
                out_channels=out_channels,
                stride=stride,
                kernel_size=kernel_size,
                use_bias=(True, False) if fewer_norm else False,
                norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
                act_layer=(act_layer, None),
            )
    else:
        if block_type == "default":
            block = MBConv(
                in_channels=in_channels,
                out_channels=out_channels,
                stride=stride,
                kernel_size=kernel_size,
                expand_ratio=expand_ratio,
                use_bias=(True, True, False) if fewer_norm else False,
                norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer,
                act_layer=(act_layer, act_layer, None),
            )
        else:
            block = FusedMBConv(
                in_channels=in_channels,
                out_channels=out_channels,
                stride=stride,
                kernel_size=kernel_size,
                expand_ratio=expand_ratio,
                use_bias=(True, False) if fewer_norm else False,
                norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
                act_layer=(act_layer, None),
            )
    return block


class Stem(nn.Sequential):
    def __init__(
        self,
        in_chs,
        out_chs,
        depth,
        stride,
        norm_layer,
        act_layer,
        block_type="default",
    ):
        super().__init__()
        self.stride = stride

        self.add_module(
            "in_conv",
            ConvNormAct(
                in_chs,
                out_chs,
                kernel_size=stride + 1,
                stride=stride,
                norm_layer=norm_layer,
                act_layer=act_layer,
            ),
        )
        stem_block = 0
        for _ in range(depth):
            self.add_module(
                f"res{stem_block}",
                ResidualBlock(
                    build_local_block(
                        in_channels=out_chs,
                        out_channels=out_chs,
                        stride=1,
                        kernel_size=3,
                        expand_ratio=1,
                        norm_layer=norm_layer,
                        act_layer=act_layer,
                        block_type=block_type,
                    ),
                    nn.Identity(),
                ),
            )
            stem_block += 1


class EfficientVitLargeStage(nn.Module):
    def __init__(
        self,
        in_chs,
        out_chs,
        depth,
        stride,
        norm_layer,
        act_layer,
        head_dim,
        vit_stage=False,
        fewer_norm=False,
    ):
        super(EfficientVitLargeStage, self).__init__()
        blocks = [
            ResidualBlock(
                build_local_block(
                    in_channels=in_chs,
                    out_channels=out_chs,
                    stride=stride,
                    kernel_size=stride + 1,
                    expand_ratio=24 if vit_stage else 16,
                    norm_layer=norm_layer,
                    act_layer=act_layer,
                    fewer_norm=vit_stage or fewer_norm,
                    block_type="default" if fewer_norm else "fused",
                ),
                None,
            )
        ]
        in_chs = out_chs

        if vit_stage:
            # for stage 4
            for _ in range(depth):
                blocks.append(
                    EfficientVitBlock(
                        in_channels=in_chs,
                        head_dim=head_dim,
                        expand_ratio=6,
                        norm_layer=norm_layer,
                        act_layer=act_layer,
                    )
                )
        else:
            # for stage 1, 2, 3
            for i in range(depth):
                blocks.append(
                    ResidualBlock(
                        build_local_block(
                            in_channels=in_chs,
                            out_channels=out_chs,
                            stride=1,
                            kernel_size=3,
                            expand_ratio=4,
                            norm_layer=norm_layer,
                            act_layer=act_layer,
                            fewer_norm=fewer_norm,
                            block_type="default" if fewer_norm else "fused",
                        ),
                        nn.Identity(),
                    )
                )

        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        return self.blocks(x)


class EfficientVitLarge(nn.Module):
    def __init__(
        self,
        config: EfficientViTConfig,
        norm_layer=nn.BatchNorm2d,
        act_layer=nn.Hardswish,
    ):
        super(EfficientVitLarge, self).__init__()
        self.grad_checkpointing = False
        self.num_classes = config.num_classes
        self.norm_eps = config.layer_norm_eps
        norm_layer = partial(norm_layer, eps=self.norm_eps)

        # input stem
        self.stem = Stem(
            config.num_channels,
            config.widths[0],
            config.depths[0],
            config.strides[0],
            norm_layer,
            act_layer,
            block_type="large",
        )
        stride = config.strides[0]

        # stages
        self.feature_info = []
        self.stages = nn.Sequential()
        in_channels = config.widths[0]
        for i, (w, d, s) in enumerate(
            zip(config.widths[1:], config.depths[1:], config.strides[1:])
        ):
            self.stages.append(
                EfficientVitLargeStage(
                    in_channels,
                    w,
                    depth=d,
                    stride=s,
                    norm_layer=norm_layer,
                    act_layer=act_layer,
                    head_dim=config.head_dim,
                    vit_stage=i >= 3,
                    fewer_norm=i >= 2,
                )
            )
            stride *= s
            in_channels = w
            self.feature_info += [
                dict(num_chs=in_channels, reduction=stride, module=f"stages.{i}")
            ]

        self.num_features = in_channels

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable

    def forward(self, x):
        x = self.stem(x)
        encoder_hidden_states = []
        for i, module in enumerate(self.stages):
            x = module(x)
            encoder_hidden_states.append(x)

        return encoder_hidden_states


class EfficientViTPreTrainedModel(SuryaPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = EfficientViTConfig
    base_model_prefix = "efficientvit"
    main_input_name = "pixel_values"

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


class DecodeMLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.proj = nn.Linear(input_dim, output_dim)

    def forward(self, hidden_states: torch.Tensor):
        # Input is B, C, H, W
        hidden_states = hidden_states.flatten(2).transpose(1, 2)
        # Output is B, HW, C
        hidden_states = self.proj(hidden_states)
        return hidden_states


class DecodeHead(EfficientViTPreTrainedModel):
    def __init__(self, config: EfficientViTConfig):
        super().__init__(config)

        # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
        mlps = []
        for width in config.widths[1:]:
            mlp = DecodeMLP(
                input_dim=width, output_dim=config.decoder_layer_hidden_size
            )
            mlps.append(mlp)
        self.linear_c = nn.ModuleList(mlps)

        # the following 3 layers implement the ConvModule of the original implementation
        self.linear_fuse = nn.Conv2d(
            in_channels=config.decoder_layer_hidden_size * config.num_stages,
            out_channels=config.decoder_hidden_size,
            kernel_size=1,
            bias=False,
        )
        self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size)
        self.activation = nn.ReLU()

        self.dropout = nn.Dropout(config.classifier_dropout_prob)
        self.classifier = nn.Conv2d(
            config.decoder_hidden_size, config.num_labels, kernel_size=1
        )

        self.config = config

    def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
        batch_size = encoder_hidden_states[-1].shape[0]

        all_hidden_states = ()
        for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
            height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
            encoder_hidden_state = mlp(encoder_hidden_state)  # Output is B, HW, C
            # Permute to B, C, HW
            encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
            encoder_hidden_state = encoder_hidden_state.reshape(
                batch_size, -1, height, width
            )
            # upsample
            encoder_hidden_state = nn.functional.interpolate(
                encoder_hidden_state,
                size=encoder_hidden_states[0].size()[2:],
                mode="bilinear",
                align_corners=False,
            )
            all_hidden_states += (encoder_hidden_state,)

        hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
        hidden_states = self.batch_norm(hidden_states)
        hidden_states = self.activation(hidden_states)

        # logits are of shape (batch_size, num_labels, height/4, width/4)
        logits = self.classifier(hidden_states)

        return logits


class EfficientViTForSemanticSegmentation(
    S3DownloaderMixin, EfficientViTPreTrainedModel
):
    def __init__(self, config, **kwargs):
        super().__init__(config)
        self.vit = EfficientVitLarge(config)
        self.decode_head = DecodeHead(config)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self, pixel_values: torch.FloatTensor
    ) -> Union[Tuple, SemanticSegmenterOutput]:
        # Pixel values should be B,C,H,W
        encoder_hidden_states = self.vit(
            pixel_values,
        )

        logits = self.decode_head(encoder_hidden_states)

        # Apply sigmoid to get 0-1 output
        logits = torch.special.expit(logits)

        return SemanticSegmenterOutput(
            loss=None, logits=logits, hidden_states=encoder_hidden_states
        )


class EfficientViTForSemanticLayoutSegmentation(EfficientViTPreTrainedModel):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.vit = EfficientVitLarge(config)
        self.decode_head = DecodeHead(config)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self, pixel_values: torch.FloatTensor
    ) -> Union[Tuple, SemanticSegmenterOutput]:
        # Pixel values should be B,C,H,W
        encoder_hidden_states = self.vit(
            pixel_values,
        )

        logits = self.decode_head(encoder_hidden_states)

        # Apply sigmoid to get 0-1 output
        logits = torch.special.expit(logits)

        return SemanticSegmenterOutput(
            loss=None, logits=logits, hidden_states=encoder_hidden_states
        )

```

--------------------------------------------------------------------------------
/surya/common/surya/processor/tokenizer.py:
--------------------------------------------------------------------------------

```python
import html
import re
from typing import List, Union, Dict, Optional, Tuple, Iterable
import numpy as np
import torch
from tokenizers import AddedToken
import json
import os
from transformers import PreTrainedTokenizer, Qwen2Tokenizer as Qwen2OriginalTokenizer


from surya.common.s3 import S3DownloaderMixin
from surya.common.surya.schema import TASK_NAMES, TaskNames
from surya.logging import get_logger
from surya.settings import settings

logger = get_logger()


def create_token_regex(tokens):
    escaped_tokens = [re.escape(token) for token in tokens]
    escaped_tokens.sort(key=len, reverse=True)
    pattern = r"^(" + "|".join(escaped_tokens) + r")"
    regex = re.compile(pattern)
    return regex


class InnerOCRTokenizer:
    def __init__(
        self,
        special_tokens: Dict[str, list] | None = None,
        qwen_tokenizer: Qwen2OriginalTokenizer | None = None,
        **kwargs,
    ):
        self.qwen_tokenizer = qwen_tokenizer
        self.qwen_token_offset = len(qwen_tokenizer)

        all_special_tokens = special_tokens.get("all", [])
        self.SPECIAL_TOKEN_MAPPING = {}

        idx = 0
        for tag in all_special_tokens:
            if tag in self.SPECIAL_TOKEN_MAPPING:
                continue
            self.SPECIAL_TOKEN_MAPPING[tag] = (
                idx + self.qwen_token_offset
            )  # Assign token ID
            idx += 1

        self.REVERSE_SPECIAL_TOKEN_MAPPING = {
            v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items()
        }
        self.SPECIAL_TOKEN_OFFSET = idx
        self.FORMAT_TAG_PATTERN = create_token_regex(special_tokens["formatting"])
        self.MATH_TAG_PATTERN = create_token_regex(special_tokens["math_external"])
        self.LAYOUT_TAG_PATTERN = create_token_regex(special_tokens["layout"])
        self.TABLE_STRUCTURE_TAG_PATTERN = create_token_regex(
            special_tokens["table_structure"]
        )
        self.SYSTEM_TAG_PATTERN = create_token_regex(special_tokens.get("system", []))
        if not special_tokens.get("system", []):
            logger.warning("Warning: No system tokens found in special_tokens")

        self.MATH_TAG_START = "<math"
        self.MATH_END_TAG = "</math>"

        super().__init__(**kwargs)

    @property
    def vocab_size(self):
        return (
            65536 + self.SPECIAL_TOKEN_OFFSET
        )  # The highest codepoint is 65535, but we add 1 to account for the 0-indexing

    def _tokenize(self, text: str) -> List[int]:
        tokens = []
        in_math = False
        text = html.unescape(text)  # Unescape html entities like &lt; in equations
        while text:
            # Look for EOS, PAD, etc. tokens
            match = self.SYSTEM_TAG_PATTERN.search(text)
            if match:
                tag = match.group(1)
                tokens.append(
                    self.SPECIAL_TOKEN_MAPPING[tag]
                )  # These are already offset
                text = text[match.end() :]
                continue

            # Look for layout tokens
            match = self.LAYOUT_TAG_PATTERN.search(text)
            if match:
                tag = match.group(1)
                tokens.append(
                    self.SPECIAL_TOKEN_MAPPING[tag]
                )  # Layout tokens are already offset
                text = text[match.end() :]
                continue

            match = self.TABLE_STRUCTURE_TAG_PATTERN.search(text)
            if match:
                tag = match.group(1)
                tokens.append(self.SPECIAL_TOKEN_MAPPING[tag])
                text = text[match.end() :]
                continue

            # Check for math tags
            match = self.MATH_TAG_PATTERN.search(text)
            if match:
                # We found a tag
                tag = match.group(1)
                if tag.startswith(self.MATH_TAG_START):
                    in_math = True
                elif tag == self.MATH_END_TAG:
                    in_math = False
                tokens.append(
                    self.SPECIAL_TOKEN_MAPPING[tag]  # Special tokens are already offset
                )  # Use special token ID
                text = text[match.end() :]
                continue

            # Tokenize math content with qwen2 tokenizer
            if in_math:
                # If we're in a math block, check to see if we have a special math tag in the text
                math_end_position = text.find(self.MATH_END_TAG)
                math_str = text[:math_end_position]  # Gets the math content
                tokens += self.qwen_tokenizer(math_str)["input_ids"]
                text = text[math_end_position:]
                continue

            # Check for formatting tags
            match = self.FORMAT_TAG_PATTERN.search(text)
            if match:
                # We found a tag
                tag = match.group(1)
                tokens.append(
                    self.SPECIAL_TOKEN_MAPPING[tag]  # Special tokens are already offset
                )  # Use special token ID
                text = text[match.end() :]
                continue

            # General case, utf-16 tokenization
            utf_16_tokens = self.text_to_utf16_numbers(text[0])
            tokens += [
                t + self.SPECIAL_TOKEN_OFFSET + self.qwen_token_offset
                for t in utf_16_tokens
            ]
            text = text[1:]

        return tokens

    def text_to_utf16_numbers(self, text: str):
        """Converts text to UTF-16 encoded numbers."""
        utf16_bytes = text.encode(
            "utf-16le"
        )  # Little-endian to simplify byte order handling
        numbers = []

        for i in range(0, len(utf16_bytes), 2):
            # Combine two adjacent bytes into a single number
            number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8)
            numbers.append(number)

        return numbers

    def utf16_numbers_to_text(self, numbers):
        """Converts UTF-16 numbers back to text."""
        byte_array = bytearray()
        for number in numbers:
            byte_array.append(number & 0xFF)  # Lower byte
            byte_array.append((number >> 8) & 0xFF)  # Upper byte

        try:
            text = byte_array.decode("utf-16le", errors="ignore")
        except Exception as e:
            logger.warning(f"Error decoding utf16: {e}")
            text = ""

        return text

    def __call__(
        self, texts: Union[str, List[str]], **kwargs
    ) -> Dict[str, List[List[int]]]:
        """Tokenizes text and returns input IDs."""
        tokenized = []

        if isinstance(texts, str):
            texts = [texts]

        for text in texts:
            tokens = self._tokenize(text)
            tokenized.append(tokens)

        return {"input_ids": tokenized}

    def decode(self, token_ids, **kwargs):
        """Decodes token IDs back to text."""
        if isinstance(token_ids, (np.ndarray, torch.Tensor)):
            token_ids = token_ids.tolist()

        decoded_text = ""
        token_buffer = []
        decode_qwen = [False]

        def decode_buffer():
            nonlocal decoded_text, token_buffer, decode_qwen
            if token_buffer:
                if decode_qwen[0]:
                    decoded_text += self.qwen_tokenizer.decode(token_buffer)
                else:
                    token_buffer = [
                        t - self.SPECIAL_TOKEN_OFFSET - self.qwen_token_offset
                        for t in token_buffer
                    ]
                    decoded_text += self.utf16_numbers_to_text(token_buffer)

            token_buffer = []
            decode_qwen[0] = False

        for t in token_ids:
            if t < self.qwen_token_offset:
                # This is for math tags
                if token_buffer and token_buffer[-1] >= self.qwen_token_offset:
                    decode_buffer()
                token_buffer.append(t)
                decode_qwen[0] = True
            elif t >= self.SPECIAL_TOKEN_OFFSET + self.qwen_token_offset:
                if token_buffer and token_buffer[-1] < self.qwen_token_offset:
                    decode_buffer()
                token_buffer.append(t)  # We shift this down later on
                decode_qwen[0] = False
            elif t in self.REVERSE_SPECIAL_TOKEN_MAPPING:
                decode_buffer()
                decoded_text += self.REVERSE_SPECIAL_TOKEN_MAPPING[t]
                decode_qwen[0] = False
            else:
                raise ValueError(
                    f'Unexpected token value while decoding, got "{t}" in token_ids {token_ids}'
                )

        # Detokenize remaining tokens
        decode_buffer()

        return decoded_text


class Qwen2Tokenizer(S3DownloaderMixin, Qwen2OriginalTokenizer):
    pass

class GreedyMathUTF16Tokenizer(S3DownloaderMixin, PreTrainedTokenizer):
    """
    HuggingFace slow tokenizer implementing:
      - UTF-16 code units as the base [0..65535]
      - Math tokens as greedy-longest-match ids after UTF-16
      - Literal special tokens after math tokens
    Absolute ID layout:
      [0 .. 65535]                      : UTF-16 units
      [65536 .. 65536+M-1]              : math tokens
      [65536+M .. 65536+M+S-1]          : special tokens
    """

    vocab_files_names = {
        "vocab_file": "vocab_math.json",  # {"\\frac": 0, "\\alpha": 1, ...} raw contiguous ids 0..M-1
        "specials_file": "specials.json",  # [flat list for legacy]
        "specials_dict_file": "specials_dict.json",  # category dict (preferred)
    }
    model_input_names = ["input_ids", "attention_mask"]
    is_fast = False

    # ---------- helpers ----------
    @staticmethod
    def _to_utf16_units(s: str) -> List[int]:
        b = s.encode("utf-16le")
        return [int.from_bytes(b[i : i + 2], "little") for i in range(0, len(b), 2)]

    @staticmethod
    def _from_utf16_units(units: List[int]) -> str:
        b = bytearray()
        for u in units:
            b += int(u).to_bytes(2, "little")
        return b.decode("utf-16le", errors="ignore")

    class _TrieNode:
        __slots__ = ("child", "id", "leaf")

        def __init__(self):
            self.child: Dict[str, "GreedyMathUTF16Tokenizer._TrieNode"] = {}
            self.id: Optional[int] = None
            self.leaf: bool = False

    @classmethod
    def _build_trie(
        cls, token_to_id: Dict[str, int]
    ) -> "GreedyMathUTF16Tokenizer._TrieNode":
        root = cls._TrieNode()
        for tok, tid in token_to_id.items():
            node = root
            for ch in tok:
                node = node.child.setdefault(ch, cls._TrieNode())
            node.leaf = True
            node.id = tid
        return root

    @classmethod
    def _encode_math_greedy(
        cls,
        s: str,
        trie: "GreedyMathUTF16Tokenizer._TrieNode",
        math_base: int,
        debug: bool = False,
    ) -> List[int]:
        i, n = 0, len(s)
        out: List[int] = []
        while i < n:
            node = trie
            j = i
            last_id = None
            last_j = i
            while j < n and (ch := s[j]) in node.child:
                node = node.child[ch]
                j += 1
                if node.leaf:
                    last_id, last_j = node.id, j
            if last_id is not None:
                if debug:
                    print(f"[MATH] matched {s[i:last_j]!r} -> {last_id}")
                out.append(math_base + last_id)
                i = last_j
            else:
                units = cls._to_utf16_units(s[i])
                if debug:
                    print(f"[MATH] fallback {s[i]!r} -> utf16 {units}")
                out.extend(units)
                i += 1
        return out

    # ---------- init ----------
    def __init__(
        self,
        vocab_file: Optional[str] = None,
        specials_file: Optional[str] = None,
        specials_dict_file: Optional[str] = None,
        *,
        # You can also pass programmatically instead of files:
        math_vocab: Optional[Dict[str, int]] = None,
        special_tokens: Optional[List[str]] = None,
        special_tokens_dict: Optional[Dict[str, List[str]]] = None,
        debug: bool = False,
        # Standard HF special token kwargs:
        bos_token: Optional[str] = None,
        eos_token: Optional[str] = None,
        pad_token: Optional[str] = None,
        unk_token: Optional[str] = None,
        **kwargs,
    ):
        # Load math vocab
        if vocab_file and os.path.isfile(vocab_file):
            with open(vocab_file, "r", encoding="utf-8") as f:
                mv = json.load(f)
        else:
            mv = math_vocab or {}

        # Make math ids contiguous if needed
        if mv:
            max_id = max(mv.values())
            if set(mv.values()) != set(range(max_id + 1)):
                items = sorted(mv.items(), key=lambda kv: kv[1])
                mv = {tok: i for i, (tok, _) in enumerate(items)}

        # Load special tokens (prefer category dict; fallback to flat list or defaults)
        sp_dict = None
        if specials_dict_file and os.path.isfile(specials_dict_file):
            with open(specials_dict_file, "r", encoding="utf-8") as f:
                sp_dict = json.load(f)
        elif special_tokens_dict is not None:
            sp_dict = dict(special_tokens_dict)

        if sp_dict is None:
            # Legacy path: flat list from file or provided/default list
            if specials_file and os.path.isfile(specials_file):
                with open(specials_file, "r", encoding="utf-8") as f:
                    sp_list_flat = json.load(f)
            else:
                sp_list_flat = special_tokens or SPECIAL_TOKENS
            sp_dict = {"all": list(sp_list_flat)}

        # Ensure "all" exists and is unique/preserved in order.
        if "all" not in sp_dict or not isinstance(sp_dict["all"], list):
            order = [
                "system",
                "formatting",
                "math_external",
                "script",
                "layout",
                "reasoning",
                "table_structure",
                "reserved",
            ]
            seen = set()
            all_tokens: List[str] = []
            for k in order:
                if k in sp_dict and isinstance(sp_dict[k], list):
                    for t in sp_dict[k]:
                        if t not in seen:
                            all_tokens.append(t)
                            seen.add(t)
            sp_dict["all"] = all_tokens

        # Keep a copy of categories (if present) for downstream processor logic.
        self.special_tokens = sp_dict
        sp_list = list(sp_dict.get("all", []))
        # Regex list should favor longest-first to avoid partial matches.
        specials_for_regex = sorted(sp_list, key=len, reverse=True)

        self.debug = debug
        self.UTF16_SPACE = 65536
        self.math_token_to_rawid = dict(mv)  # 0..M-1
        self.math_vocab_size = len(self.math_token_to_rawid)
        self.MATH_BASE = self.UTF16_SPACE
        self.SPECIAL_BASE = self.UTF16_SPACE + self.math_vocab_size

        # Maps
        self.math_absid_to_token = {
            self.MATH_BASE + rid: tok for tok, rid in self.math_token_to_rawid.items()
        }
        self.special_tokens_list = sp_list  # ID assignment order
        self.special_to_absid = {
            tok: self.SPECIAL_BASE + i for i, tok in enumerate(self.special_tokens_list)
        }
        self.absid_to_special = {v: k for k, v in self.special_to_absid.items()}

        # Public attributes for legacy/processor:
        # All specials mapping (token -> absolute id)
        self.SPECIAL_TOKEN_MAPPING: Dict[str, int] = dict(self.special_to_absid)
        # Subset used heavily by processor for quick access
        self.reverse_special_token_mapping = {
            v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items()
        }
        self.LAYOUT_LABEL2ID = {
            k: v
            for k, v in self.SPECIAL_TOKEN_MAPPING.items()
            if k in self.special_tokens["layout"]
        }
        self.TABLE_STRUCTURE_LABEL2ID = {
            k: v
            for k, v in self.SPECIAL_TOKEN_MAPPING.items()
            if k in self.special_tokens["table_structure"]
        }
        if not self.special_tokens.get("system", []):
            print("Warning: No system tokens found in special_tokens")

        self.MATH_TAG_START = "<math"
        self.MATH_END_TAG = "</math>"

        sys_list = self.special_tokens.get("system", [])
        self.system_tokens: Dict[str, int] = {
            t: self.special_to_absid[t] for t in sys_list if t in self.special_to_absid
        }

        # Regex for literal specials
        self.specials_pattern = (
            re.compile(r"(" + "|".join(re.escape(k) for k in specials_for_regex) + r")")
            if specials_for_regex
            else None
        )

        # Trie for math greedy match
        self.trie = self._build_trie(self.math_token_to_rawid)

        # Tell HF about special tokens (metadata)
        kwargs.setdefault("bos_token", bos_token)
        kwargs.setdefault("eos_token", eos_token or "</S>")
        kwargs.setdefault("pad_token", pad_token or "<PAD>")
        kwargs.setdefault("unk_token", unk_token)

        super().__init__(
            vocab_file=vocab_file,
            specials_file=specials_file,
            specials_dict_file=specials_dict_file,
            **kwargs,
        )

    # ---------- required HF surface ----------
    @property
    def vocab_size(self) -> int:
        return self.UTF16_SPACE + self.math_vocab_size + len(self.special_tokens_list)

    def get_vocab(self) -> Dict[str, int]:
        # Compact vocab: just math+specials with ABSOLUTE ids.
        v = {tok: self.MATH_BASE + rid for tok, rid in self.math_token_to_rawid.items()}
        v.update(self.special_to_absid)
        return v

    def __len__(self) -> int:
        return self.vocab_size

    # Core encode/decode on ABSOLUTE ids
    def _encode_core(self, text: str) -> List[int]:
        text = html.unescape(text)
        ids: List[int] = []
        in_math = False
        chunks = self.specials_pattern.split(text) if self.specials_pattern else [text]
        for chunk in chunks:
            if chunk in self.special_to_absid:
                ids.append(self.special_to_absid[chunk])
                if chunk.startswith("<math"):
                    in_math = True
                elif chunk.startswith("</math>"):
                    in_math = False
                if self.debug:
                    print(f"[TAG] {chunk!r} -> {self.special_to_absid[chunk]}")
                continue

            if in_math:
                ids.extend(
                    self._encode_math_greedy(
                        chunk, self.trie, self.MATH_BASE, debug=self.debug
                    )
                )
            else:
                units = self._to_utf16_units(chunk)
                if self.debug and units:
                    print(
                        f"[TEXT] utf16 {chunk[:32]!r} -> {units[:8]}{'...' if len(units) > 8 else ''}"
                    )
                ids.extend(units)
        return ids

    def _decode_core(self, ids: Iterable[int]) -> str:
        out: List[str] = []
        buf: List[int] = []

        def flush():
            if buf:
                out.append(self._from_utf16_units(buf))
                buf.clear()

        for tid in ids:
            if tid >= self.MATH_BASE and tid < self.SPECIAL_BASE:
                flush()
                out.append(self.math_absid_to_token.get(tid, ""))
            elif tid >= self.SPECIAL_BASE:
                flush()
                out.append(self.absid_to_special.get(tid, ""))
            else:
                buf.append(int(tid))
        flush()
        return "".join(out)

    # ---- Tokenizer interface ----
    def _tokenize(self, text: str, **kwargs) -> List[str]:
        ids = self._encode_core(text)
        toks: List[str] = []
        for i in ids:
            if i < self.MATH_BASE:
                toks.append(f"<U+{i:04X}>")
            elif i < self.SPECIAL_BASE:
                toks.append(self.math_absid_to_token.get(i, "<UNK_MATH>"))
            else:
                toks.append(self.absid_to_special.get(i, "<UNK_SPECIAL>"))
        return toks

    def _convert_token_to_id(self, token: str) -> int:
        if token.startswith("<U+") and token.endswith(">"):
            try:
                return int(token[3:-1], 16)  # UTF-16 unit
            except Exception:
                return self.unk_token_id if self.unk_token_id is not None else 0
        # math or specials
        if token in self.math_token_to_rawid:
            return self.MATH_BASE + self.math_token_to_rawid[token]
        if token in self.special_to_absid:
            return self.special_to_absid[token]
        # rare path: single-char token -> its UTF-16 unit
        if len(token) == 1:
            u = self._to_utf16_units(token)
            if len(u) == 1:
                return u[0]
        return self.unk_token_id if self.unk_token_id is not None else 0

    def _convert_id_to_token(self, index: int) -> str:
        if index < self.MATH_BASE:
            return f"<U+{index:04X}>"
        if index < self.SPECIAL_BASE:
            return self.math_absid_to_token.get(index, "<UNK_MATH>")
        return self.absid_to_special.get(index, "<UNK_SPECIAL>")

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        ids = [self._convert_token_to_id(t) for t in tokens]
        return self._decode_core(ids)

    def decode(self, token_ids, skip_special_tokens: bool = False, **kwargs) -> str:
        # Accept int, list, tuple, numpy, torch
        if hasattr(token_ids, "tolist"):
            token_ids = token_ids.tolist()
        elif isinstance(token_ids, int):
            token_ids = [token_ids]
        else:
            token_ids = list(token_ids)
        token_ids = [int(i) for i in token_ids]  # normalize early

        if skip_special_tokens:
            token_ids = [i for i in token_ids if i < self.SPECIAL_BASE]
        return self._decode_core(token_ids)

    # HF plumbing
    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        out = (
            list(token_ids_0)
            if token_ids_1 is None
            else list(token_ids_0) + list(token_ids_1)
        )
        # if self.eos_token_id is not None and (not out or out[-1] != self.eos_token_id):
        #     out.append(self.eos_token_id)
        return out

    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False,
    ) -> List[int]:
        def mask(seq: List[int]) -> List[int]:
            return [1 if i >= self.SPECIAL_BASE else 0 for i in seq]

        return (
            mask(token_ids_0)
            if token_ids_1 is None
            else mask(token_ids_0) + mask(token_ids_1)
        )

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        return [0] * (
            len(token_ids_0)
            if token_ids_1 is None
            else len(token_ids_0) + len(token_ids_1)
        )

    # Save/load raw assets
    def save_vocabulary(
        self, save_directory: str, filename_prefix: Optional[str] = None
    ) -> Tuple[str, str]:
        os.makedirs(save_directory, exist_ok=True)
        pre = (filename_prefix + "-") if filename_prefix else ""
        vocab_path = os.path.join(
            save_directory, pre + self.vocab_files_names["vocab_file"]
        )
        specials_path = os.path.join(
            save_directory, pre + self.vocab_files_names["specials_file"]
        )
        specials_dict_path = os.path.join(
            save_directory, pre + self.vocab_files_names["specials_dict_file"]
        )
        with open(vocab_path, "w", encoding="utf-8") as f:
            json.dump(self.math_token_to_rawid, f, ensure_ascii=False, indent=2)
        # Save both the flat list ("all") and the category dict (preferred)
        with open(specials_path, "w", encoding="utf-8") as f:
            json.dump(self.special_tokens_list, f, ensure_ascii=False, indent=2)
        with open(specials_dict_path, "w", encoding="utf-8") as f:
            json.dump(self.special_tokens, f, ensure_ascii=False, indent=2)
        return (vocab_path, specials_path)


class SuryaOCRTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
    def __init__(
        self,
        special_tokens: Dict[str, list] | None = None,
        model_checkpoint: str = settings.FOUNDATION_MODEL_CHECKPOINT,
        **kwargs,
    ):
        if special_tokens is None:
            special_tokens = dict()

        self.special_tokens = special_tokens

        self.ocr_tokenizer = GreedyMathUTF16Tokenizer.from_pretrained(
            model_checkpoint,
        )
        self.system_tokens = {
            v: self.ocr_tokenizer(v)["input_ids"][0]
            for v in special_tokens.get("system", [])
        }
        self.SPECIAL_TOKEN_MAPPING = self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING

        super().__init__(**kwargs)

    def get_vocab(self) -> Dict[str, int]:
        return self.ocr_tokenizer.get_vocab()

    def _add_tokens(
        self,
        new_tokens: Union[List[str], List[AddedToken]],
        special_tokens: bool = False,
    ) -> int:
        return self.ocr_tokenizer._add_tokens(
            new_tokens, special_tokens=special_tokens
        )

    @property
    def vocab_size(self):
        return self.ocr_tokenizer.vocab_size

    def _tokenize(self, text: str, **kwargs):
        # task = kwargs.get("task", TaskNames.ocr_with_boxes)
        # assert task in TASK_NAMES, f"Invalid task: {task}"

        tokens = self.ocr_tokenizer(text)["input_ids"]

        return tokens

    def __call__(
        self,
        texts: Union[str, List[str]],
        tasks: Union[str, List[str]] = None,
        **kwargs,
    ) -> Dict[str, List[List[int]]]:
        """Tokenizes text and returns input IDs."""
        tokenized = []

        if isinstance(texts, str):
            texts = [texts]
            assert isinstance(tasks, str), "Tasks must be a string if texts is a string"
            tasks = [tasks]

        if isinstance(texts, list):
            assert isinstance(tasks, list), "Tasks must be a list if texts is a list"

        for text, task in zip(texts, tasks):
            tokens = self._tokenize(text, task=task)
            tokenized.append(tokens)

        return {"input_ids": tokenized}

    def decode(self, token_ids, **kwargs):
        if isinstance(token_ids, (np.ndarray, torch.Tensor)):
            token_ids = token_ids.tolist()

        decoded_text = self.ocr_tokenizer.decode(token_ids, skip_special_tokens=False)
        # replace all <SCRIPT-...> tokens with empty strings
        decoded_text = re.sub(r"<SCRIPT-.*?>", "", decoded_text)
        # replace </S> with empty string
        decoded_text = re.sub(r"</S>", "", decoded_text)
        return decoded_text

```

--------------------------------------------------------------------------------
/surya/common/surya/__init__.py:
--------------------------------------------------------------------------------

```python
import warnings
from typing import Optional, Tuple, TypedDict
from dataclasses import dataclass

import torch
from torch import nn
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter

from surya.common.pretrained import SuryaPreTrainedModel
from surya.common.s3 import S3DownloaderMixin
from surya.common.surya.config import SuryaModelConfig
from surya.common.surya.decoder import SuryaDecoderModel
from surya.common.surya.embedder import SimpleTokenEmbedder
from surya.common.surya.encoder import SuryaEncoderModel
from surya.common.util import pad_to_batch_size, pad_to_batch_size_repeat
from surya.common.xla import get_nearest_pad
from surya.settings import settings

from surya.logging import get_logger

logger = get_logger()


@dataclass
class SuryaModelOutput(CausalLMOutputWithPast):
    bbox_logits: torch.FloatTensor = None
    lm_logits: torch.FloatTensor = None


class FlashAttentionKwargs(TypedDict, total=False):
    """
    Keyword arguments for Flash Attention with Compile.

    Attributes:
        cu_seq_lens_q (`torch.LongTensor`, *optional*)
            Gets cumlative sequence length for query state.
        cu_seq_lens_k (`torch.LongTensor`, *optional*)
            Gets cumlative sequence length for key state.
        max_length_q (`int`, *optional*):
            Maximum sequence length for query state.
        max_length_k (`int`, *optional*):
            Maximum sequence length for key state.
    """

    cu_seq_lens_q: Optional[torch.LongTensor]
    cu_seq_lens_k: Optional[torch.LongTensor]
    max_length_q: Optional[int]
    max_length_k: Optional[int]


class KwargsForCausalLM(FlashAttentionKwargs): ...


class DistanceProjection(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.fc1 = nn.Linear(in_features, out_features)
        self.act = nn.SiLU()
        self.fc2 = nn.Linear(out_features, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

    def init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)


class BboxHead(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.proj_layers = nn.ModuleList(
            [nn.Linear(in_features, in_features) for _ in range(6)]
        )
        self.act = nn.SiLU()
        self.out_proj = nn.Linear(in_features, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.proj_layers:
            x = layer(x)
            x = self.act(x)

        x = self.out_proj(x)
        return x


class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel):
    config_class = SuryaModelConfig
    supports_gradient_checkpointing = True
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_flex_attn = True
    _supports_cache_class = True
    _supports_quantized_cache = True
    _supports_static_cache = True
    _supports_attention_backend = True
    main_input_name = "input_ids"
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(
        self,
        config: SuryaModelConfig,
        embedder: SimpleTokenEmbedder = None,
        vision_encoder: SuryaEncoderModel = None,
        decoder: SuryaDecoderModel = None,
        **kwargs,
    ):
        super().__init__(config, **kwargs)

        if vision_encoder is None:
            vision_encoder = SuryaEncoderModel(config.vision_encoder)

        if decoder is None:
            decoder = SuryaDecoderModel(config.decoder)

        if embedder is None:
            embedder = SimpleTokenEmbedder(config)

        self.vision_encoder = vision_encoder
        self.decoder = decoder
        self.embedder = embedder

        # Simple encoding for image patches
        self.img_w_embed = nn.Embedding(
            self.config.image_embed_encoding_size,
            self.config.hidden_size,
        )

        self.img_h_embed = nn.Embedding(
            self.config.image_embed_encoding_size,
            self.config.hidden_size,
        )

        # Tying configs
        self.vision_encoder.config = self.config.vision_encoder
        self.decoder.config = self.config.decoder

        self.bbox_head = BboxHead(config.hidden_size, 6)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)

        if (
            self.config.multi_output_distance is not None
            and self.config.multi_output_distance > 0
        ):
            self.multi_output_projections = nn.ModuleList(
                [
                    DistanceProjection(
                        in_features=config.hidden_size, out_features=config.hidden_size
                    )
                    for _ in range(self.config.multi_output_distance)
                ]
            )

    def tie_weights(self):
        self._tie_weights()

    def _tie_weights(self):
        # Tie weights of lm head and token embedder
        self._tie_or_clone_weights(self.lm_head, self.embedder.token_embed)

    def get_output_embeddings(self) -> nn.Module:
        return self.lm_head

    def get_input_embeddings(self) -> nn.Module:
        return self.embedder.token_embed

    def set_output_embeddings(self, new_embeddings: nn.Module):
        self.lm_head = new_embeddings

    def set_input_embeddings(self, new_embeddings: nn.Module):
        self.embedder.token_embed = new_embeddings

    def maybe_static_pad_image_inputs(
        self,
        chunk_pixels: torch.Tensor,
        chunk_grid_thw: torch.Tensor,
        actual_chunk_len: int,
        encoder_chunk_size: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        valid_embed_len = actual_chunk_len // (
            self.vision_encoder.spatial_merge_size**2
        )
        if settings.FOUNDATION_STATIC_CACHE and actual_chunk_len < encoder_chunk_size:
            padding_len = encoder_chunk_size - actual_chunk_len
            chunk_pixels = F.pad(
                chunk_pixels,
                (0, 0, 0, padding_len),
                mode="constant",
                value=0.0,
            )

            padding_grid = torch.tensor(
                [[1, 2, padding_len // 2]],
                device=chunk_grid_thw.device,
                dtype=chunk_grid_thw.dtype,
            )
            chunk_grid_thw = torch.cat([chunk_grid_thw, padding_grid], dim=0)

        return chunk_pixels, chunk_grid_thw, valid_embed_len

    def get_image_embeddings(
        self,
        pixel_values: torch.Tensor,
        grid_thw: torch.Tensor,
        encoder_chunk_size: int,
        valid_batch_size: torch.Tensor | None = None,
        max_batch_size: int | None = None,
    ):
        # embed all images with the vision encoder after they have already been tiled and flattened into a single batch
        chunks = [0]
        grid_chunks = [0]
        curr_chunk_len = 0
        curr_seq_len = 0
        for i in range(len(grid_thw)):
            curr_chunk_len += (grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]).item()
            if curr_chunk_len > encoder_chunk_size:
                chunks.append(curr_chunk_len + curr_seq_len)
                curr_seq_len += curr_chunk_len
                curr_chunk_len = 0
                grid_chunks.append(i + 1)

        if curr_chunk_len > 0:
            chunks.append(pixel_values.shape[0])
            grid_chunks.append(len(grid_thw))

        assert curr_chunk_len + curr_seq_len == pixel_values.shape[0], (
            f"Mismatch in encoder chunking, {curr_chunk_len} + {curr_seq_len} != {pixel_values.shape[0]}"
        )

        logger.debug(
            f"Chunking encoder sequence into {len(chunks) - 1} chunks of size {encoder_chunk_size} with lengths {chunks} and grids {grid_chunks}"
        )
        embeddings = []
        for i in range(len(chunks) - 1):
            start = chunks[i]
            end = chunks[i + 1]
            grid_start = grid_chunks[i]
            grid_end = grid_chunks[i + 1]

            chunk_pixels = pixel_values[start:end]
            chunk_grid_thw = grid_thw[grid_start:grid_end]
            actual_chunk_len = end - start
            chunk_pixels, chunk_grid_thw, valid_embed_len = (
                self.maybe_static_pad_image_inputs(
                    chunk_pixels, chunk_grid_thw, actual_chunk_len, encoder_chunk_size
                )
            )

            chunk_embeddings = self.vision_encoder.embed_images(
                image_batch=chunk_pixels.unsqueeze(0).to(device=self.device),
                grid_thw=chunk_grid_thw.unsqueeze(0).to(device=self.device),
            )
            embeddings.append(chunk_embeddings[:valid_embed_len].squeeze(0))

        if len(embeddings) == 0:
            raise ValueError(
                "No image embeddings were generated. Check the input images and grid sizes."
            )
        elif len(embeddings) == 1:
            embeddings = embeddings[0]
        else:
            embeddings = torch.cat(embeddings, dim=0)

        encoding_2d = self.get_2d_learned_embeddings(
            grid_thw,
            device=embeddings.device,
            bbox_size=self.config.image_embed_encoding_multiplier,
        )
        assert embeddings.shape[0] == encoding_2d.shape[0], (
            f"Mismatch in image embedding seq len: {embeddings.shape} vs {encoding_2d.shape}"
        )
        assert embeddings.shape[1] == encoding_2d.shape[1], (
            f"Mismatch in image embedding token counts: {embeddings.shape} vs {encoding_2d.shape}"
        )

        embeddings = embeddings + encoding_2d

        return embeddings

    def embed_ids_boxes_images(
        self,
        input_ids,
        image_embeddings,
        encoder_chunk_size: int,
        valid_batch_size: torch.Tensor | None = None,
        input_boxes: torch.Tensor | None = None,
        embed_boxes: torch.Tensor | None = None,
    ):
        """
        Insert embedded image tiles into the corresponding positions into the full input sequence

        Positions to insert new tokens are indicated by the special image token index
        """
        # This is batched in the inner call
        inputs_embeds = self.embedder.embed(
            input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes
        )

        if image_embeddings is not None:
            special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
            special_image_mask = special_image_mask.expand_as(inputs_embeds)
            if inputs_embeds[special_image_mask].numel() != image_embeddings.numel():
                n_image_tokens = torch.sum((input_ids == self.config.image_token_id))
                n_image_features = image_embeddings.shape[0] * image_embeddings.shape[1]
                warnings.warn(
                    f"Image features and image tokens do not match: tokens {n_image_tokens}, features {n_image_features}. This may lead to unexpected results"
                )
            image_features = image_embeddings.to(inputs_embeds.dtype)
            inputs_embeds = inputs_embeds.masked_scatter(
                special_image_mask, image_features
            )
        else:
            assert (input_ids == self.config.image_token_id).sum() == 0, (
                "Image tokens were present in the input but no input images were provided"
            )

        return inputs_embeds

    def get_2d_learned_embeddings(
        self,
        grid_thw,
        device: str | torch.device = "cpu",
        bbox_size: int = 256,
    ):
        all_embeddings = []
        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.config.merge_size,
                grid_w // self.config.merge_size,
            )

            # Scale to 0-1024
            llm_grid_h = (
                torch.arange(llm_grid_h, device=device)
                / max(1, (llm_grid_h - 1))
                * bbox_size
            )
            llm_grid_w = (
                torch.arange(llm_grid_w, device=device)
                / max(1, (llm_grid_w - 1))
                * bbox_size
            )

            llm_grid_w_idx = llm_grid_w.to(torch.long)
            llm_grid_h_idx = llm_grid_h.to(torch.long)

            llm_grid_w = self.img_w_embed(llm_grid_w_idx)
            llm_grid_h = self.img_h_embed(llm_grid_h_idx)

            full_grid = llm_grid_h[:, None] + llm_grid_w[None, :]

            flattened = full_grid.flatten(
                0, 1
            )  # Flatten first dimension, so they are seq_len x embed_dim
            all_embeddings.append(flattened)
        return torch.concat(
            all_embeddings, dim=0
        )  # Shape is num_image_tokens x embed_dim

    def get_logits(self, hidden_states):
        assert hidden_states.shape[1] == 1, (
            "Multi output predictions only applied on the last token"
        )

        all_lm_logits = []
        all_bbox_logits = []

        current_hidden = hidden_states

        # Loop includes initial prediction (i=0) plus multi_out_distance additional predictions
        for i in range(self.config.multi_output_distance + 1):
            if i > 0:
                current_hidden = self.multi_output_projections[i - 1](current_hidden)

            lm_logits = self.lm_head(current_hidden)
            bbox_logits = F.sigmoid(self.bbox_head(current_hidden))

            all_lm_logits.append(lm_logits)
            all_bbox_logits.append(bbox_logits)

        # Concatenate along sequence dimension (dim=1)
        final_lm_logits = torch.cat(all_lm_logits, dim=1)
        final_bbox_logits = torch.cat(all_bbox_logits, dim=1)

        return final_lm_logits, final_bbox_logits

    def forward(
        self,
        input_ids=None,
        image_embeddings=None,
        labels=None,
        image_tiles=None,
        grid_thw=None,
        inputs_embeds=None,
        attention_mask=None,
        position_ids=None,
        cache_position=None,
        past_key_values=None,
        output_hidden_states=False,
        output_attentions=False,
        use_cache=False,
        encoder_chunk_size=32768,
        cache_idxs=None,
        num_valid_tokens=None,
        prefill=True,
        text_lengths=None,
        valid_batch_size: torch.Tensor = None,
        input_boxes=None,
        embed_boxes=None,
        logits_to_keep=None,
        **kwargs: KwargsForCausalLM,
    ):
        if any([
            input_ids is None,
            position_ids is None,
            cache_position is None,
            (
                prefill
                and not (
                    (image_tiles is not None and grid_thw is not None)
                    or image_embeddings is not None
                )
            ),
        ]):
            raise ValueError(
                "`input_ids`, `position_ids`, and `cache_position` **must** be specified. "
                "For prefill, you must provide either (`image_tiles` and `grid_thw`) or `image_embeddings`."
            )


        inputs_embeds = self.embed_ids_boxes_images(
            input_ids, image_embeddings, encoder_chunk_size, valid_batch_size, input_boxes, embed_boxes
        )

        # Handling flash attention kwargs outside the decoder to speed up + avoid graph breaks inside the decoder
        # Skipped during decoding since not required
        if self.decoder.config._attn_implementation == "flash_attention_2" and prefill:
            # Needed for CPU -> GPU
            from surya.common.surya.flash_attn_utils import _get_unpad_data
            batch_size, query_length, _ = inputs_embeds.shape
            indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
                attention_mask
            )
            kwargs["batch_size"] = batch_size
            kwargs["query_length"] = query_length
            kwargs["indices_k"] = indices_k
            kwargs["cu_seqlens_k"] = cu_seqlens_k
            kwargs["max_seqlen_in_batch_k"] = max_seqlen_in_batch_k

        causal_mask = self._update_causal_mask(
            attention_mask,
            inputs_embeds,
            cache_position,
            past_key_values,
            output_attentions,
        )

        attention_mask = causal_mask
        outputs = self.decoder(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            cache_position=cache_position,
            past_key_values=past_key_values,
            return_dict=True,
            use_cache=use_cache,
            cache_idxs=cache_idxs,
            num_valid_tokens=num_valid_tokens,
            prefill=prefill,
            text_lengths=text_lengths,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        if logits_to_keep is not None:
            hidden_states = hidden_states[:, -logits_to_keep:, :]
        hidden_states = hidden_states.contiguous()

        loss = None
        if labels is not None:
            # Training, return full logits
            lm_logits = self.lm_head(hidden_states)
            bbox_logits = None
            vocab_size = lm_logits.shape[-1]
            labels = torch.roll(labels, shifts=-1, dims=-1)
            loss = F.cross_entropy(
                lm_logits.view(-1, vocab_size), labels.view(-1), reduction="mean"
            )
        else:
            lm_logits, bbox_logits = self.get_logits(hidden_states)

        return SuryaModelOutput(
            loss=loss,
            bbox_logits=bbox_logits,
            lm_logits=lm_logits,
            hidden_states=outputs.hidden_states if output_hidden_states else None,
            attentions=outputs.attentions if output_attentions else None,
            past_key_values=outputs.past_key_values,
        )

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
    ):
        if self.decoder.config._attn_implementation == "flash_attention_2":
            return attention_mask

        # We always pass in a 2D attention mask from the processor - In both static and dynamic cache cases
        dtype, device = input_tensor.dtype, input_tensor.device
        min_dtype = torch.finfo(dtype).min
        sequence_length = input_tensor.shape[1]
        target_length = (
            attention_mask.shape[-1]
            if isinstance(attention_mask, torch.Tensor)
            else past_key_values.max_cache_len
        )

        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
            config=self.config,
            past_key_values=past_key_values,
        )

        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type in ["cuda", "xpu"]
            and not output_attentions
        ):
            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
            # Details: https://github.com/pytorch/pytorch/issues/110213
            causal_mask = AttentionMaskConverter._unmask_unattended(
                causal_mask, min_dtype
            )

        return causal_mask

    @staticmethod
    def _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        device: torch.device,
        cache_position: torch.Tensor,
        batch_size: int,
        config: SuryaModelConfig,
        past_key_values: Cache,
    ):
        """
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            device (`torch.device`):
                The device to plcae the 4D attention mask on.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence. Shape `(batch_size, sequence_length)`.
            batch_size (`torch.Tensor`):
                Batch size.
            config (`Qwen2Config`):
                The model's configuration class
            past_key_values (`Cache`):
                The cache class that is being used currently to generate
        """
        if attention_mask is not None and attention_mask.dim() == 4:
            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
            causal_mask = attention_mask
        else:
            min_dtype = torch.finfo(dtype).min
            causal_mask = torch.full(
                (sequence_length, target_length),
                fill_value=min_dtype,
                dtype=dtype,
                device=device,
            )
            # Batch-aware diagonal attend mask
            diagonal_attend_mask = torch.arange(target_length, device=device).unsqueeze(
                0
            ) > cache_position.unsqueeze(-1)
            causal_mask = (
                causal_mask.unsqueeze(0) * diagonal_attend_mask
            )  # (batch_size, seq_len, target_len)
            causal_mask = causal_mask[
                :, None, :, :
            ]  # (batch_size, 1, seq_len, target_len)
            if attention_mask is not None:
                causal_mask = (
                    causal_mask.clone()
                )  # copy to contiguous memory for in-place edit
                if attention_mask.shape[-1] > target_length:
                    attention_mask = attention_mask[:, :target_length]
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
                    :, None, None, :
                ].to(causal_mask.device)
                padding_mask = padding_mask == 0
                causal_mask[:, :, :, :mask_length] = causal_mask[
                    :, :, :, :mask_length
                ].masked_fill(padding_mask, min_dtype)
        return causal_mask

class SuryaXLAModel(SuryaModel):
    def get_image_embeddings(
        self,
        pixel_values: torch.Tensor,
        grid_thw: torch.Tensor,
        encoder_chunk_size: int,
        valid_batch_size: torch.Tensor | None = None,
        max_batch_size: int | None = None,
    ):
        # embed all images with the vision encoder after they have already been tiled and flattened into a single batch
        unpadded_max_grid_size = (
            (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).max().item()
        )
        max_grid_size = get_nearest_pad(
            unpadded_max_grid_size,
        )  # If we need zero padding, we still need to allocate a bit of room for the extra grid_thw

        # Always need 2 items in each row batch
        if max_grid_size == unpadded_max_grid_size:
            max_grid_size += 16

        full_image_grid = torch.zeros(
            (valid_batch_size, max_grid_size, pixel_values.shape[-1]),
            dtype=pixel_values.dtype,
        )

        # Roll out into a full grid
        seq_len = 0
        row_grids = []
        for i in range(valid_batch_size):
            curr_sample_len = grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]
            full_image_grid[i, -curr_sample_len:] = pixel_values[
                seq_len : seq_len + curr_sample_len
            ]
            padded_len = max_grid_size - curr_sample_len
            if padded_len > 0:
                row_grid = torch.tensor(
                    [
                        [1, 4, padded_len // 4],
                        grid_thw[i].tolist(),
                    ],
                    dtype=torch.long,
                )
            else:
                row_grid = torch.tensor(
                    [
                        grid_thw[i].tolist(),
                    ],
                    dtype=torch.long,
                )

            row_grids.append(row_grid)
            seq_len += curr_sample_len

        # bsz, 2, 3
        row_grids = torch.stack(row_grids, dim=0)

        if settings.FOUNDATION_STATIC_CACHE:
            # Pad to max batch size, repeat the final row
            row_grids = pad_to_batch_size_repeat(
                row_grids,
                batch_size=max_batch_size,
            )
            full_image_grid = pad_to_batch_size(
                full_image_grid,
                batch_size=max_batch_size,
            )

        full_image_grid = full_image_grid.to(self.device)

        embeddings = self.vision_encoder.embed_images(
            image_batch=full_image_grid, grid_thw=row_grids.to(self.device)
        )

        encoding_2d = self.get_2d_learned_embeddings(
            row_grids,
            bbox_size=self.config.image_embed_encoding_multiplier,
        )
        embeddings += encoding_2d

        return embeddings

    def embed_ids_boxes_images(
        self,
        input_ids,
        image_embeddings,
        encoder_chunk_size: int,
        valid_batch_size: torch.Tensor | None = None,
        input_boxes: torch.Tensor | None = None,
        embed_boxes: torch.Tensor | None = None,
    ):
        """
        Insert embedded image tiles into the corresponding positions into the full input sequence

        Positions to insert new tokens are indicated by the special image token index
        """
        # This is batched in the inner call
        inputs_embeds = self.embedder.embed(
            input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes
        )

        if image_embeddings is not None:
            image_token_id_tensor = torch.tensor(
                self.config.image_token_id,
                device=inputs_embeds.device,
                dtype=torch.long,
            )
            mask = input_ids == image_token_id_tensor
            last_image_token_pos = (
                mask.size(1)
                - 1
                - mask.flip(dims=[1]).long().argmax(dim=1, keepdim=True)
            )
            # Calculate start position to replace N positions ending at (and including) the last image token
            start_positions = last_image_token_pos - image_embeddings[0].shape[0]
            batch_size, insert_len = image_embeddings.shape[:2]

            # Create position indices for each insertion
            pos_indices = torch.arange(
                insert_len, device=inputs_embeds.device
            ).unsqueeze(0)
            insert_positions = start_positions + pos_indices

            idx = insert_positions.unsqueeze(-1).expand(
                -1, -1, inputs_embeds.size(-1)
            )  # [B,N,D]
            inputs_embeds = inputs_embeds.scatter(1, idx, image_embeddings)

        inputs_embeds = inputs_embeds * (
            input_ids != self.config.pad_token_id
        ).unsqueeze(-1).to(inputs_embeds.dtype)
        return inputs_embeds

    def get_2d_learned_embeddings(
        self,
        grid_thw,
        bbox_size: int = 256,
    ):
        dev = grid_thw.device
        all_row_coords = []
        all_col_coords = []
        for row_grid in grid_thw:
            merge = self.config.merge_size

            # per-sample grid sizes after merge
            H = (row_grid[:, 1] // merge).long()  # (B,)
            W = (row_grid[:, 2] // merge).long()  # (B,)

            row_coords = torch.cat(
                [
                    torch.linspace(0, bbox_size, steps=int(h), device=dev)
                    .round()
                    .repeat_interleave(w)  # repeat each row value w times
                    for h, w in zip(H.tolist(), W.tolist())
                ]
            )  # (full_grid_size,)

            col_coords = torch.cat(
                [
                    torch.linspace(0, bbox_size, steps=int(w), device=dev)
                    .round()
                    .repeat(int(h))  # tile the column vector h times
                    for h, w in zip(H.tolist(), W.tolist())
                ]
            )  # (full_grid_size,)
            all_row_coords.append(row_coords)
            all_col_coords.append(col_coords)
        row_coords = torch.stack(all_row_coords, dim=0).to(self.device)
        col_coords = torch.stack(all_col_coords, dim=0).to(self.device)

        emb = self.img_h_embed(row_coords.long()) + self.img_w_embed(col_coords.long())
        return emb

```

--------------------------------------------------------------------------------
/surya/common/surya/encoder/__init__.py:
--------------------------------------------------------------------------------

```python
import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN

from surya.common.pretrained import SuryaPreTrainedModel
from surya.common.surya.encoder.config import SuryaEncoderConfig
from surya.common.xla import get_nearest_pad
from surya.logging import get_logger
from surya.settings import settings

if settings.FOUNDATION_XLA:
    import torch_xla.experimental.custom_kernel

from surya.logging import get_logger
logger = get_logger()


class Qwen2_5_VLMLP(nn.Module):
    def __init__(self, config, bias: bool = False):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, hidden_state):
        return self.down_proj(
            self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)
        )


class Qwen2_5_VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
        in_channels: int = 3,
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        kernel_size = [temporal_patch_size, patch_size, patch_size]
        self.proj = nn.Conv3d(
            in_channels,
            embed_dim,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=False,
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        target_dtype = self.proj.weight.dtype
        bsz = hidden_states.shape[0]
        hidden_states = hidden_states.view(
            -1,
            self.in_channels,
            self.temporal_patch_size,
            self.patch_size,
            self.patch_size,
        )
        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(
            bsz, -1, self.embed_dim
        )
        return hidden_states


class Qwen2_5_VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.inv_freq = 1.0 / (
            theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)
        )

    def forward(self, seqlen: int) -> torch.Tensor:
        seq = torch.arange(seqlen, device="cpu", dtype=self.inv_freq.dtype)
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class Qwen2_5_VLPatchMerger(nn.Module):
    def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.GELU(),
            nn.Linear(self.hidden_size, dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        bsz = x.shape[0]
        x = self.mlp(self.ln_q(x).view(bsz, -1, self.hidden_size))
        return x


def apply_rotary_pos_emb_flashatt(
    q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    from flash_attn.layers.rotary import apply_rotary_emb

    cos = cos.chunk(2, dim=-1)[0].contiguous()
    sin = sin.chunk(2, dim=-1)[0].contiguous()
    q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
    k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
    return q_embed, k_embed


class Qwen2_5_VLVisionXLASdpaAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int = 16) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)
        self.head_dim = dim // num_heads

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1]
        q, k, v = (
            self.qkv(hidden_states)
            .reshape(bsz, seq_length, 3, self.num_heads, -1)
            .permute(0, 2, 1, 3, 4)
            .unbind(1)
        )
        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        else:
            cos, sin = position_embeddings
        q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

        attention_mask = torch.zeros([bsz, 1, seq_length, seq_length], dtype=torch.bool)
        cu_seqlens_cpu = cu_seqlens.cpu()
        for j in range(bsz):
            batch_seqlens = cu_seqlens_cpu[j]
            for i in range(1, len(batch_seqlens)):
                attention_mask[
                    j,
                    ...,
                    batch_seqlens[i - 1] : batch_seqlens[i],
                    batch_seqlens[i - 1] : batch_seqlens[i],
                ] = True

        attention_mask = attention_mask.to(q.device)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn_output = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attention_mask,
            dropout_p=0.0,
        )
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, seq_length, -1)
        attn_output = self.proj(attn_output)
        return attn_output


class Qwen2_5_VLVisionXLAFlashAttention2(nn.Module):
    def __init__(self, dim: int, num_heads: int = 16) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)
        self.head_dim = dim // num_heads

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        # Note, this is faster than SDPA, but pretty memory inefficient
        # It also has significant accuracy issues

        bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1]

        # Single reshape to target layout - avoid multiple operations
        q, k, v = (
            self.qkv(hidden_states)
            .reshape(bsz, seq_length, 3, self.num_heads, -1)
            .permute(0, 2, 1, 3, 4)
            .unbind(1)
        )

        # Apply rotary embeddings if provided
        if position_embeddings is not None:
            cos, sin = position_embeddings
            q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

        # Single reshape to flash attention format [batch, num_heads, seq_len, head_dim]
        q = q.transpose(1, 2)  # [bsz, num_heads, seq_len, head_dim]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        total_seqlen = q.shape[2]
        # from cu_seqlens to segment ids for each position in dim 0
        additive_bias = torch.zeros((bsz, 1, total_seqlen, total_seqlen), dtype=q.dtype)
        min_val = torch.finfo(q.dtype).min

        for i in range(bsz):
            padding_end = cu_seqlens[i][1].item()
            additive_bias[i, :, :, :padding_end] = min_val

        additive_bias = additive_bias.to(hidden_states.device)

        attn_scale = 1 / math.sqrt(self.head_dim)
        attn_output = torch_xla.experimental.custom_kernel.flash_attention(
            q, k, v, sm_scale=attn_scale, ab=additive_bias
        )
        attn_output = (
            attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_length, -1)
        )
        attn_output = self.proj(attn_output)
        return attn_output


class Qwen2_5_VLVisionFlashAttention2(nn.Module):
    def __init__(self, dim: int, num_heads: int = 16) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        from flash_attn import flash_attn_varlen_func

        bsz = hidden_states.shape[0]
        seq_length = hidden_states.shape[1]
        q, k, v = (
            self.qkv(hidden_states)
            .reshape(bsz, seq_length, 3, self.num_heads, -1)
            .permute(0, 2, 1, 3, 4)
            .unbind(1)
        )
        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        else:
            cos, sin = position_embeddings

        q, k = apply_rotary_pos_emb_flashatt(q, k, cos.squeeze(0), sin.squeeze(0))

        q = q.squeeze(0)
        k = k.squeeze(0)
        v = v.squeeze(0)
        cu_seqlens = cu_seqlens.squeeze(0)

        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        attn_output = flash_attn_varlen_func(
            q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
        ).reshape(bsz, seq_length, -1)
        attn_output = self.proj(attn_output)
        return attn_output


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb_vision(
    q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    orig_q_dtype = q.dtype
    orig_k_dtype = k.dtype
    q, k = q.float(), k.float()
    cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    q_embed = q_embed.to(orig_q_dtype)
    k_embed = k_embed.to(orig_k_dtype)
    return q_embed, k_embed


class Qwen2_5_VLVisionAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int = 16) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1]
        q, k, v = (
            self.qkv(hidden_states)
            .reshape(bsz, seq_length, 3, self.num_heads, -1)
            .permute(0, 2, 1, 3, 4)
            .unbind(1)
        )
        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        else:
            cos, sin = position_embeddings

        q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

        attention_mask = torch.full(
            [bsz, 1, seq_length, seq_length],
            torch.finfo(q.dtype).min,
            device=q.device,
            dtype=q.dtype,
        )
        for j in range(bsz):
            batch_seqlens = cu_seqlens[j]
            for i in range(1, len(batch_seqlens)):
                attention_mask[
                    j,
                    ...,
                    batch_seqlens[i - 1] : batch_seqlens[i],
                    batch_seqlens[i - 1] : batch_seqlens[i],
                ] = 0

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
        attn_weights = attn_weights + attention_mask
        attn_weights = nn.functional.softmax(
            attn_weights, dim=-1, dtype=torch.float32
        ).to(q.dtype)
        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, seq_length, -1)
        attn_output = self.proj(attn_output)
        return attn_output


class Qwen2_5_VLVisionSdpaAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int = 16) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)

    def unpack_qkv_with_mask(self, q, k, v, cu_seqlens):
        """
        Unpacks q, k, v sequences into batch-major form and constructs an additive attention mask.

        Args:
            q, k, v: Tensors of shape (total_seq_len, num_heads, head_dim)
            cu_seqlens: Tensor of shape (batch_size + 1,) with cumulative sequence lengths

        Returns:
            batched_q: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)
            batched_k: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)
            batched_v: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim)
            attention_mask: Tensor of shape (batch_size, 1, max_seq_len, max_seq_len)
                            with 0 for valid tokens and -inf for padding (for additive attention)
        """
        device = q.device
        dtype = q.dtype

        batch_size = cu_seqlens.shape[0] - 1
        num_heads = q.shape[1]
        head_dim = q.shape[2]

        seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]  # Keep as tensor
        max_seq_len = seq_lengths.max().item()  # Use .max() on tensor

        if settings.FOUNDATION_STATIC_CACHE:
            # Pad max_seq_len to the nearest multiple for compilation
            max_seq_len = get_nearest_pad(max_seq_len, pad_multiple=16)

            # Pad batch_size to the nearest multiple for compilation
            batch_size = get_nearest_pad(batch_size, pad_multiple=2)

            # Ensure seq_lengths is a tensor of the correct size
            seq_lengths = F.pad(
                seq_lengths, (0, batch_size - seq_lengths.size(0)), "constant", 0
            )

        # some day, you may look at this, and think: "what if I used repeat_interlave or some other fancy torch instead"?
        # don't do this - it's a path to madness.  For some reason, this loop is optimal

        batch_indices = []
        position_indices = []

        for i, seq_len in enumerate(
            seq_lengths.tolist()
        ):  # Convert to list only for iteration
            batch_indices.extend([i] * seq_len)
            position_indices.extend(list(range(seq_len)))

        batch_indices = torch.tensor(batch_indices, device=device)
        position_indices = torch.tensor(position_indices, device=device)

        batched_q = torch.zeros(
            (batch_size, max_seq_len, num_heads, head_dim), device=device, dtype=dtype
        )
        batched_k = torch.zeros_like(batched_q)
        batched_v = torch.zeros_like(batched_q)

        # Create additive attention mask
        attention_mask = torch.full(
            (batch_size, max_seq_len, max_seq_len),
            fill_value=float("-inf"),
            device=device,
            dtype=dtype,
        )

        # Create mask for valid positions
        seq_range = torch.arange(max_seq_len, device=device)
        valid_mask = seq_range.unsqueeze(0) < seq_lengths.unsqueeze(
            1
        )  # (batch_size, max_seq_len)
        valid_2d = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(
            1
        )  # (batch_size, max_seq_len, max_seq_len)

        # Simply use boolean indexing to set valid positions to 0
        attention_mask[valid_2d] = 0

        attention_mask = attention_mask.unsqueeze(
            1
        )  # (batch_size, 1, max_seq_len, max_seq_len)

        batched_q[batch_indices, position_indices] = q
        batched_k[batch_indices, position_indices] = k
        batched_v[batch_indices, position_indices] = v

        return (
            batched_q,
            batched_k,
            batched_v,
            attention_mask,
            batch_indices,
            position_indices,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        hidden_states = hidden_states.squeeze(0)
        cu_seqlens = cu_seqlens.squeeze(0)

        seq_length = hidden_states.shape[0]
        q, k, v = (
            self.qkv(hidden_states)
            .reshape(seq_length, 3, self.num_heads, -1)
            .permute(1, 0, 2, 3)
            .unbind(0)
        )
        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        else:
            cos, sin = position_embeddings
        q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
        q = q.squeeze(0)
        k = k.squeeze(0)

        q, k, v, attention_mask, batch_indices, position_indices = (
            self.unpack_qkv_with_mask(q, k, v, cu_seqlens)
        )
        batch_size, max_seqlen = q.shape[:2]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn_output = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attention_mask,
            dropout_p=0.0,
        )
        attn_output = attn_output.permute(0, 2, 1, 3).reshape(
            batch_size, max_seqlen, -1
        )  # Bring back to (batch_size, max_seqlen, hidden_dim)
        attn_output = attn_output[batch_indices, position_indices]
        attn_output = self.proj(attn_output)

        return attn_output.unsqueeze(0)


QWEN2_5_VL_VISION_ATTENTION_CLASSES = {
    "eager": Qwen2_5_VLVisionAttention,
    "flash_attention_2": Qwen2_5_VLVisionXLAFlashAttention2
    if settings.FOUNDATION_XLA
    else Qwen2_5_VLVisionFlashAttention2,
    "sdpa": Qwen2_5_VLVisionXLASdpaAttention
    if settings.FOUNDATION_XLA
    else Qwen2_5_VLVisionSdpaAttention,
}


class Qwen2_5_VLVisionBlock(nn.Module):
    def __init__(self, config, attn_implementation: str = "sdpa") -> None:
        super().__init__()
        self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
        self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
        self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](
            config.hidden_size, num_heads=config.num_heads
        )
        self.mlp = Qwen2_5_VLMLP(config, bias=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            position_embeddings=position_embeddings,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


Qwen2_5_VL_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`Qwen2_5_VLConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


class Qwen2_5_VLPreTrainedModel(SuryaPreTrainedModel):
    config_class = SuryaEncoderConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_cache_class = True
    _supports_static_cache = False  # TODO (joao): fix. torch.compile failing probably due to `cache_positions`

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, (nn.Linear, nn.Conv3d)):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()


class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
    config_class = SuryaEncoderConfig
    _no_split_modules = ["Qwen2_5_VLVisionBlock"]

    def __init__(self, config, *inputs, **kwargs) -> None:
        super().__init__(config, *inputs, **kwargs)
        self.spatial_merge_size = config.spatial_merge_size
        self.patch_size = config.patch_size
        self.fullatt_block_indexes = config.fullatt_block_indexes
        self.window_size = config.window_size
        self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size

        self.patch_embed = Qwen2_5_VisionPatchEmbed(
            patch_size=config.patch_size,
            temporal_patch_size=config.temporal_patch_size,
            in_channels=config.in_channels,
            embed_dim=config.hidden_size,
        )

        head_dim = config.hidden_size // config.num_heads
        self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)

        self.blocks = nn.ModuleList(
            [
                Qwen2_5_VLVisionBlock(config, config._attn_implementation)
                for _ in range(config.depth)
            ]
        )
        self.merger = Qwen2_5_VLPatchMerger(
            dim=config.out_hidden_size,
            context_dim=config.hidden_size,
            spatial_merge_size=config.spatial_merge_size,
        )
        self.gradient_checkpointing = False

    def rot_pos_emb(self, grid_thw):
        rotary_pos_emb = []
        grid_thw_list = grid_thw.cpu().tolist()
        for batch_item in grid_thw_list:
            row_pos_ids = []
            heights = [h for _, h, _ in batch_item]
            widths = [w for _, _, w in batch_item]
            max_grid_size = max(heights + widths)
            for t, h, w in batch_item:
                hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
                hpos_ids = hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                hpos_ids = hpos_ids.permute(0, 2, 1, 3)
                hpos_ids = hpos_ids.flatten()

                wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
                wpos_ids = wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                wpos_ids = wpos_ids.permute(0, 2, 1, 3)
                wpos_ids = wpos_ids.flatten()
                # shape: token_count, 2
                row_pos_ids.append(
                    torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
                )
            # shape: token_count, 2
            pos_ids = torch.cat(row_pos_ids, dim=0)
            rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
            rotary_pos_emb_row = rotary_pos_emb_full[pos_ids].flatten(1)
            rotary_pos_emb.append(rotary_pos_emb_row)
        rotary_pos_emb = torch.stack(rotary_pos_emb, dim=0)
        return rotary_pos_emb

    def forward(
        self,
        hidden_states: torch.Tensor,
        grid_thw: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states (`torch.Tensor` of shape `(bsz, seq_len, hidden_size)`):
                The final hidden states of the model.
            grid_thw (`torch.Tensor` of shape `(bsz, num_images_or_videos, 3)`):
                The temporal, height and width of feature shape of each image in LLM.

        Returns:
            `torch.Tensor`: hidden_states.
        """
        bsz, seq_len, _ = hidden_states.size()
        hidden_states = self.patch_embed(hidden_states)  # (bsz, seq_len, hidden_dim)
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        # hidden_states = hidden_states.reshape(bsz, seq_len, -1)
        # rotary_pos_emb = rotary_pos_emb.reshape(bsz, seq_len, -1)
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1).to(
            hidden_states.device
        )
        position_embeddings = (emb.cos(), emb.sin())

        cu_seqlens = (grid_thw[:, :, 1] * grid_thw[:, :, 2]).cumsum(
            dim=1,
            # Select dtype based on the following factors:
            #  - FA2 requires that cu_seqlens_q must have dtype int32
            #  - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
            # See https://github.com/huggingface/transformers/pull/34852 for more information
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
        for layer_num, blk in enumerate(self.blocks):
            if self.gradient_checkpointing and self.training:
                hidden_states = self._gradient_checkpointing_func(
                    blk.__call__,
                    hidden_states,
                    cu_seqlens,
                    None,
                    position_embeddings,
                )
            else:
                hidden_states = blk(
                    hidden_states,
                    cu_seqlens=cu_seqlens,
                    position_embeddings=position_embeddings,
                )

        hidden_states = self.merger(hidden_states)
        return hidden_states


class SuryaEncoderModel(Qwen2_5_VisionTransformerPretrainedModel):
    @property
    def image_size(self) -> int:
        config: SuryaEncoderConfig = self.config
        if isinstance(config.image_size, tuple) and len(config.image_size) == 2:
            return config.image_size
        elif isinstance(config.image_size, int):
            return (config.image_size, config.image_size)

        raise ValueError(
            f"The `image_size` for SwinConfig should be a tuple of (int, int) or a single int but found {type(config.image_size)}"
        )

    @property
    def hidden_size(self) -> int:
        config: SuryaEncoderConfig = self.config
        return config.hidden_size

    def embed_images(
        self,
        image_batch: torch.Tensor,
        grid_thw: torch.Tensor,
    ) -> torch.Tensor:
        return super().forward(
            hidden_states=image_batch,
            grid_thw=grid_thw,
        )

```
Page 3/4FirstPrevNextLast