Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| import io | |
| from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration | |
| import time | |
| import json | |
| from typing import List | |
| import torch | |
| import random | |
| import logging | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda:0") | |
| else: | |
| device = torch.device("cpu") | |
| logging.warning("GPU not found, using CPU, translation will be very slow.") | |
| st.cache(suppress_st_warning=True, allow_output_mutation=True) | |
| st.set_page_config(page_title="M2M100 Translator") | |
| lang_id = { | |
| "Afrikaans": "af", | |
| "Amharic": "am", | |
| "Arabic": "ar", | |
| "Asturian": "ast", | |
| "Azerbaijani": "az", | |
| "Bashkir": "ba", | |
| "Belarusian": "be", | |
| "Bulgarian": "bg", | |
| "Bengali": "bn", | |
| "Breton": "br", | |
| "Bosnian": "bs", | |
| "Catalan": "ca", | |
| "Cebuano": "ceb", | |
| "Czech": "cs", | |
| "Welsh": "cy", | |
| "Danish": "da", | |
| "German": "de", | |
| "Greeek": "el", | |
| "English": "en", | |
| "Spanish": "es", | |
| "Estonian": "et", | |
| "Persian": "fa", | |
| "Fulah": "ff", | |
| "Finnish": "fi", | |
| "French": "fr", | |
| "Western Frisian": "fy", | |
| "Irish": "ga", | |
| "Gaelic": "gd", | |
| "Galician": "gl", | |
| "Gujarati": "gu", | |
| "Hausa": "ha", | |
| "Hebrew": "he", | |
| "Hindi": "hi", | |
| "Croatian": "hr", | |
| "Haitian": "ht", | |
| "Hungarian": "hu", | |
| "Armenian": "hy", | |
| "Indonesian": "id", | |
| "Igbo": "ig", | |
| "Iloko": "ilo", | |
| "Icelandic": "is", | |
| "Italian": "it", | |
| "Japanese": "ja", | |
| "Javanese": "jv", | |
| "Georgian": "ka", | |
| "Kazakh": "kk", | |
| "Central Khmer": "km", | |
| "Kannada": "kn", | |
| "Korean": "ko", | |
| "Luxembourgish": "lb", | |
| "Ganda": "lg", | |
| "Lingala": "ln", | |
| "Lao": "lo", | |
| "Lithuanian": "lt", | |
| "Latvian": "lv", | |
| "Malagasy": "mg", | |
| "Macedonian": "mk", | |
| "Malayalam": "ml", | |
| "Mongolian": "mn", | |
| "Marathi": "mr", | |
| "Malay": "ms", | |
| "Burmese": "my", | |
| "Nepali": "ne", | |
| "Dutch": "nl", | |
| "Norwegian": "no", | |
| "Northern Sotho": "ns", | |
| "Occitan": "oc", | |
| "Oriya": "or", | |
| "Panjabi": "pa", | |
| "Polish": "pl", | |
| "Pushto": "ps", | |
| "Portuguese": "pt", | |
| "Romanian": "ro", | |
| "Russian": "ru", | |
| "Sindhi": "sd", | |
| "Sinhala": "si", | |
| "Slovak": "sk", | |
| "Slovenian": "sl", | |
| "Somali": "so", | |
| "Albanian": "sq", | |
| "Serbian": "sr", | |
| "Swati": "ss", | |
| "Sundanese": "su", | |
| "Swedish": "sv", | |
| "Swahili": "sw", | |
| "Tamil": "ta", | |
| "Thai": "th", | |
| "Tagalog": "tl", | |
| "Tswana": "tn", | |
| "Turkish": "tr", | |
| "Ukrainian": "uk", | |
| "Urdu": "ur", | |
| "Uzbek": "uz", | |
| "Vietnamese": "vi", | |
| "Wolof": "wo", | |
| "Xhosa": "xh", | |
| "Yiddish": "yi", | |
| "Yoruba": "yo", | |
| "Chinese": "zh", | |
| "Zulu": "zu", | |
| } | |
| def load_model( | |
| pretrained_model: str = "facebook/m2m100_1.2B", | |
| cache_dir: str = "models/", | |
| ): | |
| tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir) | |
| model = M2M100ForConditionalGeneration.from_pretrained( | |
| pretrained_model, cache_dir=cache_dir | |
| ).to(device) | |
| model.eval() | |
| return tokenizer, model | |
| st.title("M2M100 Translator") | |
| st.write("M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this paper https://arxiv.org/abs/2010.11125 and first released in https://github.com/pytorch/fairseq/tree/master/examples/m2m_100 repository. The model that can directly translate between the 9,900 directions of 100 languages.\n") | |
| st.write(" This demo uses the facebook/m2m100_1.2B model. For local inference see https://github.com/ikergarcia1996/Easy-Translate") | |
| user_input: str = st.text_area( | |
| "Input text", | |
| height=200, | |
| max_chars=5120, | |
| ) | |
| source_lang = st.selectbox(label="Source language", options=list(lang_id.keys())) | |
| target_lang = st.selectbox(label="Target language", options=list(lang_id.keys())) | |
| if st.button("Run"): | |
| time_start = time.time() | |
| tokenizer, model = load_model() | |
| src_lang = lang_id[source_lang] | |
| trg_lang = lang_id[target_lang] | |
| tokenizer.src_lang = src_lang | |
| with torch.no_grad(): | |
| encoded_input = tokenizer(user_input, return_tensors="pt").to(device) | |
| generated_tokens = model.generate( | |
| **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang) | |
| ) | |
| translated_text = tokenizer.batch_decode( | |
| generated_tokens, skip_special_tokens=True | |
| )[0] | |
| time_end = time.time() | |
| st.success(translated_text) | |
| st.write(f"Computation time: {round((time_end-time_start),3)} segs") | |