SmartHeal commited on
Commit
5b2e7ae
Β·
verified Β·
1 Parent(s): 89cc302

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +244 -658
src/ai_processor.py CHANGED
@@ -1,73 +1,42 @@
1
  # smartheal_ai_processor.py
2
- # Verbose, instrumented version β€” preserves public class/function names
3
- # Turn on deep logging: export LOGLEVEL=DEBUG SMARTHEAL_DEBUG=1
4
 
5
  import os
6
  import time
7
  import logging
8
  from datetime import datetime
9
- from typing import Optional, Dict, List, Tuple
10
-
11
- # ---- Environment defaults ----
12
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
- os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
14
- LOGLEVEL = os.getenv("LOGLEVEL", "INFO").upper()
15
- SMARTHEAL_DEBUG = os.getenv("SMARTHEAL_DEBUG", "0") == "1"
16
 
17
  import cv2
18
  import numpy as np
19
  from PIL import Image
20
- from PIL.ExifTags import TAGS
21
-
22
- # --- Logging config ---
23
- logging.basicConfig(
24
- level=getattr(logging, LOGLEVEL, logging.INFO),
25
- format="%(asctime)s - %(levelname)s - %(message)s",
26
- )
27
-
28
- def _log_kv(prefix: str, kv: Dict):
29
- logging.debug(prefix + " | " + " | ".join(f"{k}={v}" for k, v in kv.items()))
30
-
31
- # --- Optional Spaces GPU stub (harmless) ---
32
- try:
33
- import spaces as _spaces
34
- @_spaces.GPU(enable_queue=False)
35
- def smartheal_gpu_stub(ping: int = 0) -> str:
36
- return "ready"
37
- logging.info("Registered @spaces.GPU stub (enable_queue=False).")
38
- except Exception:
39
- pass
40
 
 
 
 
 
41
  UPLOADS_DIR = "uploads"
42
  os.makedirs(UPLOADS_DIR, exist_ok=True)
43
 
44
  HF_TOKEN = os.getenv("HF_TOKEN", None)
45
  YOLO_MODEL_PATH = "src/best.pt"
46
- SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
47
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
48
- DATASET_ID = "SmartHeal/wound-image-uploads"
49
- DEFAULT_PX_PER_CM = 38.0
50
- PX_PER_CM_MIN, PX_PER_CM_MAX = 5.0, 1200.0
51
-
52
- # Segmentation preprocessing knobs
53
- SEG_EXPECTS_RGB = os.getenv("SEG_EXPECTS_RGB", "1") == "1" # most TF models trained on RGB
54
- SEG_NORM = os.getenv("SEG_NORM", "0to1") # "0to1" | "imagenet"
55
- SEG_THRESH = float(os.getenv("SEG_THRESH", "0.5"))
56
 
 
57
  models_cache: Dict[str, object] = {}
58
  knowledge_base_cache: Dict[str, object] = {}
59
 
60
- # ---------- Lazy imports ----------
61
  def _import_ultralytics():
62
  from ultralytics import YOLO
63
  return YOLO
64
 
65
  def _import_tf_loader():
66
  import tensorflow as tf
67
- try:
68
- tf.config.set_visible_devices([], "GPU") # keep TF on CPU
69
- except Exception:
70
- pass
71
  from tensorflow.keras.models import load_model
72
  return load_model
73
 
@@ -91,50 +60,107 @@ def _import_hf_hub():
91
  from huggingface_hub import HfApi, HfFolder
92
  return HfApi, HfFolder
93
 
94
- # ---------- VLM (disabled by default) ----------
95
- def generate_medgemma_report(
96
- patient_info: str,
97
- visual_results: Dict,
98
- guideline_context: str,
99
- image_pil: Image.Image,
100
- max_new_tokens: Optional[int] = None,
101
- ) -> str:
102
- if os.getenv("SMARTHEAL_ENABLE_VLM", "0") != "1":
103
- return "⚠️ VLM disabled"
104
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  from transformers import pipeline
106
- pipe = pipeline(
107
- task="image-text-to-text",
108
- model="google/medgemma-4b-it",
109
- device_map=None,
110
- token=HF_TOKEN,
111
- trust_remote_code=True,
112
- model_kwargs={"low_cpu_mem_usage": True},
113
- )
114
- prompt = (
115
- "You are a medical AI assistant. Analyze this wound image and patient data.\n\n"
116
- f"Patient: {patient_info}\n"
117
- f"Wound: {visual_results.get('wound_type', 'Unknown')} - "
118
- f"{visual_results.get('length_cm', 0)}Γ—{visual_results.get('breadth_cm', 0)} cm\n\n"
119
- "Provide a structured report with:\n"
120
- "1. Clinical Summary\n2. Treatment Recommendations\n3. Risk Assessment\n4. Monitoring Plan\n"
121
- )
122
- messages = [{"role": "user", "content": [
123
- {"type": "image", "image": image_pil},
124
- {"type": "text", "text": prompt},
125
- ]}]
126
- out = pipe(text=messages, max_new_tokens=max_new_tokens or 600, do_sample=False, temperature=0.7)
127
- if out and len(out) > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  try:
129
- return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
130
  except Exception:
131
- return (out[0].get("generated_text", "") or "").strip() or "⚠️ Empty response"
132
- return "⚠️ No output generated"
133
- except Exception as e:
134
- logging.error(f"❌ MedGemma generation error: {e}")
135
- return "⚠️ VLM error"
 
 
 
 
 
 
136
 
137
- # ---------- Initialize CPU models ----------
138
  def load_yolo_model():
139
  YOLO = _import_ultralytics()
140
  return YOLO(YOLO_MODEL_PATH)
@@ -145,25 +171,32 @@ def load_segmentation_model():
145
 
146
  def load_classification_pipeline():
147
  pipe = _import_hf_cls()
148
- return pipe("image-classification", model="Hemg/Wound-classification", token=HF_TOKEN, device="cpu")
 
 
 
 
 
149
 
150
  def load_embedding_model():
151
  Emb = _import_embeddings()
152
  return Emb(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
153
 
154
  def initialize_cpu_models() -> None:
 
 
155
  if HF_TOKEN:
156
  try:
157
  HfApi, HfFolder = _import_hf_hub()
158
  HfFolder.save_token(HF_TOKEN)
159
- logging.info("βœ… HF token set")
160
  except Exception as e:
161
  logging.warning(f"HF token save failed: {e}")
162
 
163
  if "det" not in models_cache:
164
  try:
165
  models_cache["det"] = load_yolo_model()
166
- logging.info("βœ… YOLO loaded (CPU)")
167
  except Exception as e:
168
  logging.error(f"YOLO load failed: {e}")
169
 
@@ -171,46 +204,46 @@ def initialize_cpu_models() -> None:
171
  try:
172
  if os.path.exists(SEG_MODEL_PATH):
173
  models_cache["seg"] = load_segmentation_model()
174
- m = models_cache["seg"]
175
- ishape = getattr(m, "input_shape", None)
176
- oshape = getattr(m, "output_shape", None)
177
- logging.info(f"βœ… Segmentation model loaded (CPU) | input_shape={ishape} output_shape={oshape}")
178
  else:
179
  models_cache["seg"] = None
180
- logging.warning("Segmentation model file missing; skipping.")
181
  except Exception as e:
182
  models_cache["seg"] = None
183
- logging.warning(f"Segmentation unavailable: {e}")
184
 
185
  if "cls" not in models_cache:
186
  try:
187
  models_cache["cls"] = load_classification_pipeline()
188
- logging.info("βœ… Classifier loaded (CPU)")
189
  except Exception as e:
190
  models_cache["cls"] = None
191
- logging.warning(f"Classifier unavailable: {e}")
192
 
193
  if "embedding_model" not in models_cache:
194
  try:
195
  models_cache["embedding_model"] = load_embedding_model()
196
- logging.info("βœ… Embeddings loaded (CPU)")
197
  except Exception as e:
198
  models_cache["embedding_model"] = None
199
- logging.warning(f"Embeddings unavailable: {e}")
200
 
201
  def setup_knowledge_base() -> None:
 
202
  if "vector_store" in knowledge_base_cache:
203
  return
204
- docs: List = []
 
205
  try:
206
  PyPDFLoader = _import_langchain_pdf()
207
  for pdf in GUIDELINE_PDFS:
208
  if os.path.exists(pdf):
209
  try:
210
- docs.extend(PyPDFLoader(pdf).load())
 
211
  logging.info(f"Loaded PDF: {pdf}")
212
  except Exception as e:
213
- logging.warning(f"PDF load failed ({pdf}): {e}")
214
  except Exception as e:
215
  logging.warning(f"LangChain PDF loader unavailable: {e}")
216
 
@@ -218,603 +251,145 @@ def setup_knowledge_base() -> None:
218
  try:
219
  from langchain.text_splitter import RecursiveCharacterTextSplitter
220
  FAISS = _import_langchain_faiss()
221
- chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100).split_documents(docs)
 
222
  knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"])
223
- logging.info(f"βœ… Knowledge base ready ({len(chunks)} chunks)")
224
  except Exception as e:
225
  knowledge_base_cache["vector_store"] = None
226
- logging.warning(f"KB build failed: {e}")
227
  else:
228
  knowledge_base_cache["vector_store"] = None
229
- logging.warning("KB disabled (no docs or embeddings).")
230
 
 
231
  initialize_cpu_models()
232
  setup_knowledge_base()
233
 
234
- # ---------- Calibration helpers ----------
235
- def _adaptive_prob_threshold(p: np.ndarray) -> float:
236
- """
237
- Pick a threshold that avoids tiny blobs while not swallowing skin.
238
- Strategy:
239
- - try Otsu on the prob map
240
- - clamp to a reasonable band [0.25, 0.65]
241
- - also consider percentile cut (p90) and take the "best" by area heuristic
242
- """
243
- p01 = np.clip(p.astype(np.float32), 0, 1)
244
- p255 = (p01 * 255).astype(np.uint8)
245
-
246
- # Otsu β†’ use the returned scalar threshold (ret), NOT the image
247
- ret_otsu, _dst = cv2.threshold(p255, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
248
- thr_otsu = float(np.clip(ret_otsu / 255.0, 0.25, 0.65))
249
-
250
- # Percentile (90th)
251
- thr_pctl = float(np.clip(np.percentile(p01, 90), 0.25, 0.65))
252
-
253
- # Area fraction helper
254
- def area_frac(thr: float) -> float:
255
- return float((p01 >= thr).sum()) / float(p01.size)
256
-
257
- af_otsu = area_frac(thr_otsu)
258
- af_pctl = area_frac(thr_pctl)
259
-
260
- # Score: prefer ~3–10% coverage
261
- def score(af: float) -> float:
262
- target_low, target_high = 0.03, 0.10
263
- if af < target_low: return abs(af - target_low) * 3.0
264
- if af > target_high: return abs(af - target_high) * 1.5
265
- return 0.0
266
-
267
- return thr_otsu if score(af_otsu) <= score(af_pctl) else thr_pctl
268
-
269
-
270
- # Score: closeness to a target area fraction (aim ~3–10%)
271
- def score(af):
272
- target_low, target_high = 0.03, 0.10
273
- if af < target_low: return abs(af - target_low) * 3.0
274
- if af > target_high: return abs(af - target_high) * 1.5
275
- return 0.0
276
-
277
- return thr_otsu if score(af_otsu) <= score(af_pctl) else thr_pctl
278
-
279
-
280
- def _grabcut_refine(bgr: np.ndarray, seed01: np.ndarray, iters: int = 3) -> np.ndarray:
281
- """
282
- Use OpenCV GrabCut to grow from a confident core into low-contrast margins.
283
- seed01: 1=probable FG core, 0=unknown/other
284
- """
285
- h, w = bgr.shape[:2]
286
- # Build GC mask: start with "unknown"
287
- gc = np.full((h, w), cv2.GC_PR_BGD, np.uint8)
288
- # definite FG = dilated seed; probable FG = seed
289
- k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
290
- seed_dil = cv2.dilate(seed01, k, iterations=1)
291
- gc[seed01.astype(bool)] = cv2.GC_PR_FGD
292
- gc[seed_dil.astype(bool)] = cv2.GC_FGD
293
- # border is probable background
294
- gc[0, :], gc[-1, :], gc[:, 0], gc[:, -1] = cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD
295
-
296
- bgdModel = np.zeros((1, 65), np.float64)
297
- fgdModel = np.zeros((1, 65), np.float64)
298
- cv2.grabCut(bgr, gc, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK)
299
-
300
- # FG = definite or probable foreground
301
- mask01 = np.where((gc == cv2.GC_FGD) | (gc == cv2.GC_PR_FGD), 1, 0).astype(np.uint8)
302
- return mask01
303
-
304
-
305
- def _fill_holes(mask01: np.ndarray) -> np.ndarray:
306
- h, w = mask01.shape[:2]
307
- ff = np.zeros((h + 2, w + 2), np.uint8)
308
- m = (mask01 * 255).astype(np.uint8).copy()
309
- cv2.floodFill(m, ff, (0, 0), 255)
310
- m_inv = cv2.bitwise_not(m)
311
- out = ((mask01 * 255) | m_inv) // 255
312
- return out.astype(np.uint8)
313
-
314
-
315
- def _clean_mask(mask01: np.ndarray) -> np.ndarray:
316
- """Open β†’ Close β†’ Fill holes β†’ Largest component β†’ light smooth."""
317
- mask01 = (mask01 > 0).astype(np.uint8)
318
- k3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
319
- k5 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
320
- mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, k3, iterations=1)
321
- mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, k5, iterations=2)
322
- mask01 = _fill_holes(mask01)
323
-
324
- # keep largest component
325
- num, labels, stats, _ = cv2.connectedComponentsWithStats(mask01, 8)
326
- if num > 1:
327
- areas = stats[1:, cv2.CC_STAT_AREA]
328
- if areas.size:
329
- largest_idx = 1 + int(np.argmax(areas))
330
- mask01 = (labels == largest_idx).astype(np.uint8)
331
-
332
- # tiny masks β†’ gentle grow (distance transform based)
333
- area = int(mask01.sum())
334
- if area > 0:
335
- grow = 1 if area < 2000 else 0
336
- if grow:
337
- k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
338
- mask01 = cv2.dilate(mask01, k, iterations=1)
339
-
340
- return (mask01 > 0).astype(np.uint8)
341
-
342
-
343
-
344
-
345
-
346
- def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]:
347
- out = {}
348
- try:
349
- exif = pil_img.getexif()
350
- if not exif:
351
- return out
352
- for k, v in exif.items():
353
- tag = TAGS.get(k, k)
354
- out[tag] = v
355
- except Exception:
356
- pass
357
- return out
358
-
359
- def _to_float(val) -> Optional[float]:
360
- try:
361
- if val is None:
362
- return None
363
- if isinstance(val, tuple) and len(val) == 2:
364
- num, den = float(val[0]), float(val[1]) if float(val[1]) != 0 else 1.0
365
- return num / den
366
- return float(val)
367
- except Exception:
368
- return None
369
-
370
- def _estimate_sensor_width_mm(f_mm: Optional[float], f35: Optional[float]) -> Optional[float]:
371
- if f_mm and f35 and f35 > 0:
372
- return 36.0 * f_mm / f35
373
- return None
374
-
375
- def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float = DEFAULT_PX_PER_CM) -> Tuple[float, Dict]:
376
- meta = {"used": "default", "f_mm": None, "f35": None, "sensor_w_mm": None, "distance_m": None}
377
- try:
378
- exif = _exif_to_dict(pil_img)
379
- f_mm = _to_float(exif.get("FocalLength"))
380
- f35 = _to_float(exif.get("FocalLengthIn35mmFilm") or exif.get("FocalLengthIn35mm"))
381
- subj_dist_m = _to_float(exif.get("SubjectDistance"))
382
- sensor_w_mm = _estimate_sensor_width_mm(f_mm, f35)
383
- meta.update({"f_mm": f_mm, "f35": f35, "sensor_w_mm": sensor_w_mm, "distance_m": subj_dist_m})
384
-
385
- if f_mm and sensor_w_mm and subj_dist_m and subj_dist_m > 0:
386
- w_px = pil_img.width
387
- field_w_mm = sensor_w_mm * (subj_dist_m * 1000.0) / f_mm
388
- field_w_cm = field_w_mm / 10.0
389
- px_per_cm = w_px / max(field_w_cm, 1e-6)
390
- px_per_cm = float(np.clip(px_per_cm, PX_PER_CM_MIN, PX_PER_CM_MAX))
391
- meta["used"] = "exif"
392
- return px_per_cm, meta
393
- return float(default_px_per_cm), meta
394
- except Exception:
395
- return float(default_px_per_cm), meta
396
-
397
- # ---------- Segmentation helpers ----------
398
- def _imagenet_norm(arr: np.ndarray) -> np.ndarray:
399
- # expects RGB 0..255 -> float
400
- mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
401
- std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
402
- return (arr.astype(np.float32) - mean) / std
403
-
404
- def _preprocess_for_seg(bgr_roi: np.ndarray, target_hw: Tuple[int, int]) -> np.ndarray:
405
- H, W = target_hw
406
- resized = cv2.resize(bgr_roi, (W, H), interpolation=cv2.INTER_LINEAR)
407
- if SEG_EXPECTS_RGB:
408
- resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
409
- if SEG_NORM.lower() == "imagenet":
410
- x = _imagenet_norm(resized)
411
- else:
412
- x = resized.astype(np.float32) / 255.0
413
- x = np.expand_dims(x, axis=0) # (1,H,W,3)
414
- return x
415
-
416
- def _to_prob(pred: np.ndarray) -> np.ndarray:
417
- p = np.squeeze(pred)
418
- pmin, pmax = float(p.min()), float(p.max())
419
- if pmax > 1.0 or pmin < 0.0:
420
- p = 1.0 / (1.0 + np.exp(-p))
421
- return p.astype(np.float32)
422
-
423
- # ---- Robust mask post-processing (for "proper" masking) ----
424
- def _fill_holes(mask01: np.ndarray) -> np.ndarray:
425
- # Flood-fill from border, then invert
426
- h, w = mask01.shape[:2]
427
- ff = np.zeros((h + 2, w + 2), np.uint8)
428
- m = (mask01 * 255).astype(np.uint8).copy()
429
- cv2.floodFill(m, ff, (0, 0), 255)
430
- m_inv = cv2.bitwise_not(m)
431
- # Combine original with filled holes
432
- out = ((mask01 * 255) | m_inv) // 255
433
- return out.astype(np.uint8)
434
-
435
- # Global last debug dict (per-process) to attach into results
436
- _last_seg_debug: Dict[str, object] = {}
437
-
438
- def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndarray, Dict[str, object]]:
439
- """
440
- TF model β†’ adaptive threshold on prob β†’ (optional) GrabCut grow β†’ cleanup.
441
- Falls back to KMeans-Lab when model missing/fails.
442
- Returns (mask_uint8_0_255, debug_dict)
443
- """
444
- debug = {"used": None, "reason": None, "positive_fraction": 0.0,
445
- "thr": None, "heatmap_path": None, "roi_seen_by_model": None}
446
-
447
- seg_model = models_cache.get("seg", None)
448
-
449
- # --- Model path ---
450
- if seg_model is not None:
451
- try:
452
- ishape = getattr(seg_model, "input_shape", None)
453
- if not ishape or len(ishape) < 4:
454
- raise ValueError(f"Bad seg input_shape: {ishape}")
455
- th, tw = int(ishape[1]), int(ishape[2])
456
-
457
- # preprocess
458
- x = _preprocess_for_seg(image_bgr, (th, tw))
459
- rgb_for_view = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
460
- roi_seen_path = None
461
- if SMARTHEAL_DEBUG:
462
- roi_seen_path = os.path.join(out_dir, f"roi_for_seg_{ts}.png")
463
- cv2.imwrite(roi_seen_path, cv2.cvtColor(rgb_for_view, cv2.COLOR_RGB2BGR))
464
-
465
- # predict β†’ prob map back to ROI size
466
- pred = seg_model.predict(x, verbose=0)
467
- if isinstance(pred, (list, tuple)): pred = pred[0]
468
- p = _to_prob(pred)
469
- p = cv2.resize(p, (image_bgr.shape[1], image_bgr.shape[0]), interpolation=cv2.INTER_LINEAR)
470
-
471
- # visualization (optional)
472
- heatmap_path = None
473
- if SMARTHEAL_DEBUG:
474
- hm = (np.clip(p, 0, 1) * 255).astype(np.uint8)
475
- heat = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
476
- heatmap_path = os.path.join(out_dir, f"seg_pred_heatmap_{ts}.png")
477
- cv2.imwrite(heatmap_path, heat)
478
-
479
- # --- Adaptive threshold ---
480
- thr = _adaptive_prob_threshold(p)
481
- core01 = (p >= thr).astype(np.uint8)
482
- core_frac = float(core01.sum()) / float(core01.size)
483
-
484
- # If still too tiny, try a gentler threshold
485
- if core_frac < 0.005:
486
- thr2 = max(thr - 0.10, 0.15)
487
- core01 = (p >= thr2).astype(np.uint8)
488
- thr = thr2
489
- core_frac = float(core01.sum()) / float(core01.size)
490
-
491
- # --- Grow with GrabCut (only if some core exists) ---
492
- if core01.any():
493
- gc01 = _grabcut_refine(image_bgr, core01, iters=3)
494
- mask01 = _clean_mask(gc01)
495
- else:
496
- mask01 = np.zeros(core01.shape, np.uint8)
497
-
498
- pos_frac = float(mask01.sum()) / float(mask01.size)
499
- logging.info(f"SegModel USED | thr={thr:.2f} core_frac={core_frac:.4f} final_frac={pos_frac:.4f}")
500
-
501
- debug.update({
502
- "used": "tf_model",
503
- "reason": "ok",
504
- "positive_fraction": pos_frac,
505
- "thr": thr,
506
- "heatmap_path": heatmap_path,
507
- "roi_seen_by_model": roi_seen_path
508
- })
509
- return (mask01 * 255).astype(np.uint8), debug
510
-
511
- except Exception as e:
512
- logging.warning(f"⚠️ Segmentation model failed β†’ fallback. Reason: {e}")
513
- debug.update({"used": "fallback_kmeans", "reason": f"model_failed: {e}"})
514
-
515
- # --- Fallback: KMeans in Lab (reddest cluster as wound) ---
516
- Z = image_bgr.reshape((-1, 3)).astype(np.float32)
517
- criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
518
- _, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
519
- centers_u8 = centers.astype(np.uint8).reshape(1, 2, 3)
520
- centers_lab = cv2.cvtColor(centers_u8, cv2.COLOR_BGR2LAB)[0]
521
- wound_idx = int(np.argmax(centers_lab[:, 1])) # maximize a* (red)
522
- mask01 = (labels.reshape(image_bgr.shape[:2]) == wound_idx).astype(np.uint8)
523
- mask01 = _clean_mask(mask01)
524
-
525
- pos_frac = float(mask01.sum()) / float(mask01.size)
526
- logging.info(f"KMeans USED | final_frac={pos_frac:.4f}")
527
-
528
- debug.update({
529
- "used": "fallback_kmeans",
530
- "reason": debug.get("reason") or "no_model",
531
- "positive_fraction": pos_frac,
532
- "thr": None
533
- })
534
- return (mask01 * 255).astype(np.uint8), debug
535
-
536
-
537
- # ---------- Measurement + overlay helpers ----------
538
- def largest_component_mask(binary01: np.ndarray, min_area_px: int = 50) -> np.ndarray:
539
- num, labels, stats, _ = cv2.connectedComponentsWithStats(binary01.astype(np.uint8), connectivity=8)
540
- if num <= 1:
541
- return binary01.astype(np.uint8)
542
- areas = stats[1:, cv2.CC_STAT_AREA]
543
- if areas.size == 0 or areas.max() < min_area_px:
544
- return binary01.astype(np.uint8)
545
- largest_idx = 1 + int(np.argmax(areas))
546
- return (labels == largest_idx).astype(np.uint8)
547
-
548
- def _clean_mask(mask01: np.ndarray) -> np.ndarray:
549
- """Open→Close→Fill holes→Largest component."""
550
- if mask01.dtype != np.uint8:
551
- mask01 = mask01.astype(np.uint8)
552
- k = np.ones((3, 3), np.uint8)
553
- mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, k, iterations=1)
554
- mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, k, iterations=2)
555
- mask01 = _fill_holes(mask01)
556
- mask01 = largest_component_mask(mask01, min_area_px=30)
557
- return (mask01 > 0).astype(np.uint8)
558
-
559
- def measure_min_area_rect(mask01: np.ndarray, px_per_cm: float) -> Tuple[float, float, Tuple]:
560
- contours, _ = cv2.findContours(mask01.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
561
- if not contours:
562
- return 0.0, 0.0, (None, None)
563
- cnt = max(contours, key=cv2.contourArea)
564
- rect = cv2.minAreaRect(cnt)
565
- (w_px, h_px) = rect[1]
566
- length_px, breadth_px = (max(w_px, h_px), min(w_px, h_px))
567
- length_cm = round(length_px / max(px_per_cm, 1e-6), 2)
568
- breadth_cm = round(breadth_px / max(px_per_cm, 1e-6), 2)
569
- box = cv2.boxPoints(rect).astype(int)
570
- return length_cm, breadth_cm, (box, rect[0])
571
-
572
- def count_area_cm2(mask01: np.ndarray, px_per_cm: float) -> float:
573
- px_count = float(mask01.astype(bool).sum())
574
- return round(px_count / (max(px_per_cm, 1e-6) ** 2), 2)
575
-
576
- def draw_measurement_overlay(
577
- base_bgr: np.ndarray,
578
- mask01: np.ndarray,
579
- rect_box: np.ndarray,
580
- length_cm: float,
581
- breadth_cm: float,
582
- thickness: int = 2
583
- ) -> np.ndarray:
584
- """
585
- Draws:
586
- 1) Strong red mask overlay with white contour.
587
- 2) Min-area rectangle.
588
- 3) Two double-headed arrows:
589
- - 'Length' along the longer side.
590
- - 'Width' along the shorter side.
591
- """
592
- overlay = base_bgr.copy()
593
-
594
- # --- Strong overlay from mask (tinted red where mask==1) ---
595
- mask255 = (mask01 * 255).astype(np.uint8)
596
- mask3 = cv2.merge([mask255, mask255, mask255])
597
- red = np.zeros_like(overlay); red[:] = (0, 0, 255)
598
- alpha = 0.55
599
- tinted = cv2.addWeighted(overlay, 1 - alpha, red, alpha, 0)
600
- overlay = np.where(mask3 > 0, tinted, overlay)
601
-
602
- # Draw wound contour
603
- cnts, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
604
- if cnts:
605
- cv2.drawContours(overlay, cnts, -1, (255, 255, 255), 2)
606
-
607
- if rect_box is not None:
608
- cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness)
609
- pts = rect_box.reshape(-1, 2)
610
-
611
- def midpoint(a, b):
612
- return (int((a[0] + b[0]) / 2), int((a[1] + b[1]) / 2))
613
-
614
- # Edge lengths
615
- e = [np.linalg.norm(pts[i] - pts[(i + 1) % 4]) for i in range(4)]
616
- long_edge_idx = int(np.argmax(e))
617
- short_edge_idx = (long_edge_idx + 1) % 2 # 0/1 map for pairs below
618
-
619
- # Midpoints of opposite edges for arrows
620
- mids = [midpoint(pts[i], pts[(i + 1) % 4]) for i in range(4)]
621
- # Long side uses edges long_edge_idx and the opposite edge (i+2)
622
- long_pair = (long_edge_idx, (long_edge_idx + 2) % 4)
623
- # Short side uses the other pair
624
- short_pair = ((long_edge_idx + 1) % 4, (long_edge_idx + 3) % 4)
625
-
626
- def draw_double_arrow(img, p1, p2):
627
- cv2.arrowedLine(img, p1, p2, (0, 0, 0), thickness + 2, tipLength=0.05)
628
- cv2.arrowedLine(img, p2, p1, (0, 0, 0), thickness + 2, tipLength=0.05)
629
- cv2.arrowedLine(img, p1, p2, (255, 255, 255), thickness, tipLength=0.05)
630
- cv2.arrowedLine(img, p2, p1, (255, 255, 255), thickness, tipLength=0.05)
631
-
632
- def put_label(text, anchor):
633
- org = (anchor[0] + 6, anchor[1] - 6)
634
- cv2.putText(overlay, text, org, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 4, cv2.LINE_AA)
635
- cv2.putText(overlay, text, org, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
636
-
637
- # Draw arrows and labels
638
- draw_double_arrow(overlay, mids[long_pair[0]], mids[long_pair[1]])
639
- draw_double_arrow(overlay, mids[short_pair[0]], mids[short_pair[1]])
640
- put_label(f"Length: {length_cm:.2f} cm", mids[long_pair[0]])
641
- put_label(f"Width: {breadth_cm:.2f} cm", mids[short_pair[0]])
642
-
643
- return overlay
644
-
645
- # ---------- AI PROCESSOR ----------
646
  class AIProcessor:
647
  def __init__(self):
648
  self.models_cache = models_cache
649
  self.knowledge_base_cache = knowledge_base_cache
 
650
  self.uploads_dir = UPLOADS_DIR
651
  self.dataset_id = DATASET_ID
652
  self.hf_token = HF_TOKEN
653
 
 
654
  def _ensure_analysis_dir(self) -> str:
655
  out_dir = os.path.join(self.uploads_dir, "analysis")
656
  os.makedirs(out_dir, exist_ok=True)
657
  return out_dir
658
 
659
  def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
660
- """
661
- YOLO detect β†’ crop ROI β†’ segment_wound(ROI) β†’ clean mask β†’
662
- minAreaRect measurement (cm) using EXIF px/cm β†’ save outputs.
663
- """
664
  try:
665
- px_per_cm, exif_meta = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM)
666
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
667
 
668
- # --- Detection ---
669
- det_model = self.models_cache.get("det")
670
- if det_model is None:
671
  raise RuntimeError("YOLO model not loaded")
672
- results = det_model.predict(image_cv, verbose=False, device="cpu")
673
- if (not results) or (not getattr(results[0], "boxes", None)) or (len(results[0].boxes) == 0):
674
- try:
675
- import gradio as gr
676
- raise gr.Error("No wound could be detected.")
677
- except Exception:
678
- raise RuntimeError("No wound could be detected.")
679
 
680
  box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
681
  x1, y1, x2, y2 = [int(v) for v in box]
682
  x1, y1 = max(0, x1), max(0, y1)
683
  x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
684
- roi = image_cv[y1:y2, x1:x2].copy()
685
- if roi.size == 0:
686
- try:
687
- import gradio as gr
688
- raise gr.Error("Detected ROI is empty.")
689
- except Exception:
690
- raise RuntimeError("Detected ROI is empty.")
691
-
692
- out_dir = self._ensure_analysis_dir()
693
- ts = datetime.now().strftime("%Y%m%d_%H%M%S")
694
 
695
- # --- Segmentation (model-first + KMeans fallback) ---
696
- mask_u8_255, seg_debug = segment_wound(roi, ts, out_dir)
697
- mask01 = (mask_u8_255 > 127).astype(np.uint8)
698
-
699
- # Robust post-processing to ensure "proper" masking
700
- if mask01.any():
701
- mask01 = _clean_mask(mask01)
702
- logging.debug(f"Mask postproc: px_after={int(mask01.sum())}")
703
-
704
- # --- Measurement ---
705
- if mask01.any():
706
- length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask01, px_per_cm)
707
- surface_area_cm2 = count_area_cm2(mask01, px_per_cm)
708
- # Final annotated ROI with mask + arrows + labels
709
- anno_roi = draw_measurement_overlay(roi, mask01, box_pts, length_cm, breadth_cm)
710
- segmentation_empty = False
711
- else:
712
- # Graceful fallback if seg failed: use ROI box as bounds
713
- h_px = max(0, y2 - y1); w_px = max(0, x2 - x1)
714
- length_cm = round(max(h_px, w_px) / px_per_cm, 2)
715
- breadth_cm = round(min(h_px, w_px) / px_per_cm, 2)
716
- surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
717
- anno_roi = roi.copy()
718
- cv2.rectangle(anno_roi, (2, 2), (anno_roi.shape[1]-3, anno_roi.shape[0]-3), (0, 0, 255), 3)
719
- cv2.line(anno_roi, (0, 0), (anno_roi.shape[1]-1, anno_roi.shape[0]-1), (0, 0, 255), 2)
720
- cv2.line(anno_roi, (anno_roi.shape[1]-1, 0), (0, anno_roi.shape[0]-1), (0, 0, 255), 2)
721
- box_pts = None
722
- segmentation_empty = True
723
-
724
- # --- Save visualizations ---
725
- original_path = os.path.join(out_dir, f"original_{ts}.png")
726
- cv2.imwrite(original_path, image_cv)
727
-
728
- det_vis = image_cv.copy()
729
- cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
730
- detection_path = os.path.join(out_dir, f"detection_{ts}.png")
731
- cv2.imwrite(detection_path, det_vis)
732
-
733
- roi_mask_path = os.path.join(out_dir, f"roi_mask_{ts}.png")
734
- cv2.imwrite(roi_mask_path, (mask01 * 255).astype(np.uint8))
735
-
736
- # ROI overlay (clear mask w/ white contour, no arrows)
737
- mask255 = (mask01 * 255).astype(np.uint8)
738
- mask3 = cv2.merge([mask255, mask255, mask255])
739
- red = np.zeros_like(roi); red[:] = (0, 0, 255)
740
- alpha = 0.55
741
- tinted = cv2.addWeighted(roi, 1 - alpha, red, alpha, 0)
742
- if mask255.any():
743
- roi_overlay = np.where(mask3 > 0, tinted, roi)
744
- cnts, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
745
- cv2.drawContours(roi_overlay, cnts, -1, (255, 255, 255), 2)
746
- else:
747
- roi_overlay = anno_roi
748
-
749
- seg_full = image_cv.copy()
750
- seg_full[y1:y2, x1:x2] = roi_overlay
751
- segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
752
- cv2.imwrite(segmentation_path, seg_full)
753
-
754
- segmentation_roi_path = os.path.join(out_dir, f"segmentation_roi_{ts}.png")
755
- cv2.imwrite(segmentation_roi_path, roi_overlay)
756
-
757
- # Annotated (mask + arrows + labels) in full-frame
758
- anno_full = image_cv.copy()
759
- anno_full[y1:y2, x1:x2] = anno_roi
760
- annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
761
- cv2.imwrite(annotated_seg_path, anno_full)
762
 
763
- # --- Optional classification ---
764
  wound_type = "Unknown"
765
  cls_pipe = self.models_cache.get("cls")
766
  if cls_pipe is not None:
767
  try:
768
- preds = cls_pipe(Image.fromarray(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)))
 
769
  if preds:
770
  wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
771
  except Exception as e:
772
- logging.warning(f"Classification failed: {e}")
773
-
774
- # Log end-of-seg summary
775
- seg_summary = {
776
- "seg_used": seg_debug.get("used"),
777
- "seg_reason": seg_debug.get("reason"),
778
- "positive_fraction": round(float(seg_debug.get("positive_fraction", 0.0)), 6),
779
- "threshold": seg_debug.get("threshold", SEG_THRESH),
780
- "segmentation_empty": segmentation_empty,
781
- "exif_px_per_cm": round(px_per_cm, 3),
782
- }
783
- _log_kv("SEG_SUMMARY", seg_summary)
784
 
785
  return {
786
  "wound_type": wound_type,
787
- "length_cm": length_cm,
788
- "breadth_cm": breadth_cm,
789
- "surface_area_cm2": surface_area_cm2,
790
- "px_per_cm": round(px_per_cm, 2),
791
- "calibration_meta": exif_meta,
792
  "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
793
- if getattr(results[0].boxes, "conf", None) is not None else 0.0,
794
- "detection_image_path": detection_path,
795
- "segmentation_image_path": annotated_seg_path,
796
- "segmentation_annotated_path": annotated_seg_path,
797
- "segmentation_roi_path": segmentation_roi_path,
798
- "roi_mask_path": roi_mask_path,
799
- "segmentation_empty": segmentation_empty,
800
- "segmentation_debug": seg_debug,
801
  "original_image_path": original_path,
802
  }
803
  except Exception as e:
804
- logging.error(f"Visual analysis failed: {e}", exc_info=True)
805
  raise
806
 
807
- # ---------- Knowledge base + reporting ----------
808
  def query_guidelines(self, query: str) -> str:
 
809
  try:
810
  vs = self.knowledge_base_cache.get("vector_store")
811
  if not vs:
812
  return "Knowledge base is not available."
 
813
  try:
814
  retriever = vs.as_retriever(search_kwargs={"k": 5})
815
- docs = retriever.get_relevant_documents(query)
816
  except Exception:
817
  retriever = vs.as_retriever(search_kwargs={"k": 5})
 
818
  docs = retriever.invoke(query)
819
  lines: List[str] = []
820
  for d in docs:
@@ -826,7 +401,9 @@ class AIProcessor:
826
  logging.warning(f"Guidelines query failed: {e}")
827
  return f"Guidelines query failed: {str(e)}"
828
 
 
829
  def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
 
830
  return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report
831
  ## πŸ“‹ Patient Information
832
  {patient_info}
@@ -835,12 +412,10 @@ class AIProcessor:
835
  - **Dimensions**: {visual_results.get('length_cm', 0)} cm Γ— {visual_results.get('breadth_cm', 0)} cm
836
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cmΒ²
837
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
838
- - **Calibration**: {visual_results.get('px_per_cm','?')} px/cm ({(visual_results.get('calibration_meta') or {}).get('used','default')})
839
  ## πŸ“Š Analysis Images
840
  - **Original**: {visual_results.get('original_image_path', 'N/A')}
841
  - **Detection**: {visual_results.get('detection_image_path', 'N/A')}
842
  - **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
843
- - **Annotated**: {visual_results.get('segmentation_annotated_path', 'N/A')}
844
  ## 🎯 Clinical Summary
845
  Automated analysis provides quantitative measurements; verify via clinical examination.
846
  ## πŸ’Š Recommendations
@@ -848,10 +423,10 @@ Automated analysis provides quantitative measurements; verify via clinical exami
848
  - Debride necrotic tissue if indicated (clinical decision)
849
  - Document with serial photos and measurements
850
  ## πŸ“… Monitoring
851
- - Daily in week 1, then every 2–3 days (or as indicated)
852
  - Weekly progress review
853
  ## πŸ“š Guideline Context
854
- {(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''}
855
  **Disclaimer:** Automated, for decision support only. Verify clinically.
856
  """
857
 
@@ -863,8 +438,9 @@ Automated analysis provides quantitative measurements; verify via clinical exami
863
  image_pil: Image.Image,
864
  max_new_tokens: Optional[int] = None,
865
  ) -> str:
 
866
  try:
867
- report = generate_medgemma_report(
868
  patient_info, visual_results, guideline_context, image_pil, max_new_tokens
869
  )
870
  if report and report.strip() and not report.startswith(("⚠️", "❌")):
@@ -875,7 +451,9 @@ Automated analysis provides quantitative measurements; verify via clinical exami
875
  logging.error(f"Report generation failed: {e}")
876
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
877
 
 
878
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
 
879
  try:
880
  os.makedirs(self.uploads_dir, exist_ok=True)
881
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -884,17 +462,17 @@ Automated analysis provides quantitative measurements; verify via clinical exami
884
  image_pil.convert("RGB").save(path)
885
  logging.info(f"βœ… Image saved locally: {path}")
886
 
887
- if HF_TOKEN and DATASET_ID:
888
  try:
889
  HfApi, HfFolder = _import_hf_hub()
890
- HfFolder.save_token(HF_TOKEN)
891
  api = HfApi()
892
  api.upload_file(
893
  path_or_fileobj=path,
894
  path_in_repo=f"images/{filename}",
895
- repo_id=DATASET_ID,
896
  repo_type="dataset",
897
- token=HF_TOKEN,
898
  commit_message=f"Upload wound image: {filename}",
899
  )
900
  logging.info("βœ… Image committed to HF dataset")
@@ -906,23 +484,28 @@ Automated analysis provides quantitative measurements; verify via clinical exami
906
  logging.error(f"Failed to save/commit image: {e}")
907
  return ""
908
 
 
909
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
 
910
  try:
911
  saved_path = self.save_and_commit_image(image_pil)
 
912
  visual_results = self.perform_visual_analysis(image_pil)
913
 
 
914
  pi = questionnaire_data or {}
915
  patient_info = (
916
- f"Age: {pi.get('age','N/A')}, "
917
- f"Diabetic: {pi.get('diabetic','N/A')}, "
918
- f"Allergies: {pi.get('allergies','N/A')}, "
919
- f"Date of Wound: {pi.get('date_of_injury','N/A')}, "
920
- f"Professional Care: {pi.get('professional_care','N/A')}, "
921
- f"Oozing/Bleeding: {pi.get('oozing_bleeding','N/A')}, "
922
- f"Infection: {pi.get('infection','N/A')}, "
923
- f"Moisture: {pi.get('moisture','N/A')}"
924
  )
925
 
 
926
  query = (
927
  f"best practices for managing a {visual_results.get('wound_type','Unknown')} "
928
  f"with moisture '{pi.get('moisture','unknown')}' and infection '{pi.get('infection','unknown')}' "
@@ -930,16 +513,18 @@ Automated analysis provides quantitative measurements; verify via clinical exami
930
  )
931
  guideline_context = self.query_guidelines(query)
932
 
933
- report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil)
 
 
 
 
934
 
935
  return {
936
  "success": True,
937
  "visual_analysis": visual_results,
938
  "report": report,
939
  "saved_image_path": saved_path,
940
- "guideline_context": (guideline_context or "")[:500] + (
941
- "..." if guideline_context and len(guideline_context) > 500 else ""
942
- ),
943
  }
944
  except Exception as e:
945
  logging.error(f"Pipeline error: {e}")
@@ -953,6 +538,7 @@ Automated analysis provides quantitative measurements; verify via clinical exami
953
  }
954
 
955
  def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
 
956
  try:
957
  if isinstance(image, str):
958
  if not os.path.exists(image):
@@ -975,4 +561,4 @@ Automated analysis provides quantitative measurements; verify via clinical exami
975
  "report": f"Analysis initialization failed: {str(e)}",
976
  "saved_image_path": None,
977
  "guideline_context": "",
978
- }
 
1
  # smartheal_ai_processor.py
2
+ # Full, functional module with conditional Spaces GPU support and CPU fallbacks.
 
3
 
4
  import os
5
  import time
6
  import logging
7
  from datetime import datetime
8
+ from typing import Optional, Dict, List
 
 
 
 
 
 
9
 
10
  import cv2
11
  import numpy as np
12
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # =============== LOGGING SETUP ===============
15
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
16
+
17
+ # =============== CONFIGURATION ===============
18
  UPLOADS_DIR = "uploads"
19
  os.makedirs(UPLOADS_DIR, exist_ok=True)
20
 
21
  HF_TOKEN = os.getenv("HF_TOKEN", None)
22
  YOLO_MODEL_PATH = "src/best.pt"
23
+ SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
24
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
25
+ DATASET_ID = "SmartHeal/wound-image-uploads" # optional (set HF_TOKEN too)
26
+ PIXELS_PER_CM = 38 # heuristic
 
 
 
 
 
 
27
 
28
+ # =============== GLOBAL CACHES ===============
29
  models_cache: Dict[str, object] = {}
30
  knowledge_base_cache: Dict[str, object] = {}
31
 
32
+ # ---------- Optional imports guarded ----------
33
  def _import_ultralytics():
34
  from ultralytics import YOLO
35
  return YOLO
36
 
37
  def _import_tf_loader():
38
  import tensorflow as tf
39
+ tf.config.set_visible_devices([], "GPU") # force CPU
 
 
 
40
  from tensorflow.keras.models import load_model
41
  return load_model
42
 
 
60
  from huggingface_hub import HfApi, HfFolder
61
  return HfApi, HfFolder
62
 
63
+ # =============== SPACES GPU CONDITIONAL ===============
64
+ def _spaces_gpu_available() -> bool:
 
 
 
 
 
 
 
 
65
  try:
66
+ import torch
67
+ return bool(torch.cuda.is_available())
68
+ except Exception:
69
+ return False
70
+
71
+ def _spaces_lib_available() -> bool:
72
+ try:
73
+ import spaces # noqa
74
+ return True
75
+ except Exception:
76
+ return False
77
+
78
+ HAVE_SPACES_GPU = _spaces_gpu_available() and _spaces_lib_available()
79
+
80
+ if HAVE_SPACES_GPU:
81
+ import spaces # define only if available & GPU present
82
+
83
+ @spaces.GPU(enable_queue=True, duration=90)
84
+ def generate_medgemma_report_with_timeout(
85
+ patient_info: str,
86
+ visual_results: Dict,
87
+ guideline_context: str,
88
+ image_pil: Image.Image,
89
+ max_new_tokens: Optional[int] = None,
90
+ ) -> str:
91
+ """Runs on Spaces GPU only; callers keep one signature on both paths."""
92
+ import torch
93
  from transformers import pipeline
94
+ try:
95
+ torch.cuda.empty_cache()
96
+
97
+ prompt = f"""
98
+ You are a medical AI assistant. Analyze this wound image and patient data.
99
+ Patient: {patient_info}
100
+ Wound: {visual_results.get('wound_type', 'Unknown')} - {visual_results.get('length_cm', 0)}Γ—{visual_results.get('breadth_cm', 0)} cm
101
+ Provide a structured report with:
102
+ 1. Clinical Summary
103
+ 2. Treatment Recommendations
104
+ 3. Risk Assessment
105
+ 4. Monitoring Plan
106
+ """.strip()
107
+
108
+ pipe = pipeline(
109
+ "image-text-to-text",
110
+ model="google/medgemma-4b-it",
111
+ torch_dtype=torch.bfloat16,
112
+ device_map="auto",
113
+ token=HF_TOKEN,
114
+ model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
115
+ )
116
+
117
+ messages = [
118
+ {
119
+ "role": "user",
120
+ "content": [
121
+ {"type": "image", "image": image_pil},
122
+ {"type": "text", "text": prompt},
123
+ ],
124
+ }
125
+ ]
126
+
127
+ t0 = time.time()
128
+ out = pipe(
129
+ text=messages,
130
+ max_new_tokens=max_new_tokens or 800,
131
+ do_sample=False,
132
+ temperature=0.7,
133
+ pad_token_id=pipe.tokenizer.eos_token_id,
134
+ )
135
+ logging.info(f"βœ… MedGemma completed in {time.time() - t0:.2f}s")
136
+
137
+ if out and len(out) > 0:
138
+ # Defensive extraction
139
+ try:
140
+ return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
141
+ except Exception:
142
+ return (out[0].get("generated_text", "") or "").strip() or "⚠️ Empty response"
143
+ return "⚠️ No output generated"
144
+ except Exception as e:
145
+ logging.error(f"❌ MedGemma generation error: {e}")
146
+ return f"❌ Report generation failed: {str(e)}"
147
+ finally:
148
  try:
149
+ torch.cuda.empty_cache()
150
  except Exception:
151
+ pass
152
+ else:
153
+ def generate_medgemma_report_with_timeout(
154
+ patient_info: str,
155
+ visual_results: Dict,
156
+ guideline_context: str,
157
+ image_pil: Image.Image,
158
+ max_new_tokens: Optional[int] = None,
159
+ ) -> str:
160
+ """CPU-only path: return a warning so caller uses fallback."""
161
+ return "⚠️ GPU not available"
162
 
163
+ # =============== MODEL INITIALIZATION (CPU-SAFE) ===============
164
  def load_yolo_model():
165
  YOLO = _import_ultralytics()
166
  return YOLO(YOLO_MODEL_PATH)
 
171
 
172
  def load_classification_pipeline():
173
  pipe = _import_hf_cls()
174
+ return pipe(
175
+ "image-classification",
176
+ model="Hemg/Wound-classification",
177
+ token=HF_TOKEN,
178
+ device="cpu",
179
+ )
180
 
181
  def load_embedding_model():
182
  Emb = _import_embeddings()
183
  return Emb(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
184
 
185
  def initialize_cpu_models() -> None:
186
+ """Initialize all CPU-only models once with robust fallbacks."""
187
+ # Hugging Face auth (optional)
188
  if HF_TOKEN:
189
  try:
190
  HfApi, HfFolder = _import_hf_hub()
191
  HfFolder.save_token(HF_TOKEN)
192
+ logging.info("βœ… HuggingFace token set")
193
  except Exception as e:
194
  logging.warning(f"HF token save failed: {e}")
195
 
196
  if "det" not in models_cache:
197
  try:
198
  models_cache["det"] = load_yolo_model()
199
+ logging.info("βœ… YOLO model loaded (CPU)")
200
  except Exception as e:
201
  logging.error(f"YOLO load failed: {e}")
202
 
 
204
  try:
205
  if os.path.exists(SEG_MODEL_PATH):
206
  models_cache["seg"] = load_segmentation_model()
207
+ logging.info("βœ… Segmentation model loaded (CPU)")
 
 
 
208
  else:
209
  models_cache["seg"] = None
210
+ logging.warning("Segmentation model file not found; skipping seg.")
211
  except Exception as e:
212
  models_cache["seg"] = None
213
+ logging.warning(f"Segmentation model not available: {e}")
214
 
215
  if "cls" not in models_cache:
216
  try:
217
  models_cache["cls"] = load_classification_pipeline()
218
+ logging.info("βœ… Classification pipeline loaded (CPU)")
219
  except Exception as e:
220
  models_cache["cls"] = None
221
+ logging.warning(f"Classification pipeline not available: {e}")
222
 
223
  if "embedding_model" not in models_cache:
224
  try:
225
  models_cache["embedding_model"] = load_embedding_model()
226
+ logging.info("βœ… Embedding model loaded (CPU)")
227
  except Exception as e:
228
  models_cache["embedding_model"] = None
229
+ logging.warning(f"Embedding model not available: {e}")
230
 
231
  def setup_knowledge_base() -> None:
232
+ """Load PDFs and create FAISS vector store (optional)."""
233
  if "vector_store" in knowledge_base_cache:
234
  return
235
+
236
+ docs = []
237
  try:
238
  PyPDFLoader = _import_langchain_pdf()
239
  for pdf in GUIDELINE_PDFS:
240
  if os.path.exists(pdf):
241
  try:
242
+ loader = PyPDFLoader(pdf)
243
+ docs.extend(loader.load())
244
  logging.info(f"Loaded PDF: {pdf}")
245
  except Exception as e:
246
+ logging.warning(f"Failed to load PDF {pdf}: {e}")
247
  except Exception as e:
248
  logging.warning(f"LangChain PDF loader unavailable: {e}")
249
 
 
251
  try:
252
  from langchain.text_splitter import RecursiveCharacterTextSplitter
253
  FAISS = _import_langchain_faiss()
254
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
255
+ chunks = splitter.split_documents(docs)
256
  knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"])
257
+ logging.info(f"βœ… Knowledge base ready with {len(chunks)} chunks")
258
  except Exception as e:
259
  knowledge_base_cache["vector_store"] = None
260
+ logging.warning(f"Knowledge base unavailable: {e}")
261
  else:
262
  knowledge_base_cache["vector_store"] = None
263
+ logging.warning("Knowledge base disabled (no docs or embeddings).")
264
 
265
+ # Initialize on import
266
  initialize_cpu_models()
267
  setup_knowledge_base()
268
 
269
+ # =============== AI PROCESSOR ===============
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  class AIProcessor:
271
  def __init__(self):
272
  self.models_cache = models_cache
273
  self.knowledge_base_cache = knowledge_base_cache
274
+ self.px_per_cm = PIXELS_PER_CM
275
  self.uploads_dir = UPLOADS_DIR
276
  self.dataset_id = DATASET_ID
277
  self.hf_token = HF_TOKEN
278
 
279
+ # ---------- Image utilities ----------
280
  def _ensure_analysis_dir(self) -> str:
281
  out_dir = os.path.join(self.uploads_dir, "analysis")
282
  os.makedirs(out_dir, exist_ok=True)
283
  return out_dir
284
 
285
  def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
286
+ """YOLO detect β†’ (optional) Keras seg β†’ (optional) HF classifier β†’ save visuals."""
 
 
 
287
  try:
 
288
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
289
 
290
+ det = self.models_cache.get("det")
291
+ if det is None:
 
292
  raise RuntimeError("YOLO model not loaded")
293
+
294
+ # YOLO on CPU
295
+ results = det.predict(image_cv, verbose=False, device="cpu")
296
+ if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
297
+ raise ValueError("No wound could be detected.")
 
 
298
 
299
  box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
300
  x1, y1, x2, y2 = [int(v) for v in box]
301
  x1, y1 = max(0, x1), max(0, y1)
302
  x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
303
+ detected_region_cv = image_cv[y1:y2, x1:x2]
 
 
 
 
 
 
 
 
 
304
 
305
+ # Optional segmentation
306
+ seg_model = self.models_cache.get("seg")
307
+ length = breadth = area = 0.0
308
+ seg_path = None
309
+ if seg_model is not None and detected_region_cv.size > 0:
310
+ try:
311
+ input_size = seg_model.input_shape[1:3]
312
+ resized = cv2.resize(detected_region_cv, (input_size[1], input_size[0]))
313
+ mask_pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
314
+ mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8)
315
+
316
+ contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
317
+ if contours:
318
+ cnt = max(contours, key=cv2.contourArea)
319
+ x, y, w, h = cv2.boundingRect(cnt)
320
+ length = round(h / self.px_per_cm, 2)
321
+ breadth = round(w / self.px_per_cm, 2)
322
+ area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
323
+
324
+ # overlay visualization
325
+ mask_resized = cv2.resize(
326
+ mask_np * 255,
327
+ (detected_region_cv.shape[1], detected_region_cv.shape[0]),
328
+ interpolation=cv2.INTER_NEAREST,
329
+ )
330
+ overlay = detected_region_cv.copy()
331
+ overlay[mask_resized > 127] = [0, 0, 255]
332
+ seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0)
333
+
334
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
335
+ out_dir = self._ensure_analysis_dir()
336
+ seg_path = os.path.join(out_dir, f"segmentation_{ts}.png")
337
+ cv2.imwrite(seg_path, seg_vis)
338
+ except Exception as e:
339
+ logging.warning(f"Segmentation step skipped: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
+ # Optional classification
342
  wound_type = "Unknown"
343
  cls_pipe = self.models_cache.get("cls")
344
  if cls_pipe is not None:
345
  try:
346
+ detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
347
+ preds = cls_pipe(detected_image_pil)
348
  if preds:
349
  wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
350
  except Exception as e:
351
+ logging.warning(f"Classification step failed: {e}")
352
+
353
+ # Save detection & original
354
+ out_dir = self._ensure_analysis_dir()
355
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
356
+ det_vis = image_cv.copy()
357
+ cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
358
+ det_path = os.path.join(out_dir, f"detection_{ts}.png")
359
+ cv2.imwrite(det_path, det_vis)
360
+
361
+ original_path = os.path.join(out_dir, f"original_{ts}.png")
362
+ cv2.imwrite(original_path, image_cv)
363
 
364
  return {
365
  "wound_type": wound_type,
366
+ "length_cm": length,
367
+ "breadth_cm": breadth,
368
+ "surface_area_cm2": area,
 
 
369
  "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
370
+ if getattr(results[0].boxes, "conf", None) is not None
371
+ else 0.0,
372
+ "detection_image_path": det_path,
373
+ "segmentation_image_path": seg_path,
 
 
 
 
374
  "original_image_path": original_path,
375
  }
376
  except Exception as e:
377
+ logging.error(f"Visual analysis failed: {e}")
378
  raise
379
 
 
380
  def query_guidelines(self, query: str) -> str:
381
+ """Query the knowledge base (optional)."""
382
  try:
383
  vs = self.knowledge_base_cache.get("vector_store")
384
  if not vs:
385
  return "Knowledge base is not available."
386
+ # support both old and new retriever APIs
387
  try:
388
  retriever = vs.as_retriever(search_kwargs={"k": 5})
389
+ docs = retriever.get_relevant_documents(query) # LC >= 0.2
390
  except Exception:
391
  retriever = vs.as_retriever(search_kwargs={"k": 5})
392
+ # older invoke API
393
  docs = retriever.invoke(query)
394
  lines: List[str] = []
395
  for d in docs:
 
401
  logging.warning(f"Guidelines query failed: {e}")
402
  return f"Guidelines query failed: {str(e)}"
403
 
404
+ # ---------- Report builders ----------
405
  def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
406
+ """Plaintext/markdown fallback when MedGemma is unavailable."""
407
  return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report
408
  ## πŸ“‹ Patient Information
409
  {patient_info}
 
412
  - **Dimensions**: {visual_results.get('length_cm', 0)} cm Γ— {visual_results.get('breadth_cm', 0)} cm
413
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cmΒ²
414
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
 
415
  ## πŸ“Š Analysis Images
416
  - **Original**: {visual_results.get('original_image_path', 'N/A')}
417
  - **Detection**: {visual_results.get('detection_image_path', 'N/A')}
418
  - **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
 
419
  ## 🎯 Clinical Summary
420
  Automated analysis provides quantitative measurements; verify via clinical examination.
421
  ## πŸ’Š Recommendations
 
423
  - Debride necrotic tissue if indicated (clinical decision)
424
  - Document with serial photos and measurements
425
  ## πŸ“… Monitoring
426
+ - Daily in week 1, then every 2-3 days (or as indicated)
427
  - Weekly progress review
428
  ## πŸ“š Guideline Context
429
+ {(guideline_context or '')[:800]}{'...' if guideline_context and len(guideline_context) > 800 else ''}
430
  **Disclaimer:** Automated, for decision support only. Verify clinically.
431
  """
432
 
 
438
  image_pil: Image.Image,
439
  max_new_tokens: Optional[int] = None,
440
  ) -> str:
441
+ """Try MedGemma (GPU) β†’ fallback report."""
442
  try:
443
+ report = generate_medgemma_report_with_timeout(
444
  patient_info, visual_results, guideline_context, image_pil, max_new_tokens
445
  )
446
  if report and report.strip() and not report.startswith(("⚠️", "❌")):
 
451
  logging.error(f"Report generation failed: {e}")
452
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
453
 
454
+ # ---------- HF dataset commit ----------
455
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
456
+ """Save image locally and optionally upload to HF dataset."""
457
  try:
458
  os.makedirs(self.uploads_dir, exist_ok=True)
459
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
 
462
  image_pil.convert("RGB").save(path)
463
  logging.info(f"βœ… Image saved locally: {path}")
464
 
465
+ if self.hf_token and self.dataset_id:
466
  try:
467
  HfApi, HfFolder = _import_hf_hub()
468
+ HfFolder.save_token(self.hf_token)
469
  api = HfApi()
470
  api.upload_file(
471
  path_or_fileobj=path,
472
  path_in_repo=f"images/{filename}",
473
+ repo_id=self.dataset_id,
474
  repo_type="dataset",
475
+ token=self.hf_token,
476
  commit_message=f"Upload wound image: {filename}",
477
  )
478
  logging.info("βœ… Image committed to HF dataset")
 
484
  logging.error(f"Failed to save/commit image: {e}")
485
  return ""
486
 
487
+ # ---------- Orchestrator ----------
488
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
489
+ """End-to-end analysis with robust fallbacks."""
490
  try:
491
  saved_path = self.save_and_commit_image(image_pil)
492
+
493
  visual_results = self.perform_visual_analysis(image_pil)
494
 
495
+ # Patient info summary text
496
  pi = questionnaire_data or {}
497
  patient_info = (
498
+ f"Age: {pi.get('age', 'N/A')}, "
499
+ f"Diabetic: {pi.get('diabetic', 'N/A')}, "
500
+ f"Allergies: {pi.get('allergies', 'N/A')}, "
501
+ f"Date of Wound: {pi.get('date_of_injury', 'N/A')}, "
502
+ f"Professional Care: {pi.get('professional_care', 'N/A')}, "
503
+ f"Oozing/Bleeding: {pi.get('oozing_bleeding', 'N/A')}, "
504
+ f"Infection: {pi.get('infection', 'N/A')}, "
505
+ f"Moisture: {pi.get('moisture', 'N/A')}"
506
  )
507
 
508
+ # Query guidelines
509
  query = (
510
  f"best practices for managing a {visual_results.get('wound_type','Unknown')} "
511
  f"with moisture '{pi.get('moisture','unknown')}' and infection '{pi.get('infection','unknown')}' "
 
513
  )
514
  guideline_context = self.query_guidelines(query)
515
 
516
+ # Generate final report
517
+ report = self.generate_final_report(patient_info=patient_info,
518
+ visual_results=visual_results,
519
+ guideline_context=guideline_context,
520
+ image_pil=image_pil)
521
 
522
  return {
523
  "success": True,
524
  "visual_analysis": visual_results,
525
  "report": report,
526
  "saved_image_path": saved_path,
527
+ "guideline_context": (guideline_context or "")[:500] + ("..." if guideline_context and len(guideline_context) > 500 else ""),
 
 
528
  }
529
  except Exception as e:
530
  logging.error(f"Pipeline error: {e}")
 
538
  }
539
 
540
  def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
541
+ """Public entrypoint used by your UI."""
542
  try:
543
  if isinstance(image, str):
544
  if not os.path.exists(image):
 
561
  "report": f"Analysis initialization failed: {str(e)}",
562
  "saved_image_path": None,
563
  "guideline_context": "",
564
+ }