Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify, send_file | |
| from tensorflow.keras.models import load_model, Model | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import cv2 | |
| import tensorflow as tf | |
| from datetime import datetime | |
| import sqlite3 | |
| app = Flask(__name__) | |
| # β Directory and database path | |
| OUTPUT_DIR = '/tmp/results' | |
| if not os.path.exists(OUTPUT_DIR): | |
| os.makedirs(OUTPUT_DIR) | |
| DB_PATH = os.path.join(OUTPUT_DIR, 'results.db') | |
| def init_db(): | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS results ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| image_filename TEXT, | |
| prediction TEXT, | |
| confidence REAL, | |
| gradcam_filename TEXT, | |
| timestamp TEXT | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| init_db() | |
| # β Load Model | |
| model = load_model('mobilenet_glaucoma_model.h5', compile=False) | |
| # β Preprocess Image | |
| def preprocess_image(img): | |
| img = img.resize((224, 224)) | |
| img = np.array(img) / 255.0 | |
| img = np.expand_dims(img, axis=0) | |
| return img | |
| # β Grad-CAM Generation | |
| def make_gradcam(img_array, model, last_conv_layer_name='Conv_1_bn'): | |
| """Generate Grad-CAM for the given image and model.""" | |
| last_conv_layer = model.get_layer(last_conv_layer_name) | |
| grad_model = Model(inputs=model.inputs, outputs=[last_conv_layer.output, model.output]) | |
| with tf.GradientTape() as tape: | |
| conv_outputs, predictions = grad_model(img_array) | |
| loss = predictions[:, 0] | |
| grads = tape.gradient(loss, conv_outputs) | |
| pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) | |
| conv_outputs = conv_outputs[0].numpy() | |
| pooled_grads = pooled_grads.numpy() | |
| for i in range(conv_outputs.shape[-1]): | |
| conv_outputs[..., i] *= pooled_grads[i] | |
| heatmap = tf.reduce_mean(conv_outputs, axis=-1).numpy() | |
| heatmap = np.maximum(heatmap, 0) | |
| heatmap /= np.max(heatmap) | |
| return heatmap | |
| # β Save Grad-CAM Overlay | |
| def save_gradcam_image(original_img, heatmap, filename='gradcam.png', output_dir=OUTPUT_DIR): | |
| """Save the Grad-CAM overlay image and return its path.""" | |
| img = np.array(original_img.resize((224, 224))) | |
| heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) | |
| heatmap = np.uint8(255 * heatmap) | |
| heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
| overlay = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0) | |
| filepath = os.path.join(output_dir, filename) | |
| cv2.imwrite(filepath, overlay) | |
| return filepath | |
| def home(): | |
| return "Glaucoma Detection Flask API is running!" | |
| def test_file(): | |
| """Check if the model file is present and readable.""" | |
| filepath = "mobilenet_glaucoma_model.h5" | |
| if os.path.exists(filepath): | |
| return f"β Model file found at: {filepath}" | |
| else: | |
| return "β Model file NOT found." | |
| def predict(): | |
| """Perform prediction, save results (including uploaded image), and save to SQLite database.""" | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file uploaded'}), 400 | |
| uploaded_file = request.files['file'] | |
| if uploaded_file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| try: | |
| # β Save the uploaded image | |
| timestamp = int(datetime.now().timestamp()) | |
| uploaded_filename = f"uploaded_{timestamp}.png" | |
| uploaded_file_path = os.path.join(OUTPUT_DIR, uploaded_filename) | |
| uploaded_file.save(uploaded_file_path) | |
| # β Perform prediction | |
| img = Image.open(uploaded_file_path).convert('RGB') | |
| img_array = preprocess_image(img) | |
| prediction = model.predict(img_array)[0] | |
| glaucoma_prob = 1 - prediction[0] | |
| normal_prob = prediction[0] | |
| result = 'Glaucoma' if glaucoma_prob > normal_prob else 'Normal' | |
| confidence = float(glaucoma_prob) if result == 'Glaucoma' else float(normal_prob) | |
| # β Grad-CAM | |
| heatmap = make_gradcam(img_array, model, last_conv_layer_name='Conv_1_bn') | |
| gradcam_filename = f"gradcam_{timestamp}.png" | |
| save_gradcam_image(img, heatmap, filename=gradcam_filename) | |
| # β Save results to SQLite | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, timestamp) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, (uploaded_filename, result, confidence, gradcam_filename, datetime.now().isoformat())) | |
| conn.commit() | |
| conn.close() | |
| return jsonify({ | |
| 'prediction': result, | |
| 'confidence': confidence, | |
| 'normal_probability': float(normal_prob), | |
| 'glaucoma_probability': float(glaucoma_prob), | |
| 'gradcam_image': gradcam_filename, | |
| 'image_filename': uploaded_filename | |
| }) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def results(): | |
| """List all results from the SQLite database.""" | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT * FROM results ORDER BY timestamp DESC") | |
| results_data = cursor.fetchall() | |
| conn.close() | |
| results_list = [] | |
| for record in results_data: | |
| results_list.append({ | |
| 'id': record[0], | |
| 'image_filename': record[1], | |
| 'prediction': record[2], | |
| 'confidence': record[3], | |
| 'gradcam_filename': record[4], | |
| 'timestamp': record[5] | |
| }) | |
| return jsonify(results_list) | |
| def get_gradcam(filename): | |
| """Serve the Grad-CAM overlay image.""" | |
| filepath = os.path.join(OUTPUT_DIR, filename) | |
| if os.path.exists(filepath): | |
| return send_file(filepath, mimetype='image/png') | |
| else: | |
| return jsonify({'error': 'File not found'}), 404 | |
| def get_image(filename): | |
| """Serve the original uploaded image.""" | |
| filepath = os.path.join(OUTPUT_DIR, filename) | |
| if os.path.exists(filepath): | |
| return send_file(filepath, mimetype='image/png') | |
| else: | |
| return jsonify({'error': 'File not found'}), 404 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) | |