Spaces:
Runtime error
Runtime error
| from abc import abstractmethod | |
| from models.tokenizer import TokenAligner | |
| class PTCollator(): | |
| def __init__(self, tokenAligner: TokenAligner): | |
| self.tokenAligner = tokenAligner | |
| def collate(self, dataloader_batch, type = "train") -> dict: | |
| if type == "train": | |
| return self.collate_train(dataloader_batch) | |
| elif type == "test": | |
| return self.collate_test(dataloader_batch) | |
| elif type == "correct": | |
| return self.collate_correct(dataloader_batch) | |
| def collate_train(self, dataloader_batch): | |
| pass | |
| def collate_test(self, dataloader_batch): | |
| pass | |
| def collate_correct(self, dataloader_batch): | |
| pass | |
| class DataCollatorForCharacterTransformer(PTCollator): | |
| def __init__(self, tokenAligner: TokenAligner): | |
| super().__init__(tokenAligner) | |
| def collate_train(self, dataloader_batch): | |
| noised, labels = [], [] | |
| for sample in dataloader_batch: | |
| labels.append(sample[0]) | |
| noised.append(sample[1]) | |
| batch_srcs, batch_tgts, batch_lengths, batch_attention_masks = self.tokenAligner.tokenize_for_transformer_with_tokenization(noised, labels) | |
| data = dict() | |
| data['batch_src'] = batch_srcs | |
| data['batch_tgt'] = batch_tgts | |
| data['attn_masks'] = batch_attention_masks | |
| data['lengths'] = batch_lengths | |
| return data | |
| def collate_test(self, dataloader_batch): | |
| noised, labels = [], [] | |
| for sample in dataloader_batch: | |
| labels.append(sample[0]) | |
| noised.append(sample[1]) | |
| batch_srcs, batch_attention_masks = self.tokenAligner.tokenize_for_transformer_with_tokenization(noised, None) | |
| data = dict() | |
| data['batch_src'] = batch_srcs | |
| data['noised_texts'] = noised | |
| data['label_texts'] = labels | |
| data['attn_masks'] = batch_attention_masks | |
| return data | |
| def collate_correct(self, dataloader_batch): | |
| noised, labels = [], [] | |
| for sample in dataloader_batch: | |
| noised.append(sample[1]) | |
| batch_srcs, batch_attention_masks= self.tokenAligner.tokenize_for_transformer_with_tokenization(noised) | |
| data = dict() | |
| data['batch_src'] = batch_srcs | |
| data['noised_texts'] = noised | |
| data['attn_masks'] = batch_attention_masks | |
| return data | |