| import torch | |
| class BiEncoderCollator: | |
| def __call__(self, features): | |
| batch = { | |
| 'input_ids_text1': torch.stack([f['input_ids_text1'] for f in features]), | |
| 'attention_mask_text1': torch.stack([f['attention_mask_text1'] for f in features]), | |
| 'input_ids_text2': torch.stack([f['input_ids_text2'] for f in features]), | |
| 'attention_mask_text2': torch.stack([f['attention_mask_text2'] for f in features]), | |
| 'labels': torch.tensor([f['labels'] for f in features], dtype=torch.float) | |
| } | |
| return batch | |