Spaces:
Sleeping
Sleeping
Update src/ai_processor.py
Browse files- src/ai_processor.py +97 -251
src/ai_processor.py
CHANGED
|
@@ -16,7 +16,7 @@ import cv2
|
|
| 16 |
import numpy as np
|
| 17 |
from PIL import Image
|
| 18 |
from PIL.ExifTags import TAGS
|
| 19 |
-
|
| 20 |
# --- Logging config ---
|
| 21 |
logging.basicConfig(
|
| 22 |
level=getattr(logging, LOGLEVEL, logging.INFO),
|
|
@@ -26,6 +26,12 @@ logging.basicConfig(
|
|
| 26 |
def _log_kv(prefix: str, kv: Dict):
|
| 27 |
logging.debug(prefix + " | " + " | ".join(f"{k}={v}" for k, v in kv.items()))
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# ---- Paths / constants ----
|
| 31 |
UPLOADS_DIR = "uploads"
|
|
@@ -33,7 +39,7 @@ os.makedirs(UPLOADS_DIR, exist_ok=True)
|
|
| 33 |
|
| 34 |
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
| 35 |
YOLO_MODEL_PATH = "src/best.pt"
|
| 36 |
-
SEG_MODEL_PATH = "src/
|
| 37 |
GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
|
| 38 |
DATASET_ID = "SmartHeal/wound-image-uploads"
|
| 39 |
DEFAULT_PX_PER_CM = 38.0
|
|
@@ -117,10 +123,8 @@ SMARTHEAL_USER_PREFIX = """\
|
|
| 117 |
Patient: {patient_info}
|
| 118 |
Visual findings: type={wound_type}, size={length_cm}x{breadth_cm} cm, area={area_cm2} cm^2,
|
| 119 |
detection_conf={det_conf:.2f}, calibration={px_per_cm} px/cm.
|
| 120 |
-
|
| 121 |
Guideline context (snippets you can draw principles from; do not quote at length):
|
| 122 |
{guideline_context}
|
| 123 |
-
|
| 124 |
Write a structured answer with these headings exactly:
|
| 125 |
1. Clinical Summary (max 4 bullet points)
|
| 126 |
2. Likely Stage/Type (if uncertain, say 'uncertain')
|
|
@@ -128,238 +132,53 @@ Write a structured answer with these headings exactly:
|
|
| 128 |
4. Red Flags (what to escalate and when)
|
| 129 |
5. Follow-up Cadence (days)
|
| 130 |
6. Notes (assumptions/uncertainties)
|
| 131 |
-
|
| 132 |
Keep to 220β300 words. Do NOT provide diagnosis. Avoid contraindicated advice.
|
| 133 |
"""
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
|
|
|
| 137 |
"""
|
| 138 |
-
Runs entirely inside a Spaces GPU worker.
|
| 139 |
-
Safe for:
|
| 140 |
-
- CUDA device selection (no 'Invalid device id')
|
| 141 |
-
- BF16/FP16 choice via compute capability
|
| 142 |
-
- LLaVA processors with patch_size=None
|
| 143 |
-
- Processors WITHOUT a chat template (fallback to plain/LLaVA-style prompt)
|
| 144 |
"""
|
| 145 |
-
import logging
|
| 146 |
import torch
|
| 147 |
-
from
|
| 148 |
-
from transformers import (
|
| 149 |
-
AutoProcessor,
|
| 150 |
-
AutoModelForVision2Seq,
|
| 151 |
-
StoppingCriteria,
|
| 152 |
-
StoppingCriteriaList,
|
| 153 |
-
)
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
idx = 0
|
| 161 |
-
try:
|
| 162 |
-
torch.cuda.set_device(idx)
|
| 163 |
-
except Exception as e:
|
| 164 |
-
logging.warning(f"torch.cuda.set_device({idx}) failed: {e}; falling back to CPU.")
|
| 165 |
-
return "cpu", torch.float32
|
| 166 |
-
device = f"cuda:{idx}"
|
| 167 |
-
try:
|
| 168 |
-
props = torch.cuda.get_device_properties(idx)
|
| 169 |
-
cc = props.major * 10 + props.minor
|
| 170 |
-
dtype = torch.bfloat16 if cc >= 80 else torch.float16
|
| 171 |
-
except Exception as e:
|
| 172 |
-
logging.warning(f"Could not query CUDA props: {e}; defaulting to float16.")
|
| 173 |
-
dtype = torch.float16
|
| 174 |
-
return device, dtype
|
| 175 |
-
|
| 176 |
-
device, torch_dtype = _pick_device_and_dtype()
|
| 177 |
-
|
| 178 |
-
# -------- Load model & processor --------
|
| 179 |
-
model = AutoModelForVision2Seq.from_pretrained(
|
| 180 |
-
model_id,
|
| 181 |
-
torch_dtype=torch_dtype,
|
| 182 |
-
trust_remote_code=True,
|
| 183 |
-
low_cpu_mem_usage=True,
|
| 184 |
-
token=token,
|
| 185 |
-
).to(device)
|
| 186 |
-
model.eval()
|
| 187 |
-
|
| 188 |
-
processor = AutoProcessor.from_pretrained(
|
| 189 |
-
model_id, trust_remote_code=True, token=token
|
| 190 |
)
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
text_prompt = ""
|
| 195 |
-
for m in messages:
|
| 196 |
-
if m.get("role") == "user":
|
| 197 |
-
for c in m.get("content", []):
|
| 198 |
-
if c.get("type") == "image":
|
| 199 |
-
image_obj = c.get("image")
|
| 200 |
-
elif c.get("type") == "text":
|
| 201 |
-
text_prompt = c.get("text", "")
|
| 202 |
-
break
|
| 203 |
-
if image_obj is None:
|
| 204 |
-
raise ValueError("No image found in messages for VLM inference.")
|
| 205 |
-
|
| 206 |
-
# -------- Normalize image to PIL --------
|
| 207 |
-
from PIL import Image
|
| 208 |
-
import numpy as np
|
| 209 |
-
def _to_pil(x):
|
| 210 |
-
if isinstance(x, Image.Image):
|
| 211 |
-
return x.convert("RGB")
|
| 212 |
-
if isinstance(x, str):
|
| 213 |
-
return Image.open(x).convert("RGB")
|
| 214 |
-
if isinstance(x, np.ndarray):
|
| 215 |
-
if x.ndim == 2:
|
| 216 |
-
x = np.stack([x]*3, axis=-1)
|
| 217 |
-
if x.dtype != np.uint8:
|
| 218 |
-
x = x.astype(np.uint8)
|
| 219 |
-
return Image.fromarray(x, "RGB")
|
| 220 |
-
if hasattr(x, "read"):
|
| 221 |
-
return Image.open(x).convert("RGB")
|
| 222 |
-
raise TypeError(f"Unsupported image type: {type(x)}")
|
| 223 |
-
image_pil = _to_pil(image_obj)
|
| 224 |
-
|
| 225 |
-
# -------- Ensure patch_size for LLaVA processors --------
|
| 226 |
-
def _ensure_patch_size(proc, mdl):
|
| 227 |
-
ps = getattr(proc, "patch_size", None)
|
| 228 |
-
if not ps:
|
| 229 |
-
candidates = [
|
| 230 |
-
getattr(getattr(mdl, "vision_tower", None), "config", None),
|
| 231 |
-
getattr(mdl.config, "vision_config", None),
|
| 232 |
-
getattr(proc, "image_processor", None),
|
| 233 |
-
getattr(getattr(proc, "image_processor", None), "config", None),
|
| 234 |
-
]
|
| 235 |
-
for obj in candidates:
|
| 236 |
-
if obj is None:
|
| 237 |
-
continue
|
| 238 |
-
maybe = getattr(obj, "patch_size", None)
|
| 239 |
-
if maybe:
|
| 240 |
-
ps = int(maybe); break
|
| 241 |
-
if not ps:
|
| 242 |
-
ps = 14 # safe default for ViT-L/14-style
|
| 243 |
-
try:
|
| 244 |
-
setattr(proc, "patch_size", ps)
|
| 245 |
-
except Exception:
|
| 246 |
-
pass
|
| 247 |
-
return ps
|
| 248 |
-
_ensure_patch_size(processor, model)
|
| 249 |
-
|
| 250 |
-
# -------- Build text (chat-template only if it truly exists) --------
|
| 251 |
-
# Some processors expose apply_chat_template but tokenizer has no template β ValueError. Guard it.
|
| 252 |
-
tokenizer = getattr(processor, "tokenizer", None)
|
| 253 |
-
has_template = bool(getattr(tokenizer, "chat_template", None))
|
| 254 |
-
used_chat_template = False
|
| 255 |
-
|
| 256 |
-
def _looks_like_llava():
|
| 257 |
-
name = processor.__class__.__name__.lower()
|
| 258 |
-
mid = (model_id or "").lower()
|
| 259 |
-
return ("llava" in name) or ("llava" in mid)
|
| 260 |
-
|
| 261 |
-
if hasattr(processor, "apply_chat_template") and has_template:
|
| 262 |
-
try:
|
| 263 |
-
chat = [{
|
| 264 |
-
"role": "user",
|
| 265 |
-
"content": [
|
| 266 |
-
{"type": "image", "image": image_pil},
|
| 267 |
-
{"type": "text", "text": text_prompt or "Describe the image."},
|
| 268 |
-
],
|
| 269 |
-
}]
|
| 270 |
-
text_for_model = processor.apply_chat_template(
|
| 271 |
-
chat, add_generation_prompt=True, tokenize=False
|
| 272 |
-
)
|
| 273 |
-
used_chat_template = True
|
| 274 |
-
except Exception as e:
|
| 275 |
-
logging.info(f"No usable chat template ({e}); falling back to plain prompt.")
|
| 276 |
-
text_for_model = (
|
| 277 |
-
f"USER: <image>\n{text_prompt or 'Describe the image.'}\nASSISTANT:"
|
| 278 |
-
if _looks_like_llava() else (text_prompt or "Describe the image.")
|
| 279 |
-
)
|
| 280 |
-
else:
|
| 281 |
-
text_for_model = (
|
| 282 |
-
f"USER: <image>\n{text_prompt or 'Describe the image.'}\nASSISTANT:"
|
| 283 |
-
if _looks_like_llava() else (text_prompt or "Describe the image.")
|
| 284 |
-
)
|
| 285 |
-
|
| 286 |
-
# -------- Tokenize --------
|
| 287 |
-
inputs = processor(
|
| 288 |
-
text=[text_for_model],
|
| 289 |
-
images=[image_pil],
|
| 290 |
-
return_tensors="pt",
|
| 291 |
-
padding=True,
|
| 292 |
-
).to(device)
|
| 293 |
-
|
| 294 |
-
# -------- Stopping criteria --------
|
| 295 |
-
class EosTokenCriteria(StoppingCriteria):
|
| 296 |
-
def __init__(self, eos_token_ids: List[int]):
|
| 297 |
-
import torch as _t
|
| 298 |
-
self.eos = _t.tensor(eos_token_ids, dtype=_t.long)
|
| 299 |
-
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
| 300 |
-
import torch as _t
|
| 301 |
-
last_tok = input_ids[:, -1]
|
| 302 |
-
return _t.isin(last_tok, self.eos.to(last_tok.device)).any().item()
|
| 303 |
-
|
| 304 |
-
eos_ids: List[int] = []
|
| 305 |
-
if tokenizer is not None:
|
| 306 |
-
for attr in ("eos_token_id", "eot_token_id"):
|
| 307 |
-
v = getattr(tokenizer, attr, None)
|
| 308 |
-
if v is None: continue
|
| 309 |
-
eos_ids.extend([v] if isinstance(v, int) else list(v))
|
| 310 |
-
if not eos_ids:
|
| 311 |
-
cfg = getattr(model, "generation_config", None)
|
| 312 |
-
if cfg and getattr(cfg, "eos_token_id", None) is not None:
|
| 313 |
-
eos_ids = [cfg.eos_token_id]
|
| 314 |
-
else:
|
| 315 |
-
eos_ids = [2]
|
| 316 |
-
stopping_criteria = StoppingCriteriaList([EosTokenCriteria(eos_ids)])
|
| 317 |
-
|
| 318 |
-
if tokenizer is not None and getattr(tokenizer, "pad_token_id", None) is None:
|
| 319 |
-
try: tokenizer.pad_token_id = eos_ids[0]
|
| 320 |
-
except Exception: pass
|
| 321 |
-
|
| 322 |
-
# -------- Generate --------
|
| 323 |
-
gen_kwargs = dict(
|
| 324 |
-
max_new_tokens=int(max_new_tokens or 256),
|
| 325 |
do_sample=False,
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
pad_token_id=getattr(tokenizer, "pad_token_id", None) if tokenizer else None,
|
| 329 |
)
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
if "input_ids" in inputs:
|
| 336 |
-
cut = inputs["input_ids"].shape[-1]
|
| 337 |
-
seq = seq[cut:]
|
| 338 |
-
if tokenizer is not None:
|
| 339 |
-
text_out = tokenizer.decode(seq, skip_special_tokens=True)
|
| 340 |
-
elif hasattr(processor, "batch_decode"):
|
| 341 |
-
text_out = processor.batch_decode(seq.unsqueeze(0), skip_special_tokens=True)[0]
|
| 342 |
-
else:
|
| 343 |
-
text_out = str(seq.tolist())
|
| 344 |
-
|
| 345 |
-
return text_out.strip()
|
| 346 |
-
|
| 347 |
|
| 348 |
-
def generate_medgemma_report(
|
| 349 |
patient_info: str,
|
| 350 |
visual_results: Dict,
|
| 351 |
guideline_context: str,
|
| 352 |
-
image_pil: Image.Image,
|
| 353 |
max_new_tokens: Optional[int] = None,
|
| 354 |
) -> str:
|
| 355 |
"""
|
| 356 |
-
MedGemma
|
| 357 |
-
|
| 358 |
"""
|
| 359 |
if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
|
| 360 |
return "β οΈ VLM disabled"
|
| 361 |
|
| 362 |
-
|
|
|
|
| 363 |
max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
|
| 364 |
|
| 365 |
uprompt = SMARTHEAL_USER_PREFIX.format(
|
|
@@ -373,28 +192,69 @@ def generate_medgemma_report(
|
|
| 373 |
guideline_context=(guideline_context or "")[:900],
|
| 374 |
)
|
| 375 |
|
| 376 |
-
#
|
| 377 |
-
|
| 378 |
-
{"role": "system", "content": [{"type": "text", "text": SMARTHEAL_SYSTEM_PROMPT}]},
|
| 379 |
-
{"role": "user", "content": [
|
| 380 |
-
{"type": "image", "image": image_pil},
|
| 381 |
-
{"type": "text", "text": uprompt},
|
| 382 |
-
]},
|
| 383 |
-
]
|
| 384 |
|
| 385 |
try:
|
| 386 |
-
return
|
| 387 |
except Exception as e:
|
| 388 |
-
logging.error(f"
|
| 389 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
# ---------- Initialize CPU models ----------
|
| 391 |
def load_yolo_model():
|
| 392 |
YOLO = _import_ultralytics()
|
| 393 |
-
# Construct model with CUDA masked to avoid auto-selecting cuda:0
|
| 394 |
with _no_cuda_env():
|
| 395 |
model = YOLO(YOLO_MODEL_PATH)
|
| 396 |
return model
|
| 397 |
-
|
| 398 |
def load_segmentation_model():
|
| 399 |
import os; os.environ.setdefault("KERAS_BACKEND","tensorflow")
|
| 400 |
import tensorflow as tf; tf.config.set_visible_devices([], "GPU")
|
|
@@ -428,11 +288,11 @@ def initialize_cpu_models() -> None:
|
|
| 428 |
if "seg" not in models_cache:
|
| 429 |
try:
|
| 430 |
if os.path.exists(SEG_MODEL_PATH):
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
oshape = getattr(m, "output_shape", None)
|
| 435 |
-
logging.info(f"β
Segmentation model loaded (CPU) |
|
| 436 |
else:
|
| 437 |
models_cache["seg"] = None
|
| 438 |
logging.warning("Segmentation model file missing; skipping.")
|
|
@@ -650,11 +510,7 @@ def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndar
|
|
| 650 |
# --- Model path ---
|
| 651 |
if seg_model is not None:
|
| 652 |
try:
|
| 653 |
-
|
| 654 |
-
if not ishape or len(ishape) < 4:
|
| 655 |
-
raise ValueError(f"Bad seg input_shape: {ishape}")
|
| 656 |
-
th, tw = int(ishape[1]), int(ishape[2])
|
| 657 |
-
|
| 658 |
x = _preprocess_for_seg(image_bgr, (th, tw))
|
| 659 |
roi_seen_path = None
|
| 660 |
if SMARTHEAL_DEBUG:
|
|
@@ -745,7 +601,7 @@ def measure_min_area_rect(mask01: np.ndarray, px_per_cm: float) -> Tuple[float,
|
|
| 745 |
cnt = max(contours, key=cv2.contourArea)
|
| 746 |
rect = cv2.minAreaRect(cnt)
|
| 747 |
(w_px, h_px) = rect[1]
|
| 748 |
-
length_px, breadth_px = (max(w_px, h_px), min(
|
| 749 |
length_cm = round(length_px / max(px_per_cm, 1e-6), 2)
|
| 750 |
breadth_cm = round(breadth_px / max(px_per_cm, 1e-6), 2)
|
| 751 |
box = cv2.boxPoints(rect).astype(int)
|
|
@@ -1004,7 +860,6 @@ class AIProcessor:
|
|
| 1004 |
if not vs:
|
| 1005 |
return "Knowledge base is not available."
|
| 1006 |
retriever = vs.as_retriever(search_kwargs={"k": 5})
|
| 1007 |
-
# Modern API (avoid get_relevant_documents deprecation)
|
| 1008 |
docs = retriever.invoke(query)
|
| 1009 |
lines: List[str] = []
|
| 1010 |
for d in docs:
|
|
@@ -1018,38 +873,30 @@ class AIProcessor:
|
|
| 1018 |
|
| 1019 |
def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
|
| 1020 |
return f"""# π©Ί SmartHeal AI - Comprehensive Wound Analysis Report
|
| 1021 |
-
|
| 1022 |
## π Patient Information
|
| 1023 |
{patient_info}
|
| 1024 |
-
|
| 1025 |
## π Visual Analysis Results
|
| 1026 |
- **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
|
| 1027 |
- **Dimensions**: {visual_results.get('length_cm', 0)} cm Γ {visual_results.get('breadth_cm', 0)} cm
|
| 1028 |
- **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cmΒ²
|
| 1029 |
- **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
|
| 1030 |
- **Calibration**: {visual_results.get('px_per_cm','?')} px/cm ({(visual_results.get('calibration_meta') or {}).get('used','default')})
|
| 1031 |
-
|
| 1032 |
## π Analysis Images
|
| 1033 |
- **Original**: {visual_results.get('original_image_path', 'N/A')}
|
| 1034 |
- **Detection**: {visual_results.get('detection_image_path', 'N/A')}
|
| 1035 |
- **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
|
| 1036 |
- **Annotated**: {visual_results.get('segmentation_annotated_path', 'N/A')}
|
| 1037 |
-
|
| 1038 |
## π― Clinical Summary
|
| 1039 |
Automated analysis provides quantitative measurements; verify via clinical examination.
|
| 1040 |
-
|
| 1041 |
## π Recommendations
|
| 1042 |
- Cleanse wound gently; select dressing per exudate/infection risk
|
| 1043 |
- Debride necrotic tissue if indicated (clinical decision)
|
| 1044 |
- Document with serial photos and measurements
|
| 1045 |
-
|
| 1046 |
## π
Monitoring
|
| 1047 |
- Daily in week 1, then every 2β3 days (or as indicated)
|
| 1048 |
- Weekly progress review
|
| 1049 |
-
|
| 1050 |
## π Guideline Context
|
| 1051 |
{(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''}
|
| 1052 |
-
|
| 1053 |
**Disclaimer:** Automated, for decision support only. Verify clinically.
|
| 1054 |
"""
|
| 1055 |
|
|
@@ -1103,8 +950,7 @@ Automated analysis provides quantitative measurements; verify via clinical exami
|
|
| 1103 |
except Exception as e:
|
| 1104 |
logging.error(f"Failed to save/commit image: {e}")
|
| 1105 |
return ""
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
|
| 1109 |
try:
|
| 1110 |
saved_path = self.save_and_commit_image(image_pil)
|
|
@@ -1150,7 +996,7 @@ Automated analysis provides quantitative measurements; verify via clinical exami
|
|
| 1150 |
"saved_image_path": None,
|
| 1151 |
"guideline_context": "",
|
| 1152 |
}
|
| 1153 |
-
|
| 1154 |
def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
|
| 1155 |
try:
|
| 1156 |
if isinstance(image, str):
|
|
@@ -1174,4 +1020,4 @@ Automated analysis provides quantitative measurements; verify via clinical exami
|
|
| 1174 |
"report": f"Analysis initialization failed: {str(e)}",
|
| 1175 |
"saved_image_path": None,
|
| 1176 |
"guideline_context": "",
|
| 1177 |
-
}
|
|
|
|
| 16 |
import numpy as np
|
| 17 |
from PIL import Image
|
| 18 |
from PIL.ExifTags import TAGS
|
| 19 |
+
|
| 20 |
# --- Logging config ---
|
| 21 |
logging.basicConfig(
|
| 22 |
level=getattr(logging, LOGLEVEL, logging.INFO),
|
|
|
|
| 26 |
def _log_kv(prefix: str, kv: Dict):
|
| 27 |
logging.debug(prefix + " | " + " | ".join(f"{k}={v}" for k, v in kv.items()))
|
| 28 |
|
| 29 |
+
# --- Spaces GPU decorator (REQUIRED) ---
|
| 30 |
+
from spaces import GPU as _SPACES_GPU
|
| 31 |
+
|
| 32 |
+
@_SPACES_GPU(enable_queue=True)
|
| 33 |
+
def smartheal_gpu_stub(ping: int = 0) -> str:
|
| 34 |
+
return "ready"
|
| 35 |
|
| 36 |
# ---- Paths / constants ----
|
| 37 |
UPLOADS_DIR = "uploads"
|
|
|
|
| 39 |
|
| 40 |
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
| 41 |
YOLO_MODEL_PATH = "src/best.pt"
|
| 42 |
+
SEG_MODEL_PATH = "src/segmentation_model.h5" # optional; legacy .h5 supported
|
| 43 |
GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
|
| 44 |
DATASET_ID = "SmartHeal/wound-image-uploads"
|
| 45 |
DEFAULT_PX_PER_CM = 38.0
|
|
|
|
| 123 |
Patient: {patient_info}
|
| 124 |
Visual findings: type={wound_type}, size={length_cm}x{breadth_cm} cm, area={area_cm2} cm^2,
|
| 125 |
detection_conf={det_conf:.2f}, calibration={px_per_cm} px/cm.
|
|
|
|
| 126 |
Guideline context (snippets you can draw principles from; do not quote at length):
|
| 127 |
{guideline_context}
|
|
|
|
| 128 |
Write a structured answer with these headings exactly:
|
| 129 |
1. Clinical Summary (max 4 bullet points)
|
| 130 |
2. Likely Stage/Type (if uncertain, say 'uncertain')
|
|
|
|
| 132 |
4. Red Flags (what to escalate and when)
|
| 133 |
5. Follow-up Cadence (days)
|
| 134 |
6. Notes (assumptions/uncertainties)
|
|
|
|
| 135 |
Keep to 220β300 words. Do NOT provide diagnosis. Avoid contraindicated advice.
|
| 136 |
"""
|
| 137 |
|
| 138 |
+
# ---------- MedGemma-only text generator ----------
|
| 139 |
+
@_SPACES_GPU(enable_queue=True)
|
| 140 |
+
def _medgemma_generate_gpu(prompt: str, model_id: str, max_new_tokens: int, token: Optional[str]):
|
| 141 |
"""
|
| 142 |
+
Runs entirely inside a Spaces GPU worker. Uses Med-Gemma (text-only) to draft the report.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
"""
|
|
|
|
| 144 |
import torch
|
| 145 |
+
from transformers import pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
+
pipe = pipeline(
|
| 148 |
+
"image-text-to-text",
|
| 149 |
+
model="google/medgemma-4b-it",
|
| 150 |
+
torch_dtype=torch.bfloat16,
|
| 151 |
+
device="cuda",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
)
|
| 153 |
+
out = pipe(
|
| 154 |
+
prompt,
|
| 155 |
+
max_new_tokens=max_new_tokens,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
do_sample=False,
|
| 157 |
+
temperature=0.2,
|
| 158 |
+
return_full_text=True,
|
|
|
|
| 159 |
)
|
| 160 |
+
text = (out[0].get("generated_text") if isinstance(out, list) else out).strip()
|
| 161 |
+
# Remove the prompt echo if present
|
| 162 |
+
if text.startswith(prompt):
|
| 163 |
+
text = text[len(prompt):].lstrip()
|
| 164 |
+
return text or "β οΈ Empty response"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
+
def generate_medgemma_report( # kept name so callers don't change
|
| 167 |
patient_info: str,
|
| 168 |
visual_results: Dict,
|
| 169 |
guideline_context: str,
|
| 170 |
+
image_pil: Image.Image, # kept for signature compatibility; not used by MedGemma
|
| 171 |
max_new_tokens: Optional[int] = None,
|
| 172 |
) -> str:
|
| 173 |
"""
|
| 174 |
+
MedGemma (text-only) report generation.
|
| 175 |
+
The image is analyzed by the vision pipeline; MedGemma formats clinical guidance text.
|
| 176 |
"""
|
| 177 |
if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
|
| 178 |
return "β οΈ VLM disabled"
|
| 179 |
|
| 180 |
+
# Default to a public Med-Gemma instruction-tuned model (update via env if you have access to another).
|
| 181 |
+
model_id = os.getenv("SMARTHEAL_MEDGEMMA_MODEL", "google/med-gemma-2-2b-it")
|
| 182 |
max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
|
| 183 |
|
| 184 |
uprompt = SMARTHEAL_USER_PREFIX.format(
|
|
|
|
| 192 |
guideline_context=(guideline_context or "")[:900],
|
| 193 |
)
|
| 194 |
|
| 195 |
+
# Compose a single text prompt
|
| 196 |
+
prompt = f"{SMARTHEAL_SYSTEM_PROMPT}\n\n{uprompt}\n\nAnswer:"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
try:
|
| 199 |
+
return _medgemma_generate_gpu(prompt, model_id, max_new_tokens, HF_TOKEN)
|
| 200 |
except Exception as e:
|
| 201 |
+
logging.error(f"MedGemma call failed: {e}")
|
| 202 |
+
return "β οΈ VLM error"
|
| 203 |
+
|
| 204 |
+
# ---------- Input-shape helpers (avoid `.as_list()` on strings) ----------
|
| 205 |
+
def _shape_to_hw(shape) -> Tuple[Optional[int], Optional[int]]:
|
| 206 |
+
try:
|
| 207 |
+
if hasattr(shape, "as_list"):
|
| 208 |
+
shape = shape.as_list()
|
| 209 |
+
except Exception:
|
| 210 |
+
pass
|
| 211 |
+
if isinstance(shape, (tuple, list)):
|
| 212 |
+
if len(shape) == 4: # (None, H, W, C)
|
| 213 |
+
H, W = shape[1], shape[2]
|
| 214 |
+
elif len(shape) == 3: # (H, W, C)
|
| 215 |
+
H, W = shape[0], shape[1]
|
| 216 |
+
else:
|
| 217 |
+
return (None, None)
|
| 218 |
+
try: H = int(H) if (H is not None and str(H).lower() != "none") else None
|
| 219 |
+
except Exception: H = None
|
| 220 |
+
try: W = int(W) if (W is not None and str(W).lower() != "none") else None
|
| 221 |
+
except Exception: W = None
|
| 222 |
+
return (H, W)
|
| 223 |
+
return (None, None)
|
| 224 |
+
|
| 225 |
+
def _get_model_input_hw(model, default_hw: Tuple[int, int] = (224, 224)) -> Tuple[int, int]:
|
| 226 |
+
H, W = _shape_to_hw(getattr(model, "input_shape", None))
|
| 227 |
+
if H and W:
|
| 228 |
+
return H, W
|
| 229 |
+
try:
|
| 230 |
+
inputs = getattr(model, "inputs", None)
|
| 231 |
+
if inputs:
|
| 232 |
+
H, W = _shape_to_hw(inputs[0].shape)
|
| 233 |
+
if H and W:
|
| 234 |
+
return H, W
|
| 235 |
+
except Exception:
|
| 236 |
+
pass
|
| 237 |
+
try:
|
| 238 |
+
cfg = model.get_config() if hasattr(model, "get_config") else None
|
| 239 |
+
if isinstance(cfg, dict):
|
| 240 |
+
for layer in cfg.get("layers", []):
|
| 241 |
+
conf = (layer or {}).get("config", {})
|
| 242 |
+
cand = conf.get("batch_input_shape") or conf.get("batch_shape")
|
| 243 |
+
H, W = _shape_to_hw(cand)
|
| 244 |
+
if H and W:
|
| 245 |
+
return H, W
|
| 246 |
+
except Exception:
|
| 247 |
+
pass
|
| 248 |
+
logging.warning(f"Could not resolve model input shape; using default {default_hw}.")
|
| 249 |
+
return default_hw
|
| 250 |
+
|
| 251 |
# ---------- Initialize CPU models ----------
|
| 252 |
def load_yolo_model():
|
| 253 |
YOLO = _import_ultralytics()
|
|
|
|
| 254 |
with _no_cuda_env():
|
| 255 |
model = YOLO(YOLO_MODEL_PATH)
|
| 256 |
return model
|
| 257 |
+
|
| 258 |
def load_segmentation_model():
|
| 259 |
import os; os.environ.setdefault("KERAS_BACKEND","tensorflow")
|
| 260 |
import tensorflow as tf; tf.config.set_visible_devices([], "GPU")
|
|
|
|
| 288 |
if "seg" not in models_cache:
|
| 289 |
try:
|
| 290 |
if os.path.exists(SEG_MODEL_PATH):
|
| 291 |
+
m = load_segmentation_model() # uses global path by default
|
| 292 |
+
models_cache["seg"] = m
|
| 293 |
+
th, tw = _get_model_input_hw(m, default_hw=(224, 224))
|
| 294 |
oshape = getattr(m, "output_shape", None)
|
| 295 |
+
logging.info(f"β
Segmentation model loaded (CPU) | input_hw=({th},{tw}) output_shape={oshape}")
|
| 296 |
else:
|
| 297 |
models_cache["seg"] = None
|
| 298 |
logging.warning("Segmentation model file missing; skipping.")
|
|
|
|
| 510 |
# --- Model path ---
|
| 511 |
if seg_model is not None:
|
| 512 |
try:
|
| 513 |
+
th, tw = _get_model_input_hw(seg_model, default_hw=(224, 224))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
x = _preprocess_for_seg(image_bgr, (th, tw))
|
| 515 |
roi_seen_path = None
|
| 516 |
if SMARTHEAL_DEBUG:
|
|
|
|
| 601 |
cnt = max(contours, key=cv2.contourArea)
|
| 602 |
rect = cv2.minAreaRect(cnt)
|
| 603 |
(w_px, h_px) = rect[1]
|
| 604 |
+
length_px, breadth_px = (max(w_px, h_px), min(h_px, w_px))
|
| 605 |
length_cm = round(length_px / max(px_per_cm, 1e-6), 2)
|
| 606 |
breadth_cm = round(breadth_px / max(px_per_cm, 1e-6), 2)
|
| 607 |
box = cv2.boxPoints(rect).astype(int)
|
|
|
|
| 860 |
if not vs:
|
| 861 |
return "Knowledge base is not available."
|
| 862 |
retriever = vs.as_retriever(search_kwargs={"k": 5})
|
|
|
|
| 863 |
docs = retriever.invoke(query)
|
| 864 |
lines: List[str] = []
|
| 865 |
for d in docs:
|
|
|
|
| 873 |
|
| 874 |
def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
|
| 875 |
return f"""# π©Ί SmartHeal AI - Comprehensive Wound Analysis Report
|
|
|
|
| 876 |
## π Patient Information
|
| 877 |
{patient_info}
|
|
|
|
| 878 |
## π Visual Analysis Results
|
| 879 |
- **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
|
| 880 |
- **Dimensions**: {visual_results.get('length_cm', 0)} cm Γ {visual_results.get('breadth_cm', 0)} cm
|
| 881 |
- **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cmΒ²
|
| 882 |
- **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
|
| 883 |
- **Calibration**: {visual_results.get('px_per_cm','?')} px/cm ({(visual_results.get('calibration_meta') or {}).get('used','default')})
|
|
|
|
| 884 |
## π Analysis Images
|
| 885 |
- **Original**: {visual_results.get('original_image_path', 'N/A')}
|
| 886 |
- **Detection**: {visual_results.get('detection_image_path', 'N/A')}
|
| 887 |
- **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
|
| 888 |
- **Annotated**: {visual_results.get('segmentation_annotated_path', 'N/A')}
|
|
|
|
| 889 |
## π― Clinical Summary
|
| 890 |
Automated analysis provides quantitative measurements; verify via clinical examination.
|
|
|
|
| 891 |
## π Recommendations
|
| 892 |
- Cleanse wound gently; select dressing per exudate/infection risk
|
| 893 |
- Debride necrotic tissue if indicated (clinical decision)
|
| 894 |
- Document with serial photos and measurements
|
|
|
|
| 895 |
## π
Monitoring
|
| 896 |
- Daily in week 1, then every 2β3 days (or as indicated)
|
| 897 |
- Weekly progress review
|
|
|
|
| 898 |
## π Guideline Context
|
| 899 |
{(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''}
|
|
|
|
| 900 |
**Disclaimer:** Automated, for decision support only. Verify clinically.
|
| 901 |
"""
|
| 902 |
|
|
|
|
| 950 |
except Exception as e:
|
| 951 |
logging.error(f"Failed to save/commit image: {e}")
|
| 952 |
return ""
|
| 953 |
+
|
|
|
|
| 954 |
def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
|
| 955 |
try:
|
| 956 |
saved_path = self.save_and_commit_image(image_pil)
|
|
|
|
| 996 |
"saved_image_path": None,
|
| 997 |
"guideline_context": "",
|
| 998 |
}
|
| 999 |
+
|
| 1000 |
def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
|
| 1001 |
try:
|
| 1002 |
if isinstance(image, str):
|
|
|
|
| 1020 |
"report": f"Analysis initialization failed: {str(e)}",
|
| 1021 |
"saved_image_path": None,
|
| 1022 |
"guideline_context": "",
|
| 1023 |
+
}
|