Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import copy | |
| import transformers | |
| import higher | |
| import logging | |
| from higher.patch import monkeypatch as make_functional | |
| from collections import defaultdict | |
| from editable_model import EditableModel | |
| from hooks import hook_model | |
| import nn as local_nn | |
| from utils import _logits, _inner_params | |
| LOG = logging.getLogger(__name__) | |
| def update_counter(x, m, s, k): | |
| new_m = m + (x - m) / k | |
| new_s = s + (x - m) * (x - new_m) | |
| return new_m, new_s | |
| class GradientTransform(nn.Module): | |
| def __init__(self, x_dim: int, delta_dim: int, cfg, n_modes = None): | |
| super().__init__() | |
| self.x_dim = x_dim | |
| self.delta_dim = delta_dim | |
| self.cfg = cfg | |
| if cfg.combine and (cfg.one_sided or cfg.x_only or cfg.delta_only): | |
| raise ValueError("cfg.combine cannot be used with one-sided MEND variants") | |
| self.norm_init = False | |
| self.register_buffer("u_mean", torch.full((x_dim,), float("nan"))) | |
| self.register_buffer("v_mean", torch.full((delta_dim,), float("nan"))) | |
| self.register_buffer("u_std", torch.full((x_dim,), float("nan"))) | |
| self.register_buffer("v_std", torch.full((delta_dim,), float("nan"))) | |
| self.register_buffer("u_s", torch.full((x_dim,), float("nan"))) | |
| self.register_buffer("v_s", torch.full((delta_dim,), float("nan"))) | |
| self.register_buffer("k", torch.full((1,), float("nan"))) | |
| MlpClass = getattr(local_nn, cfg.mlp_class) | |
| LOG.info(f"Building Gradient Transform with MLP class {MlpClass}") | |
| def delta_net(): | |
| return MlpClass(delta_dim, delta_dim, delta_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes) | |
| def x_net(): | |
| return MlpClass(x_dim, x_dim, x_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes) | |
| def combined_net(): | |
| return MlpClass(delta_dim + x_dim, delta_dim + x_dim, (delta_dim + x_dim) * 2, | |
| cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes) | |
| def ID(): | |
| return lambda x, mode=None: x | |
| if cfg.combine: | |
| self.mlp = combined_net() | |
| elif cfg.one_sided: | |
| if x_dim > delta_dim: | |
| self.mlp1, self.mlp2 = ID(), delta_net() | |
| else: | |
| self.mlp1, self.mlp2 = x_net(), ID() | |
| elif cfg.x_only: | |
| self.mlp1, self.mlp2 = x_net(), ID() | |
| elif cfg.delta_only: | |
| self.mlp1, self.mlp2 = ID(), delta_net() | |
| else: | |
| self.mlp1, self.mlp2 = x_net(), delta_net() | |
| def forward(self, u, v, param_idx=None): | |
| u, v = u.to(torch.float32), v.to(torch.float32) | |
| u_ = u.view(-1, u.shape[-1]) | |
| v_ = v.view(-1, v.shape[-1]) | |
| nz_mask = (u_ != 0).any(-1) * (v_ != 0).any(-1) # Skip batch elements with zero grad | |
| u_ = u_[nz_mask] | |
| v_ = v_[nz_mask] | |
| if self.training: | |
| for idx in range(u_.shape[0]): | |
| if not self.norm_init: | |
| self.u_mean = u_[idx].clone().detach() | |
| self.v_mean = v_[idx].clone().detach() | |
| self.u_s.zero_() | |
| self.v_s.zero_() | |
| self.k[:] = 1 | |
| self.norm_init = True | |
| else: | |
| self.k += 1 | |
| self.u_mean, self.u_s = update_counter(u_[idx], self.u_mean, self.u_s, self.k) | |
| self.v_mean, self.v_s = update_counter(v_[idx], self.v_mean, self.v_s, self.k) | |
| if self.k < 2: | |
| raise RuntimeError(f"Can't perform normalization with only {self.k} samples so far") | |
| self.u_std = (self.u_s / (self.k - 1)) ** 0.5 | |
| self.v_std = (self.v_s / (self.k - 1)) ** 0.5 | |
| if self.cfg.norm: | |
| u_input = (u_ - self.u_mean) / (self.u_std + 1e-7) | |
| v_input = (v_ - self.v_mean) / (self.v_std + 1e-7) | |
| else: | |
| u_input = u_ | |
| v_input = v_ | |
| if self.cfg.combine: | |
| output = self.mlp(torch.cat((u_input, v_input), -1), mode=param_idx) | |
| out1, out2 = output.split([u.shape[-1], v.shape[-1]], -1) | |
| return out1, out2 | |
| else: | |
| return self.mlp1(u_input, mode=param_idx), self.mlp2(v_input, mode=param_idx) | |
| class MEND(EditableModel): | |
| def get_shape(self, p): | |
| # We need to (annoyingly) flip the shapes since OpenAI gpt2 uses convs instead of linear | |
| return p.shape if isinstance(self.model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0]) | |
| def __init__(self, model, config, model_constructor, gtn=None, edit_lrs=None): | |
| super().__init__(model, config, model_constructor) | |
| if edit_lrs is None: | |
| edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params))) | |
| self.edit_lrs = edit_lrs | |
| if not hasattr(self.model, "handles"): | |
| hook_model(self.model, self.config.model.inner_params) | |
| LOG.info(f"Hooked {len(self.model.handles)//2} modules") | |
| if config.gtn.shared: | |
| shape_dict = defaultdict(list) | |
| for n, p in _inner_params(model.named_parameters(), self.config.model.inner_params): | |
| shape_dict[self.get_shape(p)].append(n) | |
| self.shape_dict = shape_dict | |
| if gtn is None: | |
| if not config.gtn.shared: | |
| self.gtn = nn.ModuleDict({ | |
| n.replace(".", "#"): GradientTransform(*self.get_shape(p), config.gtn) | |
| for (n, p) in _inner_params(model.named_parameters(), self.config.model.inner_params) | |
| }) | |
| else: | |
| self.gtn = nn.ModuleDict({ | |
| str(tuple(s)): GradientTransform(*s, config.gtn, len(shape_dict[s])) | |
| for s in shape_dict.keys() | |
| }) | |
| else: | |
| self.gtn = gtn | |
| def state_dict(self, destination=None, prefix="", keep_vars=False): | |
| state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict | |
| model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params | |
| for k in model_keys: | |
| del state_dict[f"model.{k}"] | |
| state_dict["model_config"] = self.model.config # Include model config | |
| return state_dict | |
| def load_state_dict(self, state_dict, strict: bool = True): | |
| config = state_dict["model_config"] | |
| del state_dict["model_config"] | |
| if config != self.model.config: | |
| LOG.info("Loaded model config doesn't match current model config.") | |
| LOG.info(f"Loaded: {config}") | |
| LOG.info(f"Current: {self.model.config}") | |
| res = super().load_state_dict(state_dict, False) | |
| # We should only have missing keys for the model, and no unexpected keys | |
| assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model." | |
| assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys" | |
| return res | |
| def outer_parameters(self, grouped=False): | |
| if grouped: | |
| return [ | |
| dict(params=list(self.gtn.parameters()), lr=self.config.lr), | |
| dict(params=[self.edit_lrs], lr=self.config.lr_lr) | |
| ] | |
| else: | |
| return list(self.gtn.parameters()) + [self.edit_lrs] | |
| def edit(self, batch, condition=None, detach_history=False): | |
| outputs = _logits(self.model(**batch)) | |
| loss = self.edit_loss_fn(outputs, batch["labels"])["nll"] | |
| names = set([n for n, p in self.model.named_parameters()]) | |
| pset = set(self.config.model.inner_params) | |
| for p in pset: | |
| assert p in names, f"inner param {p} not in model" | |
| loss.backward() | |
| if self.config.gtn.shared: | |
| param_idx = lambda n, p: self.shape_dict[self.get_shape(p)].index(n) if self.config.gtn.shared else None # noqa: E731 | |
| transformed_factors = { | |
| n: self.gtn[str(tuple(self.get_shape(p)))](p.__x__, p.__delta__, param_idx(n, p)) | |
| for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params) | |
| } | |
| else: | |
| transformed_factors = { | |
| n: self.gtn[n.replace(".", "#")](p.__x__, p.__delta__) | |
| for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params) | |
| } | |
| # Should be bi,bj->ji for nn.Linear, but [annoying] GPT2 uses Conv1d instead... | |
| if isinstance(self.model, transformers.GPT2LMHeadModel): | |
| targ = "ij" | |
| else: | |
| targ = "ji" | |
| mean_grads = { | |
| n: torch.einsum(f"bi,bj->{targ}", x, delta) | |
| for n, (x, delta) in transformed_factors.items() | |
| } | |
| info_dict = {} | |
| idx = 0 | |
| for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params): | |
| info_dict[f"grad/true_mag{idx}"] = p.grad.norm(2).item() | |
| info_dict[f"grad/pseudo_mag{idx}"] = mean_grads[n].norm(2).item() | |
| info_dict[f"grad/true_std{idx}"] = p.grad.std().item() | |
| info_dict[f"grad/pseudo_std{idx}"] = mean_grads[n].std().item() | |
| info_dict[f"grad/diff{idx}"] = (p.grad - mean_grads[n]).norm(2).item() | |
| info_dict[f"grad/cos{idx}"] = F.cosine_similarity(p.grad.reshape(-1), mean_grads[n].reshape(-1), dim=0).item() | |
| idx += 1 | |
| self.model.zero_grad() | |
| assert len(self.edit_lrs) == len(list(mean_grads.items())) | |
| updates = {n: lr * g for lr, (n, g) in zip(self.edit_lrs, mean_grads.items())} | |
| edited_model = self.model | |
| if not isinstance(edited_model, higher.patch._MonkeyPatchBase): | |
| edited_model = make_functional(edited_model, in_place=True) | |
| new_params = [] | |
| for n, p in edited_model.named_parameters(): | |
| if n in pset: | |
| if self.config.gtn.descent: | |
| new_params.append(p - updates[n]) | |
| else: | |
| new_params.append(p + updates[n]) | |
| else: | |
| new_params.append(p) | |
| edited_model.update_params(new_params) | |
| if detach_history: | |
| new_model = self.model_constructor() | |
| new_model.load_state_dict(edited_model.state_dict()) | |
| edited_model = new_model | |
| return MEND(edited_model, self.config, self.model_constructor, self.gtn, edit_lrs=self.edit_lrs), info_dict | |
| if __name__ == '__main__': | |
| import types | |
| model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") | |
| config = types.SimpleNamespace() | |
| config.model.inner_params = [ | |
| "transformer.h.9.mlp.c_fc.weight", | |
| "transformer.h.9.mlp.c_proj.weight", | |
| "transformer.h.10.mlp.c_fc.weight", | |
| "transformer.h.10.mlp.c_proj.weight", | |
| "transformer.h.11.mlp.c_fc.weight", | |
| "transformer.h.11.mlp.c_proj.weight", | |
| ] | |
| config.edit_lr = 0.0001 | |
| config.gtn = types.SimpleNamespace() | |
| config.gtn.n_hidden = 1 | |
| config.gtn = config.gtn.__dict__ | |
| gtn = MEND(model, config, lambda: copy.deepcopy(model)).cuda() | |
| # torch.save(gtn.state_dict(), "test_state.pt") | |
| import pdb; pdb.set_trace() | |
| gtn.load_state_dict(torch.load("test_state.pt")) | |
| x = torch.arange(20).view(1, 20).cuda() + 1000 | |
| orig_logits = gtn(x) | |
| edited = gtn.edit(x, masks=torch.ones_like(x), labels=x) | |
| post_logits = gtn(x) | |
| assert torch.allclose(orig_logits, post_logits) | |
| orig_param = [p for (n, p) in gtn.model.named_parameters() if n == config.model.inner_params[-1]][0] | |
| edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0] | |
| LOG.info((orig_param - edited_param).abs().max()) | |
| edited.eval() | |
| LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"]) | |
| edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x) | |
| LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss) | |