Akis Giannoukos commited on
Commit
09716a4
·
1 Parent(s): 9d16b48

Added explainability

Browse files
Files changed (2) hide show
  1. README.md +45 -2
  2. app.py +327 -32
README.md CHANGED
@@ -17,6 +17,9 @@ A lightweight research demo that simulates a clinician conducting a brief conver
17
  ## What it does
18
  - Conversational assessment to infer PHQ‑9 items from natural dialogue (no explicit questionnaire).
19
  - Live inference of PHQ‑9 item scores, confidences, total score, and severity.
 
 
 
20
  - Automatic stop when minimum confidence across items reaches a threshold or risk is detected.
21
  - Optional TTS playback for clinician responses.
22
 
@@ -78,10 +81,50 @@ Notes:
78
  ## Safety
79
  This demo does not provide therapy or emergency counseling. If a user expresses suicidal intent or risk is inferred, the app ends the conversation and advises contacting emergency services (e.g., 988 in the U.S.).
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  ## Development notes
82
  - Framework: Gradio Blocks
83
  - ASR: Transformers pipeline (Whisper)
84
- - TTS: gTTS
85
- - Prosody features: librosa (lightweight proxies) for the scoring prompt
86
 
87
  PRs and experiments are welcome. This is a research prototype and not a clinical tool.
 
17
  ## What it does
18
  - Conversational assessment to infer PHQ‑9 items from natural dialogue (no explicit questionnaire).
19
  - Live inference of PHQ‑9 item scores, confidences, total score, and severity.
20
+ - Iterative light explainability after each turn to guide the next question (strong/weak/missing evidence by item).
21
+ - Final explainability at session end aggregating linguistic quotes and acoustic prosody.
22
+ - Self‑reflection step that checks consistency and may adjust low‑confidence item scores.
23
  - Automatic stop when minimum confidence across items reaches a threshold or risk is detected.
24
  - Optional TTS playback for clinician responses.
25
 
 
81
  ## Safety
82
  This demo does not provide therapy or emergency counseling. If a user expresses suicidal intent or risk is inferred, the app ends the conversation and advises contacting emergency services (e.g., 988 in the U.S.).
83
 
84
+ ## Architecture
85
+ RecordingAgent → ScoringAgent → ExplainabilityModule(light/full) → ReflectionModule → ReportGenerator
86
+
87
+ - RecordingAgent: generates clinician follow‑ups; guided by light explainability when available.
88
+ - ScoringAgent: infers PHQ‑9 item scores and per‑item confidences from transcript (+prosody summary).
89
+ - Explainability (light): keyword‑based evidence strength per item; selects next focus area.
90
+ - Explainability (full): aggregates transcript quotes and averaged prosody features into per‑item objects.
91
+ - Reflection: heuristic pass reduces scores by 1 for items with confidence < τ and missing evidence.
92
+ - ReportGenerator: patient and clinician summaries, confidence bars, highlights, and reflection notes.
93
+
94
+ ### Output objects
95
+ - Explainability (light):
96
+ ```json
97
+ {
98
+ "evidence_strength": {"appetite": "missing", ...},
99
+ "recommended_focus": "appetite",
100
+ "quotes": {"appetite": ["..."], ...},
101
+ "confidences": {"appetite": 0.34, ...}
102
+ }
103
+ ```
104
+ - Explainability (full):
105
+ ```json
106
+ {
107
+ "items": [
108
+ {"item":"appetite","confidence":0.42,"evidence":["..."],"prosody":["rms_mean=0.012", "zcr_mean=0.065", ...]}
109
+ ],
110
+ "notes": "Heuristic placeholder"
111
+ }
112
+ ```
113
+ - Reflection report:
114
+ ```json
115
+ {
116
+ "corrected_scores": {"appetite": 1, ...},
117
+ "final_total": 12,
118
+ "severity_label": "Moderate Depression",
119
+ "consistency_score": 0.89,
120
+ "notes": "Model revised appetite score due to low confidence and missing evidence."
121
+ }
122
+ ```
123
+
124
  ## Development notes
125
  - Framework: Gradio Blocks
126
  - ASR: Transformers pipeline (Whisper)
127
+ - TTS: gTTS or Coqui TTS
128
+ - Prosody features: librosa proxies; replaceable by OpenSMILE
129
 
130
  PRs and experiments are welcome. This is a research prototype and not a clinical tool.
app.py CHANGED
@@ -288,6 +288,36 @@ def severity_from_total(total_score: int) -> str:
288
  return "Severe Depression"
289
 
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str:
292
  """Convert chatbot history [(user, assistant), ...] to a plain text transcript."""
293
  lines = []
@@ -299,15 +329,195 @@ def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str:
299
  return "\n".join(lines)
300
 
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  def build_patient_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str:
303
  severity = meta.get("Severity") or display_json.get("Severity")
304
  total = meta.get("Total_Score") or display_json.get("Total_Score")
305
  transcript_text = transcript_to_text(chat_history)
306
- return (
307
- "# Summary for Patient\n\n"
308
- "### Conversation Transcript\n\n"
309
- f"```\n{transcript_text}\n```"
310
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
 
313
  def build_clinician_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str:
@@ -319,17 +529,75 @@ def build_clinician_summary(chat_history: List[Tuple[str, str]], meta: Dict[str,
319
  transcript_text = transcript_to_text(chat_history)
320
  scores_lines = "\n".join([f"- {k}: {v}" for k, v in scores.items()])
321
  conf_str = ", ".join([f"{c:.2f}" for c in confidences]) if confidences else ""
322
- return (
323
- "# Summary for Clinician\n\n"
324
- f"- Severity: **{severity}** \n"
325
- f"- PHQ‑9 Total: **{total}** \n"
326
- # f"- High Risk: **{risk}**\n\n"
327
- f"### Item Scores\n{scores_lines}\n\n"
328
- "### Conversation Transcript\n\n"
329
- f"```\n{transcript_text}\n```"
330
- )
331
-
332
- def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  transcript = transcript_to_text(chat_history)
334
  system_prompt = (
335
  "You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms "
@@ -337,9 +605,17 @@ def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
337
  "Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, "
338
  "sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts."
339
  )
 
 
 
 
 
 
 
 
340
  user_prompt = (
341
  "Conversation so far (Patient and Clinician turns):\n\n" + transcript +
342
- "\n\nRespond with a single short clinician-style question for the patient."
343
  )
344
  pipe = get_textgen_pipeline()
345
  tokenizer = pipe.tokenizer
@@ -620,8 +896,10 @@ def process_turn(
620
  chat_history[-1] = (chat_history[-1][0], summary)
621
  finished = True
622
  else:
623
- # Generate next clinician question
624
- reply = generate_recording_agent_reply(chat_history)
 
 
625
  chat_history[-1] = (chat_history[-1][0], reply)
626
 
627
  # TTS for the latest clinician message, if enabled
@@ -635,6 +913,9 @@ def process_turn(
635
  "Severity": severity,
636
  "Confidence": overall_conf,
637
  "High_Risk": high_risk,
 
 
 
638
  }
639
 
640
  # Clear inputs after processing
@@ -782,6 +1063,7 @@ def create_demo():
782
  meta_state = gr.State()
783
  finished_state = gr.State()
784
  turns_state = gr.State()
 
785
 
786
  # Initialize on load (no autoplay due to browser policies)
787
  demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
@@ -802,44 +1084,57 @@ def create_demo():
802
  intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main])
803
 
804
  # Wire interactions
805
- def _process_with_tts(audio, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker):
806
  result = process_turn(audio, text, chat, th, tts_on, finished, turns, scores, meta)
807
  chat_history, display_json, severity, finished_o, turns_o, _, _, _, last_tts = result
 
 
 
 
 
808
  if tts_on and chat_history and chat_history[-1][1]:
809
  new_path = synthesize_tts(chat_history[-1][1], provider=provider, coqui_model_name=coqui_model, coqui_speaker=coqui_speaker)
810
  else:
811
  new_path = None
812
  # If finished, hide the mic and display summaries in Main
813
  if finished_o:
814
- patient_md = build_patient_summary(chat_history, {"Severity": severity, "Total_Score": display_json.get("Total_Score")}, display_json)
815
- clinician_md = build_clinician_summary(chat_history, {"Severity": severity, "Total_Score": display_json.get("Total_Score")}, display_json)
 
 
 
 
 
 
 
 
816
  summary_md = patient_md + "\n\n---\n\n" + clinician_md
817
- return chat_history, display_json, severity, finished_o, turns_o, gr.update(visible=False), None, new_path, new_path, gr.update(value=summary_md, visible=True)
818
- return chat_history, display_json, severity, finished_o, turns_o, None, None, new_path, new_path, gr.update(visible=False)
819
 
820
  audio_main.stop_recording(
821
  fn=_process_with_tts,
822
- inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd],
823
- outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary],
824
  queue=True,
825
  api_name="message",
826
  )
827
 
828
  # Text input flow from Advanced tab
829
- def _process_text_and_clear(text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker):
830
- res = _process_with_tts(None, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker)
831
  return (*res, "")
832
 
833
  text_adv.submit(
834
  fn=_process_text_and_clear,
835
- inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd],
836
- outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, text_adv],
837
  queue=True,
838
  )
839
  send_adv_btn.click(
840
  fn=_process_text_and_clear,
841
- inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd],
842
- outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, text_adv],
843
  queue=True,
844
  )
845
 
 
288
  return "Severe Depression"
289
 
290
 
291
+ # ---------------------------
292
+ # PHQ-9 schema and helpers
293
+ # ---------------------------
294
+ PHQ9_KEYS_ORDERED: List[str] = [
295
+ "interest",
296
+ "mood",
297
+ "sleep",
298
+ "energy",
299
+ "appetite",
300
+ "self_worth",
301
+ "concentration",
302
+ "motor",
303
+ "suicidal_thoughts",
304
+ ]
305
+
306
+ # Lightweight keyword lexicon per item for evidence extraction.
307
+ # Placeholder for future SHAP/attention-based attributions.
308
+ PHQ9_KEYWORDS: Dict[str, List[str]] = {
309
+ "interest": ["interest", "pleasure", "enjoy", "motivation", "hobbies"],
310
+ "mood": ["depressed", "down", "sad", "hopeless", "blue", "mood"],
311
+ "sleep": ["sleep", "insomnia", "awake", "wake up", "night", "dream"],
312
+ "energy": ["tired", "fatigue", "energy", "exhausted", "worn out"],
313
+ "appetite": ["appetite", "eat", "eating", "hungry", "food", "weight"],
314
+ "self_worth": ["worthless", "failure", "guilty", "guilt", "self-esteem", "ashamed"],
315
+ "concentration": ["concentrate", "focus", "attention", "distracted", "remember"],
316
+ "motor": ["restless", "slow", "slowed", "agitated", "fidget", "move"],
317
+ "suicidal_thoughts": ["suicide", "kill myself", "die", "end my life", "self-harm", "hurt myself"],
318
+ }
319
+
320
+
321
  def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str:
322
  """Convert chatbot history [(user, assistant), ...] to a plain text transcript."""
323
  lines = []
 
329
  return "\n".join(lines)
330
 
331
 
332
+ def _patient_sentences(chat_history: List[Tuple[str, str]]) -> List[str]:
333
+ """Extract patient-only sentences from chat history."""
334
+ sentences: List[str] = []
335
+ for user, _assistant in chat_history:
336
+ if not user:
337
+ continue
338
+ parts = re.split(r"(?<=[.!?])\s+", user.strip())
339
+ for p in parts:
340
+ p = p.strip()
341
+ if p:
342
+ sentences.append(p)
343
+ return sentences
344
+
345
+
346
+ def _extract_quotes_per_item(chat_history: List[Tuple[str, str]]) -> Dict[str, List[str]]:
347
+ """Heuristic extraction of per-item evidence quotes from patient sentences based on keywords."""
348
+ quotes: Dict[str, List[str]] = {k: [] for k in PHQ9_KEYS_ORDERED}
349
+ sentences = _patient_sentences(chat_history)
350
+ for sent in sentences:
351
+ s_low = sent.lower()
352
+ for item, kws in PHQ9_KEYWORDS.items():
353
+ if any(kw in s_low for kw in kws):
354
+ if len(quotes[item]) < 5:
355
+ quotes[item].append(sent)
356
+ return quotes
357
+
358
+
359
+ def explainability_light(
360
+ chat_history: List[Tuple[str, str]],
361
+ scores: Dict[str, int],
362
+ confidences: List[float],
363
+ threshold: float,
364
+ ) -> Dict[str, Any]:
365
+ """Lightweight explainability per turn.
366
+
367
+ - Inspects transcript for keyword-based evidence per PHQ-9 item.
368
+ - Classifies evidence strength as strong/weak/missing using keyword hits and confidence.
369
+ - Suggests next focus item based on lowest-confidence or missing evidence.
370
+
371
+ Returns a JSON-serializable dict.
372
+ """
373
+ quotes = _extract_quotes_per_item(chat_history)
374
+ conf_map: Dict[str, float] = {}
375
+ for idx, key in enumerate(PHQ9_KEYS_ORDERED):
376
+ conf_map[key] = float(confidences[idx] if idx < len(confidences) else 0.0)
377
+
378
+ evidence_strength: Dict[str, str] = {}
379
+ for key in PHQ9_KEYS_ORDERED:
380
+ hits = len(quotes.get(key, []))
381
+ conf = conf_map.get(key, 0.0)
382
+ if hits >= 2 and conf >= max(0.6, threshold - 0.1):
383
+ evidence_strength[key] = "strong"
384
+ elif hits >= 1 or conf >= 0.4:
385
+ evidence_strength[key] = "weak"
386
+ else:
387
+ evidence_strength[key] = "missing"
388
+
389
+ low_items = sorted(
390
+ PHQ9_KEYS_ORDERED,
391
+ key=lambda k: (evidence_strength[k] != "missing", conf_map.get(k, 0.0))
392
+ )
393
+ recommended = low_items[0] if low_items else None
394
+
395
+ return {
396
+ "evidence_strength": evidence_strength,
397
+ "low_confidence_items": [k for k in sorted(PHQ9_KEYS_ORDERED, key=lambda x: conf_map.get(x, 0.0))],
398
+ "recommended_focus": recommended,
399
+ "quotes": quotes,
400
+ "confidences": conf_map,
401
+ }
402
+
403
+
404
+ def explainability_full(
405
+ chat_history: List[Tuple[str, str]],
406
+ confidences: List[float],
407
+ features_history: Optional[List[Dict[str, float]]],
408
+ ) -> Dict[str, Any]:
409
+ """Aggregate linguistic and acoustic attributions at session end.
410
+
411
+ - Linguistic: keyword-based quotes per item (placeholder for SHAP/attention).
412
+ - Acoustic: mean of per-turn prosodic features; returned as name=value strings.
413
+ """
414
+ def _aggregate_prosody(history: List[Dict[str, float]]) -> Dict[str, float]:
415
+ agg: Dict[str, float] = {}
416
+ if not history:
417
+ return agg
418
+ keys = set().union(*[d.keys() for d in history if isinstance(d, dict)])
419
+ for k in keys:
420
+ vals = [float(d[k]) for d in history if isinstance(d, dict) and k in d]
421
+ if vals:
422
+ agg[k] = float(np.mean(vals))
423
+ return agg
424
+
425
+ quotes = _extract_quotes_per_item(chat_history)
426
+ conf_map = {k: float(confidences[i] if i < len(confidences) else 0.0) for i, k in enumerate(PHQ9_KEYS_ORDERED)}
427
+ prosody_agg = _aggregate_prosody(list(features_history or []))
428
+ prosody_pairs = sorted(list(prosody_agg.items()), key=lambda kv: -abs(kv[1]))
429
+ prosody_names = [f"{k}={v:.3f}" for k, v in prosody_pairs[:8]]
430
+
431
+ items = []
432
+ for k in PHQ9_KEYS_ORDERED:
433
+ items.append({
434
+ "item": k,
435
+ "confidence": conf_map.get(k, 0.0),
436
+ "evidence": quotes.get(k, [])[:5],
437
+ "prosody": prosody_names,
438
+ })
439
+ return {
440
+ "items": items,
441
+ "notes": "Heuristic keyword and prosody aggregation; plug in SHAP/attention later.",
442
+ }
443
+
444
+
445
+ def reflection_module(
446
+ scores: Dict[str, int],
447
+ confidences: List[float],
448
+ exp_light: Optional[Dict[str, Any]],
449
+ exp_full: Optional[Dict[str, Any]],
450
+ threshold: float,
451
+ ) -> Dict[str, Any]:
452
+ """Self-reflection / output reevaluation.
453
+
454
+ Heuristic: if confidence for an item < threshold and evidence is missing, reduce score by 1 (min 0).
455
+ Returns a `reflection_report` JSON with corrected scores and final summary.
456
+ """
457
+ corrected = dict(scores or {})
458
+ strength = (exp_light or {}).get("evidence_strength", {}) if isinstance(exp_light, dict) else {}
459
+ changes: List[Tuple[str, int, int]] = []
460
+ for i, k in enumerate(PHQ9_KEYS_ORDERED):
461
+ conf = float(confidences[i] if i < len(confidences) else 0.0)
462
+ if conf < float(threshold) and strength.get(k) == "missing":
463
+ new_val = max(0, int(corrected.get(k, 0)) - 1)
464
+ if new_val != corrected.get(k, 0):
465
+ changes.append((k, int(corrected.get(k, 0)), new_val))
466
+ corrected[k] = new_val
467
+ final_total = int(sum(corrected.values()))
468
+ final_sev = severity_from_total(final_total)
469
+ consistency = float(1.0 - (len(changes) / max(1, len(PHQ9_KEYS_ORDERED))))
470
+ if changes:
471
+ notes = ", ".join([f"{k}: {old}->{new}" for k, old, new in changes])
472
+ notes = f"Model revised items due to low confidence and missing evidence: {notes}."
473
+ else:
474
+ notes = "No score revisions; explanations consistent with outputs."
475
+ return {
476
+ "corrected_scores": corrected,
477
+ "final_total": final_total,
478
+ "severity_label": final_sev,
479
+ "consistency_score": consistency,
480
+ "notes": notes,
481
+ }
482
+
483
+
484
+
485
  def build_patient_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str:
486
  severity = meta.get("Severity") or display_json.get("Severity")
487
  total = meta.get("Total_Score") or display_json.get("Total_Score")
488
  transcript_text = transcript_to_text(chat_history)
489
+ # Optional enriched content
490
+ exp_full = display_json.get("Explainability_Full") or {}
491
+ reflection = display_json.get("Reflection_Report") or {}
492
+
493
+ lines = []
494
+ lines.append("# Summary for Patient\n")
495
+ if total is not None and severity:
496
+ lines.append(f"- PHQ‑9 Total: **{total}** ")
497
+ lines.append(f"- Severity: **{severity}**\n")
498
+
499
+ # Highlights: show one quote per item if available
500
+ if exp_full and isinstance(exp_full, dict):
501
+ items = exp_full.get("items", [])
502
+ if isinstance(items, list) and items:
503
+ lines.append("### Highlights from our conversation\n")
504
+ for it in items:
505
+ item = it.get("item")
506
+ ev = it.get("evidence", [])
507
+ if item and ev:
508
+ lines.append(f"- {item}: \"{ev[0]}\"")
509
+ lines.append("")
510
+
511
+ if reflection:
512
+ note = reflection.get("notes")
513
+ if note:
514
+ lines.append("### Reflection\n")
515
+ lines.append(note)
516
+ lines.append("")
517
+
518
+ lines.append("### Conversation Transcript\n\n")
519
+ lines.append(f"```\n{transcript_text}\n```")
520
+ return "\n".join(lines)
521
 
522
 
523
  def build_clinician_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str:
 
529
  transcript_text = transcript_to_text(chat_history)
530
  scores_lines = "\n".join([f"- {k}: {v}" for k, v in scores.items()])
531
  conf_str = ", ".join([f"{c:.2f}" for c in confidences]) if confidences else ""
532
+ # Optional explainability
533
+ exp_light = display_json.get("Explainability_Light") or {}
534
+ exp_full = display_json.get("Explainability_Full") or {}
535
+ reflection = display_json.get("Reflection_Report") or {}
536
+
537
+ md = []
538
+ md.append("# Summary for Clinician\n")
539
+ md.append(f"- Severity: **{severity}** ")
540
+ md.append(f"- PHQ‑9 Total: **{total}** ")
541
+ if risk is not None:
542
+ md.append(f"- High Risk: **{risk}** ")
543
+ md.append("")
544
+ md.append("### Item Scores\n" + scores_lines + "\n")
545
+
546
+ # Confidence bars
547
+ if confidences:
548
+ bars = []
549
+ for i, k in enumerate(scores.keys()):
550
+ c = confidences[i] if i < len(confidences) else 0.0
551
+ bar_len = int(round(c * 20))
552
+ bars.append(f"- {k}: [{'#'*bar_len}{'.'*(20-bar_len)}] {c:.2f}")
553
+ md.append("### Confidence by item\n" + "\n".join(bars) + "\n")
554
+
555
+ # Light explainability snapshot
556
+ if exp_light:
557
+ strength = exp_light.get("evidence_strength", {})
558
+ recommended = exp_light.get("recommended_focus")
559
+ if strength:
560
+ md.append("### Evidence strength (light)\n")
561
+ md.extend([f"- {k}: {v}" for k, v in strength.items()])
562
+ md.append("")
563
+ if recommended:
564
+ md.append(f"- Next focus (if continuing): **{recommended}**\n")
565
+
566
+ # Full explainability excerpts
567
+ if exp_full and isinstance(exp_full, dict):
568
+ md.append("### Explainability (final)\n")
569
+ items = exp_full.get("items", [])
570
+ for it in items:
571
+ item = it.get("item")
572
+ conf = it.get("confidence")
573
+ ev = it.get("evidence", [])
574
+ pros = it.get("prosody", [])
575
+ if item:
576
+ md.append(f"- {item} (conf {conf:.2f}):")
577
+ for q in ev[:2]:
578
+ md.append(f" - \"{q}\"")
579
+ if pros:
580
+ md.append(f" - prosody: {', '.join([str(p) for p in pros[:4]])}")
581
+ md.append("")
582
+
583
+ # Reflection summary
584
+ if reflection:
585
+ md.append("### Self-reflection\n")
586
+ notes = reflection.get("notes")
587
+ if notes:
588
+ md.append(notes)
589
+ corr = reflection.get("corrected_scores") or {}
590
+ if corr and corr != scores:
591
+ changed = [k for k in scores.keys() if corr.get(k) != scores.get(k)]
592
+ if changed:
593
+ md.append("- Adjusted items: " + ", ".join(changed))
594
+ md.append("")
595
+
596
+ md.append("### Conversation Transcript\n\n")
597
+ md.append(f"```\n{transcript_text}\n```")
598
+ return "\n".join(md)
599
+
600
+ def generate_recording_agent_reply(chat_history: List[Tuple[str, str]], guidance: Optional[Dict[str, Any]] = None) -> str:
601
  transcript = transcript_to_text(chat_history)
602
  system_prompt = (
603
  "You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms "
 
605
  "Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, "
606
  "sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts."
607
  )
608
+ focus_text = ""
609
+ if guidance and isinstance(guidance, dict):
610
+ rec = guidance.get("recommended_focus")
611
+ if rec:
612
+ focus_text = (
613
+ f"\n\nGuidance: Focus the next question on the patient's {str(rec).replace('_', ' ')}. "
614
+ "Ask naturally about recent changes and their impact on daily life."
615
+ )
616
  user_prompt = (
617
  "Conversation so far (Patient and Clinician turns):\n\n" + transcript +
618
+ f"{focus_text}\n\nRespond with a single short clinician-style question for the patient."
619
  )
620
  pipe = get_textgen_pipeline()
621
  tokenizer = pipe.tokenizer
 
896
  chat_history[-1] = (chat_history[-1][0], summary)
897
  finished = True
898
  else:
899
+ # Iterative explainability (light) to guide next question
900
+ light_exp = explainability_light(chat_history, scores, confidences, float(threshold))
901
+ # Generate next clinician question with guidance
902
+ reply = generate_recording_agent_reply(chat_history, guidance=light_exp)
903
  chat_history[-1] = (chat_history[-1][0], reply)
904
 
905
  # TTS for the latest clinician message, if enabled
 
913
  "Severity": severity,
914
  "Confidence": overall_conf,
915
  "High_Risk": high_risk,
916
+ # Include the last audio features and light explainability for downstream modules/UI
917
+ "Last_Audio_Features": audio_features,
918
+ "Explainability_Light": explainability_light(chat_history, scores, confidences, float(threshold)),
919
  }
920
 
921
  # Clear inputs after processing
 
1063
  meta_state = gr.State()
1064
  finished_state = gr.State()
1065
  turns_state = gr.State()
1066
+ feats_state = gr.State()
1067
 
1068
  # Initialize on load (no autoplay due to browser policies)
1069
  demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
 
1084
  intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main])
1085
 
1086
  # Wire interactions
1087
+ def _process_with_tts(audio, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist):
1088
  result = process_turn(audio, text, chat, th, tts_on, finished, turns, scores, meta)
1089
  chat_history, display_json, severity, finished_o, turns_o, _, _, _, last_tts = result
1090
+ # Accumulate last audio features
1091
+ feats_hist = feats_hist or []
1092
+ last_feats = (display_json or {}).get("Last_Audio_Features") or {}
1093
+ if isinstance(last_feats, dict) and last_feats:
1094
+ feats_hist = list(feats_hist) + [last_feats]
1095
  if tts_on and chat_history and chat_history[-1][1]:
1096
  new_path = synthesize_tts(chat_history[-1][1], provider=provider, coqui_model_name=coqui_model, coqui_speaker=coqui_speaker)
1097
  else:
1098
  new_path = None
1099
  # If finished, hide the mic and display summaries in Main
1100
  if finished_o:
1101
+ # Run full explainability and reflection
1102
+ exp_full = explainability_full(chat_history, display_json.get("Confidences", []), feats_hist)
1103
+ reflect = reflection_module(display_json.get("PHQ9_Scores", {}), display_json.get("Confidences", []), display_json.get("Explainability_Light", {}), exp_full, float(th))
1104
+ display_json["Explainability_Full"] = exp_full
1105
+ display_json["Reflection_Report"] = reflect
1106
+ # Use reflection outputs to set final meta
1107
+ final_sev = reflect.get("severity_label") or severity
1108
+ final_total = reflect.get("final_total") or display_json.get("Total_Score")
1109
+ patient_md = build_patient_summary(chat_history, {"Severity": final_sev, "Total_Score": final_total}, display_json)
1110
+ clinician_md = build_clinician_summary(chat_history, {"Severity": final_sev, "Total_Score": final_total}, display_json)
1111
  summary_md = patient_md + "\n\n---\n\n" + clinician_md
1112
+ return chat_history, display_json, severity, finished_o, turns_o, gr.update(visible=False), None, new_path, new_path, gr.update(value=summary_md, visible=True), feats_hist
1113
+ return chat_history, display_json, severity, finished_o, turns_o, None, None, new_path, new_path, gr.update(visible=False), feats_hist
1114
 
1115
  audio_main.stop_recording(
1116
  fn=_process_with_tts,
1117
+ inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state],
1118
+ outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state],
1119
  queue=True,
1120
  api_name="message",
1121
  )
1122
 
1123
  # Text input flow from Advanced tab
1124
+ def _process_text_and_clear(text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist):
1125
+ res = _process_with_tts(None, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist)
1126
  return (*res, "")
1127
 
1128
  text_adv.submit(
1129
  fn=_process_text_and_clear,
1130
+ inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state],
1131
+ outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state, text_adv],
1132
  queue=True,
1133
  )
1134
  send_adv_btn.click(
1135
  fn=_process_text_and_clear,
1136
+ inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state],
1137
+ outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state, text_adv],
1138
  queue=True,
1139
  )
1140