SmartHeal commited on
Commit
74de941
·
verified ·
1 Parent(s): 5e405a7

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +55 -53
src/ai_processor.py CHANGED
@@ -140,53 +140,39 @@ Keep to 220–300 words. Do NOT provide diagnosis. Avoid contraindicated advice.
140
 
141
  # ---------- MedGemma-only text generator ----------
142
  @_SPACES_GPU(enable_queue=True)
143
- def _build_vlm_pipeline(model_id: str, token: str | None):
 
 
 
 
 
 
144
  import os, torch
145
- from transformers import pipeline
146
 
147
- # don't mask CUDA here
148
  os.environ.pop("CUDA_VISIBLE_DEVICES", None)
149
 
150
- use_cuda = torch.cuda.is_available()
151
- kwargs = dict(
152
- task="image-text-to-text",
 
 
 
 
 
 
 
 
 
 
153
  model=model_id,
 
 
 
154
  trust_remote_code=True,
155
- torch_dtype=(torch.bfloat16 if use_cuda else torch.float32),
156
- device=(0 if use_cuda else -1),
157
  )
158
 
159
- if token:
160
- try: kwargs["token"] = token
161
- except TypeError: kwargs["use_auth_token"] = token
162
-
163
- # if it's a 4-bit Unsloth build, attach bnb config (GPU required)
164
- if "bnb-4bit" in model_id.lower() or "4bit" in model_id.lower():
165
- if not use_cuda:
166
- raise RuntimeError("CUDA not available for 4-bit quantized model.")
167
- from transformers import BitsAndBytesConfig
168
- kwargs["model_kwargs"] = {
169
- "quantization_config": BitsAndBytesConfig(
170
- load_in_4bit=True,
171
- bnb_4bit_quant_type="nf4",
172
- bnb_4bit_use_double_quant=True,
173
- bnb_4bit_compute_dtype=torch.bfloat16,
174
- )
175
- }
176
-
177
- return pipeline(**kwargs)
178
-
179
- def _vlm_generate_with_messages(prompt: str,
180
- image_pil,
181
- model_id: str,
182
- max_new_tokens: int,
183
- token: str | None) -> str:
184
- # try preferred; on error, fall back to a tiny CPU-friendly VLM
185
- try:
186
- p = _build_vlm_pipeline(model_id or "unsloth/medgemma-4b-it-bnb-4bit", token)
187
- except Exception:
188
- p = _build_vlm_pipeline("bczhou/tiny-llava-v1-hf", None)
189
-
190
  messages = [{
191
  "role": "user",
192
  "content": [
@@ -195,18 +181,31 @@ def _vlm_generate_with_messages(prompt: str,
195
  ],
196
  }]
197
 
198
- out = p(text=messages,
199
- max_new_tokens=int(max_new_tokens or 256),
200
- do_sample=False,
201
- temperature=0.2,
202
- return_full_text=False)
203
-
204
- # robust extraction
205
  if isinstance(out, list) and out and isinstance(out[0], dict) and "generated_text" in out[0]:
206
  return (out[0]["generated_text"] or "").strip()
207
  return (str(out) or "").strip() or "⚠️ Empty response"
208
 
209
- def generate_medgemma_report(patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
211
  return "⚠️ VLM disabled"
212
 
@@ -221,15 +220,18 @@ def generate_medgemma_report(patient_info, visual_results, guideline_context, im
221
  guideline_context=(guideline_context or "")[:900],
222
  )
223
  prompt = f"{SMARTHEAL_SYSTEM_PROMPT}\n\n{uprompt}\n\nAnswer:"
 
224
  model_id = os.getenv("SMARTHEAL_MEDGEMMA_MODEL", "unsloth/medgemma-4b-it-bnb-4bit")
225
  max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
226
 
227
- try:
228
- return _vlm_generate_with_messages(prompt, image_pil, model_id, max_new_tokens, os.getenv("HF_TOKEN"))
229
- except Exception as e:
230
- logging.error(f"MedGemma pipeline failed: {e}", exc_info=True)
231
- return "⚠️ VLM error"
232
-
 
 
233
 
234
 
235
  # ---------- Input-shape helpers (avoid `.as_list()` on strings) ----------
 
140
 
141
  # ---------- MedGemma-only text generator ----------
142
  @_SPACES_GPU(enable_queue=True)
143
+ def vlm_generate(prompt, image_pil, model_id="unsloth/medgemma-4b-it-bnb-4bit",
144
+ max_new_tokens=256, token=None):
145
+ """
146
+ Simple helper: messages-style image+text → text using a 4-bit MedGemma pipeline.
147
+ - No explicit `device` argument (pipeline will auto-detect).
148
+ - Uses HF token from arg or HF_TOKEN env.
149
+ """
150
  import os, torch
151
+ from transformers import pipeline, BitsAndBytesConfig
152
 
153
+ # Unmask GPU if it was masked upstream (harmless on CPU too)
154
  os.environ.pop("CUDA_VISIBLE_DEVICES", None)
155
 
156
+ hf_token = token or os.getenv("HF_TOKEN")
157
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
158
+
159
+ # 4-bit quantization config (required by the Unsloth 4-bit model)
160
+ bnb = BitsAndBytesConfig(
161
+ load_in_4bit=True,
162
+ bnb_4bit_quant_type="nf4",
163
+ bnb_4bit_use_double_quant=True,
164
+ bnb_4bit_compute_dtype=dtype,
165
+ )
166
+
167
+ pipe = pipeline(
168
+ "image-text-to-text",
169
  model=model_id,
170
+ model_kwargs={"quantization_config": bnb},
171
+ torch_dtype=dtype,
172
+ token=hf_token,
173
  trust_remote_code=True,
 
 
174
  )
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  messages = [{
177
  "role": "user",
178
  "content": [
 
181
  ],
182
  }]
183
 
184
+ out = pipe(
185
+ text=messages,
186
+ max_new_tokens=int(max_new_tokens),
187
+ do_sample=False,
188
+ temperature=0.2,
189
+ return_full_text=False,
190
+ )
191
  if isinstance(out, list) and out and isinstance(out[0], dict) and "generated_text" in out[0]:
192
  return (out[0]["generated_text"] or "").strip()
193
  return (str(out) or "").strip() or "⚠️ Empty response"
194
 
195
+
196
+ def generate_medgemma_report(
197
+ patient_info: str,
198
+ visual_results: dict,
199
+ guideline_context: str,
200
+ image_pil, # PIL.Image
201
+ max_new_tokens: int | None = None,
202
+ ) -> str:
203
+ """
204
+ Build SmartHeal prompt and generate with the Unsloth MedGemma 4-bit VLM.
205
+ No fallback to any other model.
206
+ """
207
+ import os
208
+
209
  if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
210
  return "⚠️ VLM disabled"
211
 
 
220
  guideline_context=(guideline_context or "")[:900],
221
  )
222
  prompt = f"{SMARTHEAL_SYSTEM_PROMPT}\n\n{uprompt}\n\nAnswer:"
223
+
224
  model_id = os.getenv("SMARTHEAL_MEDGEMMA_MODEL", "unsloth/medgemma-4b-it-bnb-4bit")
225
  max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
226
 
227
+ # Uses the simple messages-based VLM helper you added earlier (no device param).
228
+ return vlm_generate(
229
+ prompt=prompt,
230
+ image_pil=image_pil,
231
+ model_id=model_id,
232
+ max_new_tokens=max_new_tokens,
233
+ token=os.getenv("HF_TOKEN"),
234
+ )
235
 
236
 
237
  # ---------- Input-shape helpers (avoid `.as_list()` on strings) ----------