Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from speechbrain.pretrained import GraphemeToPhoneme | |
| import os | |
| import torchaudio | |
| from wav2vecasr.MispronounciationDetector import MispronounciationDetector | |
| from wav2vecasr.PhonemeASRModel import MultitaskPhonemeASRModel | |
| import json | |
| import os | |
| import random | |
| import openai | |
| from gtts import gTTS | |
| from io import BytesIO | |
| openai.api_key = os.getenv("OPENAI_KEY") | |
| # https://gtts.readthedocs.io/en/latest/ | |
| # | |
| def tts_gtts(text): | |
| mp3_fp = BytesIO() | |
| tts = gTTS(text, lang="en") | |
| tts.write_to_fp(mp3_fp) | |
| return mp3_fp | |
| def pronounce(text): | |
| if len(text) > 0: | |
| data = tts_gtts(text) | |
| return data | |
| return [] | |
| def load_model(): | |
| path = os.path.join(os.getcwd(), "wav2vecasr", "model", "multitask_best_ctc.pt") | |
| vocab_path = os.path.join(os.getcwd(), "wav2vecasr", "model", "vocab") | |
| device = "cpu" | |
| asr_model = MultitaskPhonemeASRModel(path, vocab_path, device) | |
| g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") | |
| mispronounciation_detector = MispronounciationDetector(asr_model, g2p, device) | |
| return mispronounciation_detector | |
| def save_file(sound_file): | |
| # save your sound file in the right folder by following the path | |
| audio_folder_path = os.path.join(os.getcwd(), 'audio_files') | |
| if not os.path.exists(audio_folder_path): | |
| os.makedirs(audio_folder_path) | |
| with open(os.path.join(audio_folder_path, sound_file.name), 'wb') as f: | |
| f.write(sound_file.getbuffer()) | |
| return sound_file.name | |
| def get_audio(saved_sound_filename): | |
| audio_path = f'audio_files/{saved_sound_filename}' | |
| audio, org_sr = torchaudio.load(audio_path) | |
| audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000) | |
| audio = audio.view(audio.shape[1]) | |
| return audio | |
| def get_prompts(): | |
| prompts_path = os.path.join(os.getcwd(), "wav2vecasr", "data", "prompts.json") | |
| f = open(prompts_path) | |
| data = json.load(f) | |
| prompts = data["prompts"] | |
| return prompts | |
| def get_articulation_videos(): | |
| # note -- not all arpabets could be mapped to a video with visualisation on articulation | |
| path = os.path.join(os.getcwd(), "wav2vecasr", "data", "videos.json") | |
| f = open(path) | |
| data = json.load(f) | |
| return data | |
| def get_prompts_from_l2_arctic(prompts, current_prompt, num_to_get): | |
| selected_prompts = [] | |
| while len(selected_prompts) < num_to_get: | |
| prompt = random.choice(prompts) | |
| if prompt not in selected_prompts and prompt != current_prompt: | |
| selected_prompts.append(prompt) | |
| return selected_prompts | |
| def get_prompt_from_openai(words_with_error_list): | |
| try: | |
| words_with_errors = ", ".join(words_with_error_list) | |
| response = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": "You are writing practice reading prompts for learners of English to practice pronunciation. These prompts should be short, easy to understand and useful."}, | |
| {"role": "user", "content": f"Write a short sentence of less than 10 words and include the following words in the sentence: {words_with_errors} No numbers."} | |
| ] | |
| ) | |
| return response['choices'][0]['message']['content'] | |
| except: | |
| return "" | |
| def mispronounciation_detection_section(): | |
| st.write('# Prediction') | |
| st.write('1. Upload a recording of you saying the text in .wav format') | |
| uploaded_file = st.file_uploader(' ', type='wav') | |
| st.write('2. Input the text you are saying in your recording') | |
| text = st.text_input( | |
| "Enter the text you want to read π", | |
| label_visibility='collapsed' | |
| ) | |
| if st.button('Predict'): | |
| if uploaded_file is not None and len(text) > 0: | |
| # get audio from loaded file | |
| save_file(uploaded_file) | |
| audio = get_audio(uploaded_file.name) | |
| # load model | |
| mispronunciation_detector = load_model() | |
| st.write('# Detection Results') | |
| with st.spinner('Predicting...'): | |
| # detect | |
| raw_info = mispronunciation_detector.detect(audio, text, phoneme_error_threshold=0.25) | |
| # display prediction results for phonemes | |
| st.write('#### Phoneme Level Analysis') | |
| st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}") | |
| st.markdown( | |
| f""" | |
| <style> | |
| textarea {{ | |
| white-space: nowrap; | |
| }} | |
| </style> | |
| ``` | |
| {raw_info['ref']} | |
| {raw_info['hyp']} | |
| {raw_info['phoneme_errors']} | |
| ``` | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.divider() | |
| # display word errors | |
| md = [] | |
| words_with_errors = [] | |
| for word, has_error in zip(raw_info["words"], raw_info["word_errors"]): | |
| if has_error: | |
| words_with_errors.append(word) | |
| md.append(f"**{word}**") | |
| else: | |
| md.append(word) | |
| st.write('#### Word Level Analysis') | |
| st.write(f"Word Error Rate: {round(raw_info['wer'], 2)} and the following words in bold have errors:") | |
| st.markdown(" ".join(md)) | |
| st.divider() | |
| st.write('#### What is next?') | |
| # display pronounciation e.g. | |
| st.write("Compare your pronunciation to pronounced sample") | |
| st.audio(f'audio_files/{uploaded_file.name}', format="audio/wav", start_time=0) | |
| pronounced_sample = pronounce(text) | |
| st.audio(pronounced_sample, format="audio/wav", start_time=0) | |
| # display more prompts to practice -- 1 from ChatGPT -- based on user's mistakes, 2 from L2 Arctic | |
| st.write('Here are some more prompts for you to practice:') | |
| selected_prompts = [] | |
| unique_words_with_errors = list(set(words_with_errors)) | |
| prompt_for_mistakes_made = get_prompt_from_openai(unique_words_with_errors) | |
| if prompt_for_mistakes_made: | |
| selected_prompts.append(prompt_for_mistakes_made) | |
| prompts = get_prompts() | |
| l2_arctic_prompts = get_prompts_from_l2_arctic(prompts, text, 3-len(selected_prompts)) | |
| selected_prompts.extend(l2_arctic_prompts) | |
| for prompt in selected_prompts: | |
| st.code(f'''{prompt}''', language="python") | |
| else: | |
| st.error('The audio or text has not been properly input', icon="π¨") | |
| return | |
| def video_section(): | |
| st.write('# Get helpful videos on phoneme articulation') | |
| problem_phoneme = st.text_input( | |
| "Enter the phoneme you had problems with π" | |
| ) | |
| arpabet_to_video_map = get_articulation_videos() | |
| if st.button('Look up'): | |
| if not problem_phoneme: | |
| st.error('The audio or text has not been properly input', icon="π¨") | |
| elif problem_phoneme in arpabet_to_video_map: | |
| video_link = arpabet_to_video_map[problem_phoneme]["link"] | |
| if video_link: | |
| st.video(video_link) | |
| else: | |
| st.write("Sorry, we couldn't find a good enough video yet :( we are working on it!") | |
| if __name__ == '__main__': | |
| st.write('___') | |
| # create a sidebar | |
| st.sidebar.title('Pronounciation Evaluation') | |
| select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection', 'Helpful Videos for Problem Phonemes'], key='1', label_visibility='collapsed') | |
| st.sidebar.write(select) | |
| if select=='Mispronounciation Detection': | |
| mispronounciation_detection_section() | |
| elif select=="Helpful Videos for Problem Phonemes": | |
| video_section() | |
| else: | |
| st.write('# Pronounciation Evaluation') | |
| st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.') | |
| st.write('Wav2Vec2.0 was used to detect the phonemes from the learner and this output is compared with the correct phoneme sequence generated from input text') | |