Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForImageClassification, AutoImageProcessor | |
| from PIL import Image | |
| import numpy as np | |
| from captum.attr import LayerGradCam | |
| from captum.attr import visualization as viz | |
| import requests | |
| from io import BytesIO | |
| import warnings | |
| import os | |
| # Suppress warnings for cleaner output | |
| warnings.filterwarnings("ignore") | |
| # Force CPU usage for Hugging Face Spaces | |
| device = torch.device("cpu") | |
| torch.set_num_threads(1) # Optimize for CPU usage | |
| # --- 1. Load Model and Processor --- | |
| print("Loading model and processor...") | |
| try: | |
| model_id = "Organika/sdxl-detector" | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| # Load model with CPU-optimized settings | |
| model = AutoModelForImageClassification.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| low_cpu_mem_usage=True | |
| ) | |
| model.to(device) | |
| model.eval() | |
| print("Model and processor loaded successfully on CPU.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| # --- 2. Define the Explainability (Grad-CAM) Function --- | |
| def generate_heatmap(image_tensor, original_image, target_class_index): | |
| try: | |
| # Ensure tensor is on CPU | |
| image_tensor = image_tensor.to(device) | |
| # Define wrapper function for model forward pass | |
| def model_forward_wrapper(input_tensor): | |
| with torch.no_grad(): # Save memory during attribution | |
| outputs = model(pixel_values=input_tensor) | |
| return outputs.logits | |
| # Get the target layer for Grad-CAM | |
| # For SWIN transformer, use the layer normalization layer | |
| target_layer = model.swin.layernorm | |
| # Initialize LayerGradCam with the wrapper function | |
| lgc = LayerGradCam(model_forward_wrapper, target_layer) | |
| # Generate attributions | |
| with torch.no_grad(): | |
| attributions = lgc.attribute( | |
| image_tensor, | |
| target=target_class_index, | |
| relu_attributions=True | |
| ) | |
| # Convert attributions to numpy for visualization | |
| heatmap = np.transpose( | |
| attributions.squeeze(0).cpu().detach().numpy(), | |
| (1, 2, 0) | |
| ) | |
| # Create visualization | |
| visualized_image, _ = viz.visualize_image_attr( | |
| heatmap, | |
| np.array(original_image), | |
| method="blended_heat_map", | |
| sign="all", | |
| show_colorbar=True, | |
| title="AI Detection Heatmap", | |
| alpha_overlay=0.6 | |
| ) | |
| return visualized_image | |
| except Exception as e: | |
| print(f"Error generating heatmap: {e}") | |
| # Return original image if heatmap generation fails | |
| return np.array(original_image) | |
| # --- 3. Main Prediction Function --- | |
| def predict(image_upload: Image.Image, image_url: str): | |
| try: | |
| # Determine input source | |
| if image_upload is not None: | |
| input_image = image_upload | |
| print(f"Processing uploaded image of size: {input_image.size}") | |
| elif image_url and image_url.strip(): | |
| try: | |
| response = requests.get(image_url, timeout=10) | |
| response.raise_for_status() | |
| input_image = Image.open(BytesIO(response.content)) | |
| print(f"Processing image from URL: {image_url}") | |
| except Exception as e: | |
| raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}") | |
| else: | |
| raise gr.Error("Please upload an image or provide a URL to analyze.") | |
| # Convert RGBA to RGB if necessary | |
| if input_image.mode == 'RGBA': | |
| input_image = input_image.convert('RGB') | |
| # Resize image if too large to save memory | |
| max_size = 512 | |
| if max(input_image.size) > max_size: | |
| input_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
| # Process image | |
| inputs = processor(images=input_image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Calculate probabilities | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
| predicted_class_idx = logits.argmax(-1).item() | |
| confidence_score = probabilities[0][predicted_class_idx].item() | |
| predicted_label = model.config.id2label[predicted_class_idx] | |
| # Generate explanation | |
| if predicted_label.lower() == 'ai': | |
| explanation = ( | |
| f"🤖 The model is {confidence_score:.2%} confident that this image is **AI-GENERATED**.\n\n" | |
| "The heatmap highlights areas that most influenced this decision. " | |
| "Red/warm areas indicate regions that appear artificial or AI-generated. " | |
| "Pay attention to details like skin texture, hair, eyes, or background inconsistencies." | |
| ) | |
| else: | |
| explanation = ( | |
| f"👤 The model is {confidence_score:.2%} confident that this image is **HUMAN-MADE**.\n\n" | |
| "The heatmap shows areas the model considers natural and realistic. " | |
| "Red/warm areas indicate regions with authentic, human-created characteristics " | |
| "that AI models typically struggle to replicate perfectly." | |
| ) | |
| print("Generating heatmap...") | |
| heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx) | |
| print("Heatmap generated successfully.") | |
| # Create labels dictionary for gradio output | |
| labels_dict = { | |
| model.config.id2label[i]: float(probabilities[0][i]) | |
| for i in range(len(model.config.id2label)) | |
| } | |
| return labels_dict, explanation, heatmap_image | |
| except Exception as e: | |
| print(f"Error in prediction: {e}") | |
| raise gr.Error(f"An error occurred during prediction: {str(e)}") | |
| # --- 4. Gradio Interface --- | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="AI Image Detector", | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .tab-nav { | |
| margin-bottom: 1rem; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🔍 AI Image Detector with Explainability | |
| Determine if an image is AI-generated or human-made using advanced machine learning. | |
| **Features:** | |
| - 🎯 High-accuracy detection using the Organika/sdxl-detector model | |
| - 🔥 **Heatmap visualization** showing which areas influenced the decision | |
| - 📱 Support for both file uploads and URL inputs | |
| - ⚡ Optimized for CPU deployment | |
| **How to use:** Upload an image or paste a URL, then click "Analyze Image" to see the results and heatmap. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📥 Input") | |
| with gr.Tabs(): | |
| with gr.TabItem("📁 Upload File"): | |
| input_image_upload = gr.Image( | |
| type="pil", | |
| label="Upload Your Image", | |
| height=300 | |
| ) | |
| with gr.TabItem("🔗 Use URL"): | |
| input_image_url = gr.Textbox( | |
| label="Paste Image URL here", | |
| placeholder="https://example.com/image.jpg" | |
| ) | |
| submit_btn = gr.Button( | |
| "🔍 Analyze Image", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### ℹ️ Tips | |
| - Supported formats: JPG, PNG, WebP | |
| - Images are automatically resized for optimal processing | |
| - For best results, use clear, high-quality images | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📊 Results") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_label = gr.Label( | |
| label="Prediction Confidence", | |
| num_top_classes=2 | |
| ) | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Detailed Explanation", | |
| lines=6, | |
| interactive=False | |
| ) | |
| output_heatmap = gr.Image( | |
| label="🔥 AI Detection Heatmap - Red areas influenced the decision most", | |
| height=400 | |
| ) | |
| # Connect the interface | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[input_image_upload, input_image_url], | |
| outputs=[output_label, output_text, output_heatmap] | |
| ) | |
| # Add examples | |
| gr.Examples( | |
| examples=[ | |
| [None, "https://images.unsplash.com/photo-1494790108755-2616b612b786"], | |
| [None, "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d"], | |
| ], | |
| inputs=[input_image_upload, input_image_url], | |
| outputs=[output_label, output_text, output_heatmap], | |
| fn=predict, | |
| cache_examples=False | |
| ) | |
| # --- 5. Launch the App --- | |
| if __name__ == "__main__": | |
| demo.launch( | |
| debug=False, | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |