|
|
import logging |
|
|
import re |
|
|
from typing import Iterable, List |
|
|
|
|
|
from pydantic import BaseModel |
|
|
|
|
|
from docling.datamodel.base_models import ( |
|
|
AssembledUnit, |
|
|
ContainerElement, |
|
|
FigureElement, |
|
|
Page, |
|
|
PageElement, |
|
|
Table, |
|
|
TextElement, |
|
|
) |
|
|
from docling.datamodel.document import ConversionResult |
|
|
from docling.models.base_model import BasePageModel |
|
|
from docling.models.layout_model import LayoutModel |
|
|
from docling.utils.profiling import TimeRecorder |
|
|
|
|
|
_log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class PageAssembleOptions(BaseModel): |
|
|
pass |
|
|
|
|
|
|
|
|
class PageAssembleModel(BasePageModel): |
|
|
def __init__(self, options: PageAssembleOptions): |
|
|
self.options = options |
|
|
|
|
|
def sanitize_text(self, lines): |
|
|
if len(lines) <= 1: |
|
|
return " ".join(lines) |
|
|
|
|
|
for ix, line in enumerate(lines[1:]): |
|
|
prev_line = lines[ix] |
|
|
|
|
|
if prev_line.endswith("-"): |
|
|
prev_words = re.findall(r"\b[\w]+\b", prev_line) |
|
|
line_words = re.findall(r"\b[\w]+\b", line) |
|
|
|
|
|
if ( |
|
|
len(prev_words) |
|
|
and len(line_words) |
|
|
and prev_words[-1].isalnum() |
|
|
and line_words[0].isalnum() |
|
|
): |
|
|
lines[ix] = prev_line[:-1] |
|
|
else: |
|
|
lines[ix] += " " |
|
|
|
|
|
sanitized_text = "".join(lines) |
|
|
|
|
|
return sanitized_text.strip() |
|
|
|
|
|
def __call__( |
|
|
self, conv_res: ConversionResult, page_batch: Iterable[Page] |
|
|
) -> Iterable[Page]: |
|
|
for page in page_batch: |
|
|
assert page._backend is not None |
|
|
if not page._backend.is_valid(): |
|
|
yield page |
|
|
else: |
|
|
with TimeRecorder(conv_res, "page_assemble"): |
|
|
|
|
|
assert page.predictions.layout is not None |
|
|
|
|
|
|
|
|
|
|
|
elements: List[PageElement] = [] |
|
|
headers: List[PageElement] = [] |
|
|
body: List[PageElement] = [] |
|
|
|
|
|
for cluster in page.predictions.layout.clusters: |
|
|
|
|
|
if cluster.label in LayoutModel.TEXT_ELEM_LABELS: |
|
|
|
|
|
textlines = [ |
|
|
cell.text.replace("\x02", "-").strip() |
|
|
for cell in cluster.cells |
|
|
if len(cell.text.strip()) > 0 |
|
|
] |
|
|
text = self.sanitize_text(textlines) |
|
|
text_el = TextElement( |
|
|
label=cluster.label, |
|
|
id=cluster.id, |
|
|
text=text, |
|
|
page_no=page.page_no, |
|
|
cluster=cluster, |
|
|
) |
|
|
elements.append(text_el) |
|
|
|
|
|
if cluster.label in LayoutModel.PAGE_HEADER_LABELS: |
|
|
headers.append(text_el) |
|
|
else: |
|
|
body.append(text_el) |
|
|
elif cluster.label in LayoutModel.TABLE_LABELS: |
|
|
tbl = None |
|
|
if page.predictions.tablestructure: |
|
|
tbl = page.predictions.tablestructure.table_map.get( |
|
|
cluster.id, None |
|
|
) |
|
|
if ( |
|
|
not tbl |
|
|
): |
|
|
tbl = Table( |
|
|
label=cluster.label, |
|
|
id=cluster.id, |
|
|
text="", |
|
|
otsl_seq=[], |
|
|
table_cells=[], |
|
|
cluster=cluster, |
|
|
page_no=page.page_no, |
|
|
) |
|
|
|
|
|
elements.append(tbl) |
|
|
body.append(tbl) |
|
|
elif cluster.label == LayoutModel.FIGURE_LABEL: |
|
|
fig = None |
|
|
if page.predictions.figures_classification: |
|
|
fig = page.predictions.figures_classification.figure_map.get( |
|
|
cluster.id, None |
|
|
) |
|
|
if ( |
|
|
not fig |
|
|
): |
|
|
fig = FigureElement( |
|
|
label=cluster.label, |
|
|
id=cluster.id, |
|
|
text="", |
|
|
data=None, |
|
|
cluster=cluster, |
|
|
page_no=page.page_no, |
|
|
) |
|
|
elements.append(fig) |
|
|
body.append(fig) |
|
|
elif cluster.label in LayoutModel.CONTAINER_LABELS: |
|
|
container_el = ContainerElement( |
|
|
label=cluster.label, |
|
|
id=cluster.id, |
|
|
page_no=page.page_no, |
|
|
cluster=cluster, |
|
|
) |
|
|
elements.append(container_el) |
|
|
body.append(container_el) |
|
|
|
|
|
page.assembled = AssembledUnit( |
|
|
elements=elements, headers=headers, body=body |
|
|
) |
|
|
|
|
|
yield page |
|
|
|