Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import pipeline, set_seed | |
| from transformers import AutoTokenizer | |
| from normalizer import Normalizer | |
| import random | |
| import meta | |
| import examples | |
| from utils import ( | |
| remote_css, | |
| local_css | |
| ) | |
| class TextGeneration: | |
| def __init__(self): | |
| self.debug = False | |
| self.dummy_output = None | |
| self.tokenizer = None | |
| self.generator = None | |
| self.task = "text-generation" | |
| self.model_name_or_path = "HamidRezaAttar/gpt2-product-description-generator" | |
| set_seed(42) | |
| def load(self): | |
| if not self.debug: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) | |
| self.generator = pipeline(self.task, model=self.model_name_or_path, tokenizer=self.model_name_or_path) | |
| def generate(self, prompt, generation_kwargs): | |
| if not self.debug: | |
| generation_kwargs["num_return_sequences"] = 1 | |
| max_length = len(self.tokenizer(prompt)["input_ids"]) + generation_kwargs["max_length"] | |
| generation_kwargs["max_length"] = max_length | |
| generation_kwargs["return_full_text"] = False | |
| return self.generator( | |
| prompt, | |
| **generation_kwargs, | |
| )[0]["generated_text"] | |
| return self.dummy_output | |
| def load_text_generator(): | |
| generator = TextGeneration() | |
| generator.load() | |
| return generator | |
| def main(): | |
| st.set_page_config( | |
| page_title="GPT2 - Home", | |
| page_icon="🏡", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| remote_css("https://fonts.googleapis.com/css2?family=Roboto:wght@300&display=swap%22%20rel=%22stylesheet%22") | |
| local_css("assets/ltr.css") | |
| generator = load_text_generator() | |
| st.sidebar.markdown(meta.SIDEBAR_INFO) | |
| max_length = st.sidebar.slider( | |
| label='Max Length', | |
| help="The maximum length of the sequence to be generated.", | |
| min_value=1, | |
| max_value=128, | |
| value=50, | |
| step=1 | |
| ) | |
| top_k = st.sidebar.slider( | |
| label='Top-k', | |
| help="The number of highest probability vocabulary tokens to keep for top-k-filtering", | |
| min_value=40, | |
| max_value=80, | |
| value=50, | |
| step=1 | |
| ) | |
| top_p = st.sidebar.slider( | |
| label='Top-p', | |
| help="Only the most probable tokens with probabilities that add up to `top_p` or higher are kept for " | |
| "generation.", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.95, | |
| step=0.01 | |
| ) | |
| temperature = st.sidebar.slider( | |
| label='Temperature', | |
| help="The value used to module the next token probabilities", | |
| min_value=0.1, | |
| max_value=10.0, | |
| value=1.0, | |
| step=0.05 | |
| ) | |
| do_sample = st.sidebar.selectbox( | |
| label='Sampling ?', | |
| options=(True, False), | |
| help="Whether or not to use sampling; use greedy decoding otherwise.", | |
| ) | |
| generation_kwargs = { | |
| "max_length": max_length, | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| "temperature": temperature, | |
| "do_sample": do_sample, | |
| } | |
| st.markdown(meta.HEADER_INFO) | |
| prompts = list(examples.EXAMPLES.keys()) + ["Custom"] | |
| prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1) | |
| if prompt == "Custom": | |
| prompt_box = meta.PROMPT_BOX | |
| else: | |
| prompt_box = random.choice(examples.EXAMPLES[prompt]) | |
| text = st.text_area("Enter text", prompt_box) | |
| generation_kwargs_ph = st.empty() | |
| cleaner = Normalizer() | |
| if st.button("Generate !"): | |
| with st.spinner(text="Generating ..."): | |
| generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()])) | |
| if text: | |
| generated_text = generator.generate(text, generation_kwargs) | |
| generated_text = cleaner.clean_txt(generated_text) | |
| st.markdown( | |
| f'<p class="ltr ltr-box">' | |
| f'<span class="result-text">{text} <span>' | |
| f'<span class="result-text generated-text">{generated_text}</span>' | |
| f'</p>', | |
| unsafe_allow_html=True | |
| ) | |
| if __name__ == '__main__': | |
| main() | |