import os import json import re import time from typing import Any, Dict, List, Optional, Tuple import gradio as gr import numpy as np # Audio processing import soundfile as sf import librosa # Models import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, pipeline, ) from gtts import gTTS import spaces # --------------------------- # Configuration # --------------------------- DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "google/gemma-2-2b-it") DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en") CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8")) MAX_TURNS = int(os.getenv("MAX_TURNS", "12")) USE_TTS_DEFAULT = os.getenv("USE_TTS", "false").strip().lower() == "true" # --------------------------- # Lazy singletons for pipelines # --------------------------- _asr_pipe = None _gen_pipe = None _tokenizer = None def _hf_device() -> int: return 0 if torch.cuda.is_available() else -1 def get_asr_pipeline(): global _asr_pipe if _asr_pipe is None: _asr_pipe = pipeline( "automatic-speech-recognition", model=DEFAULT_ASR_MODEL_ID, device=_hf_device(), ) return _asr_pipe def get_textgen_pipeline(): global _gen_pipe if _gen_pipe is None: # Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID _gen_pipe = pipeline( task="text-generation", model=DEFAULT_CHAT_MODEL_ID, tokenizer=DEFAULT_CHAT_MODEL_ID, device=_hf_device(), torch_dtype=(torch.float16 if torch.cuda.is_available() else torch.float32), ) return _gen_pipe # --------------------------- # Utilities # --------------------------- def safe_json_extract(text: str) -> Optional[Dict[str, Any]]: """Extract first JSON object from text.""" if not text: return None try: return json.loads(text) except Exception: pass # Fallback: find the first {...} block match = re.search(r"\{[\s\S]*\}", text) if match: try: return json.loads(match.group(0)) except Exception: return None return None def compute_audio_features(audio_path: str) -> Dict[str, float]: """Compute lightweight prosodic features as a proxy for OpenSMILE. Returns a dictionary with summary statistics. """ try: y, sr = librosa.load(audio_path, sr=16000, mono=True) if len(y) == 0: return {} # Frame-based features hop_length = 512 frame_length = 1024 rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0] zcr = librosa.feature.zero_crossing_rate(y, frame_length=frame_length, hop_length=hop_length)[0] centroid = librosa.feature.spectral_centroid(y=y, sr=sr, n_fft=2048, hop_length=hop_length)[0] # Pitch estimation (coarse) f0 = None try: f0 = librosa.yin(y, fmin=50, fmax=400, sr=sr, frame_length=frame_length, hop_length=hop_length) f0 = f0[np.isfinite(f0)] except Exception: f0 = None # Speaking rate rough proxy: voiced ratio per second energy = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0] voiced = energy > (np.median(energy) * 1.2) voiced_ratio = float(np.mean(voiced)) features = { "rms_mean": float(np.mean(rms)), "rms_std": float(np.std(rms)), "zcr_mean": float(np.mean(zcr)), "zcr_std": float(np.std(zcr)), "centroid_mean": float(np.mean(centroid)), "centroid_std": float(np.std(centroid)), "voiced_ratio": voiced_ratio, "duration_sec": float(len(y) / sr), } if f0 is not None and f0.size > 0: features.update({ "f0_median": float(np.median(f0)), "f0_iqr": float(np.percentile(f0, 75) - np.percentile(f0, 25)), }) return features except Exception: return {} def detect_explicit_suicidality(text: Optional[str]) -> bool: if not text: return False t = text.lower() patterns = [ r"\bkill myself\b", r"\bend my life\b", r"\bend it all\b", r"\bcommit suicide\b", r"\bsuicide\b", r"\bself[-\s]?harm\b", r"\bhurt myself\b", r"\bno reason to live\b", r"\bwant to die\b", r"\bending it\b", ] for pat in patterns: if re.search(pat, t): return True return False def synthesize_tts(text: Optional[str]) -> Optional[str]: if not text: return None try: # Save MP3 to tmp and return filepath ts = int(time.time() * 1000) out_path = f"/tmp/tts_{ts}.mp3" tts = gTTS(text=text, lang="en") tts.save(out_path) return out_path except Exception: return None def severity_from_total(total_score: int) -> str: if total_score <= 4: return "Minimal Depression" if total_score <= 9: return "Mild Depression" if total_score <= 14: return "Moderate Depression" if total_score <= 19: return "Moderately Severe Depression" return "Severe Depression" def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str: """Convert chatbot history [(user, assistant), ...] to a plain text transcript.""" lines = [] for user, assistant in chat_history: if user: lines.append(f"Patient: {user}") if assistant: lines.append(f"Clinician: {assistant}") return "\n".join(lines) def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str: transcript = transcript_to_text(chat_history) system_prompt = ( "You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms " "without listing the nine questions explicitly. Keep tone empathetic, natural, and human. " "Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, " "sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts." ) user_prompt = ( "Conversation so far (Patient and Clinician turns):\n\n" + transcript + "\n\nRespond with a single short clinician-style question for the patient." ) pipe = get_textgen_pipeline() tokenizer = pipe.tokenizer combined_prompt = system_prompt + "\n\n" + user_prompt messages = [ {"role": "user", "content": combined_prompt}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) gen = pipe( prompt, max_new_tokens=96, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id, return_full_text=False, ) reply = gen[0]["generated_text"].strip() # Ensure it's a single concise question/sentence if len(reply) > 300: reply = reply[:300].rstrip() + "…" return reply def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str, float]) -> Dict[str, Any]: """Ask the LLM to produce PHQ-9 scores and confidences as JSON. Fallback if parsing fails.""" transcript = transcript_to_text(chat_history) features_json = json.dumps(features, ensure_ascii=False) system_prompt = ( "You evaluate an on-going clinician-patient conversation to infer a PHQ-9 assessment. " "Return ONLY a JSON object with: PHQ9_Scores (interest,mood,sleep,energy,appetite,self_worth,concentration,motor,suicidal_thoughts; each 0-3), " "Confidences (list of 9 floats 0-1 in the same order), Total_Score (0-27), Severity (string), Confidence (min of confidences), " "and High_Risk (boolean, true if any suicidal risk)." ) user_prompt = ( "Conversation transcript:"\ f"\n{transcript}\n\n" f"Acoustic features summary (approximate):\n{features_json}\n\n" "Instructions: Infer PHQ9_Scores (0-3 per item), estimate Confidences per item, compute Total_Score and overall Severity. " "Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose." ) pipe = get_textgen_pipeline() tokenizer = pipe.tokenizer combined_prompt = system_prompt + "\n\n" + user_prompt messages = [ {"role": "user", "content": combined_prompt}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) gen = pipe( prompt, max_new_tokens=256, temperature=0.2, do_sample=True, pad_token_id=tokenizer.eos_token_id, return_full_text=False, ) out_text = gen[0]["generated_text"] parsed = safe_json_extract(out_text) # Validate and coerce if parsed is None or "PHQ9_Scores" not in parsed: # Simple fallback heuristic: neutral scores with low confidence scores = { "interest": 1, "mood": 1, "sleep": 1, "energy": 1, "appetite": 1, "self_worth": 1, "concentration": 1, "motor": 1, "suicidal_thoughts": 0, } confidences = [0.5] * 9 total = int(sum(scores.values())) return { "PHQ9_Scores": scores, "Confidences": confidences, "Total_Score": total, "Severity": severity_from_total(total), "Confidence": float(min(confidences)), "High_Risk": False, } try: # Coerce types and compute derived values if missing scores = parsed.get("PHQ9_Scores", {}) # Ensure all keys present keys = [ "interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts" ] for k in keys: scores[k] = int(max(0, min(3, int(scores.get(k, 0))))) confidences = parsed.get("Confidences", []) if not isinstance(confidences, list) or len(confidences) != 9: confidences = [float(parsed.get("Confidence", 0.5))] * 9 confidences = [float(max(0.0, min(1.0, c))) for c in confidences] total = int(sum(scores.values())) severity = parsed.get("Severity") or severity_from_total(total) overall_conf = float(parsed.get("Confidence", min(confidences))) # Conservative high-risk detection: require explicit language or high suicidal_thoughts score # Extract last patient message last_patient = "" for user_text, assistant_text in reversed(chat_history): if user_text: last_patient = user_text break explicit_flag = detect_explicit_suicidality(last_patient) or detect_explicit_suicidality(transcript) high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2)) return { "PHQ9_Scores": scores, "Confidences": confidences, "Total_Score": total, "Severity": severity, "Confidence": overall_conf, "High_Risk": high_risk, } except Exception: # Final fallback scores = parsed.get("PHQ9_Scores", {}) if isinstance(parsed, dict) else {} if not scores: scores = {k: 1 for k in [ "interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts" ]} confidences = [0.5] * 9 total = int(sum(scores.values())) return { "PHQ9_Scores": scores, "Confidences": confidences, "Total_Score": total, "Severity": severity_from_total(total), "Confidence": float(min(confidences)), "High_Risk": False, } def transcribe_audio(audio_path: Optional[str]) -> str: if not audio_path: return "" try: asr = get_asr_pipeline() result = asr(audio_path) if isinstance(result, dict) and "text" in result: return result["text"].strip() if isinstance(result, list) and len(result) > 0 and "text" in result[0]: return result[0]["text"].strip() except Exception: pass return "" # --------------------------- # Gradio app logic # --------------------------- INTRO_MESSAGE = ( "Hello, I'm here to check in on how you've been feeling lately. " "To start, can you share how your mood has been over the past couple of weeks?" ) def init_state() -> Tuple[List[Tuple[str, str]], Dict[str, Any], Dict[str, Any], bool, int]: chat_history: List[Tuple[str, str]] = [("", INTRO_MESSAGE)] scores = {} meta = {"Severity": None, "Total_Score": None, "Confidence": 0.0} finished = False turns = 0 return chat_history, scores, meta, finished, turns @spaces.GPU def process_turn( audio_path: Optional[str], text_input: Optional[str], chat_history: List[Tuple[str, str]], threshold: float, tts_enabled: bool, finished: bool, turns: int, prev_scores: Dict[str, Any], prev_meta: Dict[str, Any], ): # If already finished, do nothing if finished: return ( chat_history, {"info": "Assessment complete."}, prev_meta.get("Severity", ""), finished, turns, None, None, None, ) patient_text = (text_input or "").strip() audio_features: Dict[str, float] = {} if audio_path: # Transcribe first transcribed = transcribe_audio(audio_path) if transcribed: patient_text = (patient_text + " ").strip() + transcribed if patient_text else transcribed # Extract features audio_features = compute_audio_features(audio_path) if not patient_text: # Ask user for input chat_history.append(("", "I didn't catch that. Could you share a bit about how you've been feeling?")) return ( chat_history, prev_scores or {}, prev_meta, finished, turns, None, None, None, ) # Add patient's message chat_history.append((patient_text, None)) # Scoring agent scoring = scoring_agent_infer(chat_history, audio_features) scores = scoring.get("PHQ9_Scores", {}) confidences = scoring.get("Confidences", []) total = scoring.get("Total_Score", 0) severity = scoring.get("Severity", severity_from_total(total)) overall_conf = float(scoring.get("Confidence", min(confidences) if confidences else 0.0)) # Override high-risk to reduce false positives: rely on explicit text or high item score # Extract last patient message last_patient = "" for user_text, assistant_text in reversed(chat_history): if user_text: last_patient = user_text break explicit_flag = detect_explicit_suicidality(last_patient) high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2)) meta = {"Severity": severity, "Total_Score": total, "Confidence": overall_conf} # Termination conditions min_conf = float(min(confidences)) if confidences else 0.0 turns += 1 done = high_risk or (min_conf >= threshold) or (turns >= MAX_TURNS) if high_risk: closing = ( "I’m concerned about your safety based on what you shared. " "If you are in danger or need immediate help, please call 988 in the U.S. or your local emergency number. " "I'll end the assessment now and display emergency resources." ) chat_history[-1] = (chat_history[-1][0], closing) finished = True elif done: summary = ( f"Thank you for sharing. Based on our conversation, your responses suggest {severity.lower()}. " "We can stop here." ) chat_history[-1] = (chat_history[-1][0], summary) finished = True else: # Generate next clinician question reply = generate_recording_agent_reply(chat_history) chat_history[-1] = (chat_history[-1][0], reply) # TTS for the latest clinician message, if enabled tts_path = synthesize_tts(chat_history[-1][1]) if tts_enabled else None # Build a compact JSON for display display_json = { "PHQ9_Scores": scores, "Confidences": confidences, "Total_Score": total, "Severity": severity, "Confidence": overall_conf, "High_Risk": high_risk, } # Clear inputs after processing return ( chat_history, display_json, severity, finished, turns, None, None, tts_path, ) def reset_app(): return init_state() # --------------------------- # UI # --------------------------- def _on_load_init(): return init_state() def create_demo(): with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ ### PHQ-9 Conversational Clinician Agent Engage in a brief, empathetic conversation. Your audio is transcribed, analyzed, and used to infer PHQ-9 scores. The system stops when confidence is high enough or any safety risk is detected. It does not provide therapy or emergency counseling. """ ) with gr.Row(): chatbot = gr.Chatbot(height=400, type="tuples") with gr.Column(): score_json = gr.JSON(label="PHQ-9 Assessment (live)") severity_label = gr.Label(label="Severity") threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)") tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT) tts_audio = gr.Audio(label="Clinician voice", interactive=False) with gr.Row(): audio = gr.Audio(sources=["microphone"], type="filepath", label="Speak your response (or use text)") text = gr.Textbox(lines=2, placeholder="Optional: type your response instead of audio") with gr.Row(): send_btn = gr.Button("Send") reset_btn = gr.Button("Reset") # App state chat_state = gr.State() scores_state = gr.State() meta_state = gr.State() finished_state = gr.State() turns_state = gr.State() # Initialize on load (top-level function to be pickle-safe under ZeroGPU) demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state]) # Wire interactions send_btn.click( fn=process_turn, inputs=[audio, text, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state], outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio, text, tts_audio], queue=True, api_name="message", ) reset_btn.click(fn=reset_app, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state]) return demo demo = create_demo() if __name__ == "__main__": # For local dev demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))