Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import higher | |
| from higher.patch import monkeypatch as make_functional | |
| import time | |
| from editable_model import EditableModel | |
| from utils import _logits, _inner_params | |
| from losses import kl_loc_loss | |
| class FT(EditableModel): | |
| """ | |
| Fine-tuning approach. Does not require training. | |
| """ | |
| def __init__(self, model, config, model_constructor, edit_loss_fn=None): | |
| super().__init__(model, config, model_constructor) | |
| if edit_loss_fn is not None: | |
| self.edit_loss_fn = edit_loss_fn | |
| self.locality_loss_fn = kl_loc_loss | |
| self.loc_ids = None | |
| self.loc_masks = None | |
| self.loc_sampler = None | |
| def _edit_loss(self, model, p0, p_edited, edit_batch): | |
| output = _logits(model(**edit_batch, params=p_edited)) | |
| loss_dict = self.edit_loss_fn(output, edit_batch["labels"]) | |
| l_edit, acc = loss_dict["nll"], loss_dict["acc"] | |
| if self.config.ft.locality.enabled: | |
| if self.config.ft.locality.oracle: | |
| loc_batch = next(self.loc_sampler)["loc"] | |
| else: | |
| raise NotImplementedError | |
| with torch.no_grad(): | |
| original_base_logits = _logits(model(**loc_batch, params=p0)) | |
| edited_base_logits = _logits(model(**loc_batch, params=p_edited)) | |
| kl_mask = loc_batch.get("decoder_attention_mask", loc_batch["attention_mask"]) | |
| l_loc = self.locality_loss_fn(original_base_logits, edited_base_logits, mask=kl_mask) | |
| loss = l_loc + self.config.ft.locality.cedit * l_edit | |
| else: | |
| l_loc = torch.tensor(float('nan')) | |
| loss = l_edit | |
| return loss, l_edit, l_loc, acc | |
| def accuracy(self, output, labels): | |
| if output.shape[-1] != 1: | |
| shifted_output = output.argmax(-1)[:, :-1] | |
| shifted_labels = labels[:, 1:] | |
| to_predict = (shifted_labels != -100).sum() | |
| correct = (shifted_output == shifted_labels).sum() | |
| acc = correct.float() / to_predict.float() | |
| else: | |
| acc = ((output > 0) == labels.bool()).sum().float() | |
| return acc | |
| def _edit_status(self, step, loss, l_edit, l_loc, acc, res_p): | |
| return ( | |
| f"step: {step}".ljust(14) + | |
| f"loss: {loss.item():.5f}".ljust(18) + | |
| f"l_edit: {l_edit.item():.5f}".ljust(18) + | |
| f"l_loc: {l_loc.item():.5f}".ljust(18) + | |
| f"acc: {acc.item():.2f}".ljust(14) + | |
| f"norm: {res_p.view(-1).norm().item():.5f}" | |
| ) | |
| def edit(self, batch, condition=None, detach_history=False): | |
| edit_model = self.model.eval() | |
| p0 = list(edit_model.named_parameters()) | |
| if not isinstance(edit_model, higher.patch._MonkeyPatchBase): | |
| edit_model = make_functional(self.model, track_higher_grads=False, in_place=True) | |
| packed_residuals = {} | |
| opt_params = [] | |
| for n, p in _inner_params(edit_model.named_parameters(), self.config.model.inner_params): | |
| if self.config.ft.rank is not None: | |
| u = nn.Parameter(torch.randn(p.shape[0], self.config.ft.rank, device=p.device) * self.config.ft.init_std) | |
| v = nn.Parameter(torch.zeros(self.config.ft.rank, p.shape[1], device=p.device)) | |
| res = [u, v] | |
| else: | |
| res = [nn.Parameter(torch.zeros_like(p, device=p.device))] | |
| packed_residuals[n] = res | |
| opt_params.extend(res) | |
| assert len(opt_params) == len(self.config.model.inner_params) | |
| OptClass = getattr(torch.optim, self.config.ft.opt) | |
| opt = OptClass(opt_params, lr=self.config.edit_lr) | |
| start_time = time.time() | |
| for edit_step in range(self.config.ft.max_edit_steps): | |
| if self.config.ft.time_limit is not None and (time.time() - start_time > self.config.ft.time_limit): | |
| break | |
| residuals = {k: v[0] @ v[1] if len(v) == 2 else v[0] for k, v in packed_residuals.items()} | |
| edited_params = [p if n not in residuals else p.detach() + residuals[n] for n, p in p0] | |
| loss, l_edit, l_loc, acc = self._edit_loss(edit_model, [p for n, p in p0], edited_params, batch) | |
| if self.config.ft.verbose: | |
| residual = list(residuals.values())[-1] | |
| print(self._edit_status(edit_step, loss, l_edit, l_loc, acc, residual), end="\r") | |
| if acc == 1.0: | |
| break | |
| for p, g in zip(opt_params, torch.autograd.grad(loss, opt_params)): | |
| p.grad = g | |
| torch.nn.utils.clip_grad_norm_(opt_params, self.config.grad_clip) | |
| opt.step() | |
| opt.zero_grad() | |
| if detach_history: | |
| new_model = self.model_constructor() | |
| new_model.load_state_dict(edit_model.state_dict()) | |
| edit_model = new_model | |
| edit_model.train(self.training) | |
| return FT(edit_model, self.config, self.model_constructor, self.edit_loss_fn), {} | |