Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import MBartForConditionalGeneration, AutoTokenizer | |
| from params import DEVICE | |
| from models.tokenizer import TokenAligner | |
| from dataset.vocab import Vocab | |
| class TransformerWithTR(nn.Module): | |
| def __init__(self, bart_model, padding_index) -> None: | |
| super().__init__() | |
| self.bart: MBartForConditionalGeneration= bart_model | |
| self.pad_token_id = padding_index | |
| def forward(self, src_ids, attn_masks, labels = None): | |
| labels[labels == self.pad_token_id] = -100 | |
| src_ids = src_ids.to(DEVICE) | |
| labels = labels.to(DEVICE) | |
| attn_masks = attn_masks.to(DEVICE) | |
| out = dict() | |
| output = self.bart(input_ids = src_ids, attention_mask = attn_masks, | |
| labels = labels) | |
| logits = output['logits'] | |
| out['loss'] = output['loss'] | |
| out['logits'] = logits | |
| probs = F.softmax(logits, dim = -1) | |
| preds = torch.argmax(probs, dim = -1) | |
| out['preds'] = preds.cpu().detach().numpy() | |
| return out | |
| def resize_token_embeddings(self, tokenAligner: TokenAligner): | |
| vocab: Vocab = tokenAligner.vocab | |
| tokenizer: AutoTokenizer = tokenAligner.tokenizer | |
| char_vocab = [] | |
| for i, key in enumerate(vocab.chartoken2idx.keys()): | |
| if i < 4: | |
| continue | |
| char_vocab.append(key) | |
| char_vocab.append(key + "@@") | |
| tokenizer.add_tokens(char_vocab) | |
| self.bart.resize_token_embeddings(len(tokenizer.get_vocab())) | |
| print("Resized token embeddings!") | |
| return | |
| def inference(self, src_ids, num_beams = 2, tokenAligner: TokenAligner = None): | |
| assert tokenAligner != None | |
| src_ids = src_ids.to(DEVICE) | |
| output = self.bart.generate(src_ids, num_beams=num_beams, max_new_tokens = 256) | |
| predict_text = tokenAligner.tokenizer.batch_decode(output, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces = False) | |
| return predict_text | |