Akis Giannoukos commited on
Commit
aec1268
·
1 Parent(s): 30f47d7

Implement Coqui TTS integration with model and speaker selection in demo interface; update requirements to include coqui-tts package.

Browse files
Files changed (2) hide show
  1. app.py +54 -5
  2. requirements.txt +1 -1
app.py CHANGED
@@ -220,12 +220,33 @@ def detect_explicit_suicidality(text: Optional[str]) -> bool:
220
  return False
221
 
222
 
223
- def synthesize_tts(text: Optional[str]) -> Optional[str]:
 
 
 
 
 
224
  if not text:
225
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  try:
227
- # Save MP3 to tmp and return filepath
228
- ts = int(time.time() * 1000)
229
  out_path = f"/tmp/tts_{ts}.mp3"
230
  tts = gTTS(text=text, lang="en")
231
  tts.save(out_path)
@@ -234,6 +255,21 @@ def synthesize_tts(text: Optional[str]) -> Optional[str]:
234
  return None
235
 
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  def severity_from_total(total_score: int) -> str:
238
  if total_score <= 4:
239
  return "Minimal Depression"
@@ -660,6 +696,10 @@ def create_demo():
660
  severity_label = gr.Label(label="Severity")
661
  threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
662
  tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
 
 
 
 
663
  tts_audio = gr.Audio(label="Clinician voice", interactive=False, autoplay=False, visible=False)
664
  model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it")
665
  with gr.Row():
@@ -681,9 +721,18 @@ def create_demo():
681
  intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main])
682
 
683
  # Wire interactions
 
 
 
 
 
 
 
 
 
684
  audio_main.stop_recording(
685
- fn=process_turn,
686
- inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state],
687
  outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main],
688
  queue=True,
689
  api_name="message",
 
220
  return False
221
 
222
 
223
+ def synthesize_tts(
224
+ text: Optional[str],
225
+ provider: str = "Coqui",
226
+ coqui_model_name: Optional[str] = None,
227
+ coqui_speaker: Optional[str] = None,
228
+ ) -> Optional[str]:
229
  if not text:
230
  return None
231
+ ts = int(time.time() * 1000)
232
+ provider_norm = (provider or "Coqui").strip().lower()
233
+ # Try Coqui first if requested
234
+ if provider_norm == "coqui":
235
+ try:
236
+ # coqui-tts uses the same import path TTS.api
237
+ from TTS.api import TTS as CoquiTTS # type: ignore
238
+ model_name = (coqui_model_name or os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")).strip()
239
+ engine = CoquiTTS(model_name=model_name, progress_bar=False)
240
+ out_path_wav = f"/tmp/tts_{ts}.wav"
241
+ kwargs = {}
242
+ if coqui_speaker:
243
+ kwargs["speaker"] = coqui_speaker
244
+ engine.tts_to_file(text=text, file_path=out_path_wav, **kwargs)
245
+ return out_path_wav
246
+ except Exception:
247
+ pass
248
+ # Fallback to gTTS
249
  try:
 
 
250
  out_path = f"/tmp/tts_{ts}.mp3"
251
  tts = gTTS(text=text, lang="en")
252
  tts.save(out_path)
 
255
  return None
256
 
257
 
258
+ def list_coqui_speakers(model_name: str) -> List[str]:
259
+ try:
260
+ from TTS.api import TTS as CoquiTTS # type: ignore
261
+ engine = CoquiTTS(model_name=model_name, progress_bar=False)
262
+ # Try common attributes
263
+ if hasattr(engine, "speakers") and isinstance(engine.speakers, list):
264
+ return [str(s) for s in engine.speakers]
265
+ if hasattr(engine, "speaker_manager") and hasattr(engine.speaker_manager, "speaker_names"):
266
+ return list(engine.speaker_manager.speaker_names)
267
+ except Exception:
268
+ pass
269
+ # Reasonable defaults for VCTK multi-speaker
270
+ return ["p225", "p227", "p231", "p233", "p236"]
271
+
272
+
273
  def severity_from_total(total_score: int) -> str:
274
  if total_score <= 4:
275
  return "Minimal Depression"
 
696
  severity_label = gr.Label(label="Severity")
697
  threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
698
  tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
699
+ with gr.Row():
700
+ tts_provider_dd = gr.Dropdown(choices=["Coqui", "gTTS"], value="Coqui", label="TTS Provider")
701
+ coqui_model_tb = gr.Textbox(value=os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits"), label="Coqui Model")
702
+ coqui_speaker_dd = gr.Dropdown(choices=list_coqui_speakers(os.getenv("COQUI_MODEL", "tts_models/en/vctk/vits")), value="p225", label="Coqui Speaker")
703
  tts_audio = gr.Audio(label="Clinician voice", interactive=False, autoplay=False, visible=False)
704
  model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it")
705
  with gr.Row():
 
721
  intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main])
722
 
723
  # Wire interactions
724
+ def _process_with_tts(audio, text, chat, th, tts_on, finished, turns, scores, meta, provider, coqui_model, coqui_speaker):
725
+ result = process_turn(audio, text, chat, th, tts_on, finished, turns, scores, meta)
726
+ chat_history, display_json, severity, finished_o, turns_o, _, _, _, last_tts = result
727
+ if tts_on and chat_history and chat_history[-1][1]:
728
+ new_path = synthesize_tts(chat_history[-1][1], provider=provider, coqui_model_name=coqui_model, coqui_speaker=coqui_speaker)
729
+ else:
730
+ new_path = None
731
+ return chat_history, display_json, severity, finished_o, turns_o, None, None, new_path, new_path
732
+
733
  audio_main.stop_recording(
734
+ fn=_process_with_tts,
735
+ inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state, tts_provider_dd, coqui_model_tb, coqui_speaker_dd],
736
  outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main],
737
  queue=True,
738
  api_name="message",
requirements.txt CHANGED
@@ -10,4 +10,4 @@ scipy>=1.11.4
10
  protobuf>=4.25.3
11
  gTTS>=2.5.3
12
  spaces>=0.27.1
13
-
 
10
  protobuf>=4.25.3
11
  gTTS>=2.5.3
12
  spaces>=0.27.1
13
+ coqui-tts>=0.27.2