Spaces:
Runtime error
Runtime error
| from utils.onnx_helpers import postprocess_onnx_output | |
| # Add missing import for infer_onnx_model | |
| from utils.onnx_helpers import infer_onnx_model | |
| # Add missing import for preprocess_onnx_input | |
| from utils.onnx_helpers import preprocess_onnx_input | |
| """ | |
| Model loading and registration logic for OpenSight Deepfake Detection Playground. | |
| Handles ONNX, HuggingFace, and Gradio API model registration and metadata. | |
| """ | |
| from utils.registry import register_model, MODEL_REGISTRY, ModelEntry | |
| from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache | |
| from utils.utils import preprocess_resize_256, postprocess_logits, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api | |
| from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| # Model paths and class names (copied from app_mcp.py) | |
| MODEL_PATHS = { | |
| "model_1": "LPX55/detection-model-1-ONNX", | |
| "model_2": "LPX55/detection-model-2-ONNX", | |
| "model_3": "LPX55/detection-model-3-ONNX", | |
| "model_4": "cmckinle/sdxl-flux-detector_v1.1", | |
| "model_5": "LPX55/detection-model-5-ONNX", | |
| "model_6": "LPX55/detection-model-6-ONNX", | |
| "model_7": "LPX55/detection-model-7-ONNX", | |
| "model_8": "aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT" | |
| } | |
| CLASS_NAMES = { | |
| "model_1": ['artificial', 'real'], | |
| "model_2": ['AI Image', 'Real Image'], | |
| "model_3": ['artificial', 'human'], | |
| "model_4": ['AI', 'Real'], | |
| "model_5": ['Realism', 'Deepfake'], | |
| "model_6": ['ai_gen', 'human'], | |
| "model_7": ['Fake', 'Real'], | |
| "model_8": ['Fake', 'Real'], | |
| } | |
| # Cache for ONNX sessions and preprocessors | |
| _onnx_model_cache = {} | |
| def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None): | |
| entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset) | |
| MODEL_REGISTRY[model_id] = entry | |
| class ONNXModelWrapper: | |
| def __init__(self, hf_model_id): | |
| self.hf_model_id = hf_model_id | |
| self._session = None | |
| self._preprocessor_config = None | |
| self._model_config = None | |
| def load(self): | |
| if self._session is None: | |
| self._session, self._preprocessor_config, self._model_config = get_onnx_model_from_cache( | |
| self.hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor | |
| ) | |
| def __call__(self, image_np): | |
| self.load() | |
| return infer_onnx_model(self.hf_model_id, image_np, self._model_config) | |
| def preprocess(self, image: Image.Image): | |
| self.load() | |
| return preprocess_onnx_input(image, self._preprocessor_config) | |
| def postprocess(self, onnx_output: dict, class_names_from_registry: list): | |
| self.load() | |
| return postprocess_onnx_output(onnx_output, self._model_config) | |
| # The main registration function | |
| def register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output): | |
| for model_key, hf_model_path in MODEL_PATHS.items(): | |
| model_num = model_key.replace("model_", "").upper() | |
| contributor = "Unknown" | |
| architecture = "Unknown" | |
| dataset = "TBA" | |
| current_class_names = CLASS_NAMES.get(model_key, []) | |
| if "ONNX" in hf_model_path: | |
| onnx_wrapper_instance = ONNXModelWrapper(hf_model_path) | |
| if model_key == "model_1": | |
| contributor = "haywoodsloan" | |
| architecture = "SwinV2" | |
| dataset = "Mixed" | |
| elif model_key == "model_2": | |
| contributor = "Heem2" | |
| architecture = "ViT" | |
| dataset = "Mixed" | |
| elif model_key == "model_3": | |
| contributor = "Organika" | |
| architecture = "VIT" | |
| dataset = "SDXL" | |
| elif model_key == "model_5": | |
| contributor = "prithivMLmods" | |
| architecture = "VIT" | |
| elif model_key == "model_6": | |
| contributor = "ideepankarsharma2003" | |
| architecture = "SWINv1" | |
| dataset = "SDXL, Midjourney" | |
| elif model_key == "model_7": | |
| contributor = "date3k2" | |
| architecture = "VIT" | |
| display_name_parts = [model_num] | |
| if architecture and architecture not in ["Unknown"]: | |
| display_name_parts.append(architecture) | |
| if dataset and dataset not in ["TBA"]: | |
| display_name_parts.append(dataset) | |
| display_name = "-".join(display_name_parts) + "_ONNX" | |
| register_model_with_metadata( | |
| model_id=model_key, | |
| model=onnx_wrapper_instance, | |
| preprocess=onnx_wrapper_instance.preprocess, | |
| postprocess=onnx_wrapper_instance.postprocess, | |
| class_names=current_class_names, | |
| display_name=display_name, | |
| contributor=contributor, | |
| model_path=hf_model_path, | |
| architecture=architecture, | |
| dataset=dataset | |
| ) | |
| elif model_key == "model_8": | |
| contributor = "aiwithoutborders-xyz" | |
| architecture = "ViT" | |
| dataset = "Massive" | |
| display_name_parts = [model_num] | |
| if architecture and architecture not in ["Unknown"]: | |
| display_name_parts.append(architecture) | |
| if dataset and dataset not in ["TBA"]: | |
| display_name_parts.append(dataset) | |
| display_name = "-".join(display_name_parts) | |
| register_model_with_metadata( | |
| model_id=model_key, | |
| model=infer_gradio_api, | |
| preprocess=preprocess_gradio_api, | |
| postprocess=postprocess_gradio_api, | |
| class_names=current_class_names, | |
| display_name=display_name, | |
| contributor=contributor, | |
| model_path=hf_model_path, | |
| architecture=architecture, | |
| dataset=dataset | |
| ) | |
| elif model_key == "model_4": | |
| contributor = "cmckinle" | |
| architecture = "VIT" | |
| dataset = "SDXL, FLUX" | |
| display_name_parts = [model_num] | |
| if architecture and architecture not in ["Unknown"]: | |
| display_name_parts.append(architecture) | |
| if dataset and dataset not in ["TBA"]: | |
| display_name_parts.append(dataset) | |
| display_name = "-".join(display_name_parts) | |
| current_processor = AutoFeatureExtractor.from_pretrained(hf_model_path, device=device) | |
| model_instance = AutoModelForImageClassification.from_pretrained(hf_model_path).to(device) | |
| preprocess_func = preprocess_resize_256 | |
| postprocess_func = postprocess_logits | |
| def custom_infer(image, processor_local=current_processor, model_local=model_instance): | |
| inputs = processor_local(image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model_local(**inputs) | |
| return outputs | |
| model_instance = custom_infer | |
| register_model_with_metadata( | |
| model_id=model_key, | |
| model=model_instance, | |
| preprocess=preprocess_func, | |
| postprocess=postprocess_func, | |
| class_names=current_class_names, | |
| display_name=display_name, | |
| contributor=contributor, | |
| model_path=hf_model_path, | |
| architecture=architecture, | |
| dataset=dataset | |
| ) | |
| else: | |
| pass # Fallback for any unhandled models | |