Satwickchikkala1 commited on
Commit
780aa84
·
verified ·
1 Parent(s): 5b6b48c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -66
app.py CHANGED
@@ -6,45 +6,36 @@ from PIL import Image
6
  import cv2
7
  import traceback
8
 
9
- # =========================
10
- # --- Load the model ---
11
- # =========================
12
  try:
13
  model = tf.keras.models.load_model("./model/deepfake_mobilenet_model.h5")
14
  except Exception as e:
15
  print(f"Error loading model. Make sure the path is correct. Error: {e}")
16
- # Fallback dummy model
17
  inputs = tf.keras.Input(shape=(224, 224, 3))
18
- x = tf.keras.layers.GlobalAveragePooling2D()(inputs)
19
- outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)
20
  model = tf.keras.Model(inputs, outputs)
21
 
22
- # =========================
23
- # --- Grad-CAM helpers ---
24
- # =========================
25
 
26
  def get_last_conv_layer_name(model):
27
  for layer in reversed(model.layers):
28
- if hasattr(layer.output, "shape") and len(layer.output.shape) == 4:
29
  return layer.name
30
- raise ValueError("No convolutional layer found in model.")
31
 
32
  def make_gradcam_heatmap(img_array, model, last_conv_layer_name):
33
  grad_model = tf.keras.models.Model(
34
  model.inputs, [model.get_layer(last_conv_layer_name).output, model.output]
35
  )
36
-
37
  with tf.GradientTape() as tape:
38
- last_conv_output, preds = grad_model(img_array)
39
-
40
- # Robustly handle list/tensor output
41
- preds = preds[0] if isinstance(preds, (list, tuple)) else preds
42
- class_channel = preds[:, 0] if len(preds.shape) > 1 else preds
43
-
44
- grads = tape.gradient(class_channel, last_conv_output)
45
  pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
46
- last_conv_output = last_conv_output[0]
47
- heatmap = last_conv_output @ pooled_grads[..., tf.newaxis]
48
  heatmap = tf.squeeze(heatmap)
49
  heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
50
  return heatmap.numpy()
@@ -57,54 +48,38 @@ def superimpose_gradcam(original_img_pil, heatmap):
57
  superimposed_img = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0)
58
  return Image.fromarray(superimposed_img)
59
 
60
- # =========================
61
- # --- Main prediction ---
62
- # =========================
63
 
64
  last_conv_layer_name = get_last_conv_layer_name(model)
65
 
66
  def predict_and_visualize(img):
67
  try:
68
- # --- Preprocess Image ---
69
  img_resized = img.resize((224, 224))
70
  img_array = image.img_to_array(img_resized) / 255.0
71
- img_array_exp = np.expand_dims(img_array, axis=0)
72
-
73
- # --- Model Prediction ---
74
- prediction = model.predict(img_array_exp, verbose=0)
75
 
76
- # Robust scalar extraction
77
- if isinstance(prediction, (list, tuple)):
78
- prediction = np.array(prediction[0])
79
- if isinstance(prediction, np.ndarray):
80
- prediction = prediction.item()
81
- prediction = float(prediction)
82
-
83
- # Confidence bars
84
- real_conf = max(0.0, min(1.0, prediction)) # clamp between 0 and 1
85
- fake_conf = 1.0 - real_conf
86
  labels = {"Real Image": real_conf, "Fake Image": fake_conf}
87
-
88
- # --- Grad-CAM (optional, fail-safe) ---
89
- try:
90
- heatmap = make_gradcam_heatmap(img_array_exp, model, last_conv_layer_name)
91
- superimposed_img = superimpose_gradcam(img, heatmap)
92
- except Exception as e:
93
- print("Grad-CAM failed:", e)
94
- superimposed_img = img # fallback to original image
95
-
96
- return labels, superimposed_img
97
 
98
  except Exception as e:
99
- print("--- PREDICTION ERROR ---")
100
  traceback.print_exc()
101
- # Always return a valid dict for gr.Label
102
- return {"Real Image": 0.0, "Fake Image": 0.0}, img
 
103
 
104
-
105
- # =========================
106
- # --- Gradio Interface ---
107
- # =========================
108
 
109
  gr.Interface(
110
  fn=predict_and_visualize,
@@ -113,16 +88,17 @@ gr.Interface(
113
  gr.Label(num_top_classes=2, label="🧠 Model Prediction"),
114
  gr.Image(label="🔥 Grad-CAM Heatmap Overlay")
115
  ],
116
- title="✨ Deepfake Image Detector with Grad-CAM ✨",
117
  description="""
118
- **Detect Real vs AI-Generated (Deepfake) Images.**
119
- The confidence bars show the model's certainty for **Real** and **Fake**.
120
- The Grad-CAM heatmap highlights areas most important to the model (red = most important).
 
121
 
122
- ⚡ **Instructions:**
123
- 1. Upload a face image (JPEG/PNG).
124
- 2. Wait a few seconds for prediction & heatmap.
125
- 3. Observe confidence bars and Grad-CAM overlay.
126
- """,
127
- theme="default"
128
  ).launch()
 
6
  import cv2
7
  import traceback
8
 
9
+ # Load the trained model
 
 
10
  try:
11
  model = tf.keras.models.load_model("./model/deepfake_mobilenet_model.h5")
12
  except Exception as e:
13
  print(f"Error loading model. Make sure the path is correct. Error: {e}")
 
14
  inputs = tf.keras.Input(shape=(224, 224, 3))
15
+ outputs = tf.keras.layers.Dense(1, activation="sigmoid")(tf.keras.layers.GlobalAveragePooling2D()(inputs))
 
16
  model = tf.keras.Model(inputs, outputs)
17
 
18
+ # ==============================================================================
19
+ # --- Grad-CAM Heatmap Generation Functions ---
20
+ # ==============================================================================
21
 
22
  def get_last_conv_layer_name(model):
23
  for layer in reversed(model.layers):
24
+ if len(layer.output.shape) == 4:
25
  return layer.name
26
+ raise ValueError("Could not find a conv layer in the model")
27
 
28
  def make_gradcam_heatmap(img_array, model, last_conv_layer_name):
29
  grad_model = tf.keras.models.Model(
30
  model.inputs, [model.get_layer(last_conv_layer_name).output, model.output]
31
  )
 
32
  with tf.GradientTape() as tape:
33
+ last_conv_layer_output, preds = grad_model([img_array])
34
+ class_channel = preds[0][:, 0]
35
+ grads = tape.gradient(class_channel, last_conv_layer_output)
 
 
 
 
36
  pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
37
+ last_conv_layer_output = last_conv_layer_output[0]
38
+ heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
39
  heatmap = tf.squeeze(heatmap)
40
  heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
41
  return heatmap.numpy()
 
48
  superimposed_img = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0)
49
  return Image.fromarray(superimposed_img)
50
 
51
+ # ==============================================================================
52
+ # --- Main Prediction Function ---
53
+ # ==============================================================================
54
 
55
  last_conv_layer_name = get_last_conv_layer_name(model)
56
 
57
  def predict_and_visualize(img):
58
  try:
 
59
  img_resized = img.resize((224, 224))
60
  img_array = image.img_to_array(img_resized) / 255.0
61
+ img_array_expanded = np.expand_dims(img_array, axis=0)
 
 
 
62
 
63
+ prediction = model.predict(img_array_expanded, verbose=0)[0][0]
64
+ real_conf = float(prediction)
65
+ fake_conf = float(1 - prediction)
 
 
 
 
 
 
 
66
  labels = {"Real Image": real_conf, "Fake Image": fake_conf}
67
+
68
+ heatmap = make_gradcam_heatmap(img_array_expanded, model, last_conv_layer_name)
69
+ superimposed_image = superimpose_gradcam(img, heatmap)
70
+
71
+ return labels, superimposed_image
 
 
 
 
 
72
 
73
  except Exception as e:
74
+ print("--- GRADIO APP ERROR ---")
75
  traceback.print_exc()
76
+ print("------------------------")
77
+ error_msg = f"Error: {e}"
78
+ return {error_msg: 0.0}, None
79
 
80
+ # ==============================================================================
81
+ # --- Gradio Interface with Improved Design ---
82
+ # ==============================================================================
 
83
 
84
  gr.Interface(
85
  fn=predict_and_visualize,
 
88
  gr.Label(num_top_classes=2, label="🧠 Model Prediction"),
89
  gr.Image(label="🔥 Grad-CAM Heatmap Overlay")
90
  ],
91
+ title="✨ Deepfake Image Detector with Visual Explanation ✨",
92
  description="""
93
+ **Detect whether an uploaded image is Real or AI-Generated (Deepfake).**
94
+ The confidence bars show the model's certainty for both **Real** and **Fake** categories.
95
+
96
+ Below, the **Grad-CAM heatmap** highlights the regions the model focused on (red = most important).
97
 
98
+ ⚡ **Instructions:**
99
+ 1. Upload a face image (JPEG/PNG).
100
+ 2. Wait a few seconds for the prediction and heatmap.
101
+ 3. Observe the confidence bars and heatmap for model explanation.
102
+ """,
103
+ theme="default" # you can later try 'soft', 'grass', or 'peach'
104
  ).launch()