Spaces:
Runtime error
Runtime error
| # Adapted from https://github.com/nicola-decao/KnowledgeEditor/blob/main/src/models/one_shot_learner.py | |
| """ | |
| @inproceedings{decao2020editing, | |
| title={Editing Factual Knowledge in Language Models}, | |
| author={Nicola De Cao and Wilker Aziz and Ivan Titov}, | |
| booktitle={arXiv pre-print 2104.08164}, | |
| url={https://arxiv.org/abs/2104.08164}, | |
| year={2021}, | |
| } | |
| """ | |
| import torch | |
| import copy | |
| import higher | |
| from higher.patch import monkeypatch as make_functional | |
| from allennlp.modules.feedforward import FeedForward | |
| from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper | |
| import logging | |
| from editable_model import EditableModel | |
| from utils import _logits, _inner_params | |
| from models import BertClassifier | |
| from transformers import BartForConditionalGeneration, T5ForConditionalGeneration | |
| LOG = logging.getLogger(__name__) | |
| class KE(EditableModel): | |
| def __init__(self, model, config, model_constructor, editor=None): | |
| super().__init__(model, config, model_constructor) | |
| if editor is None: | |
| if isinstance(model, BertClassifier): | |
| embedding = model.model.embeddings.word_embeddings.weight.data | |
| elif isinstance(model, BartForConditionalGeneration): | |
| embedding = model.model.shared.weight.data | |
| elif isinstance(model, T5ForConditionalGeneration): | |
| embedding = model.shared.weight.data | |
| else: | |
| embedding = model.transformer.wte.weight.data | |
| editor = OneShotLearner(model, vocab_dim=model.config.vocab_size, | |
| include_set=config.model.inner_params, | |
| embedding_dim=embedding.shape[-1], | |
| embedding_init=embedding.clone().to(torch.float32), | |
| max_scale=1) | |
| self.editor = editor | |
| def outer_parameters(self, grouped=False): | |
| if grouped: | |
| return [ | |
| dict(params=self.editor.parameters(), lr=self.config.lr) | |
| ] | |
| else: | |
| return list(self.editor.parameters()) | |
| def state_dict(self, destination=None, prefix="", keep_vars=False): | |
| state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict | |
| model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params | |
| for k in model_keys: | |
| del state_dict[f"model.{k}"] | |
| state_dict["model_config"] = self.model.config # Include model config | |
| return state_dict | |
| def load_state_dict(self, state_dict, strict: bool = True): | |
| config = state_dict["model_config"] | |
| del state_dict["model_config"] | |
| if config != self.model.config: | |
| LOG.info("Loaded model config doesn't match current model config.") | |
| LOG.info(f"Loaded: {config}") | |
| LOG.info(f"Current: {self.model.config}") | |
| res = super().load_state_dict(state_dict, False) | |
| # We should only have missing keys for the model, and no unexpected keys | |
| assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model." | |
| assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys" | |
| return res | |
| def edit(self, batch, condition, detach_history=False): | |
| outputs = _logits(self.model(**batch)) | |
| loss = self.edit_loss_fn(outputs, batch["labels"])["nll"] | |
| names = set([n for n, p in self.model.named_parameters()]) | |
| pset = set(self.config.model.inner_params) | |
| for p in pset: | |
| assert p in names, f"inner param {p} not in model" | |
| grads = torch.autograd.grad( | |
| loss, | |
| [p for (n, p) in _inner_params(self.model.named_parameters(), self.config.model.inner_params)] | |
| ) | |
| params_dict = self.editor( | |
| condition["input_ids"] if condition is not None else batch["input_ids"], | |
| condition["attention_mask"] if condition is not None else batch["attention_mask"], | |
| {n: g.to(torch.float32) for (n, g) in zip(self.config.model.inner_params, grads)}, | |
| ) | |
| edited_model = self.model | |
| if not isinstance(edited_model, higher.patch._MonkeyPatchBase): | |
| edited_model = make_functional(edited_model, in_place=True) | |
| def new_param(n, p): | |
| if n not in params_dict: | |
| return p | |
| if p.shape[0] == params_dict[n].shape[0]: | |
| return p + params_dict[n] | |
| else: | |
| return p + params_dict[n].T | |
| edited_model.update_params( | |
| [new_param(n, p) for (n, p) in edited_model.named_parameters()] | |
| ) | |
| if detach_history: | |
| new_model = self.model_constructor() | |
| new_model.load_state_dict(edited_model.state_dict()) | |
| edited_model = new_model | |
| return KE(edited_model, self.config, self.model_constructor, editor=self.editor), {} | |
| class ConditionedParameter(torch.nn.Module): | |
| def __init__(self, parameter, condition_dim=1024, hidden_dim=128, max_scale=1): | |
| super().__init__() | |
| self.parameter_shape = parameter.shape | |
| if len(self.parameter_shape) == 2: | |
| self.conditioners = torch.nn.Sequential( | |
| torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)), | |
| torch.nn.Tanh(), | |
| torch.nn.utils.weight_norm( | |
| torch.nn.Linear( | |
| hidden_dim, 2 * (parameter.shape[0] + parameter.shape[1]) + 1 | |
| ) | |
| ), | |
| ) | |
| elif len(self.parameter_shape) == 1: | |
| self.conditioners = torch.nn.Sequential( | |
| torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)), | |
| torch.nn.Tanh(), | |
| torch.nn.utils.weight_norm( | |
| torch.nn.Linear(hidden_dim, 2 * parameter.shape[0] + 1) | |
| ), | |
| ) | |
| else: | |
| raise RuntimeError() | |
| self.max_scale = max_scale | |
| def forward(self, inputs, grad): | |
| if inputs.shape[0] > 1: | |
| raise RuntimeError("Can only condition on batches of size 1") | |
| if len(self.parameter_shape) == 2: | |
| ( | |
| conditioner_cola, | |
| conditioner_rowa, | |
| conditioner_colb, | |
| conditioner_rowb, | |
| conditioner_norm, | |
| ) = self.conditioners(inputs).split( | |
| [ | |
| self.parameter_shape[1], | |
| self.parameter_shape[0], | |
| self.parameter_shape[1], | |
| self.parameter_shape[0], | |
| 1, | |
| ], | |
| dim=-1, | |
| ) | |
| a = conditioner_rowa.softmax(-1).T @ conditioner_cola | |
| b = conditioner_rowb.softmax(-1).T @ conditioner_colb | |
| elif len(self.parameter_shape) == 1: | |
| a, b, conditioner_norm = self.conditioners(inputs).split( | |
| [self.parameter_shape[0], self.parameter_shape[0], 1], dim=-1 | |
| ) | |
| else: | |
| raise RuntimeError() | |
| if a.squeeze().shape[0] != grad.shape[0]: | |
| return self.max_scale * conditioner_norm.sigmoid().squeeze() * (grad * a.squeeze().T + b.squeeze().T) | |
| else: | |
| return self.max_scale * conditioner_norm.sigmoid().squeeze() * (grad * a.squeeze() + b.squeeze()) | |
| class LSTMConditioner(torch.nn.Module): | |
| def __init__( | |
| self, | |
| vocab_dim=30522, | |
| embedding_dim=768, | |
| hidden_dim=256, | |
| output_dim=1024, | |
| embedding_init=None, | |
| ): | |
| super().__init__() | |
| self.embedding = torch.nn.Embedding( | |
| num_embeddings=vocab_dim, | |
| embedding_dim=embedding_dim, | |
| padding_idx=0, | |
| _weight=embedding_init, | |
| ) | |
| self.lstm = PytorchSeq2VecWrapper( | |
| torch.nn.LSTM( | |
| input_size=embedding_dim, | |
| hidden_size=hidden_dim, | |
| num_layers=1, | |
| bidirectional=True, | |
| batch_first=True, | |
| ) | |
| ) | |
| self.linear = FeedForward( | |
| input_dim=hidden_dim * 2, | |
| num_layers=1, | |
| hidden_dims=[output_dim], | |
| activations=[torch.nn.Tanh()], | |
| ) | |
| def forward(self, inputs, masks): | |
| return self.linear(self.lstm(self.embedding(inputs), masks)) | |
| class OneShotLearner(torch.nn.Module): | |
| def __init__( | |
| self, | |
| model, | |
| vocab_dim, | |
| embedding_dim=768, | |
| hidden_dim=512, | |
| condition_dim=768, | |
| include_set={}, | |
| max_scale=1e-3, | |
| embedding_init=None, | |
| ): | |
| super().__init__() | |
| self.param2conditioner_map = { | |
| n: "{}_conditioner".format(n).replace(".", "_") | |
| for n, p in model.named_parameters() | |
| if n in include_set | |
| } | |
| self.conditioners = torch.nn.ModuleDict( | |
| { | |
| self.param2conditioner_map[n]: ConditionedParameter( | |
| p, | |
| condition_dim, | |
| hidden_dim, | |
| max_scale=max_scale, | |
| ) | |
| for n, p in model.named_parameters() | |
| if n in include_set | |
| } | |
| ) | |
| self.condition = LSTMConditioner( | |
| vocab_dim, | |
| embedding_dim, | |
| hidden_dim, | |
| condition_dim, | |
| embedding_init=embedding_init, | |
| ) | |
| def forward(self, inputs, masks, grads=None): | |
| condition = self.condition(inputs, masks) | |
| return { | |
| p: self.conditioners[self.param2conditioner_map[p]]( | |
| condition, | |
| grad=grads[p] if grads else None, | |
| ) | |
| for p, c in self.param2conditioner_map.items() | |
| } | |
| if __name__ == '__main__': | |
| import transformers | |
| import types | |
| model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") | |
| config = types.SimpleNamespace() | |
| config.model.inner_params = [ | |
| "transformer.h.9.mlp.c_fc.weight", | |
| "transformer.h.9.mlp.c_proj.weight", | |
| "transformer.h.10.mlp.c_fc.weight", | |
| "transformer.h.10.mlp.c_proj.weight", | |
| "transformer.h.11.mlp.c_fc.weight", | |
| "transformer.h.11.mlp.c_proj.weight", | |
| ] | |
| efk = KE(model, config, lambda: copy.deepcopy(model)).cuda() | |
| x = torch.arange(20).view(1, 20).cuda() + 1000 | |
| orig_logits = efk(x).logits | |
| edited = efk.edit(x, masks=torch.ones_like(x), labels=x) | |
| post_logits = efk(x).logits | |
| assert torch.allclose(orig_logits, post_logits) | |
| orig_param = [p for (n, p) in efk.model.named_parameters() if n == config.model.inner_params[-1]][0] | |
| edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0] | |
| print((orig_param - edited_param).abs().max()) | |
| edited.eval() | |
| print(efk(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x))["nll"] | |
| edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x) | |
| print(efk(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss) | |
| import pdb; pdb.set_trace() | |