Spaces:
Sleeping
Sleeping
Akis Giannoukos
commited on
Commit
·
9325a21
1
Parent(s):
44521ed
Using gemma-2 model
Browse files
app.py
CHANGED
|
@@ -25,7 +25,7 @@ import spaces
|
|
| 25 |
# ---------------------------
|
| 26 |
# Configuration
|
| 27 |
# ---------------------------
|
| 28 |
-
DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "
|
| 29 |
DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en")
|
| 30 |
CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8"))
|
| 31 |
MAX_TURNS = int(os.getenv("MAX_TURNS", "12"))
|
|
@@ -191,18 +191,21 @@ def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
|
|
| 191 |
"\n\nRespond with a single short clinician-style question for the patient."
|
| 192 |
)
|
| 193 |
pipe = get_textgen_pipeline()
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
temperature=0.7,
|
| 198 |
do_sample=True,
|
| 199 |
-
pad_token_id=
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
reply = out.split("<|assistant|>")[-1].strip()
|
| 204 |
-
# Post-process to avoid trailing special tokens
|
| 205 |
-
reply = re.split(r"</s>|<\|endoftext\|>", reply)[0].strip()
|
| 206 |
# Ensure it's a single concise question/sentence
|
| 207 |
if len(reply) > 300:
|
| 208 |
reply = reply[:300].rstrip() + "…"
|
|
@@ -227,14 +230,22 @@ def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str,
|
|
| 227 |
"Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose."
|
| 228 |
)
|
| 229 |
pipe = get_textgen_pipeline()
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
max_new_tokens=256,
|
| 233 |
temperature=0.2,
|
| 234 |
do_sample=True,
|
| 235 |
-
pad_token_id=
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
| 238 |
|
| 239 |
# Validate and coerce
|
| 240 |
if parsed is None or "PHQ9_Scores" not in parsed:
|
|
|
|
| 25 |
# ---------------------------
|
| 26 |
# Configuration
|
| 27 |
# ---------------------------
|
| 28 |
+
DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "google/gemma-2-2b-it")
|
| 29 |
DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en")
|
| 30 |
CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8"))
|
| 31 |
MAX_TURNS = int(os.getenv("MAX_TURNS", "12"))
|
|
|
|
| 191 |
"\n\nRespond with a single short clinician-style question for the patient."
|
| 192 |
)
|
| 193 |
pipe = get_textgen_pipeline()
|
| 194 |
+
tokenizer = pipe.tokenizer
|
| 195 |
+
messages = [
|
| 196 |
+
{"role": "system", "content": system_prompt},
|
| 197 |
+
{"role": "user", "content": user_prompt},
|
| 198 |
+
]
|
| 199 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 200 |
+
gen = pipe(
|
| 201 |
+
prompt,
|
| 202 |
+
max_new_tokens=96,
|
| 203 |
temperature=0.7,
|
| 204 |
do_sample=True,
|
| 205 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 206 |
+
return_full_text=False,
|
| 207 |
+
)
|
| 208 |
+
reply = gen[0]["generated_text"].strip()
|
|
|
|
|
|
|
|
|
|
| 209 |
# Ensure it's a single concise question/sentence
|
| 210 |
if len(reply) > 300:
|
| 211 |
reply = reply[:300].rstrip() + "…"
|
|
|
|
| 230 |
"Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose."
|
| 231 |
)
|
| 232 |
pipe = get_textgen_pipeline()
|
| 233 |
+
tokenizer = pipe.tokenizer
|
| 234 |
+
messages = [
|
| 235 |
+
{"role": "system", "content": system_prompt},
|
| 236 |
+
{"role": "user", "content": user_prompt},
|
| 237 |
+
]
|
| 238 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 239 |
+
gen = pipe(
|
| 240 |
+
prompt,
|
| 241 |
max_new_tokens=256,
|
| 242 |
temperature=0.2,
|
| 243 |
do_sample=True,
|
| 244 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 245 |
+
return_full_text=False,
|
| 246 |
+
)
|
| 247 |
+
out_text = gen[0]["generated_text"]
|
| 248 |
+
parsed = safe_json_extract(out_text)
|
| 249 |
|
| 250 |
# Validate and coerce
|
| 251 |
if parsed is None or "PHQ9_Scores" not in parsed:
|