Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| import os | |
| import pathlib | |
| import tempfile | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from fairseq2.assets import InProcAssetMetadataProvider, asset_store | |
| from fairseq2.data import Collater, SequenceData, VocabularyInfo | |
| from fairseq2.data.audio import ( | |
| AudioDecoder, | |
| WaveformToFbankConverter, | |
| WaveformToFbankOutput, | |
| ) | |
| from seamless_communication.inference import SequenceGeneratorOptions | |
| from fairseq2.generation import NGramRepeatBlockProcessor | |
| from fairseq2.memory import MemoryBlock | |
| from fairseq2.typing import DataType, Device | |
| from huggingface_hub import snapshot_download | |
| from seamless_communication.inference import BatchedSpeechOutput, Translator, SequenceGeneratorOptions | |
| from seamless_communication.models.generator.loader import load_pretssel_vocoder_model | |
| from seamless_communication.models.unity import ( | |
| UnitTokenizer, | |
| load_gcmvn_stats, | |
| load_unity_text_tokenizer, | |
| load_unity_unit_tokenizer, | |
| ) | |
| from torch.nn import Module | |
| from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import PretsselGenerator | |
| from utils import LANGUAGE_CODE_TO_NAME | |
| DESCRIPTION = """\ | |
| # Seamless Expressive | |
| [SeamlessExpressive](https://github.com/facebookresearch/seamless_communication/blob/main/docs/expressive/README.md) is a speech-to-speech translation model that captures certain underexplored aspects of prosody such as speech rate and pauses, while preserving the style of one's voice and high content translation quality. | |
| """ | |
| CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available() | |
| CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models")) | |
| if not CHECKPOINTS_PATH.exists(): | |
| snapshot_download(repo_id="facebook/seamless-expressive", repo_type="model", local_dir=CHECKPOINTS_PATH) | |
| snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH) | |
| # Ensure that we do not have any other environment resolvers and always return | |
| # "demo" for demo purposes. | |
| asset_store.env_resolvers.clear() | |
| asset_store.env_resolvers.append(lambda: "demo") | |
| # Construct an `InProcAssetMetadataProvider` with environment-specific metadata | |
| # that just overrides the regular metadata for "demo" environment. Note the "@demo" suffix. | |
| demo_metadata = [ | |
| { | |
| "name": "seamless_expressivity@demo", | |
| "checkpoint": f"file://{CHECKPOINTS_PATH}/m2m_expressive_unity.pt", | |
| "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model", | |
| }, | |
| { | |
| "name": "vocoder_pretssel@demo", | |
| "checkpoint": f"file://{CHECKPOINTS_PATH}/pretssel_melhifigan_wm-final.pt", | |
| }, | |
| { | |
| "name": "seamlessM4T_v2_large@demo", | |
| "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt", | |
| "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model", | |
| }, | |
| ] | |
| asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata)) | |
| LANGUAGE_NAME_TO_CODE = {v: k for k, v in LANGUAGE_CODE_TO_NAME.items()} | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda:0") | |
| dtype = torch.float16 | |
| else: | |
| device = torch.device("cpu") | |
| dtype = torch.float32 | |
| MODEL_NAME = "seamless_expressivity" | |
| VOCODER_NAME = "vocoder_pretssel" | |
| # used for ASR for toxicity | |
| m4t_translator = Translator( | |
| model_name_or_card="seamlessM4T_v2_large", | |
| vocoder_name_or_card=None, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| unit_tokenizer = load_unity_unit_tokenizer(MODEL_NAME) | |
| _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(VOCODER_NAME) | |
| gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype) | |
| gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype) | |
| translator = Translator( | |
| MODEL_NAME, | |
| vocoder_name_or_card=None, | |
| device=device, | |
| dtype=dtype, | |
| apply_mintox=False, | |
| ) | |
| text_generation_opts = SequenceGeneratorOptions( | |
| beam_size=5, | |
| unk_penalty=torch.inf, | |
| soft_max_seq_len=(0, 200), | |
| step_processor=NGramRepeatBlockProcessor( | |
| ngram_size=10, | |
| ), | |
| ) | |
| m4t_text_generation_opts = SequenceGeneratorOptions( | |
| beam_size=5, | |
| unk_penalty=torch.inf, | |
| soft_max_seq_len=(1, 200), | |
| step_processor=NGramRepeatBlockProcessor( | |
| ngram_size=10, | |
| ), | |
| ) | |
| pretssel_generator = PretsselGenerator( | |
| VOCODER_NAME, | |
| vocab_info=unit_tokenizer.vocab_info, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| decode_audio = AudioDecoder(dtype=torch.float32, device=device) | |
| convert_to_fbank = WaveformToFbankConverter( | |
| num_mel_bins=80, | |
| waveform_scale=2**15, | |
| channel_last=True, | |
| standardize=False, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput: | |
| fbank = data["fbank"] | |
| std, mean = torch.std_mean(fbank, dim=0) | |
| data["fbank"] = fbank.subtract(mean).divide(std) | |
| data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std) | |
| return data | |
| collate = Collater(pad_value=0, pad_to_multiple=1) | |
| AUDIO_SAMPLE_RATE = 16000 | |
| MAX_INPUT_AUDIO_LENGTH = 10 # in seconds | |
| def remove_prosody_tokens_from_text(text): | |
| # filter out prosody tokens, there is only emphasis '*', and pause '=' | |
| text = text.replace("*", "").replace("=", "") | |
| text = " ".join(text.split()) | |
| return text | |
| def preprocess_audio(input_audio_path: str) -> None: | |
| arr, org_sr = torchaudio.load(input_audio_path) | |
| new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE) | |
| max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE) | |
| if new_arr.shape[1] > max_length: | |
| new_arr = new_arr[:, :max_length] | |
| gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.") | |
| torchaudio.save(input_audio_path, new_arr, sample_rate=AUDIO_SAMPLE_RATE) | |
| def run( | |
| input_audio_path: str, | |
| source_language: str, | |
| target_language: str, | |
| ) -> tuple[str, str]: | |
| target_language_code = LANGUAGE_NAME_TO_CODE[target_language] | |
| source_language_code = LANGUAGE_NAME_TO_CODE[source_language] | |
| preprocess_audio(input_audio_path) | |
| with pathlib.Path(input_audio_path).open("rb") as fb: | |
| block = MemoryBlock(fb.read()) | |
| example = decode_audio(block) | |
| example = convert_to_fbank(example) | |
| example = normalize_fbank(example) | |
| example = collate(example) | |
| # get transcription for mintox | |
| source_sentences, _ = m4t_translator.predict( | |
| input=example["fbank"], | |
| task_str="S2TT", # get source text | |
| tgt_lang=source_language_code, | |
| text_generation_opts=m4t_text_generation_opts, | |
| ) | |
| source_text = str(source_sentences[0]) | |
| prosody_encoder_input = example["gcmvn_fbank"] | |
| text_output, unit_output = translator.predict( | |
| example["fbank"], | |
| "S2ST", | |
| tgt_lang=target_language_code, | |
| src_lang=source_language_code, | |
| text_generation_opts=text_generation_opts, | |
| unit_generation_ngram_filtering=False, | |
| duration_factor=1.0, | |
| prosody_encoder_input=prosody_encoder_input, | |
| src_text=source_text, # for mintox check | |
| ) | |
| speech_output = pretssel_generator.predict( | |
| unit_output.units, | |
| tgt_lang=target_language_code, | |
| prosody_encoder_input=prosody_encoder_input, | |
| ) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| torchaudio.save( | |
| f.name, | |
| speech_output.audio_wavs[0][0].to(torch.float32).cpu(), | |
| sample_rate=speech_output.sample_rate, | |
| ) | |
| text_out = remove_prosody_tokens_from_text(str(text_output[0])) | |
| return f.name, text_out | |
| TARGET_LANGUAGE_NAMES = [ | |
| "English", | |
| "French", | |
| "German", | |
| "Spanish", | |
| ] | |
| UPDATED_LANGUAGE_LIST = { | |
| "English": ["French", "German", "Spanish"], | |
| "French": ["English", "German", "Spanish"], | |
| "German": ["English", "French", "Spanish"], | |
| "Spanish": ["English", "French", "German"], | |
| } | |
| def rs_change(rs): | |
| return gr.update( | |
| choices=UPDATED_LANGUAGE_LIST[rs], | |
| value=UPDATED_LANGUAGE_LIST[rs][0], | |
| ) | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| gr.DuplicateButton( | |
| value="Duplicate Space for private use", | |
| elem_id="duplicate-button", | |
| visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| input_audio = gr.Audio(label="Input speech", type="filepath") | |
| source_language = gr.Dropdown( | |
| label="Source language", | |
| choices=TARGET_LANGUAGE_NAMES, | |
| value="English", | |
| ) | |
| target_language = gr.Dropdown( | |
| label="Target language", | |
| choices=TARGET_LANGUAGE_NAMES, | |
| value="French", | |
| interactive=True, | |
| ) | |
| source_language.change( | |
| fn=rs_change, | |
| inputs=[source_language], | |
| outputs=[target_language], | |
| ) | |
| btn = gr.Button() | |
| with gr.Column(): | |
| with gr.Group(): | |
| output_audio = gr.Audio(label="Translated speech") | |
| output_text = gr.Textbox(label="Translated text") | |
| gr.Examples( | |
| examples=[ | |
| ["assets/Excited-English.wav", "English", "Spanish"], | |
| ["assets/Whisper-English.wav", "English", "German"], | |
| ["assets/FastTalking-French.wav", "French", "English"], | |
| ["assets/Sad-English.wav", "English", "Spanish"], | |
| ], | |
| inputs=[input_audio, source_language, target_language], | |
| outputs=[output_audio, output_text], | |
| fn=run, | |
| cache_examples=CACHE_EXAMPLES, | |
| api_name=False, | |
| ) | |
| btn.click( | |
| fn=run, | |
| inputs=[input_audio, source_language, target_language], | |
| outputs=[output_audio, output_text], | |
| api_name="run", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=50).launch() | |