Spaces:
Runtime error
Runtime error
| import torch | |
| from joblib import load | |
| import textwrap | |
| import streamlit as st | |
| device = 'cpu' | |
| tokenizer = load('./pages/tokenizer.joblib') | |
| model = load('./pages/model.joblib') | |
| model.load_state_dict(torch.load('./pages/model_weights.pt', map_location=device)) | |
| temperature = st.slider('Градус дичи:', min_value = 1., max_value = 20., value = 3.) | |
| num_beams = st.slider('Число веток для поиска:', min_value = 1, max_value = 15, value = 7) | |
| max_length = st.slider('Максимальная длина генерации:', min_value = 50, max_value = 150, value = 70) | |
| prompt = st.text_input('Дайте волю фантазии!',) | |
| if len(prompt) > 1: | |
| with torch.inference_mode(): | |
| prompt = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
| out = model.generate( | |
| input_ids=prompt, | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_k=50, | |
| top_p=0.6, | |
| no_repeat_ngram_size=3, | |
| num_return_sequences=3, | |
| ).cpu().numpy() | |
| for out_ in out: | |
| st.write(textwrap.fill(tokenizer.decode(out_), 40), end='\n------------------\n') | |