Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from PIL import Image as PILImage | |
| from transformers import AutoImageProcessor, SiglipForImageClassification | |
| import os | |
| import warnings | |
| # --- Configuration --- | |
| MODEL_IDENTIFIER = r"Ateeqq/ai-vs-human-image-detector" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --- Suppress specific warnings --- | |
| # Suppress the specific PIL warning about potential decompression bombs | |
| warnings.filterwarnings("ignore", message="Possibly corrupt EXIF data.") | |
| # Suppress transformers warning about loading weights without specifying revision | |
| warnings.filterwarnings("ignore", message=".*You are using the default legacy behaviour.*") | |
| # --- Load Model and Processor (Load once at startup) --- | |
| print(f"Using device: {DEVICE}") | |
| print(f"Loading processor from: {MODEL_IDENTIFIER}") | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(MODEL_IDENTIFIER) | |
| print(f"Loading model from: {MODEL_IDENTIFIER}") | |
| model = SiglipForImageClassification.from_pretrained(MODEL_IDENTIFIER) | |
| model.to(DEVICE) | |
| model.eval() | |
| print("Model and processor loaded successfully.") | |
| except Exception as e: | |
| print(f"FATAL: Error loading model or processor: {e}") | |
| # If the model fails to load, we raise an exception to stop the app | |
| raise gr.Error(f"Failed to load the model: {e}. Cannot start the application.") from e | |
| # --- Prediction Function --- | |
| def classify_image(image_pil): | |
| """ | |
| Classifies an image as AI-generated or Human-made. | |
| Args: | |
| image_pil (PIL.Image.Image): Input image in PIL format. | |
| Returns: | |
| dict: A dictionary mapping class labels ('ai', 'human') to their | |
| confidence scores. Returns an empty dict if input is None. | |
| """ | |
| if image_pil is None: | |
| # Handle case where the user clears the image input | |
| print("Warning: No image provided.") | |
| return {} # Return empty dict, Gradio Label handles this | |
| print("Processing image...") | |
| try: | |
| # Ensure image is RGB | |
| image = image_pil.convert("RGB") | |
| # Preprocess using the loaded processor | |
| inputs = processor(images=image, return_tensors="pt").to(DEVICE) | |
| # Perform inference | |
| print("Running inference...") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Get probabilities using softmax | |
| # outputs.logits is shape [1, num_labels], softmax over the last dim | |
| probabilities = torch.softmax(logits, dim=-1)[0] # Get probabilities for the first (and only) image | |
| # Create a dictionary of label -> score | |
| results = {} | |
| for i, prob in enumerate(probabilities): | |
| label = model.config.id2label[i] | |
| results[label] = round(prob.item(), 4) # Round for cleaner display | |
| print(f"Prediction results: {results}") | |
| return results | |
| except Exception as e: | |
| print(f"Error during prediction: {e}") | |
| # Return error in the format expected by gr.Label | |
| # Provide a user-friendly error message in the output | |
| return {"Error": f"Processing failed. Please try again or use a different image."} | |
| # --- Define Example Images --- | |
| example_dir = "examples" | |
| example_images = [] | |
| if os.path.exists(example_dir) and os.listdir(example_dir): # Check if dir exists AND is not empty | |
| for img_name in os.listdir(example_dir): | |
| if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): | |
| example_images.append(os.path.join(example_dir, img_name)) | |
| if example_images: | |
| print(f"Found examples: {example_images}") | |
| else: | |
| print("No valid image files found in 'examples' directory.") | |
| else: | |
| print("No 'examples' directory found or it's empty. Examples will not be shown.") | |
| # --- Custom CSS --- | |
| css = """ | |
| body { font-family: 'Inter', sans-serif; } /* Use a clean sans-serif font */ | |
| /* Style the main title */ | |
| #app-title { | |
| text-align: center; | |
| font-weight: bold; | |
| font-size: 2.5em; /* Larger title */ | |
| margin-bottom: 5px; /* Reduced space below title */ | |
| color: #2c3e50; /* Darker color */ | |
| } | |
| /* Style the description */ | |
| #app-description { | |
| text-align: center; | |
| font-size: 1.1em; | |
| margin-bottom: 25px; /* More space below description */ | |
| color: #576574; /* Subdued color */ | |
| } | |
| #app-description code { /* Style model name */ | |
| font-weight: bold; | |
| background-color: #f1f2f6; | |
| padding: 2px 5px; | |
| border-radius: 4px; | |
| } | |
| #app-description strong { /* Style device name */ | |
| color: #1abc9c; /* Highlight color for device */ | |
| } | |
| /* Style the results area */ | |
| #prediction-label .label-name { font-weight: bold; font-size: 1.1em; } | |
| #prediction-label .confidence { font-size: 1em; } | |
| /* Style the results heading */ | |
| #results-heading { | |
| text-align: center; | |
| font-size: 1.2em; /* Slightly larger heading for results */ | |
| margin-bottom: 10px; /* Space below heading */ | |
| color: #34495e; /* Match other heading colors */ | |
| } | |
| /* Style the examples section */ | |
| .gradio-container .examples-container { padding-top: 15px; } | |
| .gradio-container .examples-header { font-size: 1.1em; font-weight: bold; margin-bottom: 10px; color: #34495e; } | |
| /* Add a subtle border/shadow to input/output columns for definition */ | |
| #input-column, #output-column { | |
| border: 1px solid #e0e0e0; | |
| border-radius: 12px; /* More rounded corners */ | |
| padding: 20px; | |
| box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); /* Subtle shadow */ | |
| background-color: #ffffff; /* Ensure white background */ | |
| } | |
| /* Footer styling */ | |
| #app-footer { | |
| margin-top: 40px; | |
| padding-top: 20px; | |
| border-top: 1px solid #dfe6e9; | |
| text-align: center; | |
| font-size: 0.9em; | |
| color: #8395a7; | |
| } | |
| #app-footer a { color: #3498db; text-decoration: none; } | |
| #app-footer a:hover { text-decoration: underline; } | |
| """ | |
| # --- Gradio Interface using Blocks and Theme --- | |
| # Choose a theme: gr.themes.Soft(), gr.themes.Monochrome(), gr.themes.Glass(), etc. | |
| theme = gr.themes.Soft( | |
| primary_hue="emerald", # Color scheme based on emerald green | |
| secondary_hue="blue", | |
| neutral_hue="slate", | |
| radius_size=gr.themes.sizes.radius_lg, # Larger corner radius | |
| spacing_size=gr.themes.sizes.spacing_lg, # More spacing | |
| ).set( | |
| # Further fine-tuning | |
| body_background_fill="#f8f9fa", # Very light grey background | |
| block_radius="12px", | |
| ) | |
| with gr.Blocks(theme=theme, css=css) as iface: | |
| # Title and Description using Markdown for better formatting | |
| gr.Markdown("# AI vs Human Image Detector", elem_id="app-title") | |
| gr.Markdown( | |
| f"Upload an image to classify if it was likely generated by AI or created by a human. " | |
| f"Uses the `{MODEL_IDENTIFIER}` model. Running on **{str(DEVICE).upper()}**.", | |
| elem_id="app-description" | |
| ) | |
| # Main layout with Input and Output side-by-side | |
| with gr.Row(variant='panel'): # 'panel' adds a light border/background | |
| with gr.Column(scale=1, min_width=300, elem_id="input-column"): | |
| image_input = gr.Image( | |
| type="pil", | |
| label="πΌοΈ Upload Your Image", | |
| sources=["upload", "webcam", "clipboard"], | |
| height=400, # Adjust height as needed | |
| ) | |
| submit_button = gr.Button("π Classify Image", variant="primary") # Make button prominent | |
| with gr.Column(scale=1, min_width=300, elem_id="output-column"): | |
| # Use elem_id and target with CSS for styling | |
| gr.Markdown("π **Prediction Results**", elem_id="results-heading") | |
| result_output = gr.Label( | |
| num_top_classes=2, | |
| label="Classification", | |
| elem_id="prediction-label" | |
| ) | |
| # Examples Section | |
| if example_images: # Only show examples if they exist and list is not empty | |
| gr.Examples( | |
| examples=example_images, | |
| inputs=image_input, | |
| outputs=result_output, | |
| fn=classify_image, | |
| cache_examples=True, # Caching is good for static examples | |
| label="β¨ Click an Example to Try!" | |
| ) | |
| # Footer / Article section | |
| gr.Markdown( | |
| """ | |
| --- | |
| This application uses a fine-tuned [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) vision model | |
| specifically trained to differentiate between images generated by Artificial Intelligence and those created by humans. | |
| You can find the model card here: <a href='https://huggingface.co/{model_id}' target='_blank'>{model_id}</a> | |
| Fine tuning code available at [https://exnrt.com/blog/ai/fine-tuning-siglip2/](https://exnrt.com/blog/ai/fine-tuning-siglip2/). | |
| """.format(model_id=MODEL_IDENTIFIER), | |
| elem_id="app-footer" | |
| ) | |
| # Connect the button click or image change to the prediction function | |
| # Use api_name for potential API usage later | |
| submit_button.click(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_image_button") | |
| image_input.change(fn=classify_image, inputs=image_input, outputs=result_output, api_name="classify_image_change") | |
| # --- Launch the App --- | |
| if __name__ == "__main__": | |
| print("Launching Gradio interface...") | |
| iface.launch() # Add share=True for temporary public link if needed: iface.launch(share=True) | |
| print("Gradio interface launched.") |