SmartHeal commited on
Commit
5e405a7
·
verified ·
1 Parent(s): cfd566e

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +53 -105
src/ai_processor.py CHANGED
@@ -140,116 +140,76 @@ Keep to 220–300 words. Do NOT provide diagnosis. Avoid contraindicated advice.
140
 
141
  # ---------- MedGemma-only text generator ----------
142
  @_SPACES_GPU(enable_queue=True)
143
- def _medgemma_generate_gpu_with_pipeline(
144
- prompt: str,
145
- image_pil, # PIL.Image (the wound image)
146
- model_id: str | None = None, # e.g. "unsloth/medgemma-4b-it-bnb-4bit"
147
- max_new_tokens: int = 256,
148
- token: str | None = None,
149
- ) -> str:
150
- """
151
- Vision LLM via Transformers pipeline using the "messages" format:
152
- [{"role":"user","content":[{"type":"image","image": PIL}, {"type":"text","text": "..."}]}]
153
- Returns a generated string.
154
- """
155
  import os, torch
156
  from transformers import pipeline
157
- try:
158
- from transformers import BitsAndBytesConfig # only needed for 4-bit
159
- except Exception:
160
- BitsAndBytesConfig = None
161
 
162
- # <<< START OF FIX >>>
163
- # Force CUDA initialization to prevent IndexError in bitsandbytes/triton check.
164
- # This ensures the CUDA context is ready before transformers and bnb probe the device.
165
  use_cuda = torch.cuda.is_available()
166
- if use_cuda:
167
- try:
168
- torch.tensor([1.0]).cuda()
169
- except Exception as e:
170
- # If even this fails, CUDA is truly not working.
171
- print(f"WARNING: CUDA pre-initialization failed: {e}")
172
- use_cuda = False
173
- # <<< END OF FIX >>>
174
 
175
- hf_token = token or os.getenv("HF_TOKEN")
176
- mid = model_id or "unsloth/medgemma-4b-it-bnb-4bit"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- # device / dtype
179
- # use_cuda is already defined above
180
- device = 0 if use_cuda else -1
181
- dtype = torch.bfloat16 if use_cuda else torch.float32
 
 
 
 
 
 
 
 
182
 
183
- # Build messages in the doc format
184
  messages = [{
185
  "role": "user",
186
  "content": [
187
- {"type": "image", "image": image_pil}, # local PIL image
188
  {"type": "text", "text": prompt},
189
  ],
190
  }]
191
 
192
- pipe_kwargs = dict(
193
- task="image-text-to-text",
194
- model=mid,
195
- torch_dtype=dtype,
196
- device=device, # GPU=0 or CPU=-1
197
- trust_remote_code=True,
198
- )
199
 
200
- # Pass HF token (newer Transformers uses `token`; older uses `use_auth_token`)
201
- if hf_token:
202
- try:
203
- pipe_kwargs["token"] = hf_token
204
- except TypeError:
205
- pipe_kwargs["use_auth_token"] = hf_token
206
-
207
- # If this is the 4-bit Unsloth build, attach quantization (requires CUDA + bitsandbytes)
208
- if "bnb-4bit" in mid.lower():
209
- if not use_cuda or BitsAndBytesConfig is None:
210
- raise RuntimeError("Unsloth 4-bit requires CUDA + bitsandbytes; no GPU available.")
211
- bnb = BitsAndBytesConfig(
212
- load_in_4bit=True,
213
- bnb_4bit_quant_type="nf4",
214
- bnb_4bit_use_double_quant=True,
215
- bnb_4bit_compute_dtype=torch.bfloat16,
216
- )
217
- pipe_kwargs["model_kwargs"] = {"quantization_config": bnb}
218
-
219
- # Create pipeline and run with messages
220
- p = pipeline(**pipe_kwargs)
221
- out = p(
222
- text=messages,
223
- max_new_tokens=int(max_new_tokens or 256),
224
- do_sample=False,
225
- temperature=0.2,
226
- return_full_text=False, # we just want the answer
227
- )
228
-
229
- # Normalize output to a string
230
- if isinstance(out, list):
231
- # pipelines often return a list of strings or dicts; handle both
232
- first = out[0]
233
- text = first.get("generated_text") if isinstance(first, dict) else str(first)
234
- else:
235
- text = str(out)
236
 
237
- return (text or "").strip() or "⚠️ Empty response"
238
-
239
-
240
-
241
-
242
- def generate_medgemma_report(
243
- patient_info: str,
244
- visual_results: Dict,
245
- guideline_context: str,
246
- image_pil, # keep passing the PIL image
247
- max_new_tokens: int | None = None,
248
- ) -> str:
249
  if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
250
  return "⚠️ VLM disabled"
251
 
252
- # Build your prompt as before
253
  uprompt = SMARTHEAL_USER_PREFIX.format(
254
  patient_info=patient_info,
255
  wound_type=visual_results.get("wound_type", "Unknown"),
@@ -261,29 +221,17 @@ def generate_medgemma_report(
261
  guideline_context=(guideline_context or "")[:900],
262
  )
263
  prompt = f"{SMARTHEAL_SYSTEM_PROMPT}\n\n{uprompt}\n\nAnswer:"
264
-
265
  model_id = os.getenv("SMARTHEAL_MEDGEMMA_MODEL", "unsloth/medgemma-4b-it-bnb-4bit")
266
  max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
267
 
268
  try:
269
- return _medgemma_generate_gpu_with_pipeline(prompt, image_pil, model_id, max_new_tokens, HF_TOKEN)
270
  except Exception as e:
271
- # Optional: automatic tiny fallback if CUDA/bnb/space issues show up
272
- err = str(e)
273
- if any(s in err for s in ("No space left", "bitsandbytes", "CUDA", "requires CUDA")):
274
- try:
275
- return _medgemma_generate_gpu_with_pipeline(
276
- prompt, image_pil,
277
- model_id="bczhou/tiny-llava-v1-hf", # ~1GB; CPU OK
278
- max_new_tokens=max_new_tokens,
279
- token=HF_TOKEN,
280
- )
281
- except Exception:
282
- pass
283
  logging.error(f"MedGemma pipeline failed: {e}", exc_info=True)
284
  return "⚠️ VLM error"
285
 
286
 
 
287
  # ---------- Input-shape helpers (avoid `.as_list()` on strings) ----------
288
  def _shape_to_hw(shape) -> Tuple[Optional[int], Optional[int]]:
289
  try:
 
140
 
141
  # ---------- MedGemma-only text generator ----------
142
  @_SPACES_GPU(enable_queue=True)
143
+ def _build_vlm_pipeline(model_id: str, token: str | None):
 
 
 
 
 
 
 
 
 
 
 
144
  import os, torch
145
  from transformers import pipeline
 
 
 
 
146
 
147
+ # don't mask CUDA here
148
+ os.environ.pop("CUDA_VISIBLE_DEVICES", None)
149
+
150
  use_cuda = torch.cuda.is_available()
151
+ kwargs = dict(
152
+ task="image-text-to-text",
153
+ model=model_id,
154
+ trust_remote_code=True,
155
+ torch_dtype=(torch.bfloat16 if use_cuda else torch.float32),
156
+ device=(0 if use_cuda else -1),
157
+ )
 
158
 
159
+ if token:
160
+ try: kwargs["token"] = token
161
+ except TypeError: kwargs["use_auth_token"] = token
162
+
163
+ # if it's a 4-bit Unsloth build, attach bnb config (GPU required)
164
+ if "bnb-4bit" in model_id.lower() or "4bit" in model_id.lower():
165
+ if not use_cuda:
166
+ raise RuntimeError("CUDA not available for 4-bit quantized model.")
167
+ from transformers import BitsAndBytesConfig
168
+ kwargs["model_kwargs"] = {
169
+ "quantization_config": BitsAndBytesConfig(
170
+ load_in_4bit=True,
171
+ bnb_4bit_quant_type="nf4",
172
+ bnb_4bit_use_double_quant=True,
173
+ bnb_4bit_compute_dtype=torch.bfloat16,
174
+ )
175
+ }
176
 
177
+ return pipeline(**kwargs)
178
+
179
+ def _vlm_generate_with_messages(prompt: str,
180
+ image_pil,
181
+ model_id: str,
182
+ max_new_tokens: int,
183
+ token: str | None) -> str:
184
+ # try preferred; on error, fall back to a tiny CPU-friendly VLM
185
+ try:
186
+ p = _build_vlm_pipeline(model_id or "unsloth/medgemma-4b-it-bnb-4bit", token)
187
+ except Exception:
188
+ p = _build_vlm_pipeline("bczhou/tiny-llava-v1-hf", None)
189
 
 
190
  messages = [{
191
  "role": "user",
192
  "content": [
193
+ {"type": "image", "image": image_pil},
194
  {"type": "text", "text": prompt},
195
  ],
196
  }]
197
 
198
+ out = p(text=messages,
199
+ max_new_tokens=int(max_new_tokens or 256),
200
+ do_sample=False,
201
+ temperature=0.2,
202
+ return_full_text=False)
 
 
203
 
204
+ # robust extraction
205
+ if isinstance(out, list) and out and isinstance(out[0], dict) and "generated_text" in out[0]:
206
+ return (out[0]["generated_text"] or "").strip()
207
+ return (str(out) or "").strip() or "⚠️ Empty response"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
+ def generate_medgemma_report(patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None) -> str:
 
 
 
 
 
 
 
 
 
 
 
210
  if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
211
  return "⚠️ VLM disabled"
212
 
 
213
  uprompt = SMARTHEAL_USER_PREFIX.format(
214
  patient_info=patient_info,
215
  wound_type=visual_results.get("wound_type", "Unknown"),
 
221
  guideline_context=(guideline_context or "")[:900],
222
  )
223
  prompt = f"{SMARTHEAL_SYSTEM_PROMPT}\n\n{uprompt}\n\nAnswer:"
 
224
  model_id = os.getenv("SMARTHEAL_MEDGEMMA_MODEL", "unsloth/medgemma-4b-it-bnb-4bit")
225
  max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
226
 
227
  try:
228
+ return _vlm_generate_with_messages(prompt, image_pil, model_id, max_new_tokens, os.getenv("HF_TOKEN"))
229
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
230
  logging.error(f"MedGemma pipeline failed: {e}", exc_info=True)
231
  return "⚠️ VLM error"
232
 
233
 
234
+
235
  # ---------- Input-shape helpers (avoid `.as_list()` on strings) ----------
236
  def _shape_to_hw(shape) -> Tuple[Optional[int], Optional[int]]:
237
  try: