|
|
|
|
|
|
|
|
|
|
|
use anyhow::Result; |
|
|
use ort::{ |
|
|
execution_providers::ExecutionProvider, |
|
|
session::{Session, builder::GraphOptimizationLevel}, |
|
|
inputs, value::TensorRef, |
|
|
}; |
|
|
use ndarray::{Array1, Array2, Array4}; |
|
|
|
|
|
|
|
|
pub struct GraniteDoclingONNX { |
|
|
session: Session, |
|
|
} |
|
|
|
|
|
impl GraniteDoclingONNX { |
|
|
|
|
|
pub fn new(model_path: &str) -> Result<Self> { |
|
|
println!("Loading granite-docling ONNX model from: {}", model_path); |
|
|
|
|
|
let session = Session::builder()? |
|
|
.with_optimization_level(GraphOptimizationLevel::Level3)? |
|
|
.with_execution_providers([ |
|
|
ExecutionProvider::DirectML, |
|
|
ExecutionProvider::CUDA, |
|
|
ExecutionProvider::CPU, |
|
|
])? |
|
|
.commit_from_file(model_path)?; |
|
|
|
|
|
|
|
|
println!("Model loaded successfully:"); |
|
|
for (i, input) in session.inputs()?.iter().enumerate() { |
|
|
println!(" Input {}: {} {:?}", i, input.name(), input.input_type()); |
|
|
} |
|
|
for (i, output) in session.outputs()?.iter().enumerate() { |
|
|
println!(" Output {}: {} {:?}", i, output.name(), output.output_type()); |
|
|
} |
|
|
|
|
|
Ok(Self { session }) |
|
|
} |
|
|
|
|
|
|
|
|
pub async fn process_document( |
|
|
&self, |
|
|
document_image: Array4<f32>, |
|
|
prompt: &str, |
|
|
) -> Result<String> { |
|
|
|
|
|
println!("Processing document with granite-docling..."); |
|
|
|
|
|
|
|
|
let input_ids = self.tokenize_prompt(prompt)?; |
|
|
let attention_mask = Array2::ones((1, input_ids.len())); |
|
|
|
|
|
|
|
|
let input_ids_2d = Array2::from_shape_vec( |
|
|
(1, input_ids.len()), |
|
|
input_ids.iter().map(|&x| x as i64).collect(), |
|
|
)?; |
|
|
|
|
|
|
|
|
let outputs = self.session.run(inputs![ |
|
|
"pixel_values" => TensorRef::from_array_view(&document_image.view())?, |
|
|
"input_ids" => TensorRef::from_array_view(&input_ids_2d.view())?, |
|
|
"attention_mask" => TensorRef::from_array_view(&attention_mask.view())?, |
|
|
])?; |
|
|
|
|
|
|
|
|
let logits = outputs["logits"].try_extract_tensor::<f32>()?; |
|
|
let tokens = self.decode_logits_to_tokens(&logits)?; |
|
|
let doctags = self.detokenize_to_doctags(&tokens)?; |
|
|
|
|
|
println!("✅ Document processing complete"); |
|
|
Ok(doctags) |
|
|
} |
|
|
|
|
|
|
|
|
fn tokenize_prompt(&self, prompt: &str) -> Result<Vec<u32>> { |
|
|
|
|
|
|
|
|
let tokens: Vec<u32> = prompt |
|
|
.split_whitespace() |
|
|
.enumerate() |
|
|
.map(|(i, _)| (i + 1) as u32) |
|
|
.collect(); |
|
|
|
|
|
Ok(tokens) |
|
|
} |
|
|
|
|
|
|
|
|
fn decode_logits_to_tokens(&self, logits: &ndarray::ArrayViewD<f32>) -> Result<Vec<u32>> { |
|
|
|
|
|
let tokens: Vec<u32> = logits |
|
|
.axis_iter(ndarray::Axis(2)) |
|
|
.map(|logit_slice| { |
|
|
logit_slice |
|
|
.iter() |
|
|
.enumerate() |
|
|
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) |
|
|
.map(|(idx, _)| idx as u32) |
|
|
.unwrap_or(0) |
|
|
}) |
|
|
.collect(); |
|
|
|
|
|
Ok(tokens) |
|
|
} |
|
|
|
|
|
|
|
|
fn detokenize_to_doctags(&self, tokens: &[u32]) -> Result<String> { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let mock_doctags = format!( |
|
|
"<doctag>\n <text>Document processed with {} tokens</text>\n</doctag>", |
|
|
tokens.len() |
|
|
); |
|
|
|
|
|
Ok(mock_doctags) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn preprocess_document_image(image_path: &str) -> Result<Array4<f32>> { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let document_image = Array4::zeros((1, 3, 512, 512)); |
|
|
|
|
|
Ok(document_image) |
|
|
} |
|
|
|
|
|
#[tokio::main] |
|
|
async fn main() -> Result<()> { |
|
|
println!("granite-docling ONNX Rust Example"); |
|
|
|
|
|
|
|
|
let model_path = "granite-docling-258M-onnx/model.onnx"; |
|
|
let granite_docling = GraniteDoclingONNX::new(model_path)?; |
|
|
|
|
|
|
|
|
let document_image = preprocess_document_image("example_document.png")?; |
|
|
|
|
|
|
|
|
let prompt = "Convert this document to DocTags:"; |
|
|
let doctags = granite_docling.process_document(document_image, prompt).await?; |
|
|
|
|
|
println!("Generated DocTags:"); |
|
|
println!("{}", doctags); |
|
|
|
|
|
Ok(()) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|