SmartHeal commited on
Commit
ff7f2b3
·
verified ·
1 Parent(s): 5bf70f5

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +21 -17
src/ai_processor.py CHANGED
@@ -1,21 +1,11 @@
1
  import os
2
- # Force CPU-only until we enter the GPU context
3
- os.environ['CUDA_VISIBLE_DEVICES'] = ''
4
-
5
- import torch
6
- # Prevent any CUDA initialization in the main process
7
- torch.cuda.is_available = lambda: False
8
-
9
  import io
10
  import base64
11
  import logging
12
- import cv2
13
  import numpy as np
 
14
  from PIL import Image
15
  from datetime import datetime
16
- from transformers import pipeline
17
- from ultralytics import YOLO
18
- from tensorflow.keras.models import load_model
19
  from langchain_community.document_loaders import PyPDFLoader
20
  from langchain.text_splitter import RecursiveCharacterTextSplitter
21
  from langchain_community.embeddings import HuggingFaceEmbeddings
@@ -37,6 +27,8 @@ default_system_prompt = (
37
  "patient context."
38
  )
39
 
 
 
40
  @spaces.GPU(enable_queue=True, duration=120)
41
  def generate_medgemma_report(
42
  patient_info: str,
@@ -46,14 +38,22 @@ def generate_medgemma_report(
46
  segmentation_image_path: str,
47
  max_new_tokens: int = None
48
  ) -> str:
49
- """Runs on GPU. Lazy-loads the MedGemma pipeline and returns the markdown report."""
 
 
 
 
 
 
 
 
50
  if not hasattr(generate_medgemma_report, "_pipe"):
51
  try:
52
  cfg = Config()
53
  generate_medgemma_report._pipe = pipeline(
54
  'image-text-to-text',
55
  model='google/medgemma-4b-it',
56
- device='auto',
57
  torch_dtype='auto',
58
  offload_folder='offload',
59
  token=cfg.HF_TOKEN
@@ -65,13 +65,13 @@ def generate_medgemma_report(
65
 
66
  pipe = generate_medgemma_report._pipe
67
 
68
- # Assemble messages
69
  msgs = [
70
  {'role': 'system', 'content': [{'type': 'text', 'text': default_system_prompt}]},
71
  {'role': 'user', 'content': []},
72
  ]
73
 
74
- # Attach images
75
  for path in (detection_image_path, segmentation_image_path):
76
  if path and os.path.exists(path):
77
  msgs[1]['content'].append({'type': 'image', 'image': Image.open(path)})
@@ -96,7 +96,7 @@ class AIProcessor:
96
  self.px_per_cm = self.config.PIXELS_PER_CM
97
  self._initialize_models()
98
  self._load_knowledge_base()
99
-
100
  def _initialize_models(self):
101
  """Load all CPU-only models here."""
102
  # Set HuggingFace token
@@ -106,6 +106,7 @@ class AIProcessor:
106
 
107
  # YOLO detection (CPU-only)
108
  try:
 
109
  self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
110
  logging.info("✅ YOLO model loaded (CPU only)")
111
  except Exception as e:
@@ -114,6 +115,7 @@ class AIProcessor:
114
 
115
  # Segmentation model (CPU)
116
  try:
 
117
  self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False)
118
  logging.info("✅ Segmentation model loaded (CPU)")
119
  except Exception as e:
@@ -121,6 +123,7 @@ class AIProcessor:
121
 
122
  # Classification pipeline (CPU)
123
  try:
 
124
  self.models_cache['cls'] = pipeline(
125
  'image-classification',
126
  model='Hemg/Wound-classification',
@@ -241,6 +244,7 @@ class AIProcessor:
241
  ) -> str:
242
  det = visual_results.get('detection_image_path', '')
243
  seg = visual_results.get('segmentation_image_path', '')
 
244
  report = generate_medgemma_report(
245
  patient_info, visual_results, guideline_context,
246
  det, seg, max_new_tokens
@@ -324,4 +328,4 @@ class AIProcessor:
324
  return {'risk_score': risk_score, 'risk_level': level, 'risk_factors': risk_factors}
325
  except Exception as e:
326
  logging.error(f"Risk assessment error: {e}")
327
- return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}
 
1
  import os
 
 
 
 
 
 
 
2
  import io
3
  import base64
4
  import logging
 
5
  import numpy as np
6
+ import cv2
7
  from PIL import Image
8
  from datetime import datetime
 
 
 
9
  from langchain_community.document_loaders import PyPDFLoader
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
11
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
27
  "patient context."
28
  )
29
 
30
+ # No torch or transformers-related imports at top-level!
31
+
32
  @spaces.GPU(enable_queue=True, duration=120)
33
  def generate_medgemma_report(
34
  patient_info: str,
 
38
  segmentation_image_path: str,
39
  max_new_tokens: int = None
40
  ) -> str:
41
+ # --- All GPU-related imports and model loading here! ---
42
+ import torch
43
+ from transformers import pipeline
44
+ from PIL import Image
45
+
46
+ # System prompt as before
47
+ global default_system_prompt
48
+
49
+ # Lazy-load MedGemma pipeline on GPU
50
  if not hasattr(generate_medgemma_report, "_pipe"):
51
  try:
52
  cfg = Config()
53
  generate_medgemma_report._pipe = pipeline(
54
  'image-text-to-text',
55
  model='google/medgemma-4b-it',
56
+ device='cuda', # Explicitly on GPU
57
  torch_dtype='auto',
58
  offload_folder='offload',
59
  token=cfg.HF_TOKEN
 
65
 
66
  pipe = generate_medgemma_report._pipe
67
 
68
+ # Compose messages
69
  msgs = [
70
  {'role': 'system', 'content': [{'type': 'text', 'text': default_system_prompt}]},
71
  {'role': 'user', 'content': []},
72
  ]
73
 
74
+ # Attach images if available
75
  for path in (detection_image_path, segmentation_image_path):
76
  if path and os.path.exists(path):
77
  msgs[1]['content'].append({'type': 'image', 'image': Image.open(path)})
 
96
  self.px_per_cm = self.config.PIXELS_PER_CM
97
  self._initialize_models()
98
  self._load_knowledge_base()
99
+
100
  def _initialize_models(self):
101
  """Load all CPU-only models here."""
102
  # Set HuggingFace token
 
106
 
107
  # YOLO detection (CPU-only)
108
  try:
109
+ from ultralytics import YOLO
110
  self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
111
  logging.info("✅ YOLO model loaded (CPU only)")
112
  except Exception as e:
 
115
 
116
  # Segmentation model (CPU)
117
  try:
118
+ from tensorflow.keras.models import load_model
119
  self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False)
120
  logging.info("✅ Segmentation model loaded (CPU)")
121
  except Exception as e:
 
123
 
124
  # Classification pipeline (CPU)
125
  try:
126
+ from transformers import pipeline
127
  self.models_cache['cls'] = pipeline(
128
  'image-classification',
129
  model='Hemg/Wound-classification',
 
244
  ) -> str:
245
  det = visual_results.get('detection_image_path', '')
246
  seg = visual_results.get('segmentation_image_path', '')
247
+ # This GPU call is safe: it triggers all CUDA/model code *inside* the decorator context.
248
  report = generate_medgemma_report(
249
  patient_info, visual_results, guideline_context,
250
  det, seg, max_new_tokens
 
328
  return {'risk_score': risk_score, 'risk_level': level, 'risk_factors': risk_factors}
329
  except Exception as e:
330
  logging.error(f"Risk assessment error: {e}")
331
+ return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}