|
|
|
|
|
""" |
|
|
granite-docling ONNX Usage Example with ONNX Runtime |
|
|
Demonstrates how to use the converted granite-docling model for document processing |
|
|
""" |
|
|
|
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import json |
|
|
|
|
|
def load_granite_docling_onnx(model_path: str): |
|
|
"""Load granite-docling ONNX model""" |
|
|
print(f"Loading granite-docling ONNX model from: {model_path}") |
|
|
|
|
|
session = ort.InferenceSession(model_path) |
|
|
|
|
|
|
|
|
print("Model Information:") |
|
|
print(f" Inputs:") |
|
|
for inp in session.get_inputs(): |
|
|
print(f" {inp.name}: {inp.shape} ({inp.type})") |
|
|
|
|
|
print(f" Outputs:") |
|
|
for out in session.get_outputs(): |
|
|
print(f" {out.name}: {out.shape} ({out.type})") |
|
|
|
|
|
return session |
|
|
|
|
|
def preprocess_document_image(image_path: str) -> np.ndarray: |
|
|
"""Preprocess document image for granite-docling inference""" |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
image = image.resize((512, 512)) |
|
|
|
|
|
|
|
|
pixel_values = np.array(image).astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
mean = np.array([0.485, 0.456, 0.406]) |
|
|
std = np.array([0.229, 0.224, 0.225]) |
|
|
|
|
|
pixel_values = (pixel_values - mean) / std |
|
|
|
|
|
|
|
|
pixel_values = pixel_values.transpose(2, 0, 1) |
|
|
pixel_values = pixel_values[np.newaxis, :] |
|
|
|
|
|
return pixel_values |
|
|
|
|
|
def create_text_inputs(prompt: str = "Convert this document to DocTags:") -> tuple: |
|
|
"""Create text inputs for granite-docling""" |
|
|
|
|
|
|
|
|
|
|
|
tokens = [1] + [i for i in range(2, len(prompt.split()) + 2)] + [2] |
|
|
|
|
|
input_ids = np.array([tokens], dtype=np.int64) |
|
|
attention_mask = np.ones((1, len(tokens)), dtype=np.int64) |
|
|
|
|
|
return input_ids, attention_mask |
|
|
|
|
|
def run_granite_docling_inference(session, image_path: str): |
|
|
"""Run complete granite-docling inference""" |
|
|
|
|
|
print(f"Processing document: {image_path}") |
|
|
|
|
|
|
|
|
pixel_values = preprocess_document_image(image_path) |
|
|
input_ids, attention_mask = create_text_inputs() |
|
|
|
|
|
print(f"Input shapes:") |
|
|
print(f" pixel_values: {pixel_values.shape}") |
|
|
print(f" input_ids: {input_ids.shape}") |
|
|
print(f" attention_mask: {attention_mask.shape}") |
|
|
|
|
|
|
|
|
outputs = session.run(None, { |
|
|
'pixel_values': pixel_values, |
|
|
'input_ids': input_ids, |
|
|
'attention_mask': attention_mask |
|
|
}) |
|
|
|
|
|
logits = outputs[0] |
|
|
print(f"Output logits shape: {logits.shape}") |
|
|
|
|
|
|
|
|
predicted_tokens = np.argmax(logits, axis=-1) |
|
|
print(f"Predicted tokens shape: {predicted_tokens.shape}") |
|
|
|
|
|
|
|
|
print("✅ Inference completed successfully") |
|
|
|
|
|
return predicted_tokens |
|
|
|
|
|
def main(): |
|
|
"""Main example usage""" |
|
|
|
|
|
model_path = "model.onnx" |
|
|
|
|
|
try: |
|
|
|
|
|
session = load_granite_docling_onnx(model_path) |
|
|
|
|
|
|
|
|
|
|
|
image_path = "example_document.png" |
|
|
|
|
|
if os.path.exists(image_path): |
|
|
result = run_granite_docling_inference(session, image_path) |
|
|
print("✅ granite-docling ONNX inference successful!") |
|
|
else: |
|
|
print("⚠️ No example document provided") |
|
|
print(" Create a test document image to run inference") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Example failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import os |
|
|
main() |