Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from joblib import load | |
| import textwrap | |
| import streamlit as st | |
| device = 'cpu' | |
| class GenreNet(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| # параметры сетиnspose arrayroupout'] | |
| self.dropout = config['dropout'] | |
| self.out_range = config['out_range'] | |
| # финальный полносвязный слой для пронгоза оценки | |
| self.head = nn.Sequential( | |
| nn.Linear(312, 256), | |
| nn.Dropout(self.dropout[0]), | |
| nn.ReLU(), | |
| nn.Linear(256, 128), | |
| nn.Dropout(self.dropout[0]), | |
| nn.ReLU(), | |
| nn.Linear(128, 64), | |
| nn.Dropout(self.dropout[0]), | |
| nn.ReLU(), | |
| nn.Linear(64, 1), | |
| ) | |
| def forward(self, emb): | |
| x = torch.sigmoid(self.head(emb)) | |
| x = x * (self.out_range[1] - self.out_range[0]) + self.out_range[0] | |
| return(x) | |
| config = { | |
| 'dropout': [.5], | |
| 'out_range': [1.,5.] # для номировки выходных оценок | |
| } | |
| bert = load('./model.joblib') | |
| model = GenreNet(config) | |
| model.load_state_dict(torch.load('./pages/weights_los065_ep100_lr0001_lay256_128_64_1.pt', map_location=device)) | |
| tokenizer = load('./tokenizer.joblib') | |
| def embed_bert_cls(text, model, tokenizer): | |
| t = tokenizer(text, padding=True, truncation=True, return_tensors='pt') | |
| with torch.no_grad(): | |
| model_output = model(**{k: v.to(device) for k, v in t.items()}) | |
| embeddings = model_output.last_hidden_state[:, 0, :] | |
| embeddings = torch.nn.functional.normalize(embeddings) | |
| return embeddings[0] | |
| genre = {1 : 'Романтика', 2:'Поэзия', 3:'Детектив', 4:'Приключения', 5:'Фантастика', } | |
| prompt = st.text_input('Узнаем жанр!',) | |
| if len(prompt) > 1: | |
| with torch.inference_mode(): | |
| prompt_embedding = embed_bert_cls([prompt], bert, tokenizer) | |
| out = model(prompt_embedding).cpu().numpy() | |
| #for out_ in out: | |
| st.write('Предполагаемый жанр:', genre[int(round(out.item(), 0))]) | |