import os # Disable torch compile/dynamo globally to avoid cudagraph assertion errors os.environ["TORCHDYNAMO_DISABLE"] = "1" os.environ["TORCH_COMPILE_DISABLE"] = "1" import json import re import time from typing import Any, Dict, List, Optional, Tuple import gradio as gr import numpy as np # Audio processing import soundfile as sf import librosa # Models import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, pipeline, ) from gtts import gTTS import spaces import threading # --------------------------- # Configuration # --------------------------- DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "google/gemma-2-2b-it") DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en") CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8")) MAX_TURNS = int(os.getenv("MAX_TURNS", "12")) USE_TTS_DEFAULT = os.getenv("USE_TTS", "true").strip().lower() == "true" CONFIG_PATH = os.getenv("MODEL_CONFIG_PATH", "model_config.json") DEBUG_MODE = os.getenv("DEBUG", "false").strip().lower() in ("1", "true", "yes", "on") def _load_model_id_from_config() -> str: try: if os.path.exists(CONFIG_PATH): with open(CONFIG_PATH, "r") as f: data = json.load(f) if isinstance(data, dict) and data.get("model_id"): return str(data["model_id"]) except Exception: pass return DEFAULT_CHAT_MODEL_ID current_model_id = _load_model_id_from_config() # --------------------------- # Lazy singletons for pipelines # --------------------------- _asr_pipe = None _gen_pipe = None _tokenizer = None def _hf_device() -> int: return 0 if torch.cuda.is_available() else -1 def get_asr_pipeline(): global _asr_pipe if _asr_pipe is None: _asr_pipe = pipeline( "automatic-speech-recognition", model=DEFAULT_ASR_MODEL_ID, device=_hf_device(), ) return _asr_pipe def get_textgen_pipeline(): global _gen_pipe if _gen_pipe is None: # Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID if torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported(): _dtype = torch.bfloat16 elif torch.cuda.is_available(): _dtype = torch.float16 else: _dtype = torch.float32 _gen_pipe = pipeline( task="text-generation", model=current_model_id, tokenizer=current_model_id, device=_hf_device(), torch_dtype=_dtype, ) return _gen_pipe def set_current_model_id(new_model_id: str) -> str: global current_model_id, _gen_pipe new_model_id = (new_model_id or "").strip() if not new_model_id: return "Model id is empty; keeping current model." if new_model_id == current_model_id: return f"Model unchanged: `{current_model_id}`" current_model_id = new_model_id _gen_pipe = None # force reload on next use return f"Model switched to `{current_model_id}` (pipeline will reload on next generation)." def persist_model_id(new_model_id: str) -> None: try: with open(CONFIG_PATH, "w") as f: json.dump({"model_id": new_model_id}, f) except Exception: pass def apply_model_and_restart(new_model_id: str) -> str: mid = (new_model_id or "").strip() if not mid: return "Model id is empty; not restarting." persist_model_id(mid) set_current_model_id(mid) # Graceful delayed exit so response can flush def _exit_later(): time.sleep(0.25) os._exit(0) threading.Thread(target=_exit_later, daemon=True).start() return f"Restarting with model `{mid}`..." # --------------------------- # Utilities # --------------------------- def safe_json_extract(text: str) -> Optional[Dict[str, Any]]: """Extract first JSON object from text.""" if not text: return None try: return json.loads(text) except Exception: pass # Fallback: find the first {...} block match = re.search(r"\{[\s\S]*\}", text) if match: try: return json.loads(match.group(0)) except Exception: return None return None def compute_audio_features(audio_path: str) -> Dict[str, float]: """Compute lightweight prosodic features as a proxy for OpenSMILE. Returns a dictionary with summary statistics. """ try: y, sr = librosa.load(audio_path, sr=16000, mono=True) if len(y) == 0: return {} # Frame-based features hop_length = 512 frame_length = 1024 rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0] zcr = librosa.feature.zero_crossing_rate(y, frame_length=frame_length, hop_length=hop_length)[0] centroid = librosa.feature.spectral_centroid(y=y, sr=sr, n_fft=2048, hop_length=hop_length)[0] # Pitch estimation (coarse) f0 = None try: f0 = librosa.yin(y, fmin=50, fmax=400, sr=sr, frame_length=frame_length, hop_length=hop_length) f0 = f0[np.isfinite(f0)] except Exception: f0 = None # Speaking rate rough proxy: voiced ratio per second energy = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0] voiced = energy > (np.median(energy) * 1.2) voiced_ratio = float(np.mean(voiced)) features = { "rms_mean": float(np.mean(rms)), "rms_std": float(np.std(rms)), "zcr_mean": float(np.mean(zcr)), "zcr_std": float(np.std(zcr)), "centroid_mean": float(np.mean(centroid)), "centroid_std": float(np.std(centroid)), "voiced_ratio": voiced_ratio, "duration_sec": float(len(y) / sr), } if f0 is not None and f0.size > 0: features.update({ "f0_median": float(np.median(f0)), "f0_iqr": float(np.percentile(f0, 75) - np.percentile(f0, 25)), }) return features except Exception: return {} def detect_explicit_suicidality(text: Optional[str]) -> bool: if not text: return False t = text.lower() patterns = [ r"\bkill myself\b", r"\bend my life\b", r"\bend it all\b", r"\bcommit suicide\b", r"\bsuicide\b", r"\bself[-\s]?harm\b", r"\bhurt myself\b", r"\bno reason to live\b", r"\bwant to die\b", r"\bending it\b", ] for pat in patterns: if re.search(pat, t): return True return False def synthesize_tts( text: Optional[str], provider: str = "Coqui", coqui_model_name: Optional[str] = None, coqui_speaker: Optional[str] = None, ) -> Optional[str]: if not text: return None ts = int(time.time() * 1000) provider_norm = (provider or "Coqui").strip().lower() # Try Coqui first if requested if provider_norm == "coqui": try: # coqui-tts uses the same import path TTS.api from TTS.api import TTS as CoquiTTS # type: ignore model_name = (coqui_model_name or os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")).strip() engine = CoquiTTS(model_name=model_name, progress_bar=False) out_path_wav = f"/tmp/tts_{ts}.wav" kwargs = {} if coqui_speaker: kwargs["speaker"] = coqui_speaker engine.tts_to_file(text=text, file_path=out_path_wav, **kwargs) return out_path_wav except Exception: pass # Fallback to gTTS try: out_path = f"/tmp/tts_{ts}.mp3" tts = gTTS(text=text, lang="en") tts.save(out_path) return out_path except Exception: return None def list_coqui_speakers(model_name: str) -> List[str]: try: from TTS.api import TTS as CoquiTTS # type: ignore engine = CoquiTTS(model_name=model_name, progress_bar=False) # Try common attributes if hasattr(engine, "speakers") and isinstance(engine.speakers, list): return [str(s) for s in engine.speakers] if hasattr(engine, "speaker_manager") and hasattr(engine.speaker_manager, "speaker_names"): return list(engine.speaker_manager.speaker_names) except Exception: pass # Reasonable defaults for VCTK multi-speaker return ["p225", "p227", "p231", "p233", "p236"] def severity_from_total(total_score: int) -> str: if total_score <= 4: return "Minimal Depression" if total_score <= 9: return "Mild Depression" if total_score <= 14: return "Moderate Depression" if total_score <= 19: return "Moderately Severe Depression" return "Severe Depression" # --------------------------- # PHQ-9 schema and helpers # --------------------------- PHQ9_KEYS_ORDERED: List[str] = [ "interest", "mood", "sleep", "energy", "appetite", "self_worth", "concentration", "motor", "suicidal_thoughts", ] # Lightweight keyword lexicon per item for evidence extraction. # Placeholder for future SHAP/attention-based attributions. PHQ9_KEYWORDS: Dict[str, List[str]] = { "interest": ["interest", "pleasure", "enjoy", "motivation", "hobbies"], "mood": ["depressed", "down", "sad", "hopeless", "blue", "mood"], "sleep": ["sleep", "insomnia", "awake", "wake up", "night", "dream"], "energy": ["tired", "fatigue", "energy", "exhausted", "worn out"], "appetite": ["appetite", "eat", "eating", "hungry", "food", "weight"], "self_worth": ["worthless", "failure", "guilty", "guilt", "self-esteem", "ashamed"], "concentration": ["concentrate", "focus", "attention", "distracted", "remember"], "motor": ["restless", "slow", "slowed", "agitated", "fidget", "move"], "suicidal_thoughts": ["suicide", "kill myself", "die", "end my life", "self-harm", "hurt myself"], } def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str: """Convert chatbot history [(user, assistant), ...] to a plain text transcript.""" lines = [] for user, assistant in chat_history: if user: lines.append(f"Patient: {user}") if assistant: lines.append(f"Clinician: {assistant}") return "\n".join(lines) def _patient_sentences(chat_history: List[Tuple[str, str]]) -> List[str]: """Extract patient-only sentences from chat history.""" sentences: List[str] = [] for user, _assistant in chat_history: if not user: continue parts = re.split(r"(?<=[.!?])\s+", user.strip()) for p in parts: p = p.strip() if p: sentences.append(p) return sentences def _extract_quotes_per_item(chat_history: List[Tuple[str, str]]) -> Dict[str, List[str]]: """Heuristic extraction of per-item evidence quotes from patient sentences based on keywords.""" quotes: Dict[str, List[str]] = {k: [] for k in PHQ9_KEYS_ORDERED} sentences = _patient_sentences(chat_history) for sent in sentences: s_low = sent.lower() for item, kws in PHQ9_KEYWORDS.items(): if any(kw in s_low for kw in kws): if len(quotes[item]) < 5: quotes[item].append(sent) return quotes def _extract_quotes_with_turns(chat_history: List[Tuple[str, str]]) -> Dict[str, List[Dict[str, Any]]]: """Like _extract_quotes_per_item, but includes the patient turn index for each quote. Turn index here is 1-based and corresponds to the internal `turns` counter in the loop. """ quotes: Dict[str, List[Dict[str, Any]]] = {k: [] for k in PHQ9_KEYS_ORDERED} turn_idx = 0 for user, _assistant in chat_history: if not user: continue turn_idx += 1 parts = re.split(r"(?<=[.!?])\s+", user.strip()) for p in parts: sent = p.strip() if not sent: continue s_low = sent.lower() for item, kws in PHQ9_KEYWORDS.items(): if any(kw in s_low for kw in kws): if len(quotes[item]) < 5: quotes[item].append({"quote": sent, "turn": turn_idx}) return quotes def explainability_light( chat_history: List[Tuple[str, str]], scores: Dict[str, int], confidences: List[float], threshold: float, ) -> Dict[str, Any]: """Lightweight explainability per turn. - Inspects transcript for keyword-based evidence per PHQ-9 item. - Classifies evidence strength as strong/weak/missing using keyword hits and confidence. - Suggests next focus item based on lowest-confidence or missing evidence. Returns a JSON-serializable dict. """ quotes = _extract_quotes_per_item(chat_history) conf_map: Dict[str, float] = {} for idx, key in enumerate(PHQ9_KEYS_ORDERED): conf_map[key] = float(confidences[idx] if idx < len(confidences) else 0.0) evidence_strength: Dict[str, str] = {} for key in PHQ9_KEYS_ORDERED: hits = len(quotes.get(key, [])) conf = conf_map.get(key, 0.0) if hits >= 2 and conf >= max(0.6, threshold - 0.1): evidence_strength[key] = "strong" elif hits >= 1 or conf >= 0.4: evidence_strength[key] = "weak" else: evidence_strength[key] = "missing" low_items = sorted( PHQ9_KEYS_ORDERED, key=lambda k: (evidence_strength[k] != "missing", conf_map.get(k, 0.0)) ) recommended = low_items[0] if low_items else None return { "evidence_strength": evidence_strength, "low_confidence_items": [k for k in sorted(PHQ9_KEYS_ORDERED, key=lambda x: conf_map.get(x, 0.0))], "recommended_focus": recommended, "quotes": quotes, "confidences": conf_map, } def explainability_full( chat_history: List[Tuple[str, str]], confidences: List[float], features_history: Optional[List[Dict[str, Any]]], ) -> Dict[str, Any]: """Aggregate linguistic and acoustic attributions at session end. - Linguistic: keyword-based quotes per item (placeholder for SHAP/attention). - Acoustic: mean of per-turn prosodic features; returned as name=value strings. """ def _aggregate_prosody(history: List[Dict[str, Any]]) -> Dict[str, float]: agg: Dict[str, float] = {} if not history: return agg # history entries may be {"features": {...}, "turn": int, "patient_text": str} feature_dicts = [d.get("features", {}) if isinstance(d, dict) else {} for d in history] keys = set().union(*[fd.keys() for fd in feature_dicts if isinstance(fd, dict)]) for k in keys: vals = [float(fd[k]) for fd in feature_dicts if isinstance(fd, dict) and k in fd] if vals: agg[k] = float(np.mean(vals)) return agg # Build turn->features map for per-turn descriptors turn_to_feats: Dict[int, Dict[str, float]] = {} for entry in (features_history or []): if isinstance(entry, dict) and isinstance(entry.get("turn"), int) and isinstance(entry.get("features"), dict): turn_to_feats[int(entry["turn"])] = dict(entry["features"]) # shallow copy quotes_with_turns = _extract_quotes_with_turns(chat_history) conf_map = {k: float(confidences[i] if i < len(confidences) else 0.0) for i, k in enumerate(PHQ9_KEYS_ORDERED)} prosody_agg = _aggregate_prosody(list(features_history or [])) prosody_pairs = sorted(list(prosody_agg.items()), key=lambda kv: -abs(kv[1])) prosody_names = [f"{k}={v:.3f}" for k, v in prosody_pairs[:8]] # Prepare session statistics for relative descriptors def _compute_stats(values: List[float]) -> Tuple[float, float]: if not values: return 0.0, 1.0 mu = float(np.mean(values)) sd = float(np.std(values)) return mu, (sd if sd > 1e-8 else 1.0) feat_keys = ["rms_mean", "f0_iqr", "centroid_mean", "voiced_ratio"] stats: Dict[str, Tuple[float, float]] = {} for k in feat_keys: vals = [float(fd.get("features", {}).get(k)) for fd in (features_history or []) if isinstance(fd, dict) and isinstance(fd.get("features"), dict) and k in fd.get("features", {})] stats[k] = _compute_stats(vals) def _z(v: Optional[float], mu: float, sd: float) -> float: if v is None: return 0.0 return (float(v) - mu) / sd def prosody_descriptors(feats: Optional[Dict[str, float]]) -> List[str]: if not feats: return [] desc: List[str] = [] z_rms = _z(feats.get("rms_mean"), *stats["rms_mean"]) if "rms_mean" in stats else 0.0 if z_rms <= -0.5: desc.append("low volume") elif z_rms >= 0.5: desc.append("raised volume") else: desc.append("moderate volume") z_f0iqr = _z(feats.get("f0_iqr"), *stats["f0_iqr"]) if "f0_iqr" in stats else 0.0 if z_f0iqr <= -0.5: desc.append("flat intonation") elif z_f0iqr >= 0.5: desc.append("expressive pitch") else: desc.append("normal intonation") z_cent = _z(feats.get("centroid_mean"), *stats["centroid_mean"]) if "centroid_mean" in stats else 0.0 if z_cent <= -0.5: desc.append("darker tone") elif z_cent >= 0.5: desc.append("brighter tone") z_vr = _z(feats.get("voiced_ratio"), *stats["voiced_ratio"]) if "voiced_ratio" in stats else 0.0 if z_vr <= -0.5: desc.append("more pauses") elif z_vr >= 0.5: desc.append("continuous speech") else: desc.append("typical pacing") return desc items = [] for k in PHQ9_KEYS_ORDERED: ev_turns = quotes_with_turns.get(k, [])[:5] evidence_texts = [et.get("quote") for et in ev_turns if isinstance(et, dict) and et.get("quote")] evidence_details = [] for et in ev_turns: turn_id = int(et.get("turn", 0)) if isinstance(et, dict) else 0 feats = turn_to_feats.get(turn_id) descriptors = prosody_descriptors(feats) evidence_details.append({ "quote": et.get("quote"), "turn": turn_id, "prosody_descriptors": descriptors, }) items.append({ "item": k, "confidence": conf_map.get(k, 0.0), "evidence": evidence_texts, "evidence_details": evidence_details, "prosody_session": prosody_names, }) return { "items": items, "notes": "Heuristic keyword and prosody aggregation; plug in SHAP/attention later.", } def reflection_module( scores: Dict[str, int], confidences: List[float], exp_light: Optional[Dict[str, Any]], exp_full: Optional[Dict[str, Any]], threshold: float, ) -> Dict[str, Any]: """Self-reflection / output reevaluation. Heuristic: if confidence for an item < threshold and evidence is missing, reduce score by 1 (min 0). Returns a `reflection_report` JSON with corrected scores and final summary. """ corrected = dict(scores or {}) strength = (exp_light or {}).get("evidence_strength", {}) if isinstance(exp_light, dict) else {} changes: List[Tuple[str, int, int]] = [] for i, k in enumerate(PHQ9_KEYS_ORDERED): conf = float(confidences[i] if i < len(confidences) else 0.0) if conf < float(threshold) and strength.get(k) == "missing": new_val = max(0, int(corrected.get(k, 0)) - 1) if new_val != corrected.get(k, 0): changes.append((k, int(corrected.get(k, 0)), new_val)) corrected[k] = new_val final_total = int(sum(corrected.values())) final_sev = severity_from_total(final_total) consistency = float(1.0 - (len(changes) / max(1, len(PHQ9_KEYS_ORDERED)))) if changes: notes = ", ".join([f"{k}: {old}->{new}" for k, old, new in changes]) notes = f"Model revised items due to low confidence and missing evidence: {notes}." else: notes = "No score revisions; explanations consistent with outputs." return { "corrected_scores": corrected, "final_total": final_total, "severity_label": final_sev, "consistency_score": consistency, "notes": notes, } def build_patient_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str: severity = meta.get("Severity") or display_json.get("Severity") total = meta.get("Total_Score") or display_json.get("Total_Score") transcript_text = transcript_to_text(chat_history) # Optional enriched content exp_full = display_json.get("Explainability_Full") or {} reflection = display_json.get("Reflection_Report") or {} lines = [] lines.append("# Summary for Patient\n") # if total is not None and severity: # lines.append(f"- PHQ‑9 Total: **{total}** ") # lines.append(f"- Severity: **{severity}**\n") # # Highlights: show one quote per item if available # if exp_full and isinstance(exp_full, dict): # items = exp_full.get("items", []) # if isinstance(items, list) and items: # lines.append("### Highlights from our conversation\n") # for it in items: # item = it.get("item") # ev = it.get("evidence", []) # if item and ev: # lines.append(f"- {item}: \"{ev[0]}\"") # lines.append("") # if reflection: # note = reflection.get("notes") # if note: # lines.append("### Reflection\n") # lines.append(note) # lines.append("") lines.append("### Conversation Transcript\n\n") lines.append(f"```\n{transcript_text}\n```") return "\n".join(lines) def build_clinician_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str: scores = display_json.get("PHQ9_Scores", {}) confidences = display_json.get("Confidences", []) severity = meta.get("Severity") or display_json.get("Severity") total = meta.get("Total_Score") or display_json.get("Total_Score") risk = display_json.get("High_Risk") transcript_text = transcript_to_text(chat_history) scores_lines = "\n".join([f"- {k}: {v}" for k, v in scores.items()]) conf_str = ", ".join([f"{c:.2f}" for c in confidences]) if confidences else "" # Optional explainability exp_light = display_json.get("Explainability_Light") or {} exp_full = display_json.get("Explainability_Full") or {} reflection = display_json.get("Reflection_Report") or {} md = [] md.append("# Summary for Clinician\n") md.append(f"- Severity: **{severity}** ") md.append(f"- PHQ‑9 Total: **{total}** ") if risk is not None: md.append(f"- High Risk: **{risk}** ") md.append("") md.append("### Item Scores\n" + scores_lines + "\n") # # Confidence bars # if confidences: # bars = [] # for i, k in enumerate(scores.keys()): # c = confidences[i] if i < len(confidences) else 0.0 # bar_len = int(round(c * 20)) # bars.append(f"- {k}: [{'#'*bar_len}{'.'*(20-bar_len)}] {c:.2f}") # md.append("### Confidence by item\n" + "\n".join(bars) + "\n") # # Light explainability snapshot # if exp_light: # strength = exp_light.get("evidence_strength", {}) # recommended = exp_light.get("recommended_focus") # if strength: # md.append("### Evidence strength (light)\n") # md.extend([f"- {k}: {v}" for k, v in strength.items()]) # md.append("") # if recommended: # md.append(f"- Next focus (if continuing): **{recommended}**\n") # Full explainability excerpts if exp_full and isinstance(exp_full, dict): md.append("### Explainability\n") items = exp_full.get("items", []) for it in items: item = it.get("item") conf = it.get("confidence") ev = it.get("evidence", []) pros = it.get("prosody_session", []) ev_details = it.get("evidence_details", []) if item: md.append(f"- {item}:") # Show per-quote with turn and descriptors if available for ed in ev_details[:2]: q = ed.get("quote") t = ed.get("turn") pd = ed.get("prosody_descriptors", []) if q: md.append(f" - Turn {t}: \"{q}\" ({', '.join(pd)})") # if pros: # md.append(f" - session prosody: {', '.join([str(p) for p in pros[:4]])}") md.append("") # Reflection summary if reflection: md.append("### Self-reflection\n") notes = reflection.get("notes") if notes: md.append(notes) corr = reflection.get("corrected_scores") or {} if corr and corr != scores: changed = [k for k in scores.keys() if corr.get(k) != scores.get(k)] if changed: md.append("- Adjusted items: " + ", ".join(changed)) md.append("") md.append("### Conversation Transcript\n\n") md.append(f"```\n{transcript_text}\n```") return "\n".join(md) def generate_recording_agent_reply(chat_history: List[Tuple[str, str]], guidance: Optional[Dict[str, Any]] = None) -> str: transcript = transcript_to_text(chat_history) system_prompt = ( "You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms " "without listing the nine questions explicitly. Keep tone empathetic, natural, and human. " "Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, " "sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts." ) focus_text = "" if guidance and isinstance(guidance, dict): rec = guidance.get("recommended_focus") if rec: focus_text = ( f"\n\nGuidance: Focus the next question on the patient's {str(rec).replace('_', ' ')}. " "Ask naturally about recent changes and their impact on daily life." ) user_prompt = ( "Conversation so far (Patient and Clinician turns):\n\n" + transcript + f"{focus_text}\n\nRespond with a single short clinician-style question for the patient." ) pipe = get_textgen_pipeline() tokenizer = pipe.tokenizer combined_prompt = system_prompt + "\n\n" + user_prompt messages = [ {"role": "user", "content": combined_prompt}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Avoid TorchInductor graph capture issues on some environments try: import torch._dynamo as _dynamo # type: ignore except Exception: _dynamo = None gen = pipe( prompt, max_new_tokens=96, temperature=0.7, do_sample=True, top_p=0.9, top_k=50, pad_token_id=tokenizer.eos_token_id, return_full_text=False, ) reply = gen[0]["generated_text"].strip() # Ensure it's a single concise question/sentence if len(reply) > 300: reply = reply[:300].rstrip() + "…" return reply def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str, float]) -> Dict[str, Any]: """Ask the LLM to produce PHQ-9 scores and confidences as JSON. Fallback if parsing fails.""" transcript = transcript_to_text(chat_history) features_json = json.dumps(features, ensure_ascii=False) system_prompt = ( "You evaluate an on-going clinician-patient conversation to infer a PHQ-9 assessment. " "Return ONLY a JSON object with: PHQ9_Scores (interest,mood,sleep,energy,appetite,self_worth,concentration,motor,suicidal_thoughts; each 0-3), " "Confidences (list of 9 floats 0-1 in the same order), Total_Score (0-27), Severity (string), Confidence (min of confidences), " "and High_Risk (boolean, true if any suicidal risk)." ) user_prompt = ( "Conversation transcript:"\ f"\n{transcript}\n\n" f"Acoustic features summary (approximate):\n{features_json}\n\n" "Instructions: Infer PHQ9_Scores (0-3 per item), estimate Confidences per item, compute Total_Score and overall Severity. " "Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose." ) pipe = get_textgen_pipeline() tokenizer = pipe.tokenizer combined_prompt = system_prompt + "\n\n" + user_prompt messages = [ {"role": "user", "content": combined_prompt}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Use deterministic decoding to avoid CUDA sampling edge cases on some models try: import torch._dynamo as _dynamo # type: ignore except Exception: _dynamo = None gen = pipe( prompt, max_new_tokens=256, temperature=0.0, do_sample=False, pad_token_id=tokenizer.eos_token_id, return_full_text=False, ) out_text = gen[0]["generated_text"] parsed = safe_json_extract(out_text) # Validate and coerce if parsed is None or "PHQ9_Scores" not in parsed: # Simple fallback heuristic: neutral scores with low confidence scores = { "interest": 1, "mood": 1, "sleep": 1, "energy": 1, "appetite": 1, "self_worth": 1, "concentration": 1, "motor": 1, "suicidal_thoughts": 0, } confidences = [0.5] * 9 total = int(sum(scores.values())) return { "PHQ9_Scores": scores, "Confidences": confidences, "Total_Score": total, "Severity": severity_from_total(total), "Confidence": float(min(confidences)), "High_Risk": False, } try: # Coerce types and compute derived values if missing scores = parsed.get("PHQ9_Scores", {}) # Ensure all keys present keys = [ "interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts" ] for k in keys: scores[k] = int(max(0, min(3, int(scores.get(k, 0))))) confidences = parsed.get("Confidences", []) if not isinstance(confidences, list) or len(confidences) != 9: confidences = [float(parsed.get("Confidence", 0.5))] * 9 confidences = [float(max(0.0, min(1.0, c))) for c in confidences] total = int(sum(scores.values())) severity = parsed.get("Severity") or severity_from_total(total) overall_conf = float(parsed.get("Confidence", min(confidences))) # Conservative high-risk detection: require explicit language or high suicidal_thoughts score # Extract last patient message last_patient = "" for user_text, assistant_text in reversed(chat_history): if user_text: last_patient = user_text break explicit_flag = detect_explicit_suicidality(last_patient) or detect_explicit_suicidality(transcript) high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2)) return { "PHQ9_Scores": scores, "Confidences": confidences, "Total_Score": total, "Severity": severity, "Confidence": overall_conf, "High_Risk": high_risk, } except Exception: # Final fallback scores = parsed.get("PHQ9_Scores", {}) if isinstance(parsed, dict) else {} if not scores: scores = {k: 1 for k in [ "interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts" ]} confidences = [0.5] * 9 total = int(sum(scores.values())) return { "PHQ9_Scores": scores, "Confidences": confidences, "Total_Score": total, "Severity": severity_from_total(total), "Confidence": float(min(confidences)), "High_Risk": False, } def transcribe_audio(audio_path: Optional[str]) -> str: if not audio_path: return "" try: asr = get_asr_pipeline() result = asr(audio_path) if isinstance(result, dict) and "text" in result: return result["text"].strip() if isinstance(result, list) and len(result) > 0 and "text" in result[0]: return result[0]["text"].strip() except Exception: pass return "" # --------------------------- # Gradio app logic # --------------------------- INTRO_MESSAGE = ( "Hi, I'm an assistant, and I will ask you some questions about how you've been doing." "We'll record our conversation, and we will give you a written copy of it." "From our conversation, we will send a written copy to the clinician, we will give a summary of what you are experiencing based on a questionnaire, called the Patient Health Questionnaire (PHQ-9), and we will give a summary of what your voice is like." "We will send this to the clinician, and the clinician will follow up with you." "To start, how has your mood been over the past couple of weeks?" ) def init_state() -> Tuple[List[Tuple[str, str]], Dict[str, Any], Dict[str, Any], bool, int]: chat_history: List[Tuple[str, str]] = [("", INTRO_MESSAGE)] scores = {} meta = {"Severity": None, "Total_Score": None, "Confidence": 0.0} finished = False turns = 0 return chat_history, scores, meta, finished, turns @spaces.GPU def process_turn( audio_path: Optional[str], text_input: Optional[str], chat_history: List[Tuple[str, str]], threshold: float, tts_enabled: bool, finished: Optional[bool], turns: Optional[int], prev_scores: Dict[str, Any], prev_meta: Dict[str, Any], ): # If already finished, do nothing finished = bool(finished) if finished is not None else False turns = int(turns) if isinstance(turns, int) else 0 if finished: return ( chat_history, {"info": "Assessment complete."}, prev_meta.get("Severity", ""), finished, turns, None, None, None, None, ) patient_text = (text_input or "").strip() audio_features: Dict[str, float] = {} if audio_path: # Transcribe first transcribed = transcribe_audio(audio_path) if transcribed: patient_text = (patient_text + " ").strip() + transcribed if patient_text else transcribed # Extract features audio_features = compute_audio_features(audio_path) if not patient_text: # Ask user for input chat_history.append(("", "I didn't catch that. Could you share a bit about how you've been feeling?")) return ( chat_history, prev_scores or {}, prev_meta.get("Severity", ""), finished, turns, None, None, None, None, ) # Add patient's message chat_history.append((patient_text, None)) # Scoring agent scoring = scoring_agent_infer(chat_history, audio_features) scores = scoring.get("PHQ9_Scores", {}) confidences = scoring.get("Confidences", []) total = scoring.get("Total_Score", 0) severity = scoring.get("Severity", severity_from_total(total)) overall_conf = float(scoring.get("Confidence", min(confidences) if confidences else 0.0)) # Override high-risk to reduce false positives: rely on explicit text or high item score # Extract last patient message last_patient = "" for user_text, assistant_text in reversed(chat_history): if user_text: last_patient = user_text break explicit_flag = detect_explicit_suicidality(last_patient) high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2)) meta = {"Severity": severity, "Total_Score": total, "Confidence": overall_conf} # Termination conditions min_conf = float(min(confidences)) if confidences else 0.0 turns += 1 done = high_risk or (min_conf >= threshold) or (turns >= MAX_TURNS) if high_risk: closing = ( "I’m concerned about your safety based on what you shared. " "If you are in danger or need immediate help, please call 116 123 in the UK, 988 in the U.S. or your local emergency number. " "I'll end the assessment now and display emergency resources." ) chat_history[-1] = (chat_history[-1][0], closing) finished = True elif done: summary = ( "Thank you for your time. The clinician will review your conversation and follow up with you." "Here is a copy of our conversation so you can review it later." ) chat_history[-1] = (chat_history[-1][0], summary) finished = True else: # Iterative explainability (light) to guide next question light_exp = explainability_light(chat_history, scores, confidences, float(threshold)) # Generate next clinician question with guidance reply = generate_recording_agent_reply(chat_history, guidance=light_exp) chat_history[-1] = (chat_history[-1][0], reply) # TTS for the latest clinician message, if enabled tts_path = synthesize_tts(chat_history[-1][1]) if tts_enabled else None # Build a compact JSON for display display_json = { "PHQ9_Scores": scores, "Confidences": confidences, "Total_Score": total, "Severity": severity, "Confidence": overall_conf, "High_Risk": high_risk, # Include the last audio features and light explainability for downstream modules/UI "Last_Audio_Features": audio_features, "Explainability_Light": explainability_light(chat_history, scores, confidences, float(threshold)), } # Clear inputs after processing return ( chat_history, display_json, severity, finished, turns, None, None, tts_path, tts_path, ) def reset_app(): return init_state() # --------------------------- # UI # --------------------------- def _on_load_init(): return init_state() def _on_load_init_with_tts(tts_on: bool): chat_history, scores_state, meta_state, finished_state, turns_state = init_state() # Play the intro message via TTS if enabled tts_path = synthesize_tts(chat_history[-1][1]) if bool(tts_on) else None return chat_history, scores_state, meta_state, finished_state, turns_state, tts_path def _play_intro_tts(tts_on: bool): if not bool(tts_on): return None try: return synthesize_tts(INTRO_MESSAGE) except Exception: return None def create_demo(): with gr.Blocks( theme=gr.themes.Soft(), css=''' /* Voice bubble styles - clean and centered */ #voice-bubble { width: 240px; height: 240px; border-radius: 9999px; margin: 40px auto; display: flex; align-items: center; justify-content: center; background: linear-gradient(135deg, #6ee7b7 0%, #34d399 100%); box-shadow: 0 20px 40px rgba(16,185,129,0.3), 0 0 0 1px rgba(255,255,255,0.1) inset; transition: all 250ms cubic-bezier(0.4, 0, 0.2, 1); cursor: default; /* green circle itself is not clickable */ pointer-events: none; /* ignore clicks on the green circle */ position: relative; } #voice-bubble:hover { transform: translateY(-2px) scale(1.02); box-shadow: 0 25px 50px rgba(16,185,129,0.4), 0 0 0 1px rgba(255,255,255,0.15) inset; } #voice-bubble:active { transform: translateY(0px) scale(0.98); } #voice-bubble.listening { animation: bubble-pulse 1.5s ease-in-out infinite; background: linear-gradient(135deg, #fb7185 0%, #ef4444 100%); box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 1px rgba(255,255,255,0.1) inset; } @keyframes bubble-pulse { 0%, 100% { transform: scale(1.0); box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 0 rgba(239,68,68,0.5); } 50% { transform: scale(1.05); box-shadow: 0 25px 50px rgba(239,68,68,0.5), 0 0 0 15px rgba(239,68,68,0.0); } } /* Hide microphone dropdown selector only */ #voice-bubble select { display: none !important; } #voice-bubble .source-selection { display: none !important; } #voice-bubble label[for] { display: none !important; } /* Make the inner button the only clickable target */ #voice-bubble button { pointer-events: auto; cursor: pointer; } /* Hide TTS player UI but keep it in DOM for autoplay */ #tts-player { width: 0 !important; height: 0 !important; opacity: 0 !important; position: absolute; pointer-events: none; } /* Settings and back buttons in top right - compact */ #settings-btn { position: absolute; top: 16px; right: 16px; z-index: 10; width: auto !important; min-width: 100px !important; max-width: 120px !important; padding: 8px 16px !important; } #back-btn { position: absolute; top: 8px; right: 8px; z-index: 10; width: auto !important; min-width: 80px !important; max-width: 100px !important; padding: 8px 16px !important; } /* Play Intro button positioned under Settings */ #intro-btn { position: absolute; top: 60px; right: 16px; z-index: 10; width: auto !important; min-width: 100px !important; max-width: 120px !important; padding: 8px 16px !important; } ''' ) as demo: # Main view with gr.Column(visible=True) as main_view: # Settings button (top right, only in main view) settings_btn = gr.Button("⚙️ Settings", elem_id="settings-btn", size="sm", visible=DEBUG_MODE) gr.Markdown( """ ### Conversational Assessment for Responsive Engagement (CARE) Notes Tap on 'Record' to start speaking, then tap on 'Stop' to stop recording. """ ) intro_play_btn = gr.Button("▶️ Start", elem_id="intro-btn", variant="secondary", size="sm") # Microphone component styled as central bubble (tap to record/stop) audio_main = gr.Microphone(type="filepath", label=None, elem_id="voice-bubble", show_label=False) # Hidden text input placeholder for pipeline compatibility text_main = gr.Textbox(value="", visible=False) # Autoplay clinician voice output (player hidden with CSS) tts_audio_main = gr.Audio(label=None, interactive=False, autoplay=True, show_label=False, elem_id="tts-player") # Final summaries (shown after assessment ends) main_summary = gr.Markdown(visible=False) # Settings view (initially hidden) with gr.Column(visible=False) as settings_view: back_btn = gr.Button("← Back", elem_id="back-btn", size="sm") gr.Markdown("## Settings") chatbot = gr.Chatbot(height=360, type="tuples", label="Conversation") with gr.Row(): text_adv = gr.Textbox(placeholder="Type your message and press Enter", scale=4) send_adv_btn = gr.Button("Send", scale=1) score_json = gr.JSON(label="PHQ-9 Assessment (live)") severity_label = gr.Label(label="Severity") threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)") tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT) with gr.Row(): tts_provider_dd = gr.Dropdown(choices=["Coqui", "gTTS"], value="Coqui", label="TTS Provider") coqui_model_tb = gr.Textbox(value=os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits"), label="Coqui Model") coqui_speaker_dd = gr.Dropdown(choices=list_coqui_speakers(os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")), value="p225", label="Coqui Speaker") tts_audio = gr.Audio(label="Clinician voice", interactive=False, autoplay=False, visible=False) model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it") with gr.Row(): apply_model_btn = gr.Button("Apply model (no restart)") # apply_model_restart_btn = gr.Button("Apply model and restart") model_status = gr.Markdown(value=f"Current model: `{current_model_id}`") # App state chat_state = gr.State() scores_state = gr.State() meta_state = gr.State() finished_state = gr.State() turns_state = gr.State() feats_state = gr.State() # Initialize on load (no autoplay due to browser policies) demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state]) # View navigation settings_btn.click( fn=lambda: (gr.update(visible=False), gr.update(visible=True)), inputs=None, outputs=[main_view, settings_view] ) back_btn.click( fn=lambda: (gr.update(visible=True), gr.update(visible=False)), inputs=None, outputs=[main_view, settings_view] ) # Explicit user gesture to play intro TTS (works across browsers) intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main]) # Wire interactions def _process_with_tts(audio, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist): result = process_turn(audio, text, chat, th, tts_on, finished, turns, scores, meta) chat_history, display_json, severity, finished_o, turns_o, _, _, _, last_tts = result # Accumulate last audio features feats_hist = feats_hist or [] last_feats = (display_json or {}).get("Last_Audio_Features") or {} if isinstance(last_feats, dict) and last_feats: # Store with turn index and patient text for descriptive mapping turn_record = { "turn": int(turns_o), "features": last_feats, "patient_text": next((u for (u, a) in chat_history[::-1] if u), ""), } feats_hist = list(feats_hist) + [turn_record] if tts_on and chat_history and chat_history[-1][1]: new_path = synthesize_tts(chat_history[-1][1], provider=provider, coqui_model_name=coqui_model, coqui_speaker=coqui_speaker) else: new_path = None # If finished, hide the mic and display summaries in Main if finished_o: # Run full explainability and reflection exp_full = explainability_full(chat_history, display_json.get("Confidences", []), feats_hist) reflect = reflection_module(display_json.get("PHQ9_Scores", {}), display_json.get("Confidences", []), display_json.get("Explainability_Light", {}), exp_full, float(th)) display_json["Explainability_Full"] = exp_full display_json["Reflection_Report"] = reflect # Use reflection outputs to set final meta final_sev = reflect.get("severity_label") or severity final_total = reflect.get("final_total") or display_json.get("Total_Score") patient_md = build_patient_summary(chat_history, {"Severity": final_sev, "Total_Score": final_total}, display_json) clinician_md = build_clinician_summary(chat_history, {"Severity": final_sev, "Total_Score": final_total}, display_json) summary_md = patient_md + "\n\n---\n\n" + clinician_md return chat_history, display_json, severity, finished_o, turns_o, gr.update(visible=False), None, new_path, new_path, gr.update(value=summary_md, visible=True), feats_hist return chat_history, display_json, severity, finished_o, turns_o, None, None, new_path, new_path, gr.update(visible=False), feats_hist audio_main.stop_recording( fn=_process_with_tts, inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state], outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state], queue=True, api_name="message", ) # Text input flow from Advanced tab def _process_text_and_clear(text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist): res = _process_with_tts(None, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist) return (*res, "") text_adv.submit( fn=_process_text_and_clear, inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state], outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state, text_adv], queue=True, ) send_adv_btn.click( fn=_process_text_and_clear, inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state], outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state, text_adv], queue=True, ) # Tap bubble to toggle microphone record/stop via JS # This JS is no longer needed as the bubble is the mic # voice_bubble.click( # None, # inputs=None, # outputs=None, # js="() => {\n const bubble = document.getElementById('voice-bubble');\n const root = document.getElementById('hidden-mic');\n if (!root) return;\n let didClick = false;\n const wc = root.querySelector && root.querySelector('gradio-audio');\n if (wc && wc.shadowRoot) {\n const btns = Array.from(wc.shadowRoot.querySelectorAll('button')).filter(b => !b.disabled);\n const txt = (b) => ((b.getAttribute('aria-label')||'') + ' ' + (b.textContent||'')).toLowerCase();\n const stopBtn = btns.find(b => txt(b).includes('stop'));\n const recBtn = btns.find(b => { const t = txt(b); return t.includes('record') || t.includes('start') || t.includes('microphone') || t.includes('mic'); });\n if (stopBtn) { stopBtn.click(); didClick = true; } else if (recBtn) { recBtn.click(); didClick = true; } else if (btns[0]) { btns[0].click(); didClick = true; }\n }\n if (!didClick) {\n const candidates = Array.from(root.querySelectorAll('button, [role=\\'button\\']')).filter(el => !el.disabled);\n if (candidates.length) { candidates[0].click(); didClick = true; }\n }\n if (bubble && didClick) bubble.classList.toggle('listening');\n }", # ) # No reset button in Main tab anymore # Model switch handlers def _on_apply_model(mid: str): msg = set_current_model_id(mid) return f"Current model: `{current_model_id}`\n\n{msg}" def _on_apply_model_restart(mid: str): msg = apply_model_and_restart(mid) return f"{msg}" apply_model_btn.click(fn=_on_apply_model, inputs=[model_id_tb], outputs=[model_status]) # apply_model_restart_btn.click(fn=_on_apply_model_restart, inputs=[model_id_tb], outputs=[model_status]) return demo demo = create_demo() if __name__ == "__main__": # For local dev demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))