Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import uuid | |
| import json | |
| import librosa | |
| import os | |
| import tempfile | |
| import soundfile as sf | |
| import scipy.io.wavfile as wav | |
| from transformers import pipeline, VitsModel, AutoTokenizer, set_seed | |
| from nemo.collections.asr.models import EncDecMultiTaskModel | |
| # Constants | |
| SAMPLE_RATE = 16000 # Hz | |
| # load ASR model | |
| canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b') | |
| decode_cfg = canary_model.cfg.decoding | |
| decode_cfg.beam.beam_size = 1 | |
| canary_model.change_decoding_strategy(decode_cfg) | |
| # Function to convert audio to text using ASR | |
| def gen_text(audio_filepath, action, source_lang, target_lang): | |
| if audio_filepath is None: | |
| raise gr.Error("Please provide some input audio.") | |
| utt_id = uuid.uuid4() | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| # Convert to 16 kHz | |
| data, sr = librosa.load(audio_filepath, sr=None, mono=True) | |
| if sr != SAMPLE_RATE: | |
| data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) | |
| converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav") | |
| sf.write(converted_audio_filepath, data, SAMPLE_RATE) | |
| # Transcribe audio | |
| duration = len(data) / SAMPLE_RATE | |
| manifest_data = { | |
| "audio_filepath": converted_audio_filepath, | |
| "taskname": action, | |
| "source_lang": source_lang, | |
| "target_lang": source_lang if action=="asr" else target_lang, | |
| "pnc": "no", | |
| "answer": "predict", | |
| "duration": str(duration), | |
| } | |
| manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json") | |
| with open(manifest_filepath, 'w') as fout: | |
| fout.write(json.dumps(manifest_data)) | |
| predicted_text = canary_model.transcribe(manifest_filepath)[0] | |
| # if duration < 40: | |
| # predicted_text = canary_model.transcribe(manifest_filepath)[0] | |
| # else: | |
| # predicted_text = get_buffered_pred_feat_multitaskAED( | |
| # frame_asr, | |
| # canary_model.cfg.preprocessor, | |
| # model_stride_in_secs, | |
| # canary_model.device, | |
| # manifest=manifest_filepath, | |
| # )[0].text | |
| return predicted_text | |
| # Function to convert text to speech using TTS | |
| def gen_speech(text, lang): | |
| set_seed(555) # Make it deterministic | |
| match lang: | |
| case "en": | |
| model = "facebook/mms-tts-eng" | |
| case "fr": | |
| model = "facebook/mms-tts-fra" | |
| case "de": | |
| model = "facebook/mms-tts-deu" | |
| case "es": | |
| model = "facebook/mms-tts-spa" | |
| case _: | |
| model = "facebook/mms-tts" | |
| # load TTS model | |
| tts_model = VitsModel.from_pretrained(model) | |
| tts_tokenizer = AutoTokenizer.from_pretrained(model) | |
| input_text = tts_tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = tts_model(**input_text) | |
| waveform_np = outputs.waveform[0].cpu().numpy() | |
| output_file = f"{str(uuid.uuid4())}.wav" | |
| wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np) | |
| return output_file | |
| # Root function for Gradio interface | |
| def start_process(audio_filepath, source_lang, target_lang): | |
| transcription = gen_text(audio_filepath, "asr", source_lang, target_lang) | |
| print("Done transcribing") | |
| translation = gen_text(audio_filepath, "s2t_translation", source_lang, target_lang) | |
| print("Done translation") | |
| audio_output_filepath = gen_speech(translation, target_lang) | |
| print("Done speaking") | |
| return transcription, translation, audio_output_filepath | |
| # Create Gradio interface | |
| playground = gr.Blocks() | |
| with playground: | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| ## Your AI Translate Assistant | |
| ### Gets input audio from user, transcribe and translate it. Convert back to speech. | |
| - category: [Automatic Speech Recognition](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition), model: [nvidia/canary-1b](https://huggingface.co/nvidia/canary-1b) | |
| - category: [Text-to-Speech](https://huggingface.co/models?pipeline_tag=text-to-speech), model: [facebook/mms-tts](https://huggingface.co/facebook/mms-tts) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| source_lang = gr.Dropdown( | |
| choices=["en", "de", "es", "fr"], value="en", label="Source Language" | |
| ) | |
| with gr.Column(): | |
| target_lang = gr.Dropdown( | |
| choices=["en", "de", "es", "fr"], value="fr", label="Target Language" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_audio = gr.Audio(sources=["microphone"], type="filepath", label="Input Audio") | |
| with gr.Column(): | |
| translated_speech = gr.Audio(type="filepath", label="Generated Speech") | |
| with gr.Row(): | |
| with gr.Column(): | |
| transcipted_text = gr.Textbox(label="Transcription") | |
| with gr.Column(): | |
| translated_text = gr.Textbox(label="Translation") | |
| with gr.Row(): | |
| with gr.Column(): | |
| submit_button = gr.Button(value="Start Process", variant="primary") | |
| with gr.Column(): | |
| clear_button = gr.ClearButton(components=[input_audio, source_lang, target_lang, transcipted_text, translated_text, translated_speech], value="Clear") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| ["sample_en.wav","en","fr"], | |
| ["sample_fr.wav","fr","de"], | |
| ["sample_de.wav","de","es"], | |
| ["sample_es.wav","es","en"] | |
| ], | |
| inputs=[input_audio, source_lang, target_lang], | |
| outputs=[transcipted_text, translated_text, translated_speech], | |
| run_on_click=True, cache_examples=True, fn=start_process | |
| ) | |
| submit_button.click(start_process, inputs=[input_audio, source_lang, target_lang], outputs=[transcipted_text, translated_text, translated_speech]) | |
| playground.launch() |