granite-docling-258M-onnx / examples /rust_ort_example.rs
glamberson's picture
Add Rust ORT usage example with complete implementation
d0853a4 verified
// granite-docling ONNX Rust Example with ORT crate
// Demonstrates how to use granite-docling ONNX model in Rust applications
use anyhow::Result;
use ort::{
execution_providers::ExecutionProvider,
session::{Session, builder::GraphOptimizationLevel},
inputs, value::TensorRef,
};
use ndarray::{Array1, Array2, Array4};
/// granite-docling ONNX inference engine
pub struct GraniteDoclingONNX {
session: Session,
}
impl GraniteDoclingONNX {
/// Load granite-docling ONNX model
pub fn new(model_path: &str) -> Result<Self> {
println!("Loading granite-docling ONNX model from: {}", model_path);
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_execution_providers([
ExecutionProvider::DirectML, // Windows ML acceleration
ExecutionProvider::CUDA, // NVIDIA acceleration
ExecutionProvider::CPU, // Universal fallback
])?
.commit_from_file(model_path)?;
// Print model information
println!("Model loaded successfully:");
for (i, input) in session.inputs()?.iter().enumerate() {
println!(" Input {}: {} {:?}", i, input.name(), input.input_type());
}
for (i, output) in session.outputs()?.iter().enumerate() {
println!(" Output {}: {} {:?}", i, output.name(), output.output_type());
}
Ok(Self { session })
}
/// Process document image to DocTags markup
pub async fn process_document(
&self,
document_image: Array4<f32>, // [batch, channels, height, width]
prompt: &str,
) -> Result<String> {
println!("Processing document with granite-docling...");
// Prepare text inputs (simplified tokenization)
let input_ids = self.tokenize_prompt(prompt)?;
let attention_mask = Array2::ones((1, input_ids.len()));
// Convert to required input format
let input_ids_2d = Array2::from_shape_vec(
(1, input_ids.len()),
input_ids.iter().map(|&x| x as i64).collect(),
)?;
// Run inference
let outputs = self.session.run(inputs![
"pixel_values" => TensorRef::from_array_view(&document_image.view())?,
"input_ids" => TensorRef::from_array_view(&input_ids_2d.view())?,
"attention_mask" => TensorRef::from_array_view(&attention_mask.view())?,
])?;
// Extract logits and decode to text
let logits = outputs["logits"].try_extract_tensor::<f32>()?;
let tokens = self.decode_logits_to_tokens(&logits)?;
let doctags = self.detokenize_to_doctags(&tokens)?;
println!("✅ Document processing complete");
Ok(doctags)
}
/// Simple tokenization (in practice, use proper tokenizer)
fn tokenize_prompt(&self, prompt: &str) -> Result<Vec<u32>> {
// Simplified tokenization - in practice, load tokenizer.json
// and use proper HuggingFace tokenization
let tokens: Vec<u32> = prompt
.split_whitespace()
.enumerate()
.map(|(i, _)| (i + 1) as u32)
.collect();
Ok(tokens)
}
/// Decode logits to most likely tokens
fn decode_logits_to_tokens(&self, logits: &ndarray::ArrayViewD<f32>) -> Result<Vec<u32>> {
// Find argmax for each position
let tokens: Vec<u32> = logits
.axis_iter(ndarray::Axis(2))
.map(|logit_slice| {
logit_slice
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx as u32)
.unwrap_or(0)
})
.collect();
Ok(tokens)
}
/// Convert tokens back to DocTags markup
fn detokenize_to_doctags(&self, tokens: &[u32]) -> Result<String> {
// In practice, use granite-docling tokenizer to convert tokens → text
// Then parse the text as DocTags markup
// Simplified example
let mock_doctags = format!(
"<doctag>\n <text>Document processed with {} tokens</text>\n</doctag>",
tokens.len()
);
Ok(mock_doctags)
}
}
/// Preprocess document image for granite-docling inference
pub fn preprocess_document_image(image_path: &str) -> Result<Array4<f32>> {
// Load image and resize to 512x512 (SigLIP2 requirement)
// Normalize with SigLIP2 parameters
// Convert to [batch, channels, height, width] format
// Simplified example - in practice, use image processing library
let document_image = Array4::zeros((1, 3, 512, 512));
Ok(document_image)
}
#[tokio::main]
async fn main() -> Result<()> {
println!("granite-docling ONNX Rust Example");
// Load granite-docling ONNX model
let model_path = "granite-docling-258M-onnx/model.onnx";
let granite_docling = GraniteDoclingONNX::new(model_path)?;
// Preprocess document image
let document_image = preprocess_document_image("example_document.png")?;
// Process document
let prompt = "Convert this document to DocTags:";
let doctags = granite_docling.process_document(document_image, prompt).await?;
println!("Generated DocTags:");
println!("{}", doctags);
Ok(())
}
// Cargo.toml dependencies:
/*
[dependencies]
ort = { version = "2.0.0-rc.10", features = ["directml", "cuda", "tensorrt"] }
ndarray = "0.15"
anyhow = "1.0"
tokio = { version = "1.0", features = ["full"] }
*/