Spaces:
Runtime error
Runtime error
| import random | |
| from mtranslate import translate | |
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline | |
| LOGO = "bertin.png" | |
| MODELS = { | |
| "RoBERTa Base Gaussian Seq Len 512": { | |
| "url": "bertin-project/bertin-base-gaussian-exp-512seqlen" | |
| }, | |
| "RoBERTa Base Gaussian Seq Len 128": { | |
| "url": "bertin-project/bertin-base-gaussian" | |
| }, | |
| "RoBERTa Base Random Seq Len 128": { | |
| "url": "bertin-project/bertin-base-random" | |
| }, | |
| } | |
| PROMPT_LIST = [ | |
| "Fui a la librería a comprar un <mask>.", | |
| "¡Qué buen <mask> hace hoy!", | |
| "Hoy empiezan las vacaciones así que vamos a la <mask>.", | |
| "Mi color favorito es el <mask>.", | |
| "Voy a <mask> porque estoy muy cansada.", | |
| "Mañana vienen mis amigos de <mask>.", | |
| "¿Te apetece venir a <mask> conmigo?", | |
| "En verano hace mucho <mask>.", | |
| "En el bosque había <mask>.", | |
| "El ministro dijo que <mask> los impuestos.", | |
| "Si no estuviera afónica, <mask> esa canción.", | |
| ] | |
| def load_model(masked_text, model_url): | |
| model = AutoModelForMaskedLM.from_pretrained(model_url) | |
| tokenizer = AutoTokenizer.from_pretrained(model_url) | |
| nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer) | |
| result = nlp(masked_text) | |
| return result | |
| # Page | |
| st.set_page_config(page_title="BERTIN Demo", page_icon=LOGO) | |
| st.title("BERTIN") | |
| #Sidebar | |
| st.sidebar.image(LOGO) | |
| # Body | |
| st.markdown( | |
| """ | |
| BERTIN is a series of BERT-based models for Spanish. | |
| The models are trained with Flax and using TPUs sponsored by Google since this is part of the | |
| [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) | |
| organised by HuggingFace. | |
| All models are variations of **RoBERTa-base** trained from scratch in **Spanish** using the **mc4 dataset**. | |
| We reduced the dataset size to 50 million documents to keep training times shorter, and also to be able to bias training examples based on their perplexity. | |
| The idea is to favour examples with perplexities that are neither too small (short, repetitive texts) or too long (potentially poor quality). | |
| * **Random** sampling simply takes documents at random to reduce the dataset size. | |
| * **Gaussian** rejects documents with a higher probability for lower and larger perplexities, based on a Gaussian function. | |
| The first models have been trained (250.000 steps) on sequence length 128, and training for Gaussian changed to sequence length 512 for the last 25.000 training steps. | |
| """ | |
| ) | |
| model_name = st.selectbox("Model", list(MODELS.keys())) | |
| model_url = MODELS[model_name]["url"] | |
| prompt = st.selectbox("Prompt", ["Random", "Custom"]) | |
| if prompt == "Custom": | |
| prompt_box = "Enter your masked text here..." | |
| else: | |
| prompt_box = random.choice(PROMPT_LIST) | |
| text = st.text_area("Enter text", prompt_box) | |
| if st.button("Fill the mask"): | |
| with st.spinner(text="Filling the mask..."): | |
| st.subheader("Result") | |
| result = load_model(text, model_url) | |
| result_sequence = result[0]["sequence"] | |
| st.write(result_sequence) | |
| st.write("_English_ _translation:_", translate(result_sequence, "en", "es")) | |
| st.write(result) | |
| st.markdown( | |
| """ | |
| ### Team members | |
| - Eduardo González ([edugp](https://huggingface.co/edugp)) | |
| - Javier de la Rosa ([versae](https://huggingface.co/versae)) | |
| - Manu Romero ([mrm8488](https://huggingface.co/mrm8488)) | |
| - María Grandury ([mariagrandury](https://huggingface.co/mariagrandury)) | |
| - Pablo González de Prado ([Pablogps](https://huggingface.co/Pablogps)) | |
| - Paulo Villegas ([paulo](https://huggingface.co/paulo)) | |
| ### More information | |
| You can find more information about these models | |
| [here](https://huggingface.co/bertin-project/bertin-roberta-base-spanish). | |
| """ | |
| ) | |