Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import torch | |
| from transformers import BertTokenizer | |
| from transformers import BertForSequenceClassification | |
| # кэширование | |
| def get_model(unzip_root: str='./'): | |
| """ | |
| unzip_root ~ в тестовой среде будет произведена операция `unzip archive.zip` с переданным архивом и в эту функцию будет передан путь до `realpath .` | |
| """ | |
| checkpoint_path = os.path.join(unzip_root, "model.pth") | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| model_path = 'cointegrated/rubert-tiny' | |
| model = BertForSequenceClassification.from_pretrained(model_path) | |
| out_features = model.bert.encoder.layer[1].output.dense.out_features | |
| model.classifier = torch.nn.Linear(out_features, len(dict_tetm_to_int)) | |
| model.load_state_dict(checkpoint) | |
| return model | |
| # кэширование | |
| def get_tokenizer(): | |
| tokenizer_path = 'cointegrated/rubert-tiny' | |
| tokenizer = BertTokenizer.from_pretrained(tokenizer_path) | |
| return tokenizer | |
| # кэширование | |
| def get_vocab(unzip_root: str='./'): | |
| """ | |
| unzip_root ~ в тестовой среде будет произведена операция `unzip archive.zip` с переданным архивом и в эту функцию будет передан путь до `realpath .` | |
| """ | |
| path = os.path.join(unzip_root, "vocab.tsv") | |
| with open(path, 'r') as f: | |
| size_dict = int(f.readline()) | |
| dict_tetm_to_int = dict() | |
| for _ in range(size_dict): | |
| key = f.readline()[:-1] | |
| dict_tetm_to_int[key] = int(f.readline()) | |
| size_dict = int(f.readline()) | |
| dict_int_to_term = dict() | |
| for _ in range(size_dict): | |
| key = int(f.readline()) | |
| dict_int_to_term[key] = f.readline()[:-1] | |
| return dict_tetm_to_int, dict_int_to_term | |
| softmax = torch.nn.Softmax(dim=1) | |
| dict_tetm_to_int, dict_int_to_term = get_vocab() | |
| model = get_model() | |
| tokenizer = get_tokenizer() | |
| def predict(text, device='cpu'): | |
| encoding = tokenizer.encode_plus( | |
| text, | |
| add_special_tokens=True, | |
| max_length=512, | |
| return_token_type_ids=False, | |
| truncation=True, | |
| padding='max_length', | |
| return_attention_mask=True, | |
| return_tensors='pt', | |
| ) | |
| out = { | |
| 'text': text, | |
| 'input_ids': encoding['input_ids'].flatten(), | |
| 'attention_mask': encoding['attention_mask'].flatten() | |
| } | |
| input_ids = out["input_ids"].to(device) | |
| attention_mask = out["attention_mask"].to(device) | |
| outputs = model( | |
| input_ids=input_ids.unsqueeze(0), | |
| attention_mask=attention_mask.unsqueeze(0) | |
| ) | |
| out = softmax(outputs.logits) | |
| prediction = torch.argsort(outputs.logits, dim=1, descending=True).cpu()[0] | |
| sum_answer = 0 | |
| answer = [] | |
| idx = 0 | |
| while sum_answer < 0.95: | |
| sum_answer += out[0][idx].item() | |
| answer.append(dict_int_to_term[prediction[idx].item()]) | |
| idx += 1 | |
| return answer | |
| st.title("We will help you determine what topic your article belongs to:)") | |
| st.header("Please enter a title and/or introduction") | |
| title = st.text_input(label="Title", value="") | |
| abstract = st.text_input(label="Abstract", value="") | |
| if(st.button('Show result')): | |
| predict = ' '.join(predict(title.title() + ' ' + abstract.title())) | |
| result = 'Suggested answer:\n' + predict | |
| st.success(result) | |