This is page 5 of 9. Use http://codebase.md/datalab-to/marker?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .github │ ├── ISSUE_TEMPLATE │ │ ├── breaking-bug-report.md │ │ ├── feature_request.md │ │ └── output-bug-report.md │ └── workflows │ ├── benchmarks.yml │ ├── ci.yml │ ├── cla.yml │ ├── publish.yml │ └── scripts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── benchmarks │ ├── __init__.py │ ├── overall │ │ ├── __init__.py │ │ ├── display │ │ │ ├── __init__.py │ │ │ ├── dataset.py │ │ │ └── table.py │ │ ├── download │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── llamaparse.py │ │ │ ├── main.py │ │ │ ├── mathpix.py │ │ │ └── mistral.py │ │ ├── elo.py │ │ ├── methods │ │ │ ├── __init__.py │ │ │ ├── docling.py │ │ │ ├── gt.py │ │ │ ├── llamaparse.py │ │ │ ├── marker.py │ │ │ ├── mathpix.py │ │ │ ├── mistral.py │ │ │ ├── olmocr.py │ │ │ └── schema.py │ │ ├── overall.py │ │ ├── registry.py │ │ ├── schema.py │ │ └── scorers │ │ ├── __init__.py │ │ ├── clean.py │ │ ├── heuristic.py │ │ ├── llm.py │ │ └── schema.py │ ├── table │ │ ├── __init__.py │ │ ├── gemini.py │ │ ├── inference.py │ │ ├── scoring.py │ │ └── table.py │ ├── throughput │ │ ├── __init__.py │ │ └── main.py │ └── verify_scores.py ├── chunk_convert.py ├── CLA.md ├── convert_single.py ├── convert.py ├── data │ ├── .gitignore │ ├── examples │ │ ├── json │ │ │ ├── multicolcnn.json │ │ │ ├── switch_trans.json │ │ │ └── thinkpython.json │ │ └── markdown │ │ ├── multicolcnn │ │ │ ├── _page_1_Figure_0.jpeg │ │ │ ├── _page_2_Picture_0.jpeg │ │ │ ├── _page_6_Figure_0.jpeg │ │ │ ├── _page_7_Figure_0.jpeg │ │ │ ├── multicolcnn_meta.json │ │ │ └── multicolcnn.md │ │ ├── switch_transformers │ │ │ ├── _page_11_Figure_4.jpeg │ │ │ ├── _page_12_Figure_4.jpeg │ │ │ ├── _page_13_Figure_2.jpeg │ │ │ ├── _page_18_Figure_1.jpeg │ │ │ ├── _page_18_Figure_3.jpeg │ │ │ ├── _page_2_Figure_3.jpeg │ │ │ ├── _page_20_Figure_1.jpeg │ │ │ ├── _page_20_Figure_4.jpeg │ │ │ ├── _page_27_Figure_1.jpeg │ │ │ ├── _page_29_Figure_1.jpeg │ │ │ ├── _page_30_Figure_1.jpeg │ │ │ ├── _page_31_Figure_3.jpeg │ │ │ ├── _page_4_Figure_1.jpeg │ │ │ ├── _page_5_Figure_3.jpeg │ │ │ ├── switch_trans_meta.json │ │ │ └── switch_trans.md │ │ └── thinkpython │ │ ├── _page_109_Figure_1.jpeg │ │ ├── _page_115_Figure_1.jpeg │ │ ├── _page_116_Figure_3.jpeg │ │ ├── _page_127_Figure_1.jpeg │ │ ├── _page_128_Figure_1.jpeg │ │ ├── _page_167_Figure_1.jpeg │ │ ├── _page_169_Figure_1.jpeg │ │ ├── _page_173_Figure_1.jpeg │ │ ├── _page_190_Figure_1.jpeg │ │ ├── _page_195_Figure_1.jpeg │ │ ├── _page_205_Figure_1.jpeg │ │ ├── _page_23_Figure_1.jpeg │ │ ├── _page_23_Figure_3.jpeg │ │ ├── _page_230_Figure_1.jpeg │ │ ├── _page_233_Figure_1.jpeg │ │ ├── _page_233_Figure_3.jpeg │ │ ├── _page_234_Figure_1.jpeg │ │ ├── _page_235_Figure_1.jpeg │ │ ├── _page_236_Figure_1.jpeg │ │ ├── _page_236_Figure_3.jpeg │ │ ├── _page_237_Figure_1.jpeg │ │ ├── _page_238_Figure_1.jpeg │ │ ├── _page_46_Figure_1.jpeg │ │ ├── _page_60_Figure_1.jpeg │ │ ├── _page_60_Figure_3.jpeg │ │ ├── _page_67_Figure_1.jpeg │ │ ├── _page_71_Figure_1.jpeg │ │ ├── _page_78_Figure_1.jpeg │ │ ├── _page_85_Figure_1.jpeg │ │ ├── _page_94_Figure_1.jpeg │ │ ├── _page_99_Figure_17.jpeg │ │ ├── _page_99_Figure_178.jpeg │ │ ├── thinkpython_meta.json │ │ └── thinkpython.md │ ├── images │ │ ├── overall.png │ │ ├── per_doc.png │ │ └── table.png │ └── latex_to_md.sh ├── examples │ ├── marker_modal_deployment.py │ └── README.md ├── extraction_app.py ├── LICENSE ├── marker │ ├── builders │ │ ├── __init__.py │ │ ├── document.py │ │ ├── layout.py │ │ ├── line.py │ │ ├── ocr.py │ │ └── structure.py │ ├── config │ │ ├── __init__.py │ │ ├── crawler.py │ │ ├── parser.py │ │ └── printer.py │ ├── converters │ │ ├── __init__.py │ │ ├── extraction.py │ │ ├── ocr.py │ │ ├── pdf.py │ │ └── table.py │ ├── extractors │ │ ├── __init__.py │ │ ├── document.py │ │ └── page.py │ ├── logger.py │ ├── models.py │ ├── output.py │ ├── processors │ │ ├── __init__.py │ │ ├── blank_page.py │ │ ├── block_relabel.py │ │ ├── blockquote.py │ │ ├── code.py │ │ ├── debug.py │ │ ├── document_toc.py │ │ ├── equation.py │ │ ├── footnote.py │ │ ├── ignoretext.py │ │ ├── line_merge.py │ │ ├── line_numbers.py │ │ ├── list.py │ │ ├── llm │ │ │ ├── __init__.py │ │ │ ├── llm_complex.py │ │ │ ├── llm_equation.py │ │ │ ├── llm_form.py │ │ │ ├── llm_handwriting.py │ │ │ ├── llm_image_description.py │ │ │ ├── llm_mathblock.py │ │ │ ├── llm_meta.py │ │ │ ├── llm_page_correction.py │ │ │ ├── llm_sectionheader.py │ │ │ ├── llm_table_merge.py │ │ │ └── llm_table.py │ │ ├── order.py │ │ ├── page_header.py │ │ ├── reference.py │ │ ├── sectionheader.py │ │ ├── table.py │ │ ├── text.py │ │ └── util.py │ ├── providers │ │ ├── __init__.py │ │ ├── document.py │ │ ├── epub.py │ │ ├── html.py │ │ ├── image.py │ │ ├── pdf.py │ │ ├── powerpoint.py │ │ ├── registry.py │ │ ├── spreadsheet.py │ │ └── utils.py │ ├── renderers │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── extraction.py │ │ ├── html.py │ │ ├── json.py │ │ ├── markdown.py │ │ └── ocr_json.py │ ├── schema │ │ ├── __init__.py │ │ ├── blocks │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── basetable.py │ │ │ ├── caption.py │ │ │ ├── code.py │ │ │ ├── complexregion.py │ │ │ ├── equation.py │ │ │ ├── figure.py │ │ │ ├── footnote.py │ │ │ ├── form.py │ │ │ ├── handwriting.py │ │ │ ├── inlinemath.py │ │ │ ├── listitem.py │ │ │ ├── pagefooter.py │ │ │ ├── pageheader.py │ │ │ ├── picture.py │ │ │ ├── reference.py │ │ │ ├── sectionheader.py │ │ │ ├── table.py │ │ │ ├── tablecell.py │ │ │ ├── text.py │ │ │ └── toc.py │ │ ├── document.py │ │ ├── groups │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── figure.py │ │ │ ├── list.py │ │ │ ├── page.py │ │ │ ├── picture.py │ │ │ └── table.py │ │ ├── polygon.py │ │ ├── registry.py │ │ └── text │ │ ├── __init__.py │ │ ├── char.py │ │ ├── line.py │ │ └── span.py │ ├── scripts │ │ ├── __init__.py │ │ ├── chunk_convert.py │ │ ├── chunk_convert.sh │ │ ├── common.py │ │ ├── convert_single.py │ │ ├── convert.py │ │ ├── extraction_app.py │ │ ├── file_to_s3.py │ │ ├── run_streamlit_app.py │ │ ├── server.py │ │ └── streamlit_app.py │ ├── services │ │ ├── __init__.py │ │ ├── azure_openai.py │ │ ├── claude.py │ │ ├── gemini.py │ │ ├── ollama.py │ │ ├── openai.py │ │ └── vertex.py │ ├── settings.py │ ├── util.py │ └── utils │ ├── __init__.py │ ├── batch.py │ ├── gpu.py │ └── image.py ├── marker_app.py ├── marker_server.py ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── README.md ├── signatures │ └── version1 │ └── cla.json ├── static │ └── fonts │ └── .gitignore └── tests ├── builders │ ├── test_blank_page.py │ ├── test_document_builder.py │ ├── test_garbled_pdf.py │ ├── test_layout_replace.py │ ├── test_ocr_builder.py │ ├── test_ocr_pipeline.py │ ├── test_overriding.py │ ├── test_pdf_links.py │ ├── test_rotated_bboxes.py │ ├── test_strip_existing_ocr.py │ └── test_structure.py ├── config │ └── test_config.py ├── conftest.py ├── converters │ ├── test_extraction_converter.py │ ├── test_ocr_converter.py │ ├── test_pdf_converter.py │ └── test_table_converter.py ├── processors │ ├── test_document_toc_processor.py │ ├── test_equation_processor.py │ ├── test_footnote_processor.py │ ├── test_ignoretext.py │ ├── test_llm_processors.py │ ├── test_table_merge.py │ └── test_table_processor.py ├── providers │ ├── test_document_providers.py │ ├── test_image_provider.py │ └── test_pdf_provider.py ├── renderers │ ├── test_chunk_renderer.py │ ├── test_extract_images.py │ ├── test_html_renderer.py │ ├── test_json_renderer.py │ └── test_markdown_renderer.py ├── schema │ └── groups │ └── test_list_grouping.py ├── services │ └── test_service_init.py └── utils.py ``` # Files -------------------------------------------------------------------------------- /marker/builders/line.py: -------------------------------------------------------------------------------- ```python 1 | from copy import deepcopy 2 | from typing import Annotated, List, Tuple 3 | 4 | import numpy as np 5 | from PIL import Image 6 | import cv2 7 | 8 | from surya.detection import DetectionPredictor 9 | from surya.ocr_error import OCRErrorPredictor 10 | 11 | from marker.builders import BaseBuilder 12 | from marker.providers import ProviderOutput, ProviderPageLines 13 | from marker.providers.pdf import PdfProvider 14 | from marker.schema import BlockTypes 15 | from marker.schema.document import Document 16 | from marker.schema.groups.page import PageGroup 17 | from marker.schema.polygon import PolygonBox 18 | from marker.schema.registry import get_block_class 19 | from marker.schema.text.line import Line 20 | from marker.settings import settings 21 | from marker.util import matrix_intersection_area, sort_text_lines 22 | from marker.utils.image import is_blank_image 23 | 24 | 25 | class LineBuilder(BaseBuilder): 26 | """ 27 | A builder for detecting text lines. Merges the detected lines with the lines from the provider 28 | """ 29 | 30 | detection_batch_size: Annotated[ 31 | int, 32 | "The batch size to use for the detection model.", 33 | "Default is None, which will use the default batch size for the model.", 34 | ] = None 35 | ocr_error_batch_size: Annotated[ 36 | int, 37 | "The batch size to use for the ocr error detection model.", 38 | "Default is None, which will use the default batch size for the model.", 39 | ] = None 40 | layout_coverage_min_lines: Annotated[ 41 | int, 42 | "The minimum number of PdfProvider lines that must be covered by the layout model", 43 | "to consider the lines from the PdfProvider valid.", 44 | ] = 1 45 | layout_coverage_threshold: Annotated[ 46 | float, 47 | "The minimum coverage ratio required for the layout model to consider", 48 | "the lines from the PdfProvider valid.", 49 | ] = 0.25 50 | min_document_ocr_threshold: Annotated[ 51 | float, 52 | "If less pages than this threshold are good, OCR will happen in the document. Otherwise it will not.", 53 | ] = 0.85 54 | provider_line_provider_line_min_overlap_pct: Annotated[ 55 | float, 56 | "The percentage of a provider line that has to be covered by a detected line", 57 | ] = 0.1 58 | excluded_for_coverage: Annotated[ 59 | Tuple[BlockTypes], 60 | "A list of block types to exclude from the layout coverage check.", 61 | ] = ( 62 | BlockTypes.Figure, 63 | BlockTypes.Picture, 64 | BlockTypes.Table, 65 | BlockTypes.FigureGroup, 66 | BlockTypes.TableGroup, 67 | BlockTypes.PictureGroup, 68 | ) 69 | ocr_remove_blocks: Tuple[BlockTypes, ...] = ( 70 | BlockTypes.Table, 71 | BlockTypes.Form, 72 | BlockTypes.TableOfContents, 73 | ) 74 | disable_tqdm: Annotated[ 75 | bool, 76 | "Disable tqdm progress bars.", 77 | ] = False 78 | disable_ocr: Annotated[ 79 | bool, 80 | "Disable OCR for the document. This will only use the lines from the provider.", 81 | ] = False 82 | keep_chars: Annotated[bool, "Keep individual characters."] = False 83 | detection_line_min_confidence: Annotated[float, "Minimum confidence for a detected line to be included"] = 0.8 84 | 85 | def __init__( 86 | self, 87 | detection_model: DetectionPredictor, 88 | ocr_error_model: OCRErrorPredictor, 89 | config=None, 90 | ): 91 | super().__init__(config) 92 | 93 | self.detection_model = detection_model 94 | self.ocr_error_model = ocr_error_model 95 | 96 | def __call__(self, document: Document, provider: PdfProvider): 97 | # Disable inline detection for documents where layout model doesn't detect any equations 98 | # Also disable if we won't use the inline detections (if we aren't using the LLM) 99 | provider_lines, ocr_lines = self.get_all_lines(document, provider) 100 | self.merge_blocks(document, provider_lines, ocr_lines) 101 | 102 | def get_detection_batch_size(self): 103 | if self.detection_batch_size is not None: 104 | return self.detection_batch_size 105 | elif settings.TORCH_DEVICE_MODEL == "cuda": 106 | return 10 107 | return 4 108 | 109 | def get_ocr_error_batch_size(self): 110 | if self.ocr_error_batch_size is not None: 111 | return self.ocr_error_batch_size 112 | elif settings.TORCH_DEVICE_MODEL == "cuda": 113 | return 14 114 | return 4 115 | 116 | def get_detection_results( 117 | self, page_images: List[Image.Image], run_detection: List[bool] 118 | ): 119 | self.detection_model.disable_tqdm = self.disable_tqdm 120 | page_detection_results = self.detection_model( 121 | images=page_images, batch_size=self.get_detection_batch_size() 122 | ) 123 | 124 | assert len(page_detection_results) == sum(run_detection) 125 | detection_results = [] 126 | idx = 0 127 | for good in run_detection: 128 | if good: 129 | detection_results.append(page_detection_results[idx]) 130 | idx += 1 131 | else: 132 | detection_results.append(None) 133 | assert idx == len(page_images) 134 | 135 | assert len(run_detection) == len(detection_results) 136 | return detection_results 137 | 138 | def get_all_lines(self, document: Document, provider: PdfProvider): 139 | ocr_error_detection_results = self.ocr_error_detection( 140 | document.pages, provider.page_lines 141 | ) 142 | 143 | boxes_to_ocr = {page.page_id: [] for page in document.pages} 144 | page_lines = {page.page_id: [] for page in document.pages} 145 | 146 | LineClass: Line = get_block_class(BlockTypes.Line) 147 | 148 | layout_good = [] 149 | for document_page, ocr_error_detection_label in zip( 150 | document.pages, ocr_error_detection_results.labels 151 | ): 152 | document_page.ocr_errors_detected = ocr_error_detection_label == "bad" 153 | provider_lines: List[ProviderOutput] = provider.page_lines.get( 154 | document_page.page_id, [] 155 | ) 156 | provider_lines_good = all( 157 | [ 158 | bool(provider_lines), 159 | not document_page.ocr_errors_detected, 160 | self.check_layout_coverage(document_page, provider_lines), 161 | self.check_line_overlaps( 162 | document_page, provider_lines 163 | ), # Ensure provider lines don't overflow the page or intersect 164 | ] 165 | ) 166 | if self.disable_ocr: 167 | provider_lines_good = True 168 | 169 | layout_good.append(provider_lines_good) 170 | 171 | run_detection = [not good for good in layout_good] 172 | page_images = [ 173 | page.get_image(highres=False, remove_blocks=self.ocr_remove_blocks) 174 | for page, bad in zip(document.pages, run_detection) 175 | if bad 176 | ] 177 | 178 | # Note: run_detection is longer than page_images, since it has a value for each page, not just good ones 179 | # Detection results and inline detection results are for every page (we use run_detection to make the list full length) 180 | detection_results = self.get_detection_results(page_images, run_detection) 181 | 182 | assert len(detection_results) == len(layout_good) == len(document.pages) 183 | for document_page, detection_result, provider_lines_good in zip( 184 | document.pages, detection_results, layout_good 185 | ): 186 | provider_lines: List[ProviderOutput] = provider.page_lines.get( 187 | document_page.page_id, [] 188 | ) 189 | 190 | # Setup detection results 191 | detection_boxes = [] 192 | if detection_result: 193 | detection_boxes = [ 194 | PolygonBox(polygon=box.polygon) for box in detection_result.bboxes if box.confidence > self.detection_line_min_confidence 195 | ] 196 | 197 | detection_boxes = sort_text_lines(detection_boxes) 198 | 199 | if provider_lines_good: 200 | document_page.text_extraction_method = "pdftext" 201 | 202 | # Mark extraction method as pdftext, since all lines are good 203 | for provider_line in provider_lines: 204 | provider_line.line.text_extraction_method = "pdftext" 205 | 206 | page_lines[document_page.page_id] = provider_lines 207 | else: 208 | document_page.text_extraction_method = "surya" 209 | boxes_to_ocr[document_page.page_id].extend(detection_boxes) 210 | 211 | # Dummy lines to merge into the document - Contains no spans, will be filled in later by OCRBuilder 212 | ocr_lines = {document_page.page_id: [] for document_page in document.pages} 213 | for page_id, page_ocr_boxes in boxes_to_ocr.items(): 214 | page_size = provider.get_page_bbox(page_id).size 215 | image_size = document.get_page(page_id).get_image(highres=False).size 216 | for box_to_ocr in page_ocr_boxes: 217 | line_polygon = PolygonBox(polygon=box_to_ocr.polygon).rescale( 218 | image_size, page_size 219 | ) 220 | ocr_lines[page_id].append( 221 | ProviderOutput( 222 | line=LineClass( 223 | polygon=line_polygon, 224 | page_id=page_id, 225 | text_extraction_method="surya", 226 | ), 227 | spans=[], 228 | chars=[], 229 | ) 230 | ) 231 | 232 | return page_lines, ocr_lines 233 | 234 | def ocr_error_detection( 235 | self, pages: List[PageGroup], provider_page_lines: ProviderPageLines 236 | ): 237 | page_texts = [] 238 | for document_page in pages: 239 | provider_lines = provider_page_lines.get(document_page.page_id, []) 240 | page_text = "\n".join( 241 | " ".join(s.text for s in line.spans) for line in provider_lines 242 | ) 243 | page_texts.append(page_text) 244 | 245 | self.ocr_error_model.disable_tqdm = self.disable_tqdm 246 | ocr_error_detection_results = self.ocr_error_model( 247 | page_texts, batch_size=int(self.get_ocr_error_batch_size()) 248 | ) 249 | return ocr_error_detection_results 250 | 251 | def check_line_overlaps( 252 | self, document_page: PageGroup, provider_lines: List[ProviderOutput] 253 | ) -> bool: 254 | provider_bboxes = [line.line.polygon.bbox for line in provider_lines] 255 | # Add a small margin to account for minor overflows 256 | page_bbox = document_page.polygon.expand(5, 5).bbox 257 | 258 | for bbox in provider_bboxes: 259 | if bbox[0] < page_bbox[0]: 260 | return False 261 | if bbox[1] < page_bbox[1]: 262 | return False 263 | if bbox[2] > page_bbox[2]: 264 | return False 265 | if bbox[3] > page_bbox[3]: 266 | return False 267 | 268 | intersection_matrix = matrix_intersection_area(provider_bboxes, provider_bboxes) 269 | for i, line in enumerate(provider_lines): 270 | intersect_counts = np.sum( 271 | intersection_matrix[i] 272 | > self.provider_line_provider_line_min_overlap_pct 273 | ) 274 | 275 | # There should be one intersection with itself 276 | if intersect_counts > 2: 277 | return False 278 | 279 | return True 280 | 281 | def check_layout_coverage( 282 | self, 283 | document_page: PageGroup, 284 | provider_lines: List[ProviderOutput], 285 | ): 286 | covered_blocks = 0 287 | total_blocks = 0 288 | large_text_blocks = 0 289 | 290 | layout_blocks = [ 291 | document_page.get_block(block) for block in document_page.structure 292 | ] 293 | layout_blocks = [ 294 | b for b in layout_blocks if b.block_type not in self.excluded_for_coverage 295 | ] 296 | 297 | layout_bboxes = [block.polygon.bbox for block in layout_blocks] 298 | provider_bboxes = [line.line.polygon.bbox for line in provider_lines] 299 | 300 | if len(layout_bboxes) == 0: 301 | return True 302 | 303 | if len(provider_bboxes) == 0: 304 | return False 305 | 306 | intersection_matrix = matrix_intersection_area(layout_bboxes, provider_bboxes) 307 | 308 | for idx, layout_block in enumerate(layout_blocks): 309 | total_blocks += 1 310 | intersecting_lines = np.count_nonzero(intersection_matrix[idx] > 0) 311 | 312 | if intersecting_lines >= self.layout_coverage_min_lines: 313 | covered_blocks += 1 314 | 315 | if ( 316 | layout_block.polygon.intersection_pct(document_page.polygon) > 0.8 317 | and layout_block.block_type == BlockTypes.Text 318 | ): 319 | large_text_blocks += 1 320 | 321 | coverage_ratio = covered_blocks / total_blocks if total_blocks > 0 else 1 322 | text_okay = coverage_ratio >= self.layout_coverage_threshold 323 | 324 | # Model will sometimes say there is a single block of text on the page when it is blank 325 | if not text_okay and (total_blocks == 1 and large_text_blocks == 1): 326 | text_okay = True 327 | return text_okay 328 | 329 | def filter_blank_lines(self, page: PageGroup, lines: List[ProviderOutput]): 330 | page_size = (page.polygon.width, page.polygon.height) 331 | page_image = page.get_image() 332 | image_size = page_image.size 333 | 334 | good_lines = [] 335 | for line in lines: 336 | line_polygon_rescaled = deepcopy(line.line.polygon).rescale( 337 | page_size, image_size 338 | ) 339 | line_bbox = line_polygon_rescaled.fit_to_bounds((0, 0, *image_size)).bbox 340 | 341 | if not is_blank_image(page_image.crop(line_bbox)): 342 | good_lines.append(line) 343 | 344 | return good_lines 345 | 346 | def merge_blocks( 347 | self, 348 | document: Document, 349 | page_provider_lines: ProviderPageLines, 350 | page_ocr_lines: ProviderPageLines, 351 | ): 352 | for document_page in document.pages: 353 | provider_lines: List[ProviderOutput] = page_provider_lines[ 354 | document_page.page_id 355 | ] 356 | ocr_lines: List[ProviderOutput] = page_ocr_lines[document_page.page_id] 357 | 358 | # Only one or the other will have lines 359 | # Filter out blank lines which come from bad provider boxes, or invisible text 360 | merged_lines = self.filter_blank_lines( 361 | document_page, provider_lines + ocr_lines 362 | ) 363 | 364 | # Text extraction method is overridden later for OCRed documents 365 | document_page.merge_blocks( 366 | merged_lines, 367 | text_extraction_method="pdftext" if provider_lines else "surya", 368 | keep_chars=self.keep_chars, 369 | ) 370 | ``` -------------------------------------------------------------------------------- /examples/marker_modal_deployment.py: -------------------------------------------------------------------------------- ```python 1 | """ 2 | Modal deployment for Datalab Marker PDF conversion service. 3 | """ 4 | 5 | import modal 6 | import os 7 | from typing import Optional 8 | 9 | # Define the Modal app 10 | app = modal.App("datalab-marker-modal-demo") 11 | GPU_TYPE = "L40S" 12 | MODEL_PATH_PREFIX = "/root/.cache/datalab/models" 13 | 14 | # Define the container image with all dependencies 15 | image = ( 16 | modal.Image.debian_slim(python_version="3.10") 17 | .apt_install(["git", "wget"]) 18 | .env({"TORCH_DEVICE": "cuda"}) 19 | .pip_install([ 20 | "marker-pdf[full]", 21 | "fastapi==0.104.1", 22 | "uvicorn==0.24.0", 23 | "python-multipart==0.0.6", 24 | "torch>=2.2.2,<3.0.0", 25 | "torchvision>=0.17.0", 26 | "torchaudio>=2.2.0", 27 | ]) 28 | ) 29 | 30 | # Create a persistent volume for model caching 31 | models_volume = modal.Volume.from_name("marker-models-modal-demo", create_if_missing=True) 32 | 33 | def setup_models_with_cache_check(logger, commit_volume=False): 34 | """ 35 | Shared function to create models and handle cache checking/logging. 36 | """ 37 | import os 38 | import gc 39 | from marker.models import create_model_dict 40 | 41 | # Check if models exist in cache 42 | models_dir_exists = os.path.exists(MODEL_PATH_PREFIX) 43 | models_dir_contents = os.listdir(MODEL_PATH_PREFIX) if models_dir_exists else [] 44 | 45 | logger.info(f"Models cache directory exists: {models_dir_exists}") 46 | logger.info(f"Models cache directory contents: {models_dir_contents}") 47 | 48 | if models_dir_exists and models_dir_contents: 49 | logger.info("Found existing models in volume cache, loading from cache...") 50 | else: 51 | logger.warning("No models found in volume cache. Models will be downloaded now (this may take several minutes).") 52 | 53 | # Create/load models 54 | models = create_model_dict() 55 | logger.info(f"Successfully loaded {len(models)} models") 56 | 57 | # Check what was downloaded/cached 58 | if os.path.exists(MODEL_PATH_PREFIX): 59 | contents = os.listdir(MODEL_PATH_PREFIX) 60 | logger.info(f"Models in cache: {contents}") 61 | 62 | # Commit volume if requested (for download function) 63 | if commit_volume: 64 | gc.collect() 65 | logger.info("Attempting to commit volume...") 66 | models_volume.commit() 67 | logger.info("Volume committed successfully") 68 | 69 | return models 70 | 71 | @app.function( 72 | image=image, 73 | volumes={MODEL_PATH_PREFIX: models_volume}, 74 | gpu=GPU_TYPE, 75 | timeout=600, 76 | ) 77 | def download_models(): 78 | """ 79 | Helper function to download models used in marker into a Modal volume. 80 | """ 81 | import logging 82 | 83 | logging.basicConfig(level=logging.INFO) 84 | logger = logging.getLogger(__name__) 85 | 86 | logger.info("Downloading models to persistent volume...") 87 | logger.info(f"Volume mounted at: {MODEL_PATH_PREFIX}") 88 | 89 | try: 90 | models = setup_models_with_cache_check(logger, commit_volume=True) 91 | return f"Models downloaded successfully: {list(models.keys())}" 92 | except Exception as e: 93 | logger.error(f"Failed to download models: {e}") 94 | raise 95 | 96 | @app.cls( 97 | image=image, 98 | gpu=GPU_TYPE, 99 | memory=16384, 100 | timeout=600, # 10 minute timeout for large documents 101 | volumes={MODEL_PATH_PREFIX: models_volume}, 102 | scaledown_window=300, 103 | ) 104 | class MarkerModalDemoService: 105 | @modal.enter() 106 | def load_models(self): 107 | """Load models once per container using @modal.enter() for efficiency.""" 108 | import logging 109 | import traceback 110 | 111 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 112 | logger = logging.getLogger(__name__) 113 | 114 | logger.info("Loading Marker models using @modal.enter()...") 115 | try: 116 | self.models = setup_models_with_cache_check(logger, commit_volume=True) 117 | except Exception as e: 118 | logger.error(f"Error loading models: {e}") 119 | traceback.print_exc() 120 | self.models = None 121 | 122 | @modal.asgi_app() 123 | def marker_api(self): 124 | import traceback 125 | import io 126 | import base64 127 | import logging 128 | from contextlib import asynccontextmanager 129 | from typing import Optional 130 | from pathlib import Path 131 | 132 | from fastapi import FastAPI, Form, File, UploadFile, HTTPException 133 | from fastapi.responses import JSONResponse 134 | 135 | from marker.converters.pdf import PdfConverter 136 | from marker.config.parser import ConfigParser 137 | from marker.settings import settings 138 | 139 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 140 | logger = logging.getLogger(__name__) 141 | 142 | @asynccontextmanager 143 | async def lifespan(app: FastAPI): 144 | # Models are already loaded in @modal.enter() 145 | logger.info("Datalab Marker / Modal demo app starting up...") 146 | yield 147 | logger.info("Datalab Marker / Modal demo app shutting down...") 148 | 149 | # Create FastAPI app 150 | web_app = FastAPI( 151 | title="Datalab Marker PDF Conversion Service - Modal Demo", 152 | description="Convert PDFs and documents to markdown, JSON, or HTML using Marker, deployed on Modal", 153 | version="1.0.0", 154 | lifespan=lifespan 155 | ) 156 | 157 | @web_app.get("/health") 158 | async def health_check(): 159 | models_loaded = hasattr(self, 'models') and self.models is not None 160 | model_count = len(self.models) if models_loaded else 0 161 | 162 | # Check volume contents for debugging 163 | cache_exists = os.path.exists(MODEL_PATH_PREFIX) 164 | cache_contents = os.listdir(MODEL_PATH_PREFIX) if cache_exists else [] 165 | 166 | return { 167 | "status": "healthy" if models_loaded else "loading", 168 | "models_loaded": models_loaded, 169 | "model_count": model_count, 170 | "cache_dir": MODEL_PATH_PREFIX, 171 | "cache_exists": cache_exists, 172 | "cache_contents": cache_contents[:10] 173 | } 174 | 175 | @web_app.post("/convert") 176 | async def convert_document( 177 | file: UploadFile = File(..., description="Document to convert"), 178 | page_range: Optional[str] = Form(None), 179 | force_ocr: bool = Form(False), 180 | paginate_output: bool = Form(False), 181 | output_format: str = Form("markdown"), 182 | use_llm: bool = Form(False), 183 | ): 184 | """Convert uploaded document to specified format.""" 185 | 186 | if not hasattr(self, 'models') or self.models is None: 187 | logger.error("Models not available for conversion") 188 | raise HTTPException(status_code=503, detail="Models not loaded yet. Please wait for model initialization.") 189 | 190 | # Validate file type 191 | allowed_extensions = {'.pdf', '.png', '.jpg', '.jpeg', '.tiff', '.bmp'} 192 | file_ext = Path(file.filename).suffix.lower() 193 | if file_ext not in allowed_extensions: 194 | raise HTTPException( 195 | status_code=400, 196 | detail=f"Unsupported file type: {file_ext}. Supported: {allowed_extensions}" 197 | ) 198 | 199 | # Validate output format 200 | if output_format not in ["markdown", "json", "html", "chunks"]: 201 | raise HTTPException( 202 | status_code=400, 203 | detail="Output format must be one of: markdown, json, html, chunks" 204 | ) 205 | 206 | try: 207 | # Read file content 208 | file_content = await file.read() 209 | 210 | # Save to temporary file 211 | temp_path = f"/tmp/{file.filename}" 212 | with open(temp_path, "wb") as temp_file: 213 | temp_file.write(file_content) 214 | 215 | # Configure conversion parameters 216 | config = { 217 | "filepath": temp_path, 218 | "page_range": page_range, 219 | "force_ocr": force_ocr, 220 | "paginate_output": paginate_output, 221 | "output_format": output_format, 222 | "use_llm": use_llm, 223 | } 224 | 225 | # Create converter 226 | config_parser = ConfigParser(config) 227 | config_dict = config_parser.generate_config_dict() 228 | config_dict["pdftext_workers"] = 1 229 | 230 | converter = PdfConverter( 231 | config=config_dict, 232 | artifact_dict=self.models, 233 | processor_list=config_parser.get_processors(), 234 | renderer=config_parser.get_renderer(), 235 | llm_service=config_parser.get_llm_service() if use_llm else None, 236 | ) 237 | 238 | # Convert document - converter already applies the appropriate renderer 239 | logger.info(f"Converting {file.filename} to {output_format}...") 240 | rendered_output = converter(temp_path) 241 | 242 | # Extract content based on output format 243 | json_content = None 244 | html_content = None 245 | markdown_content = None 246 | encoded_images = {} 247 | 248 | if output_format == "json": 249 | # For JSON, return the structured data directly 250 | json_content = rendered_output.model_dump() 251 | else: 252 | from marker.output import text_from_rendered 253 | text, _, images = text_from_rendered(rendered_output) 254 | 255 | # Assign to appropriate content field 256 | if output_format == "html": 257 | html_content = text 258 | else: 259 | markdown_content = text 260 | 261 | # Encode images as base64 262 | for img_name, img_obj in images.items(): 263 | byte_stream = io.BytesIO() 264 | img_obj.save(byte_stream, format=settings.OUTPUT_IMAGE_FORMAT) 265 | encoded_images[img_name] = base64.b64encode(byte_stream.getvalue()).decode('utf-8') 266 | 267 | metadata = rendered_output.metadata 268 | 269 | logger.info(f"Conversion completed for {file.filename}") 270 | 271 | # Clean up temp file 272 | os.unlink(temp_path) 273 | 274 | return JSONResponse({ 275 | "success": True, 276 | "filename": file.filename, 277 | "output_format": output_format, 278 | "json": json_content, 279 | "html": html_content, 280 | "markdown": markdown_content, 281 | "images": encoded_images, 282 | "metadata": metadata, 283 | "page_count": len(metadata.get("page_stats", [])), 284 | }) 285 | 286 | except Exception as e: 287 | # Clean up temp file if it exists 288 | if os.path.exists(temp_path): 289 | os.unlink(temp_path) 290 | 291 | logger.error(f"Conversion error for {file.filename}: {str(e)}") 292 | traceback.print_exc() 293 | 294 | raise HTTPException( 295 | status_code=500, 296 | detail=f"Conversion failed: {str(e)}" 297 | ) 298 | 299 | return web_app 300 | 301 | 302 | # 303 | # This does not get deployed. It's a useful entrypoint from your local CLI 304 | # that you can use to test your deployment. It'll store the 305 | # API response in a new file on your machine. 306 | # 307 | @app.local_entrypoint() 308 | async def invoke_conversion( 309 | pdf_file: Optional[str] = None, 310 | output_format: str = "markdown", 311 | env: str = 'main' 312 | ): 313 | """ 314 | Local entrypoint to test your deployed Marker endpoint in Modal. 315 | 316 | Usage: 317 | modal run marker_modal_deployment.py::invoke_conversion --pdf-file /path/to/file.pdf --output-format markdown 318 | """ 319 | import requests 320 | import json 321 | from pathlib import Path 322 | 323 | if not pdf_file: 324 | print("No PDF file specified. Use --pdf-file /path/to/your.pdf") 325 | return 326 | 327 | pdf_path = Path(pdf_file) 328 | if not pdf_path.exists(): 329 | print(f"File not found: {pdf_file}") 330 | return 331 | 332 | # 333 | # Get the web URL for our deployed service 334 | # 335 | try: 336 | service = modal.Cls.from_name( 337 | "datalab-marker-modal-demo", 338 | "MarkerModalDemoService", 339 | environment_name=env 340 | ) 341 | web_url = service().marker_api.get_web_url() 342 | print(f"Found deployed service at: {web_url}") 343 | except Exception as e: 344 | print(f"Error getting web URL: {e}") 345 | print("Make sure you've deployed the service first with: modal deploy marker_modal_deployment.py") 346 | return 347 | 348 | print(f"Testing conversion of: {pdf_path.name}") 349 | print(f"Output format: {output_format}") 350 | 351 | # 352 | # Test health endpoint first 353 | # 354 | try: 355 | health_response = requests.get(f"{web_url}/health") 356 | health_data = health_response.json() 357 | print(f"Service health: {health_data['status']}") 358 | print(f"Models loaded: {health_data['models_loaded']} ({health_data['model_count']} models)") 359 | 360 | if not health_data['models_loaded']: 361 | print("Warning: Models not loaded yet. First request may be slow.") 362 | 363 | except Exception as e: 364 | print(f"Health check failed: {e}") 365 | 366 | # 367 | # Make conversion request 368 | # 369 | try: 370 | with open(pdf_path, 'rb') as f: 371 | files = {'file': (pdf_path.name, f, 'application/pdf')} 372 | data = {'output_format': output_format} 373 | 374 | print(f"Sending request to {web_url}/convert...") 375 | response = requests.post(f"{web_url}/convert", files=files, data=data) 376 | 377 | if response.status_code == 200: 378 | result = response.json() 379 | print(f"✅ Conversion successful!") 380 | print(f"Filename: {result['filename']}") 381 | print(f"Format: {result['output_format']}") 382 | print(f"Pages: {result['page_count']}") 383 | 384 | output_file = f"{pdf_path.stem}_response.json" 385 | with open(output_file, 'w', encoding='utf-8') as f: 386 | json.dump(result, f, indent=2, ensure_ascii=False) 387 | print(f"Full API response saved to: {output_file}") 388 | 389 | if result['images']: 390 | print(f"Images extracted: {len(result['images'])}") 391 | 392 | else: 393 | print(f"❌ Conversion failed: {response.status_code}") 394 | print(f"Error: {response.text}") 395 | 396 | except Exception as e: 397 | print(f"Request failed: {e}") 398 | ``` -------------------------------------------------------------------------------- /marker/processors/llm/llm_table_merge.py: -------------------------------------------------------------------------------- ```python 1 | from concurrent.futures import ThreadPoolExecutor, as_completed 2 | from typing import Annotated, List, Tuple, Literal 3 | 4 | from pydantic import BaseModel 5 | from tqdm import tqdm 6 | from PIL import Image 7 | 8 | from marker.output import json_to_html 9 | from marker.processors.llm import BaseLLMComplexBlockProcessor 10 | from marker.schema import BlockTypes 11 | from marker.schema.blocks import Block, TableCell 12 | from marker.schema.document import Document 13 | from marker.logger import get_logger 14 | 15 | logger = get_logger() 16 | 17 | class LLMTableMergeProcessor(BaseLLMComplexBlockProcessor): 18 | block_types: Annotated[ 19 | Tuple[BlockTypes], 20 | "The block types to process.", 21 | ] = (BlockTypes.Table, BlockTypes.TableOfContents) 22 | table_height_threshold: Annotated[ 23 | float, 24 | "The minimum height ratio relative to the page for the first table in a pair to be considered for merging.", 25 | ] = 0.6 26 | table_start_threshold: Annotated[ 27 | float, 28 | "The maximum percentage down the page the second table can start to be considered for merging." 29 | ] = 0.2 30 | vertical_table_height_threshold: Annotated[ 31 | float, 32 | "The height tolerance for 2 adjacent tables to be merged into one." 33 | ] = 0.25 34 | vertical_table_distance_threshold: Annotated[ 35 | int, 36 | "The maximum distance between table edges for adjacency." 37 | ] = 20 38 | horizontal_table_width_threshold: Annotated[ 39 | float, 40 | "The width tolerance for 2 adjacent tables to be merged into one." 41 | ] = 0.25 42 | horizontal_table_distance_threshold: Annotated[ 43 | int, 44 | "The maximum distance between table edges for adjacency." 45 | ] = 10 46 | column_gap_threshold: Annotated[ 47 | int, 48 | "The maximum gap between columns to merge tables" 49 | ] = 50 50 | disable_tqdm: Annotated[ 51 | bool, 52 | "Whether to disable the tqdm progress bar.", 53 | ] = False 54 | no_merge_tables_across_pages: Annotated[ 55 | bool, 56 | "Whether to disable merging tables across pages and keep page delimiters.", 57 | ] = False 58 | table_merge_prompt: Annotated[ 59 | str, 60 | "The prompt to use for rewriting text.", 61 | "Default is a string containing the Gemini rewriting prompt." 62 | ] = """You're a text correction expert specializing in accurately reproducing tables from PDFs. 63 | You'll receive two images of tables from successive pages of a PDF. Table 1 is from the first page, and Table 2 is from the second page. Both tables may actually be part of the same larger table. Your job is to decide if Table 2 should be merged with Table 1, and how they should be joined. The should only be merged if they're part of the same larger table, and Table 2 cannot be interpreted without merging. 64 | 65 | You'll specify your judgement in json format - first whether Table 2 should be merged with Table 1, then the direction of the merge, either `bottom` or `right`. A bottom merge means that the rows of Table 2 are joined to the rows of Table 1. A right merge means that the columns of Table 2 are joined to the columns of Table 1. (bottom merge is equal to np.vstack, right merge is equal to np.hstack) 66 | 67 | Table 2 should be merged at the bottom of Table 1 if Table 2 has no headers, and the rows have similar values, meaning that Table 2 continues Table 1. Table 2 should be merged to the right of Table 1 if each row in Table 2 matches a row in Table 1, meaning that Table 2 contains additional columns that augment Table 1. 68 | 69 | Only merge Table 1 and Table 2 if Table 2 cannot be interpreted without merging. Only merge Table 1 and Table 2 if you can read both images properly. 70 | 71 | **Instructions:** 72 | 1. Carefully examine the provided table images. Table 1 is the first image, and Table 2 is the second image. 73 | 2. Examine the provided html representations of Table 1 and Table 2. 74 | 3. Write a description of Table 1. 75 | 4. Write a description of Table 2. 76 | 5. Analyze whether Table 2 should be merged into Table 1, and write an explanation. 77 | 6. Output your decision on whether they should be merged, and merge direction. 78 | **Example:** 79 | Input: 80 | Table 1 81 | ```html 82 | <table> 83 | <tr> 84 | <th>Name</th> 85 | <th>Age</th> 86 | <th>City</th> 87 | <th>State</th> 88 | </tr> 89 | <tr> 90 | <td>John</td> 91 | <td>25</td> 92 | <td>Chicago</td> 93 | <td>IL</td> 94 | </tr> 95 | ``` 96 | Table 2 97 | ```html 98 | <table> 99 | <tr> 100 | <td>Jane</td> 101 | <td>30</td> 102 | <td>Los Angeles</td> 103 | <td>CA</td> 104 | </tr> 105 | ``` 106 | Output: 107 | ```json 108 | { 109 | "table1_description": "Table 1 has 4 headers, and 1 row. The headers are Name, Age, City, and State.", 110 | "table2_description": "Table 2 has no headers, but the values appear to represent a person's name, age, city, and state.", 111 | "explanation": "The values in Table 2 match the headers in Table 1, and Table 2 has no headers. Table 2 should be merged to the bottom of Table 1.", 112 | "merge": "true", 113 | "direction": "bottom" 114 | } 115 | ``` 116 | **Input:** 117 | Table 1 118 | ```html 119 | {{table1}} 120 | Table 2 121 | ```html 122 | {{table2}} 123 | ``` 124 | """ 125 | 126 | @staticmethod 127 | def get_row_count(cells: List[TableCell]): 128 | if not cells: 129 | return 0 130 | 131 | max_rows = None 132 | for col_id in set([cell.col_id for cell in cells]): 133 | col_cells = [cell for cell in cells if cell.col_id == col_id] 134 | rows = 0 135 | for cell in col_cells: 136 | rows += cell.rowspan 137 | if max_rows is None or rows > max_rows: 138 | max_rows = rows 139 | return max_rows 140 | 141 | @staticmethod 142 | def get_column_count(cells: List[TableCell]): 143 | if not cells: 144 | return 0 145 | 146 | max_cols = None 147 | for row_id in set([cell.row_id for cell in cells]): 148 | row_cells = [cell for cell in cells if cell.row_id == row_id] 149 | cols = 0 150 | for cell in row_cells: 151 | cols += cell.colspan 152 | if max_cols is None or cols > max_cols: 153 | max_cols = cols 154 | return max_cols 155 | 156 | def rewrite_blocks(self, document: Document): 157 | # Skip table merging if disabled via config 158 | if self.no_merge_tables_across_pages: 159 | logger.info("Skipping table merging across pages due to --no_merge_tables_across_pages flag") 160 | return 161 | 162 | table_runs = [] 163 | table_run = [] 164 | prev_block = None 165 | prev_page_block_count = None 166 | for page in document.pages: 167 | page_blocks = page.contained_blocks(document, self.block_types) 168 | for block in page_blocks: 169 | merge_condition = False 170 | if prev_block is not None: 171 | prev_cells = prev_block.contained_blocks(document, (BlockTypes.TableCell,)) 172 | curr_cells = block.contained_blocks(document, (BlockTypes.TableCell,)) 173 | row_match = abs(self.get_row_count(prev_cells) - self.get_row_count(curr_cells)) < 5, # Similar number of rows 174 | col_match = abs(self.get_column_count(prev_cells) - self.get_column_count(curr_cells)) < 2 175 | 176 | subsequent_page_table = all([ 177 | prev_block.page_id == block.page_id - 1, # Subsequent pages 178 | max(prev_block.polygon.height / page.polygon.height, 179 | block.polygon.height / page.polygon.height) > self.table_height_threshold, # Take up most of the page height 180 | (len(page_blocks) == 1 or prev_page_block_count == 1), # Only table on the page 181 | (row_match or col_match) 182 | ]) 183 | 184 | same_page_vertical_table = all([ 185 | prev_block.page_id == block.page_id, # On the same page 186 | (1 - self.vertical_table_height_threshold) < prev_block.polygon.height / block.polygon.height < (1 + self.vertical_table_height_threshold), # Similar height 187 | abs(block.polygon.x_start - prev_block.polygon.x_end) < self.vertical_table_distance_threshold, # Close together in x 188 | abs(block.polygon.y_start - prev_block.polygon.y_start) < self.vertical_table_distance_threshold, # Close together in y 189 | row_match 190 | ]) 191 | 192 | same_page_horizontal_table = all([ 193 | prev_block.page_id == block.page_id, # On the same page 194 | (1 - self.horizontal_table_width_threshold) < prev_block.polygon.width / block.polygon.width < (1 + self.horizontal_table_width_threshold), # Similar width 195 | abs(block.polygon.y_start - prev_block.polygon.y_end) < self.horizontal_table_distance_threshold, # Close together in y 196 | abs(block.polygon.x_start - prev_block.polygon.x_start) < self.horizontal_table_distance_threshold, # Close together in x 197 | col_match 198 | ]) 199 | 200 | same_page_new_column = all([ 201 | prev_block.page_id == block.page_id, # On the same page 202 | abs(block.polygon.x_start - prev_block.polygon.x_end) < self.column_gap_threshold, 203 | block.polygon.y_start < prev_block.polygon.y_end, 204 | block.polygon.width * (1 - self.vertical_table_height_threshold) < prev_block.polygon.width < block.polygon.width * (1 + self.vertical_table_height_threshold), # Similar width 205 | col_match 206 | ]) 207 | merge_condition = any([subsequent_page_table, same_page_vertical_table, same_page_new_column, same_page_horizontal_table]) 208 | 209 | if prev_block is not None and merge_condition: 210 | if prev_block not in table_run: 211 | table_run.append(prev_block) 212 | table_run.append(block) 213 | else: 214 | if table_run: 215 | table_runs.append(table_run) 216 | table_run = [] 217 | prev_block = block 218 | prev_page_block_count = len(page_blocks) 219 | 220 | if table_run: 221 | table_runs.append(table_run) 222 | 223 | # Don't show progress if there is nothing to process 224 | total_table_runs = len(table_runs) 225 | if total_table_runs == 0: 226 | return 227 | 228 | pbar = tqdm( 229 | total=total_table_runs, 230 | desc=f"{self.__class__.__name__} running", 231 | disable=self.disable_tqdm, 232 | ) 233 | 234 | with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor: 235 | for future in as_completed([ 236 | executor.submit(self.process_rewriting, document, blocks) 237 | for blocks in table_runs 238 | ]): 239 | future.result() # Raise exceptions if any occurred 240 | pbar.update(1) 241 | 242 | pbar.close() 243 | 244 | def process_rewriting(self, document: Document, blocks: List[Block]): 245 | if len(blocks) < 2: 246 | # Can't merge single tables 247 | return 248 | 249 | start_block = blocks[0] 250 | for i in range(1, len(blocks)): 251 | curr_block = blocks[i] 252 | children = start_block.contained_blocks(document, (BlockTypes.TableCell,)) 253 | children_curr = curr_block.contained_blocks(document, (BlockTypes.TableCell,)) 254 | if not children or not children_curr: 255 | # Happens if table/form processors didn't run 256 | break 257 | 258 | start_image = start_block.get_image(document, highres=False) 259 | curr_image = curr_block.get_image(document, highres=False) 260 | start_html = json_to_html(start_block.render(document)) 261 | curr_html = json_to_html(curr_block.render(document)) 262 | 263 | prompt = self.table_merge_prompt.replace("{{table1}}", start_html).replace("{{table2}}", curr_html) 264 | 265 | response = self.llm_service( 266 | prompt, 267 | [start_image, curr_image], 268 | curr_block, 269 | MergeSchema, 270 | ) 271 | 272 | if not response or ("direction" not in response or "merge" not in response): 273 | curr_block.update_metadata(llm_error_count=1) 274 | break 275 | 276 | merge = response["merge"] 277 | 278 | # The original table is okay 279 | if "true" not in merge: 280 | start_block = curr_block 281 | continue 282 | 283 | # Merge the cells and images of the tables 284 | direction = response["direction"] 285 | if not self.validate_merge(children, children_curr, direction): 286 | start_block = curr_block 287 | continue 288 | 289 | merged_image = self.join_images(start_image, curr_image, direction) 290 | merged_cells = self.join_cells(children, children_curr, direction) 291 | curr_block.structure = [] 292 | start_block.structure = [b.id for b in merged_cells] 293 | start_block.lowres_image = merged_image 294 | 295 | def validate_merge(self, cells1: List[TableCell], cells2: List[TableCell], direction: Literal['right', 'bottom'] = 'right'): 296 | if direction == "right": 297 | # Check if the number of rows is the same 298 | cells1_row_count = self.get_row_count(cells1) 299 | cells2_row_count = self.get_row_count(cells2) 300 | return abs(cells1_row_count - cells2_row_count) < 5 301 | elif direction == "bottom": 302 | # Check if the number of columns is the same 303 | cells1_col_count = self.get_column_count(cells1) 304 | cells2_col_count = self.get_column_count(cells2) 305 | return abs(cells1_col_count - cells2_col_count) < 2 306 | 307 | 308 | def join_cells(self, cells1: List[TableCell], cells2: List[TableCell], direction: Literal['right', 'bottom'] = 'right') -> List[TableCell]: 309 | if direction == 'right': 310 | # Shift columns right 311 | col_count = self.get_column_count(cells1) 312 | for cell in cells2: 313 | cell.col_id += col_count 314 | new_cells = cells1 + cells2 315 | else: 316 | # Shift rows up 317 | row_count = self.get_row_count(cells1) 318 | for cell in cells2: 319 | cell.row_id += row_count 320 | new_cells = cells1 + cells2 321 | return new_cells 322 | 323 | @staticmethod 324 | def join_images(image1: Image.Image, image2: Image.Image, direction: Literal['right', 'bottom'] = 'right') -> Image.Image: 325 | # Get dimensions 326 | w1, h1 = image1.size 327 | w2, h2 = image2.size 328 | 329 | if direction == 'right': 330 | new_height = max(h1, h2) 331 | new_width = w1 + w2 332 | new_img = Image.new('RGB', (new_width, new_height), 'white') 333 | new_img.paste(image1, (0, 0)) 334 | new_img.paste(image2, (w1, 0)) 335 | else: 336 | new_width = max(w1, w2) 337 | new_height = h1 + h2 338 | new_img = Image.new('RGB', (new_width, new_height), 'white') 339 | new_img.paste(image1, (0, 0)) 340 | new_img.paste(image2, (0, h1)) 341 | return new_img 342 | 343 | 344 | class MergeSchema(BaseModel): 345 | table1_description: str 346 | table2_description: str 347 | explanation: str 348 | merge: Literal["true", "false"] 349 | direction: Literal["bottom", "right"] ``` -------------------------------------------------------------------------------- /marker/providers/pdf.py: -------------------------------------------------------------------------------- ```python 1 | import contextlib 2 | import ctypes 3 | import logging 4 | import re 5 | from typing import Annotated, Dict, List, Optional, Set 6 | 7 | import pypdfium2 as pdfium 8 | import pypdfium2.raw as pdfium_c 9 | from ftfy import fix_text 10 | from pdftext.extraction import dictionary_output 11 | from pdftext.schema import Reference 12 | from pdftext.pdf.utils import flatten as flatten_pdf_page 13 | 14 | from PIL import Image 15 | from pypdfium2 import PdfiumError, PdfDocument 16 | 17 | from marker.providers import BaseProvider, ProviderOutput, Char, ProviderPageLines 18 | from marker.providers.utils import alphanum_ratio 19 | from marker.schema import BlockTypes 20 | from marker.schema.polygon import PolygonBox 21 | from marker.schema.registry import get_block_class 22 | from marker.schema.text.line import Line 23 | from marker.schema.text.span import Span 24 | 25 | # Ignore pypdfium2 warning about form flattening 26 | logging.getLogger("pypdfium2").setLevel(logging.ERROR) 27 | 28 | 29 | class PdfProvider(BaseProvider): 30 | """ 31 | A provider for PDF files. 32 | """ 33 | 34 | page_range: Annotated[ 35 | List[int], 36 | "The range of pages to process.", 37 | "Default is None, which will process all pages.", 38 | ] = None 39 | pdftext_workers: Annotated[ 40 | int, 41 | "The number of workers to use for pdftext.", 42 | ] = 4 43 | flatten_pdf: Annotated[ 44 | bool, 45 | "Whether to flatten the PDF structure.", 46 | ] = True 47 | force_ocr: Annotated[ 48 | bool, 49 | "Whether to force OCR on the whole document.", 50 | ] = False 51 | ocr_invalid_chars: Annotated[ 52 | tuple, 53 | "The characters to consider invalid for OCR.", 54 | ] = (chr(0xFFFD), "�") 55 | ocr_space_threshold: Annotated[ 56 | float, 57 | "The minimum ratio of spaces to non-spaces to detect bad text.", 58 | ] = 0.7 59 | ocr_newline_threshold: Annotated[ 60 | float, 61 | "The minimum ratio of newlines to non-newlines to detect bad text.", 62 | ] = 0.6 63 | ocr_alphanum_threshold: Annotated[ 64 | float, 65 | "The minimum ratio of alphanumeric characters to non-alphanumeric characters to consider an alphanumeric character.", 66 | ] = 0.3 67 | image_threshold: Annotated[ 68 | float, 69 | "The minimum coverage ratio of the image to the page to consider skipping the page.", 70 | ] = 0.65 71 | strip_existing_ocr: Annotated[ 72 | bool, 73 | "Whether to strip existing OCR text from the PDF.", 74 | ] = False 75 | disable_links: Annotated[ 76 | bool, 77 | "Whether to disable links.", 78 | ] = False 79 | keep_chars: Annotated[ 80 | bool, 81 | "Whether to keep character-level information in the output.", 82 | ] = False 83 | 84 | def __init__(self, filepath: str, config=None): 85 | super().__init__(filepath, config) 86 | 87 | self.filepath = filepath 88 | 89 | with self.get_doc() as doc: 90 | self.page_count = len(doc) 91 | self.page_lines: ProviderPageLines = {i: [] for i in range(len(doc))} 92 | self.page_refs: Dict[int, List[Reference]] = { 93 | i: [] for i in range(len(doc)) 94 | } 95 | 96 | if self.page_range is None: 97 | self.page_range = range(len(doc)) 98 | 99 | assert max(self.page_range) < len(doc) and min(self.page_range) >= 0, ( 100 | f"Invalid page range, values must be between 0 and {len(doc) - 1}. Min of provided page range is {min(self.page_range)} and max is {max(self.page_range)}." 101 | ) 102 | 103 | if self.force_ocr: 104 | # Manually assign page bboxes, since we can't get them from pdftext 105 | self.page_bboxes = {i: doc[i].get_bbox() for i in self.page_range} 106 | else: 107 | self.page_lines = self.pdftext_extraction(doc) 108 | 109 | @contextlib.contextmanager 110 | def get_doc(self): 111 | doc = None 112 | try: 113 | doc = pdfium.PdfDocument(self.filepath) 114 | 115 | # Must be called on the parent pdf, before retrieving pages to render correctly 116 | if self.flatten_pdf: 117 | doc.init_forms() 118 | 119 | yield doc 120 | finally: 121 | if doc: 122 | doc.close() 123 | 124 | def __len__(self) -> int: 125 | return self.page_count 126 | 127 | def font_flags_to_format(self, flags: Optional[int]) -> Set[str]: 128 | if flags is None: 129 | return {"plain"} 130 | 131 | flag_map = { 132 | 1: "FixedPitch", 133 | 2: "Serif", 134 | 3: "Symbolic", 135 | 4: "Script", 136 | 6: "Nonsymbolic", 137 | 7: "Italic", 138 | 17: "AllCap", 139 | 18: "SmallCap", 140 | 19: "ForceBold", 141 | 20: "UseExternAttr", 142 | } 143 | set_flags = set() 144 | for bit_position, flag_name in flag_map.items(): 145 | if flags & (1 << (bit_position - 1)): 146 | set_flags.add(flag_name) 147 | if not set_flags: 148 | set_flags.add("Plain") 149 | 150 | formats = set() 151 | if set_flags == {"Symbolic", "Italic"} or set_flags == { 152 | "Symbolic", 153 | "Italic", 154 | "UseExternAttr", 155 | }: 156 | formats.add("plain") 157 | elif set_flags == {"UseExternAttr"}: 158 | formats.add("plain") 159 | elif set_flags == {"Plain"}: 160 | formats.add("plain") 161 | else: 162 | if set_flags & {"Italic"}: 163 | formats.add("italic") 164 | if set_flags & {"ForceBold"}: 165 | formats.add("bold") 166 | if set_flags & { 167 | "FixedPitch", 168 | "Serif", 169 | "Script", 170 | "Nonsymbolic", 171 | "AllCap", 172 | "SmallCap", 173 | "UseExternAttr", 174 | }: 175 | formats.add("plain") 176 | return formats 177 | 178 | def font_names_to_format(self, font_name: str | None) -> Set[str]: 179 | formats = set() 180 | if font_name is None: 181 | return formats 182 | 183 | if "bold" in font_name.lower(): 184 | formats.add("bold") 185 | if "ital" in font_name.lower(): 186 | formats.add("italic") 187 | return formats 188 | 189 | @staticmethod 190 | def normalize_spaces(text): 191 | space_chars = [ 192 | "\u2003", # em space 193 | "\u2002", # en space 194 | "\u00a0", # non-breaking space 195 | "\u200b", # zero-width space 196 | "\u3000", # ideographic space 197 | ] 198 | for space in space_chars: 199 | text = text.replace(space, " ") 200 | return text 201 | 202 | def pdftext_extraction(self, doc: PdfDocument) -> ProviderPageLines: 203 | page_lines: ProviderPageLines = {} 204 | page_char_blocks = dictionary_output( 205 | self.filepath, 206 | page_range=self.page_range, 207 | keep_chars=self.keep_chars, 208 | workers=self.pdftext_workers, 209 | flatten_pdf=self.flatten_pdf, 210 | quote_loosebox=False, 211 | disable_links=self.disable_links, 212 | ) 213 | self.page_bboxes = { 214 | i: [0, 0, page["width"], page["height"]] 215 | for i, page in zip(self.page_range, page_char_blocks) 216 | } 217 | 218 | SpanClass: Span = get_block_class(BlockTypes.Span) 219 | LineClass: Line = get_block_class(BlockTypes.Line) 220 | CharClass: Char = get_block_class(BlockTypes.Char) 221 | 222 | for page in page_char_blocks: 223 | page_id = page["page"] 224 | lines: List[ProviderOutput] = [] 225 | if not self.check_page(page_id, doc): 226 | continue 227 | 228 | for block in page["blocks"]: 229 | for line in block["lines"]: 230 | spans: List[Span] = [] 231 | chars: List[List[Char]] = [] 232 | for span in line["spans"]: 233 | if not span["text"]: 234 | continue 235 | font_formats = self.font_flags_to_format( 236 | span["font"]["flags"] 237 | ).union(self.font_names_to_format(span["font"]["name"])) 238 | font_name = span["font"]["name"] or "Unknown" 239 | font_weight = span["font"]["weight"] or 0 240 | font_size = span["font"]["size"] or 0 241 | polygon = PolygonBox.from_bbox( 242 | span["bbox"], ensure_nonzero_area=True 243 | ) 244 | superscript = span.get("superscript", False) 245 | subscript = span.get("subscript", False) 246 | text = self.normalize_spaces(fix_text(span["text"])) 247 | if superscript or subscript: 248 | text = text.strip() 249 | 250 | spans.append( 251 | SpanClass( 252 | polygon=polygon, 253 | text=text, 254 | font=font_name, 255 | font_weight=font_weight, 256 | font_size=font_size, 257 | minimum_position=span["char_start_idx"], 258 | maximum_position=span["char_end_idx"], 259 | formats=list(font_formats), 260 | page_id=page_id, 261 | text_extraction_method="pdftext", 262 | url=span.get("url"), 263 | has_superscript=superscript, 264 | has_subscript=subscript, 265 | ) 266 | ) 267 | 268 | if self.keep_chars: 269 | span_chars = [ 270 | CharClass( 271 | text=c["char"], 272 | polygon=PolygonBox.from_bbox( 273 | c["bbox"], ensure_nonzero_area=True 274 | ), 275 | idx=c["char_idx"], 276 | ) 277 | for c in span["chars"] 278 | ] 279 | chars.append(span_chars) 280 | else: 281 | chars.append([]) 282 | 283 | polygon = PolygonBox.from_bbox( 284 | line["bbox"], ensure_nonzero_area=True 285 | ) 286 | 287 | assert len(spans) == len(chars), ( 288 | f"Spans and chars length mismatch on page {page_id}: {len(spans)} spans, {len(chars)} chars" 289 | ) 290 | lines.append( 291 | ProviderOutput( 292 | line=LineClass(polygon=polygon, page_id=page_id), 293 | spans=spans, 294 | chars=chars, 295 | ) 296 | ) 297 | if self.check_line_spans(lines): 298 | page_lines[page_id] = lines 299 | 300 | self.page_refs[page_id] = [] 301 | if page_refs := page.get("refs", None): 302 | self.page_refs[page_id] = page_refs 303 | 304 | return page_lines 305 | 306 | def check_line_spans(self, page_lines: List[ProviderOutput]) -> bool: 307 | page_spans = [span for line in page_lines for span in line.spans] 308 | if len(page_spans) == 0: 309 | return False 310 | 311 | text = "" 312 | for span in page_spans: 313 | text = text + " " + span.text 314 | text = text + "\n" 315 | if len(text.strip()) == 0: 316 | return False 317 | if self.detect_bad_ocr(text): 318 | return False 319 | return True 320 | 321 | def check_page(self, page_id: int, doc: PdfDocument) -> bool: 322 | page = doc.get_page(page_id) 323 | page_bbox = PolygonBox.from_bbox(page.get_bbox()) 324 | try: 325 | page_objs = list( 326 | page.get_objects( 327 | filter=[pdfium_c.FPDF_PAGEOBJ_TEXT, pdfium_c.FPDF_PAGEOBJ_IMAGE] 328 | ) 329 | ) 330 | except PdfiumError: 331 | # Happens when pdfium fails to get the number of page objects 332 | return False 333 | 334 | # if we do not see any text objects in the pdf, we can skip this page 335 | if not any([obj.type == pdfium_c.FPDF_PAGEOBJ_TEXT for obj in page_objs]): 336 | return False 337 | 338 | if self.strip_existing_ocr: 339 | # If any text objects on the page are in invisible render mode, skip this page 340 | for text_obj in filter( 341 | lambda obj: obj.type == pdfium_c.FPDF_PAGEOBJ_TEXT, page_objs 342 | ): 343 | if pdfium_c.FPDFTextObj_GetTextRenderMode(text_obj) in [ 344 | pdfium_c.FPDF_TEXTRENDERMODE_INVISIBLE, 345 | pdfium_c.FPDF_TEXTRENDERMODE_UNKNOWN, 346 | ]: 347 | return False 348 | 349 | non_embedded_fonts = [] 350 | empty_fonts = [] 351 | font_map = {} 352 | for text_obj in filter( 353 | lambda obj: obj.type == pdfium_c.FPDF_PAGEOBJ_TEXT, page_objs 354 | ): 355 | font = pdfium_c.FPDFTextObj_GetFont(text_obj) 356 | font_name = self._get_fontname(font) 357 | 358 | # we also skip pages without embedded fonts and fonts without names 359 | non_embedded_fonts.append(pdfium_c.FPDFFont_GetIsEmbedded(font) == 0) 360 | empty_fonts.append( 361 | "glyphless" in font_name.lower() 362 | ) # Add font name check back in when we bump pypdfium2 363 | if font_name not in font_map: 364 | font_map[font_name or "Unknown"] = font 365 | 366 | if all(non_embedded_fonts) or all(empty_fonts): 367 | return False 368 | 369 | # if we see very large images covering most of the page, we can skip this page 370 | for img_obj in filter( 371 | lambda obj: obj.type == pdfium_c.FPDF_PAGEOBJ_IMAGE, page_objs 372 | ): 373 | img_bbox = PolygonBox.from_bbox(img_obj.get_pos()) 374 | if page_bbox.intersection_pct(img_bbox) >= self.image_threshold: 375 | return False 376 | 377 | return True 378 | 379 | def detect_bad_ocr(self, text): 380 | if len(text) == 0: 381 | # Assume OCR failed if we have no text 382 | return True 383 | 384 | spaces = len(re.findall(r"\s+", text)) 385 | alpha_chars = len(re.sub(r"\s+", "", text)) 386 | if spaces / (alpha_chars + spaces) > self.ocr_space_threshold: 387 | return True 388 | 389 | newlines = len(re.findall(r"\n+", text)) 390 | non_newlines = len(re.sub(r"\n+", "", text)) 391 | if newlines / (newlines + non_newlines) > self.ocr_newline_threshold: 392 | return True 393 | 394 | if alphanum_ratio(text) < self.ocr_alphanum_threshold: # Garbled text 395 | return True 396 | 397 | invalid_chars = len([c for c in text if c in self.ocr_invalid_chars]) 398 | if invalid_chars > max(6.0, len(text) * 0.03): 399 | return True 400 | 401 | return False 402 | 403 | @staticmethod 404 | def _render_image( 405 | pdf: pdfium.PdfDocument, idx: int, dpi: int, flatten_page: bool 406 | ) -> Image.Image: 407 | page = pdf[idx] 408 | if flatten_page: 409 | flatten_pdf_page(page) 410 | page = pdf[idx] 411 | image = page.render(scale=dpi / 72, draw_annots=False).to_pil() 412 | image = image.convert("RGB") 413 | return image 414 | 415 | def get_images(self, idxs: List[int], dpi: int) -> List[Image.Image]: 416 | with self.get_doc() as doc: 417 | images = [ 418 | self._render_image(doc, idx, dpi, self.flatten_pdf) for idx in idxs 419 | ] 420 | return images 421 | 422 | def get_page_bbox(self, idx: int) -> PolygonBox | None: 423 | bbox = self.page_bboxes.get(idx) 424 | if bbox: 425 | return PolygonBox.from_bbox(bbox) 426 | 427 | def get_page_lines(self, idx: int) -> List[ProviderOutput]: 428 | return self.page_lines[idx] 429 | 430 | def get_page_refs(self, idx: int) -> List[Reference]: 431 | return self.page_refs[idx] 432 | 433 | @staticmethod 434 | def _get_fontname(font) -> str: 435 | font_name = "" 436 | buffer_size = 256 437 | 438 | try: 439 | font_name_buffer = ctypes.create_string_buffer(buffer_size) 440 | length = pdfium_c.FPDFFont_GetBaseFontName( 441 | font, font_name_buffer, buffer_size 442 | ) 443 | if length < buffer_size: 444 | font_name = font_name_buffer.value.decode("utf-8") 445 | else: 446 | font_name_buffer = ctypes.create_string_buffer(length) 447 | pdfium_c.FPDFFont_GetBaseFontName(font, font_name_buffer, length) 448 | font_name = font_name_buffer.value.decode("utf-8") 449 | except Exception: 450 | pass 451 | 452 | return font_name 453 | ``` -------------------------------------------------------------------------------- /data/examples/markdown/multicolcnn/multicolcnn_meta.json: -------------------------------------------------------------------------------- ```json 1 | { 2 | "table_of_contents": [ 3 | { 4 | "title": "An Aggregated Multicolumn Dilated Convolution Network\nfor Perspective-Free Counting", 5 | "heading_level": null, 6 | "page_id": 0, 7 | "polygon": [ 8 | [ 9 | 117.5888671875, 10 | 105.9219970703125 11 | ], 12 | [ 13 | 477.371826171875, 14 | 105.9219970703125 15 | ], 16 | [ 17 | 477.371826171875, 18 | 138.201171875 19 | ], 20 | [ 21 | 117.5888671875, 22 | 138.201171875 23 | ] 24 | ] 25 | }, 26 | { 27 | "title": "Abstract", 28 | "heading_level": null, 29 | "page_id": 0, 30 | "polygon": [ 31 | [ 32 | 144.1845703125, 33 | 232.4891357421875 34 | ], 35 | [ 36 | 190.48028564453125, 37 | 232.4891357421875 38 | ], 39 | [ 40 | 190.48028564453125, 41 | 244.4443359375 42 | ], 43 | [ 44 | 144.1845703125, 45 | 244.4443359375 46 | ] 47 | ] 48 | }, 49 | { 50 | "title": "1. Introduction", 51 | "heading_level": null, 52 | "page_id": 0, 53 | "polygon": [ 54 | [ 55 | 50.016357421875, 56 | 512.06591796875 57 | ], 58 | [ 59 | 128.49609375, 60 | 512.06591796875 61 | ], 62 | [ 63 | 128.49609375, 64 | 524.0211181640625 65 | ], 66 | [ 67 | 50.016357421875, 68 | 524.0211181640625 69 | ] 70 | ] 71 | }, 72 | { 73 | "title": "2. Related Work", 74 | "heading_level": null, 75 | "page_id": 0, 76 | "polygon": [ 77 | [ 78 | 307.1953125, 79 | 621.7747497558594 80 | ], 81 | [ 82 | 392.0625, 83 | 621.7747497558594 84 | ], 85 | [ 86 | 392.0625, 87 | 633.7299499511719 88 | ], 89 | [ 90 | 307.1953125, 91 | 633.7299499511719 92 | ] 93 | ] 94 | }, 95 | { 96 | "title": "3. Method", 97 | "heading_level": null, 98 | "page_id": 2, 99 | "polygon": [ 100 | [ 101 | 49.4560546875, 102 | 371.27313232421875 103 | ], 104 | [ 105 | 101.91387939453125, 106 | 371.27313232421875 107 | ], 108 | [ 109 | 101.91387939453125, 110 | 383.22833251953125 111 | ], 112 | [ 113 | 49.4560546875, 114 | 383.22833251953125 115 | ] 116 | ] 117 | }, 118 | { 119 | "title": "3.1. Dilated Convolutions for Multicolumn Net-\nworks", 120 | "heading_level": null, 121 | "page_id": 2, 122 | "polygon": [ 123 | [ 124 | 49.53076171875, 125 | 391.4488220214844 126 | ], 127 | [ 128 | 287.173828125, 129 | 391.4488220214844 130 | ], 131 | [ 132 | 287.173828125, 133 | 414.3627014160156 134 | ], 135 | [ 136 | 49.53076171875, 137 | 414.3627014160156 138 | ] 139 | ] 140 | }, 141 | { 142 | "title": "3.2. Experiments", 143 | "heading_level": null, 144 | "page_id": 3, 145 | "polygon": [ 146 | [ 147 | 49.119873046875, 148 | 263.935546875 149 | ], 150 | [ 151 | 128.95028686523438, 152 | 263.935546875 153 | ], 154 | [ 155 | 128.95028686523438, 156 | 274.936767578125 157 | ], 158 | [ 159 | 49.119873046875, 160 | 274.936767578125 161 | ] 162 | ] 163 | }, 164 | { 165 | "title": "3.2.1 UCF50 Crowd Counting", 166 | "heading_level": null, 167 | "page_id": 3, 168 | "polygon": [ 169 | [ 170 | 307.79296875, 171 | 339.732421875 172 | ], 173 | [ 174 | 443.4609375, 175 | 339.732421875 176 | ], 177 | [ 178 | 443.4609375, 179 | 350.13201904296875 180 | ], 181 | [ 182 | 307.79296875, 183 | 350.13201904296875 184 | ] 185 | ] 186 | }, 187 | { 188 | "title": "3.2.2 TRANCOS Traffic Counting", 189 | "heading_level": null, 190 | "page_id": 3, 191 | "polygon": [ 192 | [ 193 | 308.689453125, 194 | 624.1640625 195 | ], 196 | [ 197 | 461.689453125, 198 | 624.1640625 199 | ], 200 | [ 201 | 461.689453125, 202 | 634.7828826904297 203 | ], 204 | [ 205 | 308.689453125, 206 | 634.7828826904297 207 | ] 208 | ] 209 | }, 210 | { 211 | "title": "3.2.3 UCSD Crowd Counting", 212 | "heading_level": null, 213 | "page_id": 4, 214 | "polygon": [ 215 | [ 216 | 49.38134765625, 217 | 314.06341552734375 218 | ], 219 | [ 220 | 182.28515625, 221 | 314.06341552734375 222 | ], 223 | [ 224 | 182.28515625, 225 | 324.0260009765625 226 | ], 227 | [ 228 | 49.38134765625, 229 | 324.0260009765625 230 | ] 231 | ] 232 | }, 233 | { 234 | "title": "3.2.4 WorldExpo '10 Crowd Counting", 235 | "heading_level": null, 236 | "page_id": 4, 237 | "polygon": [ 238 | [ 239 | 308.86199951171875, 240 | 259.17828369140625 241 | ], 242 | [ 243 | 477.4889221191406, 244 | 259.17828369140625 245 | ], 246 | [ 247 | 477.4889221191406, 248 | 269.140869140625 249 | ], 250 | [ 251 | 308.86199951171875, 252 | 269.140869140625 253 | ] 254 | ] 255 | }, 256 | { 257 | "title": "4. Results", 258 | "heading_level": null, 259 | "page_id": 5, 260 | "polygon": [ 261 | [ 262 | 49.343994140625, 263 | 231.4151611328125 264 | ], 265 | [ 266 | 100.5556640625, 267 | 231.4151611328125 268 | ], 269 | [ 270 | 100.5556640625, 271 | 243.370361328125 272 | ], 273 | [ 274 | 49.343994140625, 275 | 243.370361328125 276 | ] 277 | ] 278 | }, 279 | { 280 | "title": "4.1. UCF Crowd Counting", 281 | "heading_level": null, 282 | "page_id": 5, 283 | "polygon": [ 284 | [ 285 | 49.418701171875, 286 | 251.10882568359375 287 | ], 288 | [ 289 | 173.4697265625, 290 | 251.10882568359375 291 | ], 292 | [ 293 | 173.4697265625, 294 | 262.0677490234375 295 | ], 296 | [ 297 | 49.418701171875, 298 | 262.0677490234375 299 | ] 300 | ] 301 | }, 302 | { 303 | "title": "4.2. TRANCOS Traffic Counting", 304 | "heading_level": null, 305 | "page_id": 5, 306 | "polygon": [ 307 | [ 308 | 49.68017578125, 309 | 455.92767333984375 310 | ], 311 | [ 312 | 203.80078125, 313 | 455.92767333984375 314 | ], 315 | [ 316 | 203.80078125, 317 | 466.8865661621094 318 | ], 319 | [ 320 | 49.68017578125, 321 | 466.8865661621094 322 | ] 323 | ] 324 | }, 325 | { 326 | "title": "4.3. UCSD Crowd Counting", 327 | "heading_level": null, 328 | "page_id": 5, 329 | "polygon": [ 330 | [ 331 | 49.941650390625, 332 | 553.1486358642578 333 | ], 334 | [ 335 | 181.08984375, 336 | 553.1486358642578 337 | ], 338 | [ 339 | 181.08984375, 340 | 564.1075286865234 341 | ], 342 | [ 343 | 49.941650390625, 344 | 564.1075286865234 345 | ] 346 | ] 347 | }, 348 | { 349 | "title": "4.4. WorldExpo '10 Crowd Counting", 350 | "heading_level": null, 351 | "page_id": 5, 352 | "polygon": [ 353 | [ 354 | 308.689453125, 355 | 318.3517761230469 356 | ], 357 | [ 358 | 480.814453125, 359 | 318.3517761230469 360 | ], 361 | [ 362 | 480.814453125, 363 | 329.3106689453125 364 | ], 365 | [ 366 | 308.689453125, 367 | 329.3106689453125 368 | ] 369 | ] 370 | }, 371 | { 372 | "title": "4.5. Ablation Studies", 373 | "heading_level": null, 374 | "page_id": 5, 375 | "polygon": [ 376 | [ 377 | 308.689453125, 378 | 475.50469970703125 379 | ], 380 | [ 381 | 405.6838684082031, 382 | 475.50469970703125 383 | ], 384 | [ 385 | 405.6838684082031, 386 | 486.4635925292969 387 | ], 388 | [ 389 | 308.689453125, 390 | 486.4635925292969 391 | ] 392 | ] 393 | }, 394 | { 395 | "title": "5. Conclusion", 396 | "heading_level": null, 397 | "page_id": 6, 398 | "polygon": [ 399 | [ 400 | 48.48486328125, 401 | 594.6561584472656 402 | ], 403 | [ 404 | 119.20110321044922, 405 | 594.6561584472656 406 | ], 407 | [ 408 | 119.20110321044922, 409 | 607.1484375 410 | ], 411 | [ 412 | 48.48486328125, 413 | 607.1484375 414 | ] 415 | ] 416 | }, 417 | { 418 | "title": "5.1. Summary", 419 | "heading_level": null, 420 | "page_id": 6, 421 | "polygon": [ 422 | [ 423 | 49.194580078125, 424 | 619.6148376464844 425 | ], 426 | [ 427 | 115.55853271484375, 428 | 619.6148376464844 429 | ], 430 | [ 431 | 115.55853271484375, 432 | 630.73828125 433 | ], 434 | [ 435 | 49.194580078125, 436 | 630.73828125 437 | ] 438 | ] 439 | }, 440 | { 441 | "title": "5.2. Future Work", 442 | "heading_level": null, 443 | "page_id": 7, 444 | "polygon": [ 445 | [ 446 | 49.269287109375, 447 | 611.3048095703125 448 | ], 449 | [ 450 | 130.67086791992188, 451 | 611.3048095703125 452 | ], 453 | [ 454 | 130.67086791992188, 455 | 622.2637023925781 456 | ], 457 | [ 458 | 49.269287109375, 459 | 622.2637023925781 460 | ] 461 | ] 462 | }, 463 | { 464 | "title": "Acknowledgment", 465 | "heading_level": null, 466 | "page_id": 7, 467 | "polygon": [ 468 | [ 469 | 308.86199951171875, 470 | 446.23602294921875 471 | ], 472 | [ 473 | 398.337890625, 474 | 446.23602294921875 475 | ], 476 | [ 477 | 398.337890625, 478 | 458.19122314453125 479 | ], 480 | [ 481 | 308.86199951171875, 482 | 458.19122314453125 483 | ] 484 | ] 485 | }, 486 | { 487 | "title": "References", 488 | "heading_level": null, 489 | "page_id": 7, 490 | "polygon": [ 491 | [ 492 | 308.86199951171875, 493 | 571.0409851074219 494 | ], 495 | [ 496 | 365.16796875, 497 | 571.0409851074219 498 | ], 499 | [ 500 | 365.16796875, 501 | 582.9961853027344 502 | ], 503 | [ 504 | 308.86199951171875, 505 | 582.9961853027344 506 | ] 507 | ] 508 | } 509 | ], 510 | "page_stats": [ 511 | { 512 | "page_id": 0, 513 | "text_extraction_method": "pdftext", 514 | "block_counts": [ 515 | [ 516 | "Span", 517 | 176 518 | ], 519 | [ 520 | "Line", 521 | 84 522 | ], 523 | [ 524 | "Text", 525 | 10 526 | ], 527 | [ 528 | "SectionHeader", 529 | 4 530 | ], 531 | [ 532 | "PageHeader", 533 | 1 534 | ], 535 | [ 536 | "PageFooter", 537 | 1 538 | ] 539 | ], 540 | "block_metadata": { 541 | "llm_request_count": 0, 542 | "llm_error_count": 0, 543 | "llm_tokens_used": 0 544 | } 545 | }, 546 | { 547 | "page_id": 1, 548 | "text_extraction_method": "pdftext", 549 | "block_counts": [ 550 | [ 551 | "Span", 552 | 201 553 | ], 554 | [ 555 | "Line", 556 | 74 557 | ], 558 | [ 559 | "Text", 560 | 5 561 | ], 562 | [ 563 | "Figure", 564 | 1 565 | ], 566 | [ 567 | "Caption", 568 | 1 569 | ], 570 | [ 571 | "FigureGroup", 572 | 1 573 | ], 574 | [ 575 | "Reference", 576 | 1 577 | ] 578 | ], 579 | "block_metadata": { 580 | "llm_request_count": 0, 581 | "llm_error_count": 0, 582 | "llm_tokens_used": 0 583 | } 584 | }, 585 | { 586 | "page_id": 2, 587 | "text_extraction_method": "pdftext", 588 | "block_counts": [ 589 | [ 590 | "Span", 591 | 327 592 | ], 593 | [ 594 | "Line", 595 | 96 596 | ], 597 | [ 598 | "Text", 599 | 10 600 | ], 601 | [ 602 | "Reference", 603 | 3 604 | ], 605 | [ 606 | "SectionHeader", 607 | 2 608 | ], 609 | [ 610 | "Equation", 611 | 2 612 | ], 613 | [ 614 | "Picture", 615 | 1 616 | ], 617 | [ 618 | "Caption", 619 | 1 620 | ], 621 | [ 622 | "TextInlineMath", 623 | 1 624 | ], 625 | [ 626 | "Footnote", 627 | 1 628 | ], 629 | [ 630 | "PictureGroup", 631 | 1 632 | ] 633 | ], 634 | "block_metadata": { 635 | "llm_request_count": 2, 636 | "llm_error_count": 0, 637 | "llm_tokens_used": 4608 638 | } 639 | }, 640 | { 641 | "page_id": 3, 642 | "text_extraction_method": "pdftext", 643 | "block_counts": [ 644 | [ 645 | "Span", 646 | 337 647 | ], 648 | [ 649 | "Line", 650 | 109 651 | ], 652 | [ 653 | "Text", 654 | 8 655 | ], 656 | [ 657 | "SectionHeader", 658 | 3 659 | ], 660 | [ 661 | "Equation", 662 | 1 663 | ], 664 | [ 665 | "TextInlineMath", 666 | 1 667 | ], 668 | [ 669 | "Reference", 670 | 1 671 | ] 672 | ], 673 | "block_metadata": { 674 | "llm_request_count": 1, 675 | "llm_error_count": 0, 676 | "llm_tokens_used": 3057 677 | } 678 | }, 679 | { 680 | "page_id": 4, 681 | "text_extraction_method": "pdftext", 682 | "block_counts": [ 683 | [ 684 | "Span", 685 | 505 686 | ], 687 | [ 688 | "Line", 689 | 121 690 | ], 691 | [ 692 | "Text", 693 | 6 694 | ], 695 | [ 696 | "TextInlineMath", 697 | 6 698 | ], 699 | [ 700 | "Equation", 701 | 2 702 | ], 703 | [ 704 | "SectionHeader", 705 | 2 706 | ], 707 | [ 708 | "Reference", 709 | 1 710 | ] 711 | ], 712 | "block_metadata": { 713 | "llm_request_count": 2, 714 | "llm_error_count": 0, 715 | "llm_tokens_used": 3814 716 | } 717 | }, 718 | { 719 | "page_id": 5, 720 | "text_extraction_method": "pdftext", 721 | "block_counts": [ 722 | [ 723 | "Span", 724 | 332 725 | ], 726 | [ 727 | "TableCell", 728 | 113 729 | ], 730 | [ 731 | "Line", 732 | 100 733 | ], 734 | [ 735 | "Text", 736 | 7 737 | ], 738 | [ 739 | "SectionHeader", 740 | 6 741 | ], 742 | [ 743 | "Reference", 744 | 3 745 | ], 746 | [ 747 | "Table", 748 | 2 749 | ], 750 | [ 751 | "Caption", 752 | 2 753 | ], 754 | [ 755 | "TableGroup", 756 | 2 757 | ], 758 | [ 759 | "TextInlineMath", 760 | 1 761 | ] 762 | ], 763 | "block_metadata": { 764 | "llm_request_count": 3, 765 | "llm_error_count": 0, 766 | "llm_tokens_used": 7669 767 | } 768 | }, 769 | { 770 | "page_id": 6, 771 | "text_extraction_method": "pdftext", 772 | "block_counts": [ 773 | [ 774 | "Span", 775 | 229 776 | ], 777 | [ 778 | "TableCell", 779 | 180 780 | ], 781 | [ 782 | "Line", 783 | 37 784 | ], 785 | [ 786 | "Caption", 787 | 4 788 | ], 789 | [ 790 | "SectionHeader", 791 | 2 792 | ], 793 | [ 794 | "Text", 795 | 2 796 | ], 797 | [ 798 | "Reference", 799 | 2 800 | ], 801 | [ 802 | "Figure", 803 | 1 804 | ], 805 | [ 806 | "Table", 807 | 1 808 | ], 809 | [ 810 | "FigureGroup", 811 | 1 812 | ], 813 | [ 814 | "TableGroup", 815 | 1 816 | ] 817 | ], 818 | "block_metadata": { 819 | "llm_request_count": 2, 820 | "llm_error_count": 0, 821 | "llm_tokens_used": 7459 822 | } 823 | }, 824 | { 825 | "page_id": 7, 826 | "text_extraction_method": "pdftext", 827 | "block_counts": [ 828 | [ 829 | "Span", 830 | 145 831 | ], 832 | [ 833 | "Line", 834 | 68 835 | ], 836 | [ 837 | "TableCell", 838 | 32 839 | ], 840 | [ 841 | "Text", 842 | 5 843 | ], 844 | [ 845 | "Reference", 846 | 5 847 | ], 848 | [ 849 | "SectionHeader", 850 | 3 851 | ], 852 | [ 853 | "ListItem", 854 | 3 855 | ], 856 | [ 857 | "Caption", 858 | 2 859 | ], 860 | [ 861 | "Figure", 862 | 1 863 | ], 864 | [ 865 | "Table", 866 | 1 867 | ], 868 | [ 869 | "FigureGroup", 870 | 1 871 | ], 872 | [ 873 | "TableGroup", 874 | 1 875 | ], 876 | [ 877 | "ListGroup", 878 | 1 879 | ] 880 | ], 881 | "block_metadata": { 882 | "llm_request_count": 1, 883 | "llm_error_count": 0, 884 | "llm_tokens_used": 2613 885 | } 886 | }, 887 | { 888 | "page_id": 8, 889 | "text_extraction_method": "pdftext", 890 | "block_counts": [ 891 | [ 892 | "Span", 893 | 312 894 | ], 895 | [ 896 | "Line", 897 | 101 898 | ], 899 | [ 900 | "ListItem", 901 | 24 902 | ], 903 | [ 904 | "Reference", 905 | 24 906 | ], 907 | [ 908 | "ListGroup", 909 | 2 910 | ], 911 | [ 912 | "Text", 913 | 1 914 | ] 915 | ], 916 | "block_metadata": { 917 | "llm_request_count": 0, 918 | "llm_error_count": 0, 919 | "llm_tokens_used": 0 920 | } 921 | }, 922 | { 923 | "page_id": 9, 924 | "text_extraction_method": "pdftext", 925 | "block_counts": [ 926 | [ 927 | "Span", 928 | 26 929 | ], 930 | [ 931 | "Line", 932 | 7 933 | ], 934 | [ 935 | "Text", 936 | 1 937 | ], 938 | [ 939 | "ListItem", 940 | 1 941 | ], 942 | [ 943 | "Reference", 944 | 1 945 | ] 946 | ], 947 | "block_metadata": { 948 | "llm_request_count": 0, 949 | "llm_error_count": 0, 950 | "llm_tokens_used": 0 951 | } 952 | } 953 | ], 954 | "debug_data_path": "debug_data/multicolcnn" 955 | } ``` -------------------------------------------------------------------------------- /marker/builders/ocr.py: -------------------------------------------------------------------------------- ```python 1 | import copy 2 | from typing import Annotated, List 3 | 4 | from ftfy import fix_text 5 | from PIL import Image 6 | from surya.common.surya.schema import TaskNames 7 | from surya.recognition import RecognitionPredictor, OCRResult, TextChar 8 | 9 | from marker.builders import BaseBuilder 10 | from marker.providers.pdf import PdfProvider 11 | from marker.schema import BlockTypes 12 | from marker.schema.blocks import BlockId 13 | from marker.schema.blocks.base import Block 14 | from marker.schema.document import Document 15 | from marker.schema.groups import PageGroup 16 | from marker.schema.registry import get_block_class 17 | from marker.schema.text.char import Char 18 | from marker.schema.text.line import Line 19 | from marker.schema.text.span import Span 20 | from marker.settings import settings 21 | from marker.schema.polygon import PolygonBox 22 | from marker.util import get_opening_tag_type, get_closing_tag_type 23 | 24 | 25 | class OcrBuilder(BaseBuilder): 26 | """ 27 | A builder for performing OCR on PDF pages and merging the results into the document. 28 | """ 29 | 30 | recognition_batch_size: Annotated[ 31 | int, 32 | "The batch size to use for the recognition model.", 33 | "Default is None, which will use the default batch size for the model.", 34 | ] = None 35 | disable_tqdm: Annotated[ 36 | bool, 37 | "Disable tqdm progress bars.", 38 | ] = False 39 | # We can skip tables here, since the TableProcessor will re-OCR 40 | skip_ocr_blocks: Annotated[ 41 | List[BlockTypes], 42 | "Blocktypes to skip OCRing by the model in this stage." 43 | "By default, this avoids recognizing lines inside equations/tables (handled later), figures, and pictures", 44 | "Note that we **do not** have to skip group types, since they are not built by this point" 45 | ] = [ 46 | BlockTypes.Equation, 47 | BlockTypes.Figure, 48 | BlockTypes.Picture, 49 | BlockTypes.Table, 50 | BlockTypes.Form, 51 | BlockTypes.TableOfContents, 52 | ] 53 | full_ocr_block_types: Annotated[ 54 | List[BlockTypes], 55 | "Blocktypes for which OCR is done at the **block level** instead of line-level." 56 | "This feature is still in beta, and should be used sparingly." 57 | ] = [ 58 | BlockTypes.SectionHeader, 59 | BlockTypes.ListItem, 60 | BlockTypes.Footnote, 61 | BlockTypes.Text, 62 | BlockTypes.TextInlineMath, 63 | BlockTypes.Code, 64 | BlockTypes.Caption, 65 | ] 66 | ocr_task_name: Annotated[ 67 | str, 68 | "The OCR mode to use, see surya for details. Set to 'ocr_without_boxes' for potentially better performance, at the expense of formatting.", 69 | ] = TaskNames.ocr_with_boxes 70 | keep_chars: Annotated[bool, "Keep individual characters."] = False 71 | disable_ocr_math: Annotated[bool, "Disable inline math recognition in OCR"] = False 72 | drop_repeated_text: Annotated[bool, "Drop repeated text in OCR results."] = False 73 | block_mode_intersection_thresh: Annotated[float, "Max intersection before falling back to line mode"] = 0.5 74 | block_mode_max_lines: Annotated[int, "Max lines within a block before falling back to line mode"] = 15 75 | block_mode_max_height_frac: Annotated[float, "Max height of a block as a percentage of the page before falling back to line mode"] = 0.5 76 | 77 | def __init__(self, recognition_model: RecognitionPredictor, config=None): 78 | super().__init__(config) 79 | 80 | self.recognition_model = recognition_model 81 | 82 | def __call__(self, document: Document, provider: PdfProvider): 83 | pages_to_ocr = [page for page in document.pages if page.text_extraction_method == 'surya'] 84 | ocr_page_images, block_polygons, block_ids, block_original_texts = ( 85 | self.get_ocr_images_polygons_ids(document, pages_to_ocr, provider) 86 | ) 87 | self.ocr_extraction( 88 | document, 89 | pages_to_ocr, 90 | ocr_page_images, 91 | block_polygons, 92 | block_ids, 93 | block_original_texts, 94 | ) 95 | 96 | def get_recognition_batch_size(self): 97 | if self.recognition_batch_size is not None: 98 | return self.recognition_batch_size 99 | elif settings.TORCH_DEVICE_MODEL == "cuda": 100 | return 48 101 | elif settings.TORCH_DEVICE_MODEL == "mps": 102 | return 16 103 | return 32 104 | 105 | def select_ocr_blocks_by_mode( 106 | self, page: PageGroup, block: Block, block_lines: List[Block], page_max_intersection_pct: float 107 | ): 108 | if any([ 109 | page_max_intersection_pct > self.block_mode_intersection_thresh, 110 | block.block_type not in self.full_ocr_block_types, 111 | len(block_lines) > self.block_mode_max_lines, 112 | block.polygon.height >= self.block_mode_max_height_frac * page.polygon.height 113 | ]): 114 | # Line mode 115 | return block_lines 116 | 117 | # Block mode 118 | return [block] 119 | 120 | def get_ocr_images_polygons_ids( 121 | self, document: Document, pages: List[PageGroup], provider: PdfProvider 122 | ): 123 | highres_images, highres_polys, block_ids, block_original_texts = [], [], [], [] 124 | for document_page in pages: 125 | page_highres_image = document_page.get_image(highres=True) 126 | page_highres_polys = [] 127 | page_block_ids = [] 128 | page_block_original_texts = [] 129 | 130 | page_size = provider.get_page_bbox(document_page.page_id).size 131 | image_size = page_highres_image.size 132 | max_intersection_pct = document_page.compute_max_structure_block_intersection_pct() 133 | for block in document_page.structure_blocks(document): 134 | if block.block_type in self.skip_ocr_blocks: 135 | # Skip OCR 136 | continue 137 | 138 | block_lines = block.contained_blocks(document, [BlockTypes.Line]) 139 | blocks_to_ocr = self.select_ocr_blocks_by_mode(document_page, block, block_lines, max_intersection_pct) 140 | 141 | block.text_extraction_method = "surya" 142 | for block in blocks_to_ocr: 143 | # Fit the polygon to image bounds since PIL image crop expands by default which might create bad images for the OCR model. 144 | block_polygon_rescaled = ( 145 | copy.deepcopy(block.polygon) 146 | .rescale(page_size, image_size) 147 | .fit_to_bounds((0, 0, *image_size)) 148 | ) 149 | block_bbox_rescaled = block_polygon_rescaled.polygon 150 | block_bbox_rescaled = [ 151 | [int(x) for x in point] for point in block_bbox_rescaled 152 | ] 153 | 154 | page_highres_polys.append(block_bbox_rescaled) 155 | page_block_ids.append(block.id) 156 | page_block_original_texts.append("") 157 | 158 | highres_images.append(page_highres_image) 159 | highres_polys.append(page_highres_polys) 160 | block_ids.append(page_block_ids) 161 | block_original_texts.append(page_block_original_texts) 162 | 163 | return highres_images, highres_polys, block_ids, block_original_texts 164 | 165 | def ocr_extraction( 166 | self, 167 | document: Document, 168 | pages: List[PageGroup], 169 | images: List[any], 170 | block_polygons: List[List[List[List[int]]]], # polygons 171 | block_ids: List[List[BlockId]], 172 | block_original_texts: List[List[str]], 173 | ): 174 | if sum(len(b) for b in block_polygons) == 0: 175 | return 176 | 177 | self.recognition_model.disable_tqdm = self.disable_tqdm 178 | recognition_results: List[OCRResult] = self.recognition_model( 179 | images=images, 180 | task_names=[self.ocr_task_name] * len(images), 181 | polygons=block_polygons, 182 | input_text=block_original_texts, 183 | recognition_batch_size=int(self.get_recognition_batch_size()), 184 | sort_lines=False, 185 | math_mode=not self.disable_ocr_math, 186 | drop_repeated_text=self.drop_repeated_text, 187 | max_sliding_window=2148, 188 | max_tokens=2048 189 | ) 190 | 191 | assert len(recognition_results) == len(images) == len(pages) == len(block_ids), ( 192 | f"Mismatch in OCR lengths: {len(recognition_results)}, {len(images)}, {len(pages)}, {len(block_ids)}" 193 | ) 194 | for document_page, page_recognition_result, page_block_ids, image in zip( 195 | pages, recognition_results, block_ids, images 196 | ): 197 | for block_id, block_ocr_result in zip( 198 | page_block_ids, page_recognition_result.text_lines 199 | ): 200 | if block_ocr_result.original_text_good: 201 | continue 202 | if not fix_text(block_ocr_result.text): 203 | continue 204 | 205 | block = document_page.get_block(block_id) 206 | # This is a nested list of spans, so multiple lines are supported 207 | all_line_spans = self.spans_from_html_chars( 208 | block_ocr_result.chars, document_page, image 209 | ) 210 | if block.block_type == BlockTypes.Line: 211 | # flatten all spans across lines 212 | flat_spans = [s for line_spans in all_line_spans for s in line_spans] 213 | self.replace_line_spans(document, document_page, block, flat_spans) 214 | else: 215 | # Clear out any old lines. Mark as removed for the json ocr renderer 216 | for line in block.contained_blocks(document_page, block_types=[BlockTypes.Line]): 217 | line.removed = True 218 | block.structure = [] 219 | 220 | for line_spans in all_line_spans: 221 | # TODO Replace this polygon with the polygon for each line, constructed from the spans 222 | # This needs the OCR model bbox predictions to improve first 223 | new_line = Line( 224 | polygon=block.polygon, 225 | page_id=block.page_id, 226 | text_extraction_method="surya" 227 | ) 228 | document_page.add_full_block(new_line) 229 | block.add_structure(new_line) 230 | self.replace_line_spans(document, document_page, new_line, line_spans) 231 | 232 | # TODO Fix polygons when we cut the span into multiple spans 233 | def link_and_break_span(self, span: Span, text: str, match_text, url: str): 234 | before_text, _, after_text = text.partition(match_text) 235 | before_span, after_span = None, None 236 | if before_text: 237 | before_span = copy.deepcopy(span) 238 | before_span.structure = [] # Avoid duplicate characters 239 | before_span.text = before_text 240 | if after_text: 241 | after_span = copy.deepcopy(span) 242 | after_span.text = after_text 243 | after_span.structure = [] # Avoid duplicate characters 244 | 245 | match_span = copy.deepcopy(span) 246 | match_span.text = match_text 247 | match_span.url = url 248 | 249 | return before_span, match_span, after_span 250 | 251 | # Pull all refs from old spans and attempt to insert back into appropriate place in new spans 252 | def replace_line_spans( 253 | self, document: Document, page: PageGroup, line: Line, new_spans: List[Span] 254 | ): 255 | old_spans = line.contained_blocks(document, [BlockTypes.Span]) 256 | text_ref_matching = {span.text: span.url for span in old_spans if span.url} 257 | 258 | # Insert refs into new spans, since the OCR model does not (cannot) generate these 259 | final_new_spans = [] 260 | for span in new_spans: 261 | # Use for copying attributes into new spans 262 | original_span = copy.deepcopy(span) 263 | remaining_text = span.text 264 | while remaining_text: 265 | matched = False 266 | for match_text, url in text_ref_matching.items(): 267 | if match_text in remaining_text: 268 | matched = True 269 | before, current, after = self.link_and_break_span( 270 | original_span, remaining_text, match_text, url 271 | ) 272 | if before: 273 | final_new_spans.append(before) 274 | final_new_spans.append(current) 275 | if after: 276 | remaining_text = after.text 277 | else: 278 | remaining_text = "" # No more text left 279 | # Prevent repeat matches 280 | del text_ref_matching[match_text] 281 | break 282 | if not matched: 283 | remaining_span = copy.deepcopy(original_span) 284 | remaining_span.text = remaining_text 285 | final_new_spans.append(remaining_span) 286 | break 287 | 288 | # Clear the old spans from the line 289 | line.structure = [] 290 | for span in final_new_spans: 291 | page.add_full_block(span) 292 | line.structure.append(span.id) 293 | 294 | def assign_chars(self, span: Span, current_chars: List[Char]): 295 | if self.keep_chars: 296 | span.structure = [c.id for c in current_chars] 297 | 298 | return [] 299 | 300 | def store_char(self, char: Char, current_chars: List[Char], page: PageGroup): 301 | if self.keep_chars: 302 | current_chars.append(char) 303 | page.add_full_block(char) 304 | 305 | def spans_from_html_chars( 306 | self, chars: List[TextChar], page: PageGroup, image: Image.Image 307 | ) -> List[List[Span]]: 308 | # Turn input characters from surya into spans - also store the raw characters 309 | SpanClass: Span = get_block_class(BlockTypes.Span) 310 | CharClass: Char = get_block_class(BlockTypes.Char) 311 | 312 | all_line_spans = [] 313 | current_line_spans = [] 314 | formats = {"plain"} 315 | current_span = None 316 | current_chars = [] 317 | image_size = image.size 318 | 319 | for idx, char in enumerate(chars): 320 | char_box = PolygonBox(polygon=char.polygon).rescale( 321 | image_size, page.polygon.size 322 | ) 323 | marker_char = CharClass( 324 | text=char.text, 325 | idx=idx, 326 | page_id=page.page_id, 327 | polygon=char_box, 328 | ) 329 | 330 | if char.text == "<br>": 331 | if current_span: 332 | current_chars = self.assign_chars(current_span, current_chars) 333 | current_line_spans.append(current_span) 334 | current_span = None 335 | if current_line_spans: 336 | current_line_spans[-1].text += "\n" 337 | all_line_spans.append(current_line_spans) 338 | current_line_spans = [] 339 | continue 340 | 341 | is_opening_tag, format = get_opening_tag_type(char.text) 342 | if is_opening_tag and format not in formats: 343 | formats.add(format) 344 | if current_span: 345 | current_chars = self.assign_chars(current_span, current_chars) 346 | current_line_spans.append(current_span) 347 | current_span = None 348 | 349 | if format == "math": 350 | current_span = SpanClass( 351 | text="", 352 | formats=list(formats), 353 | page_id=page.page_id, 354 | polygon=char_box, 355 | minimum_position=0, 356 | maximum_position=0, 357 | font="Unknown", 358 | font_weight=0, 359 | font_size=0, 360 | ) 361 | self.store_char(marker_char, current_chars, page) 362 | continue 363 | 364 | is_closing_tag, format = get_closing_tag_type(char.text) 365 | if is_closing_tag: 366 | # Useful since the OCR model sometimes returns closing tags without an opening tag 367 | try: 368 | formats.remove(format) 369 | except Exception: 370 | continue 371 | if current_span: 372 | current_chars = self.assign_chars(current_span, current_chars) 373 | current_line_spans.append(current_span) 374 | current_span = None 375 | continue 376 | 377 | if not current_span: 378 | current_span = SpanClass( 379 | text=fix_text(char.text), 380 | formats=list(formats), 381 | page_id=page.page_id, 382 | polygon=char_box, 383 | minimum_position=0, 384 | maximum_position=0, 385 | font="Unknown", 386 | font_weight=0, 387 | font_size=0, 388 | ) 389 | self.store_char(marker_char, current_chars, page) 390 | continue 391 | 392 | current_span.text = fix_text(current_span.text + char.text) 393 | self.store_char(marker_char, current_chars, page) 394 | 395 | # Tokens inside a math span don't have valid boxes, so we skip the merging 396 | if "math" not in formats: 397 | current_span.polygon = current_span.polygon.merge([char_box]) 398 | 399 | # Add the last span to the list 400 | if current_span: 401 | self.assign_chars(current_span, current_chars) 402 | current_line_spans.append(current_span) 403 | 404 | # flush last line 405 | if current_line_spans: 406 | current_line_spans[-1].text += "\n" 407 | all_line_spans.append(current_line_spans) 408 | 409 | return all_line_spans 410 | ``` -------------------------------------------------------------------------------- /marker/processors/table.py: -------------------------------------------------------------------------------- ```python 1 | import re 2 | from collections import defaultdict 3 | from copy import deepcopy 4 | from typing import Annotated, List 5 | from collections import Counter 6 | from PIL import Image 7 | 8 | from ftfy import fix_text 9 | from surya.detection import DetectionPredictor, TextDetectionResult 10 | from surya.recognition import RecognitionPredictor, TextLine 11 | from surya.table_rec import TableRecPredictor 12 | from surya.table_rec.schema import TableResult, TableCell as SuryaTableCell 13 | from pdftext.extraction import table_output 14 | 15 | from marker.processors import BaseProcessor 16 | from marker.schema import BlockTypes 17 | from marker.schema.blocks.tablecell import TableCell 18 | from marker.schema.document import Document 19 | from marker.schema.polygon import PolygonBox 20 | from marker.settings import settings 21 | from marker.util import matrix_intersection_area, unwrap_math 22 | from marker.utils.image import is_blank_image 23 | from marker.logger import get_logger 24 | 25 | logger = get_logger() 26 | 27 | 28 | class TableProcessor(BaseProcessor): 29 | """ 30 | A processor for recognizing tables in the document. 31 | """ 32 | 33 | block_types = (BlockTypes.Table, BlockTypes.TableOfContents, BlockTypes.Form) 34 | table_rec_batch_size: Annotated[ 35 | int, 36 | "The batch size to use for the table recognition model.", 37 | "Default is None, which will use the default batch size for the model.", 38 | ] = None 39 | detection_batch_size: Annotated[ 40 | int, 41 | "The batch size to use for the table detection model.", 42 | "Default is None, which will use the default batch size for the model.", 43 | ] = None 44 | recognition_batch_size: Annotated[ 45 | int, 46 | "The batch size to use for the table recognition model.", 47 | "Default is None, which will use the default batch size for the model.", 48 | ] = None 49 | contained_block_types: Annotated[ 50 | List[BlockTypes], 51 | "Block types to remove if they're contained inside the tables.", 52 | ] = (BlockTypes.Text, BlockTypes.TextInlineMath) 53 | row_split_threshold: Annotated[ 54 | float, 55 | "The percentage of rows that need to be split across the table before row splitting is active.", 56 | ] = 0.5 57 | pdftext_workers: Annotated[ 58 | int, 59 | "The number of workers to use for pdftext.", 60 | ] = 1 61 | disable_tqdm: Annotated[ 62 | bool, 63 | "Whether to disable the tqdm progress bar.", 64 | ] = False 65 | drop_repeated_table_text: Annotated[bool, "Drop repeated text in OCR results."] = ( 66 | False 67 | ) 68 | filter_tag_list = ["p", "table", "td", "tr", "th", "tbody"] 69 | disable_ocr_math: Annotated[bool, "Disable inline math recognition in OCR"] = False 70 | disable_ocr: Annotated[bool, "Disable OCR entirely."] = False 71 | 72 | def __init__( 73 | self, 74 | recognition_model: RecognitionPredictor, 75 | table_rec_model: TableRecPredictor, 76 | detection_model: DetectionPredictor, 77 | config=None, 78 | ): 79 | super().__init__(config) 80 | 81 | self.recognition_model = recognition_model 82 | self.table_rec_model = table_rec_model 83 | self.detection_model = detection_model 84 | 85 | def __call__(self, document: Document): 86 | filepath = document.filepath # Path to original pdf file 87 | 88 | table_data = [] 89 | for page in document.pages: 90 | for block in page.contained_blocks(document, self.block_types): 91 | if block.block_type == BlockTypes.Table: 92 | block.polygon = block.polygon.expand(0.01, 0.01) 93 | image = block.get_image(document, highres=True) 94 | image_poly = block.polygon.rescale( 95 | (page.polygon.width, page.polygon.height), 96 | page.get_image(highres=True).size, 97 | ) 98 | 99 | table_data.append( 100 | { 101 | "block_id": block.id, 102 | "page_id": page.page_id, 103 | "table_image": image, 104 | "table_bbox": image_poly.bbox, 105 | "img_size": page.get_image(highres=True).size, 106 | "ocr_block": any( 107 | [ 108 | page.text_extraction_method in ["surya"], 109 | page.ocr_errors_detected, 110 | ] 111 | ), 112 | } 113 | ) 114 | 115 | # Detect tables and cells 116 | self.table_rec_model.disable_tqdm = self.disable_tqdm 117 | tables: List[TableResult] = self.table_rec_model( 118 | [t["table_image"] for t in table_data], 119 | batch_size=self.get_table_rec_batch_size(), 120 | ) 121 | assert len(tables) == len(table_data), ( 122 | "Number of table results should match the number of tables" 123 | ) 124 | 125 | # Assign cell text if we don't need OCR 126 | # We do this at a line level 127 | extract_blocks = [t for t in table_data if not t["ocr_block"]] 128 | self.assign_pdftext_lines( 129 | extract_blocks, filepath 130 | ) # Handle tables where good text exists in the PDF 131 | self.assign_text_to_cells(tables, table_data) 132 | 133 | # Assign OCR lines if needed - we do this at a cell level 134 | self.assign_ocr_lines(tables, table_data) 135 | 136 | self.split_combined_rows(tables) # Split up rows that were combined 137 | self.combine_dollar_column(tables) # Combine columns that are just dollar signs 138 | 139 | # Assign table cells to the table 140 | table_idx = 0 141 | for page in document.pages: 142 | for block in page.contained_blocks(document, self.block_types): 143 | block.structure = [] # Remove any existing lines, spans, etc. 144 | cells: List[SuryaTableCell] = tables[table_idx].cells 145 | for cell in cells: 146 | # Rescale the cell polygon to the page size 147 | cell_polygon = PolygonBox(polygon=cell.polygon).rescale( 148 | page.get_image(highres=True).size, page.polygon.size 149 | ) 150 | 151 | # Rescale cell polygon to be relative to the page instead of the table 152 | for corner in cell_polygon.polygon: 153 | corner[0] += block.polygon.bbox[0] 154 | corner[1] += block.polygon.bbox[1] 155 | 156 | cell_block = TableCell( 157 | polygon=cell_polygon, 158 | text_lines=self.finalize_cell_text(cell), 159 | rowspan=cell.rowspan, 160 | colspan=cell.colspan, 161 | row_id=cell.row_id, 162 | col_id=cell.col_id, 163 | is_header=bool(cell.is_header), 164 | page_id=page.page_id, 165 | ) 166 | page.add_full_block(cell_block) 167 | block.add_structure(cell_block) 168 | table_idx += 1 169 | 170 | # Clean out other blocks inside the table 171 | # This can happen with stray text blocks inside the table post-merging 172 | for page in document.pages: 173 | child_contained_blocks = page.contained_blocks( 174 | document, self.contained_block_types 175 | ) 176 | for block in page.contained_blocks(document, self.block_types): 177 | intersections = matrix_intersection_area( 178 | [c.polygon.bbox for c in child_contained_blocks], 179 | [block.polygon.bbox], 180 | ) 181 | for child, intersection in zip(child_contained_blocks, intersections): 182 | # Adjust this to percentage of the child block that is enclosed by the table 183 | intersection_pct = intersection / max(child.polygon.area, 1) 184 | if intersection_pct > 0.95 and child.id in page.structure: 185 | page.structure.remove(child.id) 186 | 187 | def finalize_cell_text(self, cell: SuryaTableCell): 188 | fixed_text = [] 189 | text_lines = cell.text_lines if cell.text_lines else [] 190 | for line in text_lines: 191 | text = line["text"].strip() 192 | if not text or text == ".": 193 | continue 194 | # Spaced sequences: ". . .", "- - -", "_ _ _", "… … …" 195 | text = re.sub(r"(\s?[.\-_…]){2,}", "", text) 196 | # Unspaced sequences: "...", "---", "___", "……" 197 | text = re.sub(r"[.\-_…]{2,}", "", text) 198 | # Remove mathbf formatting if there is only digits with decimals/commas/currency symbols inside 199 | text = re.sub(r"\\mathbf\{([0-9.,$€£]+)\}", r"<b>\1</b>", text) 200 | # Drop empty tags like \overline{} 201 | text = re.sub(r"\\[a-zA-Z]+\{\s*\}", "", text) 202 | # Drop \phantom{...} (remove contents too) 203 | text = re.sub(r"\\phantom\{.*?\}", "", text) 204 | # Drop \quad 205 | text = re.sub(r"\\quad", "", text) 206 | # Drop \, 207 | text = re.sub(r"\\,", "", text) 208 | # Unwrap \mathsf{...} 209 | text = re.sub(r"\\mathsf\{([^}]*)\}", r"\1", text) 210 | # Handle unclosed tags: keep contents, drop the command 211 | text = re.sub(r"\\[a-zA-Z]+\{([^}]*)$", r"\1", text) 212 | # If the whole string is \text{...} → unwrap 213 | text = re.sub(r"^\s*\\text\{([^}]*)\}\s*$", r"\1", text) 214 | 215 | # In case the above steps left no more latex math - We can unwrap 216 | text = unwrap_math(text) 217 | text = self.normalize_spaces(fix_text(text)) 218 | fixed_text.append(text) 219 | return fixed_text 220 | 221 | @staticmethod 222 | def normalize_spaces(text): 223 | space_chars = [ 224 | "\u2003", # em space 225 | "\u2002", # en space 226 | "\u00a0", # non-breaking space 227 | "\u200b", # zero-width space 228 | "\u3000", # ideographic space 229 | ] 230 | for space in space_chars: 231 | text = text.replace(space, " ") 232 | return text 233 | 234 | def combine_dollar_column(self, tables: List[TableResult]): 235 | for table in tables: 236 | if len(table.cells) == 0: 237 | # Skip empty tables 238 | continue 239 | unique_cols = sorted(list(set([c.col_id for c in table.cells]))) 240 | max_col = max(unique_cols) 241 | dollar_cols = [] 242 | for col in unique_cols: 243 | # Cells in this col 244 | col_cells = [c for c in table.cells if c.col_id == col] 245 | col_text = [ 246 | "\n".join(self.finalize_cell_text(c)).strip() for c in col_cells 247 | ] 248 | all_dollars = all([ct in ["", "$"] for ct in col_text]) 249 | colspans = [c.colspan for c in col_cells] 250 | span_into_col = [ 251 | c 252 | for c in table.cells 253 | if c.col_id != col and c.col_id + c.colspan > col > c.col_id 254 | ] 255 | 256 | # This is a column that is entirely dollar signs 257 | if all( 258 | [ 259 | all_dollars, 260 | len(col_cells) > 1, 261 | len(span_into_col) == 0, 262 | all([c == 1 for c in colspans]), 263 | col < max_col, 264 | ] 265 | ): 266 | next_col_cells = [c for c in table.cells if c.col_id == col + 1] 267 | next_col_rows = [c.row_id for c in next_col_cells] 268 | col_rows = [c.row_id for c in col_cells] 269 | if ( 270 | len(next_col_cells) == len(col_cells) 271 | and next_col_rows == col_rows 272 | ): 273 | dollar_cols.append(col) 274 | 275 | if len(dollar_cols) == 0: 276 | continue 277 | 278 | dollar_cols = sorted(dollar_cols) 279 | col_offset = 0 280 | for col in unique_cols: 281 | col_cells = [c for c in table.cells if c.col_id == col] 282 | if col_offset == 0 and col not in dollar_cols: 283 | continue 284 | 285 | if col in dollar_cols: 286 | col_offset += 1 287 | for cell in col_cells: 288 | text_lines = cell.text_lines if cell.text_lines else [] 289 | next_row_col = [ 290 | c 291 | for c in table.cells 292 | if c.row_id == cell.row_id and c.col_id == col + 1 293 | ] 294 | 295 | # Add dollar to start of the next column 296 | next_text_lines = ( 297 | next_row_col[0].text_lines 298 | if next_row_col[0].text_lines 299 | else [] 300 | ) 301 | next_row_col[0].text_lines = deepcopy(text_lines) + deepcopy( 302 | next_text_lines 303 | ) 304 | table.cells = [ 305 | c for c in table.cells if c.cell_id != cell.cell_id 306 | ] # Remove original cell 307 | next_row_col[0].col_id -= col_offset 308 | else: 309 | for cell in col_cells: 310 | cell.col_id -= col_offset 311 | 312 | def split_combined_rows(self, tables: List[TableResult]): 313 | for table in tables: 314 | if len(table.cells) == 0: 315 | # Skip empty tables 316 | continue 317 | unique_rows = sorted(list(set([c.row_id for c in table.cells]))) 318 | row_info = [] 319 | for row in unique_rows: 320 | # Cells in this row 321 | # Deepcopy is because we do an in-place mutation later, and that can cause rows to shift to match rows in unique_rows 322 | # making them be processed twice 323 | row_cells = deepcopy([c for c in table.cells if c.row_id == row]) 324 | rowspans = [c.rowspan for c in row_cells] 325 | line_lens = [ 326 | len(c.text_lines) if isinstance(c.text_lines, list) else 1 327 | for c in row_cells 328 | ] 329 | 330 | # Other cells that span into this row 331 | rowspan_cells = [ 332 | c 333 | for c in table.cells 334 | if c.row_id != row and c.row_id + c.rowspan > row > c.row_id 335 | ] 336 | should_split_entire_row = all( 337 | [ 338 | len(row_cells) > 1, 339 | len(rowspan_cells) == 0, 340 | all([rowspan == 1 for rowspan in rowspans]), 341 | all([line_len > 1 for line_len in line_lens]), 342 | all([line_len == line_lens[0] for line_len in line_lens]), 343 | ] 344 | ) 345 | line_lens_counter = Counter(line_lens) 346 | counter_keys = sorted(list(line_lens_counter.keys())) 347 | should_split_partial_row = all( 348 | [ 349 | len(row_cells) > 3, # Only split if there are more than 3 cells 350 | len(rowspan_cells) == 0, 351 | all([r == 1 for r in rowspans]), 352 | len(line_lens_counter) == 2 353 | and counter_keys[0] <= 1 354 | and counter_keys[1] > 1 355 | and line_lens_counter[counter_keys[0]] 356 | == 1, # Allow a single column with a single line - keys are the line lens, values are the counts 357 | ] 358 | ) 359 | should_split = should_split_entire_row or should_split_partial_row 360 | row_info.append( 361 | { 362 | "should_split": should_split, 363 | "row_cells": row_cells, 364 | "line_lens": line_lens, 365 | } 366 | ) 367 | 368 | # Don't split if we're not splitting most of the rows in the table. This avoids splitting stray multiline rows. 369 | if ( 370 | sum([r["should_split"] for r in row_info]) / len(row_info) 371 | < self.row_split_threshold 372 | ): 373 | continue 374 | 375 | new_cells = [] 376 | shift_up = 0 377 | max_cell_id = max([c.cell_id for c in table.cells]) 378 | new_cell_count = 0 379 | for row, item_info in zip(unique_rows, row_info): 380 | max_lines = max(item_info["line_lens"]) 381 | if item_info["should_split"]: 382 | for i in range(0, max_lines): 383 | for cell in item_info["row_cells"]: 384 | # Calculate height based on number of splits 385 | split_height = cell.bbox[3] - cell.bbox[1] 386 | current_bbox = [ 387 | cell.bbox[0], 388 | cell.bbox[1] + i * split_height, 389 | cell.bbox[2], 390 | cell.bbox[1] + (i + 1) * split_height, 391 | ] 392 | 393 | line = ( 394 | [cell.text_lines[i]] 395 | if cell.text_lines and i < len(cell.text_lines) 396 | else None 397 | ) 398 | cell_id = max_cell_id + new_cell_count 399 | new_cells.append( 400 | SuryaTableCell( 401 | polygon=current_bbox, 402 | text_lines=line, 403 | rowspan=1, 404 | colspan=cell.colspan, 405 | row_id=cell.row_id + shift_up + i, 406 | col_id=cell.col_id, 407 | is_header=cell.is_header 408 | and i == 0, # Only first line is header 409 | within_row_id=cell.within_row_id, 410 | cell_id=cell_id, 411 | ) 412 | ) 413 | new_cell_count += 1 414 | 415 | # For each new row we add, shift up subsequent rows 416 | # The max is to account for partial rows 417 | shift_up += max_lines - 1 418 | else: 419 | for cell in item_info["row_cells"]: 420 | cell.row_id += shift_up 421 | new_cells.append(cell) 422 | 423 | # Only update the cells if we added new cells 424 | if len(new_cells) > len(table.cells): 425 | table.cells = new_cells 426 | 427 | def assign_text_to_cells(self, tables: List[TableResult], table_data: list): 428 | for table_result, table_page_data in zip(tables, table_data): 429 | if table_page_data["ocr_block"]: 430 | continue 431 | 432 | table_text_lines = table_page_data["table_text_lines"] 433 | table_cells: List[SuryaTableCell] = table_result.cells 434 | text_line_bboxes = [t["bbox"] for t in table_text_lines] 435 | table_cell_bboxes = [c.bbox for c in table_cells] 436 | 437 | intersection_matrix = matrix_intersection_area( 438 | text_line_bboxes, table_cell_bboxes 439 | ) 440 | 441 | cell_text = defaultdict(list) 442 | for text_line_idx, table_text_line in enumerate(table_text_lines): 443 | intersections = intersection_matrix[text_line_idx] 444 | if intersections.sum() == 0: 445 | continue 446 | 447 | max_intersection = intersections.argmax() 448 | cell_text[max_intersection].append(table_text_line) 449 | 450 | for k in cell_text: 451 | # TODO: see if the text needs to be sorted (based on rotation) 452 | text = cell_text[k] 453 | assert all("text" in t for t in text), "All text lines must have text" 454 | assert all("bbox" in t for t in text), "All text lines must have a bbox" 455 | table_cells[k].text_lines = text 456 | 457 | def assign_pdftext_lines(self, extract_blocks: list, filepath: str): 458 | table_inputs = [] 459 | unique_pages = list(set([t["page_id"] for t in extract_blocks])) 460 | if len(unique_pages) == 0: 461 | return 462 | 463 | for page in unique_pages: 464 | tables = [] 465 | img_size = None 466 | for block in extract_blocks: 467 | if block["page_id"] == page: 468 | tables.append(block["table_bbox"]) 469 | img_size = block["img_size"] 470 | 471 | table_inputs.append({"tables": tables, "img_size": img_size}) 472 | cell_text = table_output( 473 | filepath, 474 | table_inputs, 475 | page_range=unique_pages, 476 | workers=self.pdftext_workers, 477 | ) 478 | assert len(cell_text) == len(unique_pages), ( 479 | "Number of pages and table inputs must match" 480 | ) 481 | 482 | for pidx, (page_tables, pnum) in enumerate(zip(cell_text, unique_pages)): 483 | table_idx = 0 484 | for block in extract_blocks: 485 | if block["page_id"] == pnum: 486 | table_text = page_tables[table_idx] 487 | if len(table_text) == 0: 488 | block["ocr_block"] = ( 489 | True # Re-OCR the block if pdftext didn't find any text 490 | ) 491 | else: 492 | block["table_text_lines"] = page_tables[table_idx] 493 | table_idx += 1 494 | assert table_idx == len(page_tables), ( 495 | "Number of tables and table inputs must match" 496 | ) 497 | 498 | def align_table_cells( 499 | self, table: TableResult, table_detection_result: TextDetectionResult 500 | ): 501 | table_cells = table.cells 502 | table_text_lines = table_detection_result.bboxes 503 | 504 | text_line_bboxes = [t.bbox for t in table_text_lines] 505 | table_cell_bboxes = [c.bbox for c in table_cells] 506 | 507 | intersection_matrix = matrix_intersection_area( 508 | text_line_bboxes, table_cell_bboxes 509 | ) 510 | 511 | # Map cells -> list of assigned text lines 512 | cell_text = defaultdict(list) 513 | for text_line_idx, table_text_line in enumerate(table_text_lines): 514 | intersections = intersection_matrix[text_line_idx] 515 | if intersections.sum() == 0: 516 | continue 517 | max_intersection = intersections.argmax() 518 | cell_text[max_intersection].append(table_text_line) 519 | 520 | # Adjust cell polygons in place 521 | for cell_idx, cell in enumerate(table_cells): 522 | # all intersecting lines 523 | intersecting_line_indices = [ 524 | i for i, area in enumerate(intersection_matrix[:, cell_idx]) if area > 0 525 | ] 526 | if not intersecting_line_indices: 527 | continue 528 | 529 | assigned_lines = cell_text.get(cell_idx, []) 530 | # Expand to fit assigned lines - **Only in the y direction** 531 | for assigned_line in assigned_lines: 532 | x1 = cell.bbox[0] 533 | x2 = cell.bbox[2] 534 | y1 = min(cell.bbox[1], assigned_line.bbox[1]) 535 | y2 = max(cell.bbox[3], assigned_line.bbox[3]) 536 | cell.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] 537 | 538 | # Clear out non-assigned lines 539 | non_assigned_lines = [ 540 | table_text_lines[i] 541 | for i in intersecting_line_indices 542 | if table_text_lines[i] not in cell_text.get(cell_idx, []) 543 | ] 544 | if non_assigned_lines: 545 | # Find top-most and bottom-most non-assigned boxes 546 | top_box = min( 547 | non_assigned_lines, key=lambda line: line.bbox[1] 548 | ) # smallest y0 549 | bottom_box = max( 550 | non_assigned_lines, key=lambda line: line.bbox[3] 551 | ) # largest y1 552 | 553 | # Current cell bbox (from polygon) 554 | x0, y0, x1, y1 = cell.bbox 555 | 556 | # Adjust y-limits based on non-assigned boxes 557 | new_y0 = max(y0, top_box.bbox[3]) # top moves down 558 | new_y1 = min(y1, bottom_box.bbox[1]) # bottom moves up 559 | 560 | if new_y0 < new_y1: 561 | # Replace polygon with a new shrunken rectangle 562 | cell.polygon = [ 563 | [x0, new_y0], 564 | [x1, new_y0], 565 | [x1, new_y1], 566 | [x0, new_y1], 567 | ] 568 | 569 | def needs_ocr(self, tables: List[TableResult], table_blocks: List[dict]): 570 | ocr_tables = [] 571 | ocr_idxs = [] 572 | for j, (table_result, table_block) in enumerate(zip(tables, table_blocks)): 573 | table_cells: List[SuryaTableCell] = table_result.cells 574 | text_lines_need_ocr = any([tc.text_lines is None for tc in table_cells]) 575 | if ( 576 | table_block["ocr_block"] 577 | and text_lines_need_ocr 578 | and not self.disable_ocr 579 | ): 580 | logger.debug( 581 | f"Table {j} needs OCR, info table block needs ocr: {table_block['ocr_block']}, text_lines {text_lines_need_ocr}" 582 | ) 583 | ocr_tables.append(table_result) 584 | ocr_idxs.append(j) 585 | 586 | detection_results: List[TextDetectionResult] = self.detection_model( 587 | images=[table_blocks[i]["table_image"] for i in ocr_idxs], 588 | batch_size=self.get_detection_batch_size(), 589 | ) 590 | assert len(detection_results) == len(ocr_idxs), ( 591 | "Every OCRed table requires a text detection result" 592 | ) 593 | 594 | for idx, table_detection_result in zip(ocr_idxs, detection_results): 595 | self.align_table_cells(tables[idx], table_detection_result) 596 | 597 | ocr_polys = [] 598 | for ocr_idx in ocr_idxs: 599 | table_cells = tables[ocr_idx].cells 600 | polys = [tc for tc in table_cells if tc.text_lines is None] 601 | ocr_polys.append(polys) 602 | return ocr_tables, ocr_polys, ocr_idxs 603 | 604 | def get_ocr_results( 605 | self, table_images: List[Image.Image], ocr_polys: List[List[SuryaTableCell]] 606 | ): 607 | ocr_polys_bad = [] 608 | 609 | for table_image, polys in zip(table_images, ocr_polys): 610 | table_polys_bad = [ 611 | any( 612 | [ 613 | poly.height < 6, 614 | is_blank_image(table_image.crop(poly.bbox), poly.polygon), 615 | ] 616 | ) 617 | for poly in polys 618 | ] 619 | ocr_polys_bad.append(table_polys_bad) 620 | 621 | filtered_polys = [] 622 | for table_polys, table_polys_bad in zip(ocr_polys, ocr_polys_bad): 623 | filtered_table_polys = [] 624 | for p, is_bad in zip(table_polys, table_polys_bad): 625 | if is_bad: 626 | continue 627 | polygon = p.polygon 628 | # Round the polygon 629 | for corner in polygon: 630 | for i in range(2): 631 | corner[i] = int(corner[i]) 632 | 633 | filtered_table_polys.append(polygon) 634 | filtered_polys.append(filtered_table_polys) 635 | 636 | ocr_results = self.recognition_model( 637 | images=table_images, 638 | task_names=["ocr_with_boxes"] * len(table_images), 639 | recognition_batch_size=self.get_recognition_batch_size(), 640 | drop_repeated_text=self.drop_repeated_table_text, 641 | polygons=filtered_polys, 642 | filter_tag_list=self.filter_tag_list, 643 | max_tokens=2048, 644 | max_sliding_window=2148, 645 | math_mode=not self.disable_ocr_math, 646 | ) 647 | 648 | # Re-align the predictions to the original length, since we skipped some predictions 649 | for table_ocr_result, table_polys_bad in zip(ocr_results, ocr_polys_bad): 650 | updated_lines = [] 651 | idx = 0 652 | for is_bad in table_polys_bad: 653 | if is_bad: 654 | updated_lines.append( 655 | TextLine( 656 | text="", 657 | polygon=[[0, 0], [0, 0], [0, 0], [0, 0]], 658 | confidence=1, 659 | chars=[], 660 | original_text_good=False, 661 | words=None, 662 | ) 663 | ) 664 | else: 665 | updated_lines.append(table_ocr_result.text_lines[idx]) 666 | idx += 1 667 | table_ocr_result.text_lines = updated_lines 668 | 669 | return ocr_results 670 | 671 | def assign_ocr_lines(self, tables: List[TableResult], table_blocks: list): 672 | ocr_tables, ocr_polys, ocr_idxs = self.needs_ocr(tables, table_blocks) 673 | det_images = [ 674 | t["table_image"] for i, t in enumerate(table_blocks) if i in ocr_idxs 675 | ] 676 | assert len(det_images) == len(ocr_polys), ( 677 | f"Number of detection images and OCR polygons must match: {len(det_images)} != {len(ocr_polys)}" 678 | ) 679 | self.recognition_model.disable_tqdm = self.disable_tqdm 680 | ocr_results = self.get_ocr_results(table_images=det_images, ocr_polys=ocr_polys) 681 | 682 | for result, ocr_res in zip(ocr_tables, ocr_results): 683 | table_cells: List[SuryaTableCell] = result.cells 684 | cells_need_text = [tc for tc in table_cells if tc.text_lines is None] 685 | 686 | assert len(cells_need_text) == len(ocr_res.text_lines), ( 687 | "Number of cells needing text and OCR results must match" 688 | ) 689 | 690 | for cell_text, cell_needs_text in zip(ocr_res.text_lines, cells_need_text): 691 | # Don't need to correct back to image size 692 | # Table rec boxes are relative to the table 693 | cell_text_lines = [{"text": t} for t in cell_text.text.split("<br>")] 694 | cell_needs_text.text_lines = cell_text_lines 695 | 696 | def get_table_rec_batch_size(self): 697 | if self.table_rec_batch_size is not None: 698 | return self.table_rec_batch_size 699 | elif settings.TORCH_DEVICE_MODEL == "mps": 700 | return 6 701 | elif settings.TORCH_DEVICE_MODEL == "cuda": 702 | return 14 703 | return 6 704 | 705 | def get_recognition_batch_size(self): 706 | if self.recognition_batch_size is not None: 707 | return self.recognition_batch_size 708 | elif settings.TORCH_DEVICE_MODEL == "mps": 709 | return 32 710 | elif settings.TORCH_DEVICE_MODEL == "cuda": 711 | return 48 712 | return 32 713 | 714 | def get_detection_batch_size(self): 715 | if self.detection_batch_size is not None: 716 | return self.detection_batch_size 717 | elif settings.TORCH_DEVICE_MODEL == "cuda": 718 | return 10 719 | return 4 720 | ```