care-notes / app.py
Akis Giannoukos
Implement Coqui TTS integration with model and speaker selection in demo interface; update requirements to include coqui-tts package.
aec1268
raw
history blame
30.6 kB
import os
import json
import re
import time
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import numpy as np
# Audio processing
import soundfile as sf
import librosa
# Models
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
)
from gtts import gTTS
import spaces
import threading
# ---------------------------
# Configuration
# ---------------------------
DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "google/gemma-2-2b-it")
DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en")
CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8"))
MAX_TURNS = int(os.getenv("MAX_TURNS", "12"))
USE_TTS_DEFAULT = os.getenv("USE_TTS", "true").strip().lower() == "true"
CONFIG_PATH = os.getenv("MODEL_CONFIG_PATH", "model_config.json")
def _load_model_id_from_config() -> str:
try:
if os.path.exists(CONFIG_PATH):
with open(CONFIG_PATH, "r") as f:
data = json.load(f)
if isinstance(data, dict) and data.get("model_id"):
return str(data["model_id"])
except Exception:
pass
return DEFAULT_CHAT_MODEL_ID
current_model_id = _load_model_id_from_config()
# ---------------------------
# Lazy singletons for pipelines
# ---------------------------
_asr_pipe = None
_gen_pipe = None
_tokenizer = None
def _hf_device() -> int:
return 0 if torch.cuda.is_available() else -1
def get_asr_pipeline():
global _asr_pipe
if _asr_pipe is None:
_asr_pipe = pipeline(
"automatic-speech-recognition",
model=DEFAULT_ASR_MODEL_ID,
device=_hf_device(),
)
return _asr_pipe
def get_textgen_pipeline():
global _gen_pipe
if _gen_pipe is None:
# Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID
if torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
_dtype = torch.bfloat16
elif torch.cuda.is_available():
_dtype = torch.float16
else:
_dtype = torch.float32
_gen_pipe = pipeline(
task="text-generation",
model=current_model_id,
tokenizer=current_model_id,
device=_hf_device(),
torch_dtype=_dtype,
)
return _gen_pipe
def set_current_model_id(new_model_id: str) -> str:
global current_model_id, _gen_pipe
new_model_id = (new_model_id or "").strip()
if not new_model_id:
return "Model id is empty; keeping current model."
if new_model_id == current_model_id:
return f"Model unchanged: `{current_model_id}`"
current_model_id = new_model_id
_gen_pipe = None # force reload on next use
return f"Model switched to `{current_model_id}` (pipeline will reload on next generation)."
def persist_model_id(new_model_id: str) -> None:
try:
with open(CONFIG_PATH, "w") as f:
json.dump({"model_id": new_model_id}, f)
except Exception:
pass
def apply_model_and_restart(new_model_id: str) -> str:
mid = (new_model_id or "").strip()
if not mid:
return "Model id is empty; not restarting."
persist_model_id(mid)
set_current_model_id(mid)
# Graceful delayed exit so response can flush
def _exit_later():
time.sleep(0.25)
os._exit(0)
threading.Thread(target=_exit_later, daemon=True).start()
return f"Restarting with model `{mid}`..."
# ---------------------------
# Utilities
# ---------------------------
def safe_json_extract(text: str) -> Optional[Dict[str, Any]]:
"""Extract first JSON object from text."""
if not text:
return None
try:
return json.loads(text)
except Exception:
pass
# Fallback: find the first {...} block
match = re.search(r"\{[\s\S]*\}", text)
if match:
try:
return json.loads(match.group(0))
except Exception:
return None
return None
def compute_audio_features(audio_path: str) -> Dict[str, float]:
"""Compute lightweight prosodic features as a proxy for OpenSMILE.
Returns a dictionary with summary statistics.
"""
try:
y, sr = librosa.load(audio_path, sr=16000, mono=True)
if len(y) == 0:
return {}
# Frame-based features
hop_length = 512
frame_length = 1024
rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
zcr = librosa.feature.zero_crossing_rate(y, frame_length=frame_length, hop_length=hop_length)[0]
centroid = librosa.feature.spectral_centroid(y=y, sr=sr, n_fft=2048, hop_length=hop_length)[0]
# Pitch estimation (coarse)
f0 = None
try:
f0 = librosa.yin(y, fmin=50, fmax=400, sr=sr, frame_length=frame_length, hop_length=hop_length)
f0 = f0[np.isfinite(f0)]
except Exception:
f0 = None
# Speaking rate rough proxy: voiced ratio per second
energy = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
voiced = energy > (np.median(energy) * 1.2)
voiced_ratio = float(np.mean(voiced))
features = {
"rms_mean": float(np.mean(rms)),
"rms_std": float(np.std(rms)),
"zcr_mean": float(np.mean(zcr)),
"zcr_std": float(np.std(zcr)),
"centroid_mean": float(np.mean(centroid)),
"centroid_std": float(np.std(centroid)),
"voiced_ratio": voiced_ratio,
"duration_sec": float(len(y) / sr),
}
if f0 is not None and f0.size > 0:
features.update({
"f0_median": float(np.median(f0)),
"f0_iqr": float(np.percentile(f0, 75) - np.percentile(f0, 25)),
})
return features
except Exception:
return {}
def detect_explicit_suicidality(text: Optional[str]) -> bool:
if not text:
return False
t = text.lower()
patterns = [
r"\bkill myself\b",
r"\bend my life\b",
r"\bend it all\b",
r"\bcommit suicide\b",
r"\bsuicide\b",
r"\bself[-\s]?harm\b",
r"\bhurt myself\b",
r"\bno reason to live\b",
r"\bwant to die\b",
r"\bending it\b",
]
for pat in patterns:
if re.search(pat, t):
return True
return False
def synthesize_tts(
text: Optional[str],
provider: str = "Coqui",
coqui_model_name: Optional[str] = None,
coqui_speaker: Optional[str] = None,
) -> Optional[str]:
if not text:
return None
ts = int(time.time() * 1000)
provider_norm = (provider or "Coqui").strip().lower()
# Try Coqui first if requested
if provider_norm == "coqui":
try:
# coqui-tts uses the same import path TTS.api
from TTS.api import TTS as CoquiTTS # type: ignore
model_name = (coqui_model_name or os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")).strip()
engine = CoquiTTS(model_name=model_name, progress_bar=False)
out_path_wav = f"/tmp/tts_{ts}.wav"
kwargs = {}
if coqui_speaker:
kwargs["speaker"] = coqui_speaker
engine.tts_to_file(text=text, file_path=out_path_wav, **kwargs)
return out_path_wav
except Exception:
pass
# Fallback to gTTS
try:
out_path = f"/tmp/tts_{ts}.mp3"
tts = gTTS(text=text, lang="en")
tts.save(out_path)
return out_path
except Exception:
return None
def list_coqui_speakers(model_name: str) -> List[str]:
try:
from TTS.api import TTS as CoquiTTS # type: ignore
engine = CoquiTTS(model_name=model_name, progress_bar=False)
# Try common attributes
if hasattr(engine, "speakers") and isinstance(engine.speakers, list):
return [str(s) for s in engine.speakers]
if hasattr(engine, "speaker_manager") and hasattr(engine.speaker_manager, "speaker_names"):
return list(engine.speaker_manager.speaker_names)
except Exception:
pass
# Reasonable defaults for VCTK multi-speaker
return ["p225", "p227", "p231", "p233", "p236"]
def severity_from_total(total_score: int) -> str:
if total_score <= 4:
return "Minimal Depression"
if total_score <= 9:
return "Mild Depression"
if total_score <= 14:
return "Moderate Depression"
if total_score <= 19:
return "Moderately Severe Depression"
return "Severe Depression"
def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str:
"""Convert chatbot history [(user, assistant), ...] to a plain text transcript."""
lines = []
for user, assistant in chat_history:
if user:
lines.append(f"Patient: {user}")
if assistant:
lines.append(f"Clinician: {assistant}")
return "\n".join(lines)
def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
transcript = transcript_to_text(chat_history)
system_prompt = (
"You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms "
"without listing the nine questions explicitly. Keep tone empathetic, natural, and human. "
"Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, "
"sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts."
)
user_prompt = (
"Conversation so far (Patient and Clinician turns):\n\n" + transcript +
"\n\nRespond with a single short clinician-style question for the patient."
)
pipe = get_textgen_pipeline()
tokenizer = pipe.tokenizer
combined_prompt = system_prompt + "\n\n" + user_prompt
messages = [
{"role": "user", "content": combined_prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
gen = pipe(
prompt,
max_new_tokens=96,
temperature=0.7,
do_sample=True,
top_p=0.9,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
return_full_text=False,
)
reply = gen[0]["generated_text"].strip()
# Ensure it's a single concise question/sentence
if len(reply) > 300:
reply = reply[:300].rstrip() + "…"
return reply
def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str, float]) -> Dict[str, Any]:
"""Ask the LLM to produce PHQ-9 scores and confidences as JSON. Fallback if parsing fails."""
transcript = transcript_to_text(chat_history)
features_json = json.dumps(features, ensure_ascii=False)
system_prompt = (
"You evaluate an on-going clinician-patient conversation to infer a PHQ-9 assessment. "
"Return ONLY a JSON object with: PHQ9_Scores (interest,mood,sleep,energy,appetite,self_worth,concentration,motor,suicidal_thoughts; each 0-3), "
"Confidences (list of 9 floats 0-1 in the same order), Total_Score (0-27), Severity (string), Confidence (min of confidences), "
"and High_Risk (boolean, true if any suicidal risk)."
)
user_prompt = (
"Conversation transcript:"\
f"\n{transcript}\n\n"
f"Acoustic features summary (approximate):\n{features_json}\n\n"
"Instructions: Infer PHQ9_Scores (0-3 per item), estimate Confidences per item, compute Total_Score and overall Severity. "
"Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose."
)
pipe = get_textgen_pipeline()
tokenizer = pipe.tokenizer
combined_prompt = system_prompt + "\n\n" + user_prompt
messages = [
{"role": "user", "content": combined_prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Use deterministic decoding to avoid CUDA sampling edge cases on some models
gen = pipe(
prompt,
max_new_tokens=256,
temperature=0.0,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
return_full_text=False,
)
out_text = gen[0]["generated_text"]
parsed = safe_json_extract(out_text)
# Validate and coerce
if parsed is None or "PHQ9_Scores" not in parsed:
# Simple fallback heuristic: neutral scores with low confidence
scores = {
"interest": 1,
"mood": 1,
"sleep": 1,
"energy": 1,
"appetite": 1,
"self_worth": 1,
"concentration": 1,
"motor": 1,
"suicidal_thoughts": 0,
}
confidences = [0.5] * 9
total = int(sum(scores.values()))
return {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity_from_total(total),
"Confidence": float(min(confidences)),
"High_Risk": False,
}
try:
# Coerce types and compute derived values if missing
scores = parsed.get("PHQ9_Scores", {})
# Ensure all keys present
keys = [
"interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts"
]
for k in keys:
scores[k] = int(max(0, min(3, int(scores.get(k, 0)))))
confidences = parsed.get("Confidences", [])
if not isinstance(confidences, list) or len(confidences) != 9:
confidences = [float(parsed.get("Confidence", 0.5))] * 9
confidences = [float(max(0.0, min(1.0, c))) for c in confidences]
total = int(sum(scores.values()))
severity = parsed.get("Severity") or severity_from_total(total)
overall_conf = float(parsed.get("Confidence", min(confidences)))
# Conservative high-risk detection: require explicit language or high suicidal_thoughts score
# Extract last patient message
last_patient = ""
for user_text, assistant_text in reversed(chat_history):
if user_text:
last_patient = user_text
break
explicit_flag = detect_explicit_suicidality(last_patient) or detect_explicit_suicidality(transcript)
high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2))
return {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity,
"Confidence": overall_conf,
"High_Risk": high_risk,
}
except Exception:
# Final fallback
scores = parsed.get("PHQ9_Scores", {}) if isinstance(parsed, dict) else {}
if not scores:
scores = {k: 1 for k in [
"interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts"
]}
confidences = [0.5] * 9
total = int(sum(scores.values()))
return {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity_from_total(total),
"Confidence": float(min(confidences)),
"High_Risk": False,
}
def transcribe_audio(audio_path: Optional[str]) -> str:
if not audio_path:
return ""
try:
asr = get_asr_pipeline()
result = asr(audio_path)
if isinstance(result, dict) and "text" in result:
return result["text"].strip()
if isinstance(result, list) and len(result) > 0 and "text" in result[0]:
return result[0]["text"].strip()
except Exception:
pass
return ""
# ---------------------------
# Gradio app logic
# ---------------------------
INTRO_MESSAGE = (
"Hi, I'm an assistant, and I will ask you some questions about how you've been doing."
"We'll record our conversation, and we will give you a written copy of it."
"From our conversation, we will send a written copy to the clinician, we will give a summary of what you are experiencing based on a questionnaire, called the Patient Health Questionnaire (PHQ-9), and we will give a summary of what your voice is like."
"We will send this to the clinician, and the clinician will follow up with you."
"To start, how has your mood been over the past couple of weeks?"
)
def init_state() -> Tuple[List[Tuple[str, str]], Dict[str, Any], Dict[str, Any], bool, int]:
chat_history: List[Tuple[str, str]] = [("", INTRO_MESSAGE)]
scores = {}
meta = {"Severity": None, "Total_Score": None, "Confidence": 0.0}
finished = False
turns = 0
return chat_history, scores, meta, finished, turns
@spaces.GPU
def process_turn(
audio_path: Optional[str],
text_input: Optional[str],
chat_history: List[Tuple[str, str]],
threshold: float,
tts_enabled: bool,
finished: Optional[bool],
turns: Optional[int],
prev_scores: Dict[str, Any],
prev_meta: Dict[str, Any],
):
# If already finished, do nothing
finished = bool(finished) if finished is not None else False
turns = int(turns) if isinstance(turns, int) else 0
if finished:
return (
chat_history,
{"info": "Assessment complete."},
prev_meta.get("Severity", ""),
finished,
turns,
None,
None,
None,
None,
)
patient_text = (text_input or "").strip()
audio_features: Dict[str, float] = {}
if audio_path:
# Transcribe first
transcribed = transcribe_audio(audio_path)
if transcribed:
patient_text = (patient_text + " ").strip() + transcribed if patient_text else transcribed
# Extract features
audio_features = compute_audio_features(audio_path)
if not patient_text:
# Ask user for input
chat_history.append(("", "I didn't catch that. Could you share a bit about how you've been feeling?"))
return (
chat_history,
prev_scores or {},
prev_meta.get("Severity", ""),
finished,
turns,
None,
None,
None,
None,
)
# Add patient's message
chat_history.append((patient_text, None))
# Scoring agent
scoring = scoring_agent_infer(chat_history, audio_features)
scores = scoring.get("PHQ9_Scores", {})
confidences = scoring.get("Confidences", [])
total = scoring.get("Total_Score", 0)
severity = scoring.get("Severity", severity_from_total(total))
overall_conf = float(scoring.get("Confidence", min(confidences) if confidences else 0.0))
# Override high-risk to reduce false positives: rely on explicit text or high item score
# Extract last patient message
last_patient = ""
for user_text, assistant_text in reversed(chat_history):
if user_text:
last_patient = user_text
break
explicit_flag = detect_explicit_suicidality(last_patient)
high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2))
meta = {"Severity": severity, "Total_Score": total, "Confidence": overall_conf}
# Termination conditions
min_conf = float(min(confidences)) if confidences else 0.0
turns += 1
done = high_risk or (min_conf >= threshold) or (turns >= MAX_TURNS)
if high_risk:
closing = (
"I’m concerned about your safety based on what you shared. "
"If you are in danger or need immediate help, please call 988 in the U.S. or your local emergency number. "
"I'll end the assessment now and display emergency resources."
)
chat_history[-1] = (chat_history[-1][0], closing)
finished = True
elif done:
summary = (
f"Thank you for sharing. Based on our conversation, your responses suggest {severity.lower()}. "
"We can stop here."
)
chat_history[-1] = (chat_history[-1][0], summary)
finished = True
else:
# Generate next clinician question
reply = generate_recording_agent_reply(chat_history)
chat_history[-1] = (chat_history[-1][0], reply)
# TTS for the latest clinician message, if enabled
tts_path = synthesize_tts(chat_history[-1][1]) if tts_enabled else None
# Build a compact JSON for display
display_json = {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity,
"Confidence": overall_conf,
"High_Risk": high_risk,
}
# Clear inputs after processing
return (
chat_history,
display_json,
severity,
finished,
turns,
None,
None,
tts_path,
tts_path,
)
def reset_app():
return init_state()
# ---------------------------
# UI
# ---------------------------
def _on_load_init():
return init_state()
def _on_load_init_with_tts(tts_on: bool):
chat_history, scores_state, meta_state, finished_state, turns_state = init_state()
# Play the intro message via TTS if enabled
tts_path = synthesize_tts(chat_history[-1][1]) if bool(tts_on) else None
return chat_history, scores_state, meta_state, finished_state, turns_state, tts_path
def _play_intro_tts(tts_on: bool):
if not bool(tts_on):
return None
try:
return synthesize_tts(INTRO_MESSAGE)
except Exception:
return None
def create_demo():
with gr.Blocks(
theme=gr.themes.Soft(),
css='''
/* Voice bubble styles - clean and centered */
#voice-bubble {
width: 240px; height: 240px; border-radius: 9999px; margin: 40px auto;
display: flex; align-items: center; justify-content: center;
background: linear-gradient(135deg, #6ee7b7 0%, #34d399 100%);
box-shadow: 0 20px 40px rgba(16,185,129,0.3), 0 0 0 1px rgba(255,255,255,0.1) inset;
transition: all 250ms cubic-bezier(0.4, 0, 0.2, 1);
cursor: default; /* green circle itself is not clickable */
pointer-events: none; /* ignore clicks on the green circle */
position: relative;
}
#voice-bubble:hover {
transform: translateY(-2px) scale(1.02);
box-shadow: 0 25px 50px rgba(16,185,129,0.4), 0 0 0 1px rgba(255,255,255,0.15) inset;
}
#voice-bubble:active { transform: translateY(0px) scale(0.98); }
#voice-bubble.listening {
animation: bubble-pulse 1.5s ease-in-out infinite;
background: linear-gradient(135deg, #fb7185 0%, #ef4444 100%);
box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 1px rgba(255,255,255,0.1) inset;
}
@keyframes bubble-pulse {
0%, 100% { transform: scale(1.0); box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 0 rgba(239,68,68,0.5); }
50% { transform: scale(1.05); box-shadow: 0 25px 50px rgba(239,68,68,0.5), 0 0 0 15px rgba(239,68,68,0.0); }
}
/* Hide microphone dropdown selector only */
#voice-bubble select { display: none !important; }
#voice-bubble .source-selection { display: none !important; }
#voice-bubble label[for] { display: none !important; }
/* Make the inner button the only clickable target */
#voice-bubble button { pointer-events: auto; cursor: pointer; }
/* Hide TTS player UI but keep it in DOM for autoplay */
#tts-player { width: 0 !important; height: 0 !important; opacity: 0 !important; position: absolute; pointer-events: none; }
'''
) as demo:
gr.Markdown(
"""
### Conversational Assessment for Responsive Engagement (CARE) Notes
Tap on 'Record' to start speaking, then tap on 'Stop' to stop recording.
"""
)
intro_play_btn = gr.Button("▶️ Play Intro", variant="secondary")
with gr.Tabs():
with gr.TabItem("Main"):
with gr.Column():
# Microphone component styled as central bubble (tap to record/stop)
audio_main = gr.Microphone(type="filepath", label=None, elem_id="voice-bubble", show_label=False)
# Hidden text input placeholder for pipeline compatibility
text_main = gr.Textbox(value="", visible=False)
# Autoplay clinician voice output (player hidden with CSS)
tts_audio_main = gr.Audio(label=None, interactive=False, autoplay=True, show_label=False, elem_id="tts-player")
with gr.TabItem("Advanced"):
with gr.Column():
chatbot = gr.Chatbot(height=360, type="tuples", label="Conversation")
score_json = gr.JSON(label="PHQ-9 Assessment (live)")
severity_label = gr.Label(label="Severity")
threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
with gr.Row():
tts_provider_dd = gr.Dropdown(choices=["Coqui", "gTTS"], value="Coqui", label="TTS Provider")
coqui_model_tb = gr.Textbox(value=os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits"), label="Coqui Model")
coqui_speaker_dd = gr.Dropdown(choices=list_coqui_speakers(os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")), value="p225", label="Coqui Speaker")
tts_audio = gr.Audio(label="Clinician voice", interactive=False, autoplay=False, visible=False)
model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it")
with gr.Row():
apply_model_btn = gr.Button("Apply model (no restart)")
# apply_model_restart_btn = gr.Button("Apply model and restart")
model_status = gr.Markdown(value=f"Current model: `{current_model_id}`")
# App state
chat_state = gr.State()
scores_state = gr.State()
meta_state = gr.State()
finished_state = gr.State()
turns_state = gr.State()
# Initialize on load (no autoplay due to browser policies)
demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
# Explicit user gesture to play intro TTS (works across browsers)
intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main])
# Wire interactions
def _process_with_tts(audio, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker):
result = process_turn(audio, text, chat, th, tts_on, finished, turns, scores, meta)
chat_history, display_json, severity, finished_o, turns_o, _, _, _, last_tts = result
if tts_on and chat_history and chat_history[-1][1]:
new_path = synthesize_tts(chat_history[-1][1], provider=provider, coqui_model_name=coqui_model, coqui_speaker=coqui_speaker)
else:
new_path = None
return chat_history, display_json, severity, finished_o, turns_o, None, None, new_path, new_path
audio_main.stop_recording(
fn=_process_with_tts,
inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd],
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main],
queue=True,
api_name="message",
)
# Tap bubble to toggle microphone record/stop via JS
# This JS is no longer needed as the bubble is the mic
# voice_bubble.click(
# None,
# inputs=None,
# outputs=None,
# js="() => {\n const bubble = document.getElementById('voice-bubble');\n const root = document.getElementById('hidden-mic');\n if (!root) return;\n let didClick = false;\n const wc = root.querySelector && root.querySelector('gradio-audio');\n if (wc && wc.shadowRoot) {\n const btns = Array.from(wc.shadowRoot.querySelectorAll('button')).filter(b => !b.disabled);\n const txt = (b) => ((b.getAttribute('aria-label')||'') + ' ' + (b.textContent||'')).toLowerCase();\n const stopBtn = btns.find(b => txt(b).includes('stop'));\n const recBtn = btns.find(b => { const t = txt(b); return t.includes('record') || t.includes('start') || t.includes('microphone') || t.includes('mic'); });\n if (stopBtn) { stopBtn.click(); didClick = true; } else if (recBtn) { recBtn.click(); didClick = true; } else if (btns[0]) { btns[0].click(); didClick = true; }\n }\n if (!didClick) {\n const candidates = Array.from(root.querySelectorAll('button, [role=\\'button\\']')).filter(el => !el.disabled);\n if (candidates.length) { candidates[0].click(); didClick = true; }\n }\n if (bubble && didClick) bubble.classList.toggle('listening');\n }",
# )
# No reset button in Main tab anymore
# Model switch handlers
def _on_apply_model(mid: str):
msg = set_current_model_id(mid)
return f"Current model: `{current_model_id}`\n\n{msg}"
def _on_apply_model_restart(mid: str):
msg = apply_model_and_restart(mid)
return f"{msg}"
apply_model_btn.click(fn=_on_apply_model, inputs=[model_id_tb], outputs=[model_status])
# apply_model_restart_btn.click(fn=_on_apply_model_restart, inputs=[model_id_tb], outputs=[model_status])
return demo
demo = create_demo()
if __name__ == "__main__":
# For local dev
demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))