SmartHeal commited on
Commit
5bf70f5
·
verified ·
1 Parent(s): beadd24

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +47 -48
src/ai_processor.py CHANGED
@@ -24,6 +24,7 @@ from huggingface_hub import HfApi, HfFolder
24
  import spaces
25
  from .config import Config
26
 
 
27
  default_system_prompt = (
28
  "You are a world-class medical AI assistant specializing in wound care "
29
  "with expertise in wound assessment and treatment. Provide concise, "
@@ -45,7 +46,7 @@ def generate_medgemma_report(
45
  segmentation_image_path: str,
46
  max_new_tokens: int = None
47
  ) -> str:
48
- # Lazy-load HF pipeline inside GPU context
49
  if not hasattr(generate_medgemma_report, "_pipe"):
50
  try:
51
  cfg = Config()
@@ -64,22 +65,27 @@ def generate_medgemma_report(
64
 
65
  pipe = generate_medgemma_report._pipe
66
 
 
67
  msgs = [
68
- {'role':'system','content':[{'type':'text','text':default_system_prompt}]},
69
- {'role':'user','content':[]}
70
  ]
 
 
71
  for path in (detection_image_path, segmentation_image_path):
72
  if path and os.path.exists(path):
73
- msgs[1]['content'].append({'type':'image','image': Image.open(path)})
 
 
74
  prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results.get('wound_type','Unknown')}"
75
- msgs[1]['content'].append({'type':'text','text': prompt})
76
 
77
  out = pipe(
78
  text=msgs,
79
  max_new_tokens=max_new_tokens or Config().MAX_NEW_TOKENS,
80
  do_sample=False
81
  )
82
- return out[0]['generated_text'][-1].get('content','')
83
 
84
 
85
  class AIProcessor:
@@ -92,24 +98,28 @@ class AIProcessor:
92
  self._load_knowledge_base()
93
 
94
  def _initialize_models(self):
 
 
95
  if self.config.HF_TOKEN:
96
  HfFolder.save_token(self.config.HF_TOKEN)
97
  logging.info("✅ HuggingFace token set")
98
 
 
99
  try:
100
- # YOLO on CPU only (no CUDA)
101
  self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
102
  logging.info("✅ YOLO model loaded (CPU only)")
103
  except Exception as e:
104
  logging.error(f"YOLO load failed: {e}")
105
  raise
106
 
 
107
  try:
108
  self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False)
109
  logging.info("✅ Segmentation model loaded (CPU)")
110
  except Exception as e:
111
  logging.warning(f"Segmentation model not available: {e}")
112
 
 
113
  try:
114
  self.models_cache['cls'] = pipeline(
115
  'image-classification',
@@ -121,32 +131,38 @@ class AIProcessor:
121
  except Exception as e:
122
  logging.warning(f"Classification pipeline not available: {e}")
123
 
 
124
  try:
125
  self.models_cache['embedding_model'] = HuggingFaceEmbeddings(
126
  model_name='sentence-transformers/all-MiniLM-L6-v2',
127
- model_kwargs={'device':'cpu'}
128
  )
129
  logging.info("✅ Embedding model loaded (CPU)")
130
  except Exception as e:
131
  logging.warning(f"Embedding model not available: {e}")
132
 
133
  def _load_knowledge_base(self):
 
134
  docs = []
135
  for pdf in self.config.GUIDELINE_PDFS:
136
  if os.path.exists(pdf):
137
- docs.extend(PyPDFLoader(pdf).load())
 
138
  logging.info(f"Loaded PDF: {pdf}")
139
 
140
  if docs and 'embedding_model' in self.models_cache:
141
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
142
  chunks = splitter.split_documents(docs)
143
- self.knowledge_base_cache['vectorstore'] = FAISS.from_documents(chunks, self.models_cache['embedding_model'])
 
 
144
  logging.info(f"✅ Knowledge base loaded ({len(chunks)} chunks)")
145
  else:
146
  self.knowledge_base_cache['vectorstore'] = None
147
  logging.warning("Knowledge base unavailable")
148
 
149
  def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
 
150
  if 'det' not in self.models_cache:
151
  raise RuntimeError("YOLO model ('det') not loaded")
152
 
@@ -158,6 +174,7 @@ class AIProcessor:
158
  x1, y1, x2, y2 = res.boxes.xyxy[0].cpu().numpy().astype(int)
159
  region = img_cv[y1:y2, x1:x2]
160
 
 
161
  det_vis = img_cv.copy()
162
  cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0,255,0), 2)
163
  os.makedirs(f"{self.config.UPLOADS_DIR}/analysis", exist_ok=True)
@@ -165,31 +182,31 @@ class AIProcessor:
165
  det_path = f"{self.config.UPLOADS_DIR}/analysis/detection_{ts}.png"
166
  cv2.imwrite(det_path, det_vis)
167
 
 
168
  length = breadth = area = 0
169
  seg_path = None
170
  if 'seg' in self.models_cache:
171
  h, w = self.models_cache['seg'].input_shape[1:3]
172
- inp = cv2.resize(region, (w,h)) / 255.0
173
  mask = (self.models_cache['seg'].predict(inp[None])[0,:,:,0] > 0.5).astype(np.uint8)
174
  mask_rs = cv2.resize(mask, (region.shape[1], region.shape[0]), interpolation=cv2.INTER_NEAREST)
175
  ov = region.copy(); ov[mask_rs==1] = [0,0,255]
176
- seg_vis = cv2.addWeighted(region,0.7,ov,0.3,0)
177
  seg_path = f"{self.config.UPLOADS_DIR}/analysis/segmentation_{ts}.png"
178
  cv2.imwrite(seg_path, seg_vis)
179
  cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
180
  if cnts:
181
  cnt = max(cnts, key=cv2.contourArea)
182
- _,_,w0,h0 = cv2.boundingRect(cnt)
183
- length = round(h0/self.px_per_cm,2)
184
- breadth= round(w0/self.px_per_cm,2)
185
- area = round(cv2.contourArea(cnt)/(self.px_per_cm**2),2)
186
 
 
187
  wound_type = 'Unknown'
188
  if 'cls' in self.models_cache:
189
  try:
190
- preds = self.models_cache['cls'](
191
- Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB))
192
- )
193
  wound_type = max(preds, key=lambda x: x['score'])['label']
194
  except Exception:
195
  pass
@@ -225,12 +242,8 @@ class AIProcessor:
225
  det = visual_results.get('detection_image_path', '')
226
  seg = visual_results.get('segmentation_image_path', '')
227
  report = generate_medgemma_report(
228
- patient_info,
229
- visual_results,
230
- guideline_context,
231
- det,
232
- seg,
233
- max_new_tokens
234
  )
235
  if report:
236
  return report
@@ -269,23 +282,14 @@ class AIProcessor:
269
  logging.warning(f"HF upload failed: {e}")
270
  return path
271
 
272
- def full_analysis_pipeline(
273
- self,
274
- image_pil: Image.Image,
275
- questionnaire_data: dict
276
- ) -> dict:
277
  try:
278
  saved = self.save_and_commit_image(image_pil)
279
- vis = self.perform_visual_analysis(image_pil)
280
- info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v)
281
- gc = self.query_guidelines(info)
282
- report = self.generate_final_report(info, vis, gc, image_pil)
283
- return {
284
- 'success': True,
285
- 'visual_analysis': vis,
286
- 'report': report,
287
- 'saved_image_path': saved
288
- }
289
  except Exception as e:
290
  logging.error(f"Pipeline error: {e}")
291
  return {'success': False, 'error': str(e)}
@@ -303,7 +307,7 @@ class AIProcessor:
303
  risk_factors.append("Advanced age (>65)"); risk_score += 2
304
  elif age > 50:
305
  risk_factors.append("Older adult (50-65)"); risk_score += 1
306
- dur = questionnaire_data.get('wound_duration','').lower()
307
  if any(t in dur for t in ['month','year']):
308
  risk_factors.append("Chronic wound (>4 weeks)"); risk_score += 3
309
  pain = questionnaire_data.get('pain_level', 0)
@@ -316,13 +320,8 @@ class AIProcessor:
316
  risk_factors.append("Vascular issues"); risk_score += 2
317
  if 'immune' in hist:
318
  risk_factors.append("Immune compromise"); risk_score += 2
319
-
320
- level = (
321
- "High" if risk_score >= 7 else
322
- "Moderate" if risk_score >= 4 else
323
- "Low"
324
- )
325
  return {'risk_score': risk_score, 'risk_level': level, 'risk_factors': risk_factors}
326
  except Exception as e:
327
  logging.error(f"Risk assessment error: {e}")
328
- return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}
 
24
  import spaces
25
  from .config import Config
26
 
27
+ # Inline system prompt for MedGemma GPU pipeline
28
  default_system_prompt = (
29
  "You are a world-class medical AI assistant specializing in wound care "
30
  "with expertise in wound assessment and treatment. Provide concise, "
 
46
  segmentation_image_path: str,
47
  max_new_tokens: int = None
48
  ) -> str:
49
+ """Runs on GPU. Lazy-loads the MedGemma pipeline and returns the markdown report."""
50
  if not hasattr(generate_medgemma_report, "_pipe"):
51
  try:
52
  cfg = Config()
 
65
 
66
  pipe = generate_medgemma_report._pipe
67
 
68
+ # Assemble messages
69
  msgs = [
70
+ {'role': 'system', 'content': [{'type': 'text', 'text': default_system_prompt}]},
71
+ {'role': 'user', 'content': []},
72
  ]
73
+
74
+ # Attach images
75
  for path in (detection_image_path, segmentation_image_path):
76
  if path and os.path.exists(path):
77
+ msgs[1]['content'].append({'type': 'image', 'image': Image.open(path)})
78
+
79
+ # Attach text
80
  prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results.get('wound_type','Unknown')}"
81
+ msgs[1]['content'].append({'type': 'text', 'text': prompt})
82
 
83
  out = pipe(
84
  text=msgs,
85
  max_new_tokens=max_new_tokens or Config().MAX_NEW_TOKENS,
86
  do_sample=False
87
  )
88
+ return out[0]['generated_text'][-1].get('content', '')
89
 
90
 
91
  class AIProcessor:
 
98
  self._load_knowledge_base()
99
 
100
  def _initialize_models(self):
101
+ """Load all CPU-only models here."""
102
+ # Set HuggingFace token
103
  if self.config.HF_TOKEN:
104
  HfFolder.save_token(self.config.HF_TOKEN)
105
  logging.info("✅ HuggingFace token set")
106
 
107
+ # YOLO detection (CPU-only)
108
  try:
 
109
  self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
110
  logging.info("✅ YOLO model loaded (CPU only)")
111
  except Exception as e:
112
  logging.error(f"YOLO load failed: {e}")
113
  raise
114
 
115
+ # Segmentation model (CPU)
116
  try:
117
  self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False)
118
  logging.info("✅ Segmentation model loaded (CPU)")
119
  except Exception as e:
120
  logging.warning(f"Segmentation model not available: {e}")
121
 
122
+ # Classification pipeline (CPU)
123
  try:
124
  self.models_cache['cls'] = pipeline(
125
  'image-classification',
 
131
  except Exception as e:
132
  logging.warning(f"Classification pipeline not available: {e}")
133
 
134
+ # Embedding model (CPU)
135
  try:
136
  self.models_cache['embedding_model'] = HuggingFaceEmbeddings(
137
  model_name='sentence-transformers/all-MiniLM-L6-v2',
138
+ model_kwargs={'device': 'cpu'}
139
  )
140
  logging.info("✅ Embedding model loaded (CPU)")
141
  except Exception as e:
142
  logging.warning(f"Embedding model not available: {e}")
143
 
144
  def _load_knowledge_base(self):
145
+ """Load PDF guidelines into a FAISS vector store."""
146
  docs = []
147
  for pdf in self.config.GUIDELINE_PDFS:
148
  if os.path.exists(pdf):
149
+ loader = PyPDFLoader(pdf)
150
+ docs.extend(loader.load())
151
  logging.info(f"Loaded PDF: {pdf}")
152
 
153
  if docs and 'embedding_model' in self.models_cache:
154
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
155
  chunks = splitter.split_documents(docs)
156
+ self.knowledge_base_cache['vectorstore'] = FAISS.from_documents(
157
+ chunks, self.models_cache['embedding_model']
158
+ )
159
  logging.info(f"✅ Knowledge base loaded ({len(chunks)} chunks)")
160
  else:
161
  self.knowledge_base_cache['vectorstore'] = None
162
  logging.warning("Knowledge base unavailable")
163
 
164
  def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
165
+ """Detect & segment on CPU; return metrics + file paths."""
166
  if 'det' not in self.models_cache:
167
  raise RuntimeError("YOLO model ('det') not loaded")
168
 
 
174
  x1, y1, x2, y2 = res.boxes.xyxy[0].cpu().numpy().astype(int)
175
  region = img_cv[y1:y2, x1:x2]
176
 
177
+ # Save detection overlay
178
  det_vis = img_cv.copy()
179
  cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0,255,0), 2)
180
  os.makedirs(f"{self.config.UPLOADS_DIR}/analysis", exist_ok=True)
 
182
  det_path = f"{self.config.UPLOADS_DIR}/analysis/detection_{ts}.png"
183
  cv2.imwrite(det_path, det_vis)
184
 
185
+ # Segmentation
186
  length = breadth = area = 0
187
  seg_path = None
188
  if 'seg' in self.models_cache:
189
  h, w = self.models_cache['seg'].input_shape[1:3]
190
+ inp = cv2.resize(region, (w, h)) / 255.0
191
  mask = (self.models_cache['seg'].predict(inp[None])[0,:,:,0] > 0.5).astype(np.uint8)
192
  mask_rs = cv2.resize(mask, (region.shape[1], region.shape[0]), interpolation=cv2.INTER_NEAREST)
193
  ov = region.copy(); ov[mask_rs==1] = [0,0,255]
194
+ seg_vis = cv2.addWeighted(region, 0.7, ov, 0.3, 0)
195
  seg_path = f"{self.config.UPLOADS_DIR}/analysis/segmentation_{ts}.png"
196
  cv2.imwrite(seg_path, seg_vis)
197
  cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
198
  if cnts:
199
  cnt = max(cnts, key=cv2.contourArea)
200
+ _, _, w0, h0 = cv2.boundingRect(cnt)
201
+ length = round(h0 / self.px_per_cm, 2)
202
+ breadth = round(w0 / self.px_per_cm, 2)
203
+ area = round(cv2.contourArea(cnt) / (self.px_per_cm**2), 2)
204
 
205
+ # Classification
206
  wound_type = 'Unknown'
207
  if 'cls' in self.models_cache:
208
  try:
209
+ preds = self.models_cache['cls'](Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB)))
 
 
210
  wound_type = max(preds, key=lambda x: x['score'])['label']
211
  except Exception:
212
  pass
 
242
  det = visual_results.get('detection_image_path', '')
243
  seg = visual_results.get('segmentation_image_path', '')
244
  report = generate_medgemma_report(
245
+ patient_info, visual_results, guideline_context,
246
+ det, seg, max_new_tokens
 
 
 
 
247
  )
248
  if report:
249
  return report
 
282
  logging.warning(f"HF upload failed: {e}")
283
  return path
284
 
285
+ def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict:
 
 
 
 
286
  try:
287
  saved = self.save_and_commit_image(image_pil)
288
+ vis = self.perform_visual_analysis(image_pil)
289
+ info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v)
290
+ gc = self.query_guidelines(info)
291
+ report= self.generate_final_report(info, vis, gc, image_pil)
292
+ return {'success': True, 'visual_analysis': vis, 'report': report, 'saved_image_path': saved}
 
 
 
 
 
293
  except Exception as e:
294
  logging.error(f"Pipeline error: {e}")
295
  return {'success': False, 'error': str(e)}
 
307
  risk_factors.append("Advanced age (>65)"); risk_score += 2
308
  elif age > 50:
309
  risk_factors.append("Older adult (50-65)"); risk_score += 1
310
+ dur = questionnaire_data.get('wound_duration', '').lower()
311
  if any(t in dur for t in ['month','year']):
312
  risk_factors.append("Chronic wound (>4 weeks)"); risk_score += 3
313
  pain = questionnaire_data.get('pain_level', 0)
 
320
  risk_factors.append("Vascular issues"); risk_score += 2
321
  if 'immune' in hist:
322
  risk_factors.append("Immune compromise"); risk_score += 2
323
+ level = ("High" if risk_score >= 7 else "Moderate" if risk_score >= 4 else "Low")
 
 
 
 
 
324
  return {'risk_score': risk_score, 'risk_level': level, 'risk_factors': risk_factors}
325
  except Exception as e:
326
  logging.error(f"Risk assessment error: {e}")
327
+ return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}