| import gradio as gr | |
| from urllib.request import urlopen | |
| from PIL import Image | |
| import timm | |
| import torch | |
| import time | |
| model = timm.create_model("hf_hub:Marqo/nsfw-image-detection-384", pretrained=True) | |
| model = model.eval() | |
| data_config = timm.data.resolve_model_data_config(model) | |
| transforms = timm.data.create_transform(**data_config, is_training=False) | |
| def predict(image): | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| input_tensor = transforms(image).unsqueeze(0) | |
| output = model(input_tensor).softmax(dim=-1).cpu() | |
| class_names = model.pretrained_cfg["label_names"] | |
| result = {class_names[i]: float(output[0, i]) for i in range(len(class_names))} | |
| end_time = time.time() | |
| inference_time = end_time - start_time | |
| return result, f"Inference time: {inference_time:.2f} seconds" | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", height=512), | |
| outputs=[ | |
| gr.Label(num_top_classes=2), | |
| gr.Textbox(label="Inference Time") | |
| ], | |
| title="NSFW Image Detection", | |
| description=( | |
| "Upload an image to detect if it is **NSFW (Not Safe For Work)** or **Safe For Work (SFW)**.\n\n" | |
| "This app uses the [Marqo/nsfw-image-detection-384](https://huggingface.co/Marqo/nsfw-image-detection-384) " | |
| "image classification model from Hugging Face's `timm` library." | |
| ) | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |