nazdridoy commited on
Commit
43333ad
Β·
verified Β·
1 Parent(s): f76cd43

feat(tts): add dynamic model parameters and Fal.ai Chatterbox

Browse files

- [feat] Define TTS_MODEL_CONFIGS, TTS_EXAMPLE_AUDIO_URLS, and add "Chatterbox (Fal.ai)" preset (utils.py:75-90, 150-155)
- [feat] Update `generate_text_to_speech()` to accept dynamic parameters and conditionally include `extra_body` (tts_handler.py:29-33, 59-76)
- [feat] Update `handle_text_to_speech_generation()` to pass new TTS arguments (tts_handler.py:148, 158-161)
- [feat] Implement dynamic UI for TTS model selection and parameter inputs with `gr.Dropdown` and `gr.Group` components (ui_components.py:446-474)
- [feat] Add `on_model_change()` to dynamically update group visibility on model selection (ui_components.py:512-527)
- [feat] Extend generate button inputs and add Chatterbox examples (ui_components.py:533-535, 559-566)

Files changed (3) hide show
  1. tts_handler.py +32 -8
  2. ui_components.py +64 -11
  3. utils.py +25 -0
tts_handler.py CHANGED
@@ -15,7 +15,8 @@ from utils import (
15
  IMAGE_CONFIG,
16
  validate_proxy_key,
17
  format_error_message,
18
- format_success_message
 
19
  )
20
 
21
  # Timeout configuration for TTS generation
@@ -26,8 +27,12 @@ def generate_text_to_speech(
26
  text: str,
27
  model_name: str,
28
  provider: str,
29
- voice: str = "am_eric",
30
  speed: float = 1.0,
 
 
 
 
31
  ):
32
  """
33
  Generate speech from text using the specified model and provider through HF-Inferoxy.
@@ -56,16 +61,31 @@ def generate_text_to_speech(
56
 
57
  print(f"πŸš€ TTS: Client created, preparing generation params...")
58
 
 
 
 
 
59
  # Prepare generation parameters
60
  generation_params = {
61
  "text": text,
62
  "model": model_name,
63
- "extra_body": {
64
- "voice": voice,
65
- "speed": speed
66
- }
67
  }
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  print(f"πŸ“‘ TTS: Making generation request with {TTS_GENERATION_TIMEOUT}s timeout...")
70
 
71
  # Create generation function for timeout handling
@@ -133,7 +153,7 @@ def generate_text_to_speech(
133
  return None, format_error_message("Unexpected Error", f"An unexpected error occurred: {error_msg}")
134
 
135
 
136
- def handle_text_to_speech_generation(text_val, model_val, provider_val, voice_val, speed_val):
137
  """
138
  Handle text-to-speech generation request with validation.
139
  """
@@ -151,5 +171,9 @@ def handle_text_to_speech_generation(text_val, model_val, provider_val, voice_va
151
  model_name=model_val,
152
  provider=provider_val,
153
  voice=voice_val,
154
- speed=speed_val
 
 
 
 
155
  )
 
15
  IMAGE_CONFIG,
16
  validate_proxy_key,
17
  format_error_message,
18
+ format_success_message,
19
+ TTS_MODEL_CONFIGS
20
  )
21
 
22
  # Timeout configuration for TTS generation
 
27
  text: str,
28
  model_name: str,
29
  provider: str,
30
+ voice: str = "af_bella",
31
  speed: float = 1.0,
32
+ audio_url: str = "",
33
+ exaggeration: float = 0.25,
34
+ temperature: float = 0.7,
35
+ cfg: float = 0.5,
36
  ):
37
  """
38
  Generate speech from text using the specified model and provider through HF-Inferoxy.
 
61
 
62
  print(f"πŸš€ TTS: Client created, preparing generation params...")
63
 
64
+ # Get model configuration
65
+ model_config = TTS_MODEL_CONFIGS.get(model_name, {})
66
+ extra_body_params = model_config.get("extra_body_params", [])
67
+
68
  # Prepare generation parameters
69
  generation_params = {
70
  "text": text,
71
  "model": model_name,
72
+ "extra_body": {}
 
 
 
73
  }
74
 
75
+ # Add model-specific parameters to extra_body
76
+ if "voice" in extra_body_params:
77
+ generation_params["extra_body"]["voice"] = voice
78
+ if "speed" in extra_body_params:
79
+ generation_params["extra_body"]["speed"] = speed
80
+ if "audio_url" in extra_body_params:
81
+ generation_params["extra_body"]["audio_url"] = audio_url
82
+ if "exaggeration" in extra_body_params:
83
+ generation_params["extra_body"]["exaggeration"] = exaggeration
84
+ if "temperature" in extra_body_params:
85
+ generation_params["extra_body"]["temperature"] = temperature
86
+ if "cfg" in extra_body_params:
87
+ generation_params["extra_body"]["cfg"] = cfg
88
+
89
  print(f"πŸ“‘ TTS: Making generation request with {TTS_GENERATION_TIMEOUT}s timeout...")
90
 
91
  # Create generation function for timeout handling
 
153
  return None, format_error_message("Unexpected Error", f"An unexpected error occurred: {error_msg}")
154
 
155
 
156
+ def handle_text_to_speech_generation(text_val, model_val, provider_val, voice_val, speed_val, audio_url_val, exaggeration_val, temperature_val, cfg_val):
157
  """
158
  Handle text-to-speech generation request with validation.
159
  """
 
171
  model_name=model_val,
172
  provider=provider_val,
173
  voice=voice_val,
174
+ speed=speed_val,
175
+ audio_url=audio_url_val,
176
+ exaggeration=exaggeration_val,
177
+ temperature=temperature_val,
178
+ cfg=cfg_val
179
  )
ui_components.py CHANGED
@@ -9,8 +9,8 @@ from utils import (
9
  DEFAULT_IMAGE_TO_IMAGE_MODEL, DEFAULT_IMAGE_TO_IMAGE_PROVIDER,
10
  DEFAULT_TTS_MODEL, DEFAULT_TTS_PROVIDER,
11
  CHAT_CONFIG, IMAGE_CONFIG, IMAGE_PROVIDERS, IMAGE_MODEL_PRESETS,
12
- IMAGE_TO_IMAGE_MODEL_PRESETS, TTS_MODEL_PRESETS, TTS_VOICES,
13
- IMAGE_EXAMPLE_PROMPTS, IMAGE_TO_IMAGE_EXAMPLE_PROMPTS, TTS_EXAMPLE_TEXTS
14
  )
15
 
16
 
@@ -412,7 +412,7 @@ def create_image_to_image_tab(handle_image_to_image_generation_fn):
412
 
413
  def create_tts_tab(handle_tts_generation_fn):
414
  """
415
- Create the text-to-speech tab interface.
416
  """
417
  with gr.Tab("🎀 Text-to-Speech", id="tts"):
418
  with gr.Row():
@@ -430,7 +430,7 @@ def create_tts_tab(handle_tts_generation_fn):
430
  label="Generated Audio",
431
  type="numpy",
432
  interactive=False,
433
- autoplay=False,
434
  show_download_button=True
435
  )
436
  status_text = gr.Textbox(
@@ -443,10 +443,11 @@ def create_tts_tab(handle_tts_generation_fn):
443
  # Model and provider inputs
444
  with gr.Group():
445
  gr.Markdown("**πŸ€– Model & Provider**")
446
- tts_model_name = gr.Textbox(
 
447
  value=DEFAULT_TTS_MODEL,
448
- label="Model Name",
449
- placeholder="e.g., hexgrad/Kokoro-82M"
450
  )
451
  tts_provider = gr.Dropdown(
452
  choices=IMAGE_PROVIDERS,
@@ -455,9 +456,9 @@ def create_tts_tab(handle_tts_generation_fn):
455
  interactive=True
456
  )
457
 
458
- # Voice and speed settings
459
- with gr.Group():
460
- gr.Markdown("**🎀 Voice Settings**")
461
  tts_voice = gr.Dropdown(
462
  choices=list(TTS_VOICES.items()),
463
  value="af_bella",
@@ -469,6 +470,28 @@ def create_tts_tab(handle_tts_generation_fn):
469
  label="Speed", info="0.5 = slow, 2.0 = fast"
470
  )
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  # Generate and Stop buttons
473
  with gr.Row():
474
  generate_btn = gr.Button(
@@ -484,6 +507,25 @@ def create_tts_tab(handle_tts_generation_fn):
484
 
485
  # Examples for TTS generation
486
  create_tts_examples(tts_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
 
488
  # Connect TTS generation events
489
  # Show stop immediately when starting generation
@@ -497,7 +539,8 @@ def create_tts_tab(handle_tts_generation_fn):
497
  gen_event = generate_btn.click(
498
  fn=handle_tts_generation_fn,
499
  inputs=[
500
- tts_text, tts_model_name, tts_provider, tts_voice, tts_speed
 
501
  ],
502
  outputs=[output_audio, status_text]
503
  )
@@ -561,6 +604,16 @@ def create_tts_examples(tts_text):
561
  )
562
 
563
 
 
 
 
 
 
 
 
 
 
 
564
  def create_image_presets(img_model_name, img_provider):
565
  """Create quick model presets for image generation."""
566
  with gr.Group():
 
9
  DEFAULT_IMAGE_TO_IMAGE_MODEL, DEFAULT_IMAGE_TO_IMAGE_PROVIDER,
10
  DEFAULT_TTS_MODEL, DEFAULT_TTS_PROVIDER,
11
  CHAT_CONFIG, IMAGE_CONFIG, IMAGE_PROVIDERS, IMAGE_MODEL_PRESETS,
12
+ IMAGE_TO_IMAGE_MODEL_PRESETS, TTS_MODEL_PRESETS, TTS_VOICES, TTS_MODEL_CONFIGS,
13
+ IMAGE_EXAMPLE_PROMPTS, IMAGE_TO_IMAGE_EXAMPLE_PROMPTS, TTS_EXAMPLE_TEXTS, TTS_EXAMPLE_AUDIO_URLS
14
  )
15
 
16
 
 
412
 
413
  def create_tts_tab(handle_tts_generation_fn):
414
  """
415
+ Create the text-to-speech tab interface with dynamic model-specific settings.
416
  """
417
  with gr.Tab("🎀 Text-to-Speech", id="tts"):
418
  with gr.Row():
 
430
  label="Generated Audio",
431
  type="numpy",
432
  interactive=False,
433
+ autoplay=True,
434
  show_download_button=True
435
  )
436
  status_text = gr.Textbox(
 
443
  # Model and provider inputs
444
  with gr.Group():
445
  gr.Markdown("**πŸ€– Model & Provider**")
446
+ tts_model_name = gr.Dropdown(
447
+ choices=["hexgrad/Kokoro-82M", "ResembleAI/chatterbox"],
448
  value=DEFAULT_TTS_MODEL,
449
+ label="Model",
450
+ info="Select TTS model"
451
  )
452
  tts_provider = gr.Dropdown(
453
  choices=IMAGE_PROVIDERS,
 
456
  interactive=True
457
  )
458
 
459
+ # Kokoro-specific settings (initially visible)
460
+ with gr.Group(visible=True) as kokoro_settings:
461
+ gr.Markdown("**🎀 Kokoro Voice Settings**")
462
  tts_voice = gr.Dropdown(
463
  choices=list(TTS_VOICES.items()),
464
  value="af_bella",
 
470
  label="Speed", info="0.5 = slow, 2.0 = fast"
471
  )
472
 
473
+ # Chatterbox-specific settings (initially hidden)
474
+ with gr.Group(visible=False) as chatterbox_settings:
475
+ gr.Markdown("**🎭 Chatterbox Style Settings**")
476
+ tts_audio_url = gr.Textbox(
477
+ value=TTS_EXAMPLE_AUDIO_URLS[0],
478
+ label="Reference Audio URL",
479
+ placeholder="Enter URL to reference audio file",
480
+ info="Audio file to match style and tone"
481
+ )
482
+ tts_exaggeration = gr.Slider(
483
+ minimum=0.0, maximum=1.0, value=0.25, step=0.05,
484
+ label="Exaggeration", info="How much to exaggerate the style"
485
+ )
486
+ tts_temperature = gr.Slider(
487
+ minimum=0.0, maximum=1.0, value=0.7, step=0.1,
488
+ label="Temperature", info="Creativity level"
489
+ )
490
+ tts_cfg = gr.Slider(
491
+ minimum=0.0, maximum=1.0, value=0.5, step=0.1,
492
+ label="CFG", info="Guidance strength"
493
+ )
494
+
495
  # Generate and Stop buttons
496
  with gr.Row():
497
  generate_btn = gr.Button(
 
507
 
508
  # Examples for TTS generation
509
  create_tts_examples(tts_text)
510
+
511
+ # Create Chatterbox audio URL examples
512
+ create_chatterbox_examples(tts_audio_url)
513
+
514
+ # Model change handler to show/hide appropriate settings
515
+ def on_model_change(model_name):
516
+ if model_name == "hexgrad/Kokoro-82M":
517
+ return gr.update(visible=True), gr.update(visible=False)
518
+ elif model_name == "ResembleAI/chatterbox":
519
+ return gr.update(visible=False), gr.update(visible=True)
520
+ else:
521
+ return gr.update(visible=True), gr.update(visible=False)
522
+
523
+ # Connect model change event
524
+ tts_model_name.change(
525
+ fn=on_model_change,
526
+ inputs=[tts_model_name],
527
+ outputs=[kokoro_settings, chatterbox_settings]
528
+ )
529
 
530
  # Connect TTS generation events
531
  # Show stop immediately when starting generation
 
539
  gen_event = generate_btn.click(
540
  fn=handle_tts_generation_fn,
541
  inputs=[
542
+ tts_text, tts_model_name, tts_provider, tts_voice, tts_speed,
543
+ tts_audio_url, tts_exaggeration, tts_temperature, tts_cfg
544
  ],
545
  outputs=[output_audio, status_text]
546
  )
 
604
  )
605
 
606
 
607
+ def create_chatterbox_examples(tts_audio_url):
608
+ """Create example audio URLs for Chatterbox TTS."""
609
+ with gr.Group():
610
+ gr.Markdown("**🎡 Example Reference Audio URLs**")
611
+ chatterbox_examples = gr.Examples(
612
+ examples=[[url] for url in TTS_EXAMPLE_AUDIO_URLS],
613
+ inputs=tts_audio_url
614
+ )
615
+
616
+
617
  def create_image_presets(img_model_name, img_provider):
618
  """Create quick model presets for image generation."""
619
  with gr.Group():
utils.py CHANGED
@@ -72,8 +72,25 @@ IMAGE_TO_IMAGE_MODEL_PRESETS = [
72
  TTS_MODEL_PRESETS = [
73
  ("Kokoro (Fal.ai)", "hexgrad/Kokoro-82M", "fal-ai"),
74
  ("Kokoro (Replicate)", "hexgrad/Kokoro-82M", "replicate"),
 
75
  ]
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # Voice options for Kokoro TTS (based on the reference app)
78
  TTS_VOICES = {
79
  'πŸ‡ΊπŸ‡Έ 🚺 Heart ❀️': 'af_heart',
@@ -142,6 +159,14 @@ TTS_EXAMPLE_TEXTS = [
142
  "Life is what happens when you're busy making other plans. Embrace every moment with gratitude."
143
  ]
144
 
 
 
 
 
 
 
 
 
145
 
146
  def get_proxy_key():
147
  """Get the proxy API key from environment variables."""
 
72
  TTS_MODEL_PRESETS = [
73
  ("Kokoro (Fal.ai)", "hexgrad/Kokoro-82M", "fal-ai"),
74
  ("Kokoro (Replicate)", "hexgrad/Kokoro-82M", "replicate"),
75
+ ("Chatterbox (Fal.ai)", "ResembleAI/chatterbox", "fal-ai"),
76
  ]
77
 
78
+ # Model-specific configurations for TTS
79
+ TTS_MODEL_CONFIGS = {
80
+ "hexgrad/Kokoro-82M": {
81
+ "type": "kokoro",
82
+ "supports_voice": True,
83
+ "supports_speed": True,
84
+ "extra_body_params": ["voice", "speed"]
85
+ },
86
+ "ResembleAI/chatterbox": {
87
+ "type": "chatterbox",
88
+ "supports_voice": False,
89
+ "supports_speed": False,
90
+ "extra_body_params": ["audio_url", "exaggeration", "temperature", "cfg"]
91
+ }
92
+ }
93
+
94
  # Voice options for Kokoro TTS (based on the reference app)
95
  TTS_VOICES = {
96
  'πŸ‡ΊπŸ‡Έ 🚺 Heart ❀️': 'af_heart',
 
159
  "Life is what happens when you're busy making other plans. Embrace every moment with gratitude."
160
  ]
161
 
162
+ # Example audio URLs for Chatterbox TTS
163
+ TTS_EXAMPLE_AUDIO_URLS = [
164
+ "https://github.com/nazdridoy/kokoro-tts/raw/main/previews/demo.mp3",
165
+ "https://huggingface.co/datasets/hf-internal-testing/fixtures/resolve/main/audio/sample_audio_1.mp3",
166
+ "https://huggingface.co/datasets/hf-internal-testing/fixtures/resolve/main/audio/sample_audio_2.mp3",
167
+ "https://www.soundjay.com/misc/sounds/bell-ringing-05.wav"
168
+ ]
169
+
170
 
171
  def get_proxy_key():
172
  """Get the proxy API key from environment variables."""