Spaces:
Running
Running
Zachary Greathouse
Update probability distribution for provider selection for comparison. (#23)
e07c653
unverified
| # Standard Library Imports | |
| import asyncio | |
| import random | |
| from typing import Tuple | |
| # Local Application Imports | |
| from src.common import Config, Option, OptionMap, TTSProviderName, logger | |
| from src.common.constants import ELEVENLABS, HUME_AI, OPENAI | |
| from src.integrations import ( | |
| text_to_speech_with_elevenlabs, | |
| text_to_speech_with_hume, | |
| text_to_speech_with_openai, | |
| ) | |
| class TTSService: | |
| """ | |
| Service for coordinating text-to-speech generation across different providers. | |
| This class handles the logic for selecting TTS providers, making concurrent API calls, | |
| and processing the responses into a unified format for the frontend. | |
| """ | |
| def __init__(self, config: Config): | |
| """ | |
| Initialize the TTS service with application configuration. | |
| Args: | |
| config (Config): Application configuration containing API settings | |
| """ | |
| self.config = config | |
| self.tts_provider_functions = { | |
| HUME_AI: text_to_speech_with_hume, | |
| ELEVENLABS: text_to_speech_with_elevenlabs, | |
| OPENAI: text_to_speech_with_openai, | |
| } | |
| def __select_providers(self, text_modified: bool) -> Tuple[TTSProviderName, TTSProviderName]: | |
| """ | |
| Select 2 TTS providers based on whether the text has been modified. | |
| Probabilities: | |
| - 1/3 HUME_AI & OPENAI | |
| - 1/3 HUME_AI & ELEVENLABS | |
| - 1/3 OPENAI & ELEVENLABS | |
| If the `text_modified` argument is `True`, then 100% HUME_AI, HUME_AI | |
| Args: | |
| text_modified (bool): A flag indicating whether the text has been modified | |
| Returns: | |
| tuple: A tuple (TTSProviderName, TTSProviderName) | |
| """ | |
| if text_modified: | |
| return HUME_AI, HUME_AI | |
| # When modifying the probability distribution, make sure the weights match the order of provider pairs | |
| provider_pairs = [ | |
| (HUME_AI, OPENAI), | |
| (HUME_AI, ELEVENLABS), | |
| (OPENAI, ELEVENLABS), | |
| ] | |
| weights = [1, 1, 1] | |
| selected_pair = random.choices(provider_pairs, weights=weights, k=1)[0] | |
| return selected_pair | |
| async def synthesize_speech( | |
| self, | |
| character_description: str, | |
| text: str, | |
| text_modified: bool | |
| ) -> OptionMap: | |
| """ | |
| Generate speech for the given text using two different TTS providers. | |
| This method selects appropriate providers based on the text modification status, | |
| makes concurrent API calls to those providers, and returns the results. | |
| Args: | |
| character_description (str): Description of the character/voice for synthesis | |
| text (str): The text to synthesize into speech | |
| text_modified (bool): Whether the text has been modified from the original | |
| Returns: | |
| OptionMap: A mapping of shuffled TTS options, where each option includes | |
| its provider, audio file path, and generation ID. | |
| """ | |
| provider_a, provider_b = self.__select_providers(text_modified) | |
| logger.info(f"Starting speech synthesis with providers: {provider_a} and {provider_b}") | |
| task_a = self.tts_provider_functions[provider_a](character_description, text, self.config) | |
| task_b = self.tts_provider_functions[provider_b](character_description, text, self.config) | |
| (generation_id_a, audio_a), (generation_id_b, audio_b) = await asyncio.gather(task_a, task_b) | |
| logger.info(f"Synthesis succeeded for providers: {provider_a} and {provider_b}") | |
| option_a = Option(provider=provider_a, audio=audio_a, generation_id=generation_id_a) | |
| option_b = Option(provider=provider_b, audio=audio_b, generation_id=generation_id_b) | |
| options = [option_a, option_b] | |
| random.shuffle(options) | |
| shuffled_option_a, shuffled_option_b = options | |
| return { | |
| "option_a": { | |
| "provider": shuffled_option_a.provider, | |
| "generation_id": shuffled_option_a.generation_id, | |
| "audio_file_path": shuffled_option_a.audio, | |
| }, | |
| "option_b": { | |
| "provider": shuffled_option_b.provider, | |
| "generation_id": shuffled_option_b.generation_id, | |
| "audio_file_path": shuffled_option_b.audio, | |
| }, | |
| } | |