Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from PIL import Image as PILImage | |
| from transformers import AutoImageProcessor, SiglipForImageClassification | |
| import os | |
| import warnings | |
| # --- Configuration --- | |
| MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --- Suppress specific warnings --- | |
| # Suppress the specific PIL warning about potential decompression bombs | |
| warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.") | |
| # Suppress transformers warning about loading weights without specifying revision | |
| warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*") | |
| # --- Load Model and Processor (Load once at startup) --- | |
| print(f"Using device: {DEVICE}") | |
| print(f"Loading processor from: {MODEL_IDENTIFIER}") | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(MODEL_IDENTIFIER) | |
| print(f"Loading model from: {MODEL_IDENTIFIER}") | |
| model = SiglipForImageClassification.from_pretrained(MODEL_IDENTIFIER) | |
| model.to(DEVICE) | |
| model.eval() | |
| print("Model and processor loaded successfully.") | |
| except Exception as e: | |
| print(f"FATAL: Error loading model or processor: {e}") | |
| # If the model fails to load, we raise an exception to stop the app | |
| raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e | |
| # --- Prediction Function --- | |
| def classify_image(image_pil): | |
| """ | |
| Classifies an image as AI-generated or Human-made. | |
| Args: | |
| image_pil (PIL.Image.Image): Input image in PIL format. | |
| Returns: | |
| dict: A dictionary mapping class labels ('ai', 'human') to their | |
| confidence scores. Returns an empty dict if input is None. | |
| """ | |
| if image_pil is None: | |
| # Handle case where the user clears the image input | |
| print("Warning: No image provided.") | |
| return {} # Return empty dict, Gradio Label handles this | |
| print("Processing image...") | |
| try: | |
| # Ensure image is RGB | |
| image = image_pil.convert("RGB") | |
| # Preprocess using the loaded processor | |
| inputs = processor(images=image, return_tensors="pt").to(DEVICE) | |
| # Perform inference | |
| print("Running inference...") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Get probabilities using softmax | |
| # outputs.logits is shape [1, num_labels], softmax over the last dim | |
| probabilities = torch.softmax(logits, dim=-1)[0] # Get probabilities for the first (and only) image | |
| # Create a dictionary of label -> score | |
| results = {} | |
| for i, prob in enumerate(probabilities): | |
| label = model.config.id2label[i] | |
| results[label] = prob.item() # Use .item() to get Python float | |
| print(f"Prediction results: {results}") | |
| return results | |
| except Exception as e: | |
| print(f"Error during prediction: {e}") | |
| # Optionally raise a Gradio error to show it in the UI | |
| # raise gr.Error(f"Error processing image: {e}") | |
| return {"Error": f"Processing failed: {e}"} # Or return an error message | |
| # --- Gradio Interface Definition --- | |
| # Define Example Images (Optional, but recommended) | |
| # Create an 'examples' folder in your Space repo and put images there | |
| example_dir = "examples" | |
| example_images = [] | |
| if os.path.exists(example_dir): | |
| for img_name in os.listdir(example_dir): | |
| if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): | |
| example_images.append(os.path.join(example_dir, img_name)) | |
| print(f"Found examples: {example_images}") | |
| else: | |
| print("No 'examples' directory found. Examples will not be shown.") | |
| # Define the Gradio interface | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil", label="Upload Image", sources=["upload", "webcam", "clipboard"]), # Use PIL format as input | |
| outputs=gr.Label(num_top_classes=2, label="Prediction Results"), # Use gr.Label for classification output | |
| title="AI vs Human Image Detector", | |
| description=( | |
| f"Upload an image to classify if it was likely generated by AI or created by a human. " | |
| f"Uses the `{MODEL_IDENTIFIER}` model on Hugging Face. Running on **{str(DEVICE).upper()}**." | |
| ), | |
| article=( | |
| "<div>" | |
| "<p>This tool uses a SigLIP model fine-tuned for distinguishing between AI-generated and human-made images.</p>" | |
| f"<p>Model Card: <a href='https://huggingface.co/{MODEL_IDENTIFIER}' target='_blank'>{MODEL_IDENTIFIER}</a></p>" | |
| "<p>Fine tuning code available at <a href='https://exnrt.com/blog/ai/fine-tuning-siglip2/' target='_blank'>https://exnrt.com/blog/ai/fine-tuning-siglip2/</a></p>" | |
| "</div>" | |
| ), | |
| examples=example_images if example_images else None, # Only add examples if found | |
| cache_examples= True if example_images else False, # Cache results for examples if they exist | |
| allow_flagging="never" # Or "auto" if you want users to flag issues | |
| ) | |
| # --- Launch the App --- | |
| if __name__ == "__main__": | |
| print("Launching Gradio interface...") | |
| iface.launch() | |
| print("Gradio interface launched.") |