SmartHeal commited on
Commit
862d7cb
·
verified ·
1 Parent(s): 3c0b441

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +519 -157
src/ai_processor.py CHANGED
@@ -36,14 +36,19 @@ knowledge_base_cache = {}
36
 
37
  # =============== LAZY LOADING FUNCTIONS (CPU-SAFE) ===============
38
  def load_yolo_model(yolo_model_path):
 
39
  from ultralytics import YOLO
40
  return YOLO(yolo_model_path)
41
 
42
  def load_segmentation_model(seg_model_path):
 
 
 
43
  from tensorflow.keras.models import load_model
44
  return load_model(seg_model_path, compile=False)
45
 
46
  def load_classification_pipeline(hf_token):
 
47
  from transformers import pipeline
48
  return pipeline(
49
  "image-classification",
@@ -53,6 +58,7 @@ def load_classification_pipeline(hf_token):
53
  )
54
 
55
  def load_embedding_model():
 
56
  return HuggingFaceEmbeddings(
57
  model_name="sentence-transformers/all-MiniLM-L6-v2",
58
  model_kwargs={"device": "cpu"}
@@ -60,28 +66,34 @@ def load_embedding_model():
60
 
61
  # =============== MODEL INITIALIZATION ===============
62
  def initialize_cpu_models():
 
63
  global models_cache
 
64
  if HF_TOKEN:
65
  HfFolder.save_token(HF_TOKEN)
66
  logging.info("✅ HuggingFace token set")
 
67
  if "det" not in models_cache:
68
  try:
69
  models_cache["det"] = load_yolo_model(YOLO_MODEL_PATH)
70
  logging.info("✅ YOLO model loaded (CPU only)")
71
  except Exception as e:
72
  logging.error(f"YOLO load failed: {e}")
 
73
  if "seg" not in models_cache:
74
  try:
75
  models_cache["seg"] = load_segmentation_model(SEG_MODEL_PATH)
76
  logging.info("✅ Segmentation model loaded (CPU)")
77
  except Exception as e:
78
  logging.warning(f"Segmentation model not available: {e}")
 
79
  if "cls" not in models_cache:
80
  try:
81
  models_cache["cls"] = load_classification_pipeline(HF_TOKEN)
82
  logging.info("✅ Classification pipeline loaded (CPU)")
83
  except Exception as e:
84
  logging.warning(f"Classification pipeline not available: {e}")
 
85
  if "embedding_model" not in models_cache:
86
  try:
87
  models_cache["embedding_model"] = load_embedding_model()
@@ -90,9 +102,11 @@ def initialize_cpu_models():
90
  logging.warning(f"Embedding model not available: {e}")
91
 
92
  def setup_knowledge_base():
 
93
  global knowledge_base_cache
94
  if "vector_store" in knowledge_base_cache:
95
  return
 
96
  docs = []
97
  for pdf_path in GUIDELINE_PDFS:
98
  if os.path.exists(pdf_path):
@@ -102,6 +116,7 @@ def setup_knowledge_base():
102
  logging.info(f"Loaded PDF: {pdf_path}")
103
  except Exception as e:
104
  logging.warning(f"Failed to load PDF {pdf_path}: {e}")
 
105
  if docs and "embedding_model" in models_cache:
106
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
107
  chunks = splitter.split_documents(docs)
@@ -111,7 +126,7 @@ def setup_knowledge_base():
111
  knowledge_base_cache["vector_store"] = None
112
  logging.warning("Knowledge base unavailable")
113
 
114
- # Initialize models at startup
115
  initialize_cpu_models()
116
  setup_knowledge_base()
117
 
@@ -125,6 +140,8 @@ def generate_medgemma_report(
125
  segmentation_image_path,
126
  max_new_tokens=None,
127
  ):
 
 
128
  import torch
129
  from transformers import pipeline
130
  from PIL import Image
@@ -141,6 +158,7 @@ def generate_medgemma_report(
141
  "patient context."
142
  )
143
 
 
144
  if not hasattr(generate_medgemma_report, "_pipe"):
145
  try:
146
  generate_medgemma_report._pipe = pipeline(
@@ -158,16 +176,44 @@ def generate_medgemma_report(
158
 
159
  pipe = generate_medgemma_report._pipe
160
 
 
 
 
 
 
 
 
 
161
  msgs = [
162
  {"role": "system", "content": [{"type": "text", "text": default_system_prompt}]},
163
  {"role": "user", "content": []},
164
  ]
165
 
166
- for path in (detection_image_path, segmentation_image_path):
167
- if path and os.path.exists(path):
168
- msgs[1]["content"].append({"type": "image", "image": Image.open(path)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results.get('wound_type','Unknown')}"
171
  msgs[1]["content"].append({"type": "text", "text": prompt})
172
 
173
  try:
@@ -188,176 +234,492 @@ class AIProcessor:
188
  self.hf_token = HF_TOKEN
189
 
190
  def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
191
- img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
192
- yolo = self.models_cache.get("det")
193
- if yolo is None:
194
- raise RuntimeError("YOLO model ('det') not loaded")
 
 
 
 
 
195
 
196
- res = yolo.predict(img_cv, verbose=False, device="cpu")[0]
197
- if not res.boxes:
198
- raise ValueError("No wound detected")
 
199
 
200
- # Safely unpack detection boxes
201
- try:
202
- xyxy = res.boxes.xyxy.cpu().numpy()
203
- if xyxy.shape[0] == 0:
204
- raise ValueError("No detection boxes found")
205
- x1, y1, x2, y2 = xyxy[0]
206
- except Exception as e:
207
- logging.warning(f"Error unpacking detection boxes: {e}")
208
- raise
209
-
210
- region = img_cv[int(y1):int(y2), int(x1):int(x2)]
211
-
212
- # Save detection overlay
213
- det_vis = img_cv.copy()
214
- cv2.rectangle(det_vis, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
215
- os.makedirs(f"{self.uploads_dir}/analysis", exist_ok=True)
216
- ts = datetime.now().strftime("%Y%m%d_%H%M%S")
217
- det_path = f"{self.uploads_dir}/analysis/detection_{ts}.png"
218
- cv2.imwrite(det_path, det_vis)
219
-
220
- # Segmentation
221
- length, breadth, area = 0, 0, 0
222
- seg_path = None
223
- seg_model = self.models_cache.get("seg")
224
- if seg_model:
225
- try:
226
- h, w = seg_model.input_shape[1:3]
227
- inp = cv2.resize(region, (w, h)) / 255.0
228
- mask_pred = seg_model.predict(inp[None])
229
- if mask_pred.shape[1:3] != (h, w):
230
- # Resize if needed
231
- mask_pred = np.squeeze(mask_pred)
232
- mask = (mask_pred[0, :, :, 0] > 0.5).astype(np.uint8)
233
- mask_rs = cv2.resize(mask, (region.shape[1], region.shape[0]), interpolation=cv2.INTER_NEAREST)
234
- # Save segmentation visualization
235
- ov = region.copy()
236
- ov[mask_rs == 1] = [0, 0, 255]
237
- seg_vis = cv2.addWeighted(region, 0.7, ov, 0.3, 0)
238
- seg_path = f"{self.uploads_dir}/analysis/segmentation_{ts}.png"
239
- cv2.imwrite(seg_path, seg_vis)
240
-
241
- # Find contours
242
- cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
243
- if cnts:
244
- cnt = max(cnts, key=cv2.contourArea)
245
- x, y, w_box, h_box = cv2.boundingRect(cnt)
246
- length = round(h_box / self.px_per_cm, 2)
247
- breadth = round(w_box / self.px_per_cm, 2)
248
- area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
249
- except Exception as e:
250
- logging.warning(f"Segmentation processing error: {e}")
251
 
252
- # Classification
253
- wound_type = "Unknown"
254
- cls_pipe = self.models_cache.get("cls")
255
- if cls_pipe:
256
- try:
257
- preds = cls_pipe(Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB)))
258
- if preds:
259
- wound_type = max(preds, key=lambda x: x["score"])["label"]
260
- except Exception as e:
261
- logging.warning(f"Classification error: {e}")
262
-
263
- return {
264
- "wound_type": wound_type,
265
- "length_cm": length,
266
- "breadth_cm": breadth,
267
- "surface_area_cm2": area,
268
- "detection_confidence": float(res.boxes.conf[0].cpu().item()),
269
- "detection_image_path": det_path,
270
- "segmentation_image_path": seg_path,
271
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  def query_guidelines(self, query: str) -> str:
274
- vs = self.knowledge_base_cache.get("vector_store")
275
- if not vs:
276
- return "Clinical guidelines unavailable"
277
- docs = vs.as_retriever(search_kwargs={"k": 10}).invoke(query)
278
- return "\n\n".join(
279
- f"Source: {d.metadata.get('source','?')}, Page: {d.metadata.get('page','?')}\n{d.page_content}" for d in docs
280
- )
281
-
282
- def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None):
283
- det_path = visual_results.get("detection_image_path", "")
284
- seg_path = visual_results.get("segmentation_image_path", "")
285
- report = generate_medgemma_report(patient_info, visual_results, guideline_context, det_path, seg_path, max_new_tokens)
286
- if report:
287
- return report
288
- return self._generate_fallback_report(patient_info, visual_results, guideline_context)
289
-
290
- def _generate_fallback_report(self, patient_info, visual_results, guideline_context):
291
- return (
292
- f"# Fallback Report\n{patient_info}\n"
293
- f"Type: {visual_results.get('wound_type','Unknown')}\n"
294
- f"Detection Image: {visual_results.get('detection_image_path','N/A')}\n"
295
- f"Segmentation Image: {visual_results.get('segmentation_image_path','N/A')}\n"
296
- f"Guidelines: {guideline_context[:200]}..."
297
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
300
- os.makedirs(self.uploads_dir, exist_ok=True)
301
- fn = f"{datetime.now():%Y%m%d_%H%M%S}.png"
302
- path = os.path.join(self.uploads_dir, fn)
303
- image_pil.convert("RGB").save(path)
304
- if self.hf_token and self.dataset_id:
305
- try:
306
- HfApi().upload_file(
307
- path_or_fileobj=path,
308
- path_in_repo=f"images/{fn}",
309
- repo_id=self.dataset_id,
310
- repo_type="dataset",
311
- )
312
- logging.info("✅ Image committed to HF dataset")
313
- except Exception as e:
314
- logging.warning(f"HF upload failed: {e}")
315
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
- def full_analysis_pipeline(self, image, questionnaire_data):
 
318
  try:
319
- saved = self.save_and_commit_image(image)
320
- vis = self.perform_visual_analysis(image)
321
- info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v)
322
- gc = self.query_guidelines(info)
323
- report = self.generate_final_report(info, vis, gc, image)
324
- return {'success': True, 'visual_analysis': vis, 'report': report, 'saved_image_path': saved}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  except Exception as e:
326
  logging.error(f"Pipeline error: {e}")
327
- return {'success': False, 'error': str(e)}
328
-
329
- def analyze_wound(self, image, questionnaire_data):
330
- if isinstance(image, str):
331
- image = Image.open(image)
332
- return self.full_analysis_pipeline(image, questionnaire_data)
333
-
334
- def _assess_risk_legacy(self, questionnaire_data):
335
- risk_factors, risk_score = [], 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  try:
 
337
  age = questionnaire_data.get('patient_age', 0)
 
 
 
 
 
 
338
  if age > 65:
339
- risk_factors.append("Advanced age (>65)"); risk_score += 2
 
340
  elif age > 50:
341
- risk_factors.append("Older adult (50-65)"); risk_score += 1
 
342
 
343
- dur = questionnaire_data.get('wound_duration', '').lower()
344
- if any(t in dur for t in ['month','year']):
345
- risk_factors.append("Chronic wound (>4 weeks)"); risk_score += 3
 
 
 
 
 
 
 
 
 
346
 
 
347
  pain = questionnaire_data.get('pain_level', 0)
 
 
 
 
 
 
348
  if pain >= 7:
349
- risk_factors.append("High pain level"); risk_score += 2
350
-
351
- hist = questionnaire_data.get('medical_history','').lower()
352
- if 'diabetes' in hist:
353
- risk_factors.append("Diabetes mellitus"); risk_score += 3
354
- if 'vascular' in hist:
355
- risk_factors.append("Vascular issues"); risk_score += 2
356
- if 'immune' in hist:
357
- risk_factors.append("Immune compromise"); risk_score += 2
358
-
359
- level = ("High" if risk_score >= 7 else "Moderate" if risk_score >= 4 else "Low")
360
- return {'risk_score': risk_score, 'risk_level': level, 'risk_factors': risk_factors}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  except Exception as e:
362
  logging.error(f"Risk assessment error: {e}")
363
- return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # =============== LAZY LOADING FUNCTIONS (CPU-SAFE) ===============
38
  def load_yolo_model(yolo_model_path):
39
+ """Lazy import and load YOLO model to avoid CUDA initialization."""
40
  from ultralytics import YOLO
41
  return YOLO(yolo_model_path)
42
 
43
  def load_segmentation_model(seg_model_path):
44
+ """Lazy import and load segmentation model."""
45
+ import tensorflow as tf
46
+ tf.config.set_visible_devices([], 'GPU') # Force CPU for TensorFlow
47
  from tensorflow.keras.models import load_model
48
  return load_model(seg_model_path, compile=False)
49
 
50
  def load_classification_pipeline(hf_token):
51
+ """Lazy import and load classification pipeline (CPU only)."""
52
  from transformers import pipeline
53
  return pipeline(
54
  "image-classification",
 
58
  )
59
 
60
  def load_embedding_model():
61
+ """Load embedding model for knowledge base."""
62
  return HuggingFaceEmbeddings(
63
  model_name="sentence-transformers/all-MiniLM-L6-v2",
64
  model_kwargs={"device": "cpu"}
 
66
 
67
  # =============== MODEL INITIALIZATION ===============
68
  def initialize_cpu_models():
69
+ """Initialize all CPU-only models once."""
70
  global models_cache
71
+
72
  if HF_TOKEN:
73
  HfFolder.save_token(HF_TOKEN)
74
  logging.info("✅ HuggingFace token set")
75
+
76
  if "det" not in models_cache:
77
  try:
78
  models_cache["det"] = load_yolo_model(YOLO_MODEL_PATH)
79
  logging.info("✅ YOLO model loaded (CPU only)")
80
  except Exception as e:
81
  logging.error(f"YOLO load failed: {e}")
82
+
83
  if "seg" not in models_cache:
84
  try:
85
  models_cache["seg"] = load_segmentation_model(SEG_MODEL_PATH)
86
  logging.info("✅ Segmentation model loaded (CPU)")
87
  except Exception as e:
88
  logging.warning(f"Segmentation model not available: {e}")
89
+
90
  if "cls" not in models_cache:
91
  try:
92
  models_cache["cls"] = load_classification_pipeline(HF_TOKEN)
93
  logging.info("✅ Classification pipeline loaded (CPU)")
94
  except Exception as e:
95
  logging.warning(f"Classification pipeline not available: {e}")
96
+
97
  if "embedding_model" not in models_cache:
98
  try:
99
  models_cache["embedding_model"] = load_embedding_model()
 
102
  logging.warning(f"Embedding model not available: {e}")
103
 
104
  def setup_knowledge_base():
105
+ """Load PDF documents and create FAISS vector store."""
106
  global knowledge_base_cache
107
  if "vector_store" in knowledge_base_cache:
108
  return
109
+
110
  docs = []
111
  for pdf_path in GUIDELINE_PDFS:
112
  if os.path.exists(pdf_path):
 
116
  logging.info(f"Loaded PDF: {pdf_path}")
117
  except Exception as e:
118
  logging.warning(f"Failed to load PDF {pdf_path}: {e}")
119
+
120
  if docs and "embedding_model" in models_cache:
121
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
122
  chunks = splitter.split_documents(docs)
 
126
  knowledge_base_cache["vector_store"] = None
127
  logging.warning("Knowledge base unavailable")
128
 
129
+ # Initialize models on app startup
130
  initialize_cpu_models()
131
  setup_knowledge_base()
132
 
 
140
  segmentation_image_path,
141
  max_new_tokens=None,
142
  ):
143
+ """GPU-only function for MedGemma report generation."""
144
+ # Import GPU libraries ONLY here
145
  import torch
146
  from transformers import pipeline
147
  from PIL import Image
 
158
  "patient context."
159
  )
160
 
161
+ # Lazy-load MedGemma pipeline on GPU
162
  if not hasattr(generate_medgemma_report, "_pipe"):
163
  try:
164
  generate_medgemma_report._pipe = pipeline(
 
176
 
177
  pipe = generate_medgemma_report._pipe
178
 
179
+ # Load the original image that was analyzed
180
+ original_image = None
181
+ if detection_image_path and os.path.exists(detection_image_path.replace('detection_', 'original_')):
182
+ original_image = Image.open(detection_image_path.replace('detection_', 'original_'))
183
+ elif segmentation_image_path and os.path.exists(segmentation_image_path.replace('segmentation_', 'original_')):
184
+ original_image = Image.open(segmentation_image_path.replace('segmentation_', 'original_'))
185
+
186
+ # Compose messages
187
  msgs = [
188
  {"role": "system", "content": [{"type": "text", "text": default_system_prompt}]},
189
  {"role": "user", "content": []},
190
  ]
191
 
192
+ # Attach images if available
193
+ if original_image:
194
+ msgs[1]["content"].append({"type": "image", "image": original_image})
195
+ else:
196
+ # Fallback to detection or segmentation images
197
+ for path in (detection_image_path, segmentation_image_path):
198
+ if path and os.path.exists(path):
199
+ msgs[1]["content"].append({"type": "image", "image": Image.open(path)})
200
+ break
201
+
202
+ # Attach text prompt
203
+ prompt = f"""## Patient Information
204
+ {patient_info}
205
+
206
+ ## Visual Analysis Results
207
+ - Wound Type: {visual_results.get('wound_type','Unknown')}
208
+ - Dimensions: {visual_results.get('length_cm', 0)} x {visual_results.get('breadth_cm', 0)} cm
209
+ - Surface Area: {visual_results.get('surface_area_cm2', 0)} cm²
210
+ - Detection Confidence: {visual_results.get('detection_confidence', 0):.2f}
211
+
212
+ ## Clinical Guidelines Context
213
+ {guideline_context[:1500]}...
214
+
215
+ Please provide a comprehensive wound care assessment and treatment recommendations based on the image and provided information."""
216
 
 
217
  msgs[1]["content"].append({"type": "text", "text": prompt})
218
 
219
  try:
 
234
  self.hf_token = HF_TOKEN
235
 
236
  def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
237
+ """Performs the full visual analysis pipeline."""
238
+ try:
239
+ # Convert PIL to OpenCV format
240
+ image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
241
+
242
+ # YOLO Detection
243
+ yolo_model = self.models_cache.get("det")
244
+ if yolo_model is None:
245
+ raise RuntimeError("YOLO model ('det') not loaded")
246
 
247
+ results = yolo_model.predict(image_cv, verbose=False, device="cpu")
248
+
249
+ if not results or not results[0].boxes or len(results[0].boxes) == 0:
250
+ raise ValueError("No wound detected in the image")
251
 
252
+ # Extract bounding box - handle different output formats
253
+ boxes_data = results[0].boxes.xyxy.cpu().numpy()
254
+
255
+ if len(boxes_data.shape) == 1:
256
+ # Single detection case
257
+ if len(boxes_data) != 4:
258
+ raise ValueError(f"Expected 4 coordinates, got {len(boxes_data)}")
259
+ x1, y1, x2, y2 = boxes_data.astype(int)
260
+ else:
261
+ # Multiple detections - take the first one
262
+ if boxes_data.shape[1] != 4:
263
+ raise ValueError(f"Expected 4 coordinates per box, got {boxes_data.shape[1]}")
264
+ x1, y1, x2, y2 = boxes_data[0].astype(int)
265
+
266
+ # Validate coordinates
267
+ if x1 >= x2 or y1 >= y2 or x1 < 0 or y1 < 0:
268
+ raise ValueError("Invalid bounding box coordinates")
269
+
270
+ # Extract wound region
271
+ detected_region_cv = image_cv[y1:y2, x1:x2]
272
+
273
+ if detected_region_cv.size == 0:
274
+ raise ValueError("Detected region is empty")
275
+
276
+ # Save detection visualization
277
+ det_vis = image_cv.copy()
278
+ cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
279
+ os.makedirs(f"{self.uploads_dir}/analysis", exist_ok=True)
280
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
281
+ det_path = f"{self.uploads_dir}/analysis/detection_{ts}.png"
282
+ cv2.imwrite(det_path, det_vis)
283
+
284
+ # Save original image for reference
285
+ original_path = f"{self.uploads_dir}/analysis/original_{ts}.png"
286
+ cv2.imwrite(original_path, image_cv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
+ # Segmentation Analysis
289
+ length = breadth = area = 0
290
+ seg_path = None
291
+
292
+ seg_model = self.models_cache.get("seg")
293
+ if seg_model is not None:
294
+ try:
295
+ # Get input shape from model
296
+ input_shape = seg_model.input_shape
297
+ if len(input_shape) >= 3:
298
+ h, w = input_shape[1:3]
299
+ else:
300
+ h, w = 256, 256 # Default fallback
301
+
302
+ # Prepare input for segmentation
303
+ resized = cv2.resize(detected_region_cv, (w, h))
304
+ normalized_input = np.expand_dims(resized / 255.0, 0)
305
+
306
+ # Predict mask
307
+ mask_pred = seg_model.predict(normalized_input, verbose=0)
308
+
309
+ # Handle different output formats
310
+ if len(mask_pred.shape) == 4:
311
+ mask_np = (mask_pred[0, :, :, 0] > 0.5).astype(np.uint8)
312
+ elif len(mask_pred.shape) == 3:
313
+ mask_np = (mask_pred[0, :, :] > 0.5).astype(np.uint8)
314
+ else:
315
+ raise ValueError(f"Unexpected segmentation output shape: {mask_pred.shape}")
316
+
317
+ # Resize mask back to detection region size
318
+ mask_resized = cv2.resize(
319
+ mask_np * 255,
320
+ (detected_region_cv.shape[1], detected_region_cv.shape[0]),
321
+ interpolation=cv2.INTER_NEAREST
322
+ )
323
+ mask_resized = (mask_resized > 127).astype(np.uint8)
324
+
325
+ # Create segmentation visualization
326
+ overlay = detected_region_cv.copy()
327
+ overlay[mask_resized == 1] = [0, 0, 255] # Red overlay for wound area
328
+ seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0)
329
+
330
+ seg_path = f"{self.uploads_dir}/analysis/segmentation_{ts}.png"
331
+ cv2.imwrite(seg_path, seg_vis)
332
+
333
+ # Calculate measurements
334
+ contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
335
+ if contours:
336
+ # Get the largest contour
337
+ largest_contour = max(contours, key=cv2.contourArea)
338
+
339
+ # Calculate bounding rectangle
340
+ bbox = cv2.boundingRect(largest_contour)
341
+ if len(bbox) == 4:
342
+ x, y, w_box, h_box = bbox
343
+ length = round(h_box / self.px_per_cm, 2)
344
+ breadth = round(w_box / self.px_per_cm, 2)
345
+ area = round(cv2.contourArea(largest_contour) / (self.px_per_cm ** 2), 2)
346
+ else:
347
+ logging.warning(f"Unexpected bounding rect format: {bbox}")
348
+ else:
349
+ logging.info("No contours found in segmentation mask")
350
+
351
+ except Exception as seg_error:
352
+ logging.error(f"Segmentation processing error: {seg_error}")
353
+ seg_path = None
354
+
355
+ # Wound Classification
356
+ wound_type = "Unknown"
357
+ cls_pipeline = self.models_cache.get("cls")
358
+ if cls_pipeline is not None:
359
+ try:
360
+ detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
361
+ predictions = cls_pipeline(detected_image_pil)
362
+ if predictions and len(predictions) > 0:
363
+ best_pred = max(predictions, key=lambda x: x.get("score", 0))
364
+ wound_type = best_pred.get("label", "Unknown")
365
+ except Exception as cls_error:
366
+ logging.warning(f"Classification failed: {cls_error}")
367
+
368
+ # Extract confidence score
369
+ confidence = 0.0
370
+ if results[0].boxes.conf is not None and len(results[0].boxes.conf) > 0:
371
+ confidence = float(results[0].boxes.conf[0].cpu().item())
372
+
373
+ return {
374
+ "wound_type": wound_type,
375
+ "length_cm": length,
376
+ "breadth_cm": breadth,
377
+ "surface_area_cm2": area,
378
+ "detection_confidence": confidence,
379
+ "detection_image_path": det_path,
380
+ "segmentation_image_path": seg_path,
381
+ "original_image_path": original_path
382
+ }
383
+
384
+ except Exception as e:
385
+ logging.error(f"Visual analysis failed: {e}")
386
+ raise e
387
 
388
  def query_guidelines(self, query: str) -> str:
389
+ """Query the knowledge base for relevant information."""
390
+ try:
391
+ vector_store = self.knowledge_base_cache.get("vector_store")
392
+ if not vector_store:
393
+ return "Clinical guidelines unavailable - knowledge base not loaded"
394
+
395
+ retriever = vector_store.as_retriever(search_kwargs={"k": 10})
396
+ docs = retriever.invoke(query)
397
+
398
+ if not docs:
399
+ return "No relevant guidelines found for the query"
400
+
401
+ context = "\n\n".join([
402
+ f"Source: {doc.metadata.get('source', 'Unknown')}, Page: {doc.metadata.get('page', 'N/A')}\n{doc.page_content}"
403
+ for doc in docs
404
+ ])
405
+
406
+ return context
407
+
408
+ except Exception as e:
409
+ logging.error(f"Guidelines query failed: {e}")
410
+ return f"Guidelines query failed: {str(e)}"
411
+
412
+ def generate_final_report(
413
+ self, patient_info: str, visual_results: dict, guideline_context: str,
414
+ image_pil: Image.Image, max_new_tokens: int = None
415
+ ) -> str:
416
+ """Generate final report using MedGemma GPU pipeline."""
417
+ try:
418
+ det_path = visual_results.get("detection_image_path", "")
419
+ seg_path = visual_results.get("segmentation_image_path", "")
420
+
421
+ report = generate_medgemma_report(
422
+ patient_info, visual_results, guideline_context,
423
+ det_path, seg_path, max_new_tokens
424
+ )
425
+
426
+ if report and report.strip():
427
+ return report
428
+ else:
429
+ return self._generate_fallback_report(patient_info, visual_results, guideline_context)
430
+
431
+ except Exception as e:
432
+ logging.error(f"MedGemma report generation failed: {e}")
433
+ return self._generate_fallback_report(patient_info, visual_results, guideline_context)
434
+
435
+ def _generate_fallback_report(
436
+ self, patient_info: str, visual_results: dict, guideline_context: str
437
+ ) -> str:
438
+ """Generate fallback report if MedGemma fails."""
439
+
440
+ report = f"""# Wound Analysis Report
441
+
442
+ ## Patient Information
443
+ {patient_info}
444
+
445
+ ## Visual Analysis Results
446
+ - **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
447
+ - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
448
+ - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
449
+ - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.2f}
450
+
451
+ ## Analysis Images
452
+ - **Detection Image**: {visual_results.get('detection_image_path', 'N/A')}
453
+ - **Segmentation Image**: {visual_results.get('segmentation_image_path', 'N/A')}
454
+
455
+ ## Clinical Guidelines Context
456
+ {guideline_context[:1000]}{'...' if len(guideline_context) > 1000 else ''}
457
+
458
+ ## Assessment Summary
459
+ Based on the automated visual analysis, the wound has been classified as **{visual_results.get('wound_type', 'Unknown')}** with measurable dimensions. The detection confidence indicates the reliability of the automated assessment.
460
+
461
+ ## Recommendations
462
+ 1. **Clinical Evaluation**: This automated analysis should be supplemented with professional clinical assessment
463
+ 2. **Documentation**: Regular monitoring and documentation of wound progression is recommended
464
+ 3. **Treatment Planning**: Develop appropriate treatment protocol based on wound characteristics and patient factors
465
+ 4. **Follow-up**: Schedule appropriate follow-up intervals based on wound severity and healing progress
466
+
467
+ ## Important Notes
468
+ - This is an automated analysis and should not replace professional medical judgment
469
+ - All measurements are estimates based on computer vision algorithms
470
+ - Clinical correlation is essential for proper diagnosis and treatment planning
471
+ - Consider patient-specific factors not captured in this automated assessment
472
+
473
+ ## Disclaimer
474
+ This automated analysis is provided for informational purposes only and does not constitute medical advice. Always consult with qualified healthcare professionals for proper diagnosis and treatment.
475
+ """
476
+ return report
477
 
478
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
479
+ """Save image locally and optionally commit to HF dataset."""
480
+ try:
481
+ os.makedirs(self.uploads_dir, exist_ok=True)
482
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
483
+ filename = f"{timestamp}.png"
484
+ path = os.path.join(self.uploads_dir, filename)
485
+
486
+ # Save image
487
+ image_pil.convert("RGB").save(path)
488
+ logging.info(f"✅ Image saved locally: {path}")
489
+
490
+ # Upload to HuggingFace dataset if configured
491
+ if self.hf_token and self.dataset_id:
492
+ try:
493
+ api = HfApi()
494
+ api.upload_file(
495
+ path_or_fileobj=path,
496
+ path_in_repo=f"images/{filename}",
497
+ repo_id=self.dataset_id,
498
+ repo_type="dataset",
499
+ token=self.hf_token,
500
+ commit_message=f"Upload wound image: {filename}"
501
+ )
502
+ logging.info("✅ Image committed to HF dataset")
503
+ except Exception as e:
504
+ logging.warning(f"HF upload failed: {e}")
505
+
506
+ return path
507
+
508
+ except Exception as e:
509
+ logging.error(f"Failed to save image: {e}")
510
+ return ""
511
 
512
+ def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict:
513
+ """Run full analysis pipeline."""
514
  try:
515
+ # Save image first
516
+ saved_path = self.save_and_commit_image(image_pil)
517
+ logging.info(f"Image saved: {saved_path}")
518
+
519
+ # Perform visual analysis
520
+ visual_results = self.perform_visual_analysis(image_pil)
521
+ logging.info(f"Visual analysis completed: {visual_results}")
522
+
523
+ # Process questionnaire data
524
+ patient_info = ", ".join(f"{k}: {v}" for k, v in questionnaire_data.items() if v)
525
+ if not patient_info:
526
+ patient_info = "No patient information provided"
527
+
528
+ # Query guidelines
529
+ query = f"wound care treatment for {visual_results.get('wound_type', 'wound')} "
530
+ if questionnaire_data.get('diabetic') == 'Yes':
531
+ query += "diabetic patient "
532
+ if questionnaire_data.get('infection') == 'Yes':
533
+ query += "with infection signs "
534
+
535
+ guideline_context = self.query_guidelines(query)
536
+ logging.info("Guidelines queried successfully")
537
+
538
+ # Generate final report
539
+ report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil)
540
+ logging.info("Report generated successfully")
541
+
542
+ return {
543
+ 'success': True,
544
+ 'visual_analysis': visual_results,
545
+ 'report': report,
546
+ 'saved_image_path': saved_path,
547
+ 'guideline_context': guideline_context[:500] + "..." if len(guideline_context) > 500 else guideline_context
548
+ }
549
+
550
  except Exception as e:
551
  logging.error(f"Pipeline error: {e}")
552
+ return {
553
+ 'success': False,
554
+ 'error': str(e),
555
+ 'visual_analysis': {},
556
+ 'report': f"Analysis failed: {str(e)}",
557
+ 'saved_image_path': None,
558
+ 'guideline_context': ""
559
+ }
560
+
561
+ def analyze_wound(self, image, questionnaire_data: dict) -> dict:
562
+ """Main analysis entry point - maintains original function name."""
563
+ try:
564
+ # Handle different image input formats
565
+ if isinstance(image, str):
566
+ if os.path.exists(image):
567
+ image_pil = Image.open(image)
568
+ else:
569
+ raise ValueError(f"Image file not found: {image}")
570
+ elif isinstance(image, Image.Image):
571
+ image_pil = image
572
+ elif isinstance(image, np.ndarray):
573
+ image_pil = Image.fromarray(image)
574
+ else:
575
+ raise ValueError(f"Unsupported image type: {type(image)}")
576
+
577
+ return self.full_analysis_pipeline(image_pil, questionnaire_data)
578
+
579
+ except Exception as e:
580
+ logging.error(f"Wound analysis error: {e}")
581
+ return {
582
+ 'success': False,
583
+ 'error': str(e),
584
+ 'visual_analysis': {},
585
+ 'report': f"Analysis initialization failed: {str(e)}",
586
+ 'saved_image_path': None,
587
+ 'guideline_context': ""
588
+ }
589
+
590
+ def _assess_risk_legacy(self, questionnaire_data: dict) -> dict:
591
+ """Legacy risk assessment function - maintains original function name."""
592
+ risk_factors = []
593
+ risk_score = 0
594
+
595
  try:
596
+ # Age assessment
597
  age = questionnaire_data.get('patient_age', 0)
598
+ if isinstance(age, str):
599
+ try:
600
+ age = int(age)
601
+ except ValueError:
602
+ age = 0
603
+
604
  if age > 65:
605
+ risk_factors.append("Advanced age (>65)")
606
+ risk_score += 2
607
  elif age > 50:
608
+ risk_factors.append("Older adult (50-65)")
609
+ risk_score += 1
610
 
611
+ # Wound duration assessment
612
+ duration = str(questionnaire_data.get('wound_duration', '')).lower()
613
+ if any(term in duration for term in ['month', 'months', 'year', 'years']):
614
+ risk_factors.append("Chronic wound (>4 weeks)")
615
+ risk_score += 3
616
+ elif any(term in duration for term in ['week', 'weeks']):
617
+ # Try to extract number of weeks
618
+ import re
619
+ weeks_match = re.search(r'(\d+)\s*week', duration)
620
+ if weeks_match and int(weeks_match.group(1)) > 4:
621
+ risk_factors.append("Chronic wound (>4 weeks)")
622
+ risk_score += 3
623
 
624
+ # Pain level assessment
625
  pain = questionnaire_data.get('pain_level', 0)
626
+ if isinstance(pain, str):
627
+ try:
628
+ pain = float(pain)
629
+ except ValueError:
630
+ pain = 0
631
+
632
  if pain >= 7:
633
+ risk_factors.append("High pain level (≥7/10)")
634
+ risk_score += 2
635
+ elif pain >= 5:
636
+ risk_factors.append("Moderate pain level (5-6/10)")
637
+ risk_score += 1
638
+
639
+ # Medical history assessment
640
+ medical_history = str(questionnaire_data.get('medical_history', '')).lower()
641
+ diabetic_status = str(questionnaire_data.get('diabetic', '')).lower()
642
+
643
+ if 'diabetes' in medical_history or 'yes' in diabetic_status:
644
+ risk_factors.append("Diabetes mellitus")
645
+ risk_score += 3
646
+
647
+ if any(term in medical_history for term in ['vascular', 'circulation', 'arterial', 'venous']):
648
+ risk_factors.append("Vascular disease")
649
+ risk_score += 2
650
+
651
+ if any(term in medical_history for term in ['immune', 'immunocompromised', 'steroid', 'chemotherapy']):
652
+ risk_factors.append("Immune system compromise")
653
+ risk_score += 2
654
+
655
+ if any(term in medical_history for term in ['smoking', 'smoker', 'tobacco']):
656
+ risk_factors.append("Smoking history")
657
+ risk_score += 2
658
+
659
+ # Infection signs
660
+ infection_signs = str(questionnaire_data.get('infection', '')).lower()
661
+ if 'yes' in infection_signs:
662
+ risk_factors.append("Signs of infection present")
663
+ risk_score += 3
664
+
665
+ # Moisture level
666
+ moisture = str(questionnaire_data.get('moisture', '')).lower()
667
+ if any(term in moisture for term in ['wet', 'heavy', 'excessive']):
668
+ risk_factors.append("Excessive wound exudate")
669
+ risk_score += 1
670
+
671
+ # Determine risk level
672
+ if risk_score >= 8:
673
+ risk_level = "Very High"
674
+ elif risk_score >= 6:
675
+ risk_level = "High"
676
+ elif risk_score >= 3:
677
+ risk_level = "Moderate"
678
+ else:
679
+ risk_level = "Low"
680
+
681
+ return {
682
+ 'risk_score': risk_score,
683
+ 'risk_level': risk_level,
684
+ 'risk_factors': risk_factors,
685
+ 'recommendations': self._get_risk_recommendations(risk_level, risk_factors)
686
+ }
687
+
688
  except Exception as e:
689
  logging.error(f"Risk assessment error: {e}")
690
+ return {
691
+ 'risk_score': 0,
692
+ 'risk_level': 'Unknown',
693
+ 'risk_factors': [],
694
+ 'recommendations': ["Unable to assess risk due to data processing error"]
695
+ }
696
+
697
+ def _get_risk_recommendations(self, risk_level: str, risk_factors: list) -> list:
698
+ """Generate risk-based recommendations."""
699
+ recommendations = []
700
+
701
+ if risk_level in ["High", "Very High"]:
702
+ recommendations.append("Urgent referral to wound care specialist recommended")
703
+ recommendations.append("Consider daily wound monitoring")
704
+ recommendations.append("Implement aggressive wound care protocol")
705
+ elif risk_level == "Moderate":
706
+ recommendations.append("Regular wound care follow-up every 2-3 days")
707
+ recommendations.append("Monitor for signs of deterioration")
708
+ else:
709
+ recommendations.append("Standard wound care monitoring")
710
+ recommendations.append("Weekly assessment recommended")
711
+
712
+ # Specific recommendations based on risk factors
713
+ if "Diabetes mellitus" in risk_factors:
714
+ recommendations.append("Strict glycemic control essential")
715
+ recommendations.append("Monitor for diabetic complications")
716
+
717
+ if "Signs of infection present" in risk_factors:
718
+ recommendations.append("Consider antibiotic therapy")
719
+ recommendations.append("Increase wound cleaning frequency")
720
+
721
+ if "Excessive wound exudate" in risk_factors:
722
+ recommendations.append("Use high-absorption dressings")
723
+ recommendations.append("More frequent dressing changes may be needed")
724
+
725
+ return recommendations