Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from losses import masked_log_probs | |
| from utils import _logits, shift_targets | |
| class EditableModel(nn.Module): | |
| def __init__(self, model, config, model_constructor): | |
| super().__init__() | |
| self.model = model | |
| self.config = config | |
| self.model_constructor = model_constructor | |
| def _edit_loss_fn(pred, targ, **kwargs): | |
| return masked_log_probs(pred, targ, shift=shift_targets(self.config), **kwargs) | |
| self.edit_loss_fn = _edit_loss_fn | |
| self.loc_loss_fn = _edit_loss_fn | |
| def edit(self, batch, condition=None, detach_history=False): | |
| raise NotImplementedError | |
| def forward(self, *inputs, **kwargs): | |
| return _logits(self.model(*inputs, **kwargs)) | |
| def outer_parameters(self, grouped=False): | |
| if grouped: | |
| return [dict(params=self.parameters(), lr=self.config.lr)] | |
| else: | |
| return list(self.parameters()) | |
| def generate(self, *args, **kwargs): | |
| return self.model.generate(*args, **kwargs) | |
| def base_loss(self, input_ids, attention_masks, label_ids): | |
| pass | |