strelizi commited on
Commit
7eda223
Β·
verified Β·
1 Parent(s): d3e9a13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -64
app.py CHANGED
@@ -1,8 +1,3 @@
1
- """
2
- XAI Image Classifier - Optimized Production Version
3
- ===============================================================
4
- """
5
-
6
  import torch
7
  import torch.nn as nn
8
  from torchvision import models, transforms
@@ -20,22 +15,28 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  @torch.no_grad()
22
  def load_model_and_labels():
23
- """Load ResNet50 with optimized settings"""
24
- model = models.resnet50(weights='IMAGENET1K_V2')
25
- model.eval()
26
- model = model.to(DEVICE)
27
 
 
 
 
 
 
28
  url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
29
  response = urllib.request.urlopen(url)
30
  labels = [line.decode('utf-8').strip() for line in response.readlines()]
31
 
 
32
  return model, labels
33
 
34
  model, IMAGENET_LABELS = load_model_and_labels()
35
 
 
36
  target_layer = model.layer4[-1]
37
  gradcam = LayerGradCam(model, target_layer)
38
 
 
39
  transform = transforms.Compose([
40
  transforms.Resize((224, 224)),
41
  transforms.ToTensor(),
@@ -48,29 +49,29 @@ def predict_and_explain(image):
48
  return "Please upload an image", None, None
49
 
50
  try:
51
-
52
  img_tensor = transform(image).unsqueeze(0).to(DEVICE)
53
 
54
  with torch.no_grad():
 
55
  output = model(img_tensor)
56
-
57
- temperature = 1.0
58
- scaled_output = output / temperature
59
- probabilities = torch.softmax(scaled_output, dim=1)
60
- top10_prob, top10_idx = torch.topk(probabilities, 10)
61
 
 
62
  pred_class = top10_idx[0][0].item()
63
  confidence = top10_prob[0][0].item()
64
 
 
65
  attributions = gradcam.attribute(img_tensor, target=pred_class)
66
  attr_resized = interpolate(attributions, size=(224, 224), mode='bilinear', align_corners=False)
67
  attr_np = attr_resized.squeeze().cpu().detach().numpy()
68
  attr_np = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min() + 1e-8)
69
 
70
- fig = plt.figure(figsize=(20, 12))
 
71
  fig.patch.set_facecolor('#0a0a0a')
72
 
73
- gs = fig.add_gridspec(2, 3, height_ratios=[2, 1], hspace=0.25, wspace=0.12)
74
 
75
  ax1 = fig.add_subplot(gs[0, 0])
76
  ax2 = fig.add_subplot(gs[0, 1])
@@ -78,31 +79,31 @@ def predict_and_explain(image):
78
  ax4 = fig.add_subplot(gs[1, :])
79
 
80
  ax1.imshow(image)
81
- ax1.set_title("Original Image", fontsize=15, fontweight='600', color='#e0e0e0', pad=15)
82
  ax1.axis('off')
83
 
84
  im = ax2.imshow(attr_np, cmap='jet', interpolation='bilinear')
85
- ax2.set_title("Grad-CAM Heatmap", fontsize=15, fontweight='600', color='#e0e0e0', pad=15)
86
  ax2.axis('off')
87
  cbar = plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
88
- cbar.ax.tick_params(labelsize=10, colors='#a0a0a0')
89
- cbar.set_label('Importance', rotation=270, labelpad=20, color='#e0e0e0', fontsize=11, fontweight='600')
90
 
91
  ax3.imshow(image)
92
  ax3.imshow(attr_np, cmap='jet', alpha=0.5, interpolation='bilinear')
93
- ax3.set_title(f"AI Focus: {IMAGENET_LABELS[pred_class]}", fontsize=15, fontweight='600', color='#e0e0e0', pad=15)
94
  ax3.axis('off')
95
 
96
  top10_labels = [IMAGENET_LABELS[idx.item()] for idx in top10_idx[0]]
97
  top10_probs = [prob.item() * 100 for prob in top10_prob[0]]
98
 
99
  colors = ['#10b981' if i == 9 else '#3b82f6' if i >= 7 else '#8b5cf6' for i in range(10)]
100
- bars = ax4.barh(range(10), top10_probs[::-1], color=colors[::-1], edgecolor='#1a1a1a', linewidth=1.5)
101
 
102
  ax4.set_yticks(range(10))
103
- ax4.set_yticklabels(top10_labels[::-1], fontsize=12, color='#e0e0e0')
104
- ax4.set_xlabel('Confidence (%)', fontsize=13, color='#e0e0e0', fontweight='600')
105
- ax4.set_title('Top 10 Predictions', fontsize=16, fontweight='700', color='#e0e0e0', pad=15)
106
  ax4.set_xlim([0, 100])
107
  ax4.grid(axis='x', alpha=0.2, color='#404040', linestyle='--')
108
  ax4.set_facecolor('#0a0a0a')
@@ -110,48 +111,53 @@ def predict_and_explain(image):
110
  ax4.spines['right'].set_visible(False)
111
  ax4.spines['left'].set_color('#404040')
112
  ax4.spines['bottom'].set_color('#404040')
113
- ax4.tick_params(colors='#a0a0a0', labelsize=11)
114
 
115
  for bar, prob in zip(bars, top10_probs[::-1]):
116
  ax4.text(prob + 1.5, bar.get_y() + bar.get_height()/2,
117
- f'{prob:.1f}%', va='center', fontsize=11, color='#e0e0e0', fontweight='600')
118
 
119
  plt.tight_layout()
120
 
121
  buf = BytesIO()
122
- plt.savefig(buf, format='png', dpi=120, bbox_inches='tight', facecolor='#0a0a0a')
123
  buf.seek(0)
124
  result_image = Image.open(buf)
125
  plt.close(fig)
126
 
127
- fig2, axes = plt.subplots(1, 3, figsize=(18, 6))
 
128
  fig2.patch.set_facecolor('#0a0a0a')
129
 
130
- axes[0].imshow(image)
131
- axes[0].set_title("Original", fontsize=14, fontweight='600', color='#e0e0e0', pad=12)
132
- axes[0].axis('off')
 
 
 
 
 
133
 
134
- im2 = axes[1].imshow(attr_np, cmap='viridis', interpolation='gaussian')
135
- axes[1].set_title("High-Res Heatmap", fontsize=14, fontweight='600', color='#e0e0e0', pad=12)
136
- axes[1].axis('off')
137
- cbar2 = plt.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)
138
- cbar2.ax.tick_params(labelsize=10, colors='#a0a0a0')
139
 
140
- axes[2].imshow(image)
141
- axes[2].imshow(attr_np, cmap='hot', alpha=0.6, interpolation='bilinear')
142
- axes[2].contour(attr_np, levels=5, colors='white', linewidths=1.5, alpha=0.8)
143
- axes[2].set_title("Contour Analysis", fontsize=14, fontweight='600', color='#e0e0e0', pad=12)
144
- axes[2].axis('off')
145
 
146
  plt.tight_layout()
147
 
148
  buf2 = BytesIO()
149
- plt.savefig(buf2, format='png', dpi=110, bbox_inches='tight', facecolor='#0a0a0a')
150
  buf2.seek(0)
151
  detailed_heatmap = Image.open(buf2)
152
  plt.close(fig2)
153
 
154
-
155
  badge = "high" if confidence > 0.8 else "medium" if confidence > 0.5 else "low"
156
  badge_text = "High Confidence" if confidence > 0.8 else "Medium Confidence" if confidence > 0.5 else "Low Confidence"
157
  badge_icon = "🎯" if confidence > 0.8 else "⚑" if confidence > 0.5 else "⚠️"
@@ -176,6 +182,7 @@ def predict_and_explain(image):
176
  <div class="badge badge-{badge}">{badge_icon} {badge_text}</div>
177
  </div>
178
  <div class="conf-score">{confidence*100:.2f}%</div>
 
179
  <div class="divider"></div>
180
  {top5_html}
181
  </div>"""
@@ -187,31 +194,35 @@ def predict_and_explain(image):
187
 
188
 
189
  custom_css = """
190
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&display=swap');
191
  * { box-sizing: border-box; margin: 0; padding: 0; }
192
  body, .gradio-container { margin: 0 !important; padding: 0 !important; width: 100vw !important; min-height: 100vh !important; max-width: 100vw !important; background: linear-gradient(135deg, #0a0a0a 0%, #1a1a1a 50%, #0f0f0f 100%) !important; font-family: 'Inter', sans-serif !important; color: #e0e0e0 !important; overflow-x: hidden !important; }
193
  .gradio-container { padding: 0 !important; }
194
  .main-wrapper { padding: 1.5rem; max-width: 1920px; margin: 0 auto; position: relative; z-index: 2; }
195
- .hero-header { text-align: center; padding: 2rem 1rem 1.5rem; margin-bottom: 1.5rem; }
196
- .hero-header h1 { font-size: clamp(2rem, 5vw, 3.5rem); font-weight: 800; background: linear-gradient(135deg, #3b82f6 0%, #8b5cf6 50%, #3b82f6 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin: 0 0 0.5rem; letter-spacing: -1px; }
197
- .hero-header .subtitle { font-size: clamp(0.95rem, 2vw, 1.2rem); color: #808080; font-weight: 400; margin: 0; }
 
198
  .top-section { display: grid; grid-template-columns: 400px 1fr; gap: 1.25rem; margin-bottom: 1.25rem; }
199
  .upload-panel, .results-panel, .viz-section { background: rgba(20, 20, 20, 0.8); border: 1px solid rgba(255, 255, 255, 0.1); border-radius: 24px; padding: 1.5rem; backdrop-filter: blur(20px); box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4); }
200
- .section-label { font-size: 1.1rem; font-weight: 700; background: linear-gradient(135deg, #3b82f6, #8b5cf6); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin: 0 0 1rem; text-align: center; letter-spacing: 0.5px; }
201
  #input-image { border: 2px dashed rgba(59, 130, 246, 0.4) !important; border-radius: 20px !important; background: rgba(10, 10, 10, 0.6) !important; height: 320px !important; transition: all 0.3s ease; }
202
  #input-image:hover { border-color: #3b82f6 !important; background: rgba(20, 20, 30, 0.8) !important; transform: scale(1.02); box-shadow: 0 0 30px rgba(59, 130, 246, 0.2); }
 
 
203
  .btn-row { display: flex; gap: 0.75rem; margin-top: 1rem; }
204
  .gr-button { border-radius: 14px !important; font-weight: 700 !important; height: 50px !important; font-size: 0.95rem !important; transition: all 0.3s ease !important; border: none !important; letter-spacing: 0.5px; text-transform: uppercase; }
205
  .gr-button-primary { background: linear-gradient(135deg, #3b82f6, #8b5cf6) !important; color: white !important; box-shadow: 0 4px 20px rgba(59, 130, 246, 0.4) !important; }
206
  .gr-button-primary:hover { transform: translateY(-3px) !important; box-shadow: 0 8px 30px rgba(59, 130, 246, 0.6) !important; }
207
  .gr-button-secondary { background: rgba(40, 40, 40, 0.8) !important; color: #a0a0a0 !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; }
208
  .pred-header { display: flex; align-items: center; justify-content: space-between; flex-wrap: wrap; gap: 1rem; margin-bottom: 0.75rem; }
209
- .pred-label { font-size: clamp(1.5rem, 3vw, 2rem); font-weight: 800; color: #ffffff; margin: 0; letter-spacing: -0.5px; }
210
  .badge { padding: 0.5rem 1.25rem; border-radius: 50px; font-size: 0.875rem; font-weight: 700; text-transform: uppercase; letter-spacing: 0.5px; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.3); }
211
  .badge-high { background: linear-gradient(135deg, #10b981, #059669); color: white; }
212
  .badge-medium { background: linear-gradient(135deg, #f59e0b, #d97706); color: white; }
213
  .badge-low { background: linear-gradient(135deg, #ef4444, #dc2626); color: white; }
214
- .conf-score { font-size: clamp(2rem, 5vw, 3rem); font-weight: 900; background: linear-gradient(135deg, #3b82f6, #8b5cf6); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin-bottom: 1.25rem; letter-spacing: -1px; }
 
215
  .divider { height: 2px; background: linear-gradient(90deg, transparent, rgba(59, 130, 246, 0.3), transparent); margin: 1.5rem 0; }
216
  .top5-grid { display: flex; flex-direction: column; gap: 0.875rem; }
217
  .top5-row { display: grid; grid-template-columns: 40px 1fr auto 80px; align-items: center; gap: 0.875rem; font-size: 0.95rem; padding: 0.5rem; border-radius: 12px; background: rgba(30, 30, 30, 0.5); transition: all 0.3s ease; }
@@ -221,25 +232,43 @@ body, .gradio-container { margin: 0 !important; padding: 0 !important; width: 10
221
  .bar-wrap { background: rgba(40, 40, 40, 0.8); height: 10px; border-radius: 5px; overflow: hidden; min-width: 100px; box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.3); }
222
  .bar { background: linear-gradient(90deg, #3b82f6, #8b5cf6); height: 100%; transition: width 1s ease; border-radius: 5px; box-shadow: 0 0 10px rgba(59, 130, 246, 0.5); }
223
  .pct { color: #3b82f6; font-weight: 700; font-size: 0.9rem; text-align: right; }
224
- #result-image, #detailed-heatmap { border-radius: 16px !important; overflow: hidden; width: 100%; height: auto; box-shadow: 0 8px 32px rgba(0, 0, 0, 0.5); }
225
  .placeholder { text-align: center; padding: 4rem 1.5rem; color: #606060; font-size: 1.1rem; line-height: 1.6; }
226
  .placeholder strong { color: #3b82f6; }
227
  .error-msg { color: #ef4444; background: rgba(239, 68, 68, 0.1); padding: 1.5rem; border-radius: 16px; text-align: center; border: 1px solid rgba(239, 68, 68, 0.3); }
228
- .gr-accordion { background: rgba(20, 20, 20, 0.8) !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; border-radius: 20px !important; margin-top: 1.5rem; }
229
- .gr-accordion summary { color: #e0e0e0 !important; font-weight: 700 !important; padding: 1.25rem 1.5rem !important; font-size: 1.1rem !important; }
230
  footer, .footer { display: none !important; }
231
  ::-webkit-scrollbar { width: 10px; }
232
  ::-webkit-scrollbar-track { background: rgba(20, 20, 20, 0.5); }
233
  ::-webkit-scrollbar-thumb { background: rgba(59, 130, 246, 0.5); border-radius: 5px; }
234
- @media (max-width: 768px) { .top-section { grid-template-columns: 1fr; } #input-image { height: 240px !important; } .top5-row { grid-template-columns: 35px 1fr 70px; } .bar-wrap { grid-column: 1 / -1; margin-top: 0.375rem; } }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  """
236
 
237
 
238
- with gr.Blocks(css=custom_css, theme=gr.themes.Base(), title="Explainable AI") as demo:
239
  gr.HTML('<link rel="icon" href="https://res.cloudinary.com/ddn0xuwut/image/upload/v1761284764/encryption_hc0fxo.png" type="image/png">')
240
 
241
  with gr.Column(elem_classes="main-wrapper"):
242
- gr.HTML('<div class="hero-header"><h1>XAI Classifier</h1><p class="subtitle">See exactly what the AI sees – powered by ResNet50 + Grad-CAM</p></div>')
 
 
 
 
 
 
243
 
244
  with gr.Row(elem_classes="top-section"):
245
  with gr.Column(scale=0, min_width=400, elem_classes="upload-panel"):
@@ -250,17 +279,17 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Base(), title="Explainable AI") a
250
  clear_btn = gr.ClearButton([input_image], value="πŸ—‘οΈ Clear", size="lg", scale=1)
251
 
252
  with gr.Column(scale=1, elem_classes="results-panel"):
253
- output_text = gr.HTML('<div class="placeholder"><strong>πŸ‘‹ Welcome!</strong><br><br>Upload an image and click <strong>Analyze</strong></div>')
254
 
255
  with gr.Column(elem_classes="viz-section"):
256
- gr.HTML("<div class='section-label'>🎯 Visual Explainability (Includes Bar Graph)</div>")
257
- output_image = gr.Image(label=None, type="pil", show_label=False, elem_id="result-image", height=700)
258
 
259
  with gr.Column(elem_classes="viz-section"):
260
- gr.HTML("<div class='section-label'>πŸ”¬ Advanced Heatmap Analysis</div>")
261
- detailed_heatmap = gr.Image(label=None, type="pil", show_label=False, elem_id="detailed-heatmap", height=500)
262
 
263
  predict_btn.click(fn=predict_and_explain, inputs=[input_image], outputs=[output_text, output_image, detailed_heatmap])
264
 
265
  if __name__ == "__main__":
266
- demo.launch(share=True, show_error=True)
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from torchvision import models, transforms
 
15
 
16
  @torch.no_grad()
17
  def load_model_and_labels():
18
+ """Load ResNet152 model for maximum accuracy"""
19
+ print("πŸš€ Loading ResNet152 model...")
 
 
20
 
21
+ # ResNet152 (Best accuracy in ResNet family)
22
+ model = models.resnet152(weights='IMAGENET1K_V2')
23
+ model.eval().to(DEVICE)
24
+
25
+ # Load ImageNet labels
26
  url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
27
  response = urllib.request.urlopen(url)
28
  labels = [line.decode('utf-8').strip() for line in response.readlines()]
29
 
30
+ print("βœ… Model loaded successfully!")
31
  return model, labels
32
 
33
  model, IMAGENET_LABELS = load_model_and_labels()
34
 
35
+ # Setup Grad-CAM
36
  target_layer = model.layer4[-1]
37
  gradcam = LayerGradCam(model, target_layer)
38
 
39
+ # Transform for ResNet152
40
  transform = transforms.Compose([
41
  transforms.Resize((224, 224)),
42
  transforms.ToTensor(),
 
49
  return "Please upload an image", None, None
50
 
51
  try:
52
+ # Prepare input
53
  img_tensor = transform(image).unsqueeze(0).to(DEVICE)
54
 
55
  with torch.no_grad():
56
+ # ResNet152 prediction
57
  output = model(img_tensor)
58
+ probabilities = torch.softmax(output, dim=1)
 
 
 
 
59
 
60
+ top10_prob, top10_idx = torch.topk(probabilities, 10)
61
  pred_class = top10_idx[0][0].item()
62
  confidence = top10_prob[0][0].item()
63
 
64
+ # Generate Grad-CAM
65
  attributions = gradcam.attribute(img_tensor, target=pred_class)
66
  attr_resized = interpolate(attributions, size=(224, 224), mode='bilinear', align_corners=False)
67
  attr_np = attr_resized.squeeze().cpu().detach().numpy()
68
  attr_np = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min() + 1e-8)
69
 
70
+ # Main visualization
71
+ fig = plt.figure(figsize=(24, 14))
72
  fig.patch.set_facecolor('#0a0a0a')
73
 
74
+ gs = fig.add_gridspec(2, 3, height_ratios=[2, 1], hspace=0.3, wspace=0.15)
75
 
76
  ax1 = fig.add_subplot(gs[0, 0])
77
  ax2 = fig.add_subplot(gs[0, 1])
 
79
  ax4 = fig.add_subplot(gs[1, :])
80
 
81
  ax1.imshow(image)
82
+ ax1.set_title("Original Image", fontsize=18, fontweight='700', color='#e0e0e0', pad=20)
83
  ax1.axis('off')
84
 
85
  im = ax2.imshow(attr_np, cmap='jet', interpolation='bilinear')
86
+ ax2.set_title("Grad-CAM Heatmap", fontsize=18, fontweight='700', color='#e0e0e0', pad=20)
87
  ax2.axis('off')
88
  cbar = plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
89
+ cbar.ax.tick_params(labelsize=12, colors='#a0a0a0')
90
+ cbar.set_label('Importance', rotation=270, labelpad=25, color='#e0e0e0', fontsize=13, fontweight='600')
91
 
92
  ax3.imshow(image)
93
  ax3.imshow(attr_np, cmap='jet', alpha=0.5, interpolation='bilinear')
94
+ ax3.set_title(f"AI Focus: {IMAGENET_LABELS[pred_class]}", fontsize=18, fontweight='700', color='#e0e0e0', pad=20)
95
  ax3.axis('off')
96
 
97
  top10_labels = [IMAGENET_LABELS[idx.item()] for idx in top10_idx[0]]
98
  top10_probs = [prob.item() * 100 for prob in top10_prob[0]]
99
 
100
  colors = ['#10b981' if i == 9 else '#3b82f6' if i >= 7 else '#8b5cf6' for i in range(10)]
101
+ bars = ax4.barh(range(10), top10_probs[::-1], color=colors[::-1], edgecolor='#1a1a1a', linewidth=2)
102
 
103
  ax4.set_yticks(range(10))
104
+ ax4.set_yticklabels(top10_labels[::-1], fontsize=14, color='#e0e0e0', fontweight='600')
105
+ ax4.set_xlabel('Confidence (%)', fontsize=15, color='#e0e0e0', fontweight='700')
106
+ ax4.set_title('Top 10 Predictions', fontsize=19, fontweight='800', color='#e0e0e0', pad=20)
107
  ax4.set_xlim([0, 100])
108
  ax4.grid(axis='x', alpha=0.2, color='#404040', linestyle='--')
109
  ax4.set_facecolor('#0a0a0a')
 
111
  ax4.spines['right'].set_visible(False)
112
  ax4.spines['left'].set_color('#404040')
113
  ax4.spines['bottom'].set_color('#404040')
114
+ ax4.tick_params(colors='#a0a0a0', labelsize=13)
115
 
116
  for bar, prob in zip(bars, top10_probs[::-1]):
117
  ax4.text(prob + 1.5, bar.get_y() + bar.get_height()/2,
118
+ f'{prob:.1f}%', va='center', fontsize=13, color='#e0e0e0', fontweight='700')
119
 
120
  plt.tight_layout()
121
 
122
  buf = BytesIO()
123
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='#0a0a0a')
124
  buf.seek(0)
125
  result_image = Image.open(buf)
126
  plt.close(fig)
127
 
128
+ # Detailed heatmap analysis
129
+ fig2, axes = plt.subplots(2, 2, figsize=(20, 18))
130
  fig2.patch.set_facecolor('#0a0a0a')
131
 
132
+ axes[0, 0].imshow(image)
133
+ axes[0, 0].set_title("Original Image", fontsize=17, fontweight='700', color='#e0e0e0', pad=15)
134
+ axes[0, 0].axis('off')
135
+
136
+ axes[0, 1].imshow(image)
137
+ axes[0, 1].imshow(attr_np, cmap='jet', alpha=0.6, interpolation='bilinear')
138
+ axes[0, 1].set_title("Jet Colormap Overlay", fontsize=17, fontweight='700', color='#e0e0e0', pad=15)
139
+ axes[0, 1].axis('off')
140
 
141
+ axes[1, 0].imshow(image)
142
+ axes[1, 0].imshow(attr_np, cmap='hot', alpha=0.6, interpolation='bilinear')
143
+ axes[1, 0].set_title("Hot Colormap Overlay", fontsize=17, fontweight='700', color='#e0e0e0', pad=15)
144
+ axes[1, 0].axis('off')
 
145
 
146
+ axes[1, 1].imshow(image)
147
+ axes[1, 1].imshow(attr_np, cmap='viridis', alpha=0.6, interpolation='gaussian')
148
+ axes[1, 1].contour(attr_np, levels=6, colors='white', linewidths=2, alpha=0.9)
149
+ axes[1, 1].set_title("Viridis + Contours", fontsize=17, fontweight='700', color='#e0e0e0', pad=15)
150
+ axes[1, 1].axis('off')
151
 
152
  plt.tight_layout()
153
 
154
  buf2 = BytesIO()
155
+ plt.savefig(buf2, format='png', dpi=140, bbox_inches='tight', facecolor='#0a0a0a')
156
  buf2.seek(0)
157
  detailed_heatmap = Image.open(buf2)
158
  plt.close(fig2)
159
 
160
+ # Prediction card
161
  badge = "high" if confidence > 0.8 else "medium" if confidence > 0.5 else "low"
162
  badge_text = "High Confidence" if confidence > 0.8 else "Medium Confidence" if confidence > 0.5 else "Low Confidence"
163
  badge_icon = "🎯" if confidence > 0.8 else "⚑" if confidence > 0.5 else "⚠️"
 
182
  <div class="badge badge-{badge}">{badge_icon} {badge_text}</div>
183
  </div>
184
  <div class="conf-score">{confidence*100:.2f}%</div>
185
+ <div class="model-tag">πŸ”¬ ResNet152 Architecture (82.3% ImageNet Accuracy)</div>
186
  <div class="divider"></div>
187
  {top5_html}
188
  </div>"""
 
194
 
195
 
196
  custom_css = """
197
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800;900&display=swap');
198
  * { box-sizing: border-box; margin: 0; padding: 0; }
199
  body, .gradio-container { margin: 0 !important; padding: 0 !important; width: 100vw !important; min-height: 100vh !important; max-width: 100vw !important; background: linear-gradient(135deg, #0a0a0a 0%, #1a1a1a 50%, #0f0f0f 100%) !important; font-family: 'Inter', sans-serif !important; color: #e0e0e0 !important; overflow-x: hidden !important; }
200
  .gradio-container { padding: 0 !important; }
201
  .main-wrapper { padding: 1.5rem; max-width: 1920px; margin: 0 auto; position: relative; z-index: 2; }
202
+ .hero-header { text-align: center; padding: 2rem 1rem 1.5rem; margin-bottom: 1.5rem; position: relative; }
203
+ .hero-header h1 { font-size: clamp(2rem, 5vw, 3.5rem); font-weight: 900; background-color: #d8b4fe; -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin: 0 0 0.5rem; letter-spacing: -1px; }
204
+ .hero-header .subtitle { font-size: clamp(0.95rem, 2vw, 1.2rem); color: #808080; font-weight: 400; margin: 0 0 0.5rem; }
205
+ .hero-header .model-tag { display: inline-block; background: #93c5fd; border: 1px solid rgba(59, 130, 246, 0.3); color: #3b82f6; padding: 0.5rem 1.5rem; border-radius: 50px; font-size: 0.85rem; font-weight: 700; letter-spacing: 0.5px; margin-top: 0.5rem; }
206
  .top-section { display: grid; grid-template-columns: 400px 1fr; gap: 1.25rem; margin-bottom: 1.25rem; }
207
  .upload-panel, .results-panel, .viz-section { background: rgba(20, 20, 20, 0.8); border: 1px solid rgba(255, 255, 255, 0.1); border-radius: 24px; padding: 1.5rem; backdrop-filter: blur(20px); box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4); }
208
+ .section-label { font-size: 1.1rem; font-weight: 700; background: #93c5fd; -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin: 0 0 1rem; text-align: center; letter-spacing: 0.5px; }
209
  #input-image { border: 2px dashed rgba(59, 130, 246, 0.4) !important; border-radius: 20px !important; background: rgba(10, 10, 10, 0.6) !important; height: 320px !important; transition: all 0.3s ease; }
210
  #input-image:hover { border-color: #3b82f6 !important; background: rgba(20, 20, 30, 0.8) !important; transform: scale(1.02); box-shadow: 0 0 30px rgba(59, 130, 246, 0.2); }
211
+ #input-image .upload-text { border-radius: 0 !important; }
212
+ #input-image [data-testid="image"] { border-radius: 0 !important; }
213
  .btn-row { display: flex; gap: 0.75rem; margin-top: 1rem; }
214
  .gr-button { border-radius: 14px !important; font-weight: 700 !important; height: 50px !important; font-size: 0.95rem !important; transition: all 0.3s ease !important; border: none !important; letter-spacing: 0.5px; text-transform: uppercase; }
215
  .gr-button-primary { background: linear-gradient(135deg, #3b82f6, #8b5cf6) !important; color: white !important; box-shadow: 0 4px 20px rgba(59, 130, 246, 0.4) !important; }
216
  .gr-button-primary:hover { transform: translateY(-3px) !important; box-shadow: 0 8px 30px rgba(59, 130, 246, 0.6) !important; }
217
  .gr-button-secondary { background: rgba(40, 40, 40, 0.8) !important; color: #a0a0a0 !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; }
218
  .pred-header { display: flex; align-items: center; justify-content: space-between; flex-wrap: wrap; gap: 1rem; margin-bottom: 0.75rem; }
219
+ .pred-label { font-size: clamp(1.5rem, 3vw, 2rem); font-weight: 900; color: #ffffff; margin: 0; letter-spacing: -0.5px; }
220
  .badge { padding: 0.5rem 1.25rem; border-radius: 50px; font-size: 0.875rem; font-weight: 700; text-transform: uppercase; letter-spacing: 0.5px; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.3); }
221
  .badge-high { background: linear-gradient(135deg, #10b981, #059669); color: white; }
222
  .badge-medium { background: linear-gradient(135deg, #f59e0b, #d97706); color: white; }
223
  .badge-low { background: linear-gradient(135deg, #ef4444, #dc2626); color: white; }
224
+ .conf-score { font-size: clamp(2rem, 5vw, 3rem); font-weight: 900; background: linear-gradient(135deg, #3b82f6, #8b5cf6); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin-bottom: 1rem; letter-spacing: -1px; }
225
+ .model-tag { background: rgba(16, 185, 129, 0.15); border: 1px solid rgba(16, 185, 129, 0.3); color: #10b981; padding: 0.5rem 1rem; border-radius: 12px; font-size: 0.8rem; font-weight: 700; text-align: center; margin-bottom: 1rem; }
226
  .divider { height: 2px; background: linear-gradient(90deg, transparent, rgba(59, 130, 246, 0.3), transparent); margin: 1.5rem 0; }
227
  .top5-grid { display: flex; flex-direction: column; gap: 0.875rem; }
228
  .top5-row { display: grid; grid-template-columns: 40px 1fr auto 80px; align-items: center; gap: 0.875rem; font-size: 0.95rem; padding: 0.5rem; border-radius: 12px; background: rgba(30, 30, 30, 0.5); transition: all 0.3s ease; }
 
232
  .bar-wrap { background: rgba(40, 40, 40, 0.8); height: 10px; border-radius: 5px; overflow: hidden; min-width: 100px; box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.3); }
233
  .bar { background: linear-gradient(90deg, #3b82f6, #8b5cf6); height: 100%; transition: width 1s ease; border-radius: 5px; box-shadow: 0 0 10px rgba(59, 130, 246, 0.5); }
234
  .pct { color: #3b82f6; font-weight: 700; font-size: 0.9rem; text-align: right; }
235
+ #result-image, #detailed-heatmap { border-radius: 16px !important; overflow: hidden; width: 100% !important; height: auto !important; min-height: 500px !important; box-shadow: 0 8px 32px rgba(0, 0, 0, 0.5); object-fit: contain !important; }
236
  .placeholder { text-align: center; padding: 4rem 1.5rem; color: #606060; font-size: 1.1rem; line-height: 1.6; }
237
  .placeholder strong { color: #3b82f6; }
238
  .error-msg { color: #ef4444; background: rgba(239, 68, 68, 0.1); padding: 1.5rem; border-radius: 16px; text-align: center; border: 1px solid rgba(239, 68, 68, 0.3); }
 
 
239
  footer, .footer { display: none !important; }
240
  ::-webkit-scrollbar { width: 10px; }
241
  ::-webkit-scrollbar-track { background: rgba(20, 20, 20, 0.5); }
242
  ::-webkit-scrollbar-thumb { background: rgba(59, 130, 246, 0.5); border-radius: 5px; }
243
+ @media (max-width: 768px) {
244
+ .top-section { grid-template-columns: 1fr; }
245
+ #input-image { height: 240px !important; }
246
+ .top5-row { grid-template-columns: 35px 1fr 70px; }
247
+ .bar-wrap { grid-column: 1 / -1; margin-top: 0.375rem; }
248
+ #result-image { min-height: 600px !important; max-height: none !important; }
249
+ #detailed-heatmap { min-height: 450px !important; max-height: none !important; }
250
+ .viz-section { padding: 1rem; }
251
+ .section-label { font-size: 1rem; }
252
+ }
253
+ @media (max-width: 480px) {
254
+ .main-wrapper { padding: 1rem; }
255
+ #result-image { min-height: 550px !important; }
256
+ #detailed-heatmap { min-height: 400px !important; }
257
+ }
258
  """
259
 
260
 
261
+ with gr.Blocks(css=custom_css, theme=gr.themes.Base(), title="XAI Image Classifier") as demo:
262
  gr.HTML('<link rel="icon" href="https://res.cloudinary.com/ddn0xuwut/image/upload/v1761284764/encryption_hc0fxo.png" type="image/png">')
263
 
264
  with gr.Column(elem_classes="main-wrapper"):
265
+ gr.HTML('''
266
+ <div class="hero-header">
267
+ <h1>XAI Image Classifier</h1>
268
+ <p class="subtitle">ResNet152 with Grad-CAM Explainability</p>
269
+ <div class="model-tag">⚑ Maximum Accuracy Production Version</div>
270
+ </div>
271
+ ''')
272
 
273
  with gr.Row(elem_classes="top-section"):
274
  with gr.Column(scale=0, min_width=400, elem_classes="upload-panel"):
 
279
  clear_btn = gr.ClearButton([input_image], value="πŸ—‘οΈ Clear", size="lg", scale=1)
280
 
281
  with gr.Column(scale=1, elem_classes="results-panel"):
282
+ output_text = gr.HTML('<div class="placeholder"><strong>πŸ‘‹ Welcome to XAI Classifier!</strong><br><br>This classifier uses ResNet152:<br>β€’ 82.3% ImageNet Top-1 Accuracy<br>β€’ Grad-CAM Visual Explainability<br>β€’ 1000 Object Categories<br><br>Upload an image to see the magic! ✨</div>')
283
 
284
  with gr.Column(elem_classes="viz-section"):
285
+ gr.HTML("<div class='section-label'>🎯 Visual Explainability Analysis</div>")
286
+ output_image = gr.Image(label=None, type="pil", show_label=False, elem_id="result-image", container=False)
287
 
288
  with gr.Column(elem_classes="viz-section"):
289
+ gr.HTML("<div class='section-label'>πŸ”¬ Detailed Heatmap Comparison</div>")
290
+ detailed_heatmap = gr.Image(label=None, type="pil", show_label=False, elem_id="detailed-heatmap", container=False)
291
 
292
  predict_btn.click(fn=predict_and_explain, inputs=[input_image], outputs=[output_text, output_image, detailed_heatmap])
293
 
294
  if __name__ == "__main__":
295
+ demo.launch(share=False, show_error=True)