SmartHeal commited on
Commit
acd4594
Β·
verified Β·
1 Parent(s): 49e5ed3

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +97 -251
src/ai_processor.py CHANGED
@@ -16,7 +16,7 @@ import cv2
16
  import numpy as np
17
  from PIL import Image
18
  from PIL.ExifTags import TAGS
19
- import spaces
20
  # --- Logging config ---
21
  logging.basicConfig(
22
  level=getattr(logging, LOGLEVEL, logging.INFO),
@@ -26,6 +26,12 @@ logging.basicConfig(
26
  def _log_kv(prefix: str, kv: Dict):
27
  logging.debug(prefix + " | " + " | ".join(f"{k}={v}" for k, v in kv.items()))
28
 
 
 
 
 
 
 
29
 
30
  # ---- Paths / constants ----
31
  UPLOADS_DIR = "uploads"
@@ -33,7 +39,7 @@ os.makedirs(UPLOADS_DIR, exist_ok=True)
33
 
34
  HF_TOKEN = os.getenv("HF_TOKEN", None)
35
  YOLO_MODEL_PATH = "src/best.pt"
36
- SEG_MODEL_PATH = "src/segmentation_model_fixed.h5" # optional
37
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
38
  DATASET_ID = "SmartHeal/wound-image-uploads"
39
  DEFAULT_PX_PER_CM = 38.0
@@ -117,10 +123,8 @@ SMARTHEAL_USER_PREFIX = """\
117
  Patient: {patient_info}
118
  Visual findings: type={wound_type}, size={length_cm}x{breadth_cm} cm, area={area_cm2} cm^2,
119
  detection_conf={det_conf:.2f}, calibration={px_per_cm} px/cm.
120
-
121
  Guideline context (snippets you can draw principles from; do not quote at length):
122
  {guideline_context}
123
-
124
  Write a structured answer with these headings exactly:
125
  1. Clinical Summary (max 4 bullet points)
126
  2. Likely Stage/Type (if uncertain, say 'uncertain')
@@ -128,238 +132,53 @@ Write a structured answer with these headings exactly:
128
  4. Red Flags (what to escalate and when)
129
  5. Follow-up Cadence (days)
130
  6. Notes (assumptions/uncertainties)
131
-
132
  Keep to 220–300 words. Do NOT provide diagnosis. Avoid contraindicated advice.
133
  """
134
 
135
-
136
- def _vlm_infer_gpu(messages, model_id: str, max_new_tokens: int, token: Optional[str]):
 
137
  """
138
- Runs entirely inside a Spaces GPU worker. It's the ONLY place we allow CUDA init.
139
- Safe for:
140
- - CUDA device selection (no 'Invalid device id')
141
- - BF16/FP16 choice via compute capability
142
- - LLaVA processors with patch_size=None
143
- - Processors WITHOUT a chat template (fallback to plain/LLaVA-style prompt)
144
  """
145
- import logging
146
  import torch
147
- from typing import Optional, List
148
- from transformers import (
149
- AutoProcessor,
150
- AutoModelForVision2Seq,
151
- StoppingCriteria,
152
- StoppingCriteriaList,
153
- )
154
 
155
- # -------- Device & dtype (robust) --------
156
- def _pick_device_and_dtype():
157
- if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
158
- logging.warning("CUDA not available; using CPU.")
159
- return "cpu", torch.float32
160
- idx = 0
161
- try:
162
- torch.cuda.set_device(idx)
163
- except Exception as e:
164
- logging.warning(f"torch.cuda.set_device({idx}) failed: {e}; falling back to CPU.")
165
- return "cpu", torch.float32
166
- device = f"cuda:{idx}"
167
- try:
168
- props = torch.cuda.get_device_properties(idx)
169
- cc = props.major * 10 + props.minor
170
- dtype = torch.bfloat16 if cc >= 80 else torch.float16
171
- except Exception as e:
172
- logging.warning(f"Could not query CUDA props: {e}; defaulting to float16.")
173
- dtype = torch.float16
174
- return device, dtype
175
-
176
- device, torch_dtype = _pick_device_and_dtype()
177
-
178
- # -------- Load model & processor --------
179
- model = AutoModelForVision2Seq.from_pretrained(
180
- model_id,
181
- torch_dtype=torch_dtype,
182
- trust_remote_code=True,
183
- low_cpu_mem_usage=True,
184
- token=token,
185
- ).to(device)
186
- model.eval()
187
-
188
- processor = AutoProcessor.from_pretrained(
189
- model_id, trust_remote_code=True, token=token
190
  )
191
-
192
- # -------- Extract image & text --------
193
- image_obj = None
194
- text_prompt = ""
195
- for m in messages:
196
- if m.get("role") == "user":
197
- for c in m.get("content", []):
198
- if c.get("type") == "image":
199
- image_obj = c.get("image")
200
- elif c.get("type") == "text":
201
- text_prompt = c.get("text", "")
202
- break
203
- if image_obj is None:
204
- raise ValueError("No image found in messages for VLM inference.")
205
-
206
- # -------- Normalize image to PIL --------
207
- from PIL import Image
208
- import numpy as np
209
- def _to_pil(x):
210
- if isinstance(x, Image.Image):
211
- return x.convert("RGB")
212
- if isinstance(x, str):
213
- return Image.open(x).convert("RGB")
214
- if isinstance(x, np.ndarray):
215
- if x.ndim == 2:
216
- x = np.stack([x]*3, axis=-1)
217
- if x.dtype != np.uint8:
218
- x = x.astype(np.uint8)
219
- return Image.fromarray(x, "RGB")
220
- if hasattr(x, "read"):
221
- return Image.open(x).convert("RGB")
222
- raise TypeError(f"Unsupported image type: {type(x)}")
223
- image_pil = _to_pil(image_obj)
224
-
225
- # -------- Ensure patch_size for LLaVA processors --------
226
- def _ensure_patch_size(proc, mdl):
227
- ps = getattr(proc, "patch_size", None)
228
- if not ps:
229
- candidates = [
230
- getattr(getattr(mdl, "vision_tower", None), "config", None),
231
- getattr(mdl.config, "vision_config", None),
232
- getattr(proc, "image_processor", None),
233
- getattr(getattr(proc, "image_processor", None), "config", None),
234
- ]
235
- for obj in candidates:
236
- if obj is None:
237
- continue
238
- maybe = getattr(obj, "patch_size", None)
239
- if maybe:
240
- ps = int(maybe); break
241
- if not ps:
242
- ps = 14 # safe default for ViT-L/14-style
243
- try:
244
- setattr(proc, "patch_size", ps)
245
- except Exception:
246
- pass
247
- return ps
248
- _ensure_patch_size(processor, model)
249
-
250
- # -------- Build text (chat-template only if it truly exists) --------
251
- # Some processors expose apply_chat_template but tokenizer has no template β†’ ValueError. Guard it.
252
- tokenizer = getattr(processor, "tokenizer", None)
253
- has_template = bool(getattr(tokenizer, "chat_template", None))
254
- used_chat_template = False
255
-
256
- def _looks_like_llava():
257
- name = processor.__class__.__name__.lower()
258
- mid = (model_id or "").lower()
259
- return ("llava" in name) or ("llava" in mid)
260
-
261
- if hasattr(processor, "apply_chat_template") and has_template:
262
- try:
263
- chat = [{
264
- "role": "user",
265
- "content": [
266
- {"type": "image", "image": image_pil},
267
- {"type": "text", "text": text_prompt or "Describe the image."},
268
- ],
269
- }]
270
- text_for_model = processor.apply_chat_template(
271
- chat, add_generation_prompt=True, tokenize=False
272
- )
273
- used_chat_template = True
274
- except Exception as e:
275
- logging.info(f"No usable chat template ({e}); falling back to plain prompt.")
276
- text_for_model = (
277
- f"USER: <image>\n{text_prompt or 'Describe the image.'}\nASSISTANT:"
278
- if _looks_like_llava() else (text_prompt or "Describe the image.")
279
- )
280
- else:
281
- text_for_model = (
282
- f"USER: <image>\n{text_prompt or 'Describe the image.'}\nASSISTANT:"
283
- if _looks_like_llava() else (text_prompt or "Describe the image.")
284
- )
285
-
286
- # -------- Tokenize --------
287
- inputs = processor(
288
- text=[text_for_model],
289
- images=[image_pil],
290
- return_tensors="pt",
291
- padding=True,
292
- ).to(device)
293
-
294
- # -------- Stopping criteria --------
295
- class EosTokenCriteria(StoppingCriteria):
296
- def __init__(self, eos_token_ids: List[int]):
297
- import torch as _t
298
- self.eos = _t.tensor(eos_token_ids, dtype=_t.long)
299
- def __call__(self, input_ids, scores, **kwargs) -> bool:
300
- import torch as _t
301
- last_tok = input_ids[:, -1]
302
- return _t.isin(last_tok, self.eos.to(last_tok.device)).any().item()
303
-
304
- eos_ids: List[int] = []
305
- if tokenizer is not None:
306
- for attr in ("eos_token_id", "eot_token_id"):
307
- v = getattr(tokenizer, attr, None)
308
- if v is None: continue
309
- eos_ids.extend([v] if isinstance(v, int) else list(v))
310
- if not eos_ids:
311
- cfg = getattr(model, "generation_config", None)
312
- if cfg and getattr(cfg, "eos_token_id", None) is not None:
313
- eos_ids = [cfg.eos_token_id]
314
- else:
315
- eos_ids = [2]
316
- stopping_criteria = StoppingCriteriaList([EosTokenCriteria(eos_ids)])
317
-
318
- if tokenizer is not None and getattr(tokenizer, "pad_token_id", None) is None:
319
- try: tokenizer.pad_token_id = eos_ids[0]
320
- except Exception: pass
321
-
322
- # -------- Generate --------
323
- gen_kwargs = dict(
324
- max_new_tokens=int(max_new_tokens or 256),
325
  do_sample=False,
326
- stopping_criteria=stopping_criteria,
327
- eos_token_id=eos_ids[0] if eos_ids else None,
328
- pad_token_id=getattr(tokenizer, "pad_token_id", None) if tokenizer else None,
329
  )
330
- with torch.inference_mode():
331
- out = model.generate(**inputs, **gen_kwargs)
332
-
333
- # -------- Decode --------
334
- seq = out[0]
335
- if "input_ids" in inputs:
336
- cut = inputs["input_ids"].shape[-1]
337
- seq = seq[cut:]
338
- if tokenizer is not None:
339
- text_out = tokenizer.decode(seq, skip_special_tokens=True)
340
- elif hasattr(processor, "batch_decode"):
341
- text_out = processor.batch_decode(seq.unsqueeze(0), skip_special_tokens=True)[0]
342
- else:
343
- text_out = str(seq.tolist())
344
-
345
- return text_out.strip()
346
-
347
 
348
- def generate_medgemma_report(
349
  patient_info: str,
350
  visual_results: Dict,
351
  guideline_context: str,
352
- image_pil: Image.Image,
353
  max_new_tokens: Optional[int] = None,
354
  ) -> str:
355
  """
356
- MedGemma replacement using a vision-language model.
357
- Loads & runs ONLY inside a GPU worker to satisfy Stateless GPU constraints.
358
  """
359
  if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
360
  return "⚠️ VLM disabled"
361
 
362
- model_id = os.getenv("SMARTHEAL_VLM_MODEL", "bczhou/tiny-llava-v1-hf")
 
363
  max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
364
 
365
  uprompt = SMARTHEAL_USER_PREFIX.format(
@@ -373,28 +192,69 @@ def generate_medgemma_report(
373
  guideline_context=(guideline_context or "")[:900],
374
  )
375
 
376
- # The `messages` structure is passed to the verified `_vlm_infer_gpu` function
377
- messages = [
378
- {"role": "system", "content": [{"type": "text", "text": SMARTHEAL_SYSTEM_PROMPT}]},
379
- {"role": "user", "content": [
380
- {"type": "image", "image": image_pil},
381
- {"type": "text", "text": uprompt},
382
- ]},
383
- ]
384
 
385
  try:
386
- return _vlm_infer_gpu(messages, model_id, max_new_tokens, HF_TOKEN)
387
  except Exception as e:
388
- logging.error(f"VLM call failed: {e}", exc_info=True)
389
- return f"⚠️ VLM error: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  # ---------- Initialize CPU models ----------
391
  def load_yolo_model():
392
  YOLO = _import_ultralytics()
393
- # Construct model with CUDA masked to avoid auto-selecting cuda:0
394
  with _no_cuda_env():
395
  model = YOLO(YOLO_MODEL_PATH)
396
  return model
397
-
398
  def load_segmentation_model():
399
  import os; os.environ.setdefault("KERAS_BACKEND","tensorflow")
400
  import tensorflow as tf; tf.config.set_visible_devices([], "GPU")
@@ -428,11 +288,11 @@ def initialize_cpu_models() -> None:
428
  if "seg" not in models_cache:
429
  try:
430
  if os.path.exists(SEG_MODEL_PATH):
431
- models_cache["seg"] = load_segmentation_model()
432
- m = models_cache["seg"]
433
- ishape = getattr(m, "input_shape", None)
434
  oshape = getattr(m, "output_shape", None)
435
- logging.info(f"βœ… Segmentation model loaded (CPU) | input_shape={ishape} output_shape={oshape}")
436
  else:
437
  models_cache["seg"] = None
438
  logging.warning("Segmentation model file missing; skipping.")
@@ -650,11 +510,7 @@ def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndar
650
  # --- Model path ---
651
  if seg_model is not None:
652
  try:
653
- ishape = getattr(seg_model, "input_shape", None)
654
- if not ishape or len(ishape) < 4:
655
- raise ValueError(f"Bad seg input_shape: {ishape}")
656
- th, tw = int(ishape[1]), int(ishape[2])
657
-
658
  x = _preprocess_for_seg(image_bgr, (th, tw))
659
  roi_seen_path = None
660
  if SMARTHEAL_DEBUG:
@@ -745,7 +601,7 @@ def measure_min_area_rect(mask01: np.ndarray, px_per_cm: float) -> Tuple[float,
745
  cnt = max(contours, key=cv2.contourArea)
746
  rect = cv2.minAreaRect(cnt)
747
  (w_px, h_px) = rect[1]
748
- length_px, breadth_px = (max(w_px, h_px), min(w_px, h_px))
749
  length_cm = round(length_px / max(px_per_cm, 1e-6), 2)
750
  breadth_cm = round(breadth_px / max(px_per_cm, 1e-6), 2)
751
  box = cv2.boxPoints(rect).astype(int)
@@ -1004,7 +860,6 @@ class AIProcessor:
1004
  if not vs:
1005
  return "Knowledge base is not available."
1006
  retriever = vs.as_retriever(search_kwargs={"k": 5})
1007
- # Modern API (avoid get_relevant_documents deprecation)
1008
  docs = retriever.invoke(query)
1009
  lines: List[str] = []
1010
  for d in docs:
@@ -1018,38 +873,30 @@ class AIProcessor:
1018
 
1019
  def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
1020
  return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report
1021
-
1022
  ## πŸ“‹ Patient Information
1023
  {patient_info}
1024
-
1025
  ## πŸ” Visual Analysis Results
1026
  - **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
1027
  - **Dimensions**: {visual_results.get('length_cm', 0)} cm Γ— {visual_results.get('breadth_cm', 0)} cm
1028
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cmΒ²
1029
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
1030
  - **Calibration**: {visual_results.get('px_per_cm','?')} px/cm ({(visual_results.get('calibration_meta') or {}).get('used','default')})
1031
-
1032
  ## πŸ“Š Analysis Images
1033
  - **Original**: {visual_results.get('original_image_path', 'N/A')}
1034
  - **Detection**: {visual_results.get('detection_image_path', 'N/A')}
1035
  - **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
1036
  - **Annotated**: {visual_results.get('segmentation_annotated_path', 'N/A')}
1037
-
1038
  ## 🎯 Clinical Summary
1039
  Automated analysis provides quantitative measurements; verify via clinical examination.
1040
-
1041
  ## πŸ’Š Recommendations
1042
  - Cleanse wound gently; select dressing per exudate/infection risk
1043
  - Debride necrotic tissue if indicated (clinical decision)
1044
  - Document with serial photos and measurements
1045
-
1046
  ## πŸ“… Monitoring
1047
  - Daily in week 1, then every 2–3 days (or as indicated)
1048
  - Weekly progress review
1049
-
1050
  ## πŸ“š Guideline Context
1051
  {(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''}
1052
-
1053
  **Disclaimer:** Automated, for decision support only. Verify clinically.
1054
  """
1055
 
@@ -1103,8 +950,7 @@ Automated analysis provides quantitative measurements; verify via clinical exami
1103
  except Exception as e:
1104
  logging.error(f"Failed to save/commit image: {e}")
1105
  return ""
1106
-
1107
-
1108
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
1109
  try:
1110
  saved_path = self.save_and_commit_image(image_pil)
@@ -1150,7 +996,7 @@ Automated analysis provides quantitative measurements; verify via clinical exami
1150
  "saved_image_path": None,
1151
  "guideline_context": "",
1152
  }
1153
-
1154
  def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
1155
  try:
1156
  if isinstance(image, str):
@@ -1174,4 +1020,4 @@ Automated analysis provides quantitative measurements; verify via clinical exami
1174
  "report": f"Analysis initialization failed: {str(e)}",
1175
  "saved_image_path": None,
1176
  "guideline_context": "",
1177
- }
 
16
  import numpy as np
17
  from PIL import Image
18
  from PIL.ExifTags import TAGS
19
+
20
  # --- Logging config ---
21
  logging.basicConfig(
22
  level=getattr(logging, LOGLEVEL, logging.INFO),
 
26
  def _log_kv(prefix: str, kv: Dict):
27
  logging.debug(prefix + " | " + " | ".join(f"{k}={v}" for k, v in kv.items()))
28
 
29
+ # --- Spaces GPU decorator (REQUIRED) ---
30
+ from spaces import GPU as _SPACES_GPU
31
+
32
+ @_SPACES_GPU(enable_queue=True)
33
+ def smartheal_gpu_stub(ping: int = 0) -> str:
34
+ return "ready"
35
 
36
  # ---- Paths / constants ----
37
  UPLOADS_DIR = "uploads"
 
39
 
40
  HF_TOKEN = os.getenv("HF_TOKEN", None)
41
  YOLO_MODEL_PATH = "src/best.pt"
42
+ SEG_MODEL_PATH = "src/segmentation_model.h5" # optional; legacy .h5 supported
43
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
44
  DATASET_ID = "SmartHeal/wound-image-uploads"
45
  DEFAULT_PX_PER_CM = 38.0
 
123
  Patient: {patient_info}
124
  Visual findings: type={wound_type}, size={length_cm}x{breadth_cm} cm, area={area_cm2} cm^2,
125
  detection_conf={det_conf:.2f}, calibration={px_per_cm} px/cm.
 
126
  Guideline context (snippets you can draw principles from; do not quote at length):
127
  {guideline_context}
 
128
  Write a structured answer with these headings exactly:
129
  1. Clinical Summary (max 4 bullet points)
130
  2. Likely Stage/Type (if uncertain, say 'uncertain')
 
132
  4. Red Flags (what to escalate and when)
133
  5. Follow-up Cadence (days)
134
  6. Notes (assumptions/uncertainties)
 
135
  Keep to 220–300 words. Do NOT provide diagnosis. Avoid contraindicated advice.
136
  """
137
 
138
+ # ---------- MedGemma-only text generator ----------
139
+ @_SPACES_GPU(enable_queue=True)
140
+ def _medgemma_generate_gpu(prompt: str, model_id: str, max_new_tokens: int, token: Optional[str]):
141
  """
142
+ Runs entirely inside a Spaces GPU worker. Uses Med-Gemma (text-only) to draft the report.
 
 
 
 
 
143
  """
 
144
  import torch
145
+ from transformers import pipeline
 
 
 
 
 
 
146
 
147
+ pipe = pipeline(
148
+ "image-text-to-text",
149
+ model="google/medgemma-4b-it",
150
+ torch_dtype=torch.bfloat16,
151
+ device="cuda",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  )
153
+ out = pipe(
154
+ prompt,
155
+ max_new_tokens=max_new_tokens,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  do_sample=False,
157
+ temperature=0.2,
158
+ return_full_text=True,
 
159
  )
160
+ text = (out[0].get("generated_text") if isinstance(out, list) else out).strip()
161
+ # Remove the prompt echo if present
162
+ if text.startswith(prompt):
163
+ text = text[len(prompt):].lstrip()
164
+ return text or "⚠️ Empty response"
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ def generate_medgemma_report( # kept name so callers don't change
167
  patient_info: str,
168
  visual_results: Dict,
169
  guideline_context: str,
170
+ image_pil: Image.Image, # kept for signature compatibility; not used by MedGemma
171
  max_new_tokens: Optional[int] = None,
172
  ) -> str:
173
  """
174
+ MedGemma (text-only) report generation.
175
+ The image is analyzed by the vision pipeline; MedGemma formats clinical guidance text.
176
  """
177
  if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
178
  return "⚠️ VLM disabled"
179
 
180
+ # Default to a public Med-Gemma instruction-tuned model (update via env if you have access to another).
181
+ model_id = os.getenv("SMARTHEAL_MEDGEMMA_MODEL", "google/med-gemma-2-2b-it")
182
  max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
183
 
184
  uprompt = SMARTHEAL_USER_PREFIX.format(
 
192
  guideline_context=(guideline_context or "")[:900],
193
  )
194
 
195
+ # Compose a single text prompt
196
+ prompt = f"{SMARTHEAL_SYSTEM_PROMPT}\n\n{uprompt}\n\nAnswer:"
 
 
 
 
 
 
197
 
198
  try:
199
+ return _medgemma_generate_gpu(prompt, model_id, max_new_tokens, HF_TOKEN)
200
  except Exception as e:
201
+ logging.error(f"MedGemma call failed: {e}")
202
+ return "⚠️ VLM error"
203
+
204
+ # ---------- Input-shape helpers (avoid `.as_list()` on strings) ----------
205
+ def _shape_to_hw(shape) -> Tuple[Optional[int], Optional[int]]:
206
+ try:
207
+ if hasattr(shape, "as_list"):
208
+ shape = shape.as_list()
209
+ except Exception:
210
+ pass
211
+ if isinstance(shape, (tuple, list)):
212
+ if len(shape) == 4: # (None, H, W, C)
213
+ H, W = shape[1], shape[2]
214
+ elif len(shape) == 3: # (H, W, C)
215
+ H, W = shape[0], shape[1]
216
+ else:
217
+ return (None, None)
218
+ try: H = int(H) if (H is not None and str(H).lower() != "none") else None
219
+ except Exception: H = None
220
+ try: W = int(W) if (W is not None and str(W).lower() != "none") else None
221
+ except Exception: W = None
222
+ return (H, W)
223
+ return (None, None)
224
+
225
+ def _get_model_input_hw(model, default_hw: Tuple[int, int] = (224, 224)) -> Tuple[int, int]:
226
+ H, W = _shape_to_hw(getattr(model, "input_shape", None))
227
+ if H and W:
228
+ return H, W
229
+ try:
230
+ inputs = getattr(model, "inputs", None)
231
+ if inputs:
232
+ H, W = _shape_to_hw(inputs[0].shape)
233
+ if H and W:
234
+ return H, W
235
+ except Exception:
236
+ pass
237
+ try:
238
+ cfg = model.get_config() if hasattr(model, "get_config") else None
239
+ if isinstance(cfg, dict):
240
+ for layer in cfg.get("layers", []):
241
+ conf = (layer or {}).get("config", {})
242
+ cand = conf.get("batch_input_shape") or conf.get("batch_shape")
243
+ H, W = _shape_to_hw(cand)
244
+ if H and W:
245
+ return H, W
246
+ except Exception:
247
+ pass
248
+ logging.warning(f"Could not resolve model input shape; using default {default_hw}.")
249
+ return default_hw
250
+
251
  # ---------- Initialize CPU models ----------
252
  def load_yolo_model():
253
  YOLO = _import_ultralytics()
 
254
  with _no_cuda_env():
255
  model = YOLO(YOLO_MODEL_PATH)
256
  return model
257
+
258
  def load_segmentation_model():
259
  import os; os.environ.setdefault("KERAS_BACKEND","tensorflow")
260
  import tensorflow as tf; tf.config.set_visible_devices([], "GPU")
 
288
  if "seg" not in models_cache:
289
  try:
290
  if os.path.exists(SEG_MODEL_PATH):
291
+ m = load_segmentation_model() # uses global path by default
292
+ models_cache["seg"] = m
293
+ th, tw = _get_model_input_hw(m, default_hw=(224, 224))
294
  oshape = getattr(m, "output_shape", None)
295
+ logging.info(f"βœ… Segmentation model loaded (CPU) | input_hw=({th},{tw}) output_shape={oshape}")
296
  else:
297
  models_cache["seg"] = None
298
  logging.warning("Segmentation model file missing; skipping.")
 
510
  # --- Model path ---
511
  if seg_model is not None:
512
  try:
513
+ th, tw = _get_model_input_hw(seg_model, default_hw=(224, 224))
 
 
 
 
514
  x = _preprocess_for_seg(image_bgr, (th, tw))
515
  roi_seen_path = None
516
  if SMARTHEAL_DEBUG:
 
601
  cnt = max(contours, key=cv2.contourArea)
602
  rect = cv2.minAreaRect(cnt)
603
  (w_px, h_px) = rect[1]
604
+ length_px, breadth_px = (max(w_px, h_px), min(h_px, w_px))
605
  length_cm = round(length_px / max(px_per_cm, 1e-6), 2)
606
  breadth_cm = round(breadth_px / max(px_per_cm, 1e-6), 2)
607
  box = cv2.boxPoints(rect).astype(int)
 
860
  if not vs:
861
  return "Knowledge base is not available."
862
  retriever = vs.as_retriever(search_kwargs={"k": 5})
 
863
  docs = retriever.invoke(query)
864
  lines: List[str] = []
865
  for d in docs:
 
873
 
874
  def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
875
  return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report
 
876
  ## πŸ“‹ Patient Information
877
  {patient_info}
 
878
  ## πŸ” Visual Analysis Results
879
  - **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
880
  - **Dimensions**: {visual_results.get('length_cm', 0)} cm Γ— {visual_results.get('breadth_cm', 0)} cm
881
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cmΒ²
882
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
883
  - **Calibration**: {visual_results.get('px_per_cm','?')} px/cm ({(visual_results.get('calibration_meta') or {}).get('used','default')})
 
884
  ## πŸ“Š Analysis Images
885
  - **Original**: {visual_results.get('original_image_path', 'N/A')}
886
  - **Detection**: {visual_results.get('detection_image_path', 'N/A')}
887
  - **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
888
  - **Annotated**: {visual_results.get('segmentation_annotated_path', 'N/A')}
 
889
  ## 🎯 Clinical Summary
890
  Automated analysis provides quantitative measurements; verify via clinical examination.
 
891
  ## πŸ’Š Recommendations
892
  - Cleanse wound gently; select dressing per exudate/infection risk
893
  - Debride necrotic tissue if indicated (clinical decision)
894
  - Document with serial photos and measurements
 
895
  ## πŸ“… Monitoring
896
  - Daily in week 1, then every 2–3 days (or as indicated)
897
  - Weekly progress review
 
898
  ## πŸ“š Guideline Context
899
  {(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''}
 
900
  **Disclaimer:** Automated, for decision support only. Verify clinically.
901
  """
902
 
 
950
  except Exception as e:
951
  logging.error(f"Failed to save/commit image: {e}")
952
  return ""
953
+
 
954
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
955
  try:
956
  saved_path = self.save_and_commit_image(image_pil)
 
996
  "saved_image_path": None,
997
  "guideline_context": "",
998
  }
999
+
1000
  def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
1001
  try:
1002
  if isinstance(image, str):
 
1020
  "report": f"Analysis initialization failed: {str(e)}",
1021
  "saved_image_path": None,
1022
  "guideline_context": "",
1023
+ }