import os import logging import cv2 import numpy as np from PIL import Image from datetime import datetime import gradio as gr import spaces import torch import time from huggingface_hub import HfApi, HfFolder from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS # =============== LOGGING SETUP =============== logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # =============== CONFIGURATION =============== UPLOADS_DIR = "uploads" if not os.path.exists(UPLOADS_DIR): os.makedirs(UPLOADS_DIR) logging.info(f"Created uploads directory: {UPLOADS_DIR}") HF_TOKEN = os.getenv("HF_TOKEN") YOLO_MODEL_PATH = "src/best.pt" SEG_MODEL_PATH = "src/segmentation_model.h5" GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"] DATASET_ID = "SmartHeal/wound-image-uploads" MAX_NEW_TOKENS = 1024 # Reduced for stability PIXELS_PER_CM = 38 # =============== GLOBAL CACHES =============== models_cache = {} knowledge_base_cache = {} # =============== LAZY LOADING FUNCTIONS (CPU-SAFE) =============== def load_yolo_model(yolo_model_path): """Lazy import and load YOLO model to avoid CUDA initialization.""" from ultralytics import YOLO return YOLO(yolo_model_path) def load_segmentation_model(seg_model_path): """Lazy import and load segmentation model.""" import tensorflow as tf tf.config.set_visible_devices([], 'GPU') # Force CPU for TensorFlow from tensorflow.keras.models import load_model return load_model(seg_model_path, compile=False) def load_classification_pipeline(hf_token): """Lazy import and load classification pipeline (CPU only).""" from transformers import pipeline return pipeline( "image-classification", model="Hemg/Wound-classification", token=hf_token, device="cpu" ) def load_embedding_model(): """Load embedding model for knowledge base.""" return HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"} ) # =============== MODEL INITIALIZATION =============== def initialize_cpu_models(): """Initialize all CPU-only models once.""" global models_cache if HF_TOKEN: HfFolder.save_token(HF_TOKEN) logging.info("✅ HuggingFace token set") if "det" not in models_cache: try: models_cache["det"] = load_yolo_model(YOLO_MODEL_PATH) logging.info("✅ YOLO model loaded (CPU only)") except Exception as e: logging.error(f"YOLO load failed: {e}") if "seg" not in models_cache: try: models_cache["seg"] = load_segmentation_model(SEG_MODEL_PATH) logging.info("✅ Segmentation model loaded (CPU)") except Exception as e: logging.warning(f"Segmentation model not available: {e}") if "cls" not in models_cache: try: models_cache["cls"] = load_classification_pipeline(HF_TOKEN) logging.info("✅ Classification pipeline loaded (CPU)") except Exception as e: logging.warning(f"Classification pipeline not available: {e}") if "embedding_model" not in models_cache: try: models_cache["embedding_model"] = load_embedding_model() logging.info("✅ Embedding model loaded (CPU)") except Exception as e: logging.warning(f"Embedding model not available: {e}") def setup_knowledge_base(): """Load PDF documents and create FAISS vector store.""" global knowledge_base_cache if "vector_store" in knowledge_base_cache: return docs = [] for pdf_path in GUIDELINE_PDFS: if os.path.exists(pdf_path): try: loader = PyPDFLoader(pdf_path) docs.extend(loader.load()) logging.info(f"Loaded PDF: {pdf_path}") except Exception as e: logging.warning(f"Failed to load PDF {pdf_path}: {e}") if docs and "embedding_model" in models_cache: splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) chunks = splitter.split_documents(docs) knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"]) logging.info(f"✅ Knowledge base ready with {len(chunks)} chunks") else: knowledge_base_cache["vector_store"] = None logging.warning("Knowledge base unavailable") # Initialize models on app startup initialize_cpu_models() setup_knowledge_base() # =============== GPU-DECORATED MEDGEMMA FUNCTION WITH TIMEOUT HANDLING =============== @spaces.GPU(enable_queue=True, duration=90) # Reduced duration for stability def generate_medgemma_report_with_timeout( patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None, ): """GPU-only function for MedGemma report generation with improved timeout handling.""" import torch from transformers import pipeline try: # Clear GPU cache first if torch.cuda.is_available(): torch.cuda.empty_cache() # Use a shorter, more focused prompt to reduce processing time prompt = f""" You are a medical AI assistant. Analyze this wound image and patient data to provide a clinical assessment. Patient: {patient_info} Wound: {visual_results.get('wound_type', 'Unknown')} - {visual_results.get('length_cm', 0)}×{visual_results.get('breadth_cm', 0)}cm Provide a structured report with: 1. Clinical Summary (wound appearance, size, location) 2. Treatment Recommendations (dressings, care protocols) 3. Risk Assessment (healing factors) 4. Monitoring Plan (follow-up schedule) Keep response concise but medically comprehensive. """ # Initialize pipeline with optimized settings pipe = pipeline( "image-text-to-text", model="google/medgemma-4b-it", torch_dtype=torch.bfloat16, device_map="auto", token=HF_TOKEN, model_kwargs={"low_cpu_mem_usage": True, "use_cache": True} ) messages = [ { "role": "user", "content": [ {"type": "image", "image": image_pil}, {"type": "text", "text": prompt}, ] } ] # Generate with conservative settings start_time = time.time() output = pipe( text=messages, max_new_tokens=max_new_tokens or 800, # Reduced for stability do_sample=False, temperature=0.7, pad_token_id=pipe.tokenizer.eos_token_id ) processing_time = time.time() - start_time logging.info(f"✅ MedGemma processing completed in {processing_time:.2f} seconds") if output and len(output) > 0: result = output[0]["generated_text"][-1].get("content", "").strip() return result if result else "⚠️ Empty response generated" else: return "⚠️ No output generated" except Exception as e: logging.error(f"❌ MedGemma generation error: {e}") return f"❌ Report generation failed: {str(e)}" finally: # Clear GPU memory if torch.cuda.is_available(): torch.cuda.empty_cache() # =============== AI PROCESSOR CLASS =============== class AIProcessor: def __init__(self): self.models_cache = models_cache self.knowledge_base_cache = knowledge_base_cache self.px_per_cm = PIXELS_PER_CM self.uploads_dir = UPLOADS_DIR self.dataset_id = DATASET_ID self.hf_token = HF_TOKEN def perform_visual_analysis(self, image_pil: Image.Image) -> dict: """Performs the full visual analysis pipeline.""" try: # Convert PIL to OpenCV format image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) # YOLO Detection results = self.models_cache["det"].predict(image_cv, verbose=False, device="cpu") if not results or not results[0].boxes: raise ValueError("No wound could be detected.") box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int) detected_region_cv = image_cv[box[1]:box[3], box[0]:box[2]] # Segmentation input_size = self.models_cache["seg"].input_shape[1:3] resized = cv2.resize(detected_region_cv, (input_size[1], input_size[0])) mask_pred = self.models_cache["seg"].predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0] mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8) # Calculate measurements contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) length, breadth, area = (0, 0, 0) if contours: cnt = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(cnt) length, breadth, area = round(h / self.px_per_cm, 2), round(w / self.px_per_cm, 2), round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2) # Classification detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB)) wound_type = max(self.models_cache["cls"](detected_image_pil), key=lambda x: x["score"])["label"] # Save visualization images os.makedirs(f"{self.uploads_dir}/analysis", exist_ok=True) ts = datetime.now().strftime("%Y%m%d_%H%M%S") # Detection visualization det_vis = image_cv.copy() cv2.rectangle(det_vis, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) det_path = f"{self.uploads_dir}/analysis/detection_{ts}.png" cv2.imwrite(det_path, det_vis) # Original image original_path = f"{self.uploads_dir}/analysis/original_{ts}.png" cv2.imwrite(original_path, image_cv) # Segmentation visualization seg_path = None if contours: mask_resized = cv2.resize(mask_np * 255, (detected_region_cv.shape[1], detected_region_cv.shape[0]), interpolation=cv2.INTER_NEAREST) overlay = detected_region_cv.copy() overlay[mask_resized > 127] = [0, 0, 255] # Red overlay for wound area seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0) seg_path = f"{self.uploads_dir}/analysis/segmentation_{ts}.png" cv2.imwrite(seg_path, seg_vis) visual_results = { "wound_type": wound_type, "length_cm": length, "breadth_cm": breadth, "surface_area_cm2": area, "detection_confidence": float(results[0].boxes.conf[0].cpu().item()) if results[0].boxes.conf is not None else 0.0, "detection_image_path": det_path, "segmentation_image_path": seg_path, "original_image_path": original_path } return visual_results except Exception as e: logging.error(f"Visual analysis failed: {e}") raise e def query_guidelines(self, query: str) -> str: """Query the knowledge base for relevant information.""" try: vector_store = self.knowledge_base_cache.get("vector_store") if not vector_store: return "Knowledge base is not available." retriever = vector_store.as_retriever(search_kwargs={"k": 5}) # Reduced for efficiency docs = retriever.invoke(query) return "\n\n".join([f"Source: {doc.metadata.get('source', 'N/A')}\nContent: {doc.page_content[:300]}..." for doc in docs]) except Exception as e: logging.error(f"Guidelines query failed: {e}") return f"Guidelines query failed: {str(e)}" def generate_final_report( self, patient_info: str, visual_results: dict, guideline_context: str, image_pil: Image.Image, max_new_tokens: int = None ) -> str: """Generate final report using MedGemma with timeout handling.""" try: # Try MedGemma with timeout handling report = generate_medgemma_report_with_timeout( patient_info, visual_results, guideline_context, image_pil, max_new_tokens ) # Check if report is valid if report and report.strip() and not report.startswith("❌") and not report.startswith("⚠️"): return report else: logging.warning("MedGemma returned invalid response, using fallback") return self._generate_fallback_report(patient_info, visual_results, guideline_context) except Exception as e: logging.error(f"MedGemma report generation failed: {e}") return self._generate_fallback_report(patient_info, visual_results, guideline_context) def _generate_fallback_report( self, patient_info: str, visual_results: dict, guideline_context: str ) -> str: """Generate comprehensive fallback report if MedGemma fails.""" report = f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report ## 📋 Patient Information {patient_info} ## 🔍 Visual Analysis Results - **Wound Type**: {visual_results.get('wound_type', 'Unknown')} - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm² - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%} ## 📊 Analysis Images Available - **Original Image**: {visual_results.get('original_image_path', 'Available')} - **Detection Visualization**: {visual_results.get('detection_image_path', 'Available')} - **Segmentation Overlay**: {visual_results.get('segmentation_image_path', 'Available')} ## 🎯 Clinical Assessment Summary ### Wound Classification Based on automated analysis, this wound has been classified as **{visual_results.get('wound_type', 'Unspecified')}** with the following characteristics: - Size: {visual_results.get('length_cm', 0)} × {visual_results.get('breadth_cm', 0)} cm - Total area: {visual_results.get('surface_area_cm2', 0)} cm² - Detection confidence: {visual_results.get('detection_confidence', 0):.1%} ### Clinical Observations The automated visual analysis provides quantitative measurements that should be verified through clinical examination. The wound type classification helps guide initial treatment considerations. ## 💊 Treatment Recommendations ### Wound Care Protocol 1. **Assessment**: Comprehensive clinical evaluation by qualified healthcare professional 2. **Cleaning**: Gentle wound cleansing with appropriate solution 3. **Debridement**: Remove necrotic tissue if present (professional assessment required) 4. **Dressing Selection**: Choose appropriate dressing based on wound characteristics: - Moisture level assessment - Infection risk evaluation - Patient comfort and mobility ### Monitoring Plan - **Initial Phase**: Daily assessment for first week - **Ongoing Care**: Reassessment every 2-3 days or as clinically indicated - **Documentation**: Regular photo documentation and measurement tracking - **Progress Evaluation**: Weekly review of healing progression ## ⚠️ Risk Factors & Considerations ### Patient-Specific Factors Review patient history for factors that may impact healing: - Age and general health status - Diabetes or metabolic conditions - Circulation and vascular health - Nutritional status - Mobility and pressure relief ### Warning Signs Monitor for signs requiring immediate attention: - Increased pain, redness, or swelling - Purulent drainage or odor - Fever or systemic signs of infection - Wound expansion or deterioration - Delayed healing beyond expected timeframe ## 📚 Clinical Guidelines Context {guideline_context[:800]}{'...' if len(guideline_context) > 800 else ''} ## 🏥 Next Steps ### Immediate Actions 1. **Professional Consultation**: Schedule appointment with wound care specialist 2. **Baseline Documentation**: Establish comprehensive baseline assessment 3. **Treatment Plan**: Develop individualized care protocol 4. **Patient Education**: Provide wound care instructions and warning signs ### Follow-up Schedule - **Week 1**: Daily monitoring and assessment - **Week 2-4**: Every 2-3 days or as indicated - **Monthly**: Comprehensive reassessment and plan review - **As Needed**: Immediate evaluation for any concerning changes ## ⚖️ Important Medical Disclaimer **This automated analysis is provided for informational and educational purposes only.** - This report does not constitute medical diagnosis or treatment advice - All measurements are computer-generated estimates requiring clinical verification - Professional medical evaluation is essential for proper diagnosis and treatment - This AI tool should supplement, not replace, clinical judgment - Always consult qualified healthcare professionals for medical decisions ### Clinical Correlation Required - Verify all measurements with standard clinical tools - Correlate findings with patient symptoms and history - Consider factors not captured in automated analysis - Follow institutional protocols and guidelines --- *Generated by SmartHeal AI - Advanced Wound Care Analysis System* *Report Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}* *Version: AI-Processor v1.2 with Enhanced Fallback Reporting* """ return report def save_and_commit_image(self, image_pil: Image.Image) -> str: """Save image locally and optionally commit to HF dataset.""" try: os.makedirs(self.uploads_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{timestamp}.png" path = os.path.join(self.uploads_dir, filename) # Save image image_pil.convert("RGB").save(path) logging.info(f"✅ Image saved locally: {path}") # Upload to HuggingFace dataset if configured if self.hf_token and self.dataset_id: try: api = HfApi() api.upload_file( path_or_fileobj=path, path_in_repo=f"images/{filename}", repo_id=self.dataset_id, repo_type="dataset", token=self.hf_token, commit_message=f"Upload wound image: {filename}" ) logging.info("✅ Image committed to HF dataset") except Exception as e: logging.warning(f"HF upload failed: {e}") return path except Exception as e: logging.error(f"Failed to save image: {e}") return "" def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict: """Run full analysis pipeline.""" try: # Save image first saved_path = self.save_and_commit_image(image_pil) logging.info(f"Image saved: {saved_path}") # Perform visual analysis visual_results = self.perform_visual_analysis(image_pil) logging.info(f"Visual analysis completed: {visual_results}") # Process questionnaire data patient_info = f"Age: {questionnaire_data.get('age', 'N/A')}, Diabetic: {questionnaire_data.get('diabetic', 'N/A')}, Allergies: {questionnaire_data.get('allergies', 'N/A')}, Date of Wound Sustained: {questionnaire_data.get('date_of_injury', 'N/A')}, Professional Care: {questionnaire_data.get('professional_care', 'N/A')}, Oozing/Bleeding: {questionnaire_data.get('oozing_bleeding', 'N/A')}, Infection: {questionnaire_data.get('infection', 'N/A')}, Moisture: {questionnaire_data.get('moisture', 'N/A')}" # Query guidelines query = f"best practices for managing a {visual_results['wound_type']} with moisture level '{questionnaire_data.get('moisture', 'unknown')}' and signs of infection '{questionnaire_data.get('infection', 'unknown')}' in a patient who is diabetic '{questionnaire_data.get('diabetic', 'unknown')}'" guideline_context = self.query_guidelines(query) logging.info("Guidelines queried successfully") # Generate final report report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil) logging.info("Report generated successfully") return { 'success': True, 'visual_analysis': visual_results, 'report': report, 'saved_image_path': saved_path, 'guideline_context': guideline_context[:500] + "..." if len(guideline_context) > 500 else guideline_context } except Exception as e: logging.error(f"Pipeline error: {e}") return { 'success': False, 'error': str(e), 'visual_analysis': {}, 'report': f"Analysis failed: {str(e)}", 'saved_image_path': None, 'guideline_context': "" } def analyze_wound(self, image, questionnaire_data: dict) -> dict: """Main analysis entry point - maintains original function name.""" try: # Handle different image input formats if isinstance(image, str): if os.path.exists(image): image_pil = Image.open(image) else: raise ValueError(f"Image file not found: {image}") elif isinstance(image, Image.Image): image_pil = image elif isinstance(image, np.ndarray): image_pil = Image.fromarray(image) else: raise ValueError(f"Unsupported image type: {type(image)}") return self.full_analysis_pipeline(image_pil, questionnaire_data) except Exception as e: logging.error(f"Wound analysis error: {e}") return { 'success': False, 'error': str(e), 'visual_analysis': {}, 'report': f"Analysis initialization failed: {str(e)}", 'saved_image_path': None, 'guideline_context': "" }