SmartHeal commited on
Commit
beadd24
·
verified ·
1 Parent(s): a923317

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +33 -77
src/ai_processor.py CHANGED
@@ -1,7 +1,11 @@
1
  import os
2
- # Ensure all CPU-only models never touch CUDA
3
  os.environ['CUDA_VISIBLE_DEVICES'] = ''
4
 
 
 
 
 
5
  import io
6
  import base64
7
  import logging
@@ -20,7 +24,6 @@ from huggingface_hub import HfApi, HfFolder
20
  import spaces
21
  from .config import Config
22
 
23
- # System prompt for MedGemma
24
  default_system_prompt = (
25
  "You are a world-class medical AI assistant specializing in wound care "
26
  "with expertise in wound assessment and treatment. Provide concise, "
@@ -42,11 +45,7 @@ def generate_medgemma_report(
42
  segmentation_image_path: str,
43
  max_new_tokens: int = None
44
  ) -> str:
45
- """
46
- Runs on GPU. Lazy-loads the MedGemma pipeline and returns the markdown report.
47
- Accepts only primitive types and file-paths, so pickling works.
48
- """
49
- # Lazy-load pipeline
50
  if not hasattr(generate_medgemma_report, "_pipe"):
51
  try:
52
  cfg = Config()
@@ -65,16 +64,13 @@ def generate_medgemma_report(
65
 
66
  pipe = generate_medgemma_report._pipe
67
 
68
- # Assemble messages
69
  msgs = [
70
  {'role':'system','content':[{'type':'text','text':default_system_prompt}]},
71
  {'role':'user','content':[]}
72
  ]
73
- # Attach images
74
  for path in (detection_image_path, segmentation_image_path):
75
  if path and os.path.exists(path):
76
  msgs[1]['content'].append({'type':'image','image': Image.open(path)})
77
- # Attach text
78
  prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results.get('wound_type','Unknown')}"
79
  msgs[1]['content'].append({'type':'text','text': prompt})
80
 
@@ -96,28 +92,24 @@ class AIProcessor:
96
  self._load_knowledge_base()
97
 
98
  def _initialize_models(self):
99
- """Load all CPU-only models here."""
100
- # Set HuggingFace token
101
  if self.config.HF_TOKEN:
102
  HfFolder.save_token(self.config.HF_TOKEN)
103
  logging.info("✅ HuggingFace token set")
104
 
105
- # YOLO detection (CPU)
106
  try:
 
107
  self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
108
  logging.info("✅ YOLO model loaded (CPU only)")
109
  except Exception as e:
110
  logging.error(f"YOLO load failed: {e}")
111
  raise
112
 
113
- # Segmentation (CPU)
114
  try:
115
  self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False)
116
  logging.info("✅ Segmentation model loaded (CPU)")
117
  except Exception as e:
118
  logging.warning(f"Segmentation model not available: {e}")
119
 
120
- # Classification (CPU)
121
  try:
122
  self.models_cache['cls'] = pipeline(
123
  'image-classification',
@@ -129,7 +121,6 @@ class AIProcessor:
129
  except Exception as e:
130
  logging.warning(f"Classification pipeline not available: {e}")
131
 
132
- # Embedding model (CPU)
133
  try:
134
  self.models_cache['embedding_model'] = HuggingFaceEmbeddings(
135
  model_name='sentence-transformers/all-MiniLM-L6-v2',
@@ -140,26 +131,22 @@ class AIProcessor:
140
  logging.warning(f"Embedding model not available: {e}")
141
 
142
  def _load_knowledge_base(self):
143
- """Load PDF guidelines into a FAISS vector store."""
144
  docs = []
145
  for pdf in self.config.GUIDELINE_PDFS:
146
  if os.path.exists(pdf):
147
- loader = PyPDFLoader(pdf)
148
- docs.extend(loader.load())
149
  logging.info(f"Loaded PDF: {pdf}")
150
 
151
  if docs and 'embedding_model' in self.models_cache:
152
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
153
  chunks = splitter.split_documents(docs)
154
- vs = FAISS.from_documents(chunks, self.models_cache['embedding_model'])
155
- self.knowledge_base_cache['vectorstore'] = vs
156
  logging.info(f"✅ Knowledge base loaded ({len(chunks)} chunks)")
157
  else:
158
  self.knowledge_base_cache['vectorstore'] = None
159
  logging.warning("Knowledge base unavailable")
160
 
161
  def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
162
- """Detect & segment on CPU; return metrics + file paths."""
163
  if 'det' not in self.models_cache:
164
  raise RuntimeError("YOLO model ('det') not loaded")
165
 
@@ -171,7 +158,6 @@ class AIProcessor:
171
  x1, y1, x2, y2 = res.boxes.xyxy[0].cpu().numpy().astype(int)
172
  region = img_cv[y1:y2, x1:x2]
173
 
174
- # Save detection overlay
175
  det_vis = img_cv.copy()
176
  cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0,255,0), 2)
177
  os.makedirs(f"{self.config.UPLOADS_DIR}/analysis", exist_ok=True)
@@ -179,7 +165,6 @@ class AIProcessor:
179
  det_path = f"{self.config.UPLOADS_DIR}/analysis/detection_{ts}.png"
180
  cv2.imwrite(det_path, det_vis)
181
 
182
- # Segmentation metrics
183
  length = breadth = area = 0
184
  seg_path = None
185
  if 'seg' in self.models_cache:
@@ -199,7 +184,6 @@ class AIProcessor:
199
  breadth= round(w0/self.px_per_cm,2)
200
  area = round(cv2.contourArea(cnt)/(self.px_per_cm**2),2)
201
 
202
- # Classification
203
  wound_type = 'Unknown'
204
  if 'cls' in self.models_cache:
205
  try:
@@ -238,9 +222,6 @@ class AIProcessor:
238
  image_pil: Image.Image,
239
  max_new_tokens: int = None
240
  ) -> str:
241
- """
242
- Signature unchanged. Gathers arguments, calls GPU function, and falls back if needed.
243
- """
244
  det = visual_results.get('detection_image_path', '')
245
  seg = visual_results.get('segmentation_image_path', '')
246
  report = generate_medgemma_report(
@@ -314,59 +295,34 @@ class AIProcessor:
314
  image = Image.open(image)
315
  return self.full_analysis_pipeline(image, questionnaire_data)
316
 
317
- def _assess_risk_legacy(self, questionnaire_data):
318
- """Legacy risk assessment for backward compatibility"""
319
- risk_factors = []
320
- risk_score = 0
321
-
322
  try:
323
- # Age factor
324
  age = questionnaire_data.get('patient_age', 0)
325
  if age > 65:
326
- risk_factors.append("Advanced age (>65)")
327
- risk_score += 2
328
  elif age > 50:
329
- risk_factors.append("Older adult (50-65)")
330
- risk_score += 1
331
-
332
- # Duration factor
333
- duration = questionnaire_data.get('wound_duration', '').lower()
334
- if any(term in duration for term in ['month', 'months', 'year']):
335
- risk_factors.append("Chronic wound (>4 weeks)")
336
- risk_score += 3
337
-
338
- # Pain level
339
- pain_level = questionnaire_data.get('pain_level', 0)
340
- if pain_level >= 7:
341
- risk_factors.append("High pain level")
342
- risk_score += 2
343
-
344
- # Medical history risk factors
345
- medical_history = questionnaire_data.get('medical_history', '').lower()
346
- if 'diabetes' in medical_history:
347
- risk_factors.append("Diabetes mellitus")
348
- risk_score += 3
349
- if 'circulation' in medical_history or 'vascular' in medical_history:
350
- risk_factors.append("Vascular/circulation issues")
351
- risk_score += 2
352
- if 'immune' in medical_history:
353
- risk_factors.append("Immune system compromise")
354
- risk_score += 2
355
-
356
- # Determine risk level
357
- if risk_score >= 7:
358
- risk_level = "High"
359
- elif risk_score >= 4:
360
- risk_level = "Moderate"
361
- else:
362
- risk_level = "Low"
363
-
364
- return {
365
- 'risk_score': risk_score,
366
- 'risk_level': risk_level,
367
- 'risk_factors': risk_factors
368
- }
369
-
370
  except Exception as e:
371
  logging.error(f"Risk assessment error: {e}")
372
  return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}
 
1
  import os
2
+ # Force CPU-only until we enter the GPU context
3
  os.environ['CUDA_VISIBLE_DEVICES'] = ''
4
 
5
+ import torch
6
+ # Prevent any CUDA initialization in the main process
7
+ torch.cuda.is_available = lambda: False
8
+
9
  import io
10
  import base64
11
  import logging
 
24
  import spaces
25
  from .config import Config
26
 
 
27
  default_system_prompt = (
28
  "You are a world-class medical AI assistant specializing in wound care "
29
  "with expertise in wound assessment and treatment. Provide concise, "
 
45
  segmentation_image_path: str,
46
  max_new_tokens: int = None
47
  ) -> str:
48
+ # Lazy-load HF pipeline inside GPU context
 
 
 
 
49
  if not hasattr(generate_medgemma_report, "_pipe"):
50
  try:
51
  cfg = Config()
 
64
 
65
  pipe = generate_medgemma_report._pipe
66
 
 
67
  msgs = [
68
  {'role':'system','content':[{'type':'text','text':default_system_prompt}]},
69
  {'role':'user','content':[]}
70
  ]
 
71
  for path in (detection_image_path, segmentation_image_path):
72
  if path and os.path.exists(path):
73
  msgs[1]['content'].append({'type':'image','image': Image.open(path)})
 
74
  prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results.get('wound_type','Unknown')}"
75
  msgs[1]['content'].append({'type':'text','text': prompt})
76
 
 
92
  self._load_knowledge_base()
93
 
94
  def _initialize_models(self):
 
 
95
  if self.config.HF_TOKEN:
96
  HfFolder.save_token(self.config.HF_TOKEN)
97
  logging.info("✅ HuggingFace token set")
98
 
 
99
  try:
100
+ # YOLO on CPU only (no CUDA)
101
  self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
102
  logging.info("✅ YOLO model loaded (CPU only)")
103
  except Exception as e:
104
  logging.error(f"YOLO load failed: {e}")
105
  raise
106
 
 
107
  try:
108
  self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False)
109
  logging.info("✅ Segmentation model loaded (CPU)")
110
  except Exception as e:
111
  logging.warning(f"Segmentation model not available: {e}")
112
 
 
113
  try:
114
  self.models_cache['cls'] = pipeline(
115
  'image-classification',
 
121
  except Exception as e:
122
  logging.warning(f"Classification pipeline not available: {e}")
123
 
 
124
  try:
125
  self.models_cache['embedding_model'] = HuggingFaceEmbeddings(
126
  model_name='sentence-transformers/all-MiniLM-L6-v2',
 
131
  logging.warning(f"Embedding model not available: {e}")
132
 
133
  def _load_knowledge_base(self):
 
134
  docs = []
135
  for pdf in self.config.GUIDELINE_PDFS:
136
  if os.path.exists(pdf):
137
+ docs.extend(PyPDFLoader(pdf).load())
 
138
  logging.info(f"Loaded PDF: {pdf}")
139
 
140
  if docs and 'embedding_model' in self.models_cache:
141
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
142
  chunks = splitter.split_documents(docs)
143
+ self.knowledge_base_cache['vectorstore'] = FAISS.from_documents(chunks, self.models_cache['embedding_model'])
 
144
  logging.info(f"✅ Knowledge base loaded ({len(chunks)} chunks)")
145
  else:
146
  self.knowledge_base_cache['vectorstore'] = None
147
  logging.warning("Knowledge base unavailable")
148
 
149
  def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
 
150
  if 'det' not in self.models_cache:
151
  raise RuntimeError("YOLO model ('det') not loaded")
152
 
 
158
  x1, y1, x2, y2 = res.boxes.xyxy[0].cpu().numpy().astype(int)
159
  region = img_cv[y1:y2, x1:x2]
160
 
 
161
  det_vis = img_cv.copy()
162
  cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0,255,0), 2)
163
  os.makedirs(f"{self.config.UPLOADS_DIR}/analysis", exist_ok=True)
 
165
  det_path = f"{self.config.UPLOADS_DIR}/analysis/detection_{ts}.png"
166
  cv2.imwrite(det_path, det_vis)
167
 
 
168
  length = breadth = area = 0
169
  seg_path = None
170
  if 'seg' in self.models_cache:
 
184
  breadth= round(w0/self.px_per_cm,2)
185
  area = round(cv2.contourArea(cnt)/(self.px_per_cm**2),2)
186
 
 
187
  wound_type = 'Unknown'
188
  if 'cls' in self.models_cache:
189
  try:
 
222
  image_pil: Image.Image,
223
  max_new_tokens: int = None
224
  ) -> str:
 
 
 
225
  det = visual_results.get('detection_image_path', '')
226
  seg = visual_results.get('segmentation_image_path', '')
227
  report = generate_medgemma_report(
 
295
  image = Image.open(image)
296
  return self.full_analysis_pipeline(image, questionnaire_data)
297
 
298
+ def _assess_risk_legacy(self, questionnaire_data: dict) -> dict:
299
+ risk_factors, risk_score = [], 0
 
 
 
300
  try:
 
301
  age = questionnaire_data.get('patient_age', 0)
302
  if age > 65:
303
+ risk_factors.append("Advanced age (>65)"); risk_score += 2
 
304
  elif age > 50:
305
+ risk_factors.append("Older adult (50-65)"); risk_score += 1
306
+ dur = questionnaire_data.get('wound_duration','').lower()
307
+ if any(t in dur for t in ['month','year']):
308
+ risk_factors.append("Chronic wound (>4 weeks)"); risk_score += 3
309
+ pain = questionnaire_data.get('pain_level', 0)
310
+ if pain >= 7:
311
+ risk_factors.append("High pain level"); risk_score += 2
312
+ hist = questionnaire_data.get('medical_history','').lower()
313
+ if 'diabetes' in hist:
314
+ risk_factors.append("Diabetes mellitus"); risk_score += 3
315
+ if 'vascular' in hist:
316
+ risk_factors.append("Vascular issues"); risk_score += 2
317
+ if 'immune' in hist:
318
+ risk_factors.append("Immune compromise"); risk_score += 2
319
+
320
+ level = (
321
+ "High" if risk_score >= 7 else
322
+ "Moderate" if risk_score >= 4 else
323
+ "Low"
324
+ )
325
+ return {'risk_score': risk_score, 'risk_level': level, 'risk_factors': risk_factors}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  except Exception as e:
327
  logging.error(f"Risk assessment error: {e}")
328
  return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}