Spaces:
Running
on
Zero
Running
on
Zero
Akis Giannoukos
commited on
Commit
·
09716a4
1
Parent(s):
9d16b48
Added explainability
Browse files
README.md
CHANGED
|
@@ -17,6 +17,9 @@ A lightweight research demo that simulates a clinician conducting a brief conver
|
|
| 17 |
## What it does
|
| 18 |
- Conversational assessment to infer PHQ‑9 items from natural dialogue (no explicit questionnaire).
|
| 19 |
- Live inference of PHQ‑9 item scores, confidences, total score, and severity.
|
|
|
|
|
|
|
|
|
|
| 20 |
- Automatic stop when minimum confidence across items reaches a threshold or risk is detected.
|
| 21 |
- Optional TTS playback for clinician responses.
|
| 22 |
|
|
@@ -78,10 +81,50 @@ Notes:
|
|
| 78 |
## Safety
|
| 79 |
This demo does not provide therapy or emergency counseling. If a user expresses suicidal intent or risk is inferred, the app ends the conversation and advises contacting emergency services (e.g., 988 in the U.S.).
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
## Development notes
|
| 82 |
- Framework: Gradio Blocks
|
| 83 |
- ASR: Transformers pipeline (Whisper)
|
| 84 |
-
- TTS: gTTS
|
| 85 |
-
- Prosody features: librosa
|
| 86 |
|
| 87 |
PRs and experiments are welcome. This is a research prototype and not a clinical tool.
|
|
|
|
| 17 |
## What it does
|
| 18 |
- Conversational assessment to infer PHQ‑9 items from natural dialogue (no explicit questionnaire).
|
| 19 |
- Live inference of PHQ‑9 item scores, confidences, total score, and severity.
|
| 20 |
+
- Iterative light explainability after each turn to guide the next question (strong/weak/missing evidence by item).
|
| 21 |
+
- Final explainability at session end aggregating linguistic quotes and acoustic prosody.
|
| 22 |
+
- Self‑reflection step that checks consistency and may adjust low‑confidence item scores.
|
| 23 |
- Automatic stop when minimum confidence across items reaches a threshold or risk is detected.
|
| 24 |
- Optional TTS playback for clinician responses.
|
| 25 |
|
|
|
|
| 81 |
## Safety
|
| 82 |
This demo does not provide therapy or emergency counseling. If a user expresses suicidal intent or risk is inferred, the app ends the conversation and advises contacting emergency services (e.g., 988 in the U.S.).
|
| 83 |
|
| 84 |
+
## Architecture
|
| 85 |
+
RecordingAgent → ScoringAgent → ExplainabilityModule(light/full) → ReflectionModule → ReportGenerator
|
| 86 |
+
|
| 87 |
+
- RecordingAgent: generates clinician follow‑ups; guided by light explainability when available.
|
| 88 |
+
- ScoringAgent: infers PHQ‑9 item scores and per‑item confidences from transcript (+prosody summary).
|
| 89 |
+
- Explainability (light): keyword‑based evidence strength per item; selects next focus area.
|
| 90 |
+
- Explainability (full): aggregates transcript quotes and averaged prosody features into per‑item objects.
|
| 91 |
+
- Reflection: heuristic pass reduces scores by 1 for items with confidence < τ and missing evidence.
|
| 92 |
+
- ReportGenerator: patient and clinician summaries, confidence bars, highlights, and reflection notes.
|
| 93 |
+
|
| 94 |
+
### Output objects
|
| 95 |
+
- Explainability (light):
|
| 96 |
+
```json
|
| 97 |
+
{
|
| 98 |
+
"evidence_strength": {"appetite": "missing", ...},
|
| 99 |
+
"recommended_focus": "appetite",
|
| 100 |
+
"quotes": {"appetite": ["..."], ...},
|
| 101 |
+
"confidences": {"appetite": 0.34, ...}
|
| 102 |
+
}
|
| 103 |
+
```
|
| 104 |
+
- Explainability (full):
|
| 105 |
+
```json
|
| 106 |
+
{
|
| 107 |
+
"items": [
|
| 108 |
+
{"item":"appetite","confidence":0.42,"evidence":["..."],"prosody":["rms_mean=0.012", "zcr_mean=0.065", ...]}
|
| 109 |
+
],
|
| 110 |
+
"notes": "Heuristic placeholder"
|
| 111 |
+
}
|
| 112 |
+
```
|
| 113 |
+
- Reflection report:
|
| 114 |
+
```json
|
| 115 |
+
{
|
| 116 |
+
"corrected_scores": {"appetite": 1, ...},
|
| 117 |
+
"final_total": 12,
|
| 118 |
+
"severity_label": "Moderate Depression",
|
| 119 |
+
"consistency_score": 0.89,
|
| 120 |
+
"notes": "Model revised appetite score due to low confidence and missing evidence."
|
| 121 |
+
}
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
## Development notes
|
| 125 |
- Framework: Gradio Blocks
|
| 126 |
- ASR: Transformers pipeline (Whisper)
|
| 127 |
+
- TTS: gTTS or Coqui TTS
|
| 128 |
+
- Prosody features: librosa proxies; replaceable by OpenSMILE
|
| 129 |
|
| 130 |
PRs and experiments are welcome. This is a research prototype and not a clinical tool.
|
app.py
CHANGED
|
@@ -288,6 +288,36 @@ def severity_from_total(total_score: int) -> str:
|
|
| 288 |
return "Severe Depression"
|
| 289 |
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str:
|
| 292 |
"""Convert chatbot history [(user, assistant), ...] to a plain text transcript."""
|
| 293 |
lines = []
|
|
@@ -299,15 +329,195 @@ def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str:
|
|
| 299 |
return "\n".join(lines)
|
| 300 |
|
| 301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
def build_patient_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str:
|
| 303 |
severity = meta.get("Severity") or display_json.get("Severity")
|
| 304 |
total = meta.get("Total_Score") or display_json.get("Total_Score")
|
| 305 |
transcript_text = transcript_to_text(chat_history)
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
|
| 313 |
def build_clinician_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str:
|
|
@@ -319,17 +529,75 @@ def build_clinician_summary(chat_history: List[Tuple[str, str]], meta: Dict[str,
|
|
| 319 |
transcript_text = transcript_to_text(chat_history)
|
| 320 |
scores_lines = "\n".join([f"- {k}: {v}" for k, v in scores.items()])
|
| 321 |
conf_str = ", ".join([f"{c:.2f}" for c in confidences]) if confidences else ""
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
)
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
transcript = transcript_to_text(chat_history)
|
| 334 |
system_prompt = (
|
| 335 |
"You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms "
|
|
@@ -337,9 +605,17 @@ def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
|
|
| 337 |
"Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, "
|
| 338 |
"sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts."
|
| 339 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
user_prompt = (
|
| 341 |
"Conversation so far (Patient and Clinician turns):\n\n" + transcript +
|
| 342 |
-
"\n\nRespond with a single short clinician-style question for the patient."
|
| 343 |
)
|
| 344 |
pipe = get_textgen_pipeline()
|
| 345 |
tokenizer = pipe.tokenizer
|
|
@@ -620,8 +896,10 @@ def process_turn(
|
|
| 620 |
chat_history[-1] = (chat_history[-1][0], summary)
|
| 621 |
finished = True
|
| 622 |
else:
|
| 623 |
-
#
|
| 624 |
-
|
|
|
|
|
|
|
| 625 |
chat_history[-1] = (chat_history[-1][0], reply)
|
| 626 |
|
| 627 |
# TTS for the latest clinician message, if enabled
|
|
@@ -635,6 +913,9 @@ def process_turn(
|
|
| 635 |
"Severity": severity,
|
| 636 |
"Confidence": overall_conf,
|
| 637 |
"High_Risk": high_risk,
|
|
|
|
|
|
|
|
|
|
| 638 |
}
|
| 639 |
|
| 640 |
# Clear inputs after processing
|
|
@@ -782,6 +1063,7 @@ def create_demo():
|
|
| 782 |
meta_state = gr.State()
|
| 783 |
finished_state = gr.State()
|
| 784 |
turns_state = gr.State()
|
|
|
|
| 785 |
|
| 786 |
# Initialize on load (no autoplay due to browser policies)
|
| 787 |
demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
|
|
@@ -802,44 +1084,57 @@ def create_demo():
|
|
| 802 |
intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main])
|
| 803 |
|
| 804 |
# Wire interactions
|
| 805 |
-
def _process_with_tts(audio, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker):
|
| 806 |
result = process_turn(audio, text, chat, th, tts_on, finished, turns, scores, meta)
|
| 807 |
chat_history, display_json, severity, finished_o, turns_o, _, _, _, last_tts = result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 808 |
if tts_on and chat_history and chat_history[-1][1]:
|
| 809 |
new_path = synthesize_tts(chat_history[-1][1], provider=provider, coqui_model_name=coqui_model, coqui_speaker=coqui_speaker)
|
| 810 |
else:
|
| 811 |
new_path = None
|
| 812 |
# If finished, hide the mic and display summaries in Main
|
| 813 |
if finished_o:
|
| 814 |
-
|
| 815 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
summary_md = patient_md + "\n\n---\n\n" + clinician_md
|
| 817 |
-
return chat_history, display_json, severity, finished_o, turns_o, gr.update(visible=False), None, new_path, new_path, gr.update(value=summary_md, visible=True)
|
| 818 |
-
return chat_history, display_json, severity, finished_o, turns_o, None, None, new_path, new_path, gr.update(visible=False)
|
| 819 |
|
| 820 |
audio_main.stop_recording(
|
| 821 |
fn=_process_with_tts,
|
| 822 |
-
inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd],
|
| 823 |
-
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary],
|
| 824 |
queue=True,
|
| 825 |
api_name="message",
|
| 826 |
)
|
| 827 |
|
| 828 |
# Text input flow from Advanced tab
|
| 829 |
-
def _process_text_and_clear(text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker):
|
| 830 |
-
res = _process_with_tts(None, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker)
|
| 831 |
return (*res, "")
|
| 832 |
|
| 833 |
text_adv.submit(
|
| 834 |
fn=_process_text_and_clear,
|
| 835 |
-
inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd],
|
| 836 |
-
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, text_adv],
|
| 837 |
queue=True,
|
| 838 |
)
|
| 839 |
send_adv_btn.click(
|
| 840 |
fn=_process_text_and_clear,
|
| 841 |
-
inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd],
|
| 842 |
-
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, text_adv],
|
| 843 |
queue=True,
|
| 844 |
)
|
| 845 |
|
|
|
|
| 288 |
return "Severe Depression"
|
| 289 |
|
| 290 |
|
| 291 |
+
# ---------------------------
|
| 292 |
+
# PHQ-9 schema and helpers
|
| 293 |
+
# ---------------------------
|
| 294 |
+
PHQ9_KEYS_ORDERED: List[str] = [
|
| 295 |
+
"interest",
|
| 296 |
+
"mood",
|
| 297 |
+
"sleep",
|
| 298 |
+
"energy",
|
| 299 |
+
"appetite",
|
| 300 |
+
"self_worth",
|
| 301 |
+
"concentration",
|
| 302 |
+
"motor",
|
| 303 |
+
"suicidal_thoughts",
|
| 304 |
+
]
|
| 305 |
+
|
| 306 |
+
# Lightweight keyword lexicon per item for evidence extraction.
|
| 307 |
+
# Placeholder for future SHAP/attention-based attributions.
|
| 308 |
+
PHQ9_KEYWORDS: Dict[str, List[str]] = {
|
| 309 |
+
"interest": ["interest", "pleasure", "enjoy", "motivation", "hobbies"],
|
| 310 |
+
"mood": ["depressed", "down", "sad", "hopeless", "blue", "mood"],
|
| 311 |
+
"sleep": ["sleep", "insomnia", "awake", "wake up", "night", "dream"],
|
| 312 |
+
"energy": ["tired", "fatigue", "energy", "exhausted", "worn out"],
|
| 313 |
+
"appetite": ["appetite", "eat", "eating", "hungry", "food", "weight"],
|
| 314 |
+
"self_worth": ["worthless", "failure", "guilty", "guilt", "self-esteem", "ashamed"],
|
| 315 |
+
"concentration": ["concentrate", "focus", "attention", "distracted", "remember"],
|
| 316 |
+
"motor": ["restless", "slow", "slowed", "agitated", "fidget", "move"],
|
| 317 |
+
"suicidal_thoughts": ["suicide", "kill myself", "die", "end my life", "self-harm", "hurt myself"],
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
|
| 321 |
def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str:
|
| 322 |
"""Convert chatbot history [(user, assistant), ...] to a plain text transcript."""
|
| 323 |
lines = []
|
|
|
|
| 329 |
return "\n".join(lines)
|
| 330 |
|
| 331 |
|
| 332 |
+
def _patient_sentences(chat_history: List[Tuple[str, str]]) -> List[str]:
|
| 333 |
+
"""Extract patient-only sentences from chat history."""
|
| 334 |
+
sentences: List[str] = []
|
| 335 |
+
for user, _assistant in chat_history:
|
| 336 |
+
if not user:
|
| 337 |
+
continue
|
| 338 |
+
parts = re.split(r"(?<=[.!?])\s+", user.strip())
|
| 339 |
+
for p in parts:
|
| 340 |
+
p = p.strip()
|
| 341 |
+
if p:
|
| 342 |
+
sentences.append(p)
|
| 343 |
+
return sentences
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def _extract_quotes_per_item(chat_history: List[Tuple[str, str]]) -> Dict[str, List[str]]:
|
| 347 |
+
"""Heuristic extraction of per-item evidence quotes from patient sentences based on keywords."""
|
| 348 |
+
quotes: Dict[str, List[str]] = {k: [] for k in PHQ9_KEYS_ORDERED}
|
| 349 |
+
sentences = _patient_sentences(chat_history)
|
| 350 |
+
for sent in sentences:
|
| 351 |
+
s_low = sent.lower()
|
| 352 |
+
for item, kws in PHQ9_KEYWORDS.items():
|
| 353 |
+
if any(kw in s_low for kw in kws):
|
| 354 |
+
if len(quotes[item]) < 5:
|
| 355 |
+
quotes[item].append(sent)
|
| 356 |
+
return quotes
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def explainability_light(
|
| 360 |
+
chat_history: List[Tuple[str, str]],
|
| 361 |
+
scores: Dict[str, int],
|
| 362 |
+
confidences: List[float],
|
| 363 |
+
threshold: float,
|
| 364 |
+
) -> Dict[str, Any]:
|
| 365 |
+
"""Lightweight explainability per turn.
|
| 366 |
+
|
| 367 |
+
- Inspects transcript for keyword-based evidence per PHQ-9 item.
|
| 368 |
+
- Classifies evidence strength as strong/weak/missing using keyword hits and confidence.
|
| 369 |
+
- Suggests next focus item based on lowest-confidence or missing evidence.
|
| 370 |
+
|
| 371 |
+
Returns a JSON-serializable dict.
|
| 372 |
+
"""
|
| 373 |
+
quotes = _extract_quotes_per_item(chat_history)
|
| 374 |
+
conf_map: Dict[str, float] = {}
|
| 375 |
+
for idx, key in enumerate(PHQ9_KEYS_ORDERED):
|
| 376 |
+
conf_map[key] = float(confidences[idx] if idx < len(confidences) else 0.0)
|
| 377 |
+
|
| 378 |
+
evidence_strength: Dict[str, str] = {}
|
| 379 |
+
for key in PHQ9_KEYS_ORDERED:
|
| 380 |
+
hits = len(quotes.get(key, []))
|
| 381 |
+
conf = conf_map.get(key, 0.0)
|
| 382 |
+
if hits >= 2 and conf >= max(0.6, threshold - 0.1):
|
| 383 |
+
evidence_strength[key] = "strong"
|
| 384 |
+
elif hits >= 1 or conf >= 0.4:
|
| 385 |
+
evidence_strength[key] = "weak"
|
| 386 |
+
else:
|
| 387 |
+
evidence_strength[key] = "missing"
|
| 388 |
+
|
| 389 |
+
low_items = sorted(
|
| 390 |
+
PHQ9_KEYS_ORDERED,
|
| 391 |
+
key=lambda k: (evidence_strength[k] != "missing", conf_map.get(k, 0.0))
|
| 392 |
+
)
|
| 393 |
+
recommended = low_items[0] if low_items else None
|
| 394 |
+
|
| 395 |
+
return {
|
| 396 |
+
"evidence_strength": evidence_strength,
|
| 397 |
+
"low_confidence_items": [k for k in sorted(PHQ9_KEYS_ORDERED, key=lambda x: conf_map.get(x, 0.0))],
|
| 398 |
+
"recommended_focus": recommended,
|
| 399 |
+
"quotes": quotes,
|
| 400 |
+
"confidences": conf_map,
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def explainability_full(
|
| 405 |
+
chat_history: List[Tuple[str, str]],
|
| 406 |
+
confidences: List[float],
|
| 407 |
+
features_history: Optional[List[Dict[str, float]]],
|
| 408 |
+
) -> Dict[str, Any]:
|
| 409 |
+
"""Aggregate linguistic and acoustic attributions at session end.
|
| 410 |
+
|
| 411 |
+
- Linguistic: keyword-based quotes per item (placeholder for SHAP/attention).
|
| 412 |
+
- Acoustic: mean of per-turn prosodic features; returned as name=value strings.
|
| 413 |
+
"""
|
| 414 |
+
def _aggregate_prosody(history: List[Dict[str, float]]) -> Dict[str, float]:
|
| 415 |
+
agg: Dict[str, float] = {}
|
| 416 |
+
if not history:
|
| 417 |
+
return agg
|
| 418 |
+
keys = set().union(*[d.keys() for d in history if isinstance(d, dict)])
|
| 419 |
+
for k in keys:
|
| 420 |
+
vals = [float(d[k]) for d in history if isinstance(d, dict) and k in d]
|
| 421 |
+
if vals:
|
| 422 |
+
agg[k] = float(np.mean(vals))
|
| 423 |
+
return agg
|
| 424 |
+
|
| 425 |
+
quotes = _extract_quotes_per_item(chat_history)
|
| 426 |
+
conf_map = {k: float(confidences[i] if i < len(confidences) else 0.0) for i, k in enumerate(PHQ9_KEYS_ORDERED)}
|
| 427 |
+
prosody_agg = _aggregate_prosody(list(features_history or []))
|
| 428 |
+
prosody_pairs = sorted(list(prosody_agg.items()), key=lambda kv: -abs(kv[1]))
|
| 429 |
+
prosody_names = [f"{k}={v:.3f}" for k, v in prosody_pairs[:8]]
|
| 430 |
+
|
| 431 |
+
items = []
|
| 432 |
+
for k in PHQ9_KEYS_ORDERED:
|
| 433 |
+
items.append({
|
| 434 |
+
"item": k,
|
| 435 |
+
"confidence": conf_map.get(k, 0.0),
|
| 436 |
+
"evidence": quotes.get(k, [])[:5],
|
| 437 |
+
"prosody": prosody_names,
|
| 438 |
+
})
|
| 439 |
+
return {
|
| 440 |
+
"items": items,
|
| 441 |
+
"notes": "Heuristic keyword and prosody aggregation; plug in SHAP/attention later.",
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def reflection_module(
|
| 446 |
+
scores: Dict[str, int],
|
| 447 |
+
confidences: List[float],
|
| 448 |
+
exp_light: Optional[Dict[str, Any]],
|
| 449 |
+
exp_full: Optional[Dict[str, Any]],
|
| 450 |
+
threshold: float,
|
| 451 |
+
) -> Dict[str, Any]:
|
| 452 |
+
"""Self-reflection / output reevaluation.
|
| 453 |
+
|
| 454 |
+
Heuristic: if confidence for an item < threshold and evidence is missing, reduce score by 1 (min 0).
|
| 455 |
+
Returns a `reflection_report` JSON with corrected scores and final summary.
|
| 456 |
+
"""
|
| 457 |
+
corrected = dict(scores or {})
|
| 458 |
+
strength = (exp_light or {}).get("evidence_strength", {}) if isinstance(exp_light, dict) else {}
|
| 459 |
+
changes: List[Tuple[str, int, int]] = []
|
| 460 |
+
for i, k in enumerate(PHQ9_KEYS_ORDERED):
|
| 461 |
+
conf = float(confidences[i] if i < len(confidences) else 0.0)
|
| 462 |
+
if conf < float(threshold) and strength.get(k) == "missing":
|
| 463 |
+
new_val = max(0, int(corrected.get(k, 0)) - 1)
|
| 464 |
+
if new_val != corrected.get(k, 0):
|
| 465 |
+
changes.append((k, int(corrected.get(k, 0)), new_val))
|
| 466 |
+
corrected[k] = new_val
|
| 467 |
+
final_total = int(sum(corrected.values()))
|
| 468 |
+
final_sev = severity_from_total(final_total)
|
| 469 |
+
consistency = float(1.0 - (len(changes) / max(1, len(PHQ9_KEYS_ORDERED))))
|
| 470 |
+
if changes:
|
| 471 |
+
notes = ", ".join([f"{k}: {old}->{new}" for k, old, new in changes])
|
| 472 |
+
notes = f"Model revised items due to low confidence and missing evidence: {notes}."
|
| 473 |
+
else:
|
| 474 |
+
notes = "No score revisions; explanations consistent with outputs."
|
| 475 |
+
return {
|
| 476 |
+
"corrected_scores": corrected,
|
| 477 |
+
"final_total": final_total,
|
| 478 |
+
"severity_label": final_sev,
|
| 479 |
+
"consistency_score": consistency,
|
| 480 |
+
"notes": notes,
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
|
| 485 |
def build_patient_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str:
|
| 486 |
severity = meta.get("Severity") or display_json.get("Severity")
|
| 487 |
total = meta.get("Total_Score") or display_json.get("Total_Score")
|
| 488 |
transcript_text = transcript_to_text(chat_history)
|
| 489 |
+
# Optional enriched content
|
| 490 |
+
exp_full = display_json.get("Explainability_Full") or {}
|
| 491 |
+
reflection = display_json.get("Reflection_Report") or {}
|
| 492 |
+
|
| 493 |
+
lines = []
|
| 494 |
+
lines.append("# Summary for Patient\n")
|
| 495 |
+
if total is not None and severity:
|
| 496 |
+
lines.append(f"- PHQ‑9 Total: **{total}** ")
|
| 497 |
+
lines.append(f"- Severity: **{severity}**\n")
|
| 498 |
+
|
| 499 |
+
# Highlights: show one quote per item if available
|
| 500 |
+
if exp_full and isinstance(exp_full, dict):
|
| 501 |
+
items = exp_full.get("items", [])
|
| 502 |
+
if isinstance(items, list) and items:
|
| 503 |
+
lines.append("### Highlights from our conversation\n")
|
| 504 |
+
for it in items:
|
| 505 |
+
item = it.get("item")
|
| 506 |
+
ev = it.get("evidence", [])
|
| 507 |
+
if item and ev:
|
| 508 |
+
lines.append(f"- {item}: \"{ev[0]}\"")
|
| 509 |
+
lines.append("")
|
| 510 |
+
|
| 511 |
+
if reflection:
|
| 512 |
+
note = reflection.get("notes")
|
| 513 |
+
if note:
|
| 514 |
+
lines.append("### Reflection\n")
|
| 515 |
+
lines.append(note)
|
| 516 |
+
lines.append("")
|
| 517 |
+
|
| 518 |
+
lines.append("### Conversation Transcript\n\n")
|
| 519 |
+
lines.append(f"```\n{transcript_text}\n```")
|
| 520 |
+
return "\n".join(lines)
|
| 521 |
|
| 522 |
|
| 523 |
def build_clinician_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str:
|
|
|
|
| 529 |
transcript_text = transcript_to_text(chat_history)
|
| 530 |
scores_lines = "\n".join([f"- {k}: {v}" for k, v in scores.items()])
|
| 531 |
conf_str = ", ".join([f"{c:.2f}" for c in confidences]) if confidences else ""
|
| 532 |
+
# Optional explainability
|
| 533 |
+
exp_light = display_json.get("Explainability_Light") or {}
|
| 534 |
+
exp_full = display_json.get("Explainability_Full") or {}
|
| 535 |
+
reflection = display_json.get("Reflection_Report") or {}
|
| 536 |
+
|
| 537 |
+
md = []
|
| 538 |
+
md.append("# Summary for Clinician\n")
|
| 539 |
+
md.append(f"- Severity: **{severity}** ")
|
| 540 |
+
md.append(f"- PHQ‑9 Total: **{total}** ")
|
| 541 |
+
if risk is not None:
|
| 542 |
+
md.append(f"- High Risk: **{risk}** ")
|
| 543 |
+
md.append("")
|
| 544 |
+
md.append("### Item Scores\n" + scores_lines + "\n")
|
| 545 |
+
|
| 546 |
+
# Confidence bars
|
| 547 |
+
if confidences:
|
| 548 |
+
bars = []
|
| 549 |
+
for i, k in enumerate(scores.keys()):
|
| 550 |
+
c = confidences[i] if i < len(confidences) else 0.0
|
| 551 |
+
bar_len = int(round(c * 20))
|
| 552 |
+
bars.append(f"- {k}: [{'#'*bar_len}{'.'*(20-bar_len)}] {c:.2f}")
|
| 553 |
+
md.append("### Confidence by item\n" + "\n".join(bars) + "\n")
|
| 554 |
+
|
| 555 |
+
# Light explainability snapshot
|
| 556 |
+
if exp_light:
|
| 557 |
+
strength = exp_light.get("evidence_strength", {})
|
| 558 |
+
recommended = exp_light.get("recommended_focus")
|
| 559 |
+
if strength:
|
| 560 |
+
md.append("### Evidence strength (light)\n")
|
| 561 |
+
md.extend([f"- {k}: {v}" for k, v in strength.items()])
|
| 562 |
+
md.append("")
|
| 563 |
+
if recommended:
|
| 564 |
+
md.append(f"- Next focus (if continuing): **{recommended}**\n")
|
| 565 |
+
|
| 566 |
+
# Full explainability excerpts
|
| 567 |
+
if exp_full and isinstance(exp_full, dict):
|
| 568 |
+
md.append("### Explainability (final)\n")
|
| 569 |
+
items = exp_full.get("items", [])
|
| 570 |
+
for it in items:
|
| 571 |
+
item = it.get("item")
|
| 572 |
+
conf = it.get("confidence")
|
| 573 |
+
ev = it.get("evidence", [])
|
| 574 |
+
pros = it.get("prosody", [])
|
| 575 |
+
if item:
|
| 576 |
+
md.append(f"- {item} (conf {conf:.2f}):")
|
| 577 |
+
for q in ev[:2]:
|
| 578 |
+
md.append(f" - \"{q}\"")
|
| 579 |
+
if pros:
|
| 580 |
+
md.append(f" - prosody: {', '.join([str(p) for p in pros[:4]])}")
|
| 581 |
+
md.append("")
|
| 582 |
+
|
| 583 |
+
# Reflection summary
|
| 584 |
+
if reflection:
|
| 585 |
+
md.append("### Self-reflection\n")
|
| 586 |
+
notes = reflection.get("notes")
|
| 587 |
+
if notes:
|
| 588 |
+
md.append(notes)
|
| 589 |
+
corr = reflection.get("corrected_scores") or {}
|
| 590 |
+
if corr and corr != scores:
|
| 591 |
+
changed = [k for k in scores.keys() if corr.get(k) != scores.get(k)]
|
| 592 |
+
if changed:
|
| 593 |
+
md.append("- Adjusted items: " + ", ".join(changed))
|
| 594 |
+
md.append("")
|
| 595 |
+
|
| 596 |
+
md.append("### Conversation Transcript\n\n")
|
| 597 |
+
md.append(f"```\n{transcript_text}\n```")
|
| 598 |
+
return "\n".join(md)
|
| 599 |
+
|
| 600 |
+
def generate_recording_agent_reply(chat_history: List[Tuple[str, str]], guidance: Optional[Dict[str, Any]] = None) -> str:
|
| 601 |
transcript = transcript_to_text(chat_history)
|
| 602 |
system_prompt = (
|
| 603 |
"You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms "
|
|
|
|
| 605 |
"Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, "
|
| 606 |
"sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts."
|
| 607 |
)
|
| 608 |
+
focus_text = ""
|
| 609 |
+
if guidance and isinstance(guidance, dict):
|
| 610 |
+
rec = guidance.get("recommended_focus")
|
| 611 |
+
if rec:
|
| 612 |
+
focus_text = (
|
| 613 |
+
f"\n\nGuidance: Focus the next question on the patient's {str(rec).replace('_', ' ')}. "
|
| 614 |
+
"Ask naturally about recent changes and their impact on daily life."
|
| 615 |
+
)
|
| 616 |
user_prompt = (
|
| 617 |
"Conversation so far (Patient and Clinician turns):\n\n" + transcript +
|
| 618 |
+
f"{focus_text}\n\nRespond with a single short clinician-style question for the patient."
|
| 619 |
)
|
| 620 |
pipe = get_textgen_pipeline()
|
| 621 |
tokenizer = pipe.tokenizer
|
|
|
|
| 896 |
chat_history[-1] = (chat_history[-1][0], summary)
|
| 897 |
finished = True
|
| 898 |
else:
|
| 899 |
+
# Iterative explainability (light) to guide next question
|
| 900 |
+
light_exp = explainability_light(chat_history, scores, confidences, float(threshold))
|
| 901 |
+
# Generate next clinician question with guidance
|
| 902 |
+
reply = generate_recording_agent_reply(chat_history, guidance=light_exp)
|
| 903 |
chat_history[-1] = (chat_history[-1][0], reply)
|
| 904 |
|
| 905 |
# TTS for the latest clinician message, if enabled
|
|
|
|
| 913 |
"Severity": severity,
|
| 914 |
"Confidence": overall_conf,
|
| 915 |
"High_Risk": high_risk,
|
| 916 |
+
# Include the last audio features and light explainability for downstream modules/UI
|
| 917 |
+
"Last_Audio_Features": audio_features,
|
| 918 |
+
"Explainability_Light": explainability_light(chat_history, scores, confidences, float(threshold)),
|
| 919 |
}
|
| 920 |
|
| 921 |
# Clear inputs after processing
|
|
|
|
| 1063 |
meta_state = gr.State()
|
| 1064 |
finished_state = gr.State()
|
| 1065 |
turns_state = gr.State()
|
| 1066 |
+
feats_state = gr.State()
|
| 1067 |
|
| 1068 |
# Initialize on load (no autoplay due to browser policies)
|
| 1069 |
demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
|
|
|
|
| 1084 |
intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main])
|
| 1085 |
|
| 1086 |
# Wire interactions
|
| 1087 |
+
def _process_with_tts(audio, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist):
|
| 1088 |
result = process_turn(audio, text, chat, th, tts_on, finished, turns, scores, meta)
|
| 1089 |
chat_history, display_json, severity, finished_o, turns_o, _, _, _, last_tts = result
|
| 1090 |
+
# Accumulate last audio features
|
| 1091 |
+
feats_hist = feats_hist or []
|
| 1092 |
+
last_feats = (display_json or {}).get("Last_Audio_Features") or {}
|
| 1093 |
+
if isinstance(last_feats, dict) and last_feats:
|
| 1094 |
+
feats_hist = list(feats_hist) + [last_feats]
|
| 1095 |
if tts_on and chat_history and chat_history[-1][1]:
|
| 1096 |
new_path = synthesize_tts(chat_history[-1][1], provider=provider, coqui_model_name=coqui_model, coqui_speaker=coqui_speaker)
|
| 1097 |
else:
|
| 1098 |
new_path = None
|
| 1099 |
# If finished, hide the mic and display summaries in Main
|
| 1100 |
if finished_o:
|
| 1101 |
+
# Run full explainability and reflection
|
| 1102 |
+
exp_full = explainability_full(chat_history, display_json.get("Confidences", []), feats_hist)
|
| 1103 |
+
reflect = reflection_module(display_json.get("PHQ9_Scores", {}), display_json.get("Confidences", []), display_json.get("Explainability_Light", {}), exp_full, float(th))
|
| 1104 |
+
display_json["Explainability_Full"] = exp_full
|
| 1105 |
+
display_json["Reflection_Report"] = reflect
|
| 1106 |
+
# Use reflection outputs to set final meta
|
| 1107 |
+
final_sev = reflect.get("severity_label") or severity
|
| 1108 |
+
final_total = reflect.get("final_total") or display_json.get("Total_Score")
|
| 1109 |
+
patient_md = build_patient_summary(chat_history, {"Severity": final_sev, "Total_Score": final_total}, display_json)
|
| 1110 |
+
clinician_md = build_clinician_summary(chat_history, {"Severity": final_sev, "Total_Score": final_total}, display_json)
|
| 1111 |
summary_md = patient_md + "\n\n---\n\n" + clinician_md
|
| 1112 |
+
return chat_history, display_json, severity, finished_o, turns_o, gr.update(visible=False), None, new_path, new_path, gr.update(value=summary_md, visible=True), feats_hist
|
| 1113 |
+
return chat_history, display_json, severity, finished_o, turns_o, None, None, new_path, new_path, gr.update(visible=False), feats_hist
|
| 1114 |
|
| 1115 |
audio_main.stop_recording(
|
| 1116 |
fn=_process_with_tts,
|
| 1117 |
+
inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state],
|
| 1118 |
+
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state],
|
| 1119 |
queue=True,
|
| 1120 |
api_name="message",
|
| 1121 |
)
|
| 1122 |
|
| 1123 |
# Text input flow from Advanced tab
|
| 1124 |
+
def _process_text_and_clear(text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist):
|
| 1125 |
+
res = _process_with_tts(None, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist)
|
| 1126 |
return (*res, "")
|
| 1127 |
|
| 1128 |
text_adv.submit(
|
| 1129 |
fn=_process_text_and_clear,
|
| 1130 |
+
inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state],
|
| 1131 |
+
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state, text_adv],
|
| 1132 |
queue=True,
|
| 1133 |
)
|
| 1134 |
send_adv_btn.click(
|
| 1135 |
fn=_process_text_and_clear,
|
| 1136 |
+
inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state],
|
| 1137 |
+
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state, text_adv],
|
| 1138 |
queue=True,
|
| 1139 |
)
|
| 1140 |
|