This is page 3 of 4. Use http://codebase.md/datalab-to/surya?page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── ISSUE_TEMPLATE │ │ ├── breaking-bug-report.md │ │ ├── feature_request.md │ │ └── output-bug-report.md │ └── workflows │ ├── benchmarks.yml │ ├── ci.yml │ ├── cla.yml │ ├── publish.yml │ └── scripts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── benchmark │ ├── detection.py │ ├── layout.py │ ├── ordering.py │ ├── recognition.py │ ├── table_recognition.py │ ├── texify.py │ └── utils │ ├── __init__.py │ ├── bbox.py │ ├── metrics.py │ ├── scoring.py │ ├── tatr.py │ ├── tesseract.py │ ├── textract.py │ └── verify_benchmark_scores.py ├── CITATION.cff ├── CLA.md ├── detect_layout.py ├── detect_text.py ├── LICENSE ├── ocr_app.py ├── ocr_latex.py ├── ocr_text.py ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── README.md ├── signatures │ └── version1 │ └── cla.json ├── static │ ├── fonts │ │ └── .gitignore │ └── images │ ├── arabic_layout.jpg │ ├── arabic_reading.jpg │ ├── arabic_text.jpg │ ├── arabic.jpg │ ├── benchmark_chart_small.png │ ├── benchmark_chart.png │ ├── benchmark_layout_chart.png │ ├── benchmark_rec_chart.png │ ├── benchmark_tablerec_acc.png │ ├── benchmark_tablerec_speed.png │ ├── chi_hind_layout.jpg │ ├── chi_hind_orig.jpg │ ├── chi_hind_reading.jpg │ ├── chi_hind_text.jpg │ ├── chi_hind.jpg │ ├── chinese_layout.jpg │ ├── chinese_reading.jpg │ ├── chinese_text.jpg │ ├── chinese.jpg │ ├── excerpt_layout.png │ ├── excerpt_reading.jpg │ ├── excerpt_text.png │ ├── excerpt.png │ ├── funsd_layout.jpg │ ├── funsd_reading.jpg │ ├── funsd_text.jpg │ ├── funsd.png │ ├── gcloud_full_langs.png │ ├── gcloud_rec_bench.png │ ├── hindi_layout.jpg │ ├── hindi_reading.jpg │ ├── hindi_text.jpg │ ├── hindi.jpg │ ├── japanese_layout.jpg │ ├── japanese_reading.jpg │ ├── japanese_tablerec.png │ ├── japanese_text.jpg │ ├── japanese.jpg │ ├── latex_ocr.png │ ├── nyt_layout.jpg │ ├── nyt_order.jpg │ ├── nyt_text.jpg │ ├── nyt.jpg │ ├── paper_layout.jpg │ ├── paper_reading.jpg │ ├── paper_tablerec.png │ ├── paper_text.jpg │ ├── paper.jpg │ ├── pres_layout.jpg │ ├── pres_reading.jpg │ ├── pres_tablerec.png │ ├── pres_text.jpg │ ├── pres.png │ ├── rec_acc_table.png │ ├── scanned_layout.jpg │ ├── scanned_reading.jpg │ ├── scanned_tablerec.png │ ├── scanned_tablerec2.png │ ├── scanned_text.jpg │ ├── scanned.png │ ├── surya_rec_perf.png │ ├── table_rec.png │ ├── textbook_layout.jpg │ ├── textbook_order.jpg │ ├── textbook_text.jpg │ └── textbook.jpg ├── surya │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── adetr │ │ │ └── decoder.py │ │ ├── donut │ │ │ ├── encoder.py │ │ │ └── processor.py │ │ ├── load.py │ │ ├── polygon.py │ │ ├── predictor.py │ │ ├── pretrained.py │ │ ├── s3.py │ │ ├── surya │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── decoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── embedder │ │ │ │ └── __init__.py │ │ │ ├── encoder │ │ │ │ ├── __init__.py │ │ │ │ └── config.py │ │ │ ├── flash_attn_utils.py │ │ │ ├── processor │ │ │ │ ├── __init__.py │ │ │ │ ├── schema.py │ │ │ │ └── tokenizer.py │ │ │ └── schema.py │ │ ├── util.py │ │ └── xla.py │ ├── debug │ │ ├── draw.py │ │ ├── fonts.py │ │ ├── katex.js │ │ ├── render_html.py │ │ └── text.py │ ├── detection │ │ ├── __init__.py │ │ ├── heatmap.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoderdecoder.py │ │ ├── parallel.py │ │ ├── processor.py │ │ ├── schema.py │ │ └── util.py │ ├── foundation │ │ ├── __init__.py │ │ ├── cache │ │ │ ├── __init__.py │ │ │ ├── dynamic_ops.py │ │ │ └── static_ops.py │ │ ├── loader.py │ │ └── util.py │ ├── input │ │ ├── load.py │ │ └── processing.py │ ├── layout │ │ ├── __init__.py │ │ ├── label.py │ │ └── schema.py │ ├── logging.py │ ├── models.py │ ├── ocr_error │ │ ├── __init__.py │ │ ├── loader.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── encoder.py │ │ ├── schema.py │ │ └── tokenizer.py │ ├── recognition │ │ ├── __init__.py │ │ ├── languages.py │ │ ├── postprocessing.py │ │ ├── schema.py │ │ └── util.py │ ├── scripts │ │ ├── __init__.py │ │ ├── config.py │ │ ├── detect_layout.py │ │ ├── detect_text.py │ │ ├── finetune_ocr.py │ │ ├── hf_to_s3.py │ │ ├── ocr_latex.py │ │ ├── ocr_text.py │ │ ├── run_streamlit_app.py │ │ ├── run_texify_app.py │ │ ├── streamlit_app.py │ │ ├── table_recognition.py │ │ └── texify_app.py │ ├── settings.py │ └── table_rec │ ├── __init__.py │ ├── loader.py │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── encoderdecoder.py │ ├── processor.py │ ├── schema.py │ └── shaper.py ├── table_recognition.py ├── tests │ ├── assets │ │ └── test_latex.png │ ├── conftest.py │ ├── test_detection.py │ ├── test_foundation.py │ ├── test_latex_ocr.py │ ├── test_layout.py │ ├── test_ocr_errors.py │ ├── test_recognition.py │ └── test_table_rec.py └── texify_app.py ``` # Files -------------------------------------------------------------------------------- /surya/common/surya/processor/__init__.py: -------------------------------------------------------------------------------- ```python import math import cv2 import numpy as np import torch from PIL import Image from torch.nn.utils.rnn import pad_sequence from typing import List, Optional, Tuple from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils import PreTrainedTokenizer from surya.common.s3 import S3DownloaderMixin from surya.common.surya.processor.schema import ( TextInput, ImageInput, ProcessorOutput, ) from surya.common.surya.schema import TaskNames from surya.logging import get_logger from surya.settings import settings logger = get_logger() # Task agnostic tokens - Every task will use these in some form or another EOS_TOKEN = "</S>" EOI_TOKEN = "<EOI>" # This is end of INPUT, not image. Images are always followed by a task specific BOS token, so that serves as a delimiter anyways. IMAGE_TOKEN = "<IMAGE>" PAD_TOKEN = "<PAD>" NO_OUTPUT_TOKEN = "<NOP>" IMAGE_ROTATED_TOKEN = "<ROT>" REGISTER_TOKENS = ["<REG1>", "<REG2>", "<REG3>", "<REG4>"] BEACON_TOKEN = "<BEACON>" NOMATH_TOKEN = "<NO-MATH>" # Task specific tokens OCR_WITH_BOXES_BOS_TOKEN = "<OCR-WB>" OCR_WITHOUT_BOXES_BOS_TOKEN = "<OCR-WOB>" BLOCK_WITHOUT_BOXES_TOKEN = "<BLOCKS-WOB>" LAYOUT_BOS_TOKEN = "<LAYOUT>" TABLE_STRUCTURE_BOS_TOKEN = "<TABLE-STRUCTURE>" class SuryaOCRProcessor(S3DownloaderMixin, ProcessorMixin): attributes = ["image_processor", "ocr_tokenizer"] image_processor_class = "BaseImageProcessor" ocr_tokenizer_class = "PreTrainedTokenizer" rescale_factor = 1 / 255.0 image_mean = (0.485, 0.456, 0.406) image_std = (0.229, 0.224, 0.225) def __init__( self, ocr_tokenizer: PreTrainedTokenizer, blank_bbox_token_id: int, num_register_tokens: int, patch_size: int, merge_size: int, num_beacon_tokens: int, beacon_token_interval: int, model_device: str, **kwargs, ): self.ocr_tokenizer = ocr_tokenizer self.patch_size = patch_size self.merge_size = merge_size self.num_register_tokens = num_register_tokens self.num_beacon_tokens = num_beacon_tokens self.beacon_token_interval = beacon_token_interval self.tokenizer_vocab_size = 0 for attr in self.attributes: if "tokenizer" in attr: self.tokenizer_vocab_size += getattr(self, attr).vocab_size self.offsets = {"ocr": 0} # Create special token mapping self.special_token_mapping = self.ocr_tokenizer.system_tokens self.register_token_ids = [ self.special_token_mapping.get(r) for r in REGISTER_TOKENS ] self.beacon_token_id = self.special_token_mapping.get(BEACON_TOKEN) self.image_token_id = self.special_token_mapping.get(IMAGE_TOKEN) self.pad_token_id = self.special_token_mapping.get(PAD_TOKEN) self.eos_token_id = self.special_token_mapping.get(EOS_TOKEN) self.eoi_token_id = self.special_token_mapping.get(EOI_TOKEN) self.no_output_token = self.special_token_mapping.get(NO_OUTPUT_TOKEN) self.image_rotated_token = self.special_token_mapping.get(IMAGE_ROTATED_TOKEN) self.nomath_token = self.special_token_mapping.get(NOMATH_TOKEN) self.bos_token_id = { TaskNames.ocr_with_boxes: self.special_token_mapping.get( OCR_WITH_BOXES_BOS_TOKEN ), TaskNames.ocr_without_boxes: self.special_token_mapping.get( OCR_WITHOUT_BOXES_BOS_TOKEN ), TaskNames.block_without_boxes: self.special_token_mapping.get( BLOCK_WITHOUT_BOXES_TOKEN ), TaskNames.layout: self.special_token_mapping.get(LAYOUT_BOS_TOKEN), TaskNames.table_structure: self.special_token_mapping.get( TABLE_STRUCTURE_BOS_TOKEN ), } if self.image_token_id is None: logger.warning("Warning: Image token not found in special tokens") self.blank_bbox_token_id = blank_bbox_token_id self.bbox_pad_token_id = self.blank_bbox_token_id self.ignore_bbox_token_ids = [ v for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items() if k not in self.ocr_tokenizer.special_tokens["math_external"] ] math_end_token = "</math>" self.math_start_token_ids = [ v for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items() if k in self.ocr_tokenizer.special_tokens["math_external"] and k != math_end_token ] self.math_end_token_ids = [ v for (k, v) in self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING.items() if k == math_end_token ] if self.num_register_tokens > len(self.register_token_ids): raise ValueError( "The number of register tokens requested exceeds the number of register tokens defined in the special token mapping." ) self.image_mean = np.array(self.image_mean, dtype=np.float32) self.image_std = np.array(self.image_std, dtype=np.float32) self.model_device = model_device @property def vocab_size(self): return self.tokenizer_vocab_size def image_processor(self, image: Image.Image) -> np.ndarray: # Convert to array image = np.asarray(image, dtype=np.float32) return image @staticmethod def scale_to_fit( img: np.ndarray, max_size: Tuple[int, int], min_size: Tuple[int, int] = (168, 168), ): # Get current dimensions height, width = img.shape[:2] # Check for empty or invalid image if width == 0 or height == 0: return img max_width, max_height = max_size min_width, min_height = min_size # Calculate pixel counts current_pixels = width * height max_pixels = max_width * max_height min_pixels = min_width * min_height if current_pixels > max_pixels: scale_factor = (max_pixels / current_pixels) ** 0.5 new_width = math.floor(width * scale_factor) new_height = math.floor(height * scale_factor) elif current_pixels == 0: return img elif current_pixels < min_pixels: scale_factor = (min_pixels / current_pixels) ** 0.5 new_width = math.ceil(width * scale_factor) new_height = math.ceil(height * scale_factor) else: return img return cv2.resize( img, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4 ) def _image_processor(self, image: np.ndarray): image = image.astype(np.float64) * self.rescale_factor image = (image.astype(np.float32) - self.image_mean) / self.image_std return image def _process_and_tile( self, image: np.ndarray ) -> Tuple[torch.Tensor, Tuple[int, int, int]]: """ Resizes the input image to the closest multiple of tile_size while preserving the aspect ratio and returns a tensor of image tiles. """ extra_multipler = ( 4 if settings.FOUNDATION_XLA else 1 ) # Needed to force same size grid_thws per row with padding factor = ( self.patch_size * self.merge_size * extra_multipler ) # Make a multiple of window size height, width = image.shape[:2] h_bar = math.ceil(height / factor) * factor w_bar = math.ceil(width / factor) * factor if h_bar != height or w_bar != width: if height == 0 or width == 0: image = np.zeros((h_bar, w_bar, 3), dtype=np.uint8) else: image = cv2.resize(image, (w_bar, h_bar), interpolation=cv2.INTER_CUBIC) # Handle scaling and normalization image = self._image_processor(image) height, width = image.shape[:2] # Numpy array to torch tensor img_tensor = torch.from_numpy(image.transpose(2, 0, 1)) patches = img_tensor.unsqueeze(0) channel = patches.shape[1] grid_t = patches.shape[0] grid_h, grid_w = height // self.patch_size, width // self.patch_size patches = patches.reshape( grid_t, 1, channel, grid_h // self.merge_size, self.merge_size, self.patch_size, grid_w // self.merge_size, self.merge_size, self.patch_size, ) patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) flatten_patches = patches.reshape( grid_t * grid_h * grid_w, channel * 1 * self.patch_size * self.patch_size ) return flatten_patches, (grid_t, grid_h, grid_w) # Handle image input dictionaries - Process image, tile accordingly, and setup the input ids and boxes correspondingly def _process_image_input(self, image_input: ImageInput) -> ProcessorOutput: rotated = image_input.get("rotated", False) image = image_input.get("image", None) assert image is not None, ( "A PIL Image must be provided when the input type is `image`" ) image_tiles, grid_thw = self._process_and_tile(image) num_tokens = image_tiles.shape[0] / self.merge_size**2 assert num_tokens.is_integer(), ( f"Expected number of tokens to be an integer, got {num_tokens}" ) input_ids = [self.image_token_id] * int(num_tokens) input_ids += self.register_token_ids[: self.num_register_tokens] # Handle the image being rotated in the imdataset if rotated: input_ids = [self.image_rotated_token] + input_ids return ProcessorOutput( input_ids=input_ids, image_tiles=image_tiles, grid_thw=grid_thw, ) def _process_text_input(self, text_input: TextInput, task: str) -> ProcessorOutput: input_text = text_input.get("text", None) math_mode = text_input.get("math", False) input_ids = self.ocr_tokenizer(input_text, tasks=task)["input_ids"][0] input_ids = [self.offsets["ocr"] + id for id in input_ids] # nomath token does not work for layout if not math_mode and task != "layout": input_ids.insert(0, self.nomath_token) return ProcessorOutput( input_ids=input_ids, image_tiles=None, grid_thw=None, ) def _process_input(self, input_dict: dict, task: str): input_type = input_dict["type"] if input_type == "image": return self._process_image_input(input_dict) elif input_type == "text": return self._process_text_input(input_dict, task) raise NotImplementedError(f"Input of type `{input_type}` is not implemented") # Peprocessing for OCR task # The task is expected to have - image_dict, user_input_dict, output_dict # use_input_dict is allowed to have an empty input which is fine, but needs to be present def _process_ocr_with_boxes( self, mixed_input: List[dict], bos_token_id: int, task: str = TaskNames.ocr_with_boxes, ): processed_input_ids = [] all_image_tiles = [] all_grid_thw = [] # 1. Process the image input for i, input_dict in enumerate(mixed_input): processor_output = self._process_input(input_dict, task) input_ids = processor_output["input_ids"] image_tiles = processor_output["image_tiles"] grid_thw = processor_output["grid_thw"] # Special handling of some delimiter tokens if i == 1: assert input_dict["type"] == "text", ( "Expected text input for model input." ) # Case for input - Add task specific bos token + end_of_input token # We do not want the model to learn how to predict inputs. Hence IGNORE_INDEX for these input_ids = [bos_token_id] + input_ids + [self.eoi_token_id] if i == 2: assert input_dict["type"] == "text", ( "Expected text for final model input" ) input_ids = input_ids + [self.eos_token_id] elif i > 2: raise ValueError(f"Too many inputs received. Expected is 2 for inference, 3 for training. Received: {len(mixed_input)}") # Some input types don't return any image tiles, accounting for that if image_tiles is not None: all_image_tiles.append(image_tiles) all_grid_thw.append(grid_thw) processed_input_ids.extend(input_ids) return ( torch.tensor(processed_input_ids, dtype=torch.long), all_image_tiles, all_grid_thw, ) def _process_layout(self, mixed_input: List[dict], bos_token_id: int): return self._process_ocr_with_boxes( mixed_input, bos_token_id=bos_token_id, task="layout" ) def _process_table_structure(self, mixed_input: List[dict], bos_token_id: int): return self._process_ocr_with_boxes( mixed_input, bos_token_id=bos_token_id, task="table_structure" ) def _process_ocr_without_boxes( self, mixed_input: List[dict], bos_token_id: int, task: str = "ocr_without_boxes", ): # Boxes are set to None, so this will work # TODO: improve this behavior return self._process_ocr_with_boxes( mixed_input, bos_token_id=bos_token_id, task=task ) def _process_block_without_boxes( self, mixed_input: List[dict], bos_token_id: int, task: str = "block_without_boxes", ): return self._process_ocr_with_boxes( mixed_input, bos_token_id=bos_token_id, task=task ) def align_long_axis(self, image: np.ndarray) -> Tuple[np.ndarray, bool]: height, width, _ = image.shape if height > width: # Rotate vertical lines image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) return image, True return image, False def __call__( self, mixed_batch: List[dict], padding_side: Optional[str] = "left", device: Optional[torch.device] = None, pad_to_multiple: Optional[int] = None, ): all_image_tiles = [] all_input_ids = [] all_grid_thw = [] for b in mixed_batch: mixed_input = b["inputs"] task = b["task"] assert task in self.bos_token_id, f"Task {task} has no bos token defined." # Select the correct processing function based on the task type input_ids, image_tiles, grid_thw = getattr(self, f"_process_{task}")( mixed_input, self.bos_token_id[task] ) all_input_ids.append(input_ids) all_image_tiles.extend(image_tiles) all_grid_thw.extend(grid_thw) batched_input_ids = pad_sequence( all_input_ids, batch_first=True, padding_side=padding_side, padding_value=self.pad_token_id, ) if pad_to_multiple is not None: current_len = batched_input_ids.shape[1] # Calculate the next multiple of pad_to_multiple padded_len = ( (current_len + pad_to_multiple - 1) // pad_to_multiple ) * pad_to_multiple if padded_len > current_len: pad_len = padded_len - current_len batched_input_ids = torch.nn.functional.pad( batched_input_ids, (pad_len, 0), value=self.pad_token_id ) attention_mask = batched_input_ids.ne(self.pad_token_id) # Generating position IDs that are independent of left and right padding; # This should ensure same results for either padding side. Exact position id for the pad tokens themselves don't matter since they are masked position_ids = attention_mask.cumsum(dim=-1) - 1 position_ids[position_ids < 0] = ( 0 # For left padding, the position ids for padding will become -1 because of the shift; Setting to 0 ) position_ids = ( attention_mask.to(torch.long) * position_ids ) # Ensure right pad ids get set to zero batched_image_tiles = torch.cat(all_image_tiles, dim=0) batched_grid_thw = torch.from_numpy(np.array(all_grid_thw)) # Pin memory for CUDA if device == torch.device("cuda"): batched_image_tiles = batched_image_tiles.pin_memory() batched_grid_thw = batched_grid_thw.pin_memory() attention_mask = attention_mask.pin_memory() batched_input_ids = batched_input_ids.pin_memory() position_ids = position_ids.pin_memory() return BatchFeature( { "input_ids": batched_input_ids, "image_tiles": batched_image_tiles, "attention_mask": attention_mask, "position_ids": position_ids, "grid_thw": batched_grid_thw, } ) # Decode model outputs; Strips special tokens def decode(self, tokens: List[int], task: str): filtered_tokens = [ t for t in tokens if t not in self.special_token_mapping.values() and t != -100 ] # Skip special tokens and loss ignore index return self.ocr_tokenizer.decode(filtered_tokens, task=task) ``` -------------------------------------------------------------------------------- /surya/recognition/__init__.py: -------------------------------------------------------------------------------- ```python from __future__ import annotations import re from typing import List import numpy as np import torch from PIL import Image import torch.nn.functional as F from surya.common.polygon import PolygonBox from surya.common.surya.processor import NOMATH_TOKEN from surya.common.predictor import BasePredictor from surya.detection import DetectionPredictor from surya.foundation import FoundationPredictor from surya.input.processing import ( convert_if_not_rgb, slice_polys_from_image, slice_bboxes_from_image, ) from surya.recognition.postprocessing import fix_unbalanced_tags from surya.recognition.util import ( sort_text_lines, clean_close_polygons, unwrap_math, clean_math_tags, filter_blacklist_tags, words_from_chars ) from surya.foundation.util import detect_repeat_token, prediction_to_polygon_batch from surya.recognition.schema import TextLine, OCRResult, TextChar from surya.common.surya.schema import TaskNames from surya.settings import settings from surya.logging import get_logger, configure_logging configure_logging() logger = get_logger() class RecognitionPredictor(BasePredictor): batch_size = settings.RECOGNITION_BATCH_SIZE default_batch_sizes = {"cpu": 32, "mps": 64, "cuda": 256, "xla": 128} # Override base init - Do not load model def __init__(self, foundation_predictor: FoundationPredictor): self.foundation_predictor = foundation_predictor self.processor = self.foundation_predictor.processor self.bbox_size = self.foundation_predictor.model.config.bbox_size self.tasks = self.foundation_predictor.tasks # Special handling for disable tqdm to pass into foundation predictor # Make sure they are kept in sync @property def disable_tqdm(self) -> bool: return super().disable_tqdm @disable_tqdm.setter def disable_tqdm(self, value: bool) -> None: self._disable_tqdm = bool(value) self.foundation_predictor.disable_tqdm = bool(value) def detect_and_slice_bboxes( self, images: List[Image.Image], task_names: List[str], det_predictor: DetectionPredictor, detection_batch_size: int | None = None, highres_images: List[Image.Image] | None = None, ): det_predictions = det_predictor(images, batch_size=detection_batch_size) all_slices = [] slice_map = [] all_polygons = [] all_task_names = [] all_res_scales = [] for idx, (det_pred, image, highres_image, task_name) in enumerate( zip(det_predictions, images, highres_images, task_names) ): polygons = [p.polygon for p in det_pred.bboxes] if highres_image: width_scaler = highres_image.size[0] / image.size[0] height_scaler = highres_image.size[1] / image.size[1] scaled_polygons = [ [ [int(p[0] * width_scaler), int(p[1] * height_scaler)] for p in polygon ] for polygon in polygons ] highres_image = self.processor.image_processor(highres_image) slices = slice_polys_from_image(highres_image, scaled_polygons) res_scales = [(width_scaler, height_scaler) for _ in range(len(slices))] else: image = self.processor.image_processor(image) slices = slice_polys_from_image(image, polygons) res_scales = [(1, 1) for _ in range(len(slices))] slice_map.append(len(slices)) all_slices.extend(slices) all_polygons.extend(polygons) all_task_names.extend([task_name] * len(slices)) all_res_scales.extend(res_scales) assert ( len(all_slices) == sum(slice_map) == len(all_polygons) == len(all_task_names) == len(all_res_scales) ) return { "slices": all_slices, "slice_map": slice_map, "polygons": all_polygons, "task_names": all_task_names, "input_text": [None] * len(all_slices), "res_scales": all_res_scales, } def slice_bboxes( self, images: List[Image.Image], task_names: List[str], bboxes: List[List[List[int]]] | None = None, polygons: List[List[List[List[int]]]] | None = None, input_text: List[List[str | None]] | None = None, ) -> dict: assert bboxes is not None or polygons is not None slice_map = [] all_slices = [] all_polygons = [] all_text = [] all_task_names = [] for idx, image in enumerate(images): image = self.processor.image_processor(image) if polygons is not None: polys = polygons[idx] slices = slice_polys_from_image(image, polys) else: slices = slice_bboxes_from_image(image, bboxes[idx]) polys = [ [ [bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]], ] for bbox in bboxes[idx] ] slice_map.append(len(slices)) all_slices.extend(slices) all_polygons.extend(polys) all_task_names.extend([task_names[idx]] * len(slices)) if input_text is None: all_text.extend([None] * len(slices)) else: all_text.extend(input_text[idx]) assert ( len(all_slices) == sum(slice_map) == len(all_polygons) == len(all_text) == len(all_task_names) ), ( f"Mismatch in lengths: {len(all_slices)}, {sum(slice_map)}, {len(all_polygons)}, {len(all_text)}, {len(all_task_names)}" ) return { "slices": all_slices, "slice_map": slice_map, "polygons": all_polygons, "input_text": all_text, "task_names": all_task_names, "res_scales": [(1, 1) for _ in range(len(all_slices))], } def get_bboxes_text( self, flat: dict, predicted_tokens: list, scores: list, predicted_polygons: list, drop_repeated_text: bool = False, ) -> list: char_predictions = [] needs_boxes = [ self.tasks[task_name]["needs_bboxes"] for task_name in flat["task_names"] ] for slice_idx, ( slice_image, image_tokens, image_polygons, image_scores, needs_box, ) in enumerate( zip( flat["slices"], predicted_tokens, predicted_polygons, scores, needs_boxes, ) ): blank_bbox = [[0, 0], [0, 1], [1, 1], [1, 0]] if self.processor.no_output_token in image_tokens: char_predictions.append(None) continue # If the image is very out of distribution, we can get nonsense repeats, and we may need to drop the text entirely if drop_repeated_text and detect_repeat_token(image_tokens): char_predictions.append( [ TextChar( text="", polygon=blank_bbox, confidence=0, bbox_valid=False, ) ] ) continue image_polygons = image_polygons[: len(image_tokens)].cpu().numpy().tolist() detokenize_sequences = [] detokenize_sequence = [] past_char_qwen_token = False def _add_detokenize_sequence( special_token: bool, past_special_token: bool, force: bool = False, ): nonlocal detokenize_sequence, detokenize_sequences if ( special_token or past_special_token or force ) and detokenize_sequence: chars = [dt[0] for dt in detokenize_sequence] scores = [dt[1] for dt in detokenize_sequence] bboxes = [dt[2] for dt in detokenize_sequence] if past_special_token: detokenize_sequences.append((chars, scores, None, "special")) else: detokenize_sequences.append((chars, scores, bboxes, "ocr")) detokenize_sequence = [] # Split up into sequences to detokenize separately past_special_token = False for bbox, char_id, score in zip(image_polygons, image_tokens, image_scores): if char_id in [ self.processor.eos_token_id, self.processor.pad_token_id, ]: break special_token = ( char_id >= self.processor.ocr_tokenizer.ocr_tokenizer.SPECIAL_BASE ) _add_detokenize_sequence( special_token, past_special_token ) detokenize_sequence.append((char_id, score, bbox)) past_special_token = special_token _add_detokenize_sequence( False, past_special_token, force=True ) img_chars = [] for sequence in detokenize_sequences: token_ids, seq_score, bboxes, token_type = sequence if token_type == "ocr": text = self.processor.ocr_tokenizer.decode( token_ids, task=TaskNames.ocr_with_boxes ) bboxes = clean_close_polygons( bboxes ) # clean out bboxes that are close, like what happens with multiple utf-16 tokens per char bbox_idx = 0 for text_idx, text_line in enumerate(text): img_chars.append( TextChar( text=text_line, polygon=bboxes[bbox_idx], confidence=seq_score[bbox_idx], bbox_valid=True, ) ) # Ensure we don't exceed the bbox count # Use the last bbox for the rest of the text if bbox_idx < len(bboxes) - 1: bbox_idx += 1 elif token_type == "special": text = self.processor.ocr_tokenizer.decode( token_ids, task="ocr_without_boxes" ) if text in [NOMATH_TOKEN] or re.match(r"<SCRIPT-\w+>", text): continue img_chars.append( TextChar( text=text, polygon=blank_bbox, confidence=seq_score[0], bbox_valid=False, ) ) else: text = self.processor.ocr_tokenizer.decode( token_ids, task=TaskNames.block_without_boxes ) img_chars.append( TextChar( text=text, polygon=blank_bbox, confidence=seq_score[0], bbox_valid=False, ) ) char_predictions.append(img_chars) return char_predictions def __call__( self, images: List[Image.Image], task_names: List[str] | None = None, det_predictor: DetectionPredictor | None = None, detection_batch_size: int | None = None, recognition_batch_size: int | None = None, highres_images: List[Image.Image] | None = None, bboxes: List[List[List[int]]] | None = None, polygons: List[List[List[List[int]]]] | None = None, input_text: List[List[str | None]] | None = None, sort_lines: bool = False, math_mode: bool = True, return_words: bool = False, drop_repeated_text: bool = False, max_sliding_window: int | None = None, max_tokens: int | None = None, filter_tag_list: List[str] = None ) -> List[OCRResult]: if task_names is None: task_names = [TaskNames.ocr_with_boxes] * len(images) if recognition_batch_size is None: recognition_batch_size = self.get_batch_size() assert len(images) == len(task_names), ( "You need to pass in one task name for each image" ) images = convert_if_not_rgb(images) if highres_images is not None: assert len(images) == len(highres_images), ( "You need to pass in one highres image for each image" ) highres_images = ( convert_if_not_rgb(highres_images) if highres_images is not None else [None] * len(images) ) if bboxes is None and polygons is None: assert det_predictor is not None, ( "You need to pass in a detection predictor if you don't provide bboxes or polygons" ) # Detect then slice flat = self.detect_and_slice_bboxes( images, task_names, det_predictor, detection_batch_size=detection_batch_size, highres_images=highres_images, ) else: if bboxes is not None: assert len(images) == len(bboxes), ( "You need to pass in one list of bboxes for each image" ) if polygons is not None: assert len(images) == len(polygons), ( "You need to pass in one list of polygons for each image" ) flat = self.slice_bboxes( images, bboxes=bboxes, polygons=polygons, input_text=input_text, task_names=task_names, ) # No images passed, or no boxes passed, or no text detected in the images if len(flat["slices"]) == 0: return [ OCRResult( text_lines=[], image_bbox=[0, 0, im.size[0], im.size[1]] ) for im in images ] # Sort by image sizes. Negative so that longer images come first, fits in with continuous batching better sorted_pairs = sorted( enumerate(flat["slices"]), key=lambda x: -(x[1].shape[0] * x[1].shape[1]) # height * width ) indices, sorted_slices = zip(*sorted_pairs) # Reorder input_text and task_names based on the new order flat["slices"] = list(sorted_slices) flat["input_text"] = [flat["input_text"][i] for i in indices] flat["task_names"] = [flat["task_names"][i] for i in indices] # Make predictions predicted_tokens, batch_bboxes, scores, _ = self.foundation_predictor.prediction_loop( images=flat["slices"], input_texts=flat["input_text"], task_names=flat["task_names"], batch_size=recognition_batch_size, math_mode=math_mode, drop_repeated_tokens=True, max_lookahead_tokens=self.foundation_predictor.model.config.multi_output_distance, max_sliding_window=max_sliding_window, max_tokens=max_tokens, tqdm_desc="Recognizing Text" ) # Get text and bboxes in structured form bbox_size = self.bbox_size image_sizes = [img.shape for img in flat["slices"]] predicted_polygons = prediction_to_polygon_batch( batch_bboxes, image_sizes, bbox_size, bbox_size // 2 ) char_predictions = self.get_bboxes_text( flat, predicted_tokens, scores, predicted_polygons, drop_repeated_text=drop_repeated_text, ) char_predictions = sorted(zip(indices, char_predictions), key=lambda x: x[0]) char_predictions = [pred for _, pred in char_predictions] predictions_by_image = [] slice_start = 0 for idx, image in enumerate(images): slice_end = slice_start + flat["slice_map"][idx] image_lines = char_predictions[slice_start:slice_end] polygons = flat["polygons"][slice_start:slice_end] res_scales = flat["res_scales"][slice_start:slice_end] slice_start = slice_end lines = [] for text_line, polygon, res_scale in zip(image_lines, polygons, res_scales): # Special case when input text is good if not text_line: lines.append( TextLine( text="", polygon=polygon, chars=[], confidence=1, original_text_good=True, ) ) else: confidence = ( float(np.mean([char.confidence for char in text_line])) if len(text_line) > 0 else 0 ) poly_box = PolygonBox(polygon=polygon) for char in text_line: char.rescale( res_scale, (1, 1) ) # Rescale from highres if needed char.shift( poly_box.bbox[0], poly_box.bbox[1] ) # Ensure character boxes match line boxes (relative to page) char.clamp(poly_box.bbox) text_line = fix_unbalanced_tags( text_line, self.processor.ocr_tokenizer.special_tokens ) text_line = filter_blacklist_tags(text_line, filter_tag_list) text = "".join([char.text for char in text_line]) text = unwrap_math(text) text = clean_math_tags(text) lines.append( TextLine( text=text, polygon=polygon, chars=text_line, confidence=confidence, words=words_from_chars(text_line, poly_box) if return_words else [], ) ) if sort_lines: lines = sort_text_lines(lines) predictions_by_image.append( OCRResult( text_lines=lines, image_bbox=[0, 0, image.size[0], image.size[1]] ) ) return predictions_by_image ``` -------------------------------------------------------------------------------- /surya/common/surya/decoder/__init__.py: -------------------------------------------------------------------------------- ```python from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import ( Cache, ) from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import ( BaseModelOutputWithPast, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.processing_utils import Unpack from transformers.utils import ( logging, ) from surya.common.pretrained import SuryaPreTrainedModel from surya.common.surya.decoder.config import SuryaDecoderConfig logger = logging.get_logger(__name__) class Qwen2MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( query.dtype ) attn_weights = nn.functional.dropout( attn_weights, p=dropout, training=module.training ) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Qwen2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: SuryaDecoderConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) self.num_key_value_groups = ( config.num_attention_heads // config.num_key_value_heads ) self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=True ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=False ) def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, cache_idxs: Optional[List[int]] = None, num_valid_tokens: Optional[List[int]] = None, text_lengths: Optional[List[int]] = None, prefill: bool = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache # cache_idxs, num_valid_tokens, and prefill add support for our new caching mechanism cache_kwargs = { "sin": sin, "cos": cos, "cache_position": cache_position, "cache_idxs": cache_idxs, "num_valid_tokens": num_valid_tokens, "prefill": prefill, "text_lengths": text_lengths, } key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get( "output_attentions", False ): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) elif self.config._attn_implementation == "flash_attention_2": # Needed for CPU -> GPU from surya.common.surya.flash_attn_utils import ( flash_attn_decode, flash_attn_prefill, ) if prefill: attention_interface = flash_attn_prefill else: attention_interface = flash_attn_decode else: attention_interface = ALL_ATTENTION_FUNCTIONS[ self.config._attn_implementation ] """ IMPORTANT: We sometimes use a custom sliding window impl. during training We force this to None to ensure that the HF attention integrations do not perform any special handling - FA2 in particular will ignore the 4D mask, and use this instead to infer the final mask SDPA ignores this completely, and is fully dependent on the 4D mask - (https://github.com/huggingface/transformers/blob/b9faf2f93085e3cf2c65184a69d1d9e502f95786/src/transformers/integrations/sdpa_attention.py#L23) """ sliding_window = None attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=sliding_window, # main diff with Llama **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Qwen2RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Qwen2DecoderLayer(nn.Module): def __init__(self, config: SuryaDecoderConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, cache_idxs: Optional[List[int]] = None, num_valid_tokens: Optional[List[int]] = None, text_lengths: Optional[List[int]] = None, prefill: bool = False, position_embeddings: Optional[ Tuple[torch.Tensor, torch.Tensor] ] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, cache_idxs=cache_idxs, num_valid_tokens=num_valid_tokens, text_lengths=text_lengths, prefill=prefill, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs class Qwen2RotaryEmbedding(nn.Module): def __init__(self, config: SuryaDecoderConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type") ) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len ) self.register_buffer( "inv_freq", inv_freq, persistent=False ) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if ( seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len ): # reset # This .to() is needed if the model has been moved to a device after being initialized (because # the buffer is automatically moved, but not the original copy) self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = ( self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = ( device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ) with torch.autocast(device_type=device_type, enabled=False): freqs = ( inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float() ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class Qwen2PreTrainedModel(SuryaPreTrainedModel): config_class = SuryaDecoderConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class SuryaDecoderModel(Qwen2PreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] This variant has been modified to remove the embedding layer completely - It only supports inputs_embeds as an input Args: config: Qwen2Config """ def __init__(self, config: SuryaDecoderConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.layers = nn.ModuleList( [ Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def forward( self, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, cache_idxs: Optional[List[int]] = None, num_valid_tokens: Optional[List[int]] = None, text_lengths: Optional[List[int]] = None, prefill: bool = False, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if inputs_embeds is None: raise ValueError("You must specify inputs_embeds") if cache_position is None: raise ValueError("You must specify cache_position") if position_ids is None: raise ValueError("You must specify position_ids") hidden_states = inputs_embeds causal_mask = ( attention_mask # We make the 4D mask in the combined model when needed ) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers for decoder_layer in self.layers[: self.config.num_hidden_layers]: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, cache_idxs=cache_idxs, num_valid_tokens=num_valid_tokens, prefill=prefill, text_lengths=text_lengths, **flash_attn_kwargs, ) hidden_states = layer_outputs[0] hidden_states = self.norm(hidden_states) output = BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) return output if return_dict else output.to_tuple() ``` -------------------------------------------------------------------------------- /surya/ocr_error/tokenizer.py: -------------------------------------------------------------------------------- ```python import collections import os import json import unicodedata from typing import List, Optional, Tuple from tokenizers import normalizers from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from surya.common.s3 import S3DownloaderMixin VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} # Copied from transformers.models.bert.tokenization_bert.load_vocab def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() with open(vocab_file, "r", encoding="utf-8") as reader: tokens = reader.readlines() for index, token in enumerate(tokens): token = token.rstrip("\n") vocab[token] = index return vocab # Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize def whitespace_tokenize(text): """Runs basic whitespace cleaning and splitting on a piece of text.""" text = text.strip() if not text: return [] tokens = text.split() return tokens class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer): r""" Construct a DistilBERT tokenizer. Based on WordPiece. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. Args: vocab_file (`str`): File containing the vocabulary. do_lower_case (`bool`, *optional*, defaults to `True`): Whether or not to lowercase the input when tokenizing. do_basic_tokenize (`bool`, *optional*, defaults to `True`): Whether or not to do basic tokenization before WordPiece. never_split (`Iterable`, *optional*): Collection of tokens which will never be split during tokenization. Only has an effect when `do_basic_tokenize=True` unk_token (`str`, *optional*, defaults to `"[UNK]"`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. sep_token (`str`, *optional*, defaults to `"[SEP]"`): The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering. It is also used as the last token of a sequence built with special tokens. pad_token (`str`, *optional*, defaults to `"[PAD]"`): The token used for padding, for example when batching sequences of different lengths. cls_token (`str`, *optional*, defaults to `"[CLS]"`): The classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification). It is the first token of the sequence when built with special tokens. mask_token (`str`, *optional*, defaults to `"[MASK]"`): The token used for masking values. This is the token used when training this model with masked language modeling. This is the token which the model will try to predict. tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this [issue](https://github.com/huggingface/transformers/issues/328)). strip_accents (`bool`, *optional*): Whether or not to strip all accents. If this option is not specified, then it will be determined by the value for `lowercase` (as in the original BERT). """ vocab_files_names = VOCAB_FILES_NAMES model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", mask_token="[MASK]", tokenize_chinese_chars=True, strip_accents=None, **kwargs, ): if not os.path.isfile(vocab_file): raise ValueError( f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" " model use `tokenizer = DistilBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" ) self.vocab = load_vocab(vocab_file) self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) self.do_basic_tokenize = do_basic_tokenize if do_basic_tokenize: self.basic_tokenizer = BasicTokenizer( do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars, strip_accents=strip_accents, ) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) super().__init__( do_lower_case=do_lower_case, do_basic_tokenize=do_basic_tokenize, never_split=never_split, unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, tokenize_chinese_chars=tokenize_chinese_chars, strip_accents=strip_accents, **kwargs, ) @property # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.do_lower_case def do_lower_case(self): return self.basic_tokenizer.do_lower_case @property # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size def vocab_size(self): return len(self.vocab) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: for token in self.basic_tokenizer.tokenize( text, never_split=self.all_special_tokens if not split_special_tokens else None ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) else: split_tokens += self.wordpiece_tokenizer.tokenize(token) else: split_tokens = self.wordpiece_tokenizer.tokenize(text) return split_tokens # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.vocab.get(token, self.vocab.get(self.unk_token)) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" return self.ids_to_tokens.get(index, self.unk_token) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" out_string = " ".join(tokens).replace(" ##", "").strip() return out_string # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A BERT sequence has the following format: - single sequence: `[CLS] X [SEP]` - pair of sequences: `[CLS] A [SEP] B [SEP]` Args: token_ids_0 (`List[int]`): List of IDs to which the special tokens will be added. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ if token_ids_1 is None: return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] cls = [self.cls_token_id] sep = [self.sep_token_id] return cls + token_ids_0 + sep + token_ids_1 + sep # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. already_has_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not the token list is already formatted with special tokens for the model. Returns: `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True ) if token_ids_1 is not None: return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] return [1] + ([0] * len(token_ids_0)) + [1] # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence pair mask has the following format: ``` 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second sequence | ``` If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). """ sep = [self.sep_token_id] cls = [self.cls_token_id] if token_ids_1 is None: return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: index = 0 if os.path.isdir(save_directory): vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) else: vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory with open(vocab_file, "w", encoding="utf-8") as writer: for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): if index != token_index: # logger.warning( # f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." # " Please check that the vocabulary is not corrupted!" # ) index = token_index writer.write(token + "\n") index += 1 return (vocab_file,) # Copied from transformers.models.bert.tokenization_bert.BasicTokenizer class BasicTokenizer(object): """ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). Args: do_lower_case (`bool`, *optional*, defaults to `True`): Whether or not to lowercase the input when tokenizing. never_split (`Iterable`, *optional*): Collection of tokens which will never be split during tokenization. Only has an effect when `do_basic_tokenize=True` tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this [issue](https://github.com/huggingface/transformers/issues/328)). strip_accents (`bool`, *optional*): Whether or not to strip all accents. If this option is not specified, then it will be determined by the value for `lowercase` (as in the original BERT). do_split_on_punc (`bool`, *optional*, defaults to `True`): In some instances we want to skip the basic punctuation splitting so that later tokenization can capture the full context of the words, such as contractions. """ def __init__( self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None, do_split_on_punc=True, ): if never_split is None: never_split = [] self.do_lower_case = do_lower_case self.never_split = set(never_split) self.tokenize_chinese_chars = tokenize_chinese_chars self.strip_accents = strip_accents self.do_split_on_punc = do_split_on_punc def tokenize(self, text, never_split=None): """ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. Args: never_split (`List[str]`, *optional*) Kept for backward compatibility purposes. Now implemented directly at the base class level (see [`PreTrainedTokenizer.tokenize`]) List of token not to split. """ # union() returns a new set by concatenating the two sets. never_split = self.never_split.union(set(never_split)) if never_split else self.never_split text = self._clean_text(text) # This was added on November 1st, 2018 for the multilingual and Chinese # models. This is also applied to the English models now, but it doesn't # matter since the English models were not trained on any Chinese data # and generally don't have any Chinese data in them (there are Chinese # characters in the vocabulary because Wikipedia does have some Chinese # words in the English Wikipedia.). if self.tokenize_chinese_chars: text = self._tokenize_chinese_chars(text) # prevents treating the same character with different unicode codepoints as different characters unicode_normalized_text = unicodedata.normalize("NFC", text) orig_tokens = whitespace_tokenize(unicode_normalized_text) split_tokens = [] for token in orig_tokens: if token not in never_split: if self.do_lower_case: token = token.lower() if self.strip_accents is not False: token = self._run_strip_accents(token) elif self.strip_accents: token = self._run_strip_accents(token) split_tokens.extend(self._run_split_on_punc(token, never_split)) output_tokens = whitespace_tokenize(" ".join(split_tokens)) return output_tokens def _run_strip_accents(self, text): """Strips accents from a piece of text.""" text = unicodedata.normalize("NFD", text) output = [] for char in text: cat = unicodedata.category(char) if cat == "Mn": continue output.append(char) return "".join(output) def _run_split_on_punc(self, text, never_split=None): """Splits punctuation on a piece of text.""" if not self.do_split_on_punc or (never_split is not None and text in never_split): return [text] chars = list(text) i = 0 start_new_word = True output = [] while i < len(chars): char = chars[i] if _is_punctuation(char): output.append([char]) start_new_word = True else: if start_new_word: output.append([]) start_new_word = False output[-1].append(char) i += 1 return ["".join(x) for x in output] def _tokenize_chinese_chars(self, text): """Adds whitespace around any CJK character.""" output = [] for char in text: cp = ord(char) if self._is_chinese_char(cp): output.append(" ") output.append(char) output.append(" ") else: output.append(char) return "".join(output) def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # # Note that the CJK Unicode block is NOT all Japanese and Korean characters, # despite its name. The modern Korean Hangul alphabet is a different block, # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled # like the all of the other languages. if ( (cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) # or (cp >= 0x20000 and cp <= 0x2A6DF) # or (cp >= 0x2A700 and cp <= 0x2B73F) # or (cp >= 0x2B740 and cp <= 0x2B81F) # or (cp >= 0x2B820 and cp <= 0x2CEAF) # or (cp >= 0xF900 and cp <= 0xFAFF) or (cp >= 0x2F800 and cp <= 0x2FA1F) # ): # return True return False def _clean_text(self, text): """Performs invalid character removal and whitespace cleanup on text.""" output = [] for char in text: cp = ord(char) if cp == 0 or cp == 0xFFFD or _is_control(char): continue if _is_whitespace(char): output.append(" ") else: output.append(char) return "".join(output) # Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer class WordpieceTokenizer(object): """Runs WordPiece tokenization.""" def __init__(self, vocab, unk_token, max_input_chars_per_word=100): self.vocab = vocab self.unk_token = unk_token self.max_input_chars_per_word = max_input_chars_per_word def tokenize(self, text): """ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform tokenization using the given vocabulary. For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. Args: text: A single token or whitespace separated tokens. This should have already been passed through *BasicTokenizer*. Returns: A list of wordpiece tokens. """ output_tokens = [] for token in whitespace_tokenize(text): chars = list(token) if len(chars) > self.max_input_chars_per_word: output_tokens.append(self.unk_token) continue is_bad = False start = 0 sub_tokens = [] while start < len(chars): end = len(chars) cur_substr = None while start < end: substr = "".join(chars[start:end]) if start > 0: substr = "##" + substr if substr in self.vocab: cur_substr = substr break end -= 1 if cur_substr is None: is_bad = True break sub_tokens.append(cur_substr) start = end if is_bad: output_tokens.append(self.unk_token) else: output_tokens.extend(sub_tokens) return output_tokens ``` -------------------------------------------------------------------------------- /surya/detection/model/encoderdecoder.py: -------------------------------------------------------------------------------- ```python """ This is an implementation of efficientvit, with some modifications (decode head, etc). Original paper at https://arxiv.org/abs/2205.14756 Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit """ from __future__ import annotations from typing import Optional, Union, Tuple, List, Any from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_outputs import SemanticSegmenterOutput from surya.common.pretrained import SuryaPreTrainedModel from surya.common.s3 import S3DownloaderMixin from surya.detection.model.config import EfficientViTConfig def val2list(x: Union[List, Tuple, Any], repeat_time=1): if isinstance(x, (list, tuple)): return list(x) return [x for _ in range(repeat_time)] def val2tuple(x: Union[List, Tuple, Any], min_len: int = 1, idx_repeat: int = -1): # repeat elements if necessary x = val2list(x) if len(x) > 0: x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] return tuple(x) def get_same_padding( kernel_size: Union[int, Tuple[int, ...]], ) -> Union[int, Tuple[int, ...]]: if isinstance(kernel_size, tuple): return tuple([get_same_padding(ks) for ks in kernel_size]) else: assert kernel_size % 2 > 0, "kernel size should be odd number" return kernel_size // 2 def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int: padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding class ConvNormAct(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, dilation=1, groups=1, bias=False, dropout=0.0, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, ): super(ConvNormAct, self).__init__() self.dropout = nn.Dropout(dropout, inplace=False) padding = get_padding(kernel_size, stride, dilation) self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=groups, bias=bias, padding=padding, ) self.norm = ( norm_layer(num_features=out_channels) if norm_layer else nn.Identity() ) self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity() def forward(self, x): x = self.conv(x) x = self.norm(x) x = self.act(x) return x class DSConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, use_bias=False, norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), act_layer=(nn.ReLU6, None), ): super(DSConv, self).__init__() use_bias = val2tuple(use_bias, 2) norm_layer = val2tuple(norm_layer, 2) act_layer = val2tuple(act_layer, 2) self.depth_conv = ConvNormAct( in_channels, in_channels, kernel_size, stride, groups=in_channels, norm_layer=norm_layer[0], act_layer=act_layer[0], bias=use_bias[0], ) self.point_conv = ConvNormAct( in_channels, out_channels, 1, norm_layer=norm_layer[1], act_layer=act_layer[1], bias=use_bias[1], ) def forward(self, x): x = self.depth_conv(x) x = self.point_conv(x) return x class ConvBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, mid_channels=None, expand_ratio=1, use_bias=False, norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), act_layer=(nn.ReLU6, None), ): super(ConvBlock, self).__init__() use_bias = val2tuple(use_bias, 2) norm_layer = val2tuple(norm_layer, 2) act_layer = val2tuple(act_layer, 2) mid_channels = mid_channels or round(in_channels * expand_ratio) self.conv1 = ConvNormAct( in_channels, mid_channels, kernel_size, stride, norm_layer=norm_layer[0], act_layer=act_layer[0], bias=use_bias[0], ) self.conv2 = ConvNormAct( mid_channels, out_channels, kernel_size, 1, norm_layer=norm_layer[1], act_layer=act_layer[1], bias=use_bias[1], ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x class MBConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, mid_channels=None, expand_ratio=6, use_bias=False, norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d), act_layer=(nn.ReLU6, nn.ReLU6, None), ): super(MBConv, self).__init__() use_bias = val2tuple(use_bias, 3) norm_layer = val2tuple(norm_layer, 3) act_layer = val2tuple(act_layer, 3) mid_channels = mid_channels or round(in_channels * expand_ratio) self.inverted_conv = ConvNormAct( in_channels, mid_channels, 1, stride=1, norm_layer=norm_layer[0], act_layer=act_layer[0], bias=use_bias[0], ) self.depth_conv = ConvNormAct( mid_channels, mid_channels, kernel_size, stride=stride, groups=mid_channels, norm_layer=norm_layer[1], act_layer=act_layer[1], bias=use_bias[1], ) self.point_conv = ConvNormAct( mid_channels, out_channels, 1, norm_layer=norm_layer[2], act_layer=act_layer[2], bias=use_bias[2], ) def forward(self, x): x = self.inverted_conv(x) x = self.depth_conv(x) x = self.point_conv(x) return x class FusedMBConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, mid_channels=None, expand_ratio=6, groups=1, use_bias=False, norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), act_layer=(nn.ReLU6, None), ): super(FusedMBConv, self).__init__() use_bias = val2tuple(use_bias, 2) norm_layer = val2tuple(norm_layer, 2) act_layer = val2tuple(act_layer, 2) mid_channels = mid_channels or round(in_channels * expand_ratio) self.spatial_conv = ConvNormAct( in_channels, mid_channels, kernel_size, stride=stride, groups=groups, norm_layer=norm_layer[0], act_layer=act_layer[0], bias=use_bias[0], ) self.point_conv = ConvNormAct( mid_channels, out_channels, 1, norm_layer=norm_layer[1], act_layer=act_layer[1], bias=use_bias[1], ) def forward(self, x): x = self.spatial_conv(x) x = self.point_conv(x) return x class LiteMLA(nn.Module): """Lightweight multi-scale linear attention""" def __init__( self, in_channels: int, out_channels: int, heads: Union[int, None] = None, heads_ratio: float = 1.0, dim=8, use_bias=False, norm_layer=(None, nn.BatchNorm2d), act_layer=(None, None), kernel_func=nn.ReLU, scales=(5,), eps=1e-5, ): super(LiteMLA, self).__init__() self.eps = eps heads = heads or int(in_channels // dim * heads_ratio) total_dim = heads * dim use_bias = val2tuple(use_bias, 2) norm_layer = val2tuple(norm_layer, 2) act_layer = val2tuple(act_layer, 2) self.dim = dim self.qkv = ConvNormAct( in_channels, 3 * total_dim, 1, bias=use_bias[0], norm_layer=norm_layer[0], act_layer=act_layer[0], ) self.aggreg = nn.ModuleList( [ nn.Sequential( nn.Conv2d( 3 * total_dim, 3 * total_dim, scale, padding=get_same_padding(scale), groups=3 * total_dim, bias=use_bias[0], ), nn.Conv2d( 3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0], ), ) for scale in scales ] ) self.kernel_func = kernel_func(inplace=False) self.proj = ConvNormAct( total_dim * (1 + len(scales)), out_channels, 1, bias=use_bias[1], norm_layer=norm_layer[1], act_layer=act_layer[1], ) def _attn(self, q, k, v): dtype = v.dtype q, k, v = q.float(), k.float(), v.float() kv = k.transpose(-1, -2) @ v out = q @ kv out = out[..., :-1] / (out[..., -1:] + self.eps) return out.to(dtype) def forward(self, x): # Shape is B, C, H, W B, _, H, W = x.shape # generate multi-scale q, k, v qkv = self.qkv(x) multi_scale_qkv = [qkv] for op in self.aggreg: multi_scale_qkv.append(op(qkv)) multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose( -1, -2 ) # Shape for each is B, C, HW, head_dim q, k, v = multi_scale_qkv.chunk(3, dim=-1) # lightweight global attention q = self.kernel_func(q) k = self.kernel_func(k) v = F.pad(v, (0, 1), mode="constant", value=1.0) out = self._attn(q, k, v) # final projection out = out.transpose(-1, -2).reshape(B, -1, H, W) out = self.proj(out) return out class EfficientVitBlock(nn.Module): def __init__( self, in_channels, heads_ratio=1.0, head_dim=32, expand_ratio=4, norm_layer=nn.BatchNorm2d, act_layer=nn.Hardswish, ): super(EfficientVitBlock, self).__init__() self.context_module = ResidualBlock( LiteMLA( in_channels=in_channels, out_channels=in_channels, heads_ratio=heads_ratio, dim=head_dim, norm_layer=(None, norm_layer), ), nn.Identity(), ) self.local_module = ResidualBlock( MBConv( in_channels=in_channels, out_channels=in_channels, expand_ratio=expand_ratio, use_bias=(True, True, False), norm_layer=(None, None, norm_layer), act_layer=(act_layer, act_layer, None), ), nn.Identity(), ) def forward(self, x): x = self.context_module(x) x = self.local_module(x) return x class ResidualBlock(nn.Module): def __init__( self, main: Optional[nn.Module], shortcut: Optional[nn.Module] = None, pre_norm: Optional[nn.Module] = None, ): super(ResidualBlock, self).__init__() self.pre_norm = pre_norm if pre_norm is not None else nn.Identity() self.main = main self.shortcut = shortcut def forward(self, x): res = self.main(self.pre_norm(x)) if self.shortcut is not None: res = res + self.shortcut(x) return res def build_local_block( in_channels: int, out_channels: int, stride: int, kernel_size: int, expand_ratio: float, norm_layer: str, act_layer: str, fewer_norm: bool = False, block_type: str = "default", ): assert block_type in ["default", "large", "fused"] if expand_ratio == 1: if block_type == "default": block = DSConv( in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=kernel_size, use_bias=(True, False) if fewer_norm else False, norm_layer=(None, norm_layer) if fewer_norm else norm_layer, act_layer=(act_layer, None), ) else: block = ConvBlock( in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=kernel_size, use_bias=(True, False) if fewer_norm else False, norm_layer=(None, norm_layer) if fewer_norm else norm_layer, act_layer=(act_layer, None), ) else: if block_type == "default": block = MBConv( in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=kernel_size, expand_ratio=expand_ratio, use_bias=(True, True, False) if fewer_norm else False, norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer, act_layer=(act_layer, act_layer, None), ) else: block = FusedMBConv( in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=kernel_size, expand_ratio=expand_ratio, use_bias=(True, False) if fewer_norm else False, norm_layer=(None, norm_layer) if fewer_norm else norm_layer, act_layer=(act_layer, None), ) return block class Stem(nn.Sequential): def __init__( self, in_chs, out_chs, depth, stride, norm_layer, act_layer, block_type="default", ): super().__init__() self.stride = stride self.add_module( "in_conv", ConvNormAct( in_chs, out_chs, kernel_size=stride + 1, stride=stride, norm_layer=norm_layer, act_layer=act_layer, ), ) stem_block = 0 for _ in range(depth): self.add_module( f"res{stem_block}", ResidualBlock( build_local_block( in_channels=out_chs, out_channels=out_chs, stride=1, kernel_size=3, expand_ratio=1, norm_layer=norm_layer, act_layer=act_layer, block_type=block_type, ), nn.Identity(), ), ) stem_block += 1 class EfficientVitLargeStage(nn.Module): def __init__( self, in_chs, out_chs, depth, stride, norm_layer, act_layer, head_dim, vit_stage=False, fewer_norm=False, ): super(EfficientVitLargeStage, self).__init__() blocks = [ ResidualBlock( build_local_block( in_channels=in_chs, out_channels=out_chs, stride=stride, kernel_size=stride + 1, expand_ratio=24 if vit_stage else 16, norm_layer=norm_layer, act_layer=act_layer, fewer_norm=vit_stage or fewer_norm, block_type="default" if fewer_norm else "fused", ), None, ) ] in_chs = out_chs if vit_stage: # for stage 4 for _ in range(depth): blocks.append( EfficientVitBlock( in_channels=in_chs, head_dim=head_dim, expand_ratio=6, norm_layer=norm_layer, act_layer=act_layer, ) ) else: # for stage 1, 2, 3 for i in range(depth): blocks.append( ResidualBlock( build_local_block( in_channels=in_chs, out_channels=out_chs, stride=1, kernel_size=3, expand_ratio=4, norm_layer=norm_layer, act_layer=act_layer, fewer_norm=fewer_norm, block_type="default" if fewer_norm else "fused", ), nn.Identity(), ) ) self.blocks = nn.Sequential(*blocks) def forward(self, x): return self.blocks(x) class EfficientVitLarge(nn.Module): def __init__( self, config: EfficientViTConfig, norm_layer=nn.BatchNorm2d, act_layer=nn.Hardswish, ): super(EfficientVitLarge, self).__init__() self.grad_checkpointing = False self.num_classes = config.num_classes self.norm_eps = config.layer_norm_eps norm_layer = partial(norm_layer, eps=self.norm_eps) # input stem self.stem = Stem( config.num_channels, config.widths[0], config.depths[0], config.strides[0], norm_layer, act_layer, block_type="large", ) stride = config.strides[0] # stages self.feature_info = [] self.stages = nn.Sequential() in_channels = config.widths[0] for i, (w, d, s) in enumerate( zip(config.widths[1:], config.depths[1:], config.strides[1:]) ): self.stages.append( EfficientVitLargeStage( in_channels, w, depth=d, stride=s, norm_layer=norm_layer, act_layer=act_layer, head_dim=config.head_dim, vit_stage=i >= 3, fewer_norm=i >= 2, ) ) stride *= s in_channels = w self.feature_info += [ dict(num_chs=in_channels, reduction=stride, module=f"stages.{i}") ] self.num_features = in_channels @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable def forward(self, x): x = self.stem(x) encoder_hidden_states = [] for i, module in enumerate(self.stages): x = module(x) encoder_hidden_states.append(x) return encoder_hidden_states class EfficientViTPreTrainedModel(SuryaPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = EfficientViTConfig base_model_prefix = "efficientvit" main_input_name = "pixel_values" def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class DecodeMLP(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.proj = nn.Linear(input_dim, output_dim) def forward(self, hidden_states: torch.Tensor): # Input is B, C, H, W hidden_states = hidden_states.flatten(2).transpose(1, 2) # Output is B, HW, C hidden_states = self.proj(hidden_states) return hidden_states class DecodeHead(EfficientViTPreTrainedModel): def __init__(self, config: EfficientViTConfig): super().__init__(config) # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size mlps = [] for width in config.widths[1:]: mlp = DecodeMLP( input_dim=width, output_dim=config.decoder_layer_hidden_size ) mlps.append(mlp) self.linear_c = nn.ModuleList(mlps) # the following 3 layers implement the ConvModule of the original implementation self.linear_fuse = nn.Conv2d( in_channels=config.decoder_layer_hidden_size * config.num_stages, out_channels=config.decoder_hidden_size, kernel_size=1, bias=False, ) self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size) self.activation = nn.ReLU() self.dropout = nn.Dropout(config.classifier_dropout_prob) self.classifier = nn.Conv2d( config.decoder_hidden_size, config.num_labels, kernel_size=1 ) self.config = config def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor: batch_size = encoder_hidden_states[-1].shape[0] all_hidden_states = () for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c): height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] encoder_hidden_state = mlp(encoder_hidden_state) # Output is B, HW, C # Permute to B, C, HW encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) encoder_hidden_state = encoder_hidden_state.reshape( batch_size, -1, height, width ) # upsample encoder_hidden_state = nn.functional.interpolate( encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False, ) all_hidden_states += (encoder_hidden_state,) hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) hidden_states = self.batch_norm(hidden_states) hidden_states = self.activation(hidden_states) # logits are of shape (batch_size, num_labels, height/4, width/4) logits = self.classifier(hidden_states) return logits class EfficientViTForSemanticSegmentation( S3DownloaderMixin, EfficientViTPreTrainedModel ): def __init__(self, config, **kwargs): super().__init__(config) self.vit = EfficientVitLarge(config) self.decode_head = DecodeHead(config) # Initialize weights and apply final processing self.post_init() def forward( self, pixel_values: torch.FloatTensor ) -> Union[Tuple, SemanticSegmenterOutput]: # Pixel values should be B,C,H,W encoder_hidden_states = self.vit( pixel_values, ) logits = self.decode_head(encoder_hidden_states) # Apply sigmoid to get 0-1 output logits = torch.special.expit(logits) return SemanticSegmenterOutput( loss=None, logits=logits, hidden_states=encoder_hidden_states ) class EfficientViTForSemanticLayoutSegmentation(EfficientViTPreTrainedModel): def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.vit = EfficientVitLarge(config) self.decode_head = DecodeHead(config) # Initialize weights and apply final processing self.post_init() def forward( self, pixel_values: torch.FloatTensor ) -> Union[Tuple, SemanticSegmenterOutput]: # Pixel values should be B,C,H,W encoder_hidden_states = self.vit( pixel_values, ) logits = self.decode_head(encoder_hidden_states) # Apply sigmoid to get 0-1 output logits = torch.special.expit(logits) return SemanticSegmenterOutput( loss=None, logits=logits, hidden_states=encoder_hidden_states ) ``` -------------------------------------------------------------------------------- /surya/common/surya/processor/tokenizer.py: -------------------------------------------------------------------------------- ```python import html import re from typing import List, Union, Dict, Optional, Tuple, Iterable import numpy as np import torch from tokenizers import AddedToken import json import os from transformers import PreTrainedTokenizer, Qwen2Tokenizer as Qwen2OriginalTokenizer from surya.common.s3 import S3DownloaderMixin from surya.common.surya.schema import TASK_NAMES, TaskNames from surya.logging import get_logger from surya.settings import settings logger = get_logger() def create_token_regex(tokens): escaped_tokens = [re.escape(token) for token in tokens] escaped_tokens.sort(key=len, reverse=True) pattern = r"^(" + "|".join(escaped_tokens) + r")" regex = re.compile(pattern) return regex class InnerOCRTokenizer: def __init__( self, special_tokens: Dict[str, list] | None = None, qwen_tokenizer: Qwen2OriginalTokenizer | None = None, **kwargs, ): self.qwen_tokenizer = qwen_tokenizer self.qwen_token_offset = len(qwen_tokenizer) all_special_tokens = special_tokens.get("all", []) self.SPECIAL_TOKEN_MAPPING = {} idx = 0 for tag in all_special_tokens: if tag in self.SPECIAL_TOKEN_MAPPING: continue self.SPECIAL_TOKEN_MAPPING[tag] = ( idx + self.qwen_token_offset ) # Assign token ID idx += 1 self.REVERSE_SPECIAL_TOKEN_MAPPING = { v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items() } self.SPECIAL_TOKEN_OFFSET = idx self.FORMAT_TAG_PATTERN = create_token_regex(special_tokens["formatting"]) self.MATH_TAG_PATTERN = create_token_regex(special_tokens["math_external"]) self.LAYOUT_TAG_PATTERN = create_token_regex(special_tokens["layout"]) self.TABLE_STRUCTURE_TAG_PATTERN = create_token_regex( special_tokens["table_structure"] ) self.SYSTEM_TAG_PATTERN = create_token_regex(special_tokens.get("system", [])) if not special_tokens.get("system", []): logger.warning("Warning: No system tokens found in special_tokens") self.MATH_TAG_START = "<math" self.MATH_END_TAG = "</math>" super().__init__(**kwargs) @property def vocab_size(self): return ( 65536 + self.SPECIAL_TOKEN_OFFSET ) # The highest codepoint is 65535, but we add 1 to account for the 0-indexing def _tokenize(self, text: str) -> List[int]: tokens = [] in_math = False text = html.unescape(text) # Unescape html entities like < in equations while text: # Look for EOS, PAD, etc. tokens match = self.SYSTEM_TAG_PATTERN.search(text) if match: tag = match.group(1) tokens.append( self.SPECIAL_TOKEN_MAPPING[tag] ) # These are already offset text = text[match.end() :] continue # Look for layout tokens match = self.LAYOUT_TAG_PATTERN.search(text) if match: tag = match.group(1) tokens.append( self.SPECIAL_TOKEN_MAPPING[tag] ) # Layout tokens are already offset text = text[match.end() :] continue match = self.TABLE_STRUCTURE_TAG_PATTERN.search(text) if match: tag = match.group(1) tokens.append(self.SPECIAL_TOKEN_MAPPING[tag]) text = text[match.end() :] continue # Check for math tags match = self.MATH_TAG_PATTERN.search(text) if match: # We found a tag tag = match.group(1) if tag.startswith(self.MATH_TAG_START): in_math = True elif tag == self.MATH_END_TAG: in_math = False tokens.append( self.SPECIAL_TOKEN_MAPPING[tag] # Special tokens are already offset ) # Use special token ID text = text[match.end() :] continue # Tokenize math content with qwen2 tokenizer if in_math: # If we're in a math block, check to see if we have a special math tag in the text math_end_position = text.find(self.MATH_END_TAG) math_str = text[:math_end_position] # Gets the math content tokens += self.qwen_tokenizer(math_str)["input_ids"] text = text[math_end_position:] continue # Check for formatting tags match = self.FORMAT_TAG_PATTERN.search(text) if match: # We found a tag tag = match.group(1) tokens.append( self.SPECIAL_TOKEN_MAPPING[tag] # Special tokens are already offset ) # Use special token ID text = text[match.end() :] continue # General case, utf-16 tokenization utf_16_tokens = self.text_to_utf16_numbers(text[0]) tokens += [ t + self.SPECIAL_TOKEN_OFFSET + self.qwen_token_offset for t in utf_16_tokens ] text = text[1:] return tokens def text_to_utf16_numbers(self, text: str): """Converts text to UTF-16 encoded numbers.""" utf16_bytes = text.encode( "utf-16le" ) # Little-endian to simplify byte order handling numbers = [] for i in range(0, len(utf16_bytes), 2): # Combine two adjacent bytes into a single number number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8) numbers.append(number) return numbers def utf16_numbers_to_text(self, numbers): """Converts UTF-16 numbers back to text.""" byte_array = bytearray() for number in numbers: byte_array.append(number & 0xFF) # Lower byte byte_array.append((number >> 8) & 0xFF) # Upper byte try: text = byte_array.decode("utf-16le", errors="ignore") except Exception as e: logger.warning(f"Error decoding utf16: {e}") text = "" return text def __call__( self, texts: Union[str, List[str]], **kwargs ) -> Dict[str, List[List[int]]]: """Tokenizes text and returns input IDs.""" tokenized = [] if isinstance(texts, str): texts = [texts] for text in texts: tokens = self._tokenize(text) tokenized.append(tokens) return {"input_ids": tokenized} def decode(self, token_ids, **kwargs): """Decodes token IDs back to text.""" if isinstance(token_ids, (np.ndarray, torch.Tensor)): token_ids = token_ids.tolist() decoded_text = "" token_buffer = [] decode_qwen = [False] def decode_buffer(): nonlocal decoded_text, token_buffer, decode_qwen if token_buffer: if decode_qwen[0]: decoded_text += self.qwen_tokenizer.decode(token_buffer) else: token_buffer = [ t - self.SPECIAL_TOKEN_OFFSET - self.qwen_token_offset for t in token_buffer ] decoded_text += self.utf16_numbers_to_text(token_buffer) token_buffer = [] decode_qwen[0] = False for t in token_ids: if t < self.qwen_token_offset: # This is for math tags if token_buffer and token_buffer[-1] >= self.qwen_token_offset: decode_buffer() token_buffer.append(t) decode_qwen[0] = True elif t >= self.SPECIAL_TOKEN_OFFSET + self.qwen_token_offset: if token_buffer and token_buffer[-1] < self.qwen_token_offset: decode_buffer() token_buffer.append(t) # We shift this down later on decode_qwen[0] = False elif t in self.REVERSE_SPECIAL_TOKEN_MAPPING: decode_buffer() decoded_text += self.REVERSE_SPECIAL_TOKEN_MAPPING[t] decode_qwen[0] = False else: raise ValueError( f'Unexpected token value while decoding, got "{t}" in token_ids {token_ids}' ) # Detokenize remaining tokens decode_buffer() return decoded_text class Qwen2Tokenizer(S3DownloaderMixin, Qwen2OriginalTokenizer): pass class GreedyMathUTF16Tokenizer(S3DownloaderMixin, PreTrainedTokenizer): """ HuggingFace slow tokenizer implementing: - UTF-16 code units as the base [0..65535] - Math tokens as greedy-longest-match ids after UTF-16 - Literal special tokens after math tokens Absolute ID layout: [0 .. 65535] : UTF-16 units [65536 .. 65536+M-1] : math tokens [65536+M .. 65536+M+S-1] : special tokens """ vocab_files_names = { "vocab_file": "vocab_math.json", # {"\\frac": 0, "\\alpha": 1, ...} raw contiguous ids 0..M-1 "specials_file": "specials.json", # [flat list for legacy] "specials_dict_file": "specials_dict.json", # category dict (preferred) } model_input_names = ["input_ids", "attention_mask"] is_fast = False # ---------- helpers ---------- @staticmethod def _to_utf16_units(s: str) -> List[int]: b = s.encode("utf-16le") return [int.from_bytes(b[i : i + 2], "little") for i in range(0, len(b), 2)] @staticmethod def _from_utf16_units(units: List[int]) -> str: b = bytearray() for u in units: b += int(u).to_bytes(2, "little") return b.decode("utf-16le", errors="ignore") class _TrieNode: __slots__ = ("child", "id", "leaf") def __init__(self): self.child: Dict[str, "GreedyMathUTF16Tokenizer._TrieNode"] = {} self.id: Optional[int] = None self.leaf: bool = False @classmethod def _build_trie( cls, token_to_id: Dict[str, int] ) -> "GreedyMathUTF16Tokenizer._TrieNode": root = cls._TrieNode() for tok, tid in token_to_id.items(): node = root for ch in tok: node = node.child.setdefault(ch, cls._TrieNode()) node.leaf = True node.id = tid return root @classmethod def _encode_math_greedy( cls, s: str, trie: "GreedyMathUTF16Tokenizer._TrieNode", math_base: int, debug: bool = False, ) -> List[int]: i, n = 0, len(s) out: List[int] = [] while i < n: node = trie j = i last_id = None last_j = i while j < n and (ch := s[j]) in node.child: node = node.child[ch] j += 1 if node.leaf: last_id, last_j = node.id, j if last_id is not None: if debug: print(f"[MATH] matched {s[i:last_j]!r} -> {last_id}") out.append(math_base + last_id) i = last_j else: units = cls._to_utf16_units(s[i]) if debug: print(f"[MATH] fallback {s[i]!r} -> utf16 {units}") out.extend(units) i += 1 return out # ---------- init ---------- def __init__( self, vocab_file: Optional[str] = None, specials_file: Optional[str] = None, specials_dict_file: Optional[str] = None, *, # You can also pass programmatically instead of files: math_vocab: Optional[Dict[str, int]] = None, special_tokens: Optional[List[str]] = None, special_tokens_dict: Optional[Dict[str, List[str]]] = None, debug: bool = False, # Standard HF special token kwargs: bos_token: Optional[str] = None, eos_token: Optional[str] = None, pad_token: Optional[str] = None, unk_token: Optional[str] = None, **kwargs, ): # Load math vocab if vocab_file and os.path.isfile(vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: mv = json.load(f) else: mv = math_vocab or {} # Make math ids contiguous if needed if mv: max_id = max(mv.values()) if set(mv.values()) != set(range(max_id + 1)): items = sorted(mv.items(), key=lambda kv: kv[1]) mv = {tok: i for i, (tok, _) in enumerate(items)} # Load special tokens (prefer category dict; fallback to flat list or defaults) sp_dict = None if specials_dict_file and os.path.isfile(specials_dict_file): with open(specials_dict_file, "r", encoding="utf-8") as f: sp_dict = json.load(f) elif special_tokens_dict is not None: sp_dict = dict(special_tokens_dict) if sp_dict is None: # Legacy path: flat list from file or provided/default list if specials_file and os.path.isfile(specials_file): with open(specials_file, "r", encoding="utf-8") as f: sp_list_flat = json.load(f) else: sp_list_flat = special_tokens or SPECIAL_TOKENS sp_dict = {"all": list(sp_list_flat)} # Ensure "all" exists and is unique/preserved in order. if "all" not in sp_dict or not isinstance(sp_dict["all"], list): order = [ "system", "formatting", "math_external", "script", "layout", "reasoning", "table_structure", "reserved", ] seen = set() all_tokens: List[str] = [] for k in order: if k in sp_dict and isinstance(sp_dict[k], list): for t in sp_dict[k]: if t not in seen: all_tokens.append(t) seen.add(t) sp_dict["all"] = all_tokens # Keep a copy of categories (if present) for downstream processor logic. self.special_tokens = sp_dict sp_list = list(sp_dict.get("all", [])) # Regex list should favor longest-first to avoid partial matches. specials_for_regex = sorted(sp_list, key=len, reverse=True) self.debug = debug self.UTF16_SPACE = 65536 self.math_token_to_rawid = dict(mv) # 0..M-1 self.math_vocab_size = len(self.math_token_to_rawid) self.MATH_BASE = self.UTF16_SPACE self.SPECIAL_BASE = self.UTF16_SPACE + self.math_vocab_size # Maps self.math_absid_to_token = { self.MATH_BASE + rid: tok for tok, rid in self.math_token_to_rawid.items() } self.special_tokens_list = sp_list # ID assignment order self.special_to_absid = { tok: self.SPECIAL_BASE + i for i, tok in enumerate(self.special_tokens_list) } self.absid_to_special = {v: k for k, v in self.special_to_absid.items()} # Public attributes for legacy/processor: # All specials mapping (token -> absolute id) self.SPECIAL_TOKEN_MAPPING: Dict[str, int] = dict(self.special_to_absid) # Subset used heavily by processor for quick access self.reverse_special_token_mapping = { v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items() } self.LAYOUT_LABEL2ID = { k: v for k, v in self.SPECIAL_TOKEN_MAPPING.items() if k in self.special_tokens["layout"] } self.TABLE_STRUCTURE_LABEL2ID = { k: v for k, v in self.SPECIAL_TOKEN_MAPPING.items() if k in self.special_tokens["table_structure"] } if not self.special_tokens.get("system", []): print("Warning: No system tokens found in special_tokens") self.MATH_TAG_START = "<math" self.MATH_END_TAG = "</math>" sys_list = self.special_tokens.get("system", []) self.system_tokens: Dict[str, int] = { t: self.special_to_absid[t] for t in sys_list if t in self.special_to_absid } # Regex for literal specials self.specials_pattern = ( re.compile(r"(" + "|".join(re.escape(k) for k in specials_for_regex) + r")") if specials_for_regex else None ) # Trie for math greedy match self.trie = self._build_trie(self.math_token_to_rawid) # Tell HF about special tokens (metadata) kwargs.setdefault("bos_token", bos_token) kwargs.setdefault("eos_token", eos_token or "</S>") kwargs.setdefault("pad_token", pad_token or "<PAD>") kwargs.setdefault("unk_token", unk_token) super().__init__( vocab_file=vocab_file, specials_file=specials_file, specials_dict_file=specials_dict_file, **kwargs, ) # ---------- required HF surface ---------- @property def vocab_size(self) -> int: return self.UTF16_SPACE + self.math_vocab_size + len(self.special_tokens_list) def get_vocab(self) -> Dict[str, int]: # Compact vocab: just math+specials with ABSOLUTE ids. v = {tok: self.MATH_BASE + rid for tok, rid in self.math_token_to_rawid.items()} v.update(self.special_to_absid) return v def __len__(self) -> int: return self.vocab_size # Core encode/decode on ABSOLUTE ids def _encode_core(self, text: str) -> List[int]: text = html.unescape(text) ids: List[int] = [] in_math = False chunks = self.specials_pattern.split(text) if self.specials_pattern else [text] for chunk in chunks: if chunk in self.special_to_absid: ids.append(self.special_to_absid[chunk]) if chunk.startswith("<math"): in_math = True elif chunk.startswith("</math>"): in_math = False if self.debug: print(f"[TAG] {chunk!r} -> {self.special_to_absid[chunk]}") continue if in_math: ids.extend( self._encode_math_greedy( chunk, self.trie, self.MATH_BASE, debug=self.debug ) ) else: units = self._to_utf16_units(chunk) if self.debug and units: print( f"[TEXT] utf16 {chunk[:32]!r} -> {units[:8]}{'...' if len(units) > 8 else ''}" ) ids.extend(units) return ids def _decode_core(self, ids: Iterable[int]) -> str: out: List[str] = [] buf: List[int] = [] def flush(): if buf: out.append(self._from_utf16_units(buf)) buf.clear() for tid in ids: if tid >= self.MATH_BASE and tid < self.SPECIAL_BASE: flush() out.append(self.math_absid_to_token.get(tid, "")) elif tid >= self.SPECIAL_BASE: flush() out.append(self.absid_to_special.get(tid, "")) else: buf.append(int(tid)) flush() return "".join(out) # ---- Tokenizer interface ---- def _tokenize(self, text: str, **kwargs) -> List[str]: ids = self._encode_core(text) toks: List[str] = [] for i in ids: if i < self.MATH_BASE: toks.append(f"<U+{i:04X}>") elif i < self.SPECIAL_BASE: toks.append(self.math_absid_to_token.get(i, "<UNK_MATH>")) else: toks.append(self.absid_to_special.get(i, "<UNK_SPECIAL>")) return toks def _convert_token_to_id(self, token: str) -> int: if token.startswith("<U+") and token.endswith(">"): try: return int(token[3:-1], 16) # UTF-16 unit except Exception: return self.unk_token_id if self.unk_token_id is not None else 0 # math or specials if token in self.math_token_to_rawid: return self.MATH_BASE + self.math_token_to_rawid[token] if token in self.special_to_absid: return self.special_to_absid[token] # rare path: single-char token -> its UTF-16 unit if len(token) == 1: u = self._to_utf16_units(token) if len(u) == 1: return u[0] return self.unk_token_id if self.unk_token_id is not None else 0 def _convert_id_to_token(self, index: int) -> str: if index < self.MATH_BASE: return f"<U+{index:04X}>" if index < self.SPECIAL_BASE: return self.math_absid_to_token.get(index, "<UNK_MATH>") return self.absid_to_special.get(index, "<UNK_SPECIAL>") def convert_tokens_to_string(self, tokens: List[str]) -> str: ids = [self._convert_token_to_id(t) for t in tokens] return self._decode_core(ids) def decode(self, token_ids, skip_special_tokens: bool = False, **kwargs) -> str: # Accept int, list, tuple, numpy, torch if hasattr(token_ids, "tolist"): token_ids = token_ids.tolist() elif isinstance(token_ids, int): token_ids = [token_ids] else: token_ids = list(token_ids) token_ids = [int(i) for i in token_ids] # normalize early if skip_special_tokens: token_ids = [i for i in token_ids if i < self.SPECIAL_BASE] return self._decode_core(token_ids) # HF plumbing def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: out = ( list(token_ids_0) if token_ids_1 is None else list(token_ids_0) + list(token_ids_1) ) # if self.eos_token_id is not None and (not out or out[-1] != self.eos_token_id): # out.append(self.eos_token_id) return out def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False, ) -> List[int]: def mask(seq: List[int]) -> List[int]: return [1 if i >= self.SPECIAL_BASE else 0 for i in seq] return ( mask(token_ids_0) if token_ids_1 is None else mask(token_ids_0) + mask(token_ids_1) ) def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: return [0] * ( len(token_ids_0) if token_ids_1 is None else len(token_ids_0) + len(token_ids_1) ) # Save/load raw assets def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None ) -> Tuple[str, str]: os.makedirs(save_directory, exist_ok=True) pre = (filename_prefix + "-") if filename_prefix else "" vocab_path = os.path.join( save_directory, pre + self.vocab_files_names["vocab_file"] ) specials_path = os.path.join( save_directory, pre + self.vocab_files_names["specials_file"] ) specials_dict_path = os.path.join( save_directory, pre + self.vocab_files_names["specials_dict_file"] ) with open(vocab_path, "w", encoding="utf-8") as f: json.dump(self.math_token_to_rawid, f, ensure_ascii=False, indent=2) # Save both the flat list ("all") and the category dict (preferred) with open(specials_path, "w", encoding="utf-8") as f: json.dump(self.special_tokens_list, f, ensure_ascii=False, indent=2) with open(specials_dict_path, "w", encoding="utf-8") as f: json.dump(self.special_tokens, f, ensure_ascii=False, indent=2) return (vocab_path, specials_path) class SuryaOCRTokenizer(S3DownloaderMixin, PreTrainedTokenizer): def __init__( self, special_tokens: Dict[str, list] | None = None, model_checkpoint: str = settings.FOUNDATION_MODEL_CHECKPOINT, **kwargs, ): if special_tokens is None: special_tokens = dict() self.special_tokens = special_tokens self.ocr_tokenizer = GreedyMathUTF16Tokenizer.from_pretrained( model_checkpoint, ) self.system_tokens = { v: self.ocr_tokenizer(v)["input_ids"][0] for v in special_tokens.get("system", []) } self.SPECIAL_TOKEN_MAPPING = self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING super().__init__(**kwargs) def get_vocab(self) -> Dict[str, int]: return self.ocr_tokenizer.get_vocab() def _add_tokens( self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False, ) -> int: return self.ocr_tokenizer._add_tokens( new_tokens, special_tokens=special_tokens ) @property def vocab_size(self): return self.ocr_tokenizer.vocab_size def _tokenize(self, text: str, **kwargs): # task = kwargs.get("task", TaskNames.ocr_with_boxes) # assert task in TASK_NAMES, f"Invalid task: {task}" tokens = self.ocr_tokenizer(text)["input_ids"] return tokens def __call__( self, texts: Union[str, List[str]], tasks: Union[str, List[str]] = None, **kwargs, ) -> Dict[str, List[List[int]]]: """Tokenizes text and returns input IDs.""" tokenized = [] if isinstance(texts, str): texts = [texts] assert isinstance(tasks, str), "Tasks must be a string if texts is a string" tasks = [tasks] if isinstance(texts, list): assert isinstance(tasks, list), "Tasks must be a list if texts is a list" for text, task in zip(texts, tasks): tokens = self._tokenize(text, task=task) tokenized.append(tokens) return {"input_ids": tokenized} def decode(self, token_ids, **kwargs): if isinstance(token_ids, (np.ndarray, torch.Tensor)): token_ids = token_ids.tolist() decoded_text = self.ocr_tokenizer.decode(token_ids, skip_special_tokens=False) # replace all <SCRIPT-...> tokens with empty strings decoded_text = re.sub(r"<SCRIPT-.*?>", "", decoded_text) # replace </S> with empty string decoded_text = re.sub(r"</S>", "", decoded_text) return decoded_text ``` -------------------------------------------------------------------------------- /surya/common/surya/__init__.py: -------------------------------------------------------------------------------- ```python import warnings from typing import Optional, Tuple, TypedDict from dataclasses import dataclass import torch from torch import nn import torch.nn.functional as F from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.cache_utils import Cache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from surya.common.pretrained import SuryaPreTrainedModel from surya.common.s3 import S3DownloaderMixin from surya.common.surya.config import SuryaModelConfig from surya.common.surya.decoder import SuryaDecoderModel from surya.common.surya.embedder import SimpleTokenEmbedder from surya.common.surya.encoder import SuryaEncoderModel from surya.common.util import pad_to_batch_size, pad_to_batch_size_repeat from surya.common.xla import get_nearest_pad from surya.settings import settings from surya.logging import get_logger logger = get_logger() @dataclass class SuryaModelOutput(CausalLMOutputWithPast): bbox_logits: torch.FloatTensor = None lm_logits: torch.FloatTensor = None class FlashAttentionKwargs(TypedDict, total=False): """ Keyword arguments for Flash Attention with Compile. Attributes: cu_seq_lens_q (`torch.LongTensor`, *optional*) Gets cumlative sequence length for query state. cu_seq_lens_k (`torch.LongTensor`, *optional*) Gets cumlative sequence length for key state. max_length_q (`int`, *optional*): Maximum sequence length for query state. max_length_k (`int`, *optional*): Maximum sequence length for key state. """ cu_seq_lens_q: Optional[torch.LongTensor] cu_seq_lens_k: Optional[torch.LongTensor] max_length_q: Optional[int] max_length_k: Optional[int] class KwargsForCausalLM(FlashAttentionKwargs): ... class DistanceProjection(nn.Module): def __init__(self, in_features: int, out_features: int): super().__init__() self.fc1 = nn.Linear(in_features, out_features) self.act = nn.SiLU() self.fc2 = nn.Linear(out_features, out_features) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.act(x) x = self.fc2(x) return x def init_weights(self): nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) nn.init.zeros_(self.fc1.bias) nn.init.zeros_(self.fc2.bias) class BboxHead(nn.Module): def __init__(self, in_features: int, out_features: int): super().__init__() self.proj_layers = nn.ModuleList( [nn.Linear(in_features, in_features) for _ in range(6)] ) self.act = nn.SiLU() self.out_proj = nn.Linear(in_features, out_features) def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self.proj_layers: x = layer(x) x = self.act(x) x = self.out_proj(x) return x class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel): config_class = SuryaModelConfig supports_gradient_checkpointing = True _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True main_input_name = "input_ids" _tied_weights_keys = ["lm_head.weight"] def __init__( self, config: SuryaModelConfig, embedder: SimpleTokenEmbedder = None, vision_encoder: SuryaEncoderModel = None, decoder: SuryaDecoderModel = None, **kwargs, ): super().__init__(config, **kwargs) if vision_encoder is None: vision_encoder = SuryaEncoderModel(config.vision_encoder) if decoder is None: decoder = SuryaDecoderModel(config.decoder) if embedder is None: embedder = SimpleTokenEmbedder(config) self.vision_encoder = vision_encoder self.decoder = decoder self.embedder = embedder # Simple encoding for image patches self.img_w_embed = nn.Embedding( self.config.image_embed_encoding_size, self.config.hidden_size, ) self.img_h_embed = nn.Embedding( self.config.image_embed_encoding_size, self.config.hidden_size, ) # Tying configs self.vision_encoder.config = self.config.vision_encoder self.decoder.config = self.config.decoder self.bbox_head = BboxHead(config.hidden_size, 6) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) if ( self.config.multi_output_distance is not None and self.config.multi_output_distance > 0 ): self.multi_output_projections = nn.ModuleList( [ DistanceProjection( in_features=config.hidden_size, out_features=config.hidden_size ) for _ in range(self.config.multi_output_distance) ] ) def tie_weights(self): self._tie_weights() def _tie_weights(self): # Tie weights of lm head and token embedder self._tie_or_clone_weights(self.lm_head, self.embedder.token_embed) def get_output_embeddings(self) -> nn.Module: return self.lm_head def get_input_embeddings(self) -> nn.Module: return self.embedder.token_embed def set_output_embeddings(self, new_embeddings: nn.Module): self.lm_head = new_embeddings def set_input_embeddings(self, new_embeddings: nn.Module): self.embedder.token_embed = new_embeddings def maybe_static_pad_image_inputs( self, chunk_pixels: torch.Tensor, chunk_grid_thw: torch.Tensor, actual_chunk_len: int, encoder_chunk_size: int, ) -> Tuple[torch.Tensor, torch.Tensor]: valid_embed_len = actual_chunk_len // ( self.vision_encoder.spatial_merge_size**2 ) if settings.FOUNDATION_STATIC_CACHE and actual_chunk_len < encoder_chunk_size: padding_len = encoder_chunk_size - actual_chunk_len chunk_pixels = F.pad( chunk_pixels, (0, 0, 0, padding_len), mode="constant", value=0.0, ) padding_grid = torch.tensor( [[1, 2, padding_len // 2]], device=chunk_grid_thw.device, dtype=chunk_grid_thw.dtype, ) chunk_grid_thw = torch.cat([chunk_grid_thw, padding_grid], dim=0) return chunk_pixels, chunk_grid_thw, valid_embed_len def get_image_embeddings( self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, encoder_chunk_size: int, valid_batch_size: torch.Tensor | None = None, max_batch_size: int | None = None, ): # embed all images with the vision encoder after they have already been tiled and flattened into a single batch chunks = [0] grid_chunks = [0] curr_chunk_len = 0 curr_seq_len = 0 for i in range(len(grid_thw)): curr_chunk_len += (grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2]).item() if curr_chunk_len > encoder_chunk_size: chunks.append(curr_chunk_len + curr_seq_len) curr_seq_len += curr_chunk_len curr_chunk_len = 0 grid_chunks.append(i + 1) if curr_chunk_len > 0: chunks.append(pixel_values.shape[0]) grid_chunks.append(len(grid_thw)) assert curr_chunk_len + curr_seq_len == pixel_values.shape[0], ( f"Mismatch in encoder chunking, {curr_chunk_len} + {curr_seq_len} != {pixel_values.shape[0]}" ) logger.debug( f"Chunking encoder sequence into {len(chunks) - 1} chunks of size {encoder_chunk_size} with lengths {chunks} and grids {grid_chunks}" ) embeddings = [] for i in range(len(chunks) - 1): start = chunks[i] end = chunks[i + 1] grid_start = grid_chunks[i] grid_end = grid_chunks[i + 1] chunk_pixels = pixel_values[start:end] chunk_grid_thw = grid_thw[grid_start:grid_end] actual_chunk_len = end - start chunk_pixels, chunk_grid_thw, valid_embed_len = ( self.maybe_static_pad_image_inputs( chunk_pixels, chunk_grid_thw, actual_chunk_len, encoder_chunk_size ) ) chunk_embeddings = self.vision_encoder.embed_images( image_batch=chunk_pixels.unsqueeze(0).to(device=self.device), grid_thw=chunk_grid_thw.unsqueeze(0).to(device=self.device), ) embeddings.append(chunk_embeddings[:valid_embed_len].squeeze(0)) if len(embeddings) == 0: raise ValueError( "No image embeddings were generated. Check the input images and grid sizes." ) elif len(embeddings) == 1: embeddings = embeddings[0] else: embeddings = torch.cat(embeddings, dim=0) encoding_2d = self.get_2d_learned_embeddings( grid_thw, device=embeddings.device, bbox_size=self.config.image_embed_encoding_multiplier, ) assert embeddings.shape[0] == encoding_2d.shape[0], ( f"Mismatch in image embedding seq len: {embeddings.shape} vs {encoding_2d.shape}" ) assert embeddings.shape[1] == encoding_2d.shape[1], ( f"Mismatch in image embedding token counts: {embeddings.shape} vs {encoding_2d.shape}" ) embeddings = embeddings + encoding_2d return embeddings def embed_ids_boxes_images( self, input_ids, image_embeddings, encoder_chunk_size: int, valid_batch_size: torch.Tensor | None = None, input_boxes: torch.Tensor | None = None, embed_boxes: torch.Tensor | None = None, ): """ Insert embedded image tiles into the corresponding positions into the full input sequence Positions to insert new tokens are indicated by the special image token index """ # This is batched in the inner call inputs_embeds = self.embedder.embed( input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes ) if image_embeddings is not None: special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds) if inputs_embeds[special_image_mask].numel() != image_embeddings.numel(): n_image_tokens = torch.sum((input_ids == self.config.image_token_id)) n_image_features = image_embeddings.shape[0] * image_embeddings.shape[1] warnings.warn( f"Image features and image tokens do not match: tokens {n_image_tokens}, features {n_image_features}. This may lead to unexpected results" ) image_features = image_embeddings.to(inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter( special_image_mask, image_features ) else: assert (input_ids == self.config.image_token_id).sum() == 0, ( "Image tokens were present in the input but no input images were provided" ) return inputs_embeds def get_2d_learned_embeddings( self, grid_thw, device: str | torch.device = "cpu", bbox_size: int = 256, ): all_embeddings = [] for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( grid_h // self.config.merge_size, grid_w // self.config.merge_size, ) # Scale to 0-1024 llm_grid_h = ( torch.arange(llm_grid_h, device=device) / max(1, (llm_grid_h - 1)) * bbox_size ) llm_grid_w = ( torch.arange(llm_grid_w, device=device) / max(1, (llm_grid_w - 1)) * bbox_size ) llm_grid_w_idx = llm_grid_w.to(torch.long) llm_grid_h_idx = llm_grid_h.to(torch.long) llm_grid_w = self.img_w_embed(llm_grid_w_idx) llm_grid_h = self.img_h_embed(llm_grid_h_idx) full_grid = llm_grid_h[:, None] + llm_grid_w[None, :] flattened = full_grid.flatten( 0, 1 ) # Flatten first dimension, so they are seq_len x embed_dim all_embeddings.append(flattened) return torch.concat( all_embeddings, dim=0 ) # Shape is num_image_tokens x embed_dim def get_logits(self, hidden_states): assert hidden_states.shape[1] == 1, ( "Multi output predictions only applied on the last token" ) all_lm_logits = [] all_bbox_logits = [] current_hidden = hidden_states # Loop includes initial prediction (i=0) plus multi_out_distance additional predictions for i in range(self.config.multi_output_distance + 1): if i > 0: current_hidden = self.multi_output_projections[i - 1](current_hidden) lm_logits = self.lm_head(current_hidden) bbox_logits = F.sigmoid(self.bbox_head(current_hidden)) all_lm_logits.append(lm_logits) all_bbox_logits.append(bbox_logits) # Concatenate along sequence dimension (dim=1) final_lm_logits = torch.cat(all_lm_logits, dim=1) final_bbox_logits = torch.cat(all_bbox_logits, dim=1) return final_lm_logits, final_bbox_logits def forward( self, input_ids=None, image_embeddings=None, labels=None, image_tiles=None, grid_thw=None, inputs_embeds=None, attention_mask=None, position_ids=None, cache_position=None, past_key_values=None, output_hidden_states=False, output_attentions=False, use_cache=False, encoder_chunk_size=32768, cache_idxs=None, num_valid_tokens=None, prefill=True, text_lengths=None, valid_batch_size: torch.Tensor = None, input_boxes=None, embed_boxes=None, logits_to_keep=None, **kwargs: KwargsForCausalLM, ): if any([ input_ids is None, position_ids is None, cache_position is None, ( prefill and not ( (image_tiles is not None and grid_thw is not None) or image_embeddings is not None ) ), ]): raise ValueError( "`input_ids`, `position_ids`, and `cache_position` **must** be specified. " "For prefill, you must provide either (`image_tiles` and `grid_thw`) or `image_embeddings`." ) inputs_embeds = self.embed_ids_boxes_images( input_ids, image_embeddings, encoder_chunk_size, valid_batch_size, input_boxes, embed_boxes ) # Handling flash attention kwargs outside the decoder to speed up + avoid graph breaks inside the decoder # Skipped during decoding since not required if self.decoder.config._attn_implementation == "flash_attention_2" and prefill: # Needed for CPU -> GPU from surya.common.surya.flash_attn_utils import _get_unpad_data batch_size, query_length, _ = inputs_embeds.shape indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( attention_mask ) kwargs["batch_size"] = batch_size kwargs["query_length"] = query_length kwargs["indices_k"] = indices_k kwargs["cu_seqlens_k"] = cu_seqlens_k kwargs["max_seqlen_in_batch_k"] = max_seqlen_in_batch_k causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, ) attention_mask = causal_mask outputs = self.decoder( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=True, use_cache=use_cache, cache_idxs=cache_idxs, num_valid_tokens=num_valid_tokens, prefill=prefill, text_lengths=text_lengths, **kwargs, ) hidden_states = outputs.last_hidden_state if logits_to_keep is not None: hidden_states = hidden_states[:, -logits_to_keep:, :] hidden_states = hidden_states.contiguous() loss = None if labels is not None: # Training, return full logits lm_logits = self.lm_head(hidden_states) bbox_logits = None vocab_size = lm_logits.shape[-1] labels = torch.roll(labels, shifts=-1, dims=-1) loss = F.cross_entropy( lm_logits.view(-1, vocab_size), labels.view(-1), reduction="mean" ) else: lm_logits, bbox_logits = self.get_logits(hidden_states) return SuryaModelOutput( loss=loss, bbox_logits=bbox_logits, lm_logits=lm_logits, hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions if output_attentions else None, past_key_values=outputs.past_key_values, ) def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): if self.decoder.config._attn_implementation == "flash_attention_2": return attention_mask # We always pass in a 2D attention mask from the processor - In both static and dynamic cache cases dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_key_values.max_cache_len ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, past_key_values=past_key_values, ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended( causal_mask, min_dtype ) return causal_mask @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, config: SuryaModelConfig, past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to plcae the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. Shape `(batch_size, sequence_length)`. batch_size (`torch.Tensor`): Batch size. config (`Qwen2Config`): The model's configuration class past_key_values (`Cache`): The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device, ) # Batch-aware diagonal attend mask diagonal_attend_mask = torch.arange(target_length, device=device).unsqueeze( 0 ) > cache_position.unsqueeze(-1) causal_mask = ( causal_mask.unsqueeze(0) * diagonal_attend_mask ) # (batch_size, seq_len, target_len) causal_mask = causal_mask[ :, None, :, : ] # (batch_size, 1, seq_len, target_len) if attention_mask is not None: causal_mask = ( causal_mask.clone() ) # copy to contiguous memory for in-place edit if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ :, None, None, : ].to(causal_mask.device) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[ :, :, :, :mask_length ].masked_fill(padding_mask, min_dtype) return causal_mask class SuryaXLAModel(SuryaModel): def get_image_embeddings( self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, encoder_chunk_size: int, valid_batch_size: torch.Tensor | None = None, max_batch_size: int | None = None, ): # embed all images with the vision encoder after they have already been tiled and flattened into a single batch unpadded_max_grid_size = ( (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).max().item() ) max_grid_size = get_nearest_pad( unpadded_max_grid_size, ) # If we need zero padding, we still need to allocate a bit of room for the extra grid_thw # Always need 2 items in each row batch if max_grid_size == unpadded_max_grid_size: max_grid_size += 16 full_image_grid = torch.zeros( (valid_batch_size, max_grid_size, pixel_values.shape[-1]), dtype=pixel_values.dtype, ) # Roll out into a full grid seq_len = 0 row_grids = [] for i in range(valid_batch_size): curr_sample_len = grid_thw[i][0] * grid_thw[i][1] * grid_thw[i][2] full_image_grid[i, -curr_sample_len:] = pixel_values[ seq_len : seq_len + curr_sample_len ] padded_len = max_grid_size - curr_sample_len if padded_len > 0: row_grid = torch.tensor( [ [1, 4, padded_len // 4], grid_thw[i].tolist(), ], dtype=torch.long, ) else: row_grid = torch.tensor( [ grid_thw[i].tolist(), ], dtype=torch.long, ) row_grids.append(row_grid) seq_len += curr_sample_len # bsz, 2, 3 row_grids = torch.stack(row_grids, dim=0) if settings.FOUNDATION_STATIC_CACHE: # Pad to max batch size, repeat the final row row_grids = pad_to_batch_size_repeat( row_grids, batch_size=max_batch_size, ) full_image_grid = pad_to_batch_size( full_image_grid, batch_size=max_batch_size, ) full_image_grid = full_image_grid.to(self.device) embeddings = self.vision_encoder.embed_images( image_batch=full_image_grid, grid_thw=row_grids.to(self.device) ) encoding_2d = self.get_2d_learned_embeddings( row_grids, bbox_size=self.config.image_embed_encoding_multiplier, ) embeddings += encoding_2d return embeddings def embed_ids_boxes_images( self, input_ids, image_embeddings, encoder_chunk_size: int, valid_batch_size: torch.Tensor | None = None, input_boxes: torch.Tensor | None = None, embed_boxes: torch.Tensor | None = None, ): """ Insert embedded image tiles into the corresponding positions into the full input sequence Positions to insert new tokens are indicated by the special image token index """ # This is batched in the inner call inputs_embeds = self.embedder.embed( input_tokens=input_ids, input_boxes=input_boxes, embed_boxes=embed_boxes ) if image_embeddings is not None: image_token_id_tensor = torch.tensor( self.config.image_token_id, device=inputs_embeds.device, dtype=torch.long, ) mask = input_ids == image_token_id_tensor last_image_token_pos = ( mask.size(1) - 1 - mask.flip(dims=[1]).long().argmax(dim=1, keepdim=True) ) # Calculate start position to replace N positions ending at (and including) the last image token start_positions = last_image_token_pos - image_embeddings[0].shape[0] batch_size, insert_len = image_embeddings.shape[:2] # Create position indices for each insertion pos_indices = torch.arange( insert_len, device=inputs_embeds.device ).unsqueeze(0) insert_positions = start_positions + pos_indices idx = insert_positions.unsqueeze(-1).expand( -1, -1, inputs_embeds.size(-1) ) # [B,N,D] inputs_embeds = inputs_embeds.scatter(1, idx, image_embeddings) inputs_embeds = inputs_embeds * ( input_ids != self.config.pad_token_id ).unsqueeze(-1).to(inputs_embeds.dtype) return inputs_embeds def get_2d_learned_embeddings( self, grid_thw, bbox_size: int = 256, ): dev = grid_thw.device all_row_coords = [] all_col_coords = [] for row_grid in grid_thw: merge = self.config.merge_size # per-sample grid sizes after merge H = (row_grid[:, 1] // merge).long() # (B,) W = (row_grid[:, 2] // merge).long() # (B,) row_coords = torch.cat( [ torch.linspace(0, bbox_size, steps=int(h), device=dev) .round() .repeat_interleave(w) # repeat each row value w times for h, w in zip(H.tolist(), W.tolist()) ] ) # (full_grid_size,) col_coords = torch.cat( [ torch.linspace(0, bbox_size, steps=int(w), device=dev) .round() .repeat(int(h)) # tile the column vector h times for h, w in zip(H.tolist(), W.tolist()) ] ) # (full_grid_size,) all_row_coords.append(row_coords) all_col_coords.append(col_coords) row_coords = torch.stack(all_row_coords, dim=0).to(self.device) col_coords = torch.stack(all_col_coords, dim=0).to(self.device) emb = self.img_h_embed(row_coords.long()) + self.img_w_embed(col_coords.long()) return emb ``` -------------------------------------------------------------------------------- /surya/common/surya/encoder/__init__.py: -------------------------------------------------------------------------------- ```python import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers.activations import ACT2FN from surya.common.pretrained import SuryaPreTrainedModel from surya.common.surya.encoder.config import SuryaEncoderConfig from surya.common.xla import get_nearest_pad from surya.logging import get_logger from surya.settings import settings if settings.FOUNDATION_XLA: import torch_xla.experimental.custom_kernel from surya.logging import get_logger logger = get_logger() class Qwen2_5_VLMLP(nn.Module): def __init__(self, config, bias: bool = False): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_state): return self.down_proj( self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) ) class Qwen2_5_VisionPatchEmbed(nn.Module): def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, in_channels: int = 3, embed_dim: int = 1152, ) -> None: super().__init__() self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.in_channels = in_channels self.embed_dim = embed_dim kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = nn.Conv3d( in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: target_dtype = self.proj.weight.dtype bsz = hidden_states.shape[0] hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size, ) hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( bsz, -1, self.embed_dim ) return hidden_states class Qwen2_5_VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.inv_freq = 1.0 / ( theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim) ) def forward(self, seqlen: int) -> torch.Tensor: seq = torch.arange(seqlen, device="cpu", dtype=self.inv_freq.dtype) freqs = torch.outer(seq, self.inv_freq) return freqs class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Qwen2RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Qwen2_5_VLPatchMerger(nn.Module): def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) self.mlp = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: bsz = x.shape[0] x = self.mlp(self.ln_q(x).view(bsz, -1, self.hidden_size)) return x def apply_rotary_pos_emb_flashatt( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: from flash_attn.layers.rotary import apply_rotary_emb cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) return q_embed, k_embed class Qwen2_5_VLVisionXLASdpaAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) self.head_dim = dim // num_heads def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1] q, k, v = ( self.qkv(hidden_states) .reshape(bsz, seq_length, 3, self.num_heads, -1) .permute(0, 2, 1, 3, 4) .unbind(1) ) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos() sin = emb.sin() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) attention_mask = torch.zeros([bsz, 1, seq_length, seq_length], dtype=torch.bool) cu_seqlens_cpu = cu_seqlens.cpu() for j in range(bsz): batch_seqlens = cu_seqlens_cpu[j] for i in range(1, len(batch_seqlens)): attention_mask[ j, ..., batch_seqlens[i - 1] : batch_seqlens[i], batch_seqlens[i - 1] : batch_seqlens[i], ] = True attention_mask = attention_mask.to(q.device) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) attn_output = F.scaled_dot_product_attention( q, k, v, attention_mask, dropout_p=0.0, ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, seq_length, -1) attn_output = self.proj(attn_output) return attn_output class Qwen2_5_VLVisionXLAFlashAttention2(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) self.head_dim = dim // num_heads def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: # Note, this is faster than SDPA, but pretty memory inefficient # It also has significant accuracy issues bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1] # Single reshape to target layout - avoid multiple operations q, k, v = ( self.qkv(hidden_states) .reshape(bsz, seq_length, 3, self.num_heads, -1) .permute(0, 2, 1, 3, 4) .unbind(1) ) # Apply rotary embeddings if provided if position_embeddings is not None: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) # Single reshape to flash attention format [batch, num_heads, seq_len, head_dim] q = q.transpose(1, 2) # [bsz, num_heads, seq_len, head_dim] k = k.transpose(1, 2) v = v.transpose(1, 2) total_seqlen = q.shape[2] # from cu_seqlens to segment ids for each position in dim 0 additive_bias = torch.zeros((bsz, 1, total_seqlen, total_seqlen), dtype=q.dtype) min_val = torch.finfo(q.dtype).min for i in range(bsz): padding_end = cu_seqlens[i][1].item() additive_bias[i, :, :, :padding_end] = min_val additive_bias = additive_bias.to(hidden_states.device) attn_scale = 1 / math.sqrt(self.head_dim) attn_output = torch_xla.experimental.custom_kernel.flash_attention( q, k, v, sm_scale=attn_scale, ab=additive_bias ) attn_output = ( attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_length, -1) ) attn_output = self.proj(attn_output) return attn_output class Qwen2_5_VLVisionFlashAttention2(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: from flash_attn import flash_attn_varlen_func bsz = hidden_states.shape[0] seq_length = hidden_states.shape[1] q, k, v = ( self.qkv(hidden_states) .reshape(bsz, seq_length, 3, self.num_heads, -1) .permute(0, 2, 1, 3, 4) .unbind(1) ) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos() sin = emb.sin() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_flashatt(q, k, cos.squeeze(0), sin.squeeze(0)) q = q.squeeze(0) k = k.squeeze(0) v = v.squeeze(0) cu_seqlens = cu_seqlens.squeeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = flash_attn_varlen_func( q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen ).reshape(bsz, seq_length, -1) attn_output = self.proj(attn_output) return attn_output def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb_vision( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: orig_q_dtype = q.dtype orig_k_dtype = k.dtype q, k = q.float(), k.float() cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) q_embed = q_embed.to(orig_q_dtype) k_embed = k_embed.to(orig_k_dtype) return q_embed, k_embed class Qwen2_5_VLVisionAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: bsz, seq_length = hidden_states.shape[0], hidden_states.shape[1] q, k, v = ( self.qkv(hidden_states) .reshape(bsz, seq_length, 3, self.num_heads, -1) .permute(0, 2, 1, 3, 4) .unbind(1) ) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos() sin = emb.sin() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) attention_mask = torch.full( [bsz, 1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype, ) for j in range(bsz): batch_seqlens = cu_seqlens[j] for i in range(1, len(batch_seqlens)): attention_mask[ j, ..., batch_seqlens[i - 1] : batch_seqlens[i], batch_seqlens[i - 1] : batch_seqlens[i], ] = 0 q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(q.dtype) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, seq_length, -1) attn_output = self.proj(attn_output) return attn_output class Qwen2_5_VLVisionSdpaAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) def unpack_qkv_with_mask(self, q, k, v, cu_seqlens): """ Unpacks q, k, v sequences into batch-major form and constructs an additive attention mask. Args: q, k, v: Tensors of shape (total_seq_len, num_heads, head_dim) cu_seqlens: Tensor of shape (batch_size + 1,) with cumulative sequence lengths Returns: batched_q: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim) batched_k: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim) batched_v: Tensor of shape (batch_size, max_seq_len, num_heads, head_dim) attention_mask: Tensor of shape (batch_size, 1, max_seq_len, max_seq_len) with 0 for valid tokens and -inf for padding (for additive attention) """ device = q.device dtype = q.dtype batch_size = cu_seqlens.shape[0] - 1 num_heads = q.shape[1] head_dim = q.shape[2] seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] # Keep as tensor max_seq_len = seq_lengths.max().item() # Use .max() on tensor if settings.FOUNDATION_STATIC_CACHE: # Pad max_seq_len to the nearest multiple for compilation max_seq_len = get_nearest_pad(max_seq_len, pad_multiple=16) # Pad batch_size to the nearest multiple for compilation batch_size = get_nearest_pad(batch_size, pad_multiple=2) # Ensure seq_lengths is a tensor of the correct size seq_lengths = F.pad( seq_lengths, (0, batch_size - seq_lengths.size(0)), "constant", 0 ) # some day, you may look at this, and think: "what if I used repeat_interlave or some other fancy torch instead"? # don't do this - it's a path to madness. For some reason, this loop is optimal batch_indices = [] position_indices = [] for i, seq_len in enumerate( seq_lengths.tolist() ): # Convert to list only for iteration batch_indices.extend([i] * seq_len) position_indices.extend(list(range(seq_len))) batch_indices = torch.tensor(batch_indices, device=device) position_indices = torch.tensor(position_indices, device=device) batched_q = torch.zeros( (batch_size, max_seq_len, num_heads, head_dim), device=device, dtype=dtype ) batched_k = torch.zeros_like(batched_q) batched_v = torch.zeros_like(batched_q) # Create additive attention mask attention_mask = torch.full( (batch_size, max_seq_len, max_seq_len), fill_value=float("-inf"), device=device, dtype=dtype, ) # Create mask for valid positions seq_range = torch.arange(max_seq_len, device=device) valid_mask = seq_range.unsqueeze(0) < seq_lengths.unsqueeze( 1 ) # (batch_size, max_seq_len) valid_2d = valid_mask.unsqueeze(2) & valid_mask.unsqueeze( 1 ) # (batch_size, max_seq_len, max_seq_len) # Simply use boolean indexing to set valid positions to 0 attention_mask[valid_2d] = 0 attention_mask = attention_mask.unsqueeze( 1 ) # (batch_size, 1, max_seq_len, max_seq_len) batched_q[batch_indices, position_indices] = q batched_k[batch_indices, position_indices] = k batched_v[batch_indices, position_indices] = v return ( batched_q, batched_k, batched_v, attention_mask, batch_indices, position_indices, ) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: hidden_states = hidden_states.squeeze(0) cu_seqlens = cu_seqlens.squeeze(0) seq_length = hidden_states.shape[0] q, k, v = ( self.qkv(hidden_states) .reshape(seq_length, 3, self.num_heads, -1) .permute(1, 0, 2, 3) .unbind(0) ) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos() sin = emb.sin() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) q = q.squeeze(0) k = k.squeeze(0) q, k, v, attention_mask, batch_indices, position_indices = ( self.unpack_qkv_with_mask(q, k, v, cu_seqlens) ) batch_size, max_seqlen = q.shape[:2] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) attn_output = F.scaled_dot_product_attention( q, k, v, attention_mask, dropout_p=0.0, ) attn_output = attn_output.permute(0, 2, 1, 3).reshape( batch_size, max_seqlen, -1 ) # Bring back to (batch_size, max_seqlen, hidden_dim) attn_output = attn_output[batch_indices, position_indices] attn_output = self.proj(attn_output) return attn_output.unsqueeze(0) QWEN2_5_VL_VISION_ATTENTION_CLASSES = { "eager": Qwen2_5_VLVisionAttention, "flash_attention_2": Qwen2_5_VLVisionXLAFlashAttention2 if settings.FOUNDATION_XLA else Qwen2_5_VLVisionFlashAttention2, "sdpa": Qwen2_5_VLVisionXLASdpaAttention if settings.FOUNDATION_XLA else Qwen2_5_VLVisionSdpaAttention, } class Qwen2_5_VLVisionBlock(nn.Module): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( config.hidden_size, num_heads=config.num_heads ) self.mlp = Qwen2_5_VLMLP(config, bias=True) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states Qwen2_5_VL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`Qwen2_5_VLConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ class Qwen2_5_VLPreTrainedModel(SuryaPreTrainedModel): config_class = SuryaEncoderConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv3d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): config_class = SuryaEncoderConfig _no_split_modules = ["Qwen2_5_VLVisionBlock"] def __init__(self, config, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) self.spatial_merge_size = config.spatial_merge_size self.patch_size = config.patch_size self.fullatt_block_indexes = config.fullatt_block_indexes self.window_size = config.window_size self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size self.patch_embed = Qwen2_5_VisionPatchEmbed( patch_size=config.patch_size, temporal_patch_size=config.temporal_patch_size, in_channels=config.in_channels, embed_dim=config.hidden_size, ) head_dim = config.hidden_size // config.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( [ Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth) ] ) self.merger = Qwen2_5_VLPatchMerger( dim=config.out_hidden_size, context_dim=config.hidden_size, spatial_merge_size=config.spatial_merge_size, ) self.gradient_checkpointing = False def rot_pos_emb(self, grid_thw): rotary_pos_emb = [] grid_thw_list = grid_thw.cpu().tolist() for batch_item in grid_thw_list: row_pos_ids = [] heights = [h for _, h, _ in batch_item] widths = [w for _, _, w in batch_item] max_grid_size = max(heights + widths) for t, h, w in batch_item: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) hpos_ids = hpos_ids.permute(0, 2, 1, 3) hpos_ids = hpos_ids.flatten() wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) wpos_ids = wpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() # shape: token_count, 2 row_pos_ids.append( torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) ) # shape: token_count, 2 pos_ids = torch.cat(row_pos_ids, dim=0) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb_row = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb.append(rotary_pos_emb_row) rotary_pos_emb = torch.stack(rotary_pos_emb, dim=0) return rotary_pos_emb def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, ) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(bsz, seq_len, hidden_size)`): The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(bsz, num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. Returns: `torch.Tensor`: hidden_states. """ bsz, seq_len, _ = hidden_states.size() hidden_states = self.patch_embed(hidden_states) # (bsz, seq_len, hidden_dim) rotary_pos_emb = self.rot_pos_emb(grid_thw) # hidden_states = hidden_states.reshape(bsz, seq_len, -1) # rotary_pos_emb = rotary_pos_emb.reshape(bsz, seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1).to( hidden_states.device ) position_embeddings = (emb.cos(), emb.sin()) cu_seqlens = (grid_thw[:, :, 1] * grid_thw[:, :, 2]).cumsum( dim=1, # Select dtype based on the following factors: # - FA2 requires that cu_seqlens_q must have dtype int32 # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw # See https://github.com/huggingface/transformers/pull/34852 for more information dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for layer_num, blk in enumerate(self.blocks): if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( blk.__call__, hidden_states, cu_seqlens, None, position_embeddings, ) else: hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, ) hidden_states = self.merger(hidden_states) return hidden_states class SuryaEncoderModel(Qwen2_5_VisionTransformerPretrainedModel): @property def image_size(self) -> int: config: SuryaEncoderConfig = self.config if isinstance(config.image_size, tuple) and len(config.image_size) == 2: return config.image_size elif isinstance(config.image_size, int): return (config.image_size, config.image_size) raise ValueError( f"The `image_size` for SwinConfig should be a tuple of (int, int) or a single int but found {type(config.image_size)}" ) @property def hidden_size(self) -> int: config: SuryaEncoderConfig = self.config return config.hidden_size def embed_images( self, image_batch: torch.Tensor, grid_thw: torch.Tensor, ) -> torch.Tensor: return super().forward( hidden_states=image_batch, grid_thw=grid_thw, ) ```