|
|
import copy |
|
|
import logging |
|
|
import warnings |
|
|
from pathlib import Path |
|
|
from typing import Iterable, Optional, Union |
|
|
|
|
|
from docling_core.types.doc import DocItemLabel |
|
|
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor |
|
|
from PIL import Image |
|
|
|
|
|
from docling.datamodel.base_models import BoundingBox, Cluster, LayoutPrediction, Page |
|
|
from docling.datamodel.document import ConversionResult |
|
|
from docling.datamodel.pipeline_options import AcceleratorOptions |
|
|
from docling.datamodel.settings import settings |
|
|
from docling.models.base_model import BasePageModel |
|
|
from docling.utils.accelerator_utils import decide_device |
|
|
from docling.utils.layout_postprocessor import LayoutPostprocessor |
|
|
from docling.utils.profiling import TimeRecorder |
|
|
from docling.utils.visualization import draw_clusters |
|
|
|
|
|
_log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class LayoutModel(BasePageModel): |
|
|
_model_repo_folder = "ds4sd--docling-models" |
|
|
_model_path = "model_artifacts/layout" |
|
|
|
|
|
TEXT_ELEM_LABELS = [ |
|
|
DocItemLabel.TEXT, |
|
|
DocItemLabel.FOOTNOTE, |
|
|
DocItemLabel.CAPTION, |
|
|
DocItemLabel.CHECKBOX_UNSELECTED, |
|
|
DocItemLabel.CHECKBOX_SELECTED, |
|
|
DocItemLabel.SECTION_HEADER, |
|
|
DocItemLabel.PAGE_HEADER, |
|
|
DocItemLabel.PAGE_FOOTER, |
|
|
DocItemLabel.CODE, |
|
|
DocItemLabel.LIST_ITEM, |
|
|
DocItemLabel.FORMULA, |
|
|
] |
|
|
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER] |
|
|
|
|
|
TABLE_LABELS = [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX] |
|
|
FIGURE_LABEL = DocItemLabel.PICTURE |
|
|
FORMULA_LABEL = DocItemLabel.FORMULA |
|
|
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION] |
|
|
|
|
|
def __init__( |
|
|
self, artifacts_path: Optional[Path], accelerator_options: AcceleratorOptions |
|
|
): |
|
|
device = decide_device(accelerator_options.device) |
|
|
|
|
|
if artifacts_path is None: |
|
|
artifacts_path = self.download_models() / self._model_path |
|
|
else: |
|
|
|
|
|
if (artifacts_path / self._model_repo_folder).exists(): |
|
|
artifacts_path = ( |
|
|
artifacts_path / self._model_repo_folder / self._model_path |
|
|
) |
|
|
elif (artifacts_path / self._model_path).exists(): |
|
|
warnings.warn( |
|
|
"The usage of artifacts_path containing directly " |
|
|
f"{self._model_path} is deprecated. Please point " |
|
|
"the artifacts_path to the parent containing " |
|
|
f"the {self._model_repo_folder} folder.", |
|
|
DeprecationWarning, |
|
|
stacklevel=3, |
|
|
) |
|
|
artifacts_path = artifacts_path / self._model_path |
|
|
|
|
|
self.layout_predictor = LayoutPredictor( |
|
|
artifact_path=str(artifacts_path), |
|
|
device=device, |
|
|
num_threads=accelerator_options.num_threads, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def download_models( |
|
|
local_dir: Optional[Path] = None, |
|
|
force: bool = False, |
|
|
progress: bool = False, |
|
|
) -> Path: |
|
|
from huggingface_hub import snapshot_download |
|
|
from huggingface_hub.utils import disable_progress_bars |
|
|
|
|
|
if not progress: |
|
|
disable_progress_bars() |
|
|
download_path = snapshot_download( |
|
|
repo_id="ds4sd/docling-models", |
|
|
force_download=force, |
|
|
local_dir=local_dir, |
|
|
revision="v2.1.0", |
|
|
) |
|
|
|
|
|
return Path(download_path) |
|
|
|
|
|
def draw_clusters_and_cells_side_by_side( |
|
|
self, conv_res, page, clusters, mode_prefix: str, show: bool = False |
|
|
): |
|
|
""" |
|
|
Draws a page image side by side with clusters filtered into two categories: |
|
|
- Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE. |
|
|
- Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE. |
|
|
Includes label names and confidence scores for each cluster. |
|
|
""" |
|
|
scale_x = page.image.width / page.size.width |
|
|
scale_y = page.image.height / page.size.height |
|
|
|
|
|
|
|
|
exclude_labels = { |
|
|
DocItemLabel.FORM, |
|
|
DocItemLabel.KEY_VALUE_REGION, |
|
|
DocItemLabel.PICTURE, |
|
|
} |
|
|
left_clusters = [c for c in clusters if c.label not in exclude_labels] |
|
|
right_clusters = [c for c in clusters if c.label in exclude_labels] |
|
|
|
|
|
left_image = copy.deepcopy(page.image) |
|
|
right_image = copy.deepcopy(page.image) |
|
|
|
|
|
|
|
|
draw_clusters(left_image, left_clusters, scale_x, scale_y) |
|
|
draw_clusters(right_image, right_clusters, scale_x, scale_y) |
|
|
|
|
|
combined_width = left_image.width * 2 |
|
|
combined_height = left_image.height |
|
|
combined_image = Image.new("RGB", (combined_width, combined_height)) |
|
|
combined_image.paste(left_image, (0, 0)) |
|
|
combined_image.paste(right_image, (left_image.width, 0)) |
|
|
if show: |
|
|
combined_image.show() |
|
|
else: |
|
|
out_path: Path = ( |
|
|
Path(settings.debug.debug_output_path) |
|
|
/ f"debug_{conv_res.input.file.stem}" |
|
|
) |
|
|
out_path.mkdir(parents=True, exist_ok=True) |
|
|
out_file = out_path / f"{mode_prefix}_layout_page_{page.page_no:05}.png" |
|
|
combined_image.save(str(out_file), format="png") |
|
|
|
|
|
def __call__( |
|
|
self, conv_res: ConversionResult, page_batch: Iterable[Page] |
|
|
) -> Iterable[Page]: |
|
|
|
|
|
for page in page_batch: |
|
|
assert page._backend is not None |
|
|
if not page._backend.is_valid(): |
|
|
yield page |
|
|
else: |
|
|
with TimeRecorder(conv_res, "layout"): |
|
|
assert page.size is not None |
|
|
page_image = page.get_image(scale=1.0) |
|
|
assert page_image is not None |
|
|
|
|
|
clusters = [] |
|
|
for ix, pred_item in enumerate( |
|
|
self.layout_predictor.predict(page_image) |
|
|
): |
|
|
label = DocItemLabel( |
|
|
pred_item["label"] |
|
|
.lower() |
|
|
.replace(" ", "_") |
|
|
.replace("-", "_") |
|
|
) |
|
|
cluster = Cluster( |
|
|
id=ix, |
|
|
label=label, |
|
|
confidence=pred_item["confidence"], |
|
|
bbox=BoundingBox.model_validate(pred_item), |
|
|
cells=[], |
|
|
) |
|
|
clusters.append(cluster) |
|
|
|
|
|
if settings.debug.visualize_raw_layout: |
|
|
self.draw_clusters_and_cells_side_by_side( |
|
|
conv_res, page, clusters, mode_prefix="raw" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
processed_clusters, processed_cells = LayoutPostprocessor( |
|
|
page.cells, clusters, page.size |
|
|
).postprocess() |
|
|
|
|
|
|
|
|
page.cells = processed_cells |
|
|
page.predictions.layout = LayoutPrediction( |
|
|
clusters=processed_clusters |
|
|
) |
|
|
|
|
|
if settings.debug.visualize_layout: |
|
|
self.draw_clusters_and_cells_side_by_side( |
|
|
conv_res, page, processed_clusters, mode_prefix="postprocessed" |
|
|
) |
|
|
|
|
|
yield page |
|
|
|