Satwickchikkala1's picture
Update app.py
108fc0e verified
import gradio as gr
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing import image
from PIL import Image
import cv2
import traceback
# Load the trained model
try:
model = tf.keras.models.load_model("./model/deepfake_mobilenet_model.h5")
except Exception as e:
print(f"Error loading model. Make sure the path is correct. Error: {e}")
# Fallback dummy model for deployment debugging
inputs = tf.keras.Input(shape=(224, 224, 3))
outputs = tf.keras.layers.Dense(1, activation="sigmoid")(tf.keras.layers.GlobalAveragePooling2D()(inputs))
model = tf.keras.Model(inputs, outputs)
# ==============================================================================
# --- Grad-CAM Heatmap Generation Functions (with safety check restored) ---
# ==============================================================================
def get_last_conv_layer_name(model):
"""Finds the name of the last convolutional layer in the model."""
for layer in reversed(model.layers):
if len(layer.output.shape) == 4:
return layer.name
raise ValueError("Could not find a conv layer in the model")
def make_gradcam_heatmap(img_array, model, last_conv_layer_name):
"""Generates the Grad-CAM heatmap with a stability check."""
grad_model = tf.keras.models.Model(
model.inputs, [model.get_layer(last_conv_layer_name).output, model.output]
)
with tf.GradientTape() as tape:
last_conv_layer_output, preds = grad_model([img_array], training=False)
preds_tensor = preds[0]
if tf.rank(preds_tensor) > 1:
class_channel = preds_tensor[:, 0]
else:
class_channel = preds_tensor
grads = tape.gradient(class_channel, last_conv_layer_output)
# <-- FIX: The critical safety check is restored here.
# This prevents the "None to a Tensor" crash without changing the logic.
if grads is None:
print("Warning: Gradient is None. Cannot compute heatmap for this image. Returning a blank map.")
h, w = last_conv_layer_output.shape[1:3]
return np.zeros((h, w), dtype=np.float32)
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
last_conv_layer_output = last_conv_layer_output[0]
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
return heatmap.numpy()
def superimpose_gradcam(original_img_pil, heatmap):
"""Overlays the heatmap on the original image."""
original_img = np.array(original_img_pil)
heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0)
return Image.fromarray(superimposed_img)
# ==============================================================================
# --- Main Prediction Function ---
# ==============================================================================
last_conv_layer_name = get_last_conv_layer_name(model)
def predict_and_visualize(img):
"""Performs prediction and generates Grad-CAM heatmap."""
try:
if img is None:
return None, None
img_resized = img.resize((224, 224))
img_array = image.img_to_array(img_resized) / 255.0
img_array_expanded = np.expand_dims(img_array, axis=0)
prediction = model.predict(img_array_expanded, verbose=0)[0][0]
real_conf = float(prediction)
fake_conf = float(1 - prediction)
labels = {"Real Image": real_conf, "Fake Image": fake_conf}
heatmap = make_gradcam_heatmap(img_array_expanded, model, last_conv_layer_name)
superimposed_image = superimpose_gradcam(img, heatmap)
return labels, superimposed_image
except Exception as e:
print("--- GRADIO APP ERROR ---")
traceback.print_exc()
print("------------------------")
error_msg = f"Error: {e}"
return {error_msg: 0.0}, None
# ==============================================================================
# --- Gradio Interface ---
# ==============================================================================
gr.Interface(
fn=predict_and_visualize,
inputs=gr.Image(type="pil", label="πŸ“· Upload a Face Image"),
outputs=[
gr.Label(num_top_classes=2, label="🧠 Model Prediction"),
gr.Image(label="πŸ”₯ Grad-CAM Heatmap Overlay")
],
title="✨ Deepfake Image Detector with Visual Explanation ✨",
description="""
**Detect whether an uploaded image is Real or AI-Generated (Deepfake).** The confidence bars show the model's certainty for both **Real** and **Fake** categories.
Below, the **Grad-CAM heatmap** highlights the regions the model focused on (red = most important).
⚑ **Instructions:** 1. Upload a face image (JPEG/PNG).
2. Wait a few seconds for the prediction and heatmap.
3. Observe the confidence bars and heatmap for model explanation.
""",
theme="default"
).launch()