Spaces:
Runtime error
Runtime error
| # Shared ONNX inference function for use by app.py and model_loader.py | |
| def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict): | |
| from .onnx_model_loader import get_onnx_model_from_cache, load_onnx_model_and_preprocessor | |
| from .utils import softmax | |
| import numpy as np | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| _onnx_model_cache = {} | |
| try: | |
| ort_session, _, _ = get_onnx_model_from_cache(hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor) | |
| for input_meta in ort_session.get_inputs(): | |
| logger.info(f"Debug: ONNX model expected input name: {input_meta.name}, shape: {input_meta.shape}, type: {input_meta.type}") | |
| logger.info(f"Debug: preprocessed_image_np shape: {preprocessed_image_np.shape}") | |
| ort_inputs = {ort_session.get_inputs()[0].name: preprocessed_image_np} | |
| ort_outputs = ort_session.run(None, ort_inputs) | |
| logits = ort_outputs[0] | |
| logger.info(f"Debug: logits type: {type(logits)}, shape: {logits.shape}") | |
| probabilities = softmax(logits[0]) | |
| return {"logits": logits, "probabilities": probabilities} | |
| except Exception as e: | |
| logger.error(f"Error during ONNX inference for {hf_model_id}: {e}") | |
| return {"logits": np.array([]), "probabilities": np.array([])} | |
| import numpy as np | |
| from torchvision import transforms | |
| from PIL import Image | |
| import logging | |
| def preprocess_onnx_input(image, preprocessor_config): | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| initial_resize_size = preprocessor_config.get('size', {'height': 224, 'width': 224}) | |
| crop_size = preprocessor_config.get('crop_size', initial_resize_size['height']) | |
| mean = preprocessor_config.get('image_mean', [0.485, 0.456, 0.406]) | |
| std = preprocessor_config.get('image_std', [0.229, 0.224, 0.225]) | |
| transform = transforms.Compose([ | |
| transforms.Resize((initial_resize_size['height'], initial_resize_size['width'])), | |
| transforms.CenterCrop(crop_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=mean, std=std), | |
| ]) | |
| input_tensor = transform(image) | |
| return input_tensor.unsqueeze(0).cpu().numpy() | |
| def postprocess_onnx_output(onnx_output, model_config): | |
| logger = logging.getLogger(__name__) | |
| class_names_map = model_config.get('id2label') | |
| if class_names_map: | |
| class_names = [class_names_map[k] for k in sorted(class_names_map.keys())] | |
| elif model_config.get('num_classes') == 1: | |
| class_names = ['Fake', 'Real'] | |
| else: | |
| class_names = {0: 'Fake', 1: 'Real'} | |
| class_names = [class_names[i] for i in sorted(class_names.keys())] | |
| probabilities = onnx_output.get("probabilities") | |
| if probabilities is not None: | |
| if model_config.get('num_classes') == 1 and len(probabilities) == 2: | |
| fake_prob = float(probabilities[0]) | |
| real_prob = float(probabilities[1]) | |
| return {class_names[0]: fake_prob, class_names[1]: real_prob} | |
| elif len(probabilities) == len(class_names): | |
| return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))} | |
| else: | |
| logger.warning("ONNX post-processing: Probabilities length mismatch with class names.") | |
| return {name: 0.0 for name in class_names} | |
| else: | |
| logger.warning("ONNX post-processing failed: 'probabilities' key not found in output.") | |
| return {name: 0.0 for name in class_names} | |