SmartHeal commited on
Commit
3c0b441
·
verified ·
1 Parent(s): 25c58c9

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +58 -75
src/ai_processor.py CHANGED
@@ -36,17 +36,14 @@ knowledge_base_cache = {}
36
 
37
  # =============== LAZY LOADING FUNCTIONS (CPU-SAFE) ===============
38
  def load_yolo_model(yolo_model_path):
39
- """Lazy import and load YOLO model to avoid CUDA initialization."""
40
  from ultralytics import YOLO
41
  return YOLO(yolo_model_path)
42
 
43
  def load_segmentation_model(seg_model_path):
44
- """Lazy import and load segmentation model."""
45
  from tensorflow.keras.models import load_model
46
  return load_model(seg_model_path, compile=False)
47
 
48
  def load_classification_pipeline(hf_token):
49
- """Lazy import and load classification pipeline (CPU only)."""
50
  from transformers import pipeline
51
  return pipeline(
52
  "image-classification",
@@ -56,7 +53,6 @@ def load_classification_pipeline(hf_token):
56
  )
57
 
58
  def load_embedding_model():
59
- """Load embedding model for knowledge base."""
60
  return HuggingFaceEmbeddings(
61
  model_name="sentence-transformers/all-MiniLM-L6-v2",
62
  model_kwargs={"device": "cpu"}
@@ -64,34 +60,28 @@ def load_embedding_model():
64
 
65
  # =============== MODEL INITIALIZATION ===============
66
  def initialize_cpu_models():
67
- """Initialize all CPU-only models once."""
68
  global models_cache
69
-
70
  if HF_TOKEN:
71
  HfFolder.save_token(HF_TOKEN)
72
  logging.info("✅ HuggingFace token set")
73
-
74
  if "det" not in models_cache:
75
  try:
76
  models_cache["det"] = load_yolo_model(YOLO_MODEL_PATH)
77
  logging.info("✅ YOLO model loaded (CPU only)")
78
  except Exception as e:
79
  logging.error(f"YOLO load failed: {e}")
80
-
81
  if "seg" not in models_cache:
82
  try:
83
  models_cache["seg"] = load_segmentation_model(SEG_MODEL_PATH)
84
  logging.info("✅ Segmentation model loaded (CPU)")
85
  except Exception as e:
86
  logging.warning(f"Segmentation model not available: {e}")
87
-
88
  if "cls" not in models_cache:
89
  try:
90
  models_cache["cls"] = load_classification_pipeline(HF_TOKEN)
91
  logging.info("✅ Classification pipeline loaded (CPU)")
92
  except Exception as e:
93
  logging.warning(f"Classification pipeline not available: {e}")
94
-
95
  if "embedding_model" not in models_cache:
96
  try:
97
  models_cache["embedding_model"] = load_embedding_model()
@@ -100,11 +90,9 @@ def initialize_cpu_models():
100
  logging.warning(f"Embedding model not available: {e}")
101
 
102
  def setup_knowledge_base():
103
- """Load PDF documents and create FAISS vector store."""
104
  global knowledge_base_cache
105
  if "vector_store" in knowledge_base_cache:
106
  return
107
-
108
  docs = []
109
  for pdf_path in GUIDELINE_PDFS:
110
  if os.path.exists(pdf_path):
@@ -114,7 +102,6 @@ def setup_knowledge_base():
114
  logging.info(f"Loaded PDF: {pdf_path}")
115
  except Exception as e:
116
  logging.warning(f"Failed to load PDF {pdf_path}: {e}")
117
-
118
  if docs and "embedding_model" in models_cache:
119
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
120
  chunks = splitter.split_documents(docs)
@@ -124,7 +111,7 @@ def setup_knowledge_base():
124
  knowledge_base_cache["vector_store"] = None
125
  logging.warning("Knowledge base unavailable")
126
 
127
- # Initialize models on app startup
128
  initialize_cpu_models()
129
  setup_knowledge_base()
130
 
@@ -138,8 +125,6 @@ def generate_medgemma_report(
138
  segmentation_image_path,
139
  max_new_tokens=None,
140
  ):
141
- """GPU-only function for MedGemma report generation."""
142
- # Import GPU libraries ONLY here
143
  import torch
144
  from transformers import pipeline
145
  from PIL import Image
@@ -156,7 +141,6 @@ def generate_medgemma_report(
156
  "patient context."
157
  )
158
 
159
- # Lazy-load MedGemma pipeline on GPU
160
  if not hasattr(generate_medgemma_report, "_pipe"):
161
  try:
162
  generate_medgemma_report._pipe = pipeline(
@@ -174,18 +158,15 @@ def generate_medgemma_report(
174
 
175
  pipe = generate_medgemma_report._pipe
176
 
177
- # Compose messages
178
  msgs = [
179
  {"role": "system", "content": [{"type": "text", "text": default_system_prompt}]},
180
  {"role": "user", "content": []},
181
  ]
182
 
183
- # Attach images if available
184
  for path in (detection_image_path, segmentation_image_path):
185
  if path and os.path.exists(path):
186
  msgs[1]["content"].append({"type": "image", "image": Image.open(path)})
187
 
188
- # Attach text prompt
189
  prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results.get('wound_type','Unknown')}"
190
  msgs[1]["content"].append({"type": "text", "text": prompt})
191
 
@@ -207,7 +188,6 @@ class AIProcessor:
207
  self.hf_token = HF_TOKEN
208
 
209
  def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
210
- """Detect & segment on CPU; return metrics + file paths."""
211
  img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
212
  yolo = self.models_cache.get("det")
213
  if yolo is None:
@@ -217,40 +197,57 @@ class AIProcessor:
217
  if not res.boxes:
218
  raise ValueError("No wound detected")
219
 
220
- x1, y1, x2, y2 = res.boxes.xyxy.cpu().numpy().astype(int)
221
- region = img_cv[y1:y2, x1:x2]
 
 
 
 
 
 
 
 
 
222
 
223
  # Save detection overlay
224
  det_vis = img_cv.copy()
225
- cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
226
  os.makedirs(f"{self.uploads_dir}/analysis", exist_ok=True)
227
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
228
  det_path = f"{self.uploads_dir}/analysis/detection_{ts}.png"
229
  cv2.imwrite(det_path, det_vis)
230
 
231
  # Segmentation
232
- length = breadth = area = 0
233
  seg_path = None
234
  seg_model = self.models_cache.get("seg")
235
  if seg_model:
236
- h, w = seg_model.input_shape[1:3]
237
- inp = cv2.resize(region, (w, h)) / 255.0
238
- mask = (seg_model.predict(inp[None])[0, :, :, 0] > 0.5).astype(np.uint8)
239
- mask_rs = cv2.resize(mask, (region.shape[1], region.shape), interpolation=cv2.INTER_NEAREST)
240
-
241
- ov = region.copy()
242
- ov[mask_rs == 1] = [0, 0, 255]
243
- seg_vis = cv2.addWeighted(region, 0.7, ov, 0.3, 0)
244
- seg_path = f"{self.uploads_dir}/analysis/segmentation_{ts}.png"
245
- cv2.imwrite(seg_path, seg_vis)
246
-
247
- cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
248
- if cnts:
249
- cnt = max(cnts, key=cv2.contourArea)
250
- _, _, w0, h0 = cv2.boundingRect(cnt)
251
- length = round(h0 / self.px_per_cm, 2)
252
- breadth = round(w0 / self.px_per_cm, 2)
253
- area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
 
 
 
 
 
 
 
 
254
 
255
  # Classification
256
  wound_type = "Unknown"
@@ -258,9 +255,10 @@ class AIProcessor:
258
  if cls_pipe:
259
  try:
260
  preds = cls_pipe(Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB)))
261
- wound_type = max(preds, key=lambda x: x["score"])["label"]
262
- except Exception:
263
- pass
 
264
 
265
  return {
266
  "wound_type": wound_type,
@@ -273,7 +271,6 @@ class AIProcessor:
273
  }
274
 
275
  def query_guidelines(self, query: str) -> str:
276
- """Query the knowledge base for relevant information."""
277
  vs = self.knowledge_base_cache.get("vector_store")
278
  if not vs:
279
  return "Clinical guidelines unavailable"
@@ -282,39 +279,28 @@ class AIProcessor:
282
  f"Source: {d.metadata.get('source','?')}, Page: {d.metadata.get('page','?')}\n{d.page_content}" for d in docs
283
  )
284
 
285
- def generate_final_report(
286
- self, patient_info: str, visual_results: dict, guideline_context: str, image_pil: Image.Image, max_new_tokens: int = None
287
- ) -> str:
288
- """Generate final report using MedGemma GPU pipeline."""
289
- det = visual_results.get("detection_image_path", "")
290
- seg = visual_results.get("segmentation_image_path", "")
291
-
292
- report = generate_medgemma_report(patient_info, visual_results, guideline_context, det, seg, max_new_tokens)
293
  if report:
294
  return report
295
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
296
 
297
- def _generate_fallback_report(
298
- self, patient_info: str, visual_results: dict, guideline_context: str
299
- ) -> str:
300
- """Generate fallback report if MedGemma fails."""
301
- dp = visual_results.get('detection_image_path','N/A')
302
- sp = visual_results.get('segmentation_image_path','N/A')
303
  return (
304
  f"# Fallback Report\n{patient_info}\n"
305
  f"Type: {visual_results.get('wound_type','Unknown')}\n"
306
- f"Detection Image: {dp}\n"
307
- f"Segmentation Image: {sp}\n"
308
  f"Guidelines: {guideline_context[:200]}..."
309
  )
310
 
311
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
312
- """Save image locally and optionally commit to HF dataset."""
313
  os.makedirs(self.uploads_dir, exist_ok=True)
314
  fn = f"{datetime.now():%Y%m%d_%H%M%S}.png"
315
  path = os.path.join(self.uploads_dir, fn)
316
  image_pil.convert("RGB").save(path)
317
-
318
  if self.hf_token and self.dataset_id:
319
  try:
320
  HfApi().upload_file(
@@ -328,27 +314,24 @@ class AIProcessor:
328
  logging.warning(f"HF upload failed: {e}")
329
  return path
330
 
331
- def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict:
332
- """Run full analysis pipeline."""
333
  try:
334
- saved = self.save_and_commit_image(image_pil)
335
- vis = self.perform_visual_analysis(image_pil)
336
  info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v)
337
  gc = self.query_guidelines(info)
338
- report = self.generate_final_report(info, vis, gc, image_pil)
339
  return {'success': True, 'visual_analysis': vis, 'report': report, 'saved_image_path': saved}
340
  except Exception as e:
341
  logging.error(f"Pipeline error: {e}")
342
  return {'success': False, 'error': str(e)}
343
 
344
- def analyze_wound(self, image, questionnaire_data: dict) -> dict:
345
- """Main analysis entry point."""
346
  if isinstance(image, str):
347
  image = Image.open(image)
348
  return self.full_analysis_pipeline(image, questionnaire_data)
349
 
350
- def _assess_risk_legacy(self, questionnaire_data: dict) -> dict:
351
- """Legacy risk assessment function."""
352
  risk_factors, risk_score = [], 0
353
  try:
354
  age = questionnaire_data.get('patient_age', 0)
@@ -377,4 +360,4 @@ class AIProcessor:
377
  return {'risk_score': risk_score, 'risk_level': level, 'risk_factors': risk_factors}
378
  except Exception as e:
379
  logging.error(f"Risk assessment error: {e}")
380
- return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}
 
36
 
37
  # =============== LAZY LOADING FUNCTIONS (CPU-SAFE) ===============
38
  def load_yolo_model(yolo_model_path):
 
39
  from ultralytics import YOLO
40
  return YOLO(yolo_model_path)
41
 
42
  def load_segmentation_model(seg_model_path):
 
43
  from tensorflow.keras.models import load_model
44
  return load_model(seg_model_path, compile=False)
45
 
46
  def load_classification_pipeline(hf_token):
 
47
  from transformers import pipeline
48
  return pipeline(
49
  "image-classification",
 
53
  )
54
 
55
  def load_embedding_model():
 
56
  return HuggingFaceEmbeddings(
57
  model_name="sentence-transformers/all-MiniLM-L6-v2",
58
  model_kwargs={"device": "cpu"}
 
60
 
61
  # =============== MODEL INITIALIZATION ===============
62
  def initialize_cpu_models():
 
63
  global models_cache
 
64
  if HF_TOKEN:
65
  HfFolder.save_token(HF_TOKEN)
66
  logging.info("✅ HuggingFace token set")
 
67
  if "det" not in models_cache:
68
  try:
69
  models_cache["det"] = load_yolo_model(YOLO_MODEL_PATH)
70
  logging.info("✅ YOLO model loaded (CPU only)")
71
  except Exception as e:
72
  logging.error(f"YOLO load failed: {e}")
 
73
  if "seg" not in models_cache:
74
  try:
75
  models_cache["seg"] = load_segmentation_model(SEG_MODEL_PATH)
76
  logging.info("✅ Segmentation model loaded (CPU)")
77
  except Exception as e:
78
  logging.warning(f"Segmentation model not available: {e}")
 
79
  if "cls" not in models_cache:
80
  try:
81
  models_cache["cls"] = load_classification_pipeline(HF_TOKEN)
82
  logging.info("✅ Classification pipeline loaded (CPU)")
83
  except Exception as e:
84
  logging.warning(f"Classification pipeline not available: {e}")
 
85
  if "embedding_model" not in models_cache:
86
  try:
87
  models_cache["embedding_model"] = load_embedding_model()
 
90
  logging.warning(f"Embedding model not available: {e}")
91
 
92
  def setup_knowledge_base():
 
93
  global knowledge_base_cache
94
  if "vector_store" in knowledge_base_cache:
95
  return
 
96
  docs = []
97
  for pdf_path in GUIDELINE_PDFS:
98
  if os.path.exists(pdf_path):
 
102
  logging.info(f"Loaded PDF: {pdf_path}")
103
  except Exception as e:
104
  logging.warning(f"Failed to load PDF {pdf_path}: {e}")
 
105
  if docs and "embedding_model" in models_cache:
106
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
107
  chunks = splitter.split_documents(docs)
 
111
  knowledge_base_cache["vector_store"] = None
112
  logging.warning("Knowledge base unavailable")
113
 
114
+ # Initialize models at startup
115
  initialize_cpu_models()
116
  setup_knowledge_base()
117
 
 
125
  segmentation_image_path,
126
  max_new_tokens=None,
127
  ):
 
 
128
  import torch
129
  from transformers import pipeline
130
  from PIL import Image
 
141
  "patient context."
142
  )
143
 
 
144
  if not hasattr(generate_medgemma_report, "_pipe"):
145
  try:
146
  generate_medgemma_report._pipe = pipeline(
 
158
 
159
  pipe = generate_medgemma_report._pipe
160
 
 
161
  msgs = [
162
  {"role": "system", "content": [{"type": "text", "text": default_system_prompt}]},
163
  {"role": "user", "content": []},
164
  ]
165
 
 
166
  for path in (detection_image_path, segmentation_image_path):
167
  if path and os.path.exists(path):
168
  msgs[1]["content"].append({"type": "image", "image": Image.open(path)})
169
 
 
170
  prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results.get('wound_type','Unknown')}"
171
  msgs[1]["content"].append({"type": "text", "text": prompt})
172
 
 
188
  self.hf_token = HF_TOKEN
189
 
190
  def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
 
191
  img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
192
  yolo = self.models_cache.get("det")
193
  if yolo is None:
 
197
  if not res.boxes:
198
  raise ValueError("No wound detected")
199
 
200
+ # Safely unpack detection boxes
201
+ try:
202
+ xyxy = res.boxes.xyxy.cpu().numpy()
203
+ if xyxy.shape[0] == 0:
204
+ raise ValueError("No detection boxes found")
205
+ x1, y1, x2, y2 = xyxy[0]
206
+ except Exception as e:
207
+ logging.warning(f"Error unpacking detection boxes: {e}")
208
+ raise
209
+
210
+ region = img_cv[int(y1):int(y2), int(x1):int(x2)]
211
 
212
  # Save detection overlay
213
  det_vis = img_cv.copy()
214
+ cv2.rectangle(det_vis, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
215
  os.makedirs(f"{self.uploads_dir}/analysis", exist_ok=True)
216
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
217
  det_path = f"{self.uploads_dir}/analysis/detection_{ts}.png"
218
  cv2.imwrite(det_path, det_vis)
219
 
220
  # Segmentation
221
+ length, breadth, area = 0, 0, 0
222
  seg_path = None
223
  seg_model = self.models_cache.get("seg")
224
  if seg_model:
225
+ try:
226
+ h, w = seg_model.input_shape[1:3]
227
+ inp = cv2.resize(region, (w, h)) / 255.0
228
+ mask_pred = seg_model.predict(inp[None])
229
+ if mask_pred.shape[1:3] != (h, w):
230
+ # Resize if needed
231
+ mask_pred = np.squeeze(mask_pred)
232
+ mask = (mask_pred[0, :, :, 0] > 0.5).astype(np.uint8)
233
+ mask_rs = cv2.resize(mask, (region.shape[1], region.shape[0]), interpolation=cv2.INTER_NEAREST)
234
+ # Save segmentation visualization
235
+ ov = region.copy()
236
+ ov[mask_rs == 1] = [0, 0, 255]
237
+ seg_vis = cv2.addWeighted(region, 0.7, ov, 0.3, 0)
238
+ seg_path = f"{self.uploads_dir}/analysis/segmentation_{ts}.png"
239
+ cv2.imwrite(seg_path, seg_vis)
240
+
241
+ # Find contours
242
+ cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
243
+ if cnts:
244
+ cnt = max(cnts, key=cv2.contourArea)
245
+ x, y, w_box, h_box = cv2.boundingRect(cnt)
246
+ length = round(h_box / self.px_per_cm, 2)
247
+ breadth = round(w_box / self.px_per_cm, 2)
248
+ area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
249
+ except Exception as e:
250
+ logging.warning(f"Segmentation processing error: {e}")
251
 
252
  # Classification
253
  wound_type = "Unknown"
 
255
  if cls_pipe:
256
  try:
257
  preds = cls_pipe(Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB)))
258
+ if preds:
259
+ wound_type = max(preds, key=lambda x: x["score"])["label"]
260
+ except Exception as e:
261
+ logging.warning(f"Classification error: {e}")
262
 
263
  return {
264
  "wound_type": wound_type,
 
271
  }
272
 
273
  def query_guidelines(self, query: str) -> str:
 
274
  vs = self.knowledge_base_cache.get("vector_store")
275
  if not vs:
276
  return "Clinical guidelines unavailable"
 
279
  f"Source: {d.metadata.get('source','?')}, Page: {d.metadata.get('page','?')}\n{d.page_content}" for d in docs
280
  )
281
 
282
+ def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None):
283
+ det_path = visual_results.get("detection_image_path", "")
284
+ seg_path = visual_results.get("segmentation_image_path", "")
285
+ report = generate_medgemma_report(patient_info, visual_results, guideline_context, det_path, seg_path, max_new_tokens)
 
 
 
 
286
  if report:
287
  return report
288
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
289
 
290
+ def _generate_fallback_report(self, patient_info, visual_results, guideline_context):
 
 
 
 
 
291
  return (
292
  f"# Fallback Report\n{patient_info}\n"
293
  f"Type: {visual_results.get('wound_type','Unknown')}\n"
294
+ f"Detection Image: {visual_results.get('detection_image_path','N/A')}\n"
295
+ f"Segmentation Image: {visual_results.get('segmentation_image_path','N/A')}\n"
296
  f"Guidelines: {guideline_context[:200]}..."
297
  )
298
 
299
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
 
300
  os.makedirs(self.uploads_dir, exist_ok=True)
301
  fn = f"{datetime.now():%Y%m%d_%H%M%S}.png"
302
  path = os.path.join(self.uploads_dir, fn)
303
  image_pil.convert("RGB").save(path)
 
304
  if self.hf_token and self.dataset_id:
305
  try:
306
  HfApi().upload_file(
 
314
  logging.warning(f"HF upload failed: {e}")
315
  return path
316
 
317
+ def full_analysis_pipeline(self, image, questionnaire_data):
 
318
  try:
319
+ saved = self.save_and_commit_image(image)
320
+ vis = self.perform_visual_analysis(image)
321
  info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v)
322
  gc = self.query_guidelines(info)
323
+ report = self.generate_final_report(info, vis, gc, image)
324
  return {'success': True, 'visual_analysis': vis, 'report': report, 'saved_image_path': saved}
325
  except Exception as e:
326
  logging.error(f"Pipeline error: {e}")
327
  return {'success': False, 'error': str(e)}
328
 
329
+ def analyze_wound(self, image, questionnaire_data):
 
330
  if isinstance(image, str):
331
  image = Image.open(image)
332
  return self.full_analysis_pipeline(image, questionnaire_data)
333
 
334
+ def _assess_risk_legacy(self, questionnaire_data):
 
335
  risk_factors, risk_score = [], 0
336
  try:
337
  age = questionnaire_data.get('patient_age', 0)
 
360
  return {'risk_score': risk_score, 'risk_level': level, 'risk_factors': risk_factors}
361
  except Exception as e:
362
  logging.error(f"Risk assessment error: {e}")
363
+ return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}