Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| # Disable torch compile/dynamo globally to avoid cudagraph assertion errors | |
| os.environ["TORCHDYNAMO_DISABLE"] = "1" | |
| os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
| import json | |
| import re | |
| import time | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| # Audio processing | |
| import soundfile as sf | |
| import librosa | |
| # Models | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| pipeline, | |
| ) | |
| from gtts import gTTS | |
| import spaces | |
| import threading | |
| # --------------------------- | |
| # Configuration | |
| # --------------------------- | |
| DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "google/gemma-2-2b-it") | |
| DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en") | |
| CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8")) | |
| MAX_TURNS = int(os.getenv("MAX_TURNS", "12")) | |
| USE_TTS_DEFAULT = os.getenv("USE_TTS", "true").strip().lower() == "true" | |
| CONFIG_PATH = os.getenv("MODEL_CONFIG_PATH", "model_config.json") | |
| def _load_model_id_from_config() -> str: | |
| try: | |
| if os.path.exists(CONFIG_PATH): | |
| with open(CONFIG_PATH, "r") as f: | |
| data = json.load(f) | |
| if isinstance(data, dict) and data.get("model_id"): | |
| return str(data["model_id"]) | |
| except Exception: | |
| pass | |
| return DEFAULT_CHAT_MODEL_ID | |
| current_model_id = _load_model_id_from_config() | |
| # --------------------------- | |
| # Lazy singletons for pipelines | |
| # --------------------------- | |
| _asr_pipe = None | |
| _gen_pipe = None | |
| _tokenizer = None | |
| def _hf_device() -> int: | |
| return 0 if torch.cuda.is_available() else -1 | |
| def get_asr_pipeline(): | |
| global _asr_pipe | |
| if _asr_pipe is None: | |
| _asr_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=DEFAULT_ASR_MODEL_ID, | |
| device=_hf_device(), | |
| ) | |
| return _asr_pipe | |
| def get_textgen_pipeline(): | |
| global _gen_pipe | |
| if _gen_pipe is None: | |
| # Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID | |
| if torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported(): | |
| _dtype = torch.bfloat16 | |
| elif torch.cuda.is_available(): | |
| _dtype = torch.float16 | |
| else: | |
| _dtype = torch.float32 | |
| _gen_pipe = pipeline( | |
| task="text-generation", | |
| model=current_model_id, | |
| tokenizer=current_model_id, | |
| device=_hf_device(), | |
| torch_dtype=_dtype, | |
| ) | |
| return _gen_pipe | |
| def set_current_model_id(new_model_id: str) -> str: | |
| global current_model_id, _gen_pipe | |
| new_model_id = (new_model_id or "").strip() | |
| if not new_model_id: | |
| return "Model id is empty; keeping current model." | |
| if new_model_id == current_model_id: | |
| return f"Model unchanged: `{current_model_id}`" | |
| current_model_id = new_model_id | |
| _gen_pipe = None # force reload on next use | |
| return f"Model switched to `{current_model_id}` (pipeline will reload on next generation)." | |
| def persist_model_id(new_model_id: str) -> None: | |
| try: | |
| with open(CONFIG_PATH, "w") as f: | |
| json.dump({"model_id": new_model_id}, f) | |
| except Exception: | |
| pass | |
| def apply_model_and_restart(new_model_id: str) -> str: | |
| mid = (new_model_id or "").strip() | |
| if not mid: | |
| return "Model id is empty; not restarting." | |
| persist_model_id(mid) | |
| set_current_model_id(mid) | |
| # Graceful delayed exit so response can flush | |
| def _exit_later(): | |
| time.sleep(0.25) | |
| os._exit(0) | |
| threading.Thread(target=_exit_later, daemon=True).start() | |
| return f"Restarting with model `{mid}`..." | |
| # --------------------------- | |
| # Utilities | |
| # --------------------------- | |
| def safe_json_extract(text: str) -> Optional[Dict[str, Any]]: | |
| """Extract first JSON object from text.""" | |
| if not text: | |
| return None | |
| try: | |
| return json.loads(text) | |
| except Exception: | |
| pass | |
| # Fallback: find the first {...} block | |
| match = re.search(r"\{[\s\S]*\}", text) | |
| if match: | |
| try: | |
| return json.loads(match.group(0)) | |
| except Exception: | |
| return None | |
| return None | |
| def compute_audio_features(audio_path: str) -> Dict[str, float]: | |
| """Compute lightweight prosodic features as a proxy for OpenSMILE. | |
| Returns a dictionary with summary statistics. | |
| """ | |
| try: | |
| y, sr = librosa.load(audio_path, sr=16000, mono=True) | |
| if len(y) == 0: | |
| return {} | |
| # Frame-based features | |
| hop_length = 512 | |
| frame_length = 1024 | |
| rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0] | |
| zcr = librosa.feature.zero_crossing_rate(y, frame_length=frame_length, hop_length=hop_length)[0] | |
| centroid = librosa.feature.spectral_centroid(y=y, sr=sr, n_fft=2048, hop_length=hop_length)[0] | |
| # Pitch estimation (coarse) | |
| f0 = None | |
| try: | |
| f0 = librosa.yin(y, fmin=50, fmax=400, sr=sr, frame_length=frame_length, hop_length=hop_length) | |
| f0 = f0[np.isfinite(f0)] | |
| except Exception: | |
| f0 = None | |
| # Speaking rate rough proxy: voiced ratio per second | |
| energy = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0] | |
| voiced = energy > (np.median(energy) * 1.2) | |
| voiced_ratio = float(np.mean(voiced)) | |
| features = { | |
| "rms_mean": float(np.mean(rms)), | |
| "rms_std": float(np.std(rms)), | |
| "zcr_mean": float(np.mean(zcr)), | |
| "zcr_std": float(np.std(zcr)), | |
| "centroid_mean": float(np.mean(centroid)), | |
| "centroid_std": float(np.std(centroid)), | |
| "voiced_ratio": voiced_ratio, | |
| "duration_sec": float(len(y) / sr), | |
| } | |
| if f0 is not None and f0.size > 0: | |
| features.update({ | |
| "f0_median": float(np.median(f0)), | |
| "f0_iqr": float(np.percentile(f0, 75) - np.percentile(f0, 25)), | |
| }) | |
| return features | |
| except Exception: | |
| return {} | |
| def detect_explicit_suicidality(text: Optional[str]) -> bool: | |
| if not text: | |
| return False | |
| t = text.lower() | |
| patterns = [ | |
| r"\bkill myself\b", | |
| r"\bend my life\b", | |
| r"\bend it all\b", | |
| r"\bcommit suicide\b", | |
| r"\bsuicide\b", | |
| r"\bself[-\s]?harm\b", | |
| r"\bhurt myself\b", | |
| r"\bno reason to live\b", | |
| r"\bwant to die\b", | |
| r"\bending it\b", | |
| ] | |
| for pat in patterns: | |
| if re.search(pat, t): | |
| return True | |
| return False | |
| def synthesize_tts( | |
| text: Optional[str], | |
| provider: str = "Coqui", | |
| coqui_model_name: Optional[str] = None, | |
| coqui_speaker: Optional[str] = None, | |
| ) -> Optional[str]: | |
| if not text: | |
| return None | |
| ts = int(time.time() * 1000) | |
| provider_norm = (provider or "Coqui").strip().lower() | |
| # Try Coqui first if requested | |
| if provider_norm == "coqui": | |
| try: | |
| # coqui-tts uses the same import path TTS.api | |
| from TTS.api import TTS as CoquiTTS # type: ignore | |
| model_name = (coqui_model_name or os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")).strip() | |
| engine = CoquiTTS(model_name=model_name, progress_bar=False) | |
| out_path_wav = f"/tmp/tts_{ts}.wav" | |
| kwargs = {} | |
| if coqui_speaker: | |
| kwargs["speaker"] = coqui_speaker | |
| engine.tts_to_file(text=text, file_path=out_path_wav, **kwargs) | |
| return out_path_wav | |
| except Exception: | |
| pass | |
| # Fallback to gTTS | |
| try: | |
| out_path = f"/tmp/tts_{ts}.mp3" | |
| tts = gTTS(text=text, lang="en") | |
| tts.save(out_path) | |
| return out_path | |
| except Exception: | |
| return None | |
| def list_coqui_speakers(model_name: str) -> List[str]: | |
| try: | |
| from TTS.api import TTS as CoquiTTS # type: ignore | |
| engine = CoquiTTS(model_name=model_name, progress_bar=False) | |
| # Try common attributes | |
| if hasattr(engine, "speakers") and isinstance(engine.speakers, list): | |
| return [str(s) for s in engine.speakers] | |
| if hasattr(engine, "speaker_manager") and hasattr(engine.speaker_manager, "speaker_names"): | |
| return list(engine.speaker_manager.speaker_names) | |
| except Exception: | |
| pass | |
| # Reasonable defaults for VCTK multi-speaker | |
| return ["p225", "p227", "p231", "p233", "p236"] | |
| def severity_from_total(total_score: int) -> str: | |
| if total_score <= 4: | |
| return "Minimal Depression" | |
| if total_score <= 9: | |
| return "Mild Depression" | |
| if total_score <= 14: | |
| return "Moderate Depression" | |
| if total_score <= 19: | |
| return "Moderately Severe Depression" | |
| return "Severe Depression" | |
| # --------------------------- | |
| # PHQ-9 schema and helpers | |
| # --------------------------- | |
| PHQ9_KEYS_ORDERED: List[str] = [ | |
| "interest", | |
| "mood", | |
| "sleep", | |
| "energy", | |
| "appetite", | |
| "self_worth", | |
| "concentration", | |
| "motor", | |
| "suicidal_thoughts", | |
| ] | |
| # Lightweight keyword lexicon per item for evidence extraction. | |
| # Placeholder for future SHAP/attention-based attributions. | |
| PHQ9_KEYWORDS: Dict[str, List[str]] = { | |
| "interest": ["interest", "pleasure", "enjoy", "motivation", "hobbies"], | |
| "mood": ["depressed", "down", "sad", "hopeless", "blue", "mood"], | |
| "sleep": ["sleep", "insomnia", "awake", "wake up", "night", "dream"], | |
| "energy": ["tired", "fatigue", "energy", "exhausted", "worn out"], | |
| "appetite": ["appetite", "eat", "eating", "hungry", "food", "weight"], | |
| "self_worth": ["worthless", "failure", "guilty", "guilt", "self-esteem", "ashamed"], | |
| "concentration": ["concentrate", "focus", "attention", "distracted", "remember"], | |
| "motor": ["restless", "slow", "slowed", "agitated", "fidget", "move"], | |
| "suicidal_thoughts": ["suicide", "kill myself", "die", "end my life", "self-harm", "hurt myself"], | |
| } | |
| def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str: | |
| """Convert chatbot history [(user, assistant), ...] to a plain text transcript.""" | |
| lines = [] | |
| for user, assistant in chat_history: | |
| if user: | |
| lines.append(f"Patient: {user}") | |
| if assistant: | |
| lines.append(f"Clinician: {assistant}") | |
| return "\n".join(lines) | |
| def _patient_sentences(chat_history: List[Tuple[str, str]]) -> List[str]: | |
| """Extract patient-only sentences from chat history.""" | |
| sentences: List[str] = [] | |
| for user, _assistant in chat_history: | |
| if not user: | |
| continue | |
| parts = re.split(r"(?<=[.!?])\s+", user.strip()) | |
| for p in parts: | |
| p = p.strip() | |
| if p: | |
| sentences.append(p) | |
| return sentences | |
| def _extract_quotes_per_item(chat_history: List[Tuple[str, str]]) -> Dict[str, List[str]]: | |
| """Heuristic extraction of per-item evidence quotes from patient sentences based on keywords.""" | |
| quotes: Dict[str, List[str]] = {k: [] for k in PHQ9_KEYS_ORDERED} | |
| sentences = _patient_sentences(chat_history) | |
| for sent in sentences: | |
| s_low = sent.lower() | |
| for item, kws in PHQ9_KEYWORDS.items(): | |
| if any(kw in s_low for kw in kws): | |
| if len(quotes[item]) < 5: | |
| quotes[item].append(sent) | |
| return quotes | |
| def explainability_light( | |
| chat_history: List[Tuple[str, str]], | |
| scores: Dict[str, int], | |
| confidences: List[float], | |
| threshold: float, | |
| ) -> Dict[str, Any]: | |
| """Lightweight explainability per turn. | |
| - Inspects transcript for keyword-based evidence per PHQ-9 item. | |
| - Classifies evidence strength as strong/weak/missing using keyword hits and confidence. | |
| - Suggests next focus item based on lowest-confidence or missing evidence. | |
| Returns a JSON-serializable dict. | |
| """ | |
| quotes = _extract_quotes_per_item(chat_history) | |
| conf_map: Dict[str, float] = {} | |
| for idx, key in enumerate(PHQ9_KEYS_ORDERED): | |
| conf_map[key] = float(confidences[idx] if idx < len(confidences) else 0.0) | |
| evidence_strength: Dict[str, str] = {} | |
| for key in PHQ9_KEYS_ORDERED: | |
| hits = len(quotes.get(key, [])) | |
| conf = conf_map.get(key, 0.0) | |
| if hits >= 2 and conf >= max(0.6, threshold - 0.1): | |
| evidence_strength[key] = "strong" | |
| elif hits >= 1 or conf >= 0.4: | |
| evidence_strength[key] = "weak" | |
| else: | |
| evidence_strength[key] = "missing" | |
| low_items = sorted( | |
| PHQ9_KEYS_ORDERED, | |
| key=lambda k: (evidence_strength[k] != "missing", conf_map.get(k, 0.0)) | |
| ) | |
| recommended = low_items[0] if low_items else None | |
| return { | |
| "evidence_strength": evidence_strength, | |
| "low_confidence_items": [k for k in sorted(PHQ9_KEYS_ORDERED, key=lambda x: conf_map.get(x, 0.0))], | |
| "recommended_focus": recommended, | |
| "quotes": quotes, | |
| "confidences": conf_map, | |
| } | |
| def explainability_full( | |
| chat_history: List[Tuple[str, str]], | |
| confidences: List[float], | |
| features_history: Optional[List[Dict[str, float]]], | |
| ) -> Dict[str, Any]: | |
| """Aggregate linguistic and acoustic attributions at session end. | |
| - Linguistic: keyword-based quotes per item (placeholder for SHAP/attention). | |
| - Acoustic: mean of per-turn prosodic features; returned as name=value strings. | |
| """ | |
| def _aggregate_prosody(history: List[Dict[str, float]]) -> Dict[str, float]: | |
| agg: Dict[str, float] = {} | |
| if not history: | |
| return agg | |
| keys = set().union(*[d.keys() for d in history if isinstance(d, dict)]) | |
| for k in keys: | |
| vals = [float(d[k]) for d in history if isinstance(d, dict) and k in d] | |
| if vals: | |
| agg[k] = float(np.mean(vals)) | |
| return agg | |
| quotes = _extract_quotes_per_item(chat_history) | |
| conf_map = {k: float(confidences[i] if i < len(confidences) else 0.0) for i, k in enumerate(PHQ9_KEYS_ORDERED)} | |
| prosody_agg = _aggregate_prosody(list(features_history or [])) | |
| prosody_pairs = sorted(list(prosody_agg.items()), key=lambda kv: -abs(kv[1])) | |
| prosody_names = [f"{k}={v:.3f}" for k, v in prosody_pairs[:8]] | |
| items = [] | |
| for k in PHQ9_KEYS_ORDERED: | |
| items.append({ | |
| "item": k, | |
| "confidence": conf_map.get(k, 0.0), | |
| "evidence": quotes.get(k, [])[:5], | |
| "prosody": prosody_names, | |
| }) | |
| return { | |
| "items": items, | |
| "notes": "Heuristic keyword and prosody aggregation; plug in SHAP/attention later.", | |
| } | |
| def reflection_module( | |
| scores: Dict[str, int], | |
| confidences: List[float], | |
| exp_light: Optional[Dict[str, Any]], | |
| exp_full: Optional[Dict[str, Any]], | |
| threshold: float, | |
| ) -> Dict[str, Any]: | |
| """Self-reflection / output reevaluation. | |
| Heuristic: if confidence for an item < threshold and evidence is missing, reduce score by 1 (min 0). | |
| Returns a `reflection_report` JSON with corrected scores and final summary. | |
| """ | |
| corrected = dict(scores or {}) | |
| strength = (exp_light or {}).get("evidence_strength", {}) if isinstance(exp_light, dict) else {} | |
| changes: List[Tuple[str, int, int]] = [] | |
| for i, k in enumerate(PHQ9_KEYS_ORDERED): | |
| conf = float(confidences[i] if i < len(confidences) else 0.0) | |
| if conf < float(threshold) and strength.get(k) == "missing": | |
| new_val = max(0, int(corrected.get(k, 0)) - 1) | |
| if new_val != corrected.get(k, 0): | |
| changes.append((k, int(corrected.get(k, 0)), new_val)) | |
| corrected[k] = new_val | |
| final_total = int(sum(corrected.values())) | |
| final_sev = severity_from_total(final_total) | |
| consistency = float(1.0 - (len(changes) / max(1, len(PHQ9_KEYS_ORDERED)))) | |
| if changes: | |
| notes = ", ".join([f"{k}: {old}->{new}" for k, old, new in changes]) | |
| notes = f"Model revised items due to low confidence and missing evidence: {notes}." | |
| else: | |
| notes = "No score revisions; explanations consistent with outputs." | |
| return { | |
| "corrected_scores": corrected, | |
| "final_total": final_total, | |
| "severity_label": final_sev, | |
| "consistency_score": consistency, | |
| "notes": notes, | |
| } | |
| def build_patient_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str: | |
| severity = meta.get("Severity") or display_json.get("Severity") | |
| total = meta.get("Total_Score") or display_json.get("Total_Score") | |
| transcript_text = transcript_to_text(chat_history) | |
| # Optional enriched content | |
| exp_full = display_json.get("Explainability_Full") or {} | |
| reflection = display_json.get("Reflection_Report") or {} | |
| lines = [] | |
| lines.append("# Summary for Patient\n") | |
| if total is not None and severity: | |
| lines.append(f"- PHQ‑9 Total: **{total}** ") | |
| lines.append(f"- Severity: **{severity}**\n") | |
| # Highlights: show one quote per item if available | |
| if exp_full and isinstance(exp_full, dict): | |
| items = exp_full.get("items", []) | |
| if isinstance(items, list) and items: | |
| lines.append("### Highlights from our conversation\n") | |
| for it in items: | |
| item = it.get("item") | |
| ev = it.get("evidence", []) | |
| if item and ev: | |
| lines.append(f"- {item}: \"{ev[0]}\"") | |
| lines.append("") | |
| if reflection: | |
| note = reflection.get("notes") | |
| if note: | |
| lines.append("### Reflection\n") | |
| lines.append(note) | |
| lines.append("") | |
| lines.append("### Conversation Transcript\n\n") | |
| lines.append(f"```\n{transcript_text}\n```") | |
| return "\n".join(lines) | |
| def build_clinician_summary(chat_history: List[Tuple[str, str]], meta: Dict[str, Any], display_json: Dict[str, Any]) -> str: | |
| scores = display_json.get("PHQ9_Scores", {}) | |
| confidences = display_json.get("Confidences", []) | |
| severity = meta.get("Severity") or display_json.get("Severity") | |
| total = meta.get("Total_Score") or display_json.get("Total_Score") | |
| risk = display_json.get("High_Risk") | |
| transcript_text = transcript_to_text(chat_history) | |
| scores_lines = "\n".join([f"- {k}: {v}" for k, v in scores.items()]) | |
| conf_str = ", ".join([f"{c:.2f}" for c in confidences]) if confidences else "" | |
| # Optional explainability | |
| exp_light = display_json.get("Explainability_Light") or {} | |
| exp_full = display_json.get("Explainability_Full") or {} | |
| reflection = display_json.get("Reflection_Report") or {} | |
| md = [] | |
| md.append("# Summary for Clinician\n") | |
| md.append(f"- Severity: **{severity}** ") | |
| md.append(f"- PHQ‑9 Total: **{total}** ") | |
| if risk is not None: | |
| md.append(f"- High Risk: **{risk}** ") | |
| md.append("") | |
| md.append("### Item Scores\n" + scores_lines + "\n") | |
| # Confidence bars | |
| if confidences: | |
| bars = [] | |
| for i, k in enumerate(scores.keys()): | |
| c = confidences[i] if i < len(confidences) else 0.0 | |
| bar_len = int(round(c * 20)) | |
| bars.append(f"- {k}: [{'#'*bar_len}{'.'*(20-bar_len)}] {c:.2f}") | |
| md.append("### Confidence by item\n" + "\n".join(bars) + "\n") | |
| # Light explainability snapshot | |
| if exp_light: | |
| strength = exp_light.get("evidence_strength", {}) | |
| recommended = exp_light.get("recommended_focus") | |
| if strength: | |
| md.append("### Evidence strength (light)\n") | |
| md.extend([f"- {k}: {v}" for k, v in strength.items()]) | |
| md.append("") | |
| if recommended: | |
| md.append(f"- Next focus (if continuing): **{recommended}**\n") | |
| # Full explainability excerpts | |
| if exp_full and isinstance(exp_full, dict): | |
| md.append("### Explainability (final)\n") | |
| items = exp_full.get("items", []) | |
| for it in items: | |
| item = it.get("item") | |
| conf = it.get("confidence") | |
| ev = it.get("evidence", []) | |
| pros = it.get("prosody", []) | |
| if item: | |
| md.append(f"- {item} (conf {conf:.2f}):") | |
| for q in ev[:2]: | |
| md.append(f" - \"{q}\"") | |
| if pros: | |
| md.append(f" - prosody: {', '.join([str(p) for p in pros[:4]])}") | |
| md.append("") | |
| # Reflection summary | |
| if reflection: | |
| md.append("### Self-reflection\n") | |
| notes = reflection.get("notes") | |
| if notes: | |
| md.append(notes) | |
| corr = reflection.get("corrected_scores") or {} | |
| if corr and corr != scores: | |
| changed = [k for k in scores.keys() if corr.get(k) != scores.get(k)] | |
| if changed: | |
| md.append("- Adjusted items: " + ", ".join(changed)) | |
| md.append("") | |
| md.append("### Conversation Transcript\n\n") | |
| md.append(f"```\n{transcript_text}\n```") | |
| return "\n".join(md) | |
| def generate_recording_agent_reply(chat_history: List[Tuple[str, str]], guidance: Optional[Dict[str, Any]] = None) -> str: | |
| transcript = transcript_to_text(chat_history) | |
| system_prompt = ( | |
| "You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms " | |
| "without listing the nine questions explicitly. Keep tone empathetic, natural, and human. " | |
| "Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, " | |
| "sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts." | |
| ) | |
| focus_text = "" | |
| if guidance and isinstance(guidance, dict): | |
| rec = guidance.get("recommended_focus") | |
| if rec: | |
| focus_text = ( | |
| f"\n\nGuidance: Focus the next question on the patient's {str(rec).replace('_', ' ')}. " | |
| "Ask naturally about recent changes and their impact on daily life." | |
| ) | |
| user_prompt = ( | |
| "Conversation so far (Patient and Clinician turns):\n\n" + transcript + | |
| f"{focus_text}\n\nRespond with a single short clinician-style question for the patient." | |
| ) | |
| pipe = get_textgen_pipeline() | |
| tokenizer = pipe.tokenizer | |
| combined_prompt = system_prompt + "\n\n" + user_prompt | |
| messages = [ | |
| {"role": "user", "content": combined_prompt}, | |
| ] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| # Avoid TorchInductor graph capture issues on some environments | |
| try: | |
| import torch._dynamo as _dynamo # type: ignore | |
| except Exception: | |
| _dynamo = None | |
| gen = pipe( | |
| prompt, | |
| max_new_tokens=96, | |
| temperature=0.7, | |
| do_sample=True, | |
| top_p=0.9, | |
| top_k=50, | |
| pad_token_id=tokenizer.eos_token_id, | |
| return_full_text=False, | |
| ) | |
| reply = gen[0]["generated_text"].strip() | |
| # Ensure it's a single concise question/sentence | |
| if len(reply) > 300: | |
| reply = reply[:300].rstrip() + "…" | |
| return reply | |
| def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str, float]) -> Dict[str, Any]: | |
| """Ask the LLM to produce PHQ-9 scores and confidences as JSON. Fallback if parsing fails.""" | |
| transcript = transcript_to_text(chat_history) | |
| features_json = json.dumps(features, ensure_ascii=False) | |
| system_prompt = ( | |
| "You evaluate an on-going clinician-patient conversation to infer a PHQ-9 assessment. " | |
| "Return ONLY a JSON object with: PHQ9_Scores (interest,mood,sleep,energy,appetite,self_worth,concentration,motor,suicidal_thoughts; each 0-3), " | |
| "Confidences (list of 9 floats 0-1 in the same order), Total_Score (0-27), Severity (string), Confidence (min of confidences), " | |
| "and High_Risk (boolean, true if any suicidal risk)." | |
| ) | |
| user_prompt = ( | |
| "Conversation transcript:"\ | |
| f"\n{transcript}\n\n" | |
| f"Acoustic features summary (approximate):\n{features_json}\n\n" | |
| "Instructions: Infer PHQ9_Scores (0-3 per item), estimate Confidences per item, compute Total_Score and overall Severity. " | |
| "Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose." | |
| ) | |
| pipe = get_textgen_pipeline() | |
| tokenizer = pipe.tokenizer | |
| combined_prompt = system_prompt + "\n\n" + user_prompt | |
| messages = [ | |
| {"role": "user", "content": combined_prompt}, | |
| ] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| # Use deterministic decoding to avoid CUDA sampling edge cases on some models | |
| try: | |
| import torch._dynamo as _dynamo # type: ignore | |
| except Exception: | |
| _dynamo = None | |
| gen = pipe( | |
| prompt, | |
| max_new_tokens=256, | |
| temperature=0.0, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| return_full_text=False, | |
| ) | |
| out_text = gen[0]["generated_text"] | |
| parsed = safe_json_extract(out_text) | |
| # Validate and coerce | |
| if parsed is None or "PHQ9_Scores" not in parsed: | |
| # Simple fallback heuristic: neutral scores with low confidence | |
| scores = { | |
| "interest": 1, | |
| "mood": 1, | |
| "sleep": 1, | |
| "energy": 1, | |
| "appetite": 1, | |
| "self_worth": 1, | |
| "concentration": 1, | |
| "motor": 1, | |
| "suicidal_thoughts": 0, | |
| } | |
| confidences = [0.5] * 9 | |
| total = int(sum(scores.values())) | |
| return { | |
| "PHQ9_Scores": scores, | |
| "Confidences": confidences, | |
| "Total_Score": total, | |
| "Severity": severity_from_total(total), | |
| "Confidence": float(min(confidences)), | |
| "High_Risk": False, | |
| } | |
| try: | |
| # Coerce types and compute derived values if missing | |
| scores = parsed.get("PHQ9_Scores", {}) | |
| # Ensure all keys present | |
| keys = [ | |
| "interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts" | |
| ] | |
| for k in keys: | |
| scores[k] = int(max(0, min(3, int(scores.get(k, 0))))) | |
| confidences = parsed.get("Confidences", []) | |
| if not isinstance(confidences, list) or len(confidences) != 9: | |
| confidences = [float(parsed.get("Confidence", 0.5))] * 9 | |
| confidences = [float(max(0.0, min(1.0, c))) for c in confidences] | |
| total = int(sum(scores.values())) | |
| severity = parsed.get("Severity") or severity_from_total(total) | |
| overall_conf = float(parsed.get("Confidence", min(confidences))) | |
| # Conservative high-risk detection: require explicit language or high suicidal_thoughts score | |
| # Extract last patient message | |
| last_patient = "" | |
| for user_text, assistant_text in reversed(chat_history): | |
| if user_text: | |
| last_patient = user_text | |
| break | |
| explicit_flag = detect_explicit_suicidality(last_patient) or detect_explicit_suicidality(transcript) | |
| high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2)) | |
| return { | |
| "PHQ9_Scores": scores, | |
| "Confidences": confidences, | |
| "Total_Score": total, | |
| "Severity": severity, | |
| "Confidence": overall_conf, | |
| "High_Risk": high_risk, | |
| } | |
| except Exception: | |
| # Final fallback | |
| scores = parsed.get("PHQ9_Scores", {}) if isinstance(parsed, dict) else {} | |
| if not scores: | |
| scores = {k: 1 for k in [ | |
| "interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts" | |
| ]} | |
| confidences = [0.5] * 9 | |
| total = int(sum(scores.values())) | |
| return { | |
| "PHQ9_Scores": scores, | |
| "Confidences": confidences, | |
| "Total_Score": total, | |
| "Severity": severity_from_total(total), | |
| "Confidence": float(min(confidences)), | |
| "High_Risk": False, | |
| } | |
| def transcribe_audio(audio_path: Optional[str]) -> str: | |
| if not audio_path: | |
| return "" | |
| try: | |
| asr = get_asr_pipeline() | |
| result = asr(audio_path) | |
| if isinstance(result, dict) and "text" in result: | |
| return result["text"].strip() | |
| if isinstance(result, list) and len(result) > 0 and "text" in result[0]: | |
| return result[0]["text"].strip() | |
| except Exception: | |
| pass | |
| return "" | |
| # --------------------------- | |
| # Gradio app logic | |
| # --------------------------- | |
| INTRO_MESSAGE = ( | |
| "Hi, I'm an assistant, and I will ask you some questions about how you've been doing." | |
| "We'll record our conversation, and we will give you a written copy of it." | |
| "From our conversation, we will send a written copy to the clinician, we will give a summary of what you are experiencing based on a questionnaire, called the Patient Health Questionnaire (PHQ-9), and we will give a summary of what your voice is like." | |
| "We will send this to the clinician, and the clinician will follow up with you." | |
| "To start, how has your mood been over the past couple of weeks?" | |
| ) | |
| def init_state() -> Tuple[List[Tuple[str, str]], Dict[str, Any], Dict[str, Any], bool, int]: | |
| chat_history: List[Tuple[str, str]] = [("", INTRO_MESSAGE)] | |
| scores = {} | |
| meta = {"Severity": None, "Total_Score": None, "Confidence": 0.0} | |
| finished = False | |
| turns = 0 | |
| return chat_history, scores, meta, finished, turns | |
| def process_turn( | |
| audio_path: Optional[str], | |
| text_input: Optional[str], | |
| chat_history: List[Tuple[str, str]], | |
| threshold: float, | |
| tts_enabled: bool, | |
| finished: Optional[bool], | |
| turns: Optional[int], | |
| prev_scores: Dict[str, Any], | |
| prev_meta: Dict[str, Any], | |
| ): | |
| # If already finished, do nothing | |
| finished = bool(finished) if finished is not None else False | |
| turns = int(turns) if isinstance(turns, int) else 0 | |
| if finished: | |
| return ( | |
| chat_history, | |
| {"info": "Assessment complete."}, | |
| prev_meta.get("Severity", ""), | |
| finished, | |
| turns, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| patient_text = (text_input or "").strip() | |
| audio_features: Dict[str, float] = {} | |
| if audio_path: | |
| # Transcribe first | |
| transcribed = transcribe_audio(audio_path) | |
| if transcribed: | |
| patient_text = (patient_text + " ").strip() + transcribed if patient_text else transcribed | |
| # Extract features | |
| audio_features = compute_audio_features(audio_path) | |
| if not patient_text: | |
| # Ask user for input | |
| chat_history.append(("", "I didn't catch that. Could you share a bit about how you've been feeling?")) | |
| return ( | |
| chat_history, | |
| prev_scores or {}, | |
| prev_meta.get("Severity", ""), | |
| finished, | |
| turns, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| # Add patient's message | |
| chat_history.append((patient_text, None)) | |
| # Scoring agent | |
| scoring = scoring_agent_infer(chat_history, audio_features) | |
| scores = scoring.get("PHQ9_Scores", {}) | |
| confidences = scoring.get("Confidences", []) | |
| total = scoring.get("Total_Score", 0) | |
| severity = scoring.get("Severity", severity_from_total(total)) | |
| overall_conf = float(scoring.get("Confidence", min(confidences) if confidences else 0.0)) | |
| # Override high-risk to reduce false positives: rely on explicit text or high item score | |
| # Extract last patient message | |
| last_patient = "" | |
| for user_text, assistant_text in reversed(chat_history): | |
| if user_text: | |
| last_patient = user_text | |
| break | |
| explicit_flag = detect_explicit_suicidality(last_patient) | |
| high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2)) | |
| meta = {"Severity": severity, "Total_Score": total, "Confidence": overall_conf} | |
| # Termination conditions | |
| min_conf = float(min(confidences)) if confidences else 0.0 | |
| turns += 1 | |
| done = high_risk or (min_conf >= threshold) or (turns >= MAX_TURNS) | |
| if high_risk: | |
| closing = ( | |
| "I’m concerned about your safety based on what you shared. " | |
| "If you are in danger or need immediate help, please call 116 123 in the UK, 988 in the U.S. or your local emergency number. " | |
| "I'll end the assessment now and display emergency resources." | |
| ) | |
| chat_history[-1] = (chat_history[-1][0], closing) | |
| finished = True | |
| elif done: | |
| summary = ( | |
| "Thank you for your time. The clinician will review your conversation and follow up with you." | |
| "Here is a copy of our conversation so you can review it later." | |
| ) | |
| chat_history[-1] = (chat_history[-1][0], summary) | |
| finished = True | |
| else: | |
| # Iterative explainability (light) to guide next question | |
| light_exp = explainability_light(chat_history, scores, confidences, float(threshold)) | |
| # Generate next clinician question with guidance | |
| reply = generate_recording_agent_reply(chat_history, guidance=light_exp) | |
| chat_history[-1] = (chat_history[-1][0], reply) | |
| # TTS for the latest clinician message, if enabled | |
| tts_path = synthesize_tts(chat_history[-1][1]) if tts_enabled else None | |
| # Build a compact JSON for display | |
| display_json = { | |
| "PHQ9_Scores": scores, | |
| "Confidences": confidences, | |
| "Total_Score": total, | |
| "Severity": severity, | |
| "Confidence": overall_conf, | |
| "High_Risk": high_risk, | |
| # Include the last audio features and light explainability for downstream modules/UI | |
| "Last_Audio_Features": audio_features, | |
| "Explainability_Light": explainability_light(chat_history, scores, confidences, float(threshold)), | |
| } | |
| # Clear inputs after processing | |
| return ( | |
| chat_history, | |
| display_json, | |
| severity, | |
| finished, | |
| turns, | |
| None, | |
| None, | |
| tts_path, | |
| tts_path, | |
| ) | |
| def reset_app(): | |
| return init_state() | |
| # --------------------------- | |
| # UI | |
| # --------------------------- | |
| def _on_load_init(): | |
| return init_state() | |
| def _on_load_init_with_tts(tts_on: bool): | |
| chat_history, scores_state, meta_state, finished_state, turns_state = init_state() | |
| # Play the intro message via TTS if enabled | |
| tts_path = synthesize_tts(chat_history[-1][1]) if bool(tts_on) else None | |
| return chat_history, scores_state, meta_state, finished_state, turns_state, tts_path | |
| def _play_intro_tts(tts_on: bool): | |
| if not bool(tts_on): | |
| return None | |
| try: | |
| return synthesize_tts(INTRO_MESSAGE) | |
| except Exception: | |
| return None | |
| def create_demo(): | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| css=''' | |
| /* Voice bubble styles - clean and centered */ | |
| #voice-bubble { | |
| width: 240px; height: 240px; border-radius: 9999px; margin: 40px auto; | |
| display: flex; align-items: center; justify-content: center; | |
| background: linear-gradient(135deg, #6ee7b7 0%, #34d399 100%); | |
| box-shadow: 0 20px 40px rgba(16,185,129,0.3), 0 0 0 1px rgba(255,255,255,0.1) inset; | |
| transition: all 250ms cubic-bezier(0.4, 0, 0.2, 1); | |
| cursor: default; /* green circle itself is not clickable */ | |
| pointer-events: none; /* ignore clicks on the green circle */ | |
| position: relative; | |
| } | |
| #voice-bubble:hover { | |
| transform: translateY(-2px) scale(1.02); | |
| box-shadow: 0 25px 50px rgba(16,185,129,0.4), 0 0 0 1px rgba(255,255,255,0.15) inset; | |
| } | |
| #voice-bubble:active { transform: translateY(0px) scale(0.98); } | |
| #voice-bubble.listening { | |
| animation: bubble-pulse 1.5s ease-in-out infinite; | |
| background: linear-gradient(135deg, #fb7185 0%, #ef4444 100%); | |
| box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 1px rgba(255,255,255,0.1) inset; | |
| } | |
| @keyframes bubble-pulse { | |
| 0%, 100% { transform: scale(1.0); box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 0 rgba(239,68,68,0.5); } | |
| 50% { transform: scale(1.05); box-shadow: 0 25px 50px rgba(239,68,68,0.5), 0 0 0 15px rgba(239,68,68,0.0); } | |
| } | |
| /* Hide microphone dropdown selector only */ | |
| #voice-bubble select { display: none !important; } | |
| #voice-bubble .source-selection { display: none !important; } | |
| #voice-bubble label[for] { display: none !important; } | |
| /* Make the inner button the only clickable target */ | |
| #voice-bubble button { pointer-events: auto; cursor: pointer; } | |
| /* Hide TTS player UI but keep it in DOM for autoplay */ | |
| #tts-player { width: 0 !important; height: 0 !important; opacity: 0 !important; position: absolute; pointer-events: none; } | |
| /* Settings and back buttons in top right - compact */ | |
| #settings-btn { | |
| position: absolute; top: 16px; right: 16px; z-index: 10; | |
| width: auto !important; min-width: 100px !important; max-width: 120px !important; | |
| padding: 8px 16px !important; | |
| } | |
| #back-btn { | |
| position: absolute; top: 8px; right: 8px; z-index: 10; | |
| width: auto !important; min-width: 80px !important; max-width: 100px !important; | |
| padding: 8px 16px !important; | |
| } | |
| /* Play Intro button positioned under Settings */ | |
| #intro-btn { | |
| position: absolute; top: 60px; right: 16px; z-index: 10; | |
| width: auto !important; min-width: 100px !important; max-width: 120px !important; | |
| padding: 8px 16px !important; | |
| } | |
| ''' | |
| ) as demo: | |
| # Main view | |
| with gr.Column(visible=True) as main_view: | |
| # Settings button (top right, only in main view) | |
| settings_btn = gr.Button("⚙️ Settings", elem_id="settings-btn", size="sm") | |
| gr.Markdown( | |
| """ | |
| ### Conversational Assessment for Responsive Engagement (CARE) Notes | |
| Tap on 'Record' to start speaking, then tap on 'Stop' to stop recording. | |
| """ | |
| ) | |
| intro_play_btn = gr.Button("▶️ Start", elem_id="intro-btn", variant="secondary", size="sm") | |
| # Microphone component styled as central bubble (tap to record/stop) | |
| audio_main = gr.Microphone(type="filepath", label=None, elem_id="voice-bubble", show_label=False) | |
| # Hidden text input placeholder for pipeline compatibility | |
| text_main = gr.Textbox(value="", visible=False) | |
| # Autoplay clinician voice output (player hidden with CSS) | |
| tts_audio_main = gr.Audio(label=None, interactive=False, autoplay=True, show_label=False, elem_id="tts-player") | |
| # Final summaries (shown after assessment ends) | |
| main_summary = gr.Markdown(visible=False) | |
| # Settings view (initially hidden) | |
| with gr.Column(visible=False) as settings_view: | |
| back_btn = gr.Button("← Back", elem_id="back-btn", size="sm") | |
| gr.Markdown("## Settings") | |
| chatbot = gr.Chatbot(height=360, type="tuples", label="Conversation") | |
| with gr.Row(): | |
| text_adv = gr.Textbox(placeholder="Type your message and press Enter", scale=4) | |
| send_adv_btn = gr.Button("Send", scale=1) | |
| score_json = gr.JSON(label="PHQ-9 Assessment (live)") | |
| severity_label = gr.Label(label="Severity") | |
| threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)") | |
| tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT) | |
| with gr.Row(): | |
| tts_provider_dd = gr.Dropdown(choices=["Coqui", "gTTS"], value="Coqui", label="TTS Provider") | |
| coqui_model_tb = gr.Textbox(value=os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits"), label="Coqui Model") | |
| coqui_speaker_dd = gr.Dropdown(choices=list_coqui_speakers(os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")), value="p225", label="Coqui Speaker") | |
| tts_audio = gr.Audio(label="Clinician voice", interactive=False, autoplay=False, visible=False) | |
| model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it") | |
| with gr.Row(): | |
| apply_model_btn = gr.Button("Apply model (no restart)") | |
| # apply_model_restart_btn = gr.Button("Apply model and restart") | |
| model_status = gr.Markdown(value=f"Current model: `{current_model_id}`") | |
| # App state | |
| chat_state = gr.State() | |
| scores_state = gr.State() | |
| meta_state = gr.State() | |
| finished_state = gr.State() | |
| turns_state = gr.State() | |
| feats_state = gr.State() | |
| # Initialize on load (no autoplay due to browser policies) | |
| demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state]) | |
| # View navigation | |
| settings_btn.click( | |
| fn=lambda: (gr.update(visible=False), gr.update(visible=True)), | |
| inputs=None, | |
| outputs=[main_view, settings_view] | |
| ) | |
| back_btn.click( | |
| fn=lambda: (gr.update(visible=True), gr.update(visible=False)), | |
| inputs=None, | |
| outputs=[main_view, settings_view] | |
| ) | |
| # Explicit user gesture to play intro TTS (works across browsers) | |
| intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main]) | |
| # Wire interactions | |
| def _process_with_tts(audio, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist): | |
| result = process_turn(audio, text, chat, th, tts_on, finished, turns, scores, meta) | |
| chat_history, display_json, severity, finished_o, turns_o, _, _, _, last_tts = result | |
| # Accumulate last audio features | |
| feats_hist = feats_hist or [] | |
| last_feats = (display_json or {}).get("Last_Audio_Features") or {} | |
| if isinstance(last_feats, dict) and last_feats: | |
| feats_hist = list(feats_hist) + [last_feats] | |
| if tts_on and chat_history and chat_history[-1][1]: | |
| new_path = synthesize_tts(chat_history[-1][1], provider=provider, coqui_model_name=coqui_model, coqui_speaker=coqui_speaker) | |
| else: | |
| new_path = None | |
| # If finished, hide the mic and display summaries in Main | |
| if finished_o: | |
| # Run full explainability and reflection | |
| exp_full = explainability_full(chat_history, display_json.get("Confidences", []), feats_hist) | |
| reflect = reflection_module(display_json.get("PHQ9_Scores", {}), display_json.get("Confidences", []), display_json.get("Explainability_Light", {}), exp_full, float(th)) | |
| display_json["Explainability_Full"] = exp_full | |
| display_json["Reflection_Report"] = reflect | |
| # Use reflection outputs to set final meta | |
| final_sev = reflect.get("severity_label") or severity | |
| final_total = reflect.get("final_total") or display_json.get("Total_Score") | |
| patient_md = build_patient_summary(chat_history, {"Severity": final_sev, "Total_Score": final_total}, display_json) | |
| clinician_md = build_clinician_summary(chat_history, {"Severity": final_sev, "Total_Score": final_total}, display_json) | |
| summary_md = patient_md + "\n\n---\n\n" + clinician_md | |
| return chat_history, display_json, severity, finished_o, turns_o, gr.update(visible=False), None, new_path, new_path, gr.update(value=summary_md, visible=True), feats_hist | |
| return chat_history, display_json, severity, finished_o, turns_o, None, None, new_path, new_path, gr.update(visible=False), feats_hist | |
| audio_main.stop_recording( | |
| fn=_process_with_tts, | |
| inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state], | |
| outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state], | |
| queue=True, | |
| api_name="message", | |
| ) | |
| # Text input flow from Advanced tab | |
| def _process_text_and_clear(text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist): | |
| res = _process_with_tts(None, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker, feats_hist) | |
| return (*res, "") | |
| text_adv.submit( | |
| fn=_process_text_and_clear, | |
| inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state], | |
| outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state, text_adv], | |
| queue=True, | |
| ) | |
| send_adv_btn.click( | |
| fn=_process_text_and_clear, | |
| inputs=[text_adv, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd, feats_state], | |
| outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main, main_summary, feats_state, text_adv], | |
| queue=True, | |
| ) | |
| # Tap bubble to toggle microphone record/stop via JS | |
| # This JS is no longer needed as the bubble is the mic | |
| # voice_bubble.click( | |
| # None, | |
| # inputs=None, | |
| # outputs=None, | |
| # js="() => {\n const bubble = document.getElementById('voice-bubble');\n const root = document.getElementById('hidden-mic');\n if (!root) return;\n let didClick = false;\n const wc = root.querySelector && root.querySelector('gradio-audio');\n if (wc && wc.shadowRoot) {\n const btns = Array.from(wc.shadowRoot.querySelectorAll('button')).filter(b => !b.disabled);\n const txt = (b) => ((b.getAttribute('aria-label')||'') + ' ' + (b.textContent||'')).toLowerCase();\n const stopBtn = btns.find(b => txt(b).includes('stop'));\n const recBtn = btns.find(b => { const t = txt(b); return t.includes('record') || t.includes('start') || t.includes('microphone') || t.includes('mic'); });\n if (stopBtn) { stopBtn.click(); didClick = true; } else if (recBtn) { recBtn.click(); didClick = true; } else if (btns[0]) { btns[0].click(); didClick = true; }\n }\n if (!didClick) {\n const candidates = Array.from(root.querySelectorAll('button, [role=\\'button\\']')).filter(el => !el.disabled);\n if (candidates.length) { candidates[0].click(); didClick = true; }\n }\n if (bubble && didClick) bubble.classList.toggle('listening');\n }", | |
| # ) | |
| # No reset button in Main tab anymore | |
| # Model switch handlers | |
| def _on_apply_model(mid: str): | |
| msg = set_current_model_id(mid) | |
| return f"Current model: `{current_model_id}`\n\n{msg}" | |
| def _on_apply_model_restart(mid: str): | |
| msg = apply_model_and_restart(mid) | |
| return f"{msg}" | |
| apply_model_btn.click(fn=_on_apply_model, inputs=[model_id_tb], outputs=[model_status]) | |
| # apply_model_restart_btn.click(fn=_on_apply_model_restart, inputs=[model_id_tb], outputs=[model_status]) | |
| return demo | |
| demo = create_demo() | |
| if __name__ == "__main__": | |
| # For local dev | |
| demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860"))) | |