File size: 5,220 Bytes
566351c
f234b07
 
 
 
 
 
 
780aa84
f234b07
be59fd9
f234b07
 
2c503bf
f234b07
780aa84
f234b07
 
780aa84
108fc0e
780aa84
2c503bf
f234b07
9d3e381
f234b07
780aa84
f234b07
780aa84
f234b07
 
108fc0e
f234b07
 
 
 
9d3e381
2c503bf
88f01bc
 
 
 
 
 
780aa84
88f01bc
108fc0e
 
 
 
 
 
 
f234b07
780aa84
 
f234b07
 
 
 
 
108fc0e
f234b07
 
 
88f01bc
 
f234b07
 
780aa84
 
 
f234b07
 
 
88f01bc
f234b07
2c503bf
9d3e381
 
f234b07
 
780aa84
be59fd9
780aa84
 
 
f234b07
780aa84
 
 
 
 
9d3e381
f234b07
780aa84
f234b07
780aa84
 
 
5b6b48c
780aa84
9d3e381
780aa84
f234b07
 
 
 
 
be59fd9
f234b07
780aa84
f234b07
88f01bc
 
 
 
 
 
 
780aa84
9d3e381
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing import image
from PIL import Image
import cv2
import traceback

# Load the trained model
try:
    model = tf.keras.models.load_model("./model/deepfake_mobilenet_model.h5")
except Exception as e:
    print(f"Error loading model. Make sure the path is correct. Error: {e}")
    # Fallback dummy model for deployment debugging
    inputs = tf.keras.Input(shape=(224, 224, 3))
    outputs = tf.keras.layers.Dense(1, activation="sigmoid")(tf.keras.layers.GlobalAveragePooling2D()(inputs))
    model = tf.keras.Model(inputs, outputs)

# ==============================================================================
# --- Grad-CAM Heatmap Generation Functions (with safety check restored) ---
# ==============================================================================

def get_last_conv_layer_name(model):
    """Finds the name of the last convolutional layer in the model."""
    for layer in reversed(model.layers):
        if len(layer.output.shape) == 4:
            return layer.name
    raise ValueError("Could not find a conv layer in the model")

def make_gradcam_heatmap(img_array, model, last_conv_layer_name):
    """Generates the Grad-CAM heatmap with a stability check."""
    grad_model = tf.keras.models.Model(
        model.inputs, [model.get_layer(last_conv_layer_name).output, model.output]
    )
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model([img_array], training=False)
        
        preds_tensor = preds[0]
        if tf.rank(preds_tensor) > 1:
            class_channel = preds_tensor[:, 0]
        else:
            class_channel = preds_tensor
            
    grads = tape.gradient(class_channel, last_conv_layer_output)
    
    # <-- FIX: The critical safety check is restored here.
    # This prevents the "None to a Tensor" crash without changing the logic.
    if grads is None:
        print("Warning: Gradient is None. Cannot compute heatmap for this image. Returning a blank map.")
        h, w = last_conv_layer_output.shape[1:3]
        return np.zeros((h, w), dtype=np.float32)
    
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    last_conv_layer_output = last_conv_layer_output[0]
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
    return heatmap.numpy()

def superimpose_gradcam(original_img_pil, heatmap):
    """Overlays the heatmap on the original image."""
    original_img = np.array(original_img_pil)
    heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) 
    superimposed_img = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0) 
    return Image.fromarray(superimposed_img)

# ==============================================================================
# --- Main Prediction Function ---
# ==============================================================================
last_conv_layer_name = get_last_conv_layer_name(model)

def predict_and_visualize(img):
    """Performs prediction and generates Grad-CAM heatmap."""
    try:
        if img is None:
            return None, None

        img_resized = img.resize((224, 224))
        img_array = image.img_to_array(img_resized) / 255.0
        img_array_expanded = np.expand_dims(img_array, axis=0)
        
        prediction = model.predict(img_array_expanded, verbose=0)[0][0]
        real_conf = float(prediction)
        fake_conf = float(1 - prediction)
        labels = {"Real Image": real_conf, "Fake Image": fake_conf}
        
        heatmap = make_gradcam_heatmap(img_array_expanded, model, last_conv_layer_name)
        superimposed_image = superimpose_gradcam(img, heatmap)
        
        return labels, superimposed_image
        
    except Exception as e:
        print("--- GRADIO APP ERROR ---")
        traceback.print_exc()
        print("------------------------")
        error_msg = f"Error: {e}"
        return {error_msg: 0.0}, None

# ==============================================================================
# --- Gradio Interface ---
# ==============================================================================
gr.Interface(
    fn=predict_and_visualize,
    inputs=gr.Image(type="pil", label="πŸ“· Upload a Face Image"),
    outputs=[
        gr.Label(num_top_classes=2, label="🧠 Model Prediction"),
        gr.Image(label="πŸ”₯ Grad-CAM Heatmap Overlay")
    ],
    title="✨ Deepfake Image Detector with Visual Explanation ✨",
    description="""
    **Detect whether an uploaded image is Real or AI-Generated (Deepfake).** The confidence bars show the model's certainty for both **Real** and **Fake** categories.  

    Below, the **Grad-CAM heatmap** highlights the regions the model focused on (red = most important).  

    ⚑ **Instructions:** 1. Upload a face image (JPEG/PNG).  
    2. Wait a few seconds for the prediction and heatmap.  
    3. Observe the confidence bars and heatmap for model explanation.
    """,
    theme="default"
).launch()