Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from datetime import datetime | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import time | |
| from huggingface_hub import HfApi, HfFolder | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| # =============== LOGGING SETUP =============== | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # =============== CONFIGURATION =============== | |
| UPLOADS_DIR = "uploads" | |
| if not os.path.exists(UPLOADS_DIR): | |
| os.makedirs(UPLOADS_DIR) | |
| logging.info(f"Created uploads directory: {UPLOADS_DIR}") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| YOLO_MODEL_PATH = "src/best.pt" | |
| SEG_MODEL_PATH = "src/segmentation_model.h5" | |
| GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"] | |
| DATASET_ID = "SmartHeal/wound-image-uploads" | |
| MAX_NEW_TOKENS = 1024 # Reduced for stability | |
| PIXELS_PER_CM = 38 | |
| # =============== GLOBAL CACHES =============== | |
| models_cache = {} | |
| knowledge_base_cache = {} | |
| # =============== LAZY LOADING FUNCTIONS (CPU-SAFE) =============== | |
| def load_yolo_model(yolo_model_path): | |
| """Lazy import and load YOLO model to avoid CUDA initialization.""" | |
| from ultralytics import YOLO | |
| return YOLO(yolo_model_path) | |
| def load_segmentation_model(seg_model_path): | |
| """Lazy import and load segmentation model.""" | |
| import tensorflow as tf | |
| tf.config.set_visible_devices([], 'GPU') # Force CPU for TensorFlow | |
| from tensorflow.keras.models import load_model | |
| return load_model(seg_model_path, compile=False) | |
| def load_classification_pipeline(hf_token): | |
| """Lazy import and load classification pipeline (CPU only).""" | |
| from transformers import pipeline | |
| return pipeline( | |
| "image-classification", | |
| model="Hemg/Wound-classification", | |
| token=hf_token, | |
| device="cpu" | |
| ) | |
| def load_embedding_model(): | |
| """Load embedding model for knowledge base.""" | |
| return HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={"device": "cpu"} | |
| ) | |
| # =============== MODEL INITIALIZATION =============== | |
| def initialize_cpu_models(): | |
| """Initialize all CPU-only models once.""" | |
| global models_cache | |
| if HF_TOKEN: | |
| HfFolder.save_token(HF_TOKEN) | |
| logging.info("β HuggingFace token set") | |
| if "det" not in models_cache: | |
| try: | |
| models_cache["det"] = load_yolo_model(YOLO_MODEL_PATH) | |
| logging.info("β YOLO model loaded (CPU only)") | |
| except Exception as e: | |
| logging.error(f"YOLO load failed: {e}") | |
| if "seg" not in models_cache: | |
| try: | |
| models_cache["seg"] = load_segmentation_model(SEG_MODEL_PATH) | |
| logging.info("β Segmentation model loaded (CPU)") | |
| except Exception as e: | |
| logging.warning(f"Segmentation model not available: {e}") | |
| if "cls" not in models_cache: | |
| try: | |
| models_cache["cls"] = load_classification_pipeline(HF_TOKEN) | |
| logging.info("β Classification pipeline loaded (CPU)") | |
| except Exception as e: | |
| logging.warning(f"Classification pipeline not available: {e}") | |
| if "embedding_model" not in models_cache: | |
| try: | |
| models_cache["embedding_model"] = load_embedding_model() | |
| logging.info("β Embedding model loaded (CPU)") | |
| except Exception as e: | |
| logging.warning(f"Embedding model not available: {e}") | |
| def setup_knowledge_base(): | |
| """Load PDF documents and create FAISS vector store.""" | |
| global knowledge_base_cache | |
| if "vector_store" in knowledge_base_cache: | |
| return | |
| docs = [] | |
| for pdf_path in GUIDELINE_PDFS: | |
| if os.path.exists(pdf_path): | |
| try: | |
| loader = PyPDFLoader(pdf_path) | |
| docs.extend(loader.load()) | |
| logging.info(f"Loaded PDF: {pdf_path}") | |
| except Exception as e: | |
| logging.warning(f"Failed to load PDF {pdf_path}: {e}") | |
| if docs and "embedding_model" in models_cache: | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| chunks = splitter.split_documents(docs) | |
| knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"]) | |
| logging.info(f"β Knowledge base ready with {len(chunks)} chunks") | |
| else: | |
| knowledge_base_cache["vector_store"] = None | |
| logging.warning("Knowledge base unavailable") | |
| # Initialize models on app startup | |
| initialize_cpu_models() | |
| setup_knowledge_base() | |
| # =============== GPU-DECORATED MEDGEMMA FUNCTION WITH TIMEOUT HANDLING =============== | |
| # Reduced duration for stability | |
| def generate_medgemma_report_with_timeout( | |
| patient_info, | |
| visual_results, | |
| guideline_context, | |
| image_pil, | |
| max_new_tokens=None, | |
| ): | |
| """GPU-only function for MedGemma report generation with improved timeout handling.""" | |
| import torch | |
| from transformers import pipeline | |
| try: | |
| # Clear GPU cache first | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Use a shorter, more focused prompt to reduce processing time | |
| prompt = f""" | |
| You are a medical AI assistant. Analyze this wound image and patient data to provide a clinical assessment. | |
| Patient: {patient_info} | |
| Wound: {visual_results.get('wound_type', 'Unknown')} - {visual_results.get('length_cm', 0)}Γ{visual_results.get('breadth_cm', 0)}cm | |
| Provide a structured report with: | |
| 1. Clinical Summary (wound appearance, size, location) | |
| 2. Treatment Recommendations (dressings, care protocols) | |
| 3. Risk Assessment (healing factors) | |
| 4. Monitoring Plan (follow-up schedule) | |
| Keep response concise but medically comprehensive. | |
| """ | |
| # Initialize pipeline with optimized settings | |
| pipe = pipeline( | |
| "image-text-to-text", | |
| model="google/medgemma-4b-it", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| token=HF_TOKEN, | |
| model_kwargs={"low_cpu_mem_usage": True, "use_cache": True} | |
| ) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image_pil}, | |
| {"type": "text", "text": prompt}, | |
| ] | |
| } | |
| ] | |
| # Generate with conservative settings | |
| start_time = time.time() | |
| output = pipe( | |
| text=messages, | |
| max_new_tokens=max_new_tokens or 800, # Reduced for stability | |
| do_sample=False, | |
| temperature=0.7, | |
| pad_token_id=pipe.tokenizer.eos_token_id | |
| ) | |
| processing_time = time.time() - start_time | |
| logging.info(f"β MedGemma processing completed in {processing_time:.2f} seconds") | |
| if output and len(output) > 0: | |
| result = output[0]["generated_text"][-1].get("content", "").strip() | |
| return result if result else "β οΈ Empty response generated" | |
| else: | |
| return "β οΈ No output generated" | |
| except Exception as e: | |
| logging.error(f"β MedGemma generation error: {e}") | |
| return f"β Report generation failed: {str(e)}" | |
| finally: | |
| # Clear GPU memory | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # =============== AI PROCESSOR CLASS =============== | |
| class AIProcessor: | |
| def __init__(self): | |
| self.models_cache = models_cache | |
| self.knowledge_base_cache = knowledge_base_cache | |
| self.px_per_cm = PIXELS_PER_CM | |
| self.uploads_dir = UPLOADS_DIR | |
| self.dataset_id = DATASET_ID | |
| self.hf_token = HF_TOKEN | |
| def perform_visual_analysis(self, image_pil: Image.Image) -> dict: | |
| """Performs the full visual analysis pipeline.""" | |
| try: | |
| # Convert PIL to OpenCV format | |
| image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) | |
| # YOLO Detection | |
| results = self.models_cache["det"].predict(image_cv, verbose=False, device="cpu") | |
| if not results or not results[0].boxes: | |
| raise ValueError("No wound could be detected.") | |
| box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int) | |
| detected_region_cv = image_cv[box[1]:box[3], box[0]:box[2]] | |
| # Segmentation | |
| input_size = self.models_cache["seg"].input_shape[1:3] | |
| resized = cv2.resize(detected_region_cv, (input_size[1], input_size[0])) | |
| mask_pred = self.models_cache["seg"].predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0] | |
| mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8) | |
| # Calculate measurements | |
| contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| length, breadth, area = (0, 0, 0) | |
| if contours: | |
| cnt = max(contours, key=cv2.contourArea) | |
| x, y, w, h = cv2.boundingRect(cnt) | |
| length, breadth, area = round(h / self.px_per_cm, 2), round(w / self.px_per_cm, 2), round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2) | |
| # Classification | |
| detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB)) | |
| wound_type = max(self.models_cache["cls"](detected_image_pil), key=lambda x: x["score"])["label"] | |
| # Save visualization images | |
| os.makedirs(f"{self.uploads_dir}/analysis", exist_ok=True) | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # Detection visualization | |
| det_vis = image_cv.copy() | |
| cv2.rectangle(det_vis, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) | |
| det_path = f"{self.uploads_dir}/analysis/detection_{ts}.png" | |
| cv2.imwrite(det_path, det_vis) | |
| # Original image | |
| original_path = f"{self.uploads_dir}/analysis/original_{ts}.png" | |
| cv2.imwrite(original_path, image_cv) | |
| # Segmentation visualization | |
| seg_path = None | |
| if contours: | |
| mask_resized = cv2.resize(mask_np * 255, (detected_region_cv.shape[1], detected_region_cv.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| overlay = detected_region_cv.copy() | |
| overlay[mask_resized > 127] = [0, 0, 255] # Red overlay for wound area | |
| seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0) | |
| seg_path = f"{self.uploads_dir}/analysis/segmentation_{ts}.png" | |
| cv2.imwrite(seg_path, seg_vis) | |
| visual_results = { | |
| "wound_type": wound_type, | |
| "length_cm": length, | |
| "breadth_cm": breadth, | |
| "surface_area_cm2": area, | |
| "detection_confidence": float(results[0].boxes.conf[0].cpu().item()) if results[0].boxes.conf is not None else 0.0, | |
| "detection_image_path": det_path, | |
| "segmentation_image_path": seg_path, | |
| "original_image_path": original_path | |
| } | |
| return visual_results | |
| except Exception as e: | |
| logging.error(f"Visual analysis failed: {e}") | |
| raise e | |
| def query_guidelines(self, query: str) -> str: | |
| """Query the knowledge base for relevant information.""" | |
| try: | |
| vector_store = self.knowledge_base_cache.get("vector_store") | |
| if not vector_store: | |
| return "Knowledge base is not available." | |
| retriever = vector_store.as_retriever(search_kwargs={"k": 5}) # Reduced for efficiency | |
| docs = retriever.invoke(query) | |
| return "\n\n".join([f"Source: {doc.metadata.get('source', 'N/A')}\nContent: {doc.page_content[:300]}..." for doc in docs]) | |
| except Exception as e: | |
| logging.error(f"Guidelines query failed: {e}") | |
| return f"Guidelines query failed: {str(e)}" | |
| def generate_final_report( | |
| self, patient_info: str, visual_results: dict, guideline_context: str, | |
| image_pil: Image.Image, max_new_tokens: int = None | |
| ) -> str: | |
| """Generate final report using MedGemma with timeout handling.""" | |
| try: | |
| # Try MedGemma with timeout handling | |
| report = generate_medgemma_report_with_timeout( | |
| patient_info, visual_results, guideline_context, image_pil, max_new_tokens | |
| ) | |
| # Check if report is valid | |
| if report and report.strip() and not report.startswith("β") and not report.startswith("β οΈ"): | |
| return report | |
| else: | |
| logging.warning("MedGemma returned invalid response, using fallback") | |
| return self._generate_fallback_report(patient_info, visual_results, guideline_context) | |
| except Exception as e: | |
| logging.error(f"MedGemma report generation failed: {e}") | |
| return self._generate_fallback_report(patient_info, visual_results, guideline_context) | |
| def _generate_fallback_report( | |
| self, patient_info: str, visual_results: dict, guideline_context: str | |
| ) -> str: | |
| """Generate comprehensive fallback report if MedGemma fails.""" | |
| report = f"""# π©Ί SmartHeal AI - Comprehensive Wound Analysis Report | |
| ## π Patient Information | |
| {patient_info} | |
| ## π Visual Analysis Results | |
| - **Wound Type**: {visual_results.get('wound_type', 'Unknown')} | |
| - **Dimensions**: {visual_results.get('length_cm', 0)} cm Γ {visual_results.get('breadth_cm', 0)} cm | |
| - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cmΒ² | |
| - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%} | |
| ## π Analysis Images Available | |
| - **Original Image**: {visual_results.get('original_image_path', 'Available')} | |
| - **Detection Visualization**: {visual_results.get('detection_image_path', 'Available')} | |
| - **Segmentation Overlay**: {visual_results.get('segmentation_image_path', 'Available')} | |
| ## π― Clinical Assessment Summary | |
| ### Wound Classification | |
| Based on automated analysis, this wound has been classified as **{visual_results.get('wound_type', 'Unspecified')}** with the following characteristics: | |
| - Size: {visual_results.get('length_cm', 0)} Γ {visual_results.get('breadth_cm', 0)} cm | |
| - Total area: {visual_results.get('surface_area_cm2', 0)} cmΒ² | |
| - Detection confidence: {visual_results.get('detection_confidence', 0):.1%} | |
| ### Clinical Observations | |
| The automated visual analysis provides quantitative measurements that should be verified through clinical examination. The wound type classification helps guide initial treatment considerations. | |
| ## π Treatment Recommendations | |
| ### Wound Care Protocol | |
| 1. **Assessment**: Comprehensive clinical evaluation by qualified healthcare professional | |
| 2. **Cleaning**: Gentle wound cleansing with appropriate solution | |
| 3. **Debridement**: Remove necrotic tissue if present (professional assessment required) | |
| 4. **Dressing Selection**: Choose appropriate dressing based on wound characteristics: | |
| - Moisture level assessment | |
| - Infection risk evaluation | |
| - Patient comfort and mobility | |
| ### Monitoring Plan | |
| - **Initial Phase**: Daily assessment for first week | |
| - **Ongoing Care**: Reassessment every 2-3 days or as clinically indicated | |
| - **Documentation**: Regular photo documentation and measurement tracking | |
| - **Progress Evaluation**: Weekly review of healing progression | |
| ## β οΈ Risk Factors & Considerations | |
| ### Patient-Specific Factors | |
| Review patient history for factors that may impact healing: | |
| - Age and general health status | |
| - Diabetes or metabolic conditions | |
| - Circulation and vascular health | |
| - Nutritional status | |
| - Mobility and pressure relief | |
| ### Warning Signs | |
| Monitor for signs requiring immediate attention: | |
| - Increased pain, redness, or swelling | |
| - Purulent drainage or odor | |
| - Fever or systemic signs of infection | |
| - Wound expansion or deterioration | |
| - Delayed healing beyond expected timeframe | |
| ## π Clinical Guidelines Context | |
| {guideline_context[:800]}{'...' if len(guideline_context) > 800 else ''} | |
| ## π₯ Next Steps | |
| ### Immediate Actions | |
| 1. **Professional Consultation**: Schedule appointment with wound care specialist | |
| 2. **Baseline Documentation**: Establish comprehensive baseline assessment | |
| 3. **Treatment Plan**: Develop individualized care protocol | |
| 4. **Patient Education**: Provide wound care instructions and warning signs | |
| ### Follow-up Schedule | |
| - **Week 1**: Daily monitoring and assessment | |
| - **Week 2-4**: Every 2-3 days or as indicated | |
| - **Monthly**: Comprehensive reassessment and plan review | |
| - **As Needed**: Immediate evaluation for any concerning changes | |
| ## βοΈ Important Medical Disclaimer | |
| **This automated analysis is provided for informational and educational purposes only.** | |
| - This report does not constitute medical diagnosis or treatment advice | |
| - All measurements are computer-generated estimates requiring clinical verification | |
| - Professional medical evaluation is essential for proper diagnosis and treatment | |
| - This AI tool should supplement, not replace, clinical judgment | |
| - Always consult qualified healthcare professionals for medical decisions | |
| ### Clinical Correlation Required | |
| - Verify all measurements with standard clinical tools | |
| - Correlate findings with patient symptoms and history | |
| - Consider factors not captured in automated analysis | |
| - Follow institutional protocols and guidelines | |
| --- | |
| *Generated by SmartHeal AI - Advanced Wound Care Analysis System* | |
| *Report Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}* | |
| *Version: AI-Processor v1.2 with Enhanced Fallback Reporting* | |
| """ | |
| return report | |
| def save_and_commit_image(self, image_pil: Image.Image) -> str: | |
| """Save image locally and optionally commit to HF dataset.""" | |
| try: | |
| os.makedirs(self.uploads_dir, exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"{timestamp}.png" | |
| path = os.path.join(self.uploads_dir, filename) | |
| # Save image | |
| image_pil.convert("RGB").save(path) | |
| logging.info(f"β Image saved locally: {path}") | |
| # Upload to HuggingFace dataset if configured | |
| if self.hf_token and self.dataset_id: | |
| try: | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=path, | |
| path_in_repo=f"images/{filename}", | |
| repo_id=self.dataset_id, | |
| repo_type="dataset", | |
| token=self.hf_token, | |
| commit_message=f"Upload wound image: {filename}" | |
| ) | |
| logging.info("β Image committed to HF dataset") | |
| except Exception as e: | |
| logging.warning(f"HF upload failed: {e}") | |
| return path | |
| except Exception as e: | |
| logging.error(f"Failed to save image: {e}") | |
| return "" | |
| def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict: | |
| """Run full analysis pipeline.""" | |
| try: | |
| # Save image first | |
| saved_path = self.save_and_commit_image(image_pil) | |
| logging.info(f"Image saved: {saved_path}") | |
| # Perform visual analysis | |
| visual_results = self.perform_visual_analysis(image_pil) | |
| logging.info(f"Visual analysis completed: {visual_results}") | |
| # Process questionnaire data | |
| patient_info = f"Age: {questionnaire_data.get('age', 'N/A')}, Diabetic: {questionnaire_data.get('diabetic', 'N/A')}, Allergies: {questionnaire_data.get('allergies', 'N/A')}, Date of Wound Sustained: {questionnaire_data.get('date_of_injury', 'N/A')}, Professional Care: {questionnaire_data.get('professional_care', 'N/A')}, Oozing/Bleeding: {questionnaire_data.get('oozing_bleeding', 'N/A')}, Infection: {questionnaire_data.get('infection', 'N/A')}, Moisture: {questionnaire_data.get('moisture', 'N/A')}" | |
| # Query guidelines | |
| query = f"best practices for managing a {visual_results['wound_type']} with moisture level '{questionnaire_data.get('moisture', 'unknown')}' and signs of infection '{questionnaire_data.get('infection', 'unknown')}' in a patient who is diabetic '{questionnaire_data.get('diabetic', 'unknown')}'" | |
| guideline_context = self.query_guidelines(query) | |
| logging.info("Guidelines queried successfully") | |
| # Generate final report | |
| report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil) | |
| logging.info("Report generated successfully") | |
| return { | |
| 'success': True, | |
| 'visual_analysis': visual_results, | |
| 'report': report, | |
| 'saved_image_path': saved_path, | |
| 'guideline_context': guideline_context[:500] + "..." if len(guideline_context) > 500 else guideline_context | |
| } | |
| except Exception as e: | |
| logging.error(f"Pipeline error: {e}") | |
| return { | |
| 'success': False, | |
| 'error': str(e), | |
| 'visual_analysis': {}, | |
| 'report': f"Analysis failed: {str(e)}", | |
| 'saved_image_path': None, | |
| 'guideline_context': "" | |
| } | |
| def analyze_wound(self, image, questionnaire_data: dict) -> dict: | |
| """Main analysis entry point - maintains original function name.""" | |
| try: | |
| # Handle different image input formats | |
| if isinstance(image, str): | |
| if os.path.exists(image): | |
| image_pil = Image.open(image) | |
| else: | |
| raise ValueError(f"Image file not found: {image}") | |
| elif isinstance(image, Image.Image): | |
| image_pil = image | |
| elif isinstance(image, np.ndarray): | |
| image_pil = Image.fromarray(image) | |
| else: | |
| raise ValueError(f"Unsupported image type: {type(image)}") | |
| return self.full_analysis_pipeline(image_pil, questionnaire_data) | |
| except Exception as e: | |
| logging.error(f"Wound analysis error: {e}") | |
| return { | |
| 'success': False, | |
| 'error': str(e), | |
| 'visual_analysis': {}, | |
| 'report': f"Analysis initialization failed: {str(e)}", | |
| 'saved_image_path': None, | |
| 'guideline_context': "" | |
| } |