|
|
from pathlib import Path |
|
|
from typing import Iterable, List, Literal, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
from docling_core.types.doc import ( |
|
|
DoclingDocument, |
|
|
NodeItem, |
|
|
PictureClassificationClass, |
|
|
PictureClassificationData, |
|
|
PictureItem, |
|
|
) |
|
|
from PIL import Image |
|
|
from pydantic import BaseModel |
|
|
|
|
|
from docling.datamodel.pipeline_options import AcceleratorOptions |
|
|
from docling.models.base_model import BaseEnrichmentModel |
|
|
from docling.utils.accelerator_utils import decide_device |
|
|
|
|
|
|
|
|
class DocumentPictureClassifierOptions(BaseModel): |
|
|
""" |
|
|
Options for configuring the DocumentPictureClassifier. |
|
|
|
|
|
Attributes |
|
|
---------- |
|
|
kind : Literal["document_picture_classifier"] |
|
|
Identifier for the type of classifier. |
|
|
""" |
|
|
|
|
|
kind: Literal["document_picture_classifier"] = "document_picture_classifier" |
|
|
|
|
|
|
|
|
class DocumentPictureClassifier(BaseEnrichmentModel): |
|
|
""" |
|
|
A model for classifying pictures in documents. |
|
|
|
|
|
This class enriches document pictures with predicted classifications |
|
|
based on a predefined set of classes. |
|
|
|
|
|
Attributes |
|
|
---------- |
|
|
enabled : bool |
|
|
Whether the classifier is enabled for use. |
|
|
options : DocumentPictureClassifierOptions |
|
|
Configuration options for the classifier. |
|
|
document_picture_classifier : DocumentPictureClassifierPredictor |
|
|
The underlying prediction model, loaded if the classifier is enabled. |
|
|
|
|
|
Methods |
|
|
------- |
|
|
__init__(enabled, artifacts_path, options, accelerator_options) |
|
|
Initializes the classifier with specified configurations. |
|
|
is_processable(doc, element) |
|
|
Checks if the given element can be processed by the classifier. |
|
|
__call__(doc, element_batch) |
|
|
Processes a batch of elements and adds classification annotations. |
|
|
""" |
|
|
|
|
|
_model_repo_folder = "ds4sd--DocumentFigureClassifier" |
|
|
images_scale = 2 |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
enabled: bool, |
|
|
artifacts_path: Optional[Path], |
|
|
options: DocumentPictureClassifierOptions, |
|
|
accelerator_options: AcceleratorOptions, |
|
|
): |
|
|
""" |
|
|
Initializes the DocumentPictureClassifier. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
enabled : bool |
|
|
Indicates whether the classifier is enabled. |
|
|
artifacts_path : Optional[Union[Path, str]], |
|
|
Path to the directory containing model artifacts. |
|
|
options : DocumentPictureClassifierOptions |
|
|
Configuration options for the classifier. |
|
|
accelerator_options : AcceleratorOptions |
|
|
Options for configuring the device and parallelism. |
|
|
""" |
|
|
self.enabled = enabled |
|
|
self.options = options |
|
|
|
|
|
if self.enabled: |
|
|
device = decide_device(accelerator_options.device) |
|
|
from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import ( |
|
|
DocumentFigureClassifierPredictor, |
|
|
) |
|
|
|
|
|
if artifacts_path is None: |
|
|
artifacts_path = self.download_models() |
|
|
else: |
|
|
artifacts_path = artifacts_path / self._model_repo_folder |
|
|
|
|
|
self.document_picture_classifier = DocumentFigureClassifierPredictor( |
|
|
artifacts_path=str(artifacts_path), |
|
|
device=device, |
|
|
num_threads=accelerator_options.num_threads, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def download_models( |
|
|
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False |
|
|
) -> Path: |
|
|
from huggingface_hub import snapshot_download |
|
|
from huggingface_hub.utils import disable_progress_bars |
|
|
|
|
|
if not progress: |
|
|
disable_progress_bars() |
|
|
download_path = snapshot_download( |
|
|
repo_id="ds4sd/DocumentFigureClassifier", |
|
|
force_download=force, |
|
|
local_dir=local_dir, |
|
|
revision="v1.0.0", |
|
|
) |
|
|
|
|
|
return Path(download_path) |
|
|
|
|
|
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: |
|
|
""" |
|
|
Determines if the given element can be processed by the classifier. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
doc : DoclingDocument |
|
|
The document containing the element. |
|
|
element : NodeItem |
|
|
The element to be checked. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
bool |
|
|
True if the element is a PictureItem and processing is enabled; False otherwise. |
|
|
""" |
|
|
return self.enabled and isinstance(element, PictureItem) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
doc: DoclingDocument, |
|
|
element_batch: Iterable[NodeItem], |
|
|
) -> Iterable[NodeItem]: |
|
|
""" |
|
|
Processes a batch of elements and enriches them with classification predictions. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
doc : DoclingDocument |
|
|
The document containing the elements to be processed. |
|
|
element_batch : Iterable[NodeItem] |
|
|
A batch of pictures to classify. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Iterable[NodeItem] |
|
|
An iterable of NodeItem objects after processing. The field |
|
|
'data.classification' is added containing the classification for each picture. |
|
|
""" |
|
|
if not self.enabled: |
|
|
for element in element_batch: |
|
|
yield element |
|
|
return |
|
|
|
|
|
images: List[Union[Image.Image, np.ndarray]] = [] |
|
|
elements: List[PictureItem] = [] |
|
|
for el in element_batch: |
|
|
assert isinstance(el, PictureItem) |
|
|
elements.append(el) |
|
|
img = el.get_image(doc) |
|
|
assert img is not None |
|
|
images.append(img) |
|
|
|
|
|
outputs = self.document_picture_classifier.predict(images) |
|
|
|
|
|
for element, output in zip(elements, outputs): |
|
|
element.annotations.append( |
|
|
PictureClassificationData( |
|
|
provenance="DocumentPictureClassifier", |
|
|
predicted_classes=[ |
|
|
PictureClassificationClass( |
|
|
class_name=pred[0], |
|
|
confidence=pred[1], |
|
|
) |
|
|
for pred in output |
|
|
], |
|
|
) |
|
|
) |
|
|
|
|
|
yield element |
|
|
|