Spaces:
Running
on
Zero
Running
on
Zero
Akis Giannoukos
commited on
Commit
·
aec1268
1
Parent(s):
30f47d7
Implement Coqui TTS integration with model and speaker selection in demo interface; update requirements to include coqui-tts package.
Browse files- app.py +54 -5
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -220,12 +220,33 @@ def detect_explicit_suicidality(text: Optional[str]) -> bool:
|
|
| 220 |
return False
|
| 221 |
|
| 222 |
|
| 223 |
-
def synthesize_tts(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
if not text:
|
| 225 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
try:
|
| 227 |
-
# Save MP3 to tmp and return filepath
|
| 228 |
-
ts = int(time.time() * 1000)
|
| 229 |
out_path = f"/tmp/tts_{ts}.mp3"
|
| 230 |
tts = gTTS(text=text, lang="en")
|
| 231 |
tts.save(out_path)
|
|
@@ -234,6 +255,21 @@ def synthesize_tts(text: Optional[str]) -> Optional[str]:
|
|
| 234 |
return None
|
| 235 |
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
def severity_from_total(total_score: int) -> str:
|
| 238 |
if total_score <= 4:
|
| 239 |
return "Minimal Depression"
|
|
@@ -660,6 +696,10 @@ def create_demo():
|
|
| 660 |
severity_label = gr.Label(label="Severity")
|
| 661 |
threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
|
| 662 |
tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
tts_audio = gr.Audio(label="Clinician voice", interactive=False, autoplay=False, visible=False)
|
| 664 |
model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it")
|
| 665 |
with gr.Row():
|
|
@@ -681,9 +721,18 @@ def create_demo():
|
|
| 681 |
intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main])
|
| 682 |
|
| 683 |
# Wire interactions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 684 |
audio_main.stop_recording(
|
| 685 |
-
fn=
|
| 686 |
-
inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state],
|
| 687 |
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main],
|
| 688 |
queue=True,
|
| 689 |
api_name="message",
|
|
|
|
| 220 |
return False
|
| 221 |
|
| 222 |
|
| 223 |
+
def synthesize_tts(
|
| 224 |
+
text: Optional[str],
|
| 225 |
+
provider: str = "Coqui",
|
| 226 |
+
coqui_model_name: Optional[str] = None,
|
| 227 |
+
coqui_speaker: Optional[str] = None,
|
| 228 |
+
) -> Optional[str]:
|
| 229 |
if not text:
|
| 230 |
return None
|
| 231 |
+
ts = int(time.time() * 1000)
|
| 232 |
+
provider_norm = (provider or "Coqui").strip().lower()
|
| 233 |
+
# Try Coqui first if requested
|
| 234 |
+
if provider_norm == "coqui":
|
| 235 |
+
try:
|
| 236 |
+
# coqui-tts uses the same import path TTS.api
|
| 237 |
+
from TTS.api import TTS as CoquiTTS # type: ignore
|
| 238 |
+
model_name = (coqui_model_name or os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")).strip()
|
| 239 |
+
engine = CoquiTTS(model_name=model_name, progress_bar=False)
|
| 240 |
+
out_path_wav = f"/tmp/tts_{ts}.wav"
|
| 241 |
+
kwargs = {}
|
| 242 |
+
if coqui_speaker:
|
| 243 |
+
kwargs["speaker"] = coqui_speaker
|
| 244 |
+
engine.tts_to_file(text=text, file_path=out_path_wav, **kwargs)
|
| 245 |
+
return out_path_wav
|
| 246 |
+
except Exception:
|
| 247 |
+
pass
|
| 248 |
+
# Fallback to gTTS
|
| 249 |
try:
|
|
|
|
|
|
|
| 250 |
out_path = f"/tmp/tts_{ts}.mp3"
|
| 251 |
tts = gTTS(text=text, lang="en")
|
| 252 |
tts.save(out_path)
|
|
|
|
| 255 |
return None
|
| 256 |
|
| 257 |
|
| 258 |
+
def list_coqui_speakers(model_name: str) -> List[str]:
|
| 259 |
+
try:
|
| 260 |
+
from TTS.api import TTS as CoquiTTS # type: ignore
|
| 261 |
+
engine = CoquiTTS(model_name=model_name, progress_bar=False)
|
| 262 |
+
# Try common attributes
|
| 263 |
+
if hasattr(engine, "speakers") and isinstance(engine.speakers, list):
|
| 264 |
+
return [str(s) for s in engine.speakers]
|
| 265 |
+
if hasattr(engine, "speaker_manager") and hasattr(engine.speaker_manager, "speaker_names"):
|
| 266 |
+
return list(engine.speaker_manager.speaker_names)
|
| 267 |
+
except Exception:
|
| 268 |
+
pass
|
| 269 |
+
# Reasonable defaults for VCTK multi-speaker
|
| 270 |
+
return ["p225", "p227", "p231", "p233", "p236"]
|
| 271 |
+
|
| 272 |
+
|
| 273 |
def severity_from_total(total_score: int) -> str:
|
| 274 |
if total_score <= 4:
|
| 275 |
return "Minimal Depression"
|
|
|
|
| 696 |
severity_label = gr.Label(label="Severity")
|
| 697 |
threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
|
| 698 |
tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
|
| 699 |
+
with gr.Row():
|
| 700 |
+
tts_provider_dd = gr.Dropdown(choices=["Coqui", "gTTS"], value="Coqui", label="TTS Provider")
|
| 701 |
+
coqui_model_tb = gr.Textbox(value=os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits"), label="Coqui Model")
|
| 702 |
+
coqui_speaker_dd = gr.Dropdown(choices=list_coqui_speakers(os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")), value="p225", label="Coqui Speaker")
|
| 703 |
tts_audio = gr.Audio(label="Clinician voice", interactive=False, autoplay=False, visible=False)
|
| 704 |
model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it")
|
| 705 |
with gr.Row():
|
|
|
|
| 721 |
intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main])
|
| 722 |
|
| 723 |
# Wire interactions
|
| 724 |
+
def _process_with_tts(audio, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker):
|
| 725 |
+
result = process_turn(audio, text, chat, th, tts_on, finished, turns, scores, meta)
|
| 726 |
+
chat_history, display_json, severity, finished_o, turns_o, _, _, _, last_tts = result
|
| 727 |
+
if tts_on and chat_history and chat_history[-1][1]:
|
| 728 |
+
new_path = synthesize_tts(chat_history[-1][1], provider=provider, coqui_model_name=coqui_model, coqui_speaker=coqui_speaker)
|
| 729 |
+
else:
|
| 730 |
+
new_path = None
|
| 731 |
+
return chat_history, display_json, severity, finished_o, turns_o, None, None, new_path, new_path
|
| 732 |
+
|
| 733 |
audio_main.stop_recording(
|
| 734 |
+
fn=_process_with_tts,
|
| 735 |
+
inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd],
|
| 736 |
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main],
|
| 737 |
queue=True,
|
| 738 |
api_name="message",
|
requirements.txt
CHANGED
|
@@ -10,4 +10,4 @@ scipy>=1.11.4
|
|
| 10 |
protobuf>=4.25.3
|
| 11 |
gTTS>=2.5.3
|
| 12 |
spaces>=0.27.1
|
| 13 |
-
|
|
|
|
| 10 |
protobuf>=4.25.3
|
| 11 |
gTTS>=2.5.3
|
| 12 |
spaces>=0.27.1
|
| 13 |
+
coqui-tts>=0.27.2
|