Spaces:
Sleeping
Sleeping
Akis Giannoukos
Implement Coqui TTS integration with model and speaker selection in demo interface; update requirements to include coqui-tts package.
aec1268
| import os | |
| import json | |
| import re | |
| import time | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| # Audio processing | |
| import soundfile as sf | |
| import librosa | |
| # Models | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| pipeline, | |
| ) | |
| from gtts import gTTS | |
| import spaces | |
| import threading | |
| # --------------------------- | |
| # Configuration | |
| # --------------------------- | |
| DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "google/gemma-2-2b-it") | |
| DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en") | |
| CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8")) | |
| MAX_TURNS = int(os.getenv("MAX_TURNS", "12")) | |
| USE_TTS_DEFAULT = os.getenv("USE_TTS", "true").strip().lower() == "true" | |
| CONFIG_PATH = os.getenv("MODEL_CONFIG_PATH", "model_config.json") | |
| def _load_model_id_from_config() -> str: | |
| try: | |
| if os.path.exists(CONFIG_PATH): | |
| with open(CONFIG_PATH, "r") as f: | |
| data = json.load(f) | |
| if isinstance(data, dict) and data.get("model_id"): | |
| return str(data["model_id"]) | |
| except Exception: | |
| pass | |
| return DEFAULT_CHAT_MODEL_ID | |
| current_model_id = _load_model_id_from_config() | |
| # --------------------------- | |
| # Lazy singletons for pipelines | |
| # --------------------------- | |
| _asr_pipe = None | |
| _gen_pipe = None | |
| _tokenizer = None | |
| def _hf_device() -> int: | |
| return 0 if torch.cuda.is_available() else -1 | |
| def get_asr_pipeline(): | |
| global _asr_pipe | |
| if _asr_pipe is None: | |
| _asr_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=DEFAULT_ASR_MODEL_ID, | |
| device=_hf_device(), | |
| ) | |
| return _asr_pipe | |
| def get_textgen_pipeline(): | |
| global _gen_pipe | |
| if _gen_pipe is None: | |
| # Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID | |
| if torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported(): | |
| _dtype = torch.bfloat16 | |
| elif torch.cuda.is_available(): | |
| _dtype = torch.float16 | |
| else: | |
| _dtype = torch.float32 | |
| _gen_pipe = pipeline( | |
| task="text-generation", | |
| model=current_model_id, | |
| tokenizer=current_model_id, | |
| device=_hf_device(), | |
| torch_dtype=_dtype, | |
| ) | |
| return _gen_pipe | |
| def set_current_model_id(new_model_id: str) -> str: | |
| global current_model_id, _gen_pipe | |
| new_model_id = (new_model_id or "").strip() | |
| if not new_model_id: | |
| return "Model id is empty; keeping current model." | |
| if new_model_id == current_model_id: | |
| return f"Model unchanged: `{current_model_id}`" | |
| current_model_id = new_model_id | |
| _gen_pipe = None # force reload on next use | |
| return f"Model switched to `{current_model_id}` (pipeline will reload on next generation)." | |
| def persist_model_id(new_model_id: str) -> None: | |
| try: | |
| with open(CONFIG_PATH, "w") as f: | |
| json.dump({"model_id": new_model_id}, f) | |
| except Exception: | |
| pass | |
| def apply_model_and_restart(new_model_id: str) -> str: | |
| mid = (new_model_id or "").strip() | |
| if not mid: | |
| return "Model id is empty; not restarting." | |
| persist_model_id(mid) | |
| set_current_model_id(mid) | |
| # Graceful delayed exit so response can flush | |
| def _exit_later(): | |
| time.sleep(0.25) | |
| os._exit(0) | |
| threading.Thread(target=_exit_later, daemon=True).start() | |
| return f"Restarting with model `{mid}`..." | |
| # --------------------------- | |
| # Utilities | |
| # --------------------------- | |
| def safe_json_extract(text: str) -> Optional[Dict[str, Any]]: | |
| """Extract first JSON object from text.""" | |
| if not text: | |
| return None | |
| try: | |
| return json.loads(text) | |
| except Exception: | |
| pass | |
| # Fallback: find the first {...} block | |
| match = re.search(r"\{[\s\S]*\}", text) | |
| if match: | |
| try: | |
| return json.loads(match.group(0)) | |
| except Exception: | |
| return None | |
| return None | |
| def compute_audio_features(audio_path: str) -> Dict[str, float]: | |
| """Compute lightweight prosodic features as a proxy for OpenSMILE. | |
| Returns a dictionary with summary statistics. | |
| """ | |
| try: | |
| y, sr = librosa.load(audio_path, sr=16000, mono=True) | |
| if len(y) == 0: | |
| return {} | |
| # Frame-based features | |
| hop_length = 512 | |
| frame_length = 1024 | |
| rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0] | |
| zcr = librosa.feature.zero_crossing_rate(y, frame_length=frame_length, hop_length=hop_length)[0] | |
| centroid = librosa.feature.spectral_centroid(y=y, sr=sr, n_fft=2048, hop_length=hop_length)[0] | |
| # Pitch estimation (coarse) | |
| f0 = None | |
| try: | |
| f0 = librosa.yin(y, fmin=50, fmax=400, sr=sr, frame_length=frame_length, hop_length=hop_length) | |
| f0 = f0[np.isfinite(f0)] | |
| except Exception: | |
| f0 = None | |
| # Speaking rate rough proxy: voiced ratio per second | |
| energy = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0] | |
| voiced = energy > (np.median(energy) * 1.2) | |
| voiced_ratio = float(np.mean(voiced)) | |
| features = { | |
| "rms_mean": float(np.mean(rms)), | |
| "rms_std": float(np.std(rms)), | |
| "zcr_mean": float(np.mean(zcr)), | |
| "zcr_std": float(np.std(zcr)), | |
| "centroid_mean": float(np.mean(centroid)), | |
| "centroid_std": float(np.std(centroid)), | |
| "voiced_ratio": voiced_ratio, | |
| "duration_sec": float(len(y) / sr), | |
| } | |
| if f0 is not None and f0.size > 0: | |
| features.update({ | |
| "f0_median": float(np.median(f0)), | |
| "f0_iqr": float(np.percentile(f0, 75) - np.percentile(f0, 25)), | |
| }) | |
| return features | |
| except Exception: | |
| return {} | |
| def detect_explicit_suicidality(text: Optional[str]) -> bool: | |
| if not text: | |
| return False | |
| t = text.lower() | |
| patterns = [ | |
| r"\bkill myself\b", | |
| r"\bend my life\b", | |
| r"\bend it all\b", | |
| r"\bcommit suicide\b", | |
| r"\bsuicide\b", | |
| r"\bself[-\s]?harm\b", | |
| r"\bhurt myself\b", | |
| r"\bno reason to live\b", | |
| r"\bwant to die\b", | |
| r"\bending it\b", | |
| ] | |
| for pat in patterns: | |
| if re.search(pat, t): | |
| return True | |
| return False | |
| def synthesize_tts( | |
| text: Optional[str], | |
| provider: str = "Coqui", | |
| coqui_model_name: Optional[str] = None, | |
| coqui_speaker: Optional[str] = None, | |
| ) -> Optional[str]: | |
| if not text: | |
| return None | |
| ts = int(time.time() * 1000) | |
| provider_norm = (provider or "Coqui").strip().lower() | |
| # Try Coqui first if requested | |
| if provider_norm == "coqui": | |
| try: | |
| # coqui-tts uses the same import path TTS.api | |
| from TTS.api import TTS as CoquiTTS # type: ignore | |
| model_name = (coqui_model_name or os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")).strip() | |
| engine = CoquiTTS(model_name=model_name, progress_bar=False) | |
| out_path_wav = f"/tmp/tts_{ts}.wav" | |
| kwargs = {} | |
| if coqui_speaker: | |
| kwargs["speaker"] = coqui_speaker | |
| engine.tts_to_file(text=text, file_path=out_path_wav, **kwargs) | |
| return out_path_wav | |
| except Exception: | |
| pass | |
| # Fallback to gTTS | |
| try: | |
| out_path = f"/tmp/tts_{ts}.mp3" | |
| tts = gTTS(text=text, lang="en") | |
| tts.save(out_path) | |
| return out_path | |
| except Exception: | |
| return None | |
| def list_coqui_speakers(model_name: str) -> List[str]: | |
| try: | |
| from TTS.api import TTS as CoquiTTS # type: ignore | |
| engine = CoquiTTS(model_name=model_name, progress_bar=False) | |
| # Try common attributes | |
| if hasattr(engine, "speakers") and isinstance(engine.speakers, list): | |
| return [str(s) for s in engine.speakers] | |
| if hasattr(engine, "speaker_manager") and hasattr(engine.speaker_manager, "speaker_names"): | |
| return list(engine.speaker_manager.speaker_names) | |
| except Exception: | |
| pass | |
| # Reasonable defaults for VCTK multi-speaker | |
| return ["p225", "p227", "p231", "p233", "p236"] | |
| def severity_from_total(total_score: int) -> str: | |
| if total_score <= 4: | |
| return "Minimal Depression" | |
| if total_score <= 9: | |
| return "Mild Depression" | |
| if total_score <= 14: | |
| return "Moderate Depression" | |
| if total_score <= 19: | |
| return "Moderately Severe Depression" | |
| return "Severe Depression" | |
| def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str: | |
| """Convert chatbot history [(user, assistant), ...] to a plain text transcript.""" | |
| lines = [] | |
| for user, assistant in chat_history: | |
| if user: | |
| lines.append(f"Patient: {user}") | |
| if assistant: | |
| lines.append(f"Clinician: {assistant}") | |
| return "\n".join(lines) | |
| def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str: | |
| transcript = transcript_to_text(chat_history) | |
| system_prompt = ( | |
| "You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms " | |
| "without listing the nine questions explicitly. Keep tone empathetic, natural, and human. " | |
| "Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, " | |
| "sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts." | |
| ) | |
| user_prompt = ( | |
| "Conversation so far (Patient and Clinician turns):\n\n" + transcript + | |
| "\n\nRespond with a single short clinician-style question for the patient." | |
| ) | |
| pipe = get_textgen_pipeline() | |
| tokenizer = pipe.tokenizer | |
| combined_prompt = system_prompt + "\n\n" + user_prompt | |
| messages = [ | |
| {"role": "user", "content": combined_prompt}, | |
| ] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| gen = pipe( | |
| prompt, | |
| max_new_tokens=96, | |
| temperature=0.7, | |
| do_sample=True, | |
| top_p=0.9, | |
| top_k=50, | |
| pad_token_id=tokenizer.eos_token_id, | |
| return_full_text=False, | |
| ) | |
| reply = gen[0]["generated_text"].strip() | |
| # Ensure it's a single concise question/sentence | |
| if len(reply) > 300: | |
| reply = reply[:300].rstrip() + "…" | |
| return reply | |
| def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str, float]) -> Dict[str, Any]: | |
| """Ask the LLM to produce PHQ-9 scores and confidences as JSON. Fallback if parsing fails.""" | |
| transcript = transcript_to_text(chat_history) | |
| features_json = json.dumps(features, ensure_ascii=False) | |
| system_prompt = ( | |
| "You evaluate an on-going clinician-patient conversation to infer a PHQ-9 assessment. " | |
| "Return ONLY a JSON object with: PHQ9_Scores (interest,mood,sleep,energy,appetite,self_worth,concentration,motor,suicidal_thoughts; each 0-3), " | |
| "Confidences (list of 9 floats 0-1 in the same order), Total_Score (0-27), Severity (string), Confidence (min of confidences), " | |
| "and High_Risk (boolean, true if any suicidal risk)." | |
| ) | |
| user_prompt = ( | |
| "Conversation transcript:"\ | |
| f"\n{transcript}\n\n" | |
| f"Acoustic features summary (approximate):\n{features_json}\n\n" | |
| "Instructions: Infer PHQ9_Scores (0-3 per item), estimate Confidences per item, compute Total_Score and overall Severity. " | |
| "Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose." | |
| ) | |
| pipe = get_textgen_pipeline() | |
| tokenizer = pipe.tokenizer | |
| combined_prompt = system_prompt + "\n\n" + user_prompt | |
| messages = [ | |
| {"role": "user", "content": combined_prompt}, | |
| ] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| # Use deterministic decoding to avoid CUDA sampling edge cases on some models | |
| gen = pipe( | |
| prompt, | |
| max_new_tokens=256, | |
| temperature=0.0, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| return_full_text=False, | |
| ) | |
| out_text = gen[0]["generated_text"] | |
| parsed = safe_json_extract(out_text) | |
| # Validate and coerce | |
| if parsed is None or "PHQ9_Scores" not in parsed: | |
| # Simple fallback heuristic: neutral scores with low confidence | |
| scores = { | |
| "interest": 1, | |
| "mood": 1, | |
| "sleep": 1, | |
| "energy": 1, | |
| "appetite": 1, | |
| "self_worth": 1, | |
| "concentration": 1, | |
| "motor": 1, | |
| "suicidal_thoughts": 0, | |
| } | |
| confidences = [0.5] * 9 | |
| total = int(sum(scores.values())) | |
| return { | |
| "PHQ9_Scores": scores, | |
| "Confidences": confidences, | |
| "Total_Score": total, | |
| "Severity": severity_from_total(total), | |
| "Confidence": float(min(confidences)), | |
| "High_Risk": False, | |
| } | |
| try: | |
| # Coerce types and compute derived values if missing | |
| scores = parsed.get("PHQ9_Scores", {}) | |
| # Ensure all keys present | |
| keys = [ | |
| "interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts" | |
| ] | |
| for k in keys: | |
| scores[k] = int(max(0, min(3, int(scores.get(k, 0))))) | |
| confidences = parsed.get("Confidences", []) | |
| if not isinstance(confidences, list) or len(confidences) != 9: | |
| confidences = [float(parsed.get("Confidence", 0.5))] * 9 | |
| confidences = [float(max(0.0, min(1.0, c))) for c in confidences] | |
| total = int(sum(scores.values())) | |
| severity = parsed.get("Severity") or severity_from_total(total) | |
| overall_conf = float(parsed.get("Confidence", min(confidences))) | |
| # Conservative high-risk detection: require explicit language or high suicidal_thoughts score | |
| # Extract last patient message | |
| last_patient = "" | |
| for user_text, assistant_text in reversed(chat_history): | |
| if user_text: | |
| last_patient = user_text | |
| break | |
| explicit_flag = detect_explicit_suicidality(last_patient) or detect_explicit_suicidality(transcript) | |
| high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2)) | |
| return { | |
| "PHQ9_Scores": scores, | |
| "Confidences": confidences, | |
| "Total_Score": total, | |
| "Severity": severity, | |
| "Confidence": overall_conf, | |
| "High_Risk": high_risk, | |
| } | |
| except Exception: | |
| # Final fallback | |
| scores = parsed.get("PHQ9_Scores", {}) if isinstance(parsed, dict) else {} | |
| if not scores: | |
| scores = {k: 1 for k in [ | |
| "interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts" | |
| ]} | |
| confidences = [0.5] * 9 | |
| total = int(sum(scores.values())) | |
| return { | |
| "PHQ9_Scores": scores, | |
| "Confidences": confidences, | |
| "Total_Score": total, | |
| "Severity": severity_from_total(total), | |
| "Confidence": float(min(confidences)), | |
| "High_Risk": False, | |
| } | |
| def transcribe_audio(audio_path: Optional[str]) -> str: | |
| if not audio_path: | |
| return "" | |
| try: | |
| asr = get_asr_pipeline() | |
| result = asr(audio_path) | |
| if isinstance(result, dict) and "text" in result: | |
| return result["text"].strip() | |
| if isinstance(result, list) and len(result) > 0 and "text" in result[0]: | |
| return result[0]["text"].strip() | |
| except Exception: | |
| pass | |
| return "" | |
| # --------------------------- | |
| # Gradio app logic | |
| # --------------------------- | |
| INTRO_MESSAGE = ( | |
| "Hi, I'm an assistant, and I will ask you some questions about how you've been doing." | |
| "We'll record our conversation, and we will give you a written copy of it." | |
| "From our conversation, we will send a written copy to the clinician, we will give a summary of what you are experiencing based on a questionnaire, called the Patient Health Questionnaire (PHQ-9), and we will give a summary of what your voice is like." | |
| "We will send this to the clinician, and the clinician will follow up with you." | |
| "To start, how has your mood been over the past couple of weeks?" | |
| ) | |
| def init_state() -> Tuple[List[Tuple[str, str]], Dict[str, Any], Dict[str, Any], bool, int]: | |
| chat_history: List[Tuple[str, str]] = [("", INTRO_MESSAGE)] | |
| scores = {} | |
| meta = {"Severity": None, "Total_Score": None, "Confidence": 0.0} | |
| finished = False | |
| turns = 0 | |
| return chat_history, scores, meta, finished, turns | |
| def process_turn( | |
| audio_path: Optional[str], | |
| text_input: Optional[str], | |
| chat_history: List[Tuple[str, str]], | |
| threshold: float, | |
| tts_enabled: bool, | |
| finished: Optional[bool], | |
| turns: Optional[int], | |
| prev_scores: Dict[str, Any], | |
| prev_meta: Dict[str, Any], | |
| ): | |
| # If already finished, do nothing | |
| finished = bool(finished) if finished is not None else False | |
| turns = int(turns) if isinstance(turns, int) else 0 | |
| if finished: | |
| return ( | |
| chat_history, | |
| {"info": "Assessment complete."}, | |
| prev_meta.get("Severity", ""), | |
| finished, | |
| turns, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| patient_text = (text_input or "").strip() | |
| audio_features: Dict[str, float] = {} | |
| if audio_path: | |
| # Transcribe first | |
| transcribed = transcribe_audio(audio_path) | |
| if transcribed: | |
| patient_text = (patient_text + " ").strip() + transcribed if patient_text else transcribed | |
| # Extract features | |
| audio_features = compute_audio_features(audio_path) | |
| if not patient_text: | |
| # Ask user for input | |
| chat_history.append(("", "I didn't catch that. Could you share a bit about how you've been feeling?")) | |
| return ( | |
| chat_history, | |
| prev_scores or {}, | |
| prev_meta.get("Severity", ""), | |
| finished, | |
| turns, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| # Add patient's message | |
| chat_history.append((patient_text, None)) | |
| # Scoring agent | |
| scoring = scoring_agent_infer(chat_history, audio_features) | |
| scores = scoring.get("PHQ9_Scores", {}) | |
| confidences = scoring.get("Confidences", []) | |
| total = scoring.get("Total_Score", 0) | |
| severity = scoring.get("Severity", severity_from_total(total)) | |
| overall_conf = float(scoring.get("Confidence", min(confidences) if confidences else 0.0)) | |
| # Override high-risk to reduce false positives: rely on explicit text or high item score | |
| # Extract last patient message | |
| last_patient = "" | |
| for user_text, assistant_text in reversed(chat_history): | |
| if user_text: | |
| last_patient = user_text | |
| break | |
| explicit_flag = detect_explicit_suicidality(last_patient) | |
| high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2)) | |
| meta = {"Severity": severity, "Total_Score": total, "Confidence": overall_conf} | |
| # Termination conditions | |
| min_conf = float(min(confidences)) if confidences else 0.0 | |
| turns += 1 | |
| done = high_risk or (min_conf >= threshold) or (turns >= MAX_TURNS) | |
| if high_risk: | |
| closing = ( | |
| "I’m concerned about your safety based on what you shared. " | |
| "If you are in danger or need immediate help, please call 988 in the U.S. or your local emergency number. " | |
| "I'll end the assessment now and display emergency resources." | |
| ) | |
| chat_history[-1] = (chat_history[-1][0], closing) | |
| finished = True | |
| elif done: | |
| summary = ( | |
| f"Thank you for sharing. Based on our conversation, your responses suggest {severity.lower()}. " | |
| "We can stop here." | |
| ) | |
| chat_history[-1] = (chat_history[-1][0], summary) | |
| finished = True | |
| else: | |
| # Generate next clinician question | |
| reply = generate_recording_agent_reply(chat_history) | |
| chat_history[-1] = (chat_history[-1][0], reply) | |
| # TTS for the latest clinician message, if enabled | |
| tts_path = synthesize_tts(chat_history[-1][1]) if tts_enabled else None | |
| # Build a compact JSON for display | |
| display_json = { | |
| "PHQ9_Scores": scores, | |
| "Confidences": confidences, | |
| "Total_Score": total, | |
| "Severity": severity, | |
| "Confidence": overall_conf, | |
| "High_Risk": high_risk, | |
| } | |
| # Clear inputs after processing | |
| return ( | |
| chat_history, | |
| display_json, | |
| severity, | |
| finished, | |
| turns, | |
| None, | |
| None, | |
| tts_path, | |
| tts_path, | |
| ) | |
| def reset_app(): | |
| return init_state() | |
| # --------------------------- | |
| # UI | |
| # --------------------------- | |
| def _on_load_init(): | |
| return init_state() | |
| def _on_load_init_with_tts(tts_on: bool): | |
| chat_history, scores_state, meta_state, finished_state, turns_state = init_state() | |
| # Play the intro message via TTS if enabled | |
| tts_path = synthesize_tts(chat_history[-1][1]) if bool(tts_on) else None | |
| return chat_history, scores_state, meta_state, finished_state, turns_state, tts_path | |
| def _play_intro_tts(tts_on: bool): | |
| if not bool(tts_on): | |
| return None | |
| try: | |
| return synthesize_tts(INTRO_MESSAGE) | |
| except Exception: | |
| return None | |
| def create_demo(): | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| css=''' | |
| /* Voice bubble styles - clean and centered */ | |
| #voice-bubble { | |
| width: 240px; height: 240px; border-radius: 9999px; margin: 40px auto; | |
| display: flex; align-items: center; justify-content: center; | |
| background: linear-gradient(135deg, #6ee7b7 0%, #34d399 100%); | |
| box-shadow: 0 20px 40px rgba(16,185,129,0.3), 0 0 0 1px rgba(255,255,255,0.1) inset; | |
| transition: all 250ms cubic-bezier(0.4, 0, 0.2, 1); | |
| cursor: default; /* green circle itself is not clickable */ | |
| pointer-events: none; /* ignore clicks on the green circle */ | |
| position: relative; | |
| } | |
| #voice-bubble:hover { | |
| transform: translateY(-2px) scale(1.02); | |
| box-shadow: 0 25px 50px rgba(16,185,129,0.4), 0 0 0 1px rgba(255,255,255,0.15) inset; | |
| } | |
| #voice-bubble:active { transform: translateY(0px) scale(0.98); } | |
| #voice-bubble.listening { | |
| animation: bubble-pulse 1.5s ease-in-out infinite; | |
| background: linear-gradient(135deg, #fb7185 0%, #ef4444 100%); | |
| box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 1px rgba(255,255,255,0.1) inset; | |
| } | |
| @keyframes bubble-pulse { | |
| 0%, 100% { transform: scale(1.0); box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 0 rgba(239,68,68,0.5); } | |
| 50% { transform: scale(1.05); box-shadow: 0 25px 50px rgba(239,68,68,0.5), 0 0 0 15px rgba(239,68,68,0.0); } | |
| } | |
| /* Hide microphone dropdown selector only */ | |
| #voice-bubble select { display: none !important; } | |
| #voice-bubble .source-selection { display: none !important; } | |
| #voice-bubble label[for] { display: none !important; } | |
| /* Make the inner button the only clickable target */ | |
| #voice-bubble button { pointer-events: auto; cursor: pointer; } | |
| /* Hide TTS player UI but keep it in DOM for autoplay */ | |
| #tts-player { width: 0 !important; height: 0 !important; opacity: 0 !important; position: absolute; pointer-events: none; } | |
| ''' | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| ### Conversational Assessment for Responsive Engagement (CARE) Notes | |
| Tap on 'Record' to start speaking, then tap on 'Stop' to stop recording. | |
| """ | |
| ) | |
| intro_play_btn = gr.Button("▶️ Play Intro", variant="secondary") | |
| with gr.Tabs(): | |
| with gr.TabItem("Main"): | |
| with gr.Column(): | |
| # Microphone component styled as central bubble (tap to record/stop) | |
| audio_main = gr.Microphone(type="filepath", label=None, elem_id="voice-bubble", show_label=False) | |
| # Hidden text input placeholder for pipeline compatibility | |
| text_main = gr.Textbox(value="", visible=False) | |
| # Autoplay clinician voice output (player hidden with CSS) | |
| tts_audio_main = gr.Audio(label=None, interactive=False, autoplay=True, show_label=False, elem_id="tts-player") | |
| with gr.TabItem("Advanced"): | |
| with gr.Column(): | |
| chatbot = gr.Chatbot(height=360, type="tuples", label="Conversation") | |
| score_json = gr.JSON(label="PHQ-9 Assessment (live)") | |
| severity_label = gr.Label(label="Severity") | |
| threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)") | |
| tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT) | |
| with gr.Row(): | |
| tts_provider_dd = gr.Dropdown(choices=["Coqui", "gTTS"], value="Coqui", label="TTS Provider") | |
| coqui_model_tb = gr.Textbox(value=os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits"), label="Coqui Model") | |
| coqui_speaker_dd = gr.Dropdown(choices=list_coqui_speakers(os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")), value="p225", label="Coqui Speaker") | |
| tts_audio = gr.Audio(label="Clinician voice", interactive=False, autoplay=False, visible=False) | |
| model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it") | |
| with gr.Row(): | |
| apply_model_btn = gr.Button("Apply model (no restart)") | |
| # apply_model_restart_btn = gr.Button("Apply model and restart") | |
| model_status = gr.Markdown(value=f"Current model: `{current_model_id}`") | |
| # App state | |
| chat_state = gr.State() | |
| scores_state = gr.State() | |
| meta_state = gr.State() | |
| finished_state = gr.State() | |
| turns_state = gr.State() | |
| # Initialize on load (no autoplay due to browser policies) | |
| demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state]) | |
| # Explicit user gesture to play intro TTS (works across browsers) | |
| intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main]) | |
| # Wire interactions | |
| def _process_with_tts(audio, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker): | |
| result = process_turn(audio, text, chat, th, tts_on, finished, turns, scores, meta) | |
| chat_history, display_json, severity, finished_o, turns_o, _, _, _, last_tts = result | |
| if tts_on and chat_history and chat_history[-1][1]: | |
| new_path = synthesize_tts(chat_history[-1][1], provider=provider, coqui_model_name=coqui_model, coqui_speaker=coqui_speaker) | |
| else: | |
| new_path = None | |
| return chat_history, display_json, severity, finished_o, turns_o, None, None, new_path, new_path | |
| audio_main.stop_recording( | |
| fn=_process_with_tts, | |
| inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd], | |
| outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main], | |
| queue=True, | |
| api_name="message", | |
| ) | |
| # Tap bubble to toggle microphone record/stop via JS | |
| # This JS is no longer needed as the bubble is the mic | |
| # voice_bubble.click( | |
| # None, | |
| # inputs=None, | |
| # outputs=None, | |
| # js="() => {\n const bubble = document.getElementById('voice-bubble');\n const root = document.getElementById('hidden-mic');\n if (!root) return;\n let didClick = false;\n const wc = root.querySelector && root.querySelector('gradio-audio');\n if (wc && wc.shadowRoot) {\n const btns = Array.from(wc.shadowRoot.querySelectorAll('button')).filter(b => !b.disabled);\n const txt = (b) => ((b.getAttribute('aria-label')||'') + ' ' + (b.textContent||'')).toLowerCase();\n const stopBtn = btns.find(b => txt(b).includes('stop'));\n const recBtn = btns.find(b => { const t = txt(b); return t.includes('record') || t.includes('start') || t.includes('microphone') || t.includes('mic'); });\n if (stopBtn) { stopBtn.click(); didClick = true; } else if (recBtn) { recBtn.click(); didClick = true; } else if (btns[0]) { btns[0].click(); didClick = true; }\n }\n if (!didClick) {\n const candidates = Array.from(root.querySelectorAll('button, [role=\\'button\\']')).filter(el => !el.disabled);\n if (candidates.length) { candidates[0].click(); didClick = true; }\n }\n if (bubble && didClick) bubble.classList.toggle('listening');\n }", | |
| # ) | |
| # No reset button in Main tab anymore | |
| # Model switch handlers | |
| def _on_apply_model(mid: str): | |
| msg = set_current_model_id(mid) | |
| return f"Current model: `{current_model_id}`\n\n{msg}" | |
| def _on_apply_model_restart(mid: str): | |
| msg = apply_model_and_restart(mid) | |
| return f"{msg}" | |
| apply_model_btn.click(fn=_on_apply_model, inputs=[model_id_tb], outputs=[model_status]) | |
| # apply_model_restart_btn.click(fn=_on_apply_model_restart, inputs=[model_id_tb], outputs=[model_status]) | |
| return demo | |
| demo = create_demo() | |
| if __name__ == "__main__": | |
| # For local dev | |
| demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860"))) | |