Spaces:
Running
Running
Update src/ai_processor.py
Browse files- src/ai_processor.py +610 -610
src/ai_processor.py
CHANGED
|
@@ -1,611 +1,611 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import logging
|
| 3 |
-
import cv2
|
| 4 |
-
import numpy as np
|
| 5 |
-
from PIL import Image
|
| 6 |
-
import torch
|
| 7 |
-
import json
|
| 8 |
-
from datetime import datetime
|
| 9 |
-
import tensorflow as tf
|
| 10 |
-
from transformers import pipeline
|
| 11 |
-
from ultralytics import YOLO
|
| 12 |
-
from tensorflow.keras.models import load_model
|
| 13 |
-
from langchain_community.document_loaders import PyPDFLoader
|
| 14 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 15 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 16 |
-
from langchain_community.vectorstores import FAISS
|
| 17 |
-
from huggingface_hub import HfApi, HfFolder
|
| 18 |
-
import spaces
|
| 19 |
-
|
| 20 |
-
from .config import Config
|
| 21 |
-
|
| 22 |
-
class AIProcessor:
|
| 23 |
-
def __init__(self):
|
| 24 |
-
self.models_cache = {}
|
| 25 |
-
self.knowledge_base_cache = {}
|
| 26 |
-
self.config = Config()
|
| 27 |
-
self.px_per_cm = self.config.PIXELS_PER_CM
|
| 28 |
-
self._initialize_models()
|
| 29 |
-
|
| 30 |
-
def _initialize_models(self):
|
| 31 |
-
"""Initialize all AI models including real-time models"""
|
| 32 |
-
try:
|
| 33 |
-
# Set HuggingFace token
|
| 34 |
-
if self.config.HF_TOKEN:
|
| 35 |
-
HfFolder.save_token(self.config.HF_TOKEN)
|
| 36 |
-
logging.info("HuggingFace token set successfully")
|
| 37 |
-
|
| 38 |
-
# Initialize MedGemma pipeline for medical text generation
|
| 39 |
-
try:
|
| 40 |
-
self.models_cache["medgemma_pipe"] = pipeline(
|
| 41 |
-
"image-text-to-text",
|
| 42 |
-
model="google/medgemma-4b-it",
|
| 43 |
-
torch_dtype=torch.bfloat16,
|
| 44 |
-
offload_folder="offload",
|
| 45 |
-
device_map="auto",
|
| 46 |
-
token=self.config.HF_TOKEN
|
| 47 |
-
)
|
| 48 |
-
logging.info("✅ MedGemma pipeline loaded successfully")
|
| 49 |
-
except Exception as e:
|
| 50 |
-
logging.warning(f"MedGemma pipeline not available: {e}")
|
| 51 |
-
|
| 52 |
-
# Initialize YOLO model for wound detection
|
| 53 |
-
try:
|
| 54 |
-
self.models_cache["det"] = YOLO(self.config.YOLO_MODEL_PATH)
|
| 55 |
-
logging.info("✅ YOLO detection model loaded successfully")
|
| 56 |
-
except Exception as e:
|
| 57 |
-
logging.warning(f"YOLO model not available: {e}")
|
| 58 |
-
|
| 59 |
-
# Initialize segmentation model
|
| 60 |
-
try:
|
| 61 |
-
self.models_cache["seg"] = load_model(self.config.SEG_MODEL_PATH, compile=False)
|
| 62 |
-
logging.info("✅ Segmentation model loaded successfully")
|
| 63 |
-
except Exception as e:
|
| 64 |
-
logging.warning(f"Segmentation model not available: {e}")
|
| 65 |
-
|
| 66 |
-
# Initialize wound classification model
|
| 67 |
-
try:
|
| 68 |
-
self.models_cache["cls"] = pipeline(
|
| 69 |
-
"image-classification",
|
| 70 |
-
model="Hemg/Wound-classification",
|
| 71 |
-
token=self.config.HF_TOKEN,
|
| 72 |
-
device="cpu"
|
| 73 |
-
)
|
| 74 |
-
logging.info("✅ Wound classification model loaded successfully")
|
| 75 |
-
except Exception as e:
|
| 76 |
-
logging.warning(f"Wound classification model not available: {e}")
|
| 77 |
-
|
| 78 |
-
# Initialize embedding model for knowledge base
|
| 79 |
-
try:
|
| 80 |
-
self.models_cache["embedding_model"] = HuggingFaceEmbeddings(
|
| 81 |
-
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 82 |
-
model_kwargs={'device': 'cpu'}
|
| 83 |
-
)
|
| 84 |
-
logging.info("✅ Embedding model loaded successfully")
|
| 85 |
-
except Exception as e:
|
| 86 |
-
logging.warning(f"Embedding model not available: {e}")
|
| 87 |
-
|
| 88 |
-
logging.info("✅ All models loaded.")
|
| 89 |
-
self._load_knowledge_base()
|
| 90 |
-
|
| 91 |
-
except Exception as e:
|
| 92 |
-
logging.error(f"Error initializing AI models: {e}")
|
| 93 |
-
|
| 94 |
-
def _load_knowledge_base(self):
|
| 95 |
-
"""Load knowledge base from PDF guidelines"""
|
| 96 |
-
try:
|
| 97 |
-
documents = []
|
| 98 |
-
for pdf_path in self.config.GUIDELINE_PDFS:
|
| 99 |
-
if os.path.exists(pdf_path):
|
| 100 |
-
loader = PyPDFLoader(pdf_path)
|
| 101 |
-
docs = loader.load()
|
| 102 |
-
documents.extend(docs)
|
| 103 |
-
logging.info(f"Loaded PDF: {pdf_path}")
|
| 104 |
-
|
| 105 |
-
if documents and 'embedding_model' in self.models_cache:
|
| 106 |
-
# Split documents into chunks
|
| 107 |
-
text_splitter = RecursiveCharacterTextSplitter(
|
| 108 |
-
chunk_size=1000,
|
| 109 |
-
chunk_overlap=100
|
| 110 |
-
)
|
| 111 |
-
chunks = text_splitter.split_documents(documents)
|
| 112 |
-
|
| 113 |
-
# Create vector store
|
| 114 |
-
vectorstore = FAISS.from_documents(chunks, self.models_cache['embedding_model'])
|
| 115 |
-
self.knowledge_base_cache['vectorstore'] = vectorstore
|
| 116 |
-
logging.info(f"✅ Knowledge base loaded with {len(chunks)} chunks")
|
| 117 |
-
else:
|
| 118 |
-
self.knowledge_base_cache['vectorstore'] = None
|
| 119 |
-
logging.warning("Knowledge base not available - no PDFs found or embedding model unavailable")
|
| 120 |
-
|
| 121 |
-
except Exception as e:
|
| 122 |
-
logging.warning(f"Knowledge base loading error: {e}")
|
| 123 |
-
self.knowledge_base_cache['vectorstore'] = None
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def perform_visual_analysis(self, image_pil):
|
| 127 |
-
"""Perform comprehensive visual analysis of wound image."""
|
| 128 |
-
try:
|
| 129 |
-
# Convert PIL to OpenCV format
|
| 130 |
-
image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
| 131 |
-
|
| 132 |
-
# YOLO detection
|
| 133 |
-
if 'det' not in self.models_cache:
|
| 134 |
-
raise ValueError("YOLO detection model not available.")
|
| 135 |
-
|
| 136 |
-
results = self.models_cache['det'].predict(image_cv, verbose=False, device="cpu")
|
| 137 |
-
|
| 138 |
-
if not results or not results[0].boxes:
|
| 139 |
-
raise ValueError("No wound detected in the image.")
|
| 140 |
-
|
| 141 |
-
# Extract bounding box
|
| 142 |
-
box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
|
| 143 |
-
x1, y1, x2, y2 = box
|
| 144 |
-
region_cv = image_cv[y1:y2, x1:x2]
|
| 145 |
-
|
| 146 |
-
# Save detection image
|
| 147 |
-
detection_image_cv = image_cv.copy()
|
| 148 |
-
cv2.rectangle(detection_image_cv, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 149 |
-
os.makedirs(os.path.join(self.config.UPLOADS_DIR, "analysis"), exist_ok=True)
|
| 150 |
-
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 151 |
-
detection_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"detection_{timestamp}.png")
|
| 152 |
-
cv2.imwrite(detection_image_path, detection_image_cv)
|
| 153 |
-
detection_image_pil = Image.fromarray(cv2.cvtColor(detection_image_cv, cv2.COLOR_BGR2RGB))
|
| 154 |
-
|
| 155 |
-
# Initialize outputs
|
| 156 |
-
length = breadth = area = 0
|
| 157 |
-
segmentation_image_pil = None
|
| 158 |
-
segmentation_image_path = None
|
| 159 |
-
|
| 160 |
-
# Segmentation (optional)
|
| 161 |
-
if 'seg' in self.models_cache:
|
| 162 |
-
input_size = self.models_cache['seg'].input_shape[1:3] # (height, width)
|
| 163 |
-
resized_region = cv2.resize(region_cv, (input_size[1], input_size[0]))
|
| 164 |
-
|
| 165 |
-
seg_input = np.expand_dims(resized_region / 255.0, 0)
|
| 166 |
-
mask_pred = self.models_cache['seg'].predict(seg_input, verbose=0)[0]
|
| 167 |
-
mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8)
|
| 168 |
-
|
| 169 |
-
# Resize mask back to original region size
|
| 170 |
-
mask_resized = cv2.resize(mask_np, (region_cv.shape[1], region_cv.shape[0]), interpolation=cv2.INTER_NEAREST)
|
| 171 |
-
|
| 172 |
-
# Overlay mask on region for visualization
|
| 173 |
-
overlay = region_cv.copy()
|
| 174 |
-
overlay[mask_resized == 1] = [0, 0, 255] # Red overlay
|
| 175 |
-
|
| 176 |
-
# Blend overlay for final output
|
| 177 |
-
segmented_visual = cv2.addWeighted(region_cv, 0.7, overlay, 0.3, 0)
|
| 178 |
-
|
| 179 |
-
# Save segmentation image
|
| 180 |
-
segmentation_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"segmentation_{timestamp}.png")
|
| 181 |
-
cv2.imwrite(segmentation_image_path, segmented_visual)
|
| 182 |
-
segmentation_image_pil = Image.fromarray(cv2.cvtColor(segmented_visual, cv2.COLOR_BGR2RGB))
|
| 183 |
-
|
| 184 |
-
# Wound measurements from resized mask
|
| 185 |
-
contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 186 |
-
if contours:
|
| 187 |
-
cnt = max(contours, key=cv2.contourArea)
|
| 188 |
-
x, y, w, h = cv2.boundingRect(cnt)
|
| 189 |
-
length = round(h / self.px_per_cm, 2)
|
| 190 |
-
breadth = round(w / self.px_per_cm, 2)
|
| 191 |
-
area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
|
| 192 |
-
|
| 193 |
-
# Classification (optional)
|
| 194 |
-
wound_type = "Unknown"
|
| 195 |
-
if 'cls' in self.models_cache:
|
| 196 |
-
try:
|
| 197 |
-
region_pil = Image.fromarray(cv2.cvtColor(region_cv, cv2.COLOR_BGR2RGB))
|
| 198 |
-
cls_result = self.models_cache['cls'](region_pil)
|
| 199 |
-
wound_type = max(cls_result, key=lambda x: x['score'])['label']
|
| 200 |
-
except Exception as e:
|
| 201 |
-
logging.warning(f"Wound classification error: {e}")
|
| 202 |
-
|
| 203 |
-
return {
|
| 204 |
-
'wound_type': wound_type,
|
| 205 |
-
'length_cm': length,
|
| 206 |
-
'breadth_cm': breadth,
|
| 207 |
-
'surface_area_cm2': area,
|
| 208 |
-
'detection_confidence': float(results[0].boxes[0].conf.cpu().item()),
|
| 209 |
-
'bounding_box': box.tolist(),
|
| 210 |
-
'detection_image_path': detection_image_path,
|
| 211 |
-
'detection_image_pil': detection_image_pil,
|
| 212 |
-
'segmentation_image_path': segmentation_image_path,
|
| 213 |
-
'segmentation_image_pil': segmentation_image_pil
|
| 214 |
-
}
|
| 215 |
-
|
| 216 |
-
except Exception as e:
|
| 217 |
-
logging.error(f"Visual analysis error: {e}")
|
| 218 |
-
raise ValueError(f"Visual analysis failed: {str(e)}")
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
def query_guidelines(self, query: str):
|
| 222 |
-
"""Query the knowledge base for relevant guidelines"""
|
| 223 |
-
try:
|
| 224 |
-
vector_store = self.knowledge_base_cache.get("vectorstore")
|
| 225 |
-
if not vector_store:
|
| 226 |
-
return "Knowledge base unavailable - clinical guidelines not loaded"
|
| 227 |
-
|
| 228 |
-
# Retrieve relevant documents
|
| 229 |
-
retriever = vector_store.as_retriever(search_kwargs={"k": 10})
|
| 230 |
-
docs = retriever.invoke(query)
|
| 231 |
-
|
| 232 |
-
if not docs:
|
| 233 |
-
return "No relevant guidelines found for the query"
|
| 234 |
-
|
| 235 |
-
# Format the results
|
| 236 |
-
formatted_results = []
|
| 237 |
-
for doc in docs:
|
| 238 |
-
source = doc.metadata.get('source', 'Unknown')
|
| 239 |
-
page = doc.metadata.get('page', 'N/A')
|
| 240 |
-
content = doc.page_content.strip()
|
| 241 |
-
formatted_results.append(f"Source: {source}, Page: {page}\nContent: {content}")
|
| 242 |
-
|
| 243 |
-
return "\n\n".join(formatted_results)
|
| 244 |
-
|
| 245 |
-
except Exception as e:
|
| 246 |
-
logging.error(f"Guidelines query error: {e}")
|
| 247 |
-
return f"Error querying guidelines: {str(e)}"
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None):
|
| 251 |
-
"""Generate comprehensive medical report using MedGemma"""
|
| 252 |
-
try:
|
| 253 |
-
if 'medgemma_pipe' not in self.models_cache:
|
| 254 |
-
return self._generate_fallback_report(patient_info, visual_results, guideline_context)
|
| 255 |
-
|
| 256 |
-
max_tokens = max_new_tokens or self.config.MAX_NEW_TOKENS
|
| 257 |
-
|
| 258 |
-
# Get detection and segmentation images if available
|
| 259 |
-
detection_image = visual_results.get('detection_image_pil', None)
|
| 260 |
-
segmentation_image = visual_results.get('segmentation_image_pil', None)
|
| 261 |
-
|
| 262 |
-
# Create image paths for report
|
| 263 |
-
detection_path = visual_results.get('detection_image_path', '')
|
| 264 |
-
segmentation_path = visual_results.get('segmentation_image_path', '')
|
| 265 |
-
|
| 266 |
-
# Create detailed prompt for medical analysis with image paths
|
| 267 |
-
prompt = f"""
|
| 268 |
-
# Wound Care Report
|
| 269 |
-
|
| 270 |
-
## Patient Information
|
| 271 |
-
{patient_info}
|
| 272 |
-
|
| 273 |
-
## Visual Analysis Summary
|
| 274 |
-
- Wound Type: {visual_results.get('wound_type', 'Unknown')}
|
| 275 |
-
- Length: {visual_results.get('length_cm', 0)} cm
|
| 276 |
-
- Breadth: {visual_results.get('breadth_cm', 0)} cm
|
| 277 |
-
- Surface Area: {visual_results.get('surface_area_cm2', 0)} cm²
|
| 278 |
-
- Detection Confidence: {visual_results.get('detection_confidence', 0):.2f}
|
| 279 |
-
|
| 280 |
-
## Clinical Reference
|
| 281 |
-
{guideline_context}
|
| 282 |
-
|
| 283 |
-
You are SmartHeal-AI Agent, a world-class wound care AI specialist trained in clinical wound assessment and guideline-based treatment planning.
|
| 284 |
-
Your task is to process the following structured inputs (patient data, wound measurements, clinical guidelines, and image) and perform **clinical reasoning and decision-making** to generate a complete wound care report.
|
| 285 |
-
---
|
| 286 |
-
🔍 **YOUR PROCESS — FOLLOW STRICTLY:**
|
| 287 |
-
### Step 1: Clinical Reasoning (Chain-of-Thought)
|
| 288 |
-
Use the provided information to think step-by-step about:
|
| 289 |
-
- Patient’s risk factors (e.g. diabetes, age, healing limitations)
|
| 290 |
-
- Wound characteristics (size, tissue appearance, moisture, infection signs)
|
| 291 |
-
- Visual clues from the image (location, granulation, maceration, inflammation, surrounding skin)
|
| 292 |
-
|
| 293 |
-
---
|
| 294 |
-
-Step 2: Structured Clinical Report
|
| 295 |
-
Generate the following report sections using markdown and medical terminology:
|
| 296 |
-
**1. Clinical Summary**
|
| 297 |
-
- Describe wound appearance and tissue types (e.g., slough, necrotic, granulating, epithelializing)
|
| 298 |
-
- Include size, wound bed condition, peri-wound skin, and signs of infection or biofilm
|
| 299 |
-
- Mention inferred location (e.g., heel, forefoot) if image allows
|
| 300 |
-
- Summarize patient's systemic risk profile
|
| 301 |
-
**2. Medicinal & Dressing Recommendations**
|
| 302 |
-
Based on your analysis:
|
| 303 |
-
- Recommend specific **wound care dressings** (e.g., hydrocolloid, alginate, foam, antimicrobial silver, etc.) suitable to wound moisture level and infection risk
|
| 304 |
-
- Propose **topical or systemic agents** ONLY if relevant — include name classes (e.g., antiseptic: povidone iodine, antibiotic ointments, enzymatic debriders)
|
| 305 |
-
- Mention **techniques** (e.g., sharp debridement, NPWT, moisture balance, pressure offloading, dressing frequency)
|
| 306 |
-
- Avoid repeating guidelines — **apply them**
|
| 307 |
-
**3. Key Risk Factors**
|
| 308 |
-
Explain how the patient’s condition (e.g., diabetic, poor circulation, advanced age, poor hygiene) may affect wound healing
|
| 309 |
-
**4. Prognosis & Monitoring Advice**
|
| 310 |
-
- Mention how often wound should be reassessed
|
| 311 |
-
- Indicate signs to monitor for deterioration or improvement
|
| 312 |
-
- Include when escalation to specialist is necessary
|
| 313 |
-
|
| 314 |
-
**Note:** Every dressing change is a chance for wound reassessment. Always perform a thorough wound evaluation at each dressing change.
|
| 315 |
-
"""
|
| 316 |
-
|
| 317 |
-
# Prepare messages for MedGemma with all available images
|
| 318 |
-
content_list = [{"type": "text", "text": prompt}]
|
| 319 |
-
|
| 320 |
-
# Add original image
|
| 321 |
-
if image_pil:
|
| 322 |
-
content_list.insert(0, {"type": "image", "image": image_pil})
|
| 323 |
-
|
| 324 |
-
# Add detection image if available
|
| 325 |
-
if detection_image:
|
| 326 |
-
content_list.insert(1, {"type": "image", "image": detection_image})
|
| 327 |
-
|
| 328 |
-
# Add segmentation image if available
|
| 329 |
-
if segmentation_image:
|
| 330 |
-
content_list.insert(2, {"type": "image", "image": segmentation_image})
|
| 331 |
-
|
| 332 |
-
messages = [
|
| 333 |
-
{
|
| 334 |
-
"role": "system",
|
| 335 |
-
"content": [{"type": "text", "text": "You are a world-class medical AI assistant specializing in wound care with expertise in wound assessment and treatment. Provide concise, evidence-based medical assessments focusing on: (1) Precise wound classification based on tissue type and appearance, (2) Specific treatment recommendations with exact product names or interventions when appropriate, (3) Objective evaluation of healing progression or deterioration indicators, and (4) Clear follow-up timelines. Avoid general statements and prioritize actionable insights based on the visual analysis measurements and patient context."}],
|
| 336 |
-
},
|
| 337 |
-
{
|
| 338 |
-
"role": "user",
|
| 339 |
-
"content": content_list
|
| 340 |
-
}
|
| 341 |
-
]
|
| 342 |
-
|
| 343 |
-
# Generate report using MedGemma
|
| 344 |
-
output = self.models_cache['medgemma_pipe'](
|
| 345 |
-
text=messages,
|
| 346 |
-
max_new_tokens=1024,
|
| 347 |
-
do_sample=False,
|
| 348 |
-
)
|
| 349 |
-
|
| 350 |
-
generated_content = output[0]['generated_text'][-1].get('content', '').strip()
|
| 351 |
-
|
| 352 |
-
# Include image paths in the final report for display in UI
|
| 353 |
-
if generated_content:
|
| 354 |
-
# Add image paths to the report for frontend display
|
| 355 |
-
image_paths_section = f"""
|
| 356 |
-
## Analysis Images
|
| 357 |
-
- Original Image: {image_pil}
|
| 358 |
-
- Detection Image: {detection_path}
|
| 359 |
-
- Segmentation Image: {segmentation_path}
|
| 360 |
-
"""
|
| 361 |
-
generated_content = image_paths_section + generated_content
|
| 362 |
-
|
| 363 |
-
return generated_content if generated_content else self._generate_fallback_report(patient_info, visual_results, guideline_context)
|
| 364 |
-
|
| 365 |
-
except Exception as e:
|
| 366 |
-
logging.error(f"MedGemma report generation error: {e}")
|
| 367 |
-
return self._generate_fallback_report(patient_info, visual_results, guideline_context)
|
| 368 |
-
|
| 369 |
-
def _generate_fallback_report(self, patient_info, visual_results, guideline_context):
|
| 370 |
-
"""Generate a fallback report when MedGemma is not available"""
|
| 371 |
-
# Get image paths for report
|
| 372 |
-
detection_path = visual_results.get('detection_image_path', 'Not available')
|
| 373 |
-
segmentation_path = visual_results.get('segmentation_image_path', 'Not available')
|
| 374 |
-
|
| 375 |
-
report = f"""
|
| 376 |
-
# Wound Analysis Report
|
| 377 |
-
## Patient Information
|
| 378 |
-
{patient_info}
|
| 379 |
-
|
| 380 |
-
## Visual Analysis Results
|
| 381 |
-
- **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
|
| 382 |
-
- **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
|
| 383 |
-
- **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
|
| 384 |
-
- **Detection Confidence**: {visual_results.get('detection_confidence', 0):.2f}
|
| 385 |
-
|
| 386 |
-
## Analysis Images
|
| 387 |
-
- **Detection Image**: {detection_path}
|
| 388 |
-
- **Segmentation Image**: {segmentation_path}
|
| 389 |
-
|
| 390 |
-
## Assessment
|
| 391 |
-
Based on the visual analysis, this appears to be a {visual_results.get('wound_type', 'wound')} with measurable dimensions.
|
| 392 |
-
|
| 393 |
-
## Recommendations
|
| 394 |
-
- Continue monitoring wound healing progress
|
| 395 |
-
- Maintain proper wound hygiene
|
| 396 |
-
- Follow appropriate dressing protocols
|
| 397 |
-
- Seek medical attention if signs of infection develop
|
| 398 |
-
|
| 399 |
-
## Clinical Guidelines
|
| 400 |
-
{guideline_context[:500]}...
|
| 401 |
-
|
| 402 |
-
*Note: This is an automated analysis. Please consult with a healthcare professional for definitive diagnosis and treatment.*
|
| 403 |
-
"""
|
| 404 |
-
return report
|
| 405 |
-
|
| 406 |
-
def save_and_commit_image(self, image_pil):
|
| 407 |
-
"""Save image locally and optionally upload to HuggingFace dataset"""
|
| 408 |
-
try:
|
| 409 |
-
# Ensure uploads directory exists
|
| 410 |
-
os.makedirs(self.config.UPLOADS_DIR, exist_ok=True)
|
| 411 |
-
|
| 412 |
-
# Generate filename with timestamp
|
| 413 |
-
filename = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.png"
|
| 414 |
-
local_path = os.path.join(self.config.UPLOADS_DIR, filename)
|
| 415 |
-
|
| 416 |
-
# Save image locally
|
| 417 |
-
image_pil.convert("RGB").save(local_path)
|
| 418 |
-
logging.info(f"Image saved locally: {local_path}")
|
| 419 |
-
|
| 420 |
-
# Upload to HuggingFace dataset if configured
|
| 421 |
-
if self.config.HF_TOKEN and self.config.DATASET_ID:
|
| 422 |
-
try:
|
| 423 |
-
api = HfApi()
|
| 424 |
-
api.upload_file(
|
| 425 |
-
path_or_fileobj=local_path,
|
| 426 |
-
path_in_repo=f"images/{filename}",
|
| 427 |
-
repo_id=self.config.DATASET_ID,
|
| 428 |
-
repo_type="dataset",
|
| 429 |
-
commit_message=f"Upload wound image: {filename}"
|
| 430 |
-
)
|
| 431 |
-
logging.info("✅ Image uploaded to HuggingFace dataset")
|
| 432 |
-
except Exception as e:
|
| 433 |
-
logging.warning(f"HuggingFace upload failed: {e}")
|
| 434 |
-
|
| 435 |
-
return local_path
|
| 436 |
-
|
| 437 |
-
except Exception as e:
|
| 438 |
-
logging.error(f"Image saving error: {e}")
|
| 439 |
-
return None
|
| 440 |
-
|
| 441 |
-
@spaces.GPU(enable_queue=True, duration=120)
|
| 442 |
-
def full_analysis_pipeline(self, image, questionnaire_data):
|
| 443 |
-
"""Complete analysis pipeline with real-time models"""
|
| 444 |
-
try:
|
| 445 |
-
# Save the image
|
| 446 |
-
saved_path = self.save_and_commit_image(image)
|
| 447 |
-
|
| 448 |
-
# Perform visual analysis
|
| 449 |
-
visual_results = self.perform_visual_analysis(image)
|
| 450 |
-
|
| 451 |
-
# Format patient information
|
| 452 |
-
patient_info = ", ".join([f"{k}: {v}" for k, v in questionnaire_data.items() if v])
|
| 453 |
-
|
| 454 |
-
# Create query for guidelines
|
| 455 |
-
wound_type = visual_results.get('wound_type', 'wound')
|
| 456 |
-
moisture = questionnaire_data.get('moisture', 'unknown')
|
| 457 |
-
infection = questionnaire_data.get('infection', 'unknown')
|
| 458 |
-
diabetic = questionnaire_data.get('diabetic', 'unknown')
|
| 459 |
-
|
| 460 |
-
query = f"best practices for managing a {wound_type} with moisture level '{moisture}' and signs of infection '{infection}' in a patient who is diabetic '{diabetic}'"
|
| 461 |
-
|
| 462 |
-
# Query guidelines
|
| 463 |
-
guideline_context = self.query_guidelines(query)
|
| 464 |
-
|
| 465 |
-
# Generate final report
|
| 466 |
-
final_report = self.generate_final_report(patient_info, visual_results, guideline_context, image)
|
| 467 |
-
|
| 468 |
-
return {
|
| 469 |
-
'success': True,
|
| 470 |
-
'visual_analysis': visual_results,
|
| 471 |
-
'report': final_report,
|
| 472 |
-
'saved_image_path': saved_path,
|
| 473 |
-
'timestamp': datetime.now().isoformat()
|
| 474 |
-
}
|
| 475 |
-
|
| 476 |
-
except Exception as e:
|
| 477 |
-
logging.error(f"Full analysis pipeline error: {e}")
|
| 478 |
-
return {
|
| 479 |
-
'success': False,
|
| 480 |
-
'error': str(e),
|
| 481 |
-
'timestamp': datetime.now().isoformat()
|
| 482 |
-
}
|
| 483 |
-
|
| 484 |
-
# Legacy methods for backward compatibility
|
| 485 |
-
def analyze_wound(self, image, questionnaire_data):
|
| 486 |
-
"""Legacy method for backward compatibility"""
|
| 487 |
-
try:
|
| 488 |
-
# Convert string path to PIL Image if needed
|
| 489 |
-
if isinstance(image, str):
|
| 490 |
-
try:
|
| 491 |
-
from PIL import Image
|
| 492 |
-
image = Image.open(image)
|
| 493 |
-
logging.info(f"Converted string path to PIL Image: {image}")
|
| 494 |
-
except Exception as e:
|
| 495 |
-
logging.error(f"Error converting string path to image: {e}")
|
| 496 |
-
|
| 497 |
-
# Ensure we have a PIL Image object
|
| 498 |
-
if not isinstance(image, Image.Image):
|
| 499 |
-
try:
|
| 500 |
-
from PIL import Image
|
| 501 |
-
import io
|
| 502 |
-
|
| 503 |
-
# If it's a file-like object
|
| 504 |
-
if hasattr(image, 'read'):
|
| 505 |
-
# Reset file pointer if possible
|
| 506 |
-
if hasattr(image, 'seek'):
|
| 507 |
-
image.seek(0)
|
| 508 |
-
image = Image.open(image)
|
| 509 |
-
logging.info("Converted file-like object to PIL Image")
|
| 510 |
-
except Exception as e:
|
| 511 |
-
logging.error(f"Error ensuring image is PIL Image: {e}")
|
| 512 |
-
raise ValueError(f"Invalid image format: {type(image)}")
|
| 513 |
-
|
| 514 |
-
result = self.full_analysis_pipeline(image, questionnaire_data)
|
| 515 |
-
|
| 516 |
-
if result['success']:
|
| 517 |
-
return {
|
| 518 |
-
'timestamp': result['timestamp'],
|
| 519 |
-
'summary': f"Analysis completed for {questionnaire_data.get('patient_name', 'patient')}",
|
| 520 |
-
'recommendations': result['report'],
|
| 521 |
-
'wound_detection': {
|
| 522 |
-
'status': 'success',
|
| 523 |
-
'detections': [result['visual_analysis']],
|
| 524 |
-
'total_wounds': 1
|
| 525 |
-
},
|
| 526 |
-
'segmentation_result': {
|
| 527 |
-
'status': 'success',
|
| 528 |
-
'wound_area_percentage': result['visual_analysis'].get('surface_area_cm2', 0)
|
| 529 |
-
},
|
| 530 |
-
'risk_assessment': self._assess_risk_legacy(questionnaire_data),
|
| 531 |
-
'guideline_recommendations': [result['report'][:200] + "..."]
|
| 532 |
-
}
|
| 533 |
-
else:
|
| 534 |
-
return {
|
| 535 |
-
'timestamp': result['timestamp'],
|
| 536 |
-
'summary': f"Analysis failed: {result['error']}",
|
| 537 |
-
'recommendations': "Please consult with a healthcare professional.",
|
| 538 |
-
'wound_detection': {'status': 'error', 'message': result['error']},
|
| 539 |
-
'segmentation_result': {'status': 'error', 'message': result['error']},
|
| 540 |
-
'risk_assessment': {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []},
|
| 541 |
-
'guideline_recommendations': ["Analysis unavailable due to error"]
|
| 542 |
-
}
|
| 543 |
-
|
| 544 |
-
except Exception as e:
|
| 545 |
-
logging.error(f"Legacy analyze_wound error: {e}")
|
| 546 |
-
return {
|
| 547 |
-
'timestamp': datetime.now().isoformat(),
|
| 548 |
-
'summary': f"Analysis error: {str(e)}",
|
| 549 |
-
'recommendations': "Please consult with a healthcare professional.",
|
| 550 |
-
'wound_detection': {'status': 'error', 'message': str(e)},
|
| 551 |
-
'segmentation_result': {'status': 'error', 'message': str(e)},
|
| 552 |
-
'risk_assessment': {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []},
|
| 553 |
-
'guideline_recommendations': ["Analysis unavailable due to error"]
|
| 554 |
-
}
|
| 555 |
-
|
| 556 |
-
def _assess_risk_legacy(self, questionnaire_data):
|
| 557 |
-
"""Legacy risk assessment for backward compatibility"""
|
| 558 |
-
risk_factors = []
|
| 559 |
-
risk_score = 0
|
| 560 |
-
|
| 561 |
-
try:
|
| 562 |
-
# Age factor
|
| 563 |
-
age = questionnaire_data.get('patient_age', 0)
|
| 564 |
-
if age > 65:
|
| 565 |
-
risk_factors.append("Advanced age (>65)")
|
| 566 |
-
risk_score += 2
|
| 567 |
-
elif age > 50:
|
| 568 |
-
risk_factors.append("Older adult (50-65)")
|
| 569 |
-
risk_score += 1
|
| 570 |
-
|
| 571 |
-
# Duration factor
|
| 572 |
-
duration = questionnaire_data.get('wound_duration', '').lower()
|
| 573 |
-
if any(term in duration for term in ['month', 'months', 'year']):
|
| 574 |
-
risk_factors.append("Chronic wound (>4 weeks)")
|
| 575 |
-
risk_score += 3
|
| 576 |
-
|
| 577 |
-
# Pain level
|
| 578 |
-
pain_level = questionnaire_data.get('pain_level', 0)
|
| 579 |
-
if pain_level >= 7:
|
| 580 |
-
risk_factors.append("High pain level")
|
| 581 |
-
risk_score += 2
|
| 582 |
-
|
| 583 |
-
# Medical history risk factors
|
| 584 |
-
medical_history = questionnaire_data.get('medical_history', '').lower()
|
| 585 |
-
if 'diabetes' in medical_history:
|
| 586 |
-
risk_factors.append("Diabetes mellitus")
|
| 587 |
-
risk_score += 3
|
| 588 |
-
if 'circulation' in medical_history or 'vascular' in medical_history:
|
| 589 |
-
risk_factors.append("Vascular/circulation issues")
|
| 590 |
-
risk_score += 2
|
| 591 |
-
if 'immune' in medical_history:
|
| 592 |
-
risk_factors.append("Immune system compromise")
|
| 593 |
-
risk_score += 2
|
| 594 |
-
|
| 595 |
-
# Determine risk level
|
| 596 |
-
if risk_score >= 7:
|
| 597 |
-
risk_level = "High"
|
| 598 |
-
elif risk_score >= 4:
|
| 599 |
-
risk_level = "Moderate"
|
| 600 |
-
else:
|
| 601 |
-
risk_level = "Low"
|
| 602 |
-
|
| 603 |
-
return {
|
| 604 |
-
'risk_score': risk_score,
|
| 605 |
-
'risk_level': risk_level,
|
| 606 |
-
'risk_factors': risk_factors
|
| 607 |
-
}
|
| 608 |
-
|
| 609 |
-
except Exception as e:
|
| 610 |
-
logging.error(f"Risk assessment error: {e}")
|
| 611 |
return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
import json
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
import tensorflow as tf
|
| 10 |
+
from transformers import pipeline
|
| 11 |
+
from ultralytics import YOLO
|
| 12 |
+
from tensorflow.keras.models import load_model
|
| 13 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 14 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 15 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 16 |
+
from langchain_community.vectorstores import FAISS
|
| 17 |
+
from huggingface_hub import HfApi, HfFolder
|
| 18 |
+
import spaces
|
| 19 |
+
|
| 20 |
+
from .config import Config
|
| 21 |
+
|
| 22 |
+
class AIProcessor:
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.models_cache = {}
|
| 25 |
+
self.knowledge_base_cache = {}
|
| 26 |
+
self.config = Config()
|
| 27 |
+
self.px_per_cm = self.config.PIXELS_PER_CM
|
| 28 |
+
self._initialize_models()
|
| 29 |
+
|
| 30 |
+
def _initialize_models(self):
|
| 31 |
+
"""Initialize all AI models including real-time models"""
|
| 32 |
+
try:
|
| 33 |
+
# Set HuggingFace token
|
| 34 |
+
if self.config.HF_TOKEN:
|
| 35 |
+
HfFolder.save_token(self.config.HF_TOKEN)
|
| 36 |
+
logging.info("HuggingFace token set successfully")
|
| 37 |
+
|
| 38 |
+
# Initialize MedGemma pipeline for medical text generation
|
| 39 |
+
try:
|
| 40 |
+
self.models_cache["medgemma_pipe"] = pipeline(
|
| 41 |
+
"image-text-to-text",
|
| 42 |
+
model="google/medgemma-4b-it",
|
| 43 |
+
torch_dtype=torch.bfloat16,
|
| 44 |
+
offload_folder="offload",
|
| 45 |
+
device_map="auto",
|
| 46 |
+
token=self.config.HF_TOKEN
|
| 47 |
+
)
|
| 48 |
+
logging.info("✅ MedGemma pipeline loaded successfully")
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logging.warning(f"MedGemma pipeline not available: {e}")
|
| 51 |
+
|
| 52 |
+
# Initialize YOLO model for wound detection
|
| 53 |
+
try:
|
| 54 |
+
self.models_cache["det"] = YOLO(self.config.YOLO_MODEL_PATH)
|
| 55 |
+
logging.info("✅ YOLO detection model loaded successfully")
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logging.warning(f"YOLO model not available: {e}")
|
| 58 |
+
|
| 59 |
+
# Initialize segmentation model
|
| 60 |
+
try:
|
| 61 |
+
self.models_cache["seg"] = load_model(self.config.SEG_MODEL_PATH, compile=False)
|
| 62 |
+
logging.info("✅ Segmentation model loaded successfully")
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logging.warning(f"Segmentation model not available: {e}")
|
| 65 |
+
|
| 66 |
+
# Initialize wound classification model
|
| 67 |
+
try:
|
| 68 |
+
self.models_cache["cls"] = pipeline(
|
| 69 |
+
"image-classification",
|
| 70 |
+
model="Hemg/Wound-classification",
|
| 71 |
+
token=self.config.HF_TOKEN,
|
| 72 |
+
device="cpu"
|
| 73 |
+
)
|
| 74 |
+
logging.info("✅ Wound classification model loaded successfully")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logging.warning(f"Wound classification model not available: {e}")
|
| 77 |
+
|
| 78 |
+
# Initialize embedding model for knowledge base
|
| 79 |
+
try:
|
| 80 |
+
self.models_cache["embedding_model"] = HuggingFaceEmbeddings(
|
| 81 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 82 |
+
model_kwargs={'device': 'cpu'}
|
| 83 |
+
)
|
| 84 |
+
logging.info("✅ Embedding model loaded successfully")
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logging.warning(f"Embedding model not available: {e}")
|
| 87 |
+
|
| 88 |
+
logging.info("✅ All models loaded.")
|
| 89 |
+
self._load_knowledge_base()
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logging.error(f"Error initializing AI models: {e}")
|
| 93 |
+
|
| 94 |
+
def _load_knowledge_base(self):
|
| 95 |
+
"""Load knowledge base from PDF guidelines"""
|
| 96 |
+
try:
|
| 97 |
+
documents = []
|
| 98 |
+
for pdf_path in self.config.GUIDELINE_PDFS:
|
| 99 |
+
if os.path.exists(pdf_path):
|
| 100 |
+
loader = PyPDFLoader(pdf_path)
|
| 101 |
+
docs = loader.load()
|
| 102 |
+
documents.extend(docs)
|
| 103 |
+
logging.info(f"Loaded PDF: {pdf_path}")
|
| 104 |
+
|
| 105 |
+
if documents and 'embedding_model' in self.models_cache:
|
| 106 |
+
# Split documents into chunks
|
| 107 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 108 |
+
chunk_size=1000,
|
| 109 |
+
chunk_overlap=100
|
| 110 |
+
)
|
| 111 |
+
chunks = text_splitter.split_documents(documents)
|
| 112 |
+
|
| 113 |
+
# Create vector store
|
| 114 |
+
vectorstore = FAISS.from_documents(chunks, self.models_cache['embedding_model'])
|
| 115 |
+
self.knowledge_base_cache['vectorstore'] = vectorstore
|
| 116 |
+
logging.info(f"✅ Knowledge base loaded with {len(chunks)} chunks")
|
| 117 |
+
else:
|
| 118 |
+
self.knowledge_base_cache['vectorstore'] = None
|
| 119 |
+
logging.warning("Knowledge base not available - no PDFs found or embedding model unavailable")
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logging.warning(f"Knowledge base loading error: {e}")
|
| 123 |
+
self.knowledge_base_cache['vectorstore'] = None
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def perform_visual_analysis(self, image_pil):
|
| 127 |
+
"""Perform comprehensive visual analysis of wound image."""
|
| 128 |
+
try:
|
| 129 |
+
# Convert PIL to OpenCV format
|
| 130 |
+
image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
| 131 |
+
|
| 132 |
+
# YOLO detection
|
| 133 |
+
if 'det' not in self.models_cache:
|
| 134 |
+
raise ValueError("YOLO detection model not available.")
|
| 135 |
+
|
| 136 |
+
results = self.models_cache['det'].predict(image_cv, verbose=False, device="cpu")
|
| 137 |
+
|
| 138 |
+
if not results or not results[0].boxes:
|
| 139 |
+
raise ValueError("No wound detected in the image.")
|
| 140 |
+
|
| 141 |
+
# Extract bounding box
|
| 142 |
+
box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
|
| 143 |
+
x1, y1, x2, y2 = box
|
| 144 |
+
region_cv = image_cv[y1:y2, x1:x2]
|
| 145 |
+
|
| 146 |
+
# Save detection image
|
| 147 |
+
detection_image_cv = image_cv.copy()
|
| 148 |
+
cv2.rectangle(detection_image_cv, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 149 |
+
os.makedirs(os.path.join(self.config.UPLOADS_DIR, "analysis"), exist_ok=True)
|
| 150 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 151 |
+
detection_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"detection_{timestamp}.png")
|
| 152 |
+
cv2.imwrite(detection_image_path, detection_image_cv)
|
| 153 |
+
detection_image_pil = Image.fromarray(cv2.cvtColor(detection_image_cv, cv2.COLOR_BGR2RGB))
|
| 154 |
+
|
| 155 |
+
# Initialize outputs
|
| 156 |
+
length = breadth = area = 0
|
| 157 |
+
segmentation_image_pil = None
|
| 158 |
+
segmentation_image_path = None
|
| 159 |
+
|
| 160 |
+
# Segmentation (optional)
|
| 161 |
+
if 'seg' in self.models_cache:
|
| 162 |
+
input_size = self.models_cache['seg'].input_shape[1:3] # (height, width)
|
| 163 |
+
resized_region = cv2.resize(region_cv, (input_size[1], input_size[0]))
|
| 164 |
+
|
| 165 |
+
seg_input = np.expand_dims(resized_region / 255.0, 0)
|
| 166 |
+
mask_pred = self.models_cache['seg'].predict(seg_input, verbose=0)[0]
|
| 167 |
+
mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8)
|
| 168 |
+
|
| 169 |
+
# Resize mask back to original region size
|
| 170 |
+
mask_resized = cv2.resize(mask_np, (region_cv.shape[1], region_cv.shape[0]), interpolation=cv2.INTER_NEAREST)
|
| 171 |
+
|
| 172 |
+
# Overlay mask on region for visualization
|
| 173 |
+
overlay = region_cv.copy()
|
| 174 |
+
overlay[mask_resized == 1] = [0, 0, 255] # Red overlay
|
| 175 |
+
|
| 176 |
+
# Blend overlay for final output
|
| 177 |
+
segmented_visual = cv2.addWeighted(region_cv, 0.7, overlay, 0.3, 0)
|
| 178 |
+
|
| 179 |
+
# Save segmentation image
|
| 180 |
+
segmentation_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"segmentation_{timestamp}.png")
|
| 181 |
+
cv2.imwrite(segmentation_image_path, segmented_visual)
|
| 182 |
+
segmentation_image_pil = Image.fromarray(cv2.cvtColor(segmented_visual, cv2.COLOR_BGR2RGB))
|
| 183 |
+
|
| 184 |
+
# Wound measurements from resized mask
|
| 185 |
+
contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 186 |
+
if contours:
|
| 187 |
+
cnt = max(contours, key=cv2.contourArea)
|
| 188 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
| 189 |
+
length = round(h / self.px_per_cm, 2)
|
| 190 |
+
breadth = round(w / self.px_per_cm, 2)
|
| 191 |
+
area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
|
| 192 |
+
|
| 193 |
+
# Classification (optional)
|
| 194 |
+
wound_type = "Unknown"
|
| 195 |
+
if 'cls' in self.models_cache:
|
| 196 |
+
try:
|
| 197 |
+
region_pil = Image.fromarray(cv2.cvtColor(region_cv, cv2.COLOR_BGR2RGB))
|
| 198 |
+
cls_result = self.models_cache['cls'](region_pil)
|
| 199 |
+
wound_type = max(cls_result, key=lambda x: x['score'])['label']
|
| 200 |
+
except Exception as e:
|
| 201 |
+
logging.warning(f"Wound classification error: {e}")
|
| 202 |
+
|
| 203 |
+
return {
|
| 204 |
+
'wound_type': wound_type,
|
| 205 |
+
'length_cm': length,
|
| 206 |
+
'breadth_cm': breadth,
|
| 207 |
+
'surface_area_cm2': area,
|
| 208 |
+
'detection_confidence': float(results[0].boxes[0].conf.cpu().item()),
|
| 209 |
+
'bounding_box': box.tolist(),
|
| 210 |
+
'detection_image_path': detection_image_path,
|
| 211 |
+
'detection_image_pil': detection_image_pil,
|
| 212 |
+
'segmentation_image_path': segmentation_image_path,
|
| 213 |
+
'segmentation_image_pil': segmentation_image_pil
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logging.error(f"Visual analysis error: {e}")
|
| 218 |
+
raise ValueError(f"Visual analysis failed: {str(e)}")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def query_guidelines(self, query: str):
|
| 222 |
+
"""Query the knowledge base for relevant guidelines"""
|
| 223 |
+
try:
|
| 224 |
+
vector_store = self.knowledge_base_cache.get("vectorstore")
|
| 225 |
+
if not vector_store:
|
| 226 |
+
return "Knowledge base unavailable - clinical guidelines not loaded"
|
| 227 |
+
|
| 228 |
+
# Retrieve relevant documents
|
| 229 |
+
retriever = vector_store.as_retriever(search_kwargs={"k": 10})
|
| 230 |
+
docs = retriever.invoke(query)
|
| 231 |
+
|
| 232 |
+
if not docs:
|
| 233 |
+
return "No relevant guidelines found for the query"
|
| 234 |
+
|
| 235 |
+
# Format the results
|
| 236 |
+
formatted_results = []
|
| 237 |
+
for doc in docs:
|
| 238 |
+
source = doc.metadata.get('source', 'Unknown')
|
| 239 |
+
page = doc.metadata.get('page', 'N/A')
|
| 240 |
+
content = doc.page_content.strip()
|
| 241 |
+
formatted_results.append(f"Source: {source}, Page: {page}\nContent: {content}")
|
| 242 |
+
|
| 243 |
+
return "\n\n".join(formatted_results)
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
logging.error(f"Guidelines query error: {e}")
|
| 247 |
+
return f"Error querying guidelines: {str(e)}"
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None):
|
| 251 |
+
"""Generate comprehensive medical report using MedGemma"""
|
| 252 |
+
try:
|
| 253 |
+
if 'medgemma_pipe' not in self.models_cache:
|
| 254 |
+
return self._generate_fallback_report(patient_info, visual_results, guideline_context)
|
| 255 |
+
|
| 256 |
+
max_tokens = max_new_tokens or self.config.MAX_NEW_TOKENS
|
| 257 |
+
|
| 258 |
+
# Get detection and segmentation images if available
|
| 259 |
+
detection_image = visual_results.get('detection_image_pil', None)
|
| 260 |
+
segmentation_image = visual_results.get('segmentation_image_pil', None)
|
| 261 |
+
|
| 262 |
+
# Create image paths for report
|
| 263 |
+
detection_path = visual_results.get('detection_image_path', '')
|
| 264 |
+
segmentation_path = visual_results.get('segmentation_image_path', '')
|
| 265 |
+
|
| 266 |
+
# Create detailed prompt for medical analysis with image paths
|
| 267 |
+
prompt = f"""
|
| 268 |
+
# Wound Care Report
|
| 269 |
+
|
| 270 |
+
## Patient Information
|
| 271 |
+
{patient_info}
|
| 272 |
+
|
| 273 |
+
## Visual Analysis Summary
|
| 274 |
+
- Wound Type: {visual_results.get('wound_type', 'Unknown')}
|
| 275 |
+
- Length: {visual_results.get('length_cm', 0)} cm
|
| 276 |
+
- Breadth: {visual_results.get('breadth_cm', 0)} cm
|
| 277 |
+
- Surface Area: {visual_results.get('surface_area_cm2', 0)} cm²
|
| 278 |
+
- Detection Confidence: {visual_results.get('detection_confidence', 0):.2f}
|
| 279 |
+
|
| 280 |
+
## Clinical Reference
|
| 281 |
+
{guideline_context}
|
| 282 |
+
|
| 283 |
+
You are SmartHeal-AI Agent, a world-class wound care AI specialist trained in clinical wound assessment and guideline-based treatment planning.
|
| 284 |
+
Your task is to process the following structured inputs (patient data, wound measurements, clinical guidelines, and image) and perform **clinical reasoning and decision-making** to generate a complete wound care report.
|
| 285 |
+
---
|
| 286 |
+
🔍 **YOUR PROCESS — FOLLOW STRICTLY:**
|
| 287 |
+
### Step 1: Clinical Reasoning (Chain-of-Thought)
|
| 288 |
+
Use the provided information to think step-by-step about:
|
| 289 |
+
- Patient’s risk factors (e.g. diabetes, age, healing limitations)
|
| 290 |
+
- Wound characteristics (size, tissue appearance, moisture, infection signs)
|
| 291 |
+
- Visual clues from the image (location, granulation, maceration, inflammation, surrounding skin)
|
| 292 |
+
|
| 293 |
+
---
|
| 294 |
+
-Step 2: Structured Clinical Report
|
| 295 |
+
Generate the following report sections using markdown and medical terminology:
|
| 296 |
+
**1. Clinical Summary**
|
| 297 |
+
- Describe wound appearance and tissue types (e.g., slough, necrotic, granulating, epithelializing)
|
| 298 |
+
- Include size, wound bed condition, peri-wound skin, and signs of infection or biofilm
|
| 299 |
+
- Mention inferred location (e.g., heel, forefoot) if image allows
|
| 300 |
+
- Summarize patient's systemic risk profile
|
| 301 |
+
**2. Medicinal & Dressing Recommendations**
|
| 302 |
+
Based on your analysis:
|
| 303 |
+
- Recommend specific **wound care dressings** (e.g., hydrocolloid, alginate, foam, antimicrobial silver, etc.) suitable to wound moisture level and infection risk
|
| 304 |
+
- Propose **topical or systemic agents** ONLY if relevant — include name classes (e.g., antiseptic: povidone iodine, antibiotic ointments, enzymatic debriders)
|
| 305 |
+
- Mention **techniques** (e.g., sharp debridement, NPWT, moisture balance, pressure offloading, dressing frequency)
|
| 306 |
+
- Avoid repeating guidelines — **apply them**
|
| 307 |
+
**3. Key Risk Factors**
|
| 308 |
+
Explain how the patient’s condition (e.g., diabetic, poor circulation, advanced age, poor hygiene) may affect wound healing
|
| 309 |
+
**4. Prognosis & Monitoring Advice**
|
| 310 |
+
- Mention how often wound should be reassessed
|
| 311 |
+
- Indicate signs to monitor for deterioration or improvement
|
| 312 |
+
- Include when escalation to specialist is necessary
|
| 313 |
+
|
| 314 |
+
**Note:** Every dressing change is a chance for wound reassessment. Always perform a thorough wound evaluation at each dressing change.
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
# Prepare messages for MedGemma with all available images
|
| 318 |
+
content_list = [{"type": "text", "text": prompt}]
|
| 319 |
+
|
| 320 |
+
# Add original image
|
| 321 |
+
if image_pil:
|
| 322 |
+
content_list.insert(0, {"type": "image", "image": image_pil})
|
| 323 |
+
|
| 324 |
+
# Add detection image if available
|
| 325 |
+
if detection_image:
|
| 326 |
+
content_list.insert(1, {"type": "image", "image": detection_image})
|
| 327 |
+
|
| 328 |
+
# Add segmentation image if available
|
| 329 |
+
if segmentation_image:
|
| 330 |
+
content_list.insert(2, {"type": "image", "image": segmentation_image})
|
| 331 |
+
|
| 332 |
+
messages = [
|
| 333 |
+
{
|
| 334 |
+
"role": "system",
|
| 335 |
+
"content": [{"type": "text", "text": "You are a world-class medical AI assistant specializing in wound care with expertise in wound assessment and treatment. Provide concise, evidence-based medical assessments focusing on: (1) Precise wound classification based on tissue type and appearance, (2) Specific treatment recommendations with exact product names or interventions when appropriate, (3) Objective evaluation of healing progression or deterioration indicators, and (4) Clear follow-up timelines. Avoid general statements and prioritize actionable insights based on the visual analysis measurements and patient context."}],
|
| 336 |
+
},
|
| 337 |
+
{
|
| 338 |
+
"role": "user",
|
| 339 |
+
"content": content_list
|
| 340 |
+
}
|
| 341 |
+
]
|
| 342 |
+
|
| 343 |
+
# Generate report using MedGemma
|
| 344 |
+
output = self.models_cache['medgemma_pipe'](
|
| 345 |
+
text=messages,
|
| 346 |
+
max_new_tokens=1024,
|
| 347 |
+
do_sample=False,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
generated_content = output[0]['generated_text'][-1].get('content', '').strip()
|
| 351 |
+
|
| 352 |
+
# Include image paths in the final report for display in UI
|
| 353 |
+
if generated_content:
|
| 354 |
+
# Add image paths to the report for frontend display
|
| 355 |
+
image_paths_section = f"""
|
| 356 |
+
## Analysis Images
|
| 357 |
+
- Original Image: {image_pil}
|
| 358 |
+
- Detection Image: {detection_path}
|
| 359 |
+
- Segmentation Image: {segmentation_path}
|
| 360 |
+
"""
|
| 361 |
+
generated_content = image_paths_section + generated_content
|
| 362 |
+
|
| 363 |
+
return generated_content if generated_content else self._generate_fallback_report(patient_info, visual_results, guideline_context)
|
| 364 |
+
|
| 365 |
+
except Exception as e:
|
| 366 |
+
logging.error(f"MedGemma report generation error: {e}")
|
| 367 |
+
return self._generate_fallback_report(patient_info, visual_results, guideline_context)
|
| 368 |
+
|
| 369 |
+
def _generate_fallback_report(self, patient_info, visual_results, guideline_context):
|
| 370 |
+
"""Generate a fallback report when MedGemma is not available"""
|
| 371 |
+
# Get image paths for report
|
| 372 |
+
detection_path = visual_results.get('detection_image_path', 'Not available')
|
| 373 |
+
segmentation_path = visual_results.get('segmentation_image_path', 'Not available')
|
| 374 |
+
|
| 375 |
+
report = f"""
|
| 376 |
+
# Wound Analysis Report
|
| 377 |
+
## Patient Information
|
| 378 |
+
{patient_info}
|
| 379 |
+
|
| 380 |
+
## Visual Analysis Results
|
| 381 |
+
- **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
|
| 382 |
+
- **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
|
| 383 |
+
- **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
|
| 384 |
+
- **Detection Confidence**: {visual_results.get('detection_confidence', 0):.2f}
|
| 385 |
+
|
| 386 |
+
## Analysis Images
|
| 387 |
+
- **Detection Image**: {detection_path}
|
| 388 |
+
- **Segmentation Image**: {segmentation_path}
|
| 389 |
+
|
| 390 |
+
## Assessment
|
| 391 |
+
Based on the visual analysis, this appears to be a {visual_results.get('wound_type', 'wound')} with measurable dimensions.
|
| 392 |
+
|
| 393 |
+
## Recommendations
|
| 394 |
+
- Continue monitoring wound healing progress
|
| 395 |
+
- Maintain proper wound hygiene
|
| 396 |
+
- Follow appropriate dressing protocols
|
| 397 |
+
- Seek medical attention if signs of infection develop
|
| 398 |
+
|
| 399 |
+
## Clinical Guidelines
|
| 400 |
+
{guideline_context[:500]}...
|
| 401 |
+
|
| 402 |
+
*Note: This is an automated analysis. Please consult with a healthcare professional for definitive diagnosis and treatment.*
|
| 403 |
+
"""
|
| 404 |
+
return report
|
| 405 |
+
|
| 406 |
+
def save_and_commit_image(self, image_pil):
|
| 407 |
+
"""Save image locally and optionally upload to HuggingFace dataset"""
|
| 408 |
+
try:
|
| 409 |
+
# Ensure uploads directory exists
|
| 410 |
+
os.makedirs(self.config.UPLOADS_DIR, exist_ok=True)
|
| 411 |
+
|
| 412 |
+
# Generate filename with timestamp
|
| 413 |
+
filename = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.png"
|
| 414 |
+
local_path = os.path.join(self.config.UPLOADS_DIR, filename)
|
| 415 |
+
|
| 416 |
+
# Save image locally
|
| 417 |
+
image_pil.convert("RGB").save(local_path)
|
| 418 |
+
logging.info(f"Image saved locally: {local_path}")
|
| 419 |
+
|
| 420 |
+
# Upload to HuggingFace dataset if configured
|
| 421 |
+
if self.config.HF_TOKEN and self.config.DATASET_ID:
|
| 422 |
+
try:
|
| 423 |
+
api = HfApi()
|
| 424 |
+
api.upload_file(
|
| 425 |
+
path_or_fileobj=local_path,
|
| 426 |
+
path_in_repo=f"images/{filename}",
|
| 427 |
+
repo_id=self.config.DATASET_ID,
|
| 428 |
+
repo_type="dataset",
|
| 429 |
+
commit_message=f"Upload wound image: {filename}"
|
| 430 |
+
)
|
| 431 |
+
logging.info("✅ Image uploaded to HuggingFace dataset")
|
| 432 |
+
except Exception as e:
|
| 433 |
+
logging.warning(f"HuggingFace upload failed: {e}")
|
| 434 |
+
|
| 435 |
+
return local_path
|
| 436 |
+
|
| 437 |
+
except Exception as e:
|
| 438 |
+
logging.error(f"Image saving error: {e}")
|
| 439 |
+
return None
|
| 440 |
+
|
| 441 |
+
@spaces.GPU(enable_queue=True, duration=120)
|
| 442 |
+
def full_analysis_pipeline(self, image, questionnaire_data):
|
| 443 |
+
"""Complete analysis pipeline with real-time models"""
|
| 444 |
+
try:
|
| 445 |
+
# Save the image
|
| 446 |
+
saved_path = self.save_and_commit_image(image)
|
| 447 |
+
|
| 448 |
+
# Perform visual analysis
|
| 449 |
+
visual_results = self.perform_visual_analysis(image)
|
| 450 |
+
|
| 451 |
+
# Format patient information
|
| 452 |
+
patient_info = ", ".join([f"{k}: {v}" for k, v in questionnaire_data.items() if v])
|
| 453 |
+
|
| 454 |
+
# Create query for guidelines
|
| 455 |
+
wound_type = visual_results.get('wound_type', 'wound')
|
| 456 |
+
moisture = questionnaire_data.get('moisture', 'unknown')
|
| 457 |
+
infection = questionnaire_data.get('infection', 'unknown')
|
| 458 |
+
diabetic = questionnaire_data.get('diabetic', 'unknown')
|
| 459 |
+
|
| 460 |
+
query = f"best practices for managing a {wound_type} with moisture level '{moisture}' and signs of infection '{infection}' in a patient who is diabetic '{diabetic}'"
|
| 461 |
+
|
| 462 |
+
# Query guidelines
|
| 463 |
+
guideline_context = self.query_guidelines(query)
|
| 464 |
+
|
| 465 |
+
# Generate final report
|
| 466 |
+
final_report = self.generate_final_report(patient_info, visual_results, guideline_context, image)
|
| 467 |
+
|
| 468 |
+
return {
|
| 469 |
+
'success': True,
|
| 470 |
+
'visual_analysis': visual_results,
|
| 471 |
+
'report': final_report,
|
| 472 |
+
'saved_image_path': saved_path,
|
| 473 |
+
'timestamp': datetime.now().isoformat()
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
except Exception as e:
|
| 477 |
+
logging.error(f"Full analysis pipeline error: {e}")
|
| 478 |
+
return {
|
| 479 |
+
'success': False,
|
| 480 |
+
'error': str(e),
|
| 481 |
+
'timestamp': datetime.now().isoformat()
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
# Legacy methods for backward compatibility
|
| 485 |
+
def analyze_wound(self, image, questionnaire_data):
|
| 486 |
+
"""Legacy method for backward compatibility"""
|
| 487 |
+
try:
|
| 488 |
+
# Convert string path to PIL Image if needed
|
| 489 |
+
if isinstance(image, str):
|
| 490 |
+
try:
|
| 491 |
+
from PIL import Image
|
| 492 |
+
image = Image.open(image)
|
| 493 |
+
logging.info(f"Converted string path to PIL Image: {image}")
|
| 494 |
+
except Exception as e:
|
| 495 |
+
logging.error(f"Error converting string path to image: {e}")
|
| 496 |
+
|
| 497 |
+
# Ensure we have a PIL Image object
|
| 498 |
+
if not isinstance(image, Image.Image):
|
| 499 |
+
try:
|
| 500 |
+
from PIL import Image
|
| 501 |
+
import io
|
| 502 |
+
|
| 503 |
+
# If it's a file-like object
|
| 504 |
+
if hasattr(image, 'read'):
|
| 505 |
+
# Reset file pointer if possible
|
| 506 |
+
if hasattr(image, 'seek'):
|
| 507 |
+
image.seek(0)
|
| 508 |
+
image = Image.open(image)
|
| 509 |
+
logging.info("Converted file-like object to PIL Image")
|
| 510 |
+
except Exception as e:
|
| 511 |
+
logging.error(f"Error ensuring image is PIL Image: {e}")
|
| 512 |
+
raise ValueError(f"Invalid image format: {type(image)}")
|
| 513 |
+
|
| 514 |
+
result = self.full_analysis_pipeline(image, questionnaire_data)
|
| 515 |
+
|
| 516 |
+
if result['success']:
|
| 517 |
+
return {
|
| 518 |
+
'timestamp': result['timestamp'],
|
| 519 |
+
'summary': f"Analysis completed for {questionnaire_data.get('patient_name', 'patient')}",
|
| 520 |
+
'recommendations': result['report'],
|
| 521 |
+
'wound_detection': {
|
| 522 |
+
'status': 'success',
|
| 523 |
+
'detections': [result['visual_analysis']],
|
| 524 |
+
'total_wounds': 1
|
| 525 |
+
},
|
| 526 |
+
'segmentation_result': {
|
| 527 |
+
'status': 'success',
|
| 528 |
+
'wound_area_percentage': result['visual_analysis'].get('surface_area_cm2', 0)
|
| 529 |
+
},
|
| 530 |
+
'risk_assessment': self._assess_risk_legacy(questionnaire_data),
|
| 531 |
+
'guideline_recommendations': [result['report'][:200] + "..."]
|
| 532 |
+
}
|
| 533 |
+
else:
|
| 534 |
+
return {
|
| 535 |
+
'timestamp': result['timestamp'],
|
| 536 |
+
'summary': f"Analysis failed: {result['error']}",
|
| 537 |
+
'recommendations': "Please consult with a healthcare professional.",
|
| 538 |
+
'wound_detection': {'status': 'error', 'message': result['error']},
|
| 539 |
+
'segmentation_result': {'status': 'error', 'message': result['error']},
|
| 540 |
+
'risk_assessment': {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []},
|
| 541 |
+
'guideline_recommendations': ["Analysis unavailable due to error"]
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
except Exception as e:
|
| 545 |
+
logging.error(f"Legacy analyze_wound error: {e}")
|
| 546 |
+
return {
|
| 547 |
+
'timestamp': datetime.now().isoformat(),
|
| 548 |
+
'summary': f"Analysis error: {str(e)}",
|
| 549 |
+
'recommendations': "Please consult with a healthcare professional.",
|
| 550 |
+
'wound_detection': {'status': 'error', 'message': str(e)},
|
| 551 |
+
'segmentation_result': {'status': 'error', 'message': str(e)},
|
| 552 |
+
'risk_assessment': {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []},
|
| 553 |
+
'guideline_recommendations': ["Analysis unavailable due to error"]
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
def _assess_risk_legacy(self, questionnaire_data):
|
| 557 |
+
"""Legacy risk assessment for backward compatibility"""
|
| 558 |
+
risk_factors = []
|
| 559 |
+
risk_score = 0
|
| 560 |
+
|
| 561 |
+
try:
|
| 562 |
+
# Age factor
|
| 563 |
+
age = questionnaire_data.get('patient_age', 0)
|
| 564 |
+
if age > 65:
|
| 565 |
+
risk_factors.append("Advanced age (>65)")
|
| 566 |
+
risk_score += 2
|
| 567 |
+
elif age > 50:
|
| 568 |
+
risk_factors.append("Older adult (50-65)")
|
| 569 |
+
risk_score += 1
|
| 570 |
+
|
| 571 |
+
# Duration factor
|
| 572 |
+
duration = questionnaire_data.get('wound_duration', '').lower()
|
| 573 |
+
if any(term in duration for term in ['month', 'months', 'year']):
|
| 574 |
+
risk_factors.append("Chronic wound (>4 weeks)")
|
| 575 |
+
risk_score += 3
|
| 576 |
+
|
| 577 |
+
# Pain level
|
| 578 |
+
pain_level = questionnaire_data.get('pain_level', 0)
|
| 579 |
+
if pain_level >= 7:
|
| 580 |
+
risk_factors.append("High pain level")
|
| 581 |
+
risk_score += 2
|
| 582 |
+
|
| 583 |
+
# Medical history risk factors
|
| 584 |
+
medical_history = questionnaire_data.get('medical_history', '').lower()
|
| 585 |
+
if 'diabetes' in medical_history:
|
| 586 |
+
risk_factors.append("Diabetes mellitus")
|
| 587 |
+
risk_score += 3
|
| 588 |
+
if 'circulation' in medical_history or 'vascular' in medical_history:
|
| 589 |
+
risk_factors.append("Vascular/circulation issues")
|
| 590 |
+
risk_score += 2
|
| 591 |
+
if 'immune' in medical_history:
|
| 592 |
+
risk_factors.append("Immune system compromise")
|
| 593 |
+
risk_score += 2
|
| 594 |
+
|
| 595 |
+
# Determine risk level
|
| 596 |
+
if risk_score >= 7:
|
| 597 |
+
risk_level = "High"
|
| 598 |
+
elif risk_score >= 4:
|
| 599 |
+
risk_level = "Moderate"
|
| 600 |
+
else:
|
| 601 |
+
risk_level = "Low"
|
| 602 |
+
|
| 603 |
+
return {
|
| 604 |
+
'risk_score': risk_score,
|
| 605 |
+
'risk_level': risk_level,
|
| 606 |
+
'risk_factors': risk_factors
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
except Exception as e:
|
| 610 |
+
logging.error(f"Risk assessment error: {e}")
|
| 611 |
return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}
|