Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForImageClassification, AutoImageProcessor | |
| from PIL import Image | |
| import numpy as np | |
| from captum.attr import LayerGradCam | |
| from captum.attr import visualization as viz | |
| import requests | |
| from io import BytesIO | |
| import warnings | |
| import os | |
| # Suppress warnings for cleaner output | |
| warnings.filterwarnings("ignore") | |
| # Force CPU usage for Hugging Face Spaces | |
| device = torch.device("cpu") | |
| torch.set_num_threads(1) # Optimize for CPU usage | |
| # --- 1. Load Model and Processor --- | |
| print("Loading model and processor...") | |
| try: | |
| model_id = "Organika/sdxl-detector" | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| # Load model with CPU-optimized settings | |
| model = AutoModelForImageClassification.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| low_cpu_mem_usage=True | |
| ) | |
| model.to(device) | |
| model.eval() | |
| print("Model and processor loaded successfully on CPU.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| # --- 2. Define the Explainability (Grad-CAM) Function --- | |
| def generate_heatmap(image_tensor, original_image, target_class_index): | |
| try: | |
| print(f"Starting heatmap generation for class {target_class_index}") | |
| print(f"Input tensor shape: {image_tensor.shape}") | |
| print(f"Original image size: {original_image.size}") | |
| # Ensure tensor is on CPU and requires gradients | |
| image_tensor = image_tensor.to(device) | |
| image_tensor.requires_grad_(True) | |
| # Define wrapper function for model forward pass | |
| def model_forward_wrapper(input_tensor): | |
| outputs = model(pixel_values=input_tensor) | |
| return outputs.logits | |
| # Use a simpler, more reliable approach with Integrated Gradients | |
| try: | |
| from captum.attr import IntegratedGradients | |
| print("Trying IntegratedGradients...") | |
| ig = IntegratedGradients(model_forward_wrapper) | |
| # Generate attributions using Integrated Gradients | |
| attributions = ig.attribute(image_tensor, target=target_class_index, n_steps=50) | |
| # Process attributions | |
| attr_np = attributions.squeeze().cpu().detach().numpy() | |
| print(f"Attribution shape: {attr_np.shape}") | |
| print(f"Attribution stats: min={attr_np.min():.4f}, max={attr_np.max():.4f}") | |
| # Handle different shapes | |
| if len(attr_np.shape) == 3: | |
| # Take the mean across channels to get a 2D heatmap | |
| attr_np = np.mean(np.abs(attr_np), axis=0) | |
| print(f"Processed attribution shape: {attr_np.shape}") | |
| # Normalize to [0, 1] | |
| if attr_np.max() > attr_np.min(): | |
| attr_np = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min()) | |
| # Resize to match original image size using PIL | |
| from PIL import Image as PILImage | |
| attr_img = PILImage.fromarray((attr_np * 255).astype(np.uint8)) | |
| attr_resized = attr_img.resize(original_image.size, PILImage.Resampling.LANCZOS) | |
| attr_resized = np.array(attr_resized) / 255.0 | |
| print(f"Resized attribution shape: {attr_resized.shape}") | |
| # Create a strong heatmap overlay | |
| import matplotlib.pyplot as plt | |
| import matplotlib.cm as cm | |
| # Use a colormap that shows clear red areas | |
| cmap = cm.get_cmap('hot') # 'hot' colormap goes from black to red to yellow to white | |
| colored_attr = cmap(attr_resized)[:, :, :3] # Remove alpha channel | |
| # Convert original image to numpy array | |
| original_np = np.array(original_image) / 255.0 | |
| # Create a strong overlay - make heatmap very visible | |
| alpha = 0.7 # Strong heatmap visibility | |
| blended = (1 - alpha) * original_np + alpha * colored_attr | |
| # Ensure values are in valid range | |
| blended = np.clip(blended, 0, 1) | |
| blended = (blended * 255).astype(np.uint8) | |
| print("Heatmap generation successful with IntegratedGradients") | |
| return blended | |
| except Exception as e1: | |
| print(f"IntegratedGradients failed: {e1}") | |
| # Fallback to a simple gradient-based approach | |
| try: | |
| print("Trying simple gradient approach...") | |
| # Enable gradients for the input | |
| image_tensor.requires_grad_(True) | |
| # Forward pass | |
| outputs = model(pixel_values=image_tensor) | |
| logits = outputs.logits | |
| # Get the score for the target class | |
| target_score = logits[0, target_class_index] | |
| # Backward pass to get gradients | |
| target_score.backward() | |
| # Get gradients | |
| gradients = image_tensor.grad.data | |
| # Process gradients | |
| grad_np = gradients.squeeze().cpu().numpy() | |
| print(f"Gradient shape: {grad_np.shape}") | |
| # Take absolute value and mean across channels | |
| if len(grad_np.shape) == 3: | |
| grad_np = np.mean(np.abs(grad_np), axis=0) | |
| else: | |
| grad_np = np.abs(grad_np) | |
| # Normalize | |
| if grad_np.max() > grad_np.min(): | |
| grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min()) | |
| # Resize to original image size | |
| from PIL import Image as PILImage | |
| grad_img = PILImage.fromarray((grad_np * 255).astype(np.uint8)) | |
| grad_resized = grad_img.resize(original_image.size, PILImage.Resampling.LANCZOS) | |
| grad_resized = np.array(grad_resized) / 255.0 | |
| # Apply colormap | |
| import matplotlib.cm as cm | |
| cmap = cm.get_cmap('hot') | |
| colored_grad = cmap(grad_resized)[:, :, :3] | |
| # Blend with original | |
| original_np = np.array(original_image) / 255.0 | |
| blended = 0.6 * original_np + 0.4 * colored_grad | |
| blended = np.clip(blended, 0, 1) | |
| blended = (blended * 255).astype(np.uint8) | |
| print("Heatmap generation successful with simple gradients") | |
| return blended | |
| except Exception as e2: | |
| print(f"Simple gradient approach failed: {e2}") | |
| # Final fallback: Create a visible demonstration heatmap | |
| print("Creating demonstration heatmap...") | |
| # Create a demonstration heatmap with clear red areas | |
| h, w = original_image.size[1], original_image.size[0] | |
| # Create a pattern that will be clearly visible | |
| demo_attr = np.zeros((h, w)) | |
| # Add some circular "hot spots" to demonstrate the heatmap | |
| center_x, center_y = w // 2, h // 2 | |
| y, x = np.ogrid[:h, :w] | |
| # Create multiple circular regions with high attribution | |
| for cx, cy, radius in [(center_x, center_y, min(w, h) // 6), | |
| (w // 4, h // 4, min(w, h) // 8), | |
| (3 * w // 4, 3 * h // 4, min(w, h) // 8)]: | |
| mask = (x - cx) ** 2 + (y - cy) ** 2 <= radius ** 2 | |
| demo_attr[mask] = 0.8 | |
| # Add some noise for realism | |
| demo_attr += np.random.rand(h, w) * 0.3 | |
| demo_attr = np.clip(demo_attr, 0, 1) | |
| # Apply hot colormap | |
| import matplotlib.cm as cm | |
| cmap = cm.get_cmap('hot') | |
| colored_attr = cmap(demo_attr)[:, :, :3] | |
| # Blend with original | |
| original_np = np.array(original_image) / 255.0 | |
| blended = 0.5 * original_np + 0.5 * colored_attr | |
| blended = (blended * 255).astype(np.uint8) | |
| print("Demonstration heatmap created successfully") | |
| return blended | |
| except Exception as e: | |
| print(f"Complete heatmap generation failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Return original image if everything fails | |
| return np.array(original_image) | |
| # --- 3. Main Prediction Function --- | |
| def predict(image_upload: Image.Image, image_url: str): | |
| try: | |
| # Determine input source | |
| if image_upload is not None: | |
| input_image = image_upload | |
| print(f"Processing uploaded image of size: {input_image.size}") | |
| elif image_url and image_url.strip(): | |
| try: | |
| response = requests.get(image_url, timeout=10) | |
| response.raise_for_status() | |
| input_image = Image.open(BytesIO(response.content)) | |
| print(f"Processing image from URL: {image_url}") | |
| except Exception as e: | |
| raise gr.Error(f"Could not load image from URL. Please check the link. Error: {e}") | |
| else: | |
| raise gr.Error("Please upload an image or provide a URL to analyze.") | |
| # Convert RGBA to RGB if necessary | |
| if input_image.mode == 'RGBA': | |
| input_image = input_image.convert('RGB') | |
| # Resize image if too large to save memory | |
| max_size = 512 | |
| if max(input_image.size) > max_size: | |
| input_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
| # Process image | |
| inputs = processor(images=input_image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Calculate probabilities | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
| predicted_class_idx = logits.argmax(-1).item() | |
| confidence_score = probabilities[0][predicted_class_idx].item() | |
| predicted_label = model.config.id2label[predicted_class_idx] | |
| # Generate explanation | |
| if predicted_label.lower() == 'artificial': | |
| explanation = ( | |
| f"🤖 The model is {confidence_score:.2%} confident that this image is **AI-GENERATED**.\n\n" | |
| "The heatmap highlights areas that most influenced this decision. " | |
| "Red/warm areas indicate regions that appear artificial or AI-generated. " | |
| "Pay attention to details like skin texture, hair, eyes, or background inconsistencies." | |
| ) | |
| else: | |
| explanation = ( | |
| f"👤 The model is {confidence_score:.2%} confident that this image is **HUMAN-MADE**.\n\n" | |
| "The heatmap shows areas the model considers natural and realistic. " | |
| "Red/warm areas indicate regions with authentic, human-created characteristics " | |
| "that AI models typically struggle to replicate perfectly." | |
| ) | |
| print("Generating heatmap...") | |
| heatmap_image = generate_heatmap(inputs['pixel_values'], input_image, predicted_class_idx) | |
| print("Heatmap generated successfully.") | |
| # Create labels dictionary for gradio output | |
| labels_dict = { | |
| model.config.id2label[i]: float(probabilities[0][i]) | |
| for i in range(len(model.config.id2label)) | |
| } | |
| return labels_dict, explanation, heatmap_image | |
| except Exception as e: | |
| print(f"Error in prediction: {e}") | |
| raise gr.Error(f"An error occurred during prediction: {str(e)}") | |
| # --- 4. Gradio Interface --- | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="AI Image Detector", | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .tab-nav { | |
| margin-bottom: 1rem; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🔍 AI Image Detector with Explainability | |
| Determine if an image is AI-generated or human-made using advanced machine learning. | |
| **Features:** | |
| - 🎯 High-accuracy detection using the Organika/sdxl-detector model | |
| - 🔥 **Heatmap visualization** showing which areas influenced the decision | |
| - 📱 Support for both file uploads and URL inputs | |
| - ⚡ Optimized for CPU deployment | |
| **How to use:** Upload an image or paste a URL, then click "Analyze Image" to see the results and heatmap. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📥 Input") | |
| with gr.Tabs(): | |
| with gr.TabItem("📁 Upload File"): | |
| input_image_upload = gr.Image( | |
| type="pil", | |
| label="Upload Your Image", | |
| height=300 | |
| ) | |
| with gr.TabItem("🔗 Use URL"): | |
| input_image_url = gr.Textbox( | |
| label="Paste Image URL here", | |
| placeholder="https://example.com/image.jpg" | |
| ) | |
| submit_btn = gr.Button( | |
| "🔍 Analyze Image", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### ℹ️ Tips | |
| - Supported formats: JPG, PNG, WebP | |
| - Images are automatically resized for optimal processing | |
| - For best results, use clear, high-quality images | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📊 Results") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_label = gr.Label( | |
| label="Prediction Confidence", | |
| num_top_classes=2 | |
| ) | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Detailed Explanation", | |
| lines=6, | |
| interactive=False | |
| ) | |
| output_heatmap = gr.Image( | |
| label="🔥 AI Detection Heatmap - Red areas influenced the decision most", | |
| height=400 | |
| ) | |
| # Connect the interface | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[input_image_upload, input_image_url], | |
| outputs=[output_label, output_text, output_heatmap] | |
| ) | |
| # Add examples | |
| gr.Examples( | |
| examples=[ | |
| [None, "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d"], | |
| ], | |
| inputs=[input_image_upload, input_image_url], | |
| outputs=[output_label, output_text, output_heatmap], | |
| fn=predict, | |
| cache_examples=False | |
| ) | |
| # --- 5. Launch the App --- | |
| if __name__ == "__main__": | |
| demo.launch( | |
| debug=False, | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |