Akis Giannoukos
Added explainability
09716a4
raw
history blame
48.4 kB
import os
# Disable torch compile/dynamo globally to avoid cudagraph assertion errors
os.environ["TORCHDYNAMO_DISABLE"] = "1"
os.environ["TORCH_COMPILE_DISABLE"] = "1"
import json
import re
import time
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import numpy as np
# Audio processing
import soundfile as sf
import librosa
# Models
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
)
from gtts import gTTS
import spaces
import threading
# ---------------------------
# Configuration
# ---------------------------
DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "google/gemma-2-2b-it")
DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en")
CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8"))
MAX_TURNS = int(os.getenv("MAX_TURNS", "12"))
USE_TTS_DEFAULT = os.getenv("USE_TTS", "true").strip().lower() == "true"
CONFIG_PATH = os.getenv("MODEL_CONFIG_PATH", "model_config.json")
def _load_model_id_from_config() -> str:
try:
if os.path.exists(CONFIG_PATH):
with open(CONFIG_PATH, "r") as f:
data = json.load(f)
if isinstance(data, dict) and data.get("model_id"):
return str(data["model_id"])
except Exception:
pass
return DEFAULT_CHAT_MODEL_ID
current_model_id = _load_model_id_from_config()
# ---------------------------
# Lazy singletons for pipelines
# ---------------------------
_asr_pipe = None
_gen_pipe = None
_tokenizer = None
def _hf_device() -> int:
return 0 if torch.cuda.is_available() else -1
def get_asr_pipeline():
global _asr_pipe
if _asr_pipe is None:
_asr_pipe = pipeline(
"automatic-speech-recognition",
model=DEFAULT_ASR_MODEL_ID,
device=_hf_device(),
)
return _asr_pipe
def get_textgen_pipeline():
global _gen_pipe
if _gen_pipe is None:
# Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID
if torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
_dtype = torch.bfloat16
elif torch.cuda.is_available():
_dtype = torch.float16
else:
_dtype = torch.float32
_gen_pipe = pipeline(
task="text-generation",
model=current_model_id,
tokenizer=current_model_id,
device=_hf_device(),
torch_dtype=_dtype,
)
return _gen_pipe
def set_current_model_id(new_model_id: str) -> str:
global current_model_id, _gen_pipe
new_model_id = (new_model_id or "").strip()
if not new_model_id:
return "Model id is empty; keeping current model."
if new_model_id == current_model_id:
return f"Model unchanged: `{current_model_id}`"
current_model_id = new_model_id
_gen_pipe = None # force reload on next use
return f"Model switched to `{current_model_id}` (pipeline will reload on next generation)."
def persist_model_id(new_model_id: str) -> None:
try:
with open(CONFIG_PATH, "w") as f:
json.dump({"model_id": new_model_id}, f)
except Exception:
pass
def apply_model_and_restart(new_model_id: str) -> str:
mid = (new_model_id or "").strip()
if not mid:
return "Model id is empty; not restarting."
persist_model_id(mid)
set_current_model_id(mid)
# Graceful delayed exit so response can flush
def _exit_later():
time.sleep(0.25)
os._exit(0)
threading.Thread(target=_exit_later, daemon=True).start()
return f"Restarting with model `{mid}`..."
# ---------------------------
# Utilities
# ---------------------------
def safe_json_extract(text: str) -> Optional[Dict[str, Any]]:
"""Extract first JSON object from text."""
if not text:
return None
try:
return json.loads(text)
except Exception:
pass
# Fallback: find the first {...} block
match = re.search(r"\{[\s\S]*\}", text)
if match:
try:
return json.loads(match.group(0))
except Exception:
return None
return None
def compute_audio_features(audio_path: str) -> Dict[str, float]:
"""Compute lightweight prosodic features as a proxy for OpenSMILE.
Returns a dictionary with summary statistics.
"""
try:
y, sr = librosa.load(audio_path, sr=16000, mono=True)
if len(y) == 0:
return {}
# Frame-based features
hop_length = 512
frame_length = 1024
rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
zcr = librosa.feature.zero_crossing_rate(y, frame_length=frame_length, hop_length=hop_length)[0]
centroid = librosa.feature.spectral_centroid(y=y, sr=sr, n_fft=2048, hop_length=hop_length)[0]
# Pitch estimation (coarse)
f0 = None
try:
f0 = librosa.yin(y, fmin=50, fmax=400, sr=sr, frame_length=frame_length, hop_length=hop_length)
f0 = f0[np.isfinite(f0)]
except Exception:
f0 = None
# Speaking rate rough proxy: voiced ratio per second
energy = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
voiced = energy > (np.median(energy) * 1.2)
voiced_ratio = float(np.mean(voiced))
features = {
"rms_mean": float(np.mean(rms)),
"rms_std": float(np.std(rms)),
"zcr_mean": float(np.mean(zcr)),
"zcr_std": float(np.std(zcr)),
"centroid_mean": float(np.mean(centroid)),
"centroid_std": float(np.std(centroid)),
"voiced_ratio": voiced_ratio,
"duration_sec": float(len(y) / sr),
}
if f0 is not None and f0.size > 0:
features.update({
"f0_median": float(np.median(f0)),
"f0_iqr": float(np.percentile(f0, 75) - np.percentile(f0, 25)),
})
return features
except Exception:
return {}
def detect_explicit_suicidality(text: Optional[str]) -> bool:
if not text:
return False
t = text.lower()
patterns = [
r"\bkill myself\b",
r"\bend my life\b",
r"\bend it all\b",
r"\bcommit suicide\b",
r"\bsuicide\b",
r"\bself[-\s]?harm\b",
r"\bhurt myself\b",
r"\bno reason to live\b",
r"\bwant to die\b",
r"\bending it\b",
]
for pat in patterns:
if re.search(pat, t):
return True
return False
def synthesize_tts(
text: Optional[str],
provider: str = "Coqui",
coqui_model_name: Optional[str] = None,
coqui_speaker: Optional[str] = None,
) -> Optional[str]:
if not text:
return None
ts = int(time.time() * 1000)
provider_norm = (provider or "Coqui").strip().lower()
# Try Coqui first if requested
if provider_norm == "coqui":
try:
# coqui-tts uses the same import path TTS.api
from TTS.api import TTS as CoquiTTS # type: ignore
model_name = (coqui_model_name or os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")).strip()
engine = CoquiTTS(model_name=model_name, progress_bar=False)
out_path_wav = f"/tmp/tts_{ts}.wav"
kwargs = {}
if coqui_speaker:
kwargs["speaker"] = coqui_speaker
engine.tts_to_file(text=text, file_path=out_path_wav, **kwargs)
return out_path_wav
except Exception:
pass
# Fallback to gTTS
try:
out_path = f"/tmp/tts_{ts}.mp3"
tts = gTTS(text=text, lang="en")
tts.save(out_path)
return out_path
except Exception:
return None
def list_coqui_speakers(model_name: str) -> List[str]:
try:
from TTS.api import TTS as CoquiTTS # type: ignore
engine = CoquiTTS(model_name=model_name, progress_bar=False)
# Try common attributes
if hasattr(engine, "speakers") and isinstance(engine.speakers, list):
return [str(s) for s in engine.speakers]
if hasattr(engine, "speaker_manager") and hasattr(engine.speaker_manager, "speaker_names"):
return list(engine.speaker_manager.speaker_names)
except Exception:
pass
# Reasonable defaults for VCTK multi-speaker
return ["p225", "p227", "p231", "p233", "p236"]
def severity_from_total(total_score: int) -> str:
if total_score <= 4:
return "Minimal Depression"
if total_score <= 9:
return "Mild Depression"
if total_score <= 14:
return "Moderate Depression"
if total_score <= 19:
return "Moderately Severe Depression"
return "Severe Depression"
# ---------------------------
# PHQ-9 schema and helpers
# ---------------------------
PHQ9_KEYS_ORDERED: List[str] = [
"interest",
"mood",
"sleep",
"energy",
"appetite",
"self_worth",
"concentration",
"motor",
"suicidal_thoughts",
]
# Lightweight keyword lexicon per item for evidence extraction.
# Placeholder for future SHAP/attention-based attributions.
PHQ9_KEYWORDS: Dict[str, List[str]] = {
"interest": ["interest", "pleasure", "enjoy", "motivation", "hobbies"],
"mood": ["depressed", "down", "sad", "hopeless", "blue", "mood"],
"sleep": ["sleep", "insomnia", "awake", "wake up", "night", "dream"],
"energy": ["tired", "fatigue", "energy", "exhausted", "worn out"],
"appetite": ["appetite", "eat", "eating", "hungry", "food", "weight"],
"self_worth": ["worthless", "failure", "guilty", "guilt", "self-esteem", "ashamed"],
"concentration": ["concentrate", "focus", "attention", "distracted", "remember"],
"motor": ["restless", "slow", "slowed", "agitated", "fidget", "move"],
"suicidal_thoughts": ["suicide", "kill myself", "die", "end my life", "self-harm", "hurt myself"],
}
def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str:
"""Convert chatbot history [(user, assistant), ...] to a plain text transcript."""
lines = []
for user, assistant in chat_history:
if user:
lines.append(f"Patient: {user}")
if assistant:
lines.append(f"Clinician: {assistant}")
return "\n".join(lines)
def _patient_sentences(chat_history: List[Tuple[str, str]]) -> List[str]:
"""Extract patient-only sentences from chat history."""
sentences: List[str] = []
for user, _assistant in chat_history:
if not user:
continue
parts = re.split(r"(?<=[.!?])\s+", user.strip())
for p in parts:
p = p.strip()
if p:
sentences.append(p)
return sentences
def _extract_quotes_per_item(chat_history: List[Tuple[str, str]]) -> Dict[str, List[str]]:
"""Heuristic extraction of per-item evidence quotes from patient sentences based on keywords."""
quotes: Dict[str, List[str]] = {k: [] for k in PHQ9_KEYS_ORDERED}
sentences = _patient_sentences(chat_history)
for sent in sentences:
s_low = sent.lower()
for item, kws in PHQ9_KEYWORDS.items():
if any(kw in s_low for kw in kws):
if len(quotes[item]) < 5:
quotes[item].append(sent)
return quotes
def explainability_light(
chat_history: List[Tuple[str, str]],
scores: Dict[str, int],
confidences: List[float],
threshold: float,
) -> Dict[str, Any]:
"""Lightweight explainability per turn.
- Inspects transcript for keyword-based evidence per PHQ-9 item.
- Classifies evidence strength as strong/weak/missing using keyword hits and confidence.
- Suggests next focus item based on lowest-confidence or missing evidence.
Returns a JSON-serializable dict.
"""
quotes = _extract_quotes_per_item(chat_history)
conf_map: Dict[str, float] = {}
for idx, key in enumerate(PHQ9_KEYS_ORDERED):
conf_map[key] = float(confidences[idx] if idx < len(confidences) else 0.0)
evidence_strength: Dict[str, str] = {}
for key in PHQ9_KEYS_ORDERED:
hits = len(quotes.get(key, []))
conf = conf_map.get(key, 0.0)
if hits >= 2 and conf >= max(0.6, threshold - 0.1):
evidence_strength[key] = "strong"
elif hits >= 1 or conf >= 0.4:
evidence_strength[key] = "weak"
else:
evidence_strength[key] = "missing"
low_items = sorted(
PHQ9_KEYS_ORDERED,
key=lambda k: (evidence_strength[k] != "missing", conf_map.get(k, 0.0))
)
recommended = low_items[0] if low_items else None
return {
"evidence_strength": evidence_strength,
"low_confidence_items": [k for k in sorted(PHQ9_KEYS_ORDERED, key=lambda x: conf_map.get(x, 0.0))],
"recommended_focus": recommended,
"quotes": quotes,
"confidences": conf_map,
}
def explainability_full(
chat_history: List[Tuple[str, str]],
confidences: List[float],
features_history: Optional[List[Dict[str, float]]],
) -> Dict[str, Any]:
"""Aggregate linguistic and acoustic attributions at session end.
- Linguistic: keyword-based quotes per item (placeholder for SHAP/attention).
- Acoustic: mean of per-turn prosodic features; returned as name=value strings.
"""
def _aggregate_prosody(history: List[Dict[str, float]]) -> Dict[str, float]:
agg: Dict[str, float] = {}
if not history:
return agg
keys = set().union(*[d.keys() for d in history if isinstance(d, dict)])
for k in keys:
vals = [float(d[k]) for d in history if isinstance(d, dict) and k in d]
if vals:
agg[k] = float(np.mean(vals))
return agg
quotes = _extract_quotes_per_item(chat_history)
conf_map = {k: float(confidences[i] if i < len(confidences) else 0.0) for i, k in enumerate(PHQ9_KEYS_ORDERED)}
prosody_agg = _aggregate_prosody(list(features_history or []))
prosody_pairs = sorted(list(prosody_agg.items()), key=lambda kv: -abs(kv[1]))
prosody_names = [f"{k}={v:.3f}" for k, v in prosody_pairs[:8]]
items = []
for k in PHQ9_KEYS_ORDERED:
items.append({
"item": k,
"confidence": conf_map.get(k, 0.0),
"evidence": quotes.get(k, [])[:5],
"prosody": prosody_names,
})
return {
"items": items,
"notes": "Heuristic keyword and prosody aggregation; plug in SHAP/attention later.",
}
def reflection_module(
scores: Dict[str, int],
confidences: List[float],
exp_light: Optional[Dict[str, Any]],
exp_full: Optional[Dict[str, Any]],
threshold: float,
) -> Dict[str, Any]:
"""Self-reflection / output reevaluation.
Heuristic: if confidence for an item < threshold and evidence is missing, reduce score by 1 (min 0).
Returns a `reflection_report` JSON with corrected scores and final summary.
"""
corrected = dict(scores or {})
strength = (exp_light or {}).get("evidence_strength", {}) if isinstance(exp_light, dict) else {}
changes: List[Tuple[str, int, int]] = []
for i, k in enumerate(PHQ9_KEYS_ORDERED):
conf = float(confidences[i] if i < len(confidences) else 0.0)
if conf < float(threshold) and strength.get(k) == "missing":
new_val = max(0, int(corrected.get(k, 0)) - 1)
if new_val != corrected.get(k, 0):
changes.append((k, int(corrected.get(k, 0)), new_val))
corrected[k] = new_val
final_total = int(sum(corrected.values()))
final_sev = severity_from_total(final_total)
consistency = float(1.0 - (len(changes) / max(1, len(PHQ9_KEYS_ORDERED))))
if changes:
notes = ", ".join([f"{k}: {old}->{new}" for k, old, new in changes])
notes = f"Model revised items due to low confidence and missing evidence: {notes}."
else:
notes = "No score revisions; explanations consistent with outputs."
return {
"corrected_scores": corrected,
"final_total": final_total,
"severity_label": final_sev,
"consistency_score": consistency,
"notes": notes,
}
def build_patient_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str:
severity = meta.get("Severity") or display_json.get("Severity")
total = meta.get("Total_Score") or display_json.get("Total_Score")
transcript_text = transcript_to_text(chat_history)
# Optional enriched content
exp_full = display_json.get("Explainability_Full") or {}
reflection = display_json.get("Reflection_Report") or {}
lines = []
lines.append("# Summary for Patient\n")
if total is not None and severity:
lines.append(f"- PHQ‑9 Total: **{total}** ")
lines.append(f"- Severity: **{severity}**\n")
# Highlights: show one quote per item if available
if exp_full and isinstance(exp_full, dict):
items = exp_full.get("items", [])
if isinstance(items, list) and items:
lines.append("### Highlights from our conversation\n")
for it in items:
item = it.get("item")
ev = it.get("evidence", [])
if item and ev:
lines.append(f"- {item}: \"{ev[0]}\"")
lines.append("")
if reflection:
note = reflection.get("notes")
if note:
lines.append("### Reflection\n")
lines.append(note)
lines.append("")
lines.append("### Conversation Transcript\n\n")
lines.append(f"```\n{transcript_text}\n```")
return "\n".join(lines)
def build_clinician_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str:
scores = display_json.get("PHQ9_Scores", {})
confidences = display_json.get("Confidences", [])
severity = meta.get("Severity") or display_json.get("Severity")
total = meta.get("Total_Score") or display_json.get("Total_Score")
risk = display_json.get("High_Risk")
transcript_text = transcript_to_text(chat_history)
scores_lines = "\n".join([f"- {k}: {v}" for k, v in scores.items()])
conf_str = ", ".join([f"{c:.2f}" for c in confidences]) if confidences else ""
# Optional explainability
exp_light = display_json.get("Explainability_Light") or {}
exp_full = display_json.get("Explainability_Full") or {}
reflection = display_json.get("Reflection_Report") or {}
md = []
md.append("# Summary for Clinician\n")
md.append(f"- Severity: **{severity}** ")
md.append(f"- PHQ‑9 Total: **{total}** ")
if risk is not None:
md.append(f"- High Risk: **{risk}** ")
md.append("")
md.append("### Item Scores\n" + scores_lines + "\n")
# Confidence bars
if confidences:
bars = []
for i, k in enumerate(scores.keys()):
c = confidences[i] if i < len(confidences) else 0.0
bar_len = int(round(c * 20))
bars.append(f"- {k}: [{'#'*bar_len}{'.'*(20-bar_len)}] {c:.2f}")
md.append("### Confidence by item\n" + "\n".join(bars) + "\n")
# Light explainability snapshot
if exp_light:
strength = exp_light.get("evidence_strength", {})
recommended = exp_light.get("recommended_focus")
if strength:
md.append("### Evidence strength (light)\n")
md.extend([f"- {k}: {v}" for k, v in strength.items()])
md.append("")
if recommended:
md.append(f"- Next focus (if continuing): **{recommended}**\n")
# Full explainability excerpts
if exp_full and isinstance(exp_full, dict):
md.append("### Explainability (final)\n")
items = exp_full.get("items", [])
for it in items:
item = it.get("item")
conf = it.get("confidence")
ev = it.get("evidence", [])
pros = it.get("prosody", [])
if item:
md.append(f"- {item} (conf {conf:.2f}):")
for q in ev[:2]:
md.append(f" - \"{q}\"")
if pros:
md.append(f" - prosody: {', '.join([str(p) for p in pros[:4]])}")
md.append("")
# Reflection summary
if reflection:
md.append("### Self-reflection\n")
notes = reflection.get("notes")
if notes:
md.append(notes)
corr = reflection.get("corrected_scores") or {}
if corr and corr != scores:
changed = [k for k in scores.keys() if corr.get(k) != scores.get(k)]
if changed:
md.append("- Adjusted items: " + ", ".join(changed))
md.append("")
md.append("### Conversation Transcript\n\n")
md.append(f"```\n{transcript_text}\n```")
return "\n".join(md)
def generate_recording_agent_reply(chat_history: List[Tuple[str, str]], guidance: Optional[Dict[str, Any]] = None) -> str:
transcript = transcript_to_text(chat_history)
system_prompt = (
"You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms "
"without listing the nine questions explicitly. Keep tone empathetic, natural, and human. "
"Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, "
"sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts."
)
focus_text = ""
if guidance and isinstance(guidance, dict):
rec = guidance.get("recommended_focus")
if rec:
focus_text = (
f"\n\nGuidance: Focus the next question on the patient's {str(rec).replace('_', ' ')}. "
"Ask naturally about recent changes and their impact on daily life."
)
user_prompt = (
"Conversation so far (Patient and Clinician turns):\n\n" + transcript +
f"{focus_text}\n\nRespond with a single short clinician-style question for the patient."
)
pipe = get_textgen_pipeline()
tokenizer = pipe.tokenizer
combined_prompt = system_prompt + "\n\n" + user_prompt
messages = [
{"role": "user", "content": combined_prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Avoid TorchInductor graph capture issues on some environments
try:
import torch._dynamo as _dynamo # type: ignore
except Exception:
_dynamo = None
gen = pipe(
prompt,
max_new_tokens=96,
temperature=0.7,
do_sample=True,
top_p=0.9,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
return_full_text=False,
)
reply = gen[0]["generated_text"].strip()
# Ensure it's a single concise question/sentence
if len(reply) > 300:
reply = reply[:300].rstrip() + "…"
return reply
def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str, float]) -> Dict[str, Any]:
"""Ask the LLM to produce PHQ-9 scores and confidences as JSON. Fallback if parsing fails."""
transcript = transcript_to_text(chat_history)
features_json = json.dumps(features, ensure_ascii=False)
system_prompt = (
"You evaluate an on-going clinician-patient conversation to infer a PHQ-9 assessment. "
"Return ONLY a JSON object with: PHQ9_Scores (interest,mood,sleep,energy,appetite,self_worth,concentration,motor,suicidal_thoughts; each 0-3), "
"Confidences (list of 9 floats 0-1 in the same order), Total_Score (0-27), Severity (string), Confidence (min of confidences), "
"and High_Risk (boolean, true if any suicidal risk)."
)
user_prompt = (
"Conversation transcript:"\
f"\n{transcript}\n\n"
f"Acoustic features summary (approximate):\n{features_json}\n\n"
"Instructions: Infer PHQ9_Scores (0-3 per item), estimate Confidences per item, compute Total_Score and overall Severity. "
"Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose."
)
pipe = get_textgen_pipeline()
tokenizer = pipe.tokenizer
combined_prompt = system_prompt + "\n\n" + user_prompt
messages = [
{"role": "user", "content": combined_prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Use deterministic decoding to avoid CUDA sampling edge cases on some models
try:
import torch._dynamo as _dynamo # type: ignore
except Exception:
_dynamo = None
gen = pipe(
prompt,
max_new_tokens=256,
temperature=0.0,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
return_full_text=False,
)
out_text = gen[0]["generated_text"]
parsed = safe_json_extract(out_text)
# Validate and coerce
if parsed is None or "PHQ9_Scores" not in parsed:
# Simple fallback heuristic: neutral scores with low confidence
scores = {
"interest": 1,
"mood": 1,
"sleep": 1,
"energy": 1,
"appetite": 1,
"self_worth": 1,
"concentration": 1,
"motor": 1,
"suicidal_thoughts": 0,
}
confidences = [0.5] * 9
total = int(sum(scores.values()))
return {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity_from_total(total),
"Confidence": float(min(confidences)),
"High_Risk": False,
}
try:
# Coerce types and compute derived values if missing
scores = parsed.get("PHQ9_Scores", {})
# Ensure all keys present
keys = [
"interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts"
]
for k in keys:
scores[k] = int(max(0, min(3, int(scores.get(k, 0)))))
confidences = parsed.get("Confidences", [])
if not isinstance(confidences, list) or len(confidences) != 9:
confidences = [float(parsed.get("Confidence", 0.5))] * 9
confidences = [float(max(0.0, min(1.0, c))) for c in confidences]
total = int(sum(scores.values()))
severity = parsed.get("Severity") or severity_from_total(total)
overall_conf = float(parsed.get("Confidence", min(confidences)))
# Conservative high-risk detection: require explicit language or high suicidal_thoughts score
# Extract last patient message
last_patient = ""
for user_text, assistant_text in reversed(chat_history):
if user_text:
last_patient = user_text
break
explicit_flag = detect_explicit_suicidality(last_patient) or detect_explicit_suicidality(transcript)
high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2))
return {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity,
"Confidence": overall_conf,
"High_Risk": high_risk,
}
except Exception:
# Final fallback
scores = parsed.get("PHQ9_Scores", {}) if isinstance(parsed, dict) else {}
if not scores:
scores = {k: 1 for k in [
"interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts"
]}
confidences = [0.5] * 9
total = int(sum(scores.values()))
return {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity_from_total(total),
"Confidence": float(min(confidences)),
"High_Risk": False,
}
def transcribe_audio(audio_path: Optional[str]) -> str:
if not audio_path:
return ""
try:
asr = get_asr_pipeline()
result = asr(audio_path)
if isinstance(result, dict) and "text" in result:
return result["text"].strip()
if isinstance(result, list) and len(result) > 0 and "text" in result[0]:
return result[0]["text"].strip()
except Exception:
pass
return ""
# ---------------------------
# Gradio app logic
# ---------------------------
INTRO_MESSAGE = (
"Hi, I'm an assistant, and I will ask you some questions about how you've been doing."
"We'll record our conversation, and we will give you a written copy of it."
"From our conversation, we will send a written copy to the clinician, we will give a summary of what you are experiencing based on a questionnaire, called the Patient Health Questionnaire (PHQ-9), and we will give a summary of what your voice is like."
"We will send this to the clinician, and the clinician will follow up with you."
"To start, how has your mood been over the past couple of weeks?"
)
def init_state() -> Tuple[List[Tuple[str, str]], Dict[str, Any], Dict[str, Any], bool, int]:
chat_history: List[Tuple[str, str]] = [("", INTRO_MESSAGE)]
scores = {}
meta = {"Severity": None, "Total_Score": None, "Confidence": 0.0}
finished = False
turns = 0
return chat_history, scores, meta, finished, turns
@spaces.GPU
def process_turn(
audio_path: Optional[str],
text_input: Optional[str],
chat_history: List[Tuple[str, str]],
threshold: float,
tts_enabled: bool,
finished: Optional[bool],
turns: Optional[int],
prev_scores: Dict[str, Any],
prev_meta: Dict[str, Any],
):
# If already finished, do nothing
finished = bool(finished) if finished is not None else False
turns = int(turns) if isinstance(turns, int) else 0
if finished:
return (
chat_history,
{"info": "Assessment complete."},
prev_meta.get("Severity", ""),
finished,
turns,
None,
None,
None,
None,
)
patient_text = (text_input or "").strip()
audio_features: Dict[str, float] = {}
if audio_path:
# Transcribe first
transcribed = transcribe_audio(audio_path)
if transcribed:
patient_text = (patient_text + " ").strip() + transcribed if patient_text else transcribed
# Extract features
audio_features = compute_audio_features(audio_path)
if not patient_text:
# Ask user for input
chat_history.append(("", "I didn't catch that. Could you share a bit about how you've been feeling?"))
return (
chat_history,
prev_scores or {},
prev_meta.get("Severity", ""),
finished,
turns,
None,
None,
None,
None,
)
# Add patient's message
chat_history.append((patient_text, None))
# Scoring agent
scoring = scoring_agent_infer(chat_history, audio_features)
scores = scoring.get("PHQ9_Scores", {})
confidences = scoring.get("Confidences", [])
total = scoring.get("Total_Score", 0)
severity = scoring.get("Severity", severity_from_total(total))
overall_conf = float(scoring.get("Confidence", min(confidences) if confidences else 0.0))
# Override high-risk to reduce false positives: rely on explicit text or high item score
# Extract last patient message
last_patient = ""
for user_text, assistant_text in reversed(chat_history):
if user_text:
last_patient = user_text
break
explicit_flag = detect_explicit_suicidality(last_patient)
high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2))
meta = {"Severity": severity, "Total_Score": total, "Confidence": overall_conf}
# Termination conditions
min_conf = float(min(confidences)) if confidences else 0.0
turns += 1
done = high_risk or (min_conf >= threshold) or (turns >= MAX_TURNS)
if high_risk:
closing = (
"I’m concerned about your safety based on what you shared. "
"If you are in danger or need immediate help, please call 116 123 in the UK, 988 in the U.S. or your local emergency number. "
"I'll end the assessment now and display emergency resources."
)
chat_history[-1] = (chat_history[-1][0], closing)
finished = True
elif done:
summary = (
"Thank you for your time. The clinician will review your conversation and follow up with you."
"Here is a copy of our conversation so you can review it later."
)
chat_history[-1] = (chat_history[-1][0], summary)
finished = True
else:
# Iterative explainability (light) to guide next question
light_exp = explainability_light(chat_history, scores, confidences, float(threshold))
# Generate next clinician question with guidance
reply = generate_recording_agent_reply(chat_history, guidance=light_exp)
chat_history[-1] = (chat_history[-1][0], reply)
# TTS for the latest clinician message, if enabled
tts_path = synthesize_tts(chat_history[-1][1]) if tts_enabled else None
# Build a compact JSON for display
display_json = {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity,
"Confidence": overall_conf,
"High_Risk": high_risk,
# Include the last audio features and light explainability for downstream modules/UI
"Last_Audio_Features": audio_features,
"Explainability_Light": explainability_light(chat_history, scores, confidences, float(threshold)),
}
# Clear inputs after processing
return (
chat_history,
display_json,
severity,
finished,
turns,
None,
None,
tts_path,
tts_path,
)
def reset_app():
return init_state()
# ---------------------------
# UI
# ---------------------------
def _on_load_init():
return init_state()
def _on_load_init_with_tts(tts_on: bool):
chat_history, scores_state, meta_state, finished_state, turns_state = init_state()
# Play the intro message via TTS if enabled
tts_path = synthesize_tts(chat_history[-1][1]) if bool(tts_on) else None
return chat_history, scores_state, meta_state, finished_state, turns_state, tts_path
def _play_intro_tts(tts_on: bool):
if not bool(tts_on):
return None
try:
return synthesize_tts(INTRO_MESSAGE)
except Exception:
return None
def create_demo():
with gr.Blocks(
theme=gr.themes.Soft(),
css='''
/* Voice bubble styles - clean and centered */
#voice-bubble {
width: 240px; height: 240px; border-radius: 9999px; margin: 40px auto;
display: flex; align-items: center; justify-content: center;
background: linear-gradient(135deg, #6ee7b7 0%, #34d399 100%);
box-shadow: 0 20px 40px rgba(16,185,129,0.3), 0 0 0 1px rgba(255,255,255,0.1) inset;
transition: all 250ms cubic-bezier(0.4, 0, 0.2, 1);
cursor: default; /* green circle itself is not clickable */
pointer-events: none; /* ignore clicks on the green circle */
position: relative;
}
#voice-bubble:hover {
transform: translateY(-2px) scale(1.02);
box-shadow: 0 25px 50px rgba(16,185,129,0.4), 0 0 0 1px rgba(255,255,255,0.15) inset;
}
#voice-bubble:active { transform: translateY(0px) scale(0.98); }
#voice-bubble.listening {
animation: bubble-pulse 1.5s ease-in-out infinite;
background: linear-gradient(135deg, #fb7185 0%, #ef4444 100%);
box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 1px rgba(255,255,255,0.1) inset;
}
@keyframes bubble-pulse {
0%, 100% { transform: scale(1.0); box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 0 rgba(239,68,68,0.5); }
50% { transform: scale(1.05); box-shadow: 0 25px 50px rgba(239,68,68,0.5), 0 0 0 15px rgba(239,68,68,0.0); }
}
/* Hide microphone dropdown selector only */
#voice-bubble select { display: none !important; }
#voice-bubble .source-selection { display: none !important; }
#voice-bubble label[for] { display: none !important; }
/* Make the inner button the only clickable target */
#voice-bubble button { pointer-events: auto; cursor: pointer; }
/* Hide TTS player UI but keep it in DOM for autoplay */
#tts-player { width: 0 !important; height: 0 !important; opacity: 0 !important; position: absolute; pointer-events: none; }
/* Settings and back buttons in top right - compact */
#settings-btn {
position: absolute; top: 16px; right: 16px; z-index: 10;
width: auto !important; min-width: 100px !important; max-width: 120px !important;
padding: 8px 16px !important;
}
#back-btn {
position: absolute; top: 8px; right: 8px; z-index: 10;
width: auto !important; min-width: 80px !important; max-width: 100px !important;
padding: 8px 16px !important;
}
/* Play Intro button positioned under Settings */
#intro-btn {
position: absolute; top: 60px; right: 16px; z-index: 10;
width: auto !important; min-width: 100px !important; max-width: 120px !important;
padding: 8px 16px !important;
}
'''
) as demo:
# Main view
with gr.Column(visible=True) as main_view:
# Settings button (top right, only in main view)
settings_btn = gr.Button("⚙️ Settings", elem_id="settings-btn", size="sm")
gr.Markdown(
"""
### Conversational Assessment for Responsive Engagement (CARE) Notes
Tap on 'Record' to start speaking, then tap on 'Stop' to stop recording.
"""
)
intro_play_btn = gr.Button("▶️ Start", elem_id="intro-btn", variant="secondary", size="sm")
# Microphone component styled as central bubble (tap to record/stop)
audio_main = gr.Microphone(type="filepath", label=None, elem_id="voice-bubble", show_label=False)
# Hidden text input placeholder for pipeline compatibility
text_main = gr.Textbox(value="", visible=False)
# Autoplay clinician voice output (player hidden with CSS)
tts_audio_main = gr.Audio(label=None, interactive=False, autoplay=True, show_label=False, elem_id="tts-player")
# Final summaries (shown after assessment ends)
main_summary = gr.Markdown(visible=False)
# Settings view (initially hidden)
with gr.Column(visible=False) as settings_view:
back_btn = gr.Button("← Back", elem_id="back-btn", size="sm")
gr.Markdown("## Settings")
chatbot = gr.Chatbot(height=360, type="tuples", label="Conversation")
with gr.Row():
text_adv = gr.Textbox(placeholder="Type your message and press Enter", scale=4)
send_adv_btn = gr.Button("Send", scale=1)
score_json = gr.JSON(label="PHQ-9 Assessment (live)")
severity_label = gr.Label(label="Severity")
threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
with gr.Row():
tts_provider_dd = gr.Dropdown(choices=["Coqui", "gTTS"], value="Coqui", label="TTS Provider")
coqui_model_tb = gr.Textbox(value=os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits"), label="Coqui Model")
coqui_speaker_dd = gr.Dropdown(choices=list_coqui_speakers(os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")), value="p225", label="Coqui Speaker")
tts_audio = gr.Audio(label="Clinician voice", interactive=False, autoplay=False, visible=False)
model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it")
with gr.Row():
apply_model_btn = gr.Button("Apply model (no restart)")
# apply_model_restart_btn = gr.Button("Apply model and restart")
model_status = gr.Markdown(value=f"Current model: `{current_model_id}`")
# App state
chat_state = gr.State()
scores_state = gr.State()
meta_state = gr.State()
finished_state = gr.State()
turns_state = gr.State()
feats_state = gr.State()
# Initialize on load (no autoplay due to browser policies)
demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
# View navigation
settings_btn.click(
fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
inputs=None,
outputs=[main_view, settings_view]
)
back_btn.click(
fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
inputs=None,
outputs=[main_view, settings_view]
)
# Explicit user gesture to play intro TTS (works across browsers)
intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main])
# Wire interactions
def _process_with_tts(audio, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist):
result = process_turn(audio, text, chat, th, tts_on, finished, turns, scores, meta)
chat_history, display_json, severity, finished_o, turns_o, _, _, _, last_tts = result
# Accumulate last audio features
feats_hist = feats_hist or []
last_feats = (display_json or {}).get("Last_Audio_Features") or {}
if isinstance(last_feats, dict) and last_feats:
feats_hist = list(feats_hist) + [last_feats]
if tts_on and chat_history and chat_history[-1][1]:
new_path = synthesize_tts(chat_history[-1][1], provider=provider, coqui_model_name=coqui_model, coqui_speaker=coqui_speaker)
else:
new_path = None
# If finished, hide the mic and display summaries in Main
if finished_o:
# Run full explainability and reflection
exp_full = explainability_full(chat_history, display_json.get("Confidences", []), feats_hist)
reflect = reflection_module(display_json.get("PHQ9_Scores", {}), display_json.get("Confidences", []), display_json.get("Explainability_Light", {}), exp_full, float(th))
display_json["Explainability_Full"] = exp_full
display_json["Reflection_Report"] = reflect
# Use reflection outputs to set final meta
final_sev = reflect.get("severity_label") or severity
final_total = reflect.get("final_total") or display_json.get("Total_Score")
patient_md = build_patient_summary(chat_history, {"Severity": final_sev, "Total_Score": final_total}, display_json)
clinician_md = build_clinician_summary(chat_history, {"Severity": final_sev, "Total_Score": final_total}, display_json)
summary_md = patient_md + "\n\n---\n\n" + clinician_md
return chat_history, display_json, severity, finished_o, turns_o, gr.update(visible=False), None, new_path, new_path, gr.update(value=summary_md, visible=True), feats_hist
return chat_history, display_json, severity, finished_o, turns_o, None, None, new_path, new_path, gr.update(visible=False), feats_hist
audio_main.stop_recording(
fn=_process_with_tts,
inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state],
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state],
queue=True,
api_name="message",
)
# Text input flow from Advanced tab
def _process_text_and_clear(text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist):
res = _process_with_tts(None, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist)
return (*res, "")
text_adv.submit(
fn=_process_text_and_clear,
inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state],
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state, text_adv],
queue=True,
)
send_adv_btn.click(
fn=_process_text_and_clear,
inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state],
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state, text_adv],
queue=True,
)
# Tap bubble to toggle microphone record/stop via JS
# This JS is no longer needed as the bubble is the mic
# voice_bubble.click(
# None,
# inputs=None,
# outputs=None,
# js="() => {\n const bubble = document.getElementById('voice-bubble');\n const root = document.getElementById('hidden-mic');\n if (!root) return;\n let didClick = false;\n const wc = root.querySelector && root.querySelector('gradio-audio');\n if (wc && wc.shadowRoot) {\n const btns = Array.from(wc.shadowRoot.querySelectorAll('button')).filter(b => !b.disabled);\n const txt = (b) => ((b.getAttribute('aria-label')||'') + ' ' + (b.textContent||'')).toLowerCase();\n const stopBtn = btns.find(b => txt(b).includes('stop'));\n const recBtn = btns.find(b => { const t = txt(b); return t.includes('record') || t.includes('start') || t.includes('microphone') || t.includes('mic'); });\n if (stopBtn) { stopBtn.click(); didClick = true; } else if (recBtn) { recBtn.click(); didClick = true; } else if (btns[0]) { btns[0].click(); didClick = true; }\n }\n if (!didClick) {\n const candidates = Array.from(root.querySelectorAll('button, [role=\\'button\\']')).filter(el => !el.disabled);\n if (candidates.length) { candidates[0].click(); didClick = true; }\n }\n if (bubble && didClick) bubble.classList.toggle('listening');\n }",
# )
# No reset button in Main tab anymore
# Model switch handlers
def _on_apply_model(mid: str):
msg = set_current_model_id(mid)
return f"Current model: `{current_model_id}`\n\n{msg}"
def _on_apply_model_restart(mid: str):
msg = apply_model_and_restart(mid)
return f"{msg}"
apply_model_btn.click(fn=_on_apply_model, inputs=[model_id_tb], outputs=[model_status])
# apply_model_restart_btn.click(fn=_on_apply_model_restart, inputs=[model_id_tb], outputs=[model_status])
return demo
demo = create_demo()
if __name__ == "__main__":
# For local dev
demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))