Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify, send_file | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| import os | |
| import numpy as np | |
| from datetime import datetime | |
| import sqlite3 | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| import cv2 | |
| # β New Grad-CAM++ imports | |
| from pytorch_grad_cam import GradCAMPlusPlus | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| app = Flask(__name__) | |
| # β Directory and database path | |
| OUTPUT_DIR = '/tmp/results' | |
| if not os.path.exists(OUTPUT_DIR): | |
| os.makedirs(OUTPUT_DIR) | |
| DB_PATH = os.path.join(OUTPUT_DIR, 'results.db') | |
| def init_db(): | |
| """Initialize SQLite database for storing results.""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS results ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| image_filename TEXT, | |
| prediction TEXT, | |
| confidence REAL, | |
| gradcam_filename TEXT, | |
| gradcam_gray_filename TEXT, | |
| timestamp TEXT | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| init_db() | |
| # β Import your custom GLAM model | |
| from densenet_withglam import get_model_with_attention | |
| # β Instantiate the model | |
| model = get_model_with_attention('densenet169', num_classes=3) # Will have GLAM | |
| model.load_state_dict(torch.load('densenet169_seed40_best.pt', map_location='cpu')) | |
| model.eval() | |
| # β Class Names | |
| CLASS_NAMES = ["Advanced", "Early", "Normal"] | |
| # β Transformation for input images | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def home(): | |
| """Check that the API is working.""" | |
| return "Glaucoma Detection Flask API (3-Class Model) is running!" | |
| def test_file(): | |
| """Check if the .pt model file is present and readable.""" | |
| filepath = "densenet169_seed40_best2.pt" | |
| if os.path.exists(filepath): | |
| return f"β Model file found at: {filepath}" | |
| else: | |
| return "β Model file NOT found." | |
| def predict(): | |
| """Perform prediction and save results (including Grad-CAM++) to the database.""" | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file uploaded'}), 400 | |
| uploaded_file = request.files['file'] | |
| if uploaded_file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| try: | |
| # β Save the uploaded image | |
| timestamp = int(datetime.now().timestamp()) | |
| uploaded_filename = f"uploaded_{timestamp}.png" | |
| uploaded_file_path = os.path.join(OUTPUT_DIR, uploaded_filename) | |
| uploaded_file.save(uploaded_file_path) | |
| # β Perform prediction | |
| img = Image.open(uploaded_file_path).convert('RGB') | |
| input_tensor = transform(img).unsqueeze(0) | |
| # β Get prediction | |
| output = model(input_tensor) | |
| probabilities = F.softmax(output, dim=1).cpu().detach().numpy()[0] | |
| class_index = np.argmax(probabilities) | |
| result = CLASS_NAMES[class_index] | |
| confidence = float(probabilities[class_index]) | |
| # β Grad-CAM++ setup | |
| target_layer = model.features[2].global_spatial_conv | |
| cam_model = GradCAMPlusPlus(model=model, target_layers=[target_layer]) | |
| # β Get Grad-CAM++ map | |
| cam_output = cam_model(input_tensor=input_tensor, targets=[ClassifierOutputTarget(class_index)])[0] | |
| # β Create RGB overlay | |
| original_img = np.asarray(img.resize((224, 224)), dtype=np.float32) / 255.0 | |
| overlay = show_cam_on_image(original_img, cam_output, use_rgb=True) | |
| # β Create grayscale version | |
| cam_normalized = np.uint8(255 * cam_output) | |
| # β Save overlay | |
| gradcam_filename = f"gradcam_{timestamp}.png" | |
| gradcam_file_path = os.path.join(OUTPUT_DIR, gradcam_filename) | |
| cv2.imwrite(gradcam_file_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)) | |
| # β Save grayscale | |
| gray_filename = f"gradcam_gray_{timestamp}.png" | |
| gray_file_path = os.path.join(OUTPUT_DIR, gray_filename) | |
| cv2.imwrite(gray_file_path, cam_normalized) | |
| # β Save results to database | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, gradcam_gray_filename, timestamp) | |
| VALUES (?, ?, ?, ?, ?, ?) | |
| """, (uploaded_filename, result, confidence, gradcam_filename, gray_filename, datetime.now().isoformat())) | |
| conn.commit() | |
| conn.close() | |
| # β Return results | |
| return jsonify({ | |
| 'prediction': result, | |
| 'confidence': confidence, | |
| 'normal_probability': float(probabilities[0]), | |
| 'early_glaucoma_probability': float(probabilities[1]), | |
| 'advanced_glaucoma_probability': float(probabilities[2]), | |
| 'gradcam_image': gradcam_filename, | |
| 'gradcam_gray_image': gray_filename, | |
| 'image_filename': uploaded_filename | |
| }) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def results(): | |
| """List all results from the SQLite database.""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT * FROM results ORDER BY timestamp DESC") | |
| results_data = cursor.fetchall() | |
| conn.close() | |
| results_list = [] | |
| for record in results_data: | |
| results_list.append({ | |
| 'id': record[0], | |
| 'image_filename': record[1], | |
| 'prediction': record[2], | |
| 'confidence': record[3], | |
| 'gradcam_filename': record[4], | |
| 'gradcam_gray_filename': record[5], | |
| 'timestamp': record[6] | |
| }) | |
| return jsonify(results_list) | |
| def get_gradcam(filename): | |
| """Serve the Grad-CAM overlay image.""" | |
| filepath = os.path.join(OUTPUT_DIR, filename) | |
| if os.path.exists(filepath): | |
| return send_file(filepath, mimetype='image/png') | |
| else: | |
| return jsonify({'error': 'File not found'}), 404 | |
| def get_image(filename): | |
| """Serve the original uploaded image.""" | |
| filepath = os.path.join(OUTPUT_DIR, filename) | |
| if os.path.exists(filepath): | |
| return send_file(filepath, mimetype='image/png') | |
| else: | |
| return jsonify({'error': 'File not found'}), 404 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) | |