SmartHeal commited on
Commit
afbe010
·
verified ·
1 Parent(s): c897c01

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +246 -2
src/ai_processor.py CHANGED
@@ -771,7 +771,89 @@ class AIProcessor:
771
  os.makedirs(out_dir, exist_ok=True)
772
  return out_dir
773
 
774
- def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
  """
776
  YOLO detect → crop ROI → segment_wound(ROI) → clean mask →
777
  minAreaRect measurement (cm) using EXIF px/cm → save outputs.
@@ -783,8 +865,158 @@ class AIProcessor:
783
  if (exif_meta or {}).get("used") != "exif":
784
  logging.warning(f"Calibration fallback used: px_per_cm={px_per_cm:.2f} (default). Prefer ruler/Aruco for accuracy.")
785
 
 
786
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
788
  # --- Detection ---
789
  det_model = self.models_cache.get("det")
790
  if det_model is None:
@@ -906,6 +1138,9 @@ class AIProcessor:
906
  }
907
  _log_kv("SEG_SUMMARY", seg_summary)
908
 
 
 
 
909
  return {
910
  "wound_type": wound_type,
911
  "length_cm": length_cm,
@@ -923,6 +1158,10 @@ class AIProcessor:
923
  "segmentation_empty": segmentation_empty,
924
  "segmentation_debug": seg_debug,
925
  "original_image_path": original_path,
 
 
 
 
926
  }
927
  except Exception as e:
928
  logging.error(f"Visual analysis failed: {e}", exc_info=True)
@@ -1029,7 +1268,12 @@ Automated analysis provides quantitative measurements; verify via clinical exami
1029
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
1030
  try:
1031
  saved_path = self.save_and_commit_image(image_pil)
1032
- visual_results = self.perform_visual_analysis(image_pil)
 
 
 
 
 
1033
 
1034
  pi = questionnaire_data or {}
1035
  patient_info = (
 
771
  os.makedirs(out_dir, exist_ok=True)
772
  return out_dir
773
 
774
+ def _compute_skin_tone(self, image_cv: np.ndarray, mask01: np.ndarray) -> Tuple[str, float]:
775
+ """
776
+ Estimate Fitzpatrick skin tone based on the Individual Typology Angle (ITA) computed
777
+ from the non-wound region of the provided ROI. If no background region exists,
778
+ use the entire image. Returns a tuple (label, ita_degrees).
779
+ """
780
+ try:
781
+ # Convert BGR to LAB color space
782
+ lab = cv2.cvtColor(image_cv, cv2.COLOR_BGR2LAB).astype(np.float32)
783
+ # Split channels (L in [0,255], A and B in [0,255])
784
+ L_channel = lab[:, :, 0]
785
+ b_channel = lab[:, :, 2]
786
+
787
+ # Determine background (skin) pixels: where mask==0
788
+ if mask01 is not None and mask01.size == L_channel.shape[0] * L_channel.shape[1]:
789
+ bg_mask = (mask01 == 0)
790
+ else:
791
+ bg_mask = np.ones_like(L_channel, dtype=bool)
792
+
793
+ # If background region is too small, use entire image
794
+ if bg_mask.sum() < max(100, 0.05 * bg_mask.size):
795
+ bg_mask = np.ones_like(L_channel, dtype=bool)
796
+
797
+ L_vals = L_channel[bg_mask]
798
+ b_vals = b_channel[bg_mask]
799
+
800
+ # Convert to CIELAB ranges: L* ∈ [0,100]; b* ∈ [-128,127]
801
+ L_star = L_vals * (100.0 / 255.0)
802
+ b_star = (b_vals - 128.0) * (200.0 / 255.0)
803
+ # Mean values
804
+ L_mean = float(np.mean(L_star)) if L_star.size > 0 else 50.0
805
+ b_mean = float(np.mean(b_star)) if b_star.size > 0 else 0.0
806
+
807
+ # ITA calculation
808
+ ita = np.degrees(np.arctan2((L_mean - 50.0), b_mean))
809
+ ita = float(ita)
810
+
811
+ # Classification based on Del Bino ranges
812
+ if ita > 55:
813
+ label = "Type I (Very Light)"
814
+ elif ita > 41:
815
+ label = "Type II (Light)"
816
+ elif ita > 28:
817
+ label = "Type III (Intermediate)"
818
+ elif ita > 10:
819
+ label = "Type IV (Tan)"
820
+ elif ita > -30:
821
+ label = "Type V (Brown)"
822
+ else:
823
+ label = "Type VI (Dark)"
824
+ return label, round(ita, 2)
825
+ except Exception:
826
+ return "Unknown", 0.0
827
+
828
+ def _compute_tissue_type(self, image_cv: np.ndarray, mask01: np.ndarray) -> str:
829
+ """
830
+ Classify wound tissue based on the average color of the wound region.
831
+ Returns one of ["Granulation", "Slough", "Necrotic", "Unknown"].
832
+ """
833
+ try:
834
+ if mask01 is None or not mask01.any():
835
+ return "Unknown"
836
+ # convert BGR to HSV; OpenCV hue in [0,179], saturation/value in [0,255]
837
+ hsv = cv2.cvtColor(image_cv, cv2.COLOR_BGR2HSV).astype(np.float32)
838
+ h = hsv[:, :, 0][mask01 == 1]
839
+ s = hsv[:, :, 1][mask01 == 1] / 255.0
840
+ v = hsv[:, :, 2][mask01 == 1] / 255.0
841
+ if v.size == 0:
842
+ return "Unknown"
843
+ h_mean = float(np.mean(h)) * 2.0 # convert to degrees 0-360
844
+ v_mean = float(np.mean(v))
845
+ # Dark = necrotic
846
+ if v_mean < 0.2:
847
+ return "Necrotic"
848
+ # Slough: yellowish hues ~15-40°
849
+ if 15.0 <= h_mean <= 40.0:
850
+ return "Slough"
851
+ # Otherwise granulation (reddish)
852
+ return "Granulation"
853
+ except Exception:
854
+ return "Unknown"
855
+
856
+ def perform_visual_analysis(self, image_pil: Image.Image, manual_mask_data: Optional[dict] = None) -> Dict:
857
  """
858
  YOLO detect → crop ROI → segment_wound(ROI) → clean mask →
859
  minAreaRect measurement (cm) using EXIF px/cm → save outputs.
 
865
  if (exif_meta or {}).get("used") != "exif":
866
  logging.warning(f"Calibration fallback used: px_per_cm={px_per_cm:.2f} (default). Prefer ruler/Aruco for accuracy.")
867
 
868
+ # Convert input PIL image to OpenCV BGR once here; manual mode may reuse it
869
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
870
 
871
+ # ----------------------------------------------------------------------
872
+ # Manual annotation branch
873
+ # If a manual_mask_data dictionary/image is provided, bypass model detection
874
+ if manual_mask_data:
875
+ try:
876
+ # Extract mask image from various possible structures
877
+ mask_source = manual_mask_data
878
+ if isinstance(mask_source, dict):
879
+ mask_source = mask_source.get("mask") or mask_source.get("image") or mask_source
880
+ # Load mask into numpy grayscale
881
+ if isinstance(mask_source, Image.Image):
882
+ mask_np = np.array(mask_source.convert("L"))
883
+ elif isinstance(mask_source, np.ndarray):
884
+ mask_np = mask_source.copy()
885
+ if mask_np.ndim == 3:
886
+ # If mask is RGB/A, convert to grayscale
887
+ mask_np = cv2.cvtColor(mask_np, cv2.COLOR_BGR2GRAY)
888
+ elif isinstance(mask_source, str) and os.path.exists(mask_source):
889
+ mask_np = np.array(Image.open(mask_source).convert("L"))
890
+ else:
891
+ mask_np = None
892
+
893
+ if mask_np is None:
894
+ raise ValueError("Invalid manual mask")
895
+
896
+ # Binary mask (1 for wound, 0 for background)
897
+ mask01_full = (mask_np > 0).astype(np.uint8)
898
+ h_full, w_full = image_cv.shape[:2]
899
+
900
+ # Measurement using full-size mask
901
+ segmentation_empty = not mask01_full.any()
902
+ if not segmentation_empty:
903
+ length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask01_full, px_per_cm)
904
+ area_poly_cm2, largest_cnt = area_cm2_from_contour(mask01_full, px_per_cm)
905
+ if largest_cnt is not None:
906
+ surface_area_cm2 = clamp_area_with_minrect(largest_cnt, px_per_cm, area_poly_cm2)
907
+ else:
908
+ surface_area_cm2 = area_poly_cm2
909
+ anno_roi = draw_measurement_overlay(image_cv.copy(), mask01_full, box_pts, length_cm, breadth_cm)
910
+ else:
911
+ # fallback: use full image dims
912
+ h_px = h_full; w_px = w_full
913
+ length_cm = round(max(h_px, w_px) / px_per_cm, 2)
914
+ breadth_cm = round(min(h_px, w_px) / px_per_cm, 2)
915
+ surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
916
+ anno_roi = image_cv.copy()
917
+ cv2.rectangle(anno_roi, (2, 2), (anno_roi.shape[1]-3, anno_roi.shape[0]-3), (0, 0, 255), 3)
918
+ cv2.line(anno_roi, (0, 0), (anno_roi.shape[1]-1, anno_roi.shape[0]-1), (0, 0, 255), 2)
919
+ cv2.line(anno_roi, (anno_roi.shape[1]-1, 0), (0, anno_roi.shape[0]-1), (0, 0, 255), 2)
920
+ box_pts = None
921
+
922
+ # Prepare output directory
923
+ out_dir = self._ensure_analysis_dir()
924
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
925
+ original_path = os.path.join(out_dir, f"original_{ts}.png")
926
+ cv2.imwrite(original_path, image_cv)
927
+
928
+ # Detection visualization: draw bounding box of mask
929
+ det_vis = image_cv.copy()
930
+ if mask01_full.any():
931
+ ys, xs = np.where(mask01_full == 1)
932
+ y_min, y_max = int(ys.min()), int(ys.max())
933
+ x_min, x_max = int(xs.min()), int(xs.max())
934
+ cv2.rectangle(det_vis, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
935
+ detection_path = os.path.join(out_dir, f"detection_{ts}.png")
936
+ cv2.imwrite(detection_path, det_vis)
937
+
938
+ # Save mask and overlays
939
+ roi_mask_path = os.path.join(out_dir, f"roi_mask_{ts}.png")
940
+ cv2.imwrite(roi_mask_path, (mask01_full * 255).astype(np.uint8))
941
+
942
+ # ROI overlay: tinted mask
943
+ mask255 = (mask01_full * 255).astype(np.uint8)
944
+ mask3 = cv2.merge([mask255, mask255, mask255])
945
+ red = np.zeros_like(image_cv); red[:] = (0, 0, 255)
946
+ alpha = 0.55
947
+ tinted = cv2.addWeighted(image_cv, 1 - alpha, red, alpha, 0)
948
+ if mask255.any():
949
+ roi_overlay = np.where(mask3 > 0, tinted, image_cv)
950
+ cnts, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
951
+ cv2.drawContours(roi_overlay, cnts, -1, (255, 255, 255), 2)
952
+ else:
953
+ roi_overlay = anno_roi
954
+
955
+ seg_full = image_cv.copy()
956
+ seg_full[:, :] = roi_overlay
957
+ segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
958
+ cv2.imwrite(segmentation_path, seg_full)
959
+
960
+ segmentation_roi_path = os.path.join(out_dir, f"segmentation_roi_{ts}.png")
961
+ cv2.imwrite(segmentation_roi_path, roi_overlay)
962
+
963
+ anno_full = image_cv.copy()
964
+ anno_full[:, :] = anno_roi
965
+ annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
966
+ cv2.imwrite(annotated_seg_path, anno_full)
967
+
968
+ # Classification: crop bounding box region for classification
969
+ wound_type = "Unknown"
970
+ cls_pipe = self.models_cache.get("cls")
971
+ if cls_pipe is not None and mask01_full.any():
972
+ try:
973
+ roi_bbox = image_cv[y_min:y_max+1, x_min:x_max+1]
974
+ preds = cls_pipe(Image.fromarray(cv2.cvtColor(roi_bbox, cv2.COLOR_BGR2RGB)))
975
+ if preds:
976
+ wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
977
+ except Exception as e:
978
+ logging.warning(f"Classification failed: {e}")
979
+
980
+ # Compute skin tone and tissue classification using full image and mask
981
+ skin_tone_label, ita_degrees = self._compute_skin_tone(image_cv, mask01_full)
982
+ tissue_type = self._compute_tissue_type(image_cv, mask01_full)
983
+
984
+ seg_debug = {
985
+ "used": "manual",
986
+ "reason": "Manual annotation provided",
987
+ "positive_fraction": float(mask01_full.mean()) if mask01_full.size > 0 else 0.0,
988
+ "thr": SEG_THRESH,
989
+ }
990
+
991
+ return {
992
+ "wound_type": wound_type,
993
+ "length_cm": length_cm,
994
+ "breadth_cm": breadth_cm,
995
+ "surface_area_cm2": surface_area_cm2,
996
+ "px_per_cm": round(px_per_cm, 2),
997
+ "calibration_meta": exif_meta,
998
+ "detection_confidence": 1.0 if mask01_full.any() else 0.0,
999
+ "detection_image_path": detection_path,
1000
+ "segmentation_image_path": annotated_seg_path,
1001
+ "segmentation_annotated_path": annotated_seg_path,
1002
+ "segmentation_roi_path": segmentation_roi_path,
1003
+ "roi_mask_path": roi_mask_path,
1004
+ "segmentation_empty": segmentation_empty,
1005
+ "segmentation_debug": seg_debug,
1006
+ "original_image_path": original_path,
1007
+ "skin_tone_label": skin_tone_label,
1008
+ "ita_degrees": ita_degrees,
1009
+ "tissue_type": tissue_type,
1010
+ "segmentation_used": "manual",
1011
+ }
1012
+ except Exception as e_manual:
1013
+ logging.error(f"Manual analysis failed: {e_manual}", exc_info=True)
1014
+ # If manual branch fails, fall back to automatic detection
1015
+ pass
1016
+
1017
+ # ----------------------------------------------------------------------
1018
+ # Automatic model-based pipeline below
1019
+
1020
  # --- Detection ---
1021
  det_model = self.models_cache.get("det")
1022
  if det_model is None:
 
1138
  }
1139
  _log_kv("SEG_SUMMARY", seg_summary)
1140
 
1141
+ # Compute skin tone and tissue classification on ROI for automatic segmentation
1142
+ skin_tone_label, ita_degrees = self._compute_skin_tone(roi, mask01)
1143
+ tissue_type = self._compute_tissue_type(roi, mask01)
1144
  return {
1145
  "wound_type": wound_type,
1146
  "length_cm": length_cm,
 
1158
  "segmentation_empty": segmentation_empty,
1159
  "segmentation_debug": seg_debug,
1160
  "original_image_path": original_path,
1161
+ "skin_tone_label": skin_tone_label,
1162
+ "ita_degrees": ita_degrees,
1163
+ "tissue_type": tissue_type,
1164
+ "segmentation_used": "automatic",
1165
  }
1166
  except Exception as e:
1167
  logging.error(f"Visual analysis failed: {e}", exc_info=True)
 
1268
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
1269
  try:
1270
  saved_path = self.save_and_commit_image(image_pil)
1271
+ # Extract any manual annotation mask from questionnaire data; remove it from the dict so it doesn't
1272
+ # get forwarded to the text generation pipeline.
1273
+ manual_mask_data = None
1274
+ if isinstance(questionnaire_data, dict) and 'manual_mask' in questionnaire_data:
1275
+ manual_mask_data = questionnaire_data.pop('manual_mask')
1276
+ visual_results = self.perform_visual_analysis(image_pil, manual_mask_data)
1277
 
1278
  pi = questionnaire_data or {}
1279
  patient_info = (