Spaces:
Runtime error
Runtime error
| from models.transformer import TransformerWithTR | |
| from models.collator import * | |
| from transformers import AutoTokenizer | |
| import transformers | |
| from models.tokenizer import TokenAligner | |
| from dataset.vocab import Vocab | |
| class ModelWrapper: | |
| def __init__(self, model, vocab: Vocab): | |
| self.model_name = model | |
| if model == "tfmwtr": | |
| self.tokenizer = AutoTokenizer.from_pretrained("vinai/bartpho-word-base") | |
| self.tokenAligner = TokenAligner(self.tokenizer, vocab) | |
| self.bart = transformers.MBartForConditionalGeneration.from_pretrained("vinai/bartpho-word-base") | |
| self.model = TransformerWithTR(self.bart, self.tokenizer.pad_token_id) | |
| self.collator = DataCollatorForCharacterTransformer(self.tokenAligner) | |
| # self.model.resize_token_embeddings(self.tokenAligner) | |
| else: | |
| raise(Exception(f"Model {model} isn't implemented!")) | |