Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -1,8 +1,3 @@ 
     | 
|
| 1 | 
         
            -
            """
         
     | 
| 2 | 
         
            -
             XAI Image Classifier - Optimized Production Version
         
     | 
| 3 | 
         
            -
            ===============================================================
         
     | 
| 4 | 
         
            -
            """
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
             
            import torch
         
     | 
| 7 | 
         
             
            import torch.nn as nn
         
     | 
| 8 | 
         
             
            from torchvision import models, transforms
         
     | 
| 
         @@ -20,22 +15,28 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
     | 
|
| 20 | 
         | 
| 21 | 
         
             
            @torch.no_grad()
         
     | 
| 22 | 
         
             
            def load_model_and_labels():
         
     | 
| 23 | 
         
            -
                """Load  
     | 
| 24 | 
         
            -
                 
     | 
| 25 | 
         
            -
                model.eval()
         
     | 
| 26 | 
         
            -
                model = model.to(DEVICE)
         
     | 
| 27 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 28 | 
         
             
                url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
         
     | 
| 29 | 
         
             
                response = urllib.request.urlopen(url)
         
     | 
| 30 | 
         
             
                labels = [line.decode('utf-8').strip() for line in response.readlines()]
         
     | 
| 31 | 
         | 
| 
         | 
|
| 32 | 
         
             
                return model, labels
         
     | 
| 33 | 
         | 
| 34 | 
         
             
            model, IMAGENET_LABELS = load_model_and_labels()
         
     | 
| 35 | 
         | 
| 
         | 
|
| 36 | 
         
             
            target_layer = model.layer4[-1]
         
     | 
| 37 | 
         
             
            gradcam = LayerGradCam(model, target_layer)
         
     | 
| 38 | 
         | 
| 
         | 
|
| 39 | 
         
             
            transform = transforms.Compose([
         
     | 
| 40 | 
         
             
                transforms.Resize((224, 224)),
         
     | 
| 41 | 
         
             
                transforms.ToTensor(),
         
     | 
| 
         @@ -48,29 +49,29 @@ def predict_and_explain(image): 
     | 
|
| 48 | 
         
             
                    return "Please upload an image", None, None
         
     | 
| 49 | 
         | 
| 50 | 
         
             
                try:
         
     | 
| 51 | 
         
            -
                    
         
     | 
| 52 | 
         
             
                    img_tensor = transform(image).unsqueeze(0).to(DEVICE)
         
     | 
| 53 | 
         | 
| 54 | 
         
             
                    with torch.no_grad():
         
     | 
| 
         | 
|
| 55 | 
         
             
                        output = model(img_tensor)
         
     | 
| 56 | 
         
            -
                        
         
     | 
| 57 | 
         
            -
                        temperature = 1.0  
         
     | 
| 58 | 
         
            -
                        scaled_output = output / temperature
         
     | 
| 59 | 
         
            -
                        probabilities = torch.softmax(scaled_output, dim=1)
         
     | 
| 60 | 
         
            -
                        top10_prob, top10_idx = torch.topk(probabilities, 10)
         
     | 
| 61 | 
         | 
| 
         | 
|
| 62 | 
         
             
                    pred_class = top10_idx[0][0].item()
         
     | 
| 63 | 
         
             
                    confidence = top10_prob[0][0].item()
         
     | 
| 64 | 
         | 
| 
         | 
|
| 65 | 
         
             
                    attributions = gradcam.attribute(img_tensor, target=pred_class)
         
     | 
| 66 | 
         
             
                    attr_resized = interpolate(attributions, size=(224, 224), mode='bilinear', align_corners=False)
         
     | 
| 67 | 
         
             
                    attr_np = attr_resized.squeeze().cpu().detach().numpy()
         
     | 
| 68 | 
         
             
                    attr_np = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min() + 1e-8)
         
     | 
| 69 | 
         | 
| 70 | 
         
            -
                     
     | 
| 
         | 
|
| 71 | 
         
             
                    fig.patch.set_facecolor('#0a0a0a')
         
     | 
| 72 | 
         | 
| 73 | 
         
            -
                    gs = fig.add_gridspec(2, 3, height_ratios=[2, 1], hspace=0. 
     | 
| 74 | 
         | 
| 75 | 
         
             
                    ax1 = fig.add_subplot(gs[0, 0])
         
     | 
| 76 | 
         
             
                    ax2 = fig.add_subplot(gs[0, 1])
         
     | 
| 
         @@ -78,31 +79,31 @@ def predict_and_explain(image): 
     | 
|
| 78 | 
         
             
                    ax4 = fig.add_subplot(gs[1, :])
         
     | 
| 79 | 
         | 
| 80 | 
         
             
                    ax1.imshow(image)
         
     | 
| 81 | 
         
            -
                    ax1.set_title("Original Image", fontsize= 
     | 
| 82 | 
         
             
                    ax1.axis('off')
         
     | 
| 83 | 
         | 
| 84 | 
         
             
                    im = ax2.imshow(attr_np, cmap='jet', interpolation='bilinear')
         
     | 
| 85 | 
         
            -
                    ax2.set_title("Grad-CAM Heatmap", fontsize= 
     | 
| 86 | 
         
             
                    ax2.axis('off')
         
     | 
| 87 | 
         
             
                    cbar = plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
         
     | 
| 88 | 
         
            -
                    cbar.ax.tick_params(labelsize= 
     | 
| 89 | 
         
            -
                    cbar.set_label('Importance', rotation=270, labelpad= 
     | 
| 90 | 
         | 
| 91 | 
         
             
                    ax3.imshow(image)
         
     | 
| 92 | 
         
             
                    ax3.imshow(attr_np, cmap='jet', alpha=0.5, interpolation='bilinear')
         
     | 
| 93 | 
         
            -
                    ax3.set_title(f"AI Focus: {IMAGENET_LABELS[pred_class]}", fontsize= 
     | 
| 94 | 
         
             
                    ax3.axis('off')
         
     | 
| 95 | 
         | 
| 96 | 
         
             
                    top10_labels = [IMAGENET_LABELS[idx.item()] for idx in top10_idx[0]]
         
     | 
| 97 | 
         
             
                    top10_probs = [prob.item() * 100 for prob in top10_prob[0]]
         
     | 
| 98 | 
         | 
| 99 | 
         
             
                    colors = ['#10b981' if i == 9 else '#3b82f6' if i >= 7 else '#8b5cf6' for i in range(10)]
         
     | 
| 100 | 
         
            -
                    bars = ax4.barh(range(10), top10_probs[::-1], color=colors[::-1], edgecolor='#1a1a1a', linewidth= 
     | 
| 101 | 
         | 
| 102 | 
         
             
                    ax4.set_yticks(range(10))
         
     | 
| 103 | 
         
            -
                    ax4.set_yticklabels(top10_labels[::-1], fontsize= 
     | 
| 104 | 
         
            -
                    ax4.set_xlabel('Confidence (%)', fontsize= 
     | 
| 105 | 
         
            -
                    ax4.set_title('Top 10 Predictions', fontsize= 
     | 
| 106 | 
         
             
                    ax4.set_xlim([0, 100])
         
     | 
| 107 | 
         
             
                    ax4.grid(axis='x', alpha=0.2, color='#404040', linestyle='--')
         
     | 
| 108 | 
         
             
                    ax4.set_facecolor('#0a0a0a')
         
     | 
| 
         @@ -110,48 +111,53 @@ def predict_and_explain(image): 
     | 
|
| 110 | 
         
             
                    ax4.spines['right'].set_visible(False)
         
     | 
| 111 | 
         
             
                    ax4.spines['left'].set_color('#404040')
         
     | 
| 112 | 
         
             
                    ax4.spines['bottom'].set_color('#404040')
         
     | 
| 113 | 
         
            -
                    ax4.tick_params(colors='#a0a0a0', labelsize= 
     | 
| 114 | 
         | 
| 115 | 
         
             
                    for bar, prob in zip(bars, top10_probs[::-1]):
         
     | 
| 116 | 
         
             
                        ax4.text(prob + 1.5, bar.get_y() + bar.get_height()/2, 
         
     | 
| 117 | 
         
            -
                                f'{prob:.1f}%', va='center', fontsize= 
     | 
| 118 | 
         | 
| 119 | 
         
             
                    plt.tight_layout()
         
     | 
| 120 | 
         | 
| 121 | 
         
             
                    buf = BytesIO()
         
     | 
| 122 | 
         
            -
                    plt.savefig(buf, format='png', dpi= 
     | 
| 123 | 
         
             
                    buf.seek(0)
         
     | 
| 124 | 
         
             
                    result_image = Image.open(buf)
         
     | 
| 125 | 
         
             
                    plt.close(fig)
         
     | 
| 126 | 
         | 
| 127 | 
         
            -
                     
     | 
| 
         | 
|
| 128 | 
         
             
                    fig2.patch.set_facecolor('#0a0a0a')
         
     | 
| 129 | 
         | 
| 130 | 
         
            -
                    axes[0].imshow(image)
         
     | 
| 131 | 
         
            -
                    axes[0].set_title("Original", fontsize= 
     | 
| 132 | 
         
            -
                    axes[0].axis('off')
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 133 | 
         | 
| 134 | 
         
            -
                     
     | 
| 135 | 
         
            -
                    axes[1]. 
     | 
| 136 | 
         
            -
                    axes[1]. 
     | 
| 137 | 
         
            -
                     
     | 
| 138 | 
         
            -
                    cbar2.ax.tick_params(labelsize=10, colors='#a0a0a0')
         
     | 
| 139 | 
         | 
| 140 | 
         
            -
                    axes[ 
     | 
| 141 | 
         
            -
                    axes[ 
     | 
| 142 | 
         
            -
                    axes[ 
     | 
| 143 | 
         
            -
                    axes[ 
     | 
| 144 | 
         
            -
                    axes[ 
     | 
| 145 | 
         | 
| 146 | 
         
             
                    plt.tight_layout()
         
     | 
| 147 | 
         | 
| 148 | 
         
             
                    buf2 = BytesIO()
         
     | 
| 149 | 
         
            -
                    plt.savefig(buf2, format='png', dpi= 
     | 
| 150 | 
         
             
                    buf2.seek(0)
         
     | 
| 151 | 
         
             
                    detailed_heatmap = Image.open(buf2)
         
     | 
| 152 | 
         
             
                    plt.close(fig2)
         
     | 
| 153 | 
         | 
| 154 | 
         
            -
             
     | 
| 155 | 
         
             
                    badge = "high" if confidence > 0.8 else "medium" if confidence > 0.5 else "low"
         
     | 
| 156 | 
         
             
                    badge_text = "High Confidence" if confidence > 0.8 else "Medium Confidence" if confidence > 0.5 else "Low Confidence"
         
     | 
| 157 | 
         
             
                    badge_icon = "π―" if confidence > 0.8 else "β‘" if confidence > 0.5 else "β οΈ"
         
     | 
| 
         @@ -176,6 +182,7 @@ def predict_and_explain(image): 
     | 
|
| 176 | 
         
             
                    <div class="badge badge-{badge}">{badge_icon} {badge_text}</div>
         
     | 
| 177 | 
         
             
                </div>
         
     | 
| 178 | 
         
             
                <div class="conf-score">{confidence*100:.2f}%</div>
         
     | 
| 
         | 
|
| 179 | 
         
             
                <div class="divider"></div>
         
     | 
| 180 | 
         
             
                {top5_html}
         
     | 
| 181 | 
         
             
            </div>"""
         
     | 
| 
         @@ -187,31 +194,35 @@ def predict_and_explain(image): 
     | 
|
| 187 | 
         | 
| 188 | 
         | 
| 189 | 
         
             
            custom_css = """
         
     | 
| 190 | 
         
            -
            @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&display=swap');
         
     | 
| 191 | 
         
             
            * { box-sizing: border-box; margin: 0; padding: 0; }
         
     | 
| 192 | 
         
             
            body, .gradio-container { margin: 0 !important; padding: 0 !important; width: 100vw !important; min-height: 100vh !important; max-width: 100vw !important; background: linear-gradient(135deg, #0a0a0a 0%, #1a1a1a 50%, #0f0f0f 100%) !important; font-family: 'Inter', sans-serif !important; color: #e0e0e0 !important; overflow-x: hidden !important; }
         
     | 
| 193 | 
         
             
            .gradio-container { padding: 0 !important; }
         
     | 
| 194 | 
         
             
            .main-wrapper { padding: 1.5rem; max-width: 1920px; margin: 0 auto; position: relative; z-index: 2; }
         
     | 
| 195 | 
         
            -
            .hero-header { text-align: center; padding: 2rem 1rem 1.5rem; margin-bottom: 1.5rem; }
         
     | 
| 196 | 
         
            -
            .hero-header h1 { font-size: clamp(2rem, 5vw, 3.5rem); font-weight:  
     | 
| 197 | 
         
            -
            .hero-header .subtitle { font-size: clamp(0.95rem, 2vw, 1.2rem); color: #808080; font-weight: 400; margin: 0; }
         
     | 
| 
         | 
|
| 198 | 
         
             
            .top-section { display: grid; grid-template-columns: 400px 1fr; gap: 1.25rem; margin-bottom: 1.25rem; }
         
     | 
| 199 | 
         
             
            .upload-panel, .results-panel, .viz-section { background: rgba(20, 20, 20, 0.8); border: 1px solid rgba(255, 255, 255, 0.1); border-radius: 24px; padding: 1.5rem; backdrop-filter: blur(20px); box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4); }
         
     | 
| 200 | 
         
            -
            .section-label { font-size: 1.1rem; font-weight: 700; background:  
     | 
| 201 | 
         
             
            #input-image { border: 2px dashed rgba(59, 130, 246, 0.4) !important; border-radius: 20px !important; background: rgba(10, 10, 10, 0.6) !important; height: 320px !important; transition: all 0.3s ease; }
         
     | 
| 202 | 
         
             
            #input-image:hover { border-color: #3b82f6 !important; background: rgba(20, 20, 30, 0.8) !important; transform: scale(1.02); box-shadow: 0 0 30px rgba(59, 130, 246, 0.2); }
         
     | 
| 
         | 
|
| 
         | 
|
| 203 | 
         
             
            .btn-row { display: flex; gap: 0.75rem; margin-top: 1rem; }
         
     | 
| 204 | 
         
             
            .gr-button { border-radius: 14px !important; font-weight: 700 !important; height: 50px !important; font-size: 0.95rem !important; transition: all 0.3s ease !important; border: none !important; letter-spacing: 0.5px; text-transform: uppercase; }
         
     | 
| 205 | 
         
             
            .gr-button-primary { background: linear-gradient(135deg, #3b82f6, #8b5cf6) !important; color: white !important; box-shadow: 0 4px 20px rgba(59, 130, 246, 0.4) !important; }
         
     | 
| 206 | 
         
             
            .gr-button-primary:hover { transform: translateY(-3px) !important; box-shadow: 0 8px 30px rgba(59, 130, 246, 0.6) !important; }
         
     | 
| 207 | 
         
             
            .gr-button-secondary { background: rgba(40, 40, 40, 0.8) !important; color: #a0a0a0 !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; }
         
     | 
| 208 | 
         
             
            .pred-header { display: flex; align-items: center; justify-content: space-between; flex-wrap: wrap; gap: 1rem; margin-bottom: 0.75rem; }
         
     | 
| 209 | 
         
            -
            .pred-label { font-size: clamp(1.5rem, 3vw, 2rem); font-weight:  
     | 
| 210 | 
         
             
            .badge { padding: 0.5rem 1.25rem; border-radius: 50px; font-size: 0.875rem; font-weight: 700; text-transform: uppercase; letter-spacing: 0.5px; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.3); }
         
     | 
| 211 | 
         
             
            .badge-high { background: linear-gradient(135deg, #10b981, #059669); color: white; }
         
     | 
| 212 | 
         
             
            .badge-medium { background: linear-gradient(135deg, #f59e0b, #d97706); color: white; }
         
     | 
| 213 | 
         
             
            .badge-low { background: linear-gradient(135deg, #ef4444, #dc2626); color: white; }
         
     | 
| 214 | 
         
            -
            .conf-score { font-size: clamp(2rem, 5vw, 3rem); font-weight: 900; background: linear-gradient(135deg, #3b82f6, #8b5cf6); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin-bottom:  
     | 
| 
         | 
|
| 215 | 
         
             
            .divider { height: 2px; background: linear-gradient(90deg, transparent, rgba(59, 130, 246, 0.3), transparent); margin: 1.5rem 0; }
         
     | 
| 216 | 
         
             
            .top5-grid { display: flex; flex-direction: column; gap: 0.875rem; }
         
     | 
| 217 | 
         
             
            .top5-row { display: grid; grid-template-columns: 40px 1fr auto 80px; align-items: center; gap: 0.875rem; font-size: 0.95rem; padding: 0.5rem; border-radius: 12px; background: rgba(30, 30, 30, 0.5); transition: all 0.3s ease; }
         
     | 
| 
         @@ -221,25 +232,43 @@ body, .gradio-container { margin: 0 !important; padding: 0 !important; width: 10 
     | 
|
| 221 | 
         
             
            .bar-wrap { background: rgba(40, 40, 40, 0.8); height: 10px; border-radius: 5px; overflow: hidden; min-width: 100px; box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.3); }
         
     | 
| 222 | 
         
             
            .bar { background: linear-gradient(90deg, #3b82f6, #8b5cf6); height: 100%; transition: width 1s ease; border-radius: 5px; box-shadow: 0 0 10px rgba(59, 130, 246, 0.5); }
         
     | 
| 223 | 
         
             
            .pct { color: #3b82f6; font-weight: 700; font-size: 0.9rem; text-align: right; }
         
     | 
| 224 | 
         
            -
            #result-image, #detailed-heatmap { border-radius: 16px !important; overflow: hidden; width: 100 
     | 
| 225 | 
         
             
            .placeholder { text-align: center; padding: 4rem 1.5rem; color: #606060; font-size: 1.1rem; line-height: 1.6; }
         
     | 
| 226 | 
         
             
            .placeholder strong { color: #3b82f6; }
         
     | 
| 227 | 
         
             
            .error-msg { color: #ef4444; background: rgba(239, 68, 68, 0.1); padding: 1.5rem; border-radius: 16px; text-align: center; border: 1px solid rgba(239, 68, 68, 0.3); }
         
     | 
| 228 | 
         
            -
            .gr-accordion { background: rgba(20, 20, 20, 0.8) !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; border-radius: 20px !important; margin-top: 1.5rem; }
         
     | 
| 229 | 
         
            -
            .gr-accordion summary { color: #e0e0e0 !important; font-weight: 700 !important; padding: 1.25rem 1.5rem !important; font-size: 1.1rem !important; }
         
     | 
| 230 | 
         
             
            footer, .footer { display: none !important; }
         
     | 
| 231 | 
         
             
            ::-webkit-scrollbar { width: 10px; }
         
     | 
| 232 | 
         
             
            ::-webkit-scrollbar-track { background: rgba(20, 20, 20, 0.5); }
         
     | 
| 233 | 
         
             
            ::-webkit-scrollbar-thumb { background: rgba(59, 130, 246, 0.5); border-radius: 5px; }
         
     | 
| 234 | 
         
            -
            @media (max-width: 768px) {  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 235 | 
         
             
            """
         
     | 
| 236 | 
         | 
| 237 | 
         | 
| 238 | 
         
            -
            with gr.Blocks(css=custom_css, theme=gr.themes.Base(), title=" 
     | 
| 239 | 
         
             
                gr.HTML('<link rel="icon" href="https://res.cloudinary.com/ddn0xuwut/image/upload/v1761284764/encryption_hc0fxo.png" type="image/png">')
         
     | 
| 240 | 
         | 
| 241 | 
         
             
                with gr.Column(elem_classes="main-wrapper"):
         
     | 
| 242 | 
         
            -
                    gr.HTML(' 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 243 | 
         | 
| 244 | 
         
             
                    with gr.Row(elem_classes="top-section"):
         
     | 
| 245 | 
         
             
                        with gr.Column(scale=0, min_width=400, elem_classes="upload-panel"):
         
     | 
| 
         @@ -250,17 +279,17 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Base(), title="Explainable AI") a 
     | 
|
| 250 | 
         
             
                                clear_btn = gr.ClearButton([input_image], value="ποΈ Clear", size="lg", scale=1)
         
     | 
| 251 | 
         | 
| 252 | 
         
             
                        with gr.Column(scale=1, elem_classes="results-panel"):
         
     | 
| 253 | 
         
            -
                            output_text = gr.HTML('<div class="placeholder"><strong>π Welcome!</strong><br><br>Upload an image  
     | 
| 254 | 
         | 
| 255 | 
         
             
                    with gr.Column(elem_classes="viz-section"):
         
     | 
| 256 | 
         
            -
                        gr.HTML("<div class='section-label'>π― Visual Explainability  
     | 
| 257 | 
         
            -
                        output_image = gr.Image(label=None, type="pil", show_label=False, elem_id="result-image",  
     | 
| 258 | 
         | 
| 259 | 
         
             
                    with gr.Column(elem_classes="viz-section"):
         
     | 
| 260 | 
         
            -
                        gr.HTML("<div class='section-label'>π¬  
     | 
| 261 | 
         
            -
                        detailed_heatmap = gr.Image(label=None, type="pil", show_label=False, elem_id="detailed-heatmap",  
     | 
| 262 | 
         | 
| 263 | 
         
             
                predict_btn.click(fn=predict_and_explain, inputs=[input_image], outputs=[output_text, output_image, detailed_heatmap])
         
     | 
| 264 | 
         | 
| 265 | 
         
             
            if __name__ == "__main__":
         
     | 
| 266 | 
         
            -
                demo.launch(share= 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
             
            import torch.nn as nn
         
     | 
| 3 | 
         
             
            from torchvision import models, transforms
         
     | 
| 
         | 
|
| 15 | 
         | 
| 16 | 
         
             
            @torch.no_grad()
         
     | 
| 17 | 
         
             
            def load_model_and_labels():
         
     | 
| 18 | 
         
            +
                """Load ResNet152 model for maximum accuracy"""
         
     | 
| 19 | 
         
            +
                print("π Loading ResNet152 model...")
         
     | 
| 
         | 
|
| 
         | 
|
| 20 | 
         | 
| 21 | 
         
            +
                # ResNet152 (Best accuracy in ResNet family)
         
     | 
| 22 | 
         
            +
                model = models.resnet152(weights='IMAGENET1K_V2')
         
     | 
| 23 | 
         
            +
                model.eval().to(DEVICE)
         
     | 
| 24 | 
         
            +
                
         
     | 
| 25 | 
         
            +
                # Load ImageNet labels
         
     | 
| 26 | 
         
             
                url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
         
     | 
| 27 | 
         
             
                response = urllib.request.urlopen(url)
         
     | 
| 28 | 
         
             
                labels = [line.decode('utf-8').strip() for line in response.readlines()]
         
     | 
| 29 | 
         | 
| 30 | 
         
            +
                print("β
 Model loaded successfully!")
         
     | 
| 31 | 
         
             
                return model, labels
         
     | 
| 32 | 
         | 
| 33 | 
         
             
            model, IMAGENET_LABELS = load_model_and_labels()
         
     | 
| 34 | 
         | 
| 35 | 
         
            +
            # Setup Grad-CAM
         
     | 
| 36 | 
         
             
            target_layer = model.layer4[-1]
         
     | 
| 37 | 
         
             
            gradcam = LayerGradCam(model, target_layer)
         
     | 
| 38 | 
         | 
| 39 | 
         
            +
            # Transform for ResNet152
         
     | 
| 40 | 
         
             
            transform = transforms.Compose([
         
     | 
| 41 | 
         
             
                transforms.Resize((224, 224)),
         
     | 
| 42 | 
         
             
                transforms.ToTensor(),
         
     | 
| 
         | 
|
| 49 | 
         
             
                    return "Please upload an image", None, None
         
     | 
| 50 | 
         | 
| 51 | 
         
             
                try:
         
     | 
| 52 | 
         
            +
                    # Prepare input
         
     | 
| 53 | 
         
             
                    img_tensor = transform(image).unsqueeze(0).to(DEVICE)
         
     | 
| 54 | 
         | 
| 55 | 
         
             
                    with torch.no_grad():
         
     | 
| 56 | 
         
            +
                        # ResNet152 prediction
         
     | 
| 57 | 
         
             
                        output = model(img_tensor)
         
     | 
| 58 | 
         
            +
                        probabilities = torch.softmax(output, dim=1)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 59 | 
         | 
| 60 | 
         
            +
                    top10_prob, top10_idx = torch.topk(probabilities, 10)
         
     | 
| 61 | 
         
             
                    pred_class = top10_idx[0][0].item()
         
     | 
| 62 | 
         
             
                    confidence = top10_prob[0][0].item()
         
     | 
| 63 | 
         | 
| 64 | 
         
            +
                    # Generate Grad-CAM
         
     | 
| 65 | 
         
             
                    attributions = gradcam.attribute(img_tensor, target=pred_class)
         
     | 
| 66 | 
         
             
                    attr_resized = interpolate(attributions, size=(224, 224), mode='bilinear', align_corners=False)
         
     | 
| 67 | 
         
             
                    attr_np = attr_resized.squeeze().cpu().detach().numpy()
         
     | 
| 68 | 
         
             
                    attr_np = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min() + 1e-8)
         
     | 
| 69 | 
         | 
| 70 | 
         
            +
                    # Main visualization
         
     | 
| 71 | 
         
            +
                    fig = plt.figure(figsize=(24, 14))  
         
     | 
| 72 | 
         
             
                    fig.patch.set_facecolor('#0a0a0a')
         
     | 
| 73 | 
         | 
| 74 | 
         
            +
                    gs = fig.add_gridspec(2, 3, height_ratios=[2, 1], hspace=0.3, wspace=0.15)  
         
     | 
| 75 | 
         | 
| 76 | 
         
             
                    ax1 = fig.add_subplot(gs[0, 0])
         
     | 
| 77 | 
         
             
                    ax2 = fig.add_subplot(gs[0, 1])
         
     | 
| 
         | 
|
| 79 | 
         
             
                    ax4 = fig.add_subplot(gs[1, :])
         
     | 
| 80 | 
         | 
| 81 | 
         
             
                    ax1.imshow(image)
         
     | 
| 82 | 
         
            +
                    ax1.set_title("Original Image", fontsize=18, fontweight='700', color='#e0e0e0', pad=20)
         
     | 
| 83 | 
         
             
                    ax1.axis('off')
         
     | 
| 84 | 
         | 
| 85 | 
         
             
                    im = ax2.imshow(attr_np, cmap='jet', interpolation='bilinear')
         
     | 
| 86 | 
         
            +
                    ax2.set_title("Grad-CAM Heatmap", fontsize=18, fontweight='700', color='#e0e0e0', pad=20)
         
     | 
| 87 | 
         
             
                    ax2.axis('off')
         
     | 
| 88 | 
         
             
                    cbar = plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
         
     | 
| 89 | 
         
            +
                    cbar.ax.tick_params(labelsize=12, colors='#a0a0a0')
         
     | 
| 90 | 
         
            +
                    cbar.set_label('Importance', rotation=270, labelpad=25, color='#e0e0e0', fontsize=13, fontweight='600')
         
     | 
| 91 | 
         | 
| 92 | 
         
             
                    ax3.imshow(image)
         
     | 
| 93 | 
         
             
                    ax3.imshow(attr_np, cmap='jet', alpha=0.5, interpolation='bilinear')
         
     | 
| 94 | 
         
            +
                    ax3.set_title(f"AI Focus: {IMAGENET_LABELS[pred_class]}", fontsize=18, fontweight='700', color='#e0e0e0', pad=20)
         
     | 
| 95 | 
         
             
                    ax3.axis('off')
         
     | 
| 96 | 
         | 
| 97 | 
         
             
                    top10_labels = [IMAGENET_LABELS[idx.item()] for idx in top10_idx[0]]
         
     | 
| 98 | 
         
             
                    top10_probs = [prob.item() * 100 for prob in top10_prob[0]]
         
     | 
| 99 | 
         | 
| 100 | 
         
             
                    colors = ['#10b981' if i == 9 else '#3b82f6' if i >= 7 else '#8b5cf6' for i in range(10)]
         
     | 
| 101 | 
         
            +
                    bars = ax4.barh(range(10), top10_probs[::-1], color=colors[::-1], edgecolor='#1a1a1a', linewidth=2)
         
     | 
| 102 | 
         | 
| 103 | 
         
             
                    ax4.set_yticks(range(10))
         
     | 
| 104 | 
         
            +
                    ax4.set_yticklabels(top10_labels[::-1], fontsize=14, color='#e0e0e0', fontweight='600')
         
     | 
| 105 | 
         
            +
                    ax4.set_xlabel('Confidence (%)', fontsize=15, color='#e0e0e0', fontweight='700')
         
     | 
| 106 | 
         
            +
                    ax4.set_title('Top 10 Predictions', fontsize=19, fontweight='800', color='#e0e0e0', pad=20)
         
     | 
| 107 | 
         
             
                    ax4.set_xlim([0, 100])
         
     | 
| 108 | 
         
             
                    ax4.grid(axis='x', alpha=0.2, color='#404040', linestyle='--')
         
     | 
| 109 | 
         
             
                    ax4.set_facecolor('#0a0a0a')
         
     | 
| 
         | 
|
| 111 | 
         
             
                    ax4.spines['right'].set_visible(False)
         
     | 
| 112 | 
         
             
                    ax4.spines['left'].set_color('#404040')
         
     | 
| 113 | 
         
             
                    ax4.spines['bottom'].set_color('#404040')
         
     | 
| 114 | 
         
            +
                    ax4.tick_params(colors='#a0a0a0', labelsize=13)
         
     | 
| 115 | 
         | 
| 116 | 
         
             
                    for bar, prob in zip(bars, top10_probs[::-1]):
         
     | 
| 117 | 
         
             
                        ax4.text(prob + 1.5, bar.get_y() + bar.get_height()/2, 
         
     | 
| 118 | 
         
            +
                                f'{prob:.1f}%', va='center', fontsize=13, color='#e0e0e0', fontweight='700')
         
     | 
| 119 | 
         | 
| 120 | 
         
             
                    plt.tight_layout()
         
     | 
| 121 | 
         | 
| 122 | 
         
             
                    buf = BytesIO()
         
     | 
| 123 | 
         
            +
                    plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='#0a0a0a')
         
     | 
| 124 | 
         
             
                    buf.seek(0)
         
     | 
| 125 | 
         
             
                    result_image = Image.open(buf)
         
     | 
| 126 | 
         
             
                    plt.close(fig)
         
     | 
| 127 | 
         | 
| 128 | 
         
            +
                    # Detailed heatmap analysis
         
     | 
| 129 | 
         
            +
                    fig2, axes = plt.subplots(2, 2, figsize=(20, 18)) 
         
     | 
| 130 | 
         
             
                    fig2.patch.set_facecolor('#0a0a0a')
         
     | 
| 131 | 
         | 
| 132 | 
         
            +
                    axes[0, 0].imshow(image)
         
     | 
| 133 | 
         
            +
                    axes[0, 0].set_title("Original Image", fontsize=17, fontweight='700', color='#e0e0e0', pad=15)
         
     | 
| 134 | 
         
            +
                    axes[0, 0].axis('off')
         
     | 
| 135 | 
         
            +
                    
         
     | 
| 136 | 
         
            +
                    axes[0, 1].imshow(image)
         
     | 
| 137 | 
         
            +
                    axes[0, 1].imshow(attr_np, cmap='jet', alpha=0.6, interpolation='bilinear')
         
     | 
| 138 | 
         
            +
                    axes[0, 1].set_title("Jet Colormap Overlay", fontsize=17, fontweight='700', color='#e0e0e0', pad=15)
         
     | 
| 139 | 
         
            +
                    axes[0, 1].axis('off')
         
     | 
| 140 | 
         | 
| 141 | 
         
            +
                    axes[1, 0].imshow(image)
         
     | 
| 142 | 
         
            +
                    axes[1, 0].imshow(attr_np, cmap='hot', alpha=0.6, interpolation='bilinear')
         
     | 
| 143 | 
         
            +
                    axes[1, 0].set_title("Hot Colormap Overlay", fontsize=17, fontweight='700', color='#e0e0e0', pad=15)
         
     | 
| 144 | 
         
            +
                    axes[1, 0].axis('off')
         
     | 
| 
         | 
|
| 145 | 
         | 
| 146 | 
         
            +
                    axes[1, 1].imshow(image)
         
     | 
| 147 | 
         
            +
                    axes[1, 1].imshow(attr_np, cmap='viridis', alpha=0.6, interpolation='gaussian')
         
     | 
| 148 | 
         
            +
                    axes[1, 1].contour(attr_np, levels=6, colors='white', linewidths=2, alpha=0.9)
         
     | 
| 149 | 
         
            +
                    axes[1, 1].set_title("Viridis + Contours", fontsize=17, fontweight='700', color='#e0e0e0', pad=15)
         
     | 
| 150 | 
         
            +
                    axes[1, 1].axis('off')
         
     | 
| 151 | 
         | 
| 152 | 
         
             
                    plt.tight_layout()
         
     | 
| 153 | 
         | 
| 154 | 
         
             
                    buf2 = BytesIO()
         
     | 
| 155 | 
         
            +
                    plt.savefig(buf2, format='png', dpi=140, bbox_inches='tight', facecolor='#0a0a0a')
         
     | 
| 156 | 
         
             
                    buf2.seek(0)
         
     | 
| 157 | 
         
             
                    detailed_heatmap = Image.open(buf2)
         
     | 
| 158 | 
         
             
                    plt.close(fig2)
         
     | 
| 159 | 
         | 
| 160 | 
         
            +
                    # Prediction card
         
     | 
| 161 | 
         
             
                    badge = "high" if confidence > 0.8 else "medium" if confidence > 0.5 else "low"
         
     | 
| 162 | 
         
             
                    badge_text = "High Confidence" if confidence > 0.8 else "Medium Confidence" if confidence > 0.5 else "Low Confidence"
         
     | 
| 163 | 
         
             
                    badge_icon = "π―" if confidence > 0.8 else "β‘" if confidence > 0.5 else "β οΈ"
         
     | 
| 
         | 
|
| 182 | 
         
             
                    <div class="badge badge-{badge}">{badge_icon} {badge_text}</div>
         
     | 
| 183 | 
         
             
                </div>
         
     | 
| 184 | 
         
             
                <div class="conf-score">{confidence*100:.2f}%</div>
         
     | 
| 185 | 
         
            +
                <div class="model-tag">π¬ ResNet152 Architecture (82.3% ImageNet Accuracy)</div>
         
     | 
| 186 | 
         
             
                <div class="divider"></div>
         
     | 
| 187 | 
         
             
                {top5_html}
         
     | 
| 188 | 
         
             
            </div>"""
         
     | 
| 
         | 
|
| 194 | 
         | 
| 195 | 
         | 
| 196 | 
         
             
            custom_css = """
         
     | 
| 197 | 
         
            +
            @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800;900&display=swap');
         
     | 
| 198 | 
         
             
            * { box-sizing: border-box; margin: 0; padding: 0; }
         
     | 
| 199 | 
         
             
            body, .gradio-container { margin: 0 !important; padding: 0 !important; width: 100vw !important; min-height: 100vh !important; max-width: 100vw !important; background: linear-gradient(135deg, #0a0a0a 0%, #1a1a1a 50%, #0f0f0f 100%) !important; font-family: 'Inter', sans-serif !important; color: #e0e0e0 !important; overflow-x: hidden !important; }
         
     | 
| 200 | 
         
             
            .gradio-container { padding: 0 !important; }
         
     | 
| 201 | 
         
             
            .main-wrapper { padding: 1.5rem; max-width: 1920px; margin: 0 auto; position: relative; z-index: 2; }
         
     | 
| 202 | 
         
            +
            .hero-header { text-align: center; padding: 2rem 1rem 1.5rem; margin-bottom: 1.5rem; position: relative; }
         
     | 
| 203 | 
         
            +
            .hero-header h1 { font-size: clamp(2rem, 5vw, 3.5rem); font-weight: 900; background-color: #d8b4fe; -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin: 0 0 0.5rem; letter-spacing: -1px; }
         
     | 
| 204 | 
         
            +
            .hero-header .subtitle { font-size: clamp(0.95rem, 2vw, 1.2rem); color: #808080; font-weight: 400; margin: 0 0 0.5rem; }
         
     | 
| 205 | 
         
            +
            .hero-header .model-tag { display: inline-block; background: #93c5fd; border: 1px solid rgba(59, 130, 246, 0.3); color: #3b82f6; padding: 0.5rem 1.5rem; border-radius: 50px; font-size: 0.85rem; font-weight: 700; letter-spacing: 0.5px; margin-top: 0.5rem; }
         
     | 
| 206 | 
         
             
            .top-section { display: grid; grid-template-columns: 400px 1fr; gap: 1.25rem; margin-bottom: 1.25rem; }
         
     | 
| 207 | 
         
             
            .upload-panel, .results-panel, .viz-section { background: rgba(20, 20, 20, 0.8); border: 1px solid rgba(255, 255, 255, 0.1); border-radius: 24px; padding: 1.5rem; backdrop-filter: blur(20px); box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4); }
         
     | 
| 208 | 
         
            +
            .section-label { font-size: 1.1rem; font-weight: 700; background: #93c5fd; -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin: 0 0 1rem; text-align: center; letter-spacing: 0.5px; }
         
     | 
| 209 | 
         
             
            #input-image { border: 2px dashed rgba(59, 130, 246, 0.4) !important; border-radius: 20px !important; background: rgba(10, 10, 10, 0.6) !important; height: 320px !important; transition: all 0.3s ease; }
         
     | 
| 210 | 
         
             
            #input-image:hover { border-color: #3b82f6 !important; background: rgba(20, 20, 30, 0.8) !important; transform: scale(1.02); box-shadow: 0 0 30px rgba(59, 130, 246, 0.2); }
         
     | 
| 211 | 
         
            +
            #input-image .upload-text { border-radius: 0 !important; }
         
     | 
| 212 | 
         
            +
            #input-image [data-testid="image"] { border-radius: 0 !important; }
         
     | 
| 213 | 
         
             
            .btn-row { display: flex; gap: 0.75rem; margin-top: 1rem; }
         
     | 
| 214 | 
         
             
            .gr-button { border-radius: 14px !important; font-weight: 700 !important; height: 50px !important; font-size: 0.95rem !important; transition: all 0.3s ease !important; border: none !important; letter-spacing: 0.5px; text-transform: uppercase; }
         
     | 
| 215 | 
         
             
            .gr-button-primary { background: linear-gradient(135deg, #3b82f6, #8b5cf6) !important; color: white !important; box-shadow: 0 4px 20px rgba(59, 130, 246, 0.4) !important; }
         
     | 
| 216 | 
         
             
            .gr-button-primary:hover { transform: translateY(-3px) !important; box-shadow: 0 8px 30px rgba(59, 130, 246, 0.6) !important; }
         
     | 
| 217 | 
         
             
            .gr-button-secondary { background: rgba(40, 40, 40, 0.8) !important; color: #a0a0a0 !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; }
         
     | 
| 218 | 
         
             
            .pred-header { display: flex; align-items: center; justify-content: space-between; flex-wrap: wrap; gap: 1rem; margin-bottom: 0.75rem; }
         
     | 
| 219 | 
         
            +
            .pred-label { font-size: clamp(1.5rem, 3vw, 2rem); font-weight: 900; color: #ffffff; margin: 0; letter-spacing: -0.5px; }
         
     | 
| 220 | 
         
             
            .badge { padding: 0.5rem 1.25rem; border-radius: 50px; font-size: 0.875rem; font-weight: 700; text-transform: uppercase; letter-spacing: 0.5px; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.3); }
         
     | 
| 221 | 
         
             
            .badge-high { background: linear-gradient(135deg, #10b981, #059669); color: white; }
         
     | 
| 222 | 
         
             
            .badge-medium { background: linear-gradient(135deg, #f59e0b, #d97706); color: white; }
         
     | 
| 223 | 
         
             
            .badge-low { background: linear-gradient(135deg, #ef4444, #dc2626); color: white; }
         
     | 
| 224 | 
         
            +
            .conf-score { font-size: clamp(2rem, 5vw, 3rem); font-weight: 900; background: linear-gradient(135deg, #3b82f6, #8b5cf6); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin-bottom: 1rem; letter-spacing: -1px; }
         
     | 
| 225 | 
         
            +
            .model-tag { background: rgba(16, 185, 129, 0.15); border: 1px solid rgba(16, 185, 129, 0.3); color: #10b981; padding: 0.5rem 1rem; border-radius: 12px; font-size: 0.8rem; font-weight: 700; text-align: center; margin-bottom: 1rem; }
         
     | 
| 226 | 
         
             
            .divider { height: 2px; background: linear-gradient(90deg, transparent, rgba(59, 130, 246, 0.3), transparent); margin: 1.5rem 0; }
         
     | 
| 227 | 
         
             
            .top5-grid { display: flex; flex-direction: column; gap: 0.875rem; }
         
     | 
| 228 | 
         
             
            .top5-row { display: grid; grid-template-columns: 40px 1fr auto 80px; align-items: center; gap: 0.875rem; font-size: 0.95rem; padding: 0.5rem; border-radius: 12px; background: rgba(30, 30, 30, 0.5); transition: all 0.3s ease; }
         
     | 
| 
         | 
|
| 232 | 
         
             
            .bar-wrap { background: rgba(40, 40, 40, 0.8); height: 10px; border-radius: 5px; overflow: hidden; min-width: 100px; box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.3); }
         
     | 
| 233 | 
         
             
            .bar { background: linear-gradient(90deg, #3b82f6, #8b5cf6); height: 100%; transition: width 1s ease; border-radius: 5px; box-shadow: 0 0 10px rgba(59, 130, 246, 0.5); }
         
     | 
| 234 | 
         
             
            .pct { color: #3b82f6; font-weight: 700; font-size: 0.9rem; text-align: right; }
         
     | 
| 235 | 
         
            +
            #result-image, #detailed-heatmap { border-radius: 16px !important; overflow: hidden; width: 100% !important; height: auto !important; min-height: 500px !important; box-shadow: 0 8px 32px rgba(0, 0, 0, 0.5); object-fit: contain !important; }
         
     | 
| 236 | 
         
             
            .placeholder { text-align: center; padding: 4rem 1.5rem; color: #606060; font-size: 1.1rem; line-height: 1.6; }
         
     | 
| 237 | 
         
             
            .placeholder strong { color: #3b82f6; }
         
     | 
| 238 | 
         
             
            .error-msg { color: #ef4444; background: rgba(239, 68, 68, 0.1); padding: 1.5rem; border-radius: 16px; text-align: center; border: 1px solid rgba(239, 68, 68, 0.3); }
         
     | 
| 
         | 
|
| 
         | 
|
| 239 | 
         
             
            footer, .footer { display: none !important; }
         
     | 
| 240 | 
         
             
            ::-webkit-scrollbar { width: 10px; }
         
     | 
| 241 | 
         
             
            ::-webkit-scrollbar-track { background: rgba(20, 20, 20, 0.5); }
         
     | 
| 242 | 
         
             
            ::-webkit-scrollbar-thumb { background: rgba(59, 130, 246, 0.5); border-radius: 5px; }
         
     | 
| 243 | 
         
            +
            @media (max-width: 768px) { 
         
     | 
| 244 | 
         
            +
                .top-section { grid-template-columns: 1fr; } 
         
     | 
| 245 | 
         
            +
                #input-image { height: 240px !important; } 
         
     | 
| 246 | 
         
            +
                .top5-row { grid-template-columns: 35px 1fr 70px; } 
         
     | 
| 247 | 
         
            +
                .bar-wrap { grid-column: 1 / -1; margin-top: 0.375rem; }
         
     | 
| 248 | 
         
            +
                #result-image { min-height: 600px !important; max-height: none !important; }
         
     | 
| 249 | 
         
            +
                #detailed-heatmap { min-height: 450px !important; max-height: none !important; }
         
     | 
| 250 | 
         
            +
                .viz-section { padding: 1rem; }
         
     | 
| 251 | 
         
            +
                .section-label { font-size: 1rem; }
         
     | 
| 252 | 
         
            +
            }
         
     | 
| 253 | 
         
            +
            @media (max-width: 480px) {
         
     | 
| 254 | 
         
            +
                .main-wrapper { padding: 1rem; }
         
     | 
| 255 | 
         
            +
                #result-image { min-height: 550px !important; }
         
     | 
| 256 | 
         
            +
                #detailed-heatmap { min-height: 400px !important; }
         
     | 
| 257 | 
         
            +
            }
         
     | 
| 258 | 
         
             
            """
         
     | 
| 259 | 
         | 
| 260 | 
         | 
| 261 | 
         
            +
            with gr.Blocks(css=custom_css, theme=gr.themes.Base(), title="XAI Image Classifier") as demo:
         
     | 
| 262 | 
         
             
                gr.HTML('<link rel="icon" href="https://res.cloudinary.com/ddn0xuwut/image/upload/v1761284764/encryption_hc0fxo.png" type="image/png">')
         
     | 
| 263 | 
         | 
| 264 | 
         
             
                with gr.Column(elem_classes="main-wrapper"):
         
     | 
| 265 | 
         
            +
                    gr.HTML('''
         
     | 
| 266 | 
         
            +
                        <div class="hero-header">
         
     | 
| 267 | 
         
            +
                            <h1>XAI Image Classifier</h1>
         
     | 
| 268 | 
         
            +
                            <p class="subtitle">ResNet152 with Grad-CAM Explainability</p>
         
     | 
| 269 | 
         
            +
                            <div class="model-tag">β‘ Maximum Accuracy Production Version</div>
         
     | 
| 270 | 
         
            +
                        </div>
         
     | 
| 271 | 
         
            +
                    ''')
         
     | 
| 272 | 
         | 
| 273 | 
         
             
                    with gr.Row(elem_classes="top-section"):
         
     | 
| 274 | 
         
             
                        with gr.Column(scale=0, min_width=400, elem_classes="upload-panel"):
         
     | 
| 
         | 
|
| 279 | 
         
             
                                clear_btn = gr.ClearButton([input_image], value="ποΈ Clear", size="lg", scale=1)
         
     | 
| 280 | 
         | 
| 281 | 
         
             
                        with gr.Column(scale=1, elem_classes="results-panel"):
         
     | 
| 282 | 
         
            +
                            output_text = gr.HTML('<div class="placeholder"><strong>π Welcome to XAI Classifier!</strong><br><br>This classifier uses ResNet152:<br>β’ 82.3% ImageNet Top-1 Accuracy<br>β’ Grad-CAM Visual Explainability<br>β’ 1000 Object Categories<br><br>Upload an image to see the magic! β¨</div>')
         
     | 
| 283 | 
         | 
| 284 | 
         
             
                    with gr.Column(elem_classes="viz-section"):
         
     | 
| 285 | 
         
            +
                        gr.HTML("<div class='section-label'>π― Visual Explainability Analysis</div>")
         
     | 
| 286 | 
         
            +
                        output_image = gr.Image(label=None, type="pil", show_label=False, elem_id="result-image", container=False)
         
     | 
| 287 | 
         | 
| 288 | 
         
             
                    with gr.Column(elem_classes="viz-section"):
         
     | 
| 289 | 
         
            +
                        gr.HTML("<div class='section-label'>π¬ Detailed Heatmap Comparison</div>")
         
     | 
| 290 | 
         
            +
                        detailed_heatmap = gr.Image(label=None, type="pil", show_label=False, elem_id="detailed-heatmap", container=False)
         
     | 
| 291 | 
         | 
| 292 | 
         
             
                predict_btn.click(fn=predict_and_explain, inputs=[input_image], outputs=[output_text, output_image, detailed_heatmap])
         
     | 
| 293 | 
         | 
| 294 | 
         
             
            if __name__ == "__main__":
         
     | 
| 295 | 
         
            +
                demo.launch(share=False, show_error=True)
         
     |