|
|
import copy |
|
|
import warnings |
|
|
from pathlib import Path |
|
|
from typing import Iterable, Optional, Union |
|
|
|
|
|
import numpy |
|
|
from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell |
|
|
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor |
|
|
from PIL import ImageDraw |
|
|
|
|
|
from docling.datamodel.base_models import Page, Table, TableStructurePrediction |
|
|
from docling.datamodel.document import ConversionResult |
|
|
from docling.datamodel.pipeline_options import ( |
|
|
AcceleratorDevice, |
|
|
AcceleratorOptions, |
|
|
TableFormerMode, |
|
|
TableStructureOptions, |
|
|
) |
|
|
from docling.datamodel.settings import settings |
|
|
from docling.models.base_model import BasePageModel |
|
|
from docling.utils.accelerator_utils import decide_device |
|
|
from docling.utils.profiling import TimeRecorder |
|
|
|
|
|
|
|
|
class TableStructureModel(BasePageModel): |
|
|
_model_repo_folder = "ds4sd--docling-models" |
|
|
_model_path = "model_artifacts/tableformer" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
enabled: bool, |
|
|
artifacts_path: Optional[Path], |
|
|
options: TableStructureOptions, |
|
|
accelerator_options: AcceleratorOptions, |
|
|
): |
|
|
self.options = options |
|
|
self.do_cell_matching = self.options.do_cell_matching |
|
|
self.mode = self.options.mode |
|
|
|
|
|
self.enabled = enabled |
|
|
if self.enabled: |
|
|
|
|
|
if artifacts_path is None: |
|
|
artifacts_path = self.download_models() / self._model_path |
|
|
else: |
|
|
|
|
|
if (artifacts_path / self._model_repo_folder).exists(): |
|
|
artifacts_path = ( |
|
|
artifacts_path / self._model_repo_folder / self._model_path |
|
|
) |
|
|
elif (artifacts_path / self._model_path).exists(): |
|
|
warnings.warn( |
|
|
"The usage of artifacts_path containing directly " |
|
|
f"{self._model_path} is deprecated. Please point " |
|
|
"the artifacts_path to the parent containing " |
|
|
f"the {self._model_repo_folder} folder.", |
|
|
DeprecationWarning, |
|
|
stacklevel=3, |
|
|
) |
|
|
artifacts_path = artifacts_path / self._model_path |
|
|
|
|
|
if self.mode == TableFormerMode.ACCURATE: |
|
|
artifacts_path = artifacts_path / "accurate" |
|
|
else: |
|
|
artifacts_path = artifacts_path / "fast" |
|
|
|
|
|
|
|
|
import docling_ibm_models.tableformer.common as c |
|
|
|
|
|
device = decide_device(accelerator_options.device) |
|
|
|
|
|
|
|
|
if device == AcceleratorDevice.MPS.value: |
|
|
device = AcceleratorDevice.CPU.value |
|
|
|
|
|
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json") |
|
|
self.tm_config["model"]["save_dir"] = artifacts_path |
|
|
self.tm_model_type = self.tm_config["model"]["type"] |
|
|
|
|
|
self.tf_predictor = TFPredictor( |
|
|
self.tm_config, device, accelerator_options.num_threads |
|
|
) |
|
|
self.scale = 2.0 |
|
|
|
|
|
@staticmethod |
|
|
def download_models( |
|
|
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False |
|
|
) -> Path: |
|
|
from huggingface_hub import snapshot_download |
|
|
from huggingface_hub.utils import disable_progress_bars |
|
|
|
|
|
if not progress: |
|
|
disable_progress_bars() |
|
|
download_path = snapshot_download( |
|
|
repo_id="ds4sd/docling-models", |
|
|
force_download=force, |
|
|
local_dir=local_dir, |
|
|
revision="v2.1.0", |
|
|
) |
|
|
|
|
|
return Path(download_path) |
|
|
|
|
|
def draw_table_and_cells( |
|
|
self, |
|
|
conv_res: ConversionResult, |
|
|
page: Page, |
|
|
tbl_list: Iterable[Table], |
|
|
show: bool = False, |
|
|
): |
|
|
assert page._backend is not None |
|
|
assert page.size is not None |
|
|
|
|
|
image = ( |
|
|
page._backend.get_page_image() |
|
|
) |
|
|
|
|
|
scale_x = image.width / page.size.width |
|
|
scale_y = image.height / page.size.height |
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
for table_element in tbl_list: |
|
|
x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple() |
|
|
y0 *= scale_x |
|
|
y1 *= scale_y |
|
|
x0 *= scale_x |
|
|
x1 *= scale_x |
|
|
|
|
|
draw.rectangle([(x0, y0), (x1, y1)], outline="red") |
|
|
|
|
|
for cell in table_element.cluster.cells: |
|
|
x0, y0, x1, y1 = cell.bbox.as_tuple() |
|
|
x0 *= scale_x |
|
|
x1 *= scale_x |
|
|
y0 *= scale_x |
|
|
y1 *= scale_y |
|
|
|
|
|
draw.rectangle([(x0, y0), (x1, y1)], outline="green") |
|
|
|
|
|
for tc in table_element.table_cells: |
|
|
if tc.bbox is not None: |
|
|
x0, y0, x1, y1 = tc.bbox.as_tuple() |
|
|
x0 *= scale_x |
|
|
x1 *= scale_x |
|
|
y0 *= scale_x |
|
|
y1 *= scale_y |
|
|
|
|
|
if tc.column_header: |
|
|
width = 3 |
|
|
else: |
|
|
width = 1 |
|
|
draw.rectangle([(x0, y0), (x1, y1)], outline="blue", width=width) |
|
|
draw.text( |
|
|
(x0 + 3, y0 + 3), |
|
|
text=f"{tc.start_row_offset_idx}, {tc.start_col_offset_idx}", |
|
|
fill="black", |
|
|
) |
|
|
if show: |
|
|
image.show() |
|
|
else: |
|
|
out_path: Path = ( |
|
|
Path(settings.debug.debug_output_path) |
|
|
/ f"debug_{conv_res.input.file.stem}" |
|
|
) |
|
|
out_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
out_file = out_path / f"table_struct_page_{page.page_no:05}.png" |
|
|
image.save(str(out_file), format="png") |
|
|
|
|
|
def __call__( |
|
|
self, conv_res: ConversionResult, page_batch: Iterable[Page] |
|
|
) -> Iterable[Page]: |
|
|
|
|
|
if not self.enabled: |
|
|
yield from page_batch |
|
|
return |
|
|
|
|
|
for page in page_batch: |
|
|
assert page._backend is not None |
|
|
if not page._backend.is_valid(): |
|
|
yield page |
|
|
else: |
|
|
with TimeRecorder(conv_res, "table_structure"): |
|
|
|
|
|
assert page.predictions.layout is not None |
|
|
assert page.size is not None |
|
|
|
|
|
page.predictions.tablestructure = ( |
|
|
TableStructurePrediction() |
|
|
) |
|
|
|
|
|
in_tables = [ |
|
|
( |
|
|
cluster, |
|
|
[ |
|
|
round(cluster.bbox.l) * self.scale, |
|
|
round(cluster.bbox.t) * self.scale, |
|
|
round(cluster.bbox.r) * self.scale, |
|
|
round(cluster.bbox.b) * self.scale, |
|
|
], |
|
|
) |
|
|
for cluster in page.predictions.layout.clusters |
|
|
if cluster.label |
|
|
in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX] |
|
|
] |
|
|
if not len(in_tables): |
|
|
yield page |
|
|
continue |
|
|
|
|
|
page_input = { |
|
|
"width": page.size.width * self.scale, |
|
|
"height": page.size.height * self.scale, |
|
|
"image": numpy.asarray(page.get_image(scale=self.scale)), |
|
|
} |
|
|
|
|
|
table_clusters, table_bboxes = zip(*in_tables) |
|
|
|
|
|
if len(table_bboxes): |
|
|
for table_cluster, tbl_box in in_tables: |
|
|
|
|
|
tokens = [] |
|
|
for c in table_cluster.cells: |
|
|
|
|
|
if len(c.text.strip()) > 0: |
|
|
new_cell = copy.deepcopy(c) |
|
|
new_cell.bbox = new_cell.bbox.scaled( |
|
|
scale=self.scale |
|
|
) |
|
|
|
|
|
tokens.append(new_cell.model_dump()) |
|
|
page_input["tokens"] = tokens |
|
|
|
|
|
tf_output = self.tf_predictor.multi_table_predict( |
|
|
page_input, [tbl_box], do_matching=self.do_cell_matching |
|
|
) |
|
|
table_out = tf_output[0] |
|
|
table_cells = [] |
|
|
for element in table_out["tf_responses"]: |
|
|
|
|
|
if not self.do_cell_matching: |
|
|
the_bbox = BoundingBox.model_validate( |
|
|
element["bbox"] |
|
|
).scaled(1 / self.scale) |
|
|
text_piece = page._backend.get_text_in_rect( |
|
|
the_bbox |
|
|
) |
|
|
element["bbox"]["token"] = text_piece |
|
|
|
|
|
tc = TableCell.model_validate(element) |
|
|
if self.do_cell_matching and tc.bbox is not None: |
|
|
tc.bbox = tc.bbox.scaled(1 / self.scale) |
|
|
table_cells.append(tc) |
|
|
|
|
|
assert "predict_details" in table_out |
|
|
|
|
|
|
|
|
num_rows = table_out["predict_details"].get("num_rows", 0) |
|
|
num_cols = table_out["predict_details"].get("num_cols", 0) |
|
|
otsl_seq = ( |
|
|
table_out["predict_details"] |
|
|
.get("prediction", {}) |
|
|
.get("rs_seq", []) |
|
|
) |
|
|
|
|
|
tbl = Table( |
|
|
otsl_seq=otsl_seq, |
|
|
table_cells=table_cells, |
|
|
num_rows=num_rows, |
|
|
num_cols=num_cols, |
|
|
id=table_cluster.id, |
|
|
page_no=page.page_no, |
|
|
cluster=table_cluster, |
|
|
label=table_cluster.label, |
|
|
) |
|
|
|
|
|
page.predictions.tablestructure.table_map[ |
|
|
table_cluster.id |
|
|
] = tbl |
|
|
|
|
|
|
|
|
if settings.debug.visualize_tables: |
|
|
self.draw_table_and_cells( |
|
|
conv_res, |
|
|
page, |
|
|
page.predictions.tablestructure.table_map.values(), |
|
|
) |
|
|
|
|
|
yield page |
|
|
|