Akis Giannoukos commited on
Commit
9325a21
·
1 Parent(s): 44521ed

Using gemma-2 model

Browse files
Files changed (1) hide show
  1. app.py +27 -16
app.py CHANGED
@@ -25,7 +25,7 @@ import spaces
25
  # ---------------------------
26
  # Configuration
27
  # ---------------------------
28
- DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
29
  DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en")
30
  CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8"))
31
  MAX_TURNS = int(os.getenv("MAX_TURNS", "12"))
@@ -191,18 +191,21 @@ def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
191
  "\n\nRespond with a single short clinician-style question for the patient."
192
  )
193
  pipe = get_textgen_pipeline()
194
- out = pipe(
195
- f"<|system|>\n{system_prompt}\n<|user|>\n{user_prompt}\n<|assistant|>",
196
- max_new_tokens=128,
 
 
 
 
 
 
197
  temperature=0.7,
198
  do_sample=True,
199
- pad_token_id=pipe.tokenizer.eos_token_id,
200
- )[0]["generated_text"]
201
-
202
- # Extract assistant content after the last assistant tag if present
203
- reply = out.split("<|assistant|>")[-1].strip()
204
- # Post-process to avoid trailing special tokens
205
- reply = re.split(r"</s>|<\|endoftext\|>", reply)[0].strip()
206
  # Ensure it's a single concise question/sentence
207
  if len(reply) > 300:
208
  reply = reply[:300].rstrip() + "…"
@@ -227,14 +230,22 @@ def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str,
227
  "Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose."
228
  )
229
  pipe = get_textgen_pipeline()
230
- out = pipe(
231
- f"<|system|>\n{system_prompt}\n<|user|>\n{user_prompt}\n<|assistant|>",
 
 
 
 
 
 
232
  max_new_tokens=256,
233
  temperature=0.2,
234
  do_sample=True,
235
- pad_token_id=pipe.tokenizer.eos_token_id,
236
- )[0]["generated_text"]
237
- parsed = safe_json_extract(out)
 
 
238
 
239
  # Validate and coerce
240
  if parsed is None or "PHQ9_Scores" not in parsed:
 
25
  # ---------------------------
26
  # Configuration
27
  # ---------------------------
28
+ DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "google/gemma-2-2b-it")
29
  DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en")
30
  CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8"))
31
  MAX_TURNS = int(os.getenv("MAX_TURNS", "12"))
 
191
  "\n\nRespond with a single short clinician-style question for the patient."
192
  )
193
  pipe = get_textgen_pipeline()
194
+ tokenizer = pipe.tokenizer
195
+ messages = [
196
+ {"role": "system", "content": system_prompt},
197
+ {"role": "user", "content": user_prompt},
198
+ ]
199
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
200
+ gen = pipe(
201
+ prompt,
202
+ max_new_tokens=96,
203
  temperature=0.7,
204
  do_sample=True,
205
+ pad_token_id=tokenizer.eos_token_id,
206
+ return_full_text=False,
207
+ )
208
+ reply = gen[0]["generated_text"].strip()
 
 
 
209
  # Ensure it's a single concise question/sentence
210
  if len(reply) > 300:
211
  reply = reply[:300].rstrip() + "…"
 
230
  "Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose."
231
  )
232
  pipe = get_textgen_pipeline()
233
+ tokenizer = pipe.tokenizer
234
+ messages = [
235
+ {"role": "system", "content": system_prompt},
236
+ {"role": "user", "content": user_prompt},
237
+ ]
238
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
239
+ gen = pipe(
240
+ prompt,
241
  max_new_tokens=256,
242
  temperature=0.2,
243
  do_sample=True,
244
+ pad_token_id=tokenizer.eos_token_id,
245
+ return_full_text=False,
246
+ )
247
+ out_text = gen[0]["generated_text"]
248
+ parsed = safe_json_extract(out_text)
249
 
250
  # Validate and coerce
251
  if parsed is None or "PHQ9_Scores" not in parsed: