Spaces:
Runtime error
Runtime error
Charles Lin
commited on
Commit
·
8335d0c
1
Parent(s):
a9853a7
All algs except KE working.
Browse files
algs/lu.py
CHANGED
|
@@ -15,56 +15,45 @@ class LU(EditableModel):
|
|
| 15 |
def __init__(self, model, config, model_constructor, memory=None):
|
| 16 |
super().__init__(model, config, model_constructor)
|
| 17 |
|
|
|
|
|
|
|
| 18 |
self.memory = memory
|
| 19 |
|
| 20 |
-
def
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
if self.memory is not None:
|
| 29 |
-
for i, encoder_state in enumerate(encoder_states):
|
| 30 |
-
if "gpt2" in self.config.model.name.lower():
|
| 31 |
-
# NOTE: broken
|
| 32 |
-
memory_prefixes, memory_labels = self.memory
|
| 33 |
-
prefix_means = encoder_state.cumsum(0).detach() / torch.arange(1, encoder_state.shape[0] + 1, device=encoder_state.device).view(-1, 1)
|
| 34 |
-
dist_mat = (prefix_means.unsqueeze(1) - memory_prefixes.unsqueeze(0)).norm(2, dim=-1)
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
closest_v = memory_labels[closest_idx]
|
| 47 |
-
|
| 48 |
-
if closest_dist < self.config.lu.threshold:
|
| 49 |
-
output[i] = torch.zeros((1, kwargs['labels'].shape[1], output.shape[2]), device=output.device)
|
| 50 |
-
for j, idx in enumerate(closest_v):
|
| 51 |
-
if j >= output.shape[1]:
|
| 52 |
-
break
|
| 53 |
-
output[i, j, idx] = self.config.lu.onehot_logit
|
| 54 |
-
if "t5" not in self.config.model.name.lower():
|
| 55 |
-
# T5 does not shift targets in the loss
|
| 56 |
-
output[i] = output[i].roll(-1, -2)
|
| 57 |
-
else:
|
| 58 |
-
avg_encoder_state = encoder_state.detach().mean(0)
|
| 59 |
-
memory_keys, memory_labels = self.memory
|
| 60 |
-
dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
|
| 61 |
-
closest_dist = dists.min()
|
| 62 |
-
closest_idx = dists.argmin()
|
| 63 |
-
closest_v = memory_labels[closest_idx]
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
return output
|
| 69 |
|
| 70 |
def edit(self, batch, condition=None, detach_history=False):
|
|
@@ -77,14 +66,9 @@ class LU(EditableModel):
|
|
| 77 |
memory_keys = []
|
| 78 |
memory_labels = []
|
| 79 |
for encoder_state, label in zip(encoder_states, batch["labels"]):
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
memory = (avg_encoder_states, label[-10:])
|
| 84 |
-
else:
|
| 85 |
-
avg_encoder_state = encoder_state.detach().mean(0)
|
| 86 |
-
memory_keys.append(avg_encoder_state)
|
| 87 |
-
memory_labels.append(label)
|
| 88 |
|
| 89 |
memory = (torch.stack(memory_keys), torch.stack(memory_labels))
|
| 90 |
return LU(self.model.eval(), self.config, self.model_constructor, memory), {}
|
|
|
|
| 15 |
def __init__(self, model, config, model_constructor, memory=None):
|
| 16 |
super().__init__(model, config, model_constructor)
|
| 17 |
|
| 18 |
+
if "t5" not in self.config.model.name.lower():
|
| 19 |
+
raise NotImplementedError
|
| 20 |
self.memory = memory
|
| 21 |
|
| 22 |
+
def lookup_replace(self, output, encoder_states):
|
| 23 |
+
for i, encoder_state in enumerate(encoder_states):
|
| 24 |
+
avg_encoder_state = encoder_state.detach().mean(0)
|
| 25 |
+
memory_keys, memory_labels = self.memory
|
| 26 |
+
dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
|
| 27 |
+
closest_dist = dists.min()
|
| 28 |
+
closest_idx = dists.argmin()
|
| 29 |
+
closest_v = memory_labels[closest_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
if closest_dist < self.config.lu.threshold:
|
| 32 |
+
output[i] = torch.zeros((1, output.shape[1], output.shape[2]), device=output.device)
|
| 33 |
+
for j, idx in enumerate(closest_v):
|
| 34 |
+
if j >= output.shape[1]:
|
| 35 |
+
break
|
| 36 |
+
output[i, j, idx] = self.config.lu.onehot_logit
|
| 37 |
+
if "t5" not in self.config.model.name.lower():
|
| 38 |
+
# T5 does not shift targets in the loss
|
| 39 |
+
output[i] = output[i].roll(-1, -2)
|
| 40 |
+
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
def generate(self, *inputs, **kwargs):
|
| 43 |
+
model_output = self.model.generate(*inputs, **kwargs, output_hidden_states=True,
|
| 44 |
+
output_scores=True, return_dict_in_generate=True)
|
| 45 |
+
encoder_states = _last_encoder_state(model_output)
|
| 46 |
+
output = _logits(model_output)
|
| 47 |
+
if self.memory is not None:
|
| 48 |
+
output = self.lookup_replace(output, encoder_states)
|
| 49 |
+
return output.argmax(-1)
|
| 50 |
|
| 51 |
+
def forward(self, *inputs, **kwargs):
|
| 52 |
+
model_output = self.model(*inputs, **kwargs, output_hidden_states=True)
|
| 53 |
+
encoder_states = _last_encoder_state(model_output)
|
| 54 |
+
output = _logits(model_output)
|
| 55 |
+
if self.memory is not None:
|
| 56 |
+
output = self.lookup_replace(output, encoder_states)
|
| 57 |
return output
|
| 58 |
|
| 59 |
def edit(self, batch, condition=None, detach_history=False):
|
|
|
|
| 66 |
memory_keys = []
|
| 67 |
memory_labels = []
|
| 68 |
for encoder_state, label in zip(encoder_states, batch["labels"]):
|
| 69 |
+
avg_encoder_state = encoder_state.detach().mean(0)
|
| 70 |
+
memory_keys.append(avg_encoder_state)
|
| 71 |
+
memory_labels.append(label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
memory = (torch.stack(memory_keys), torch.stack(memory_labels))
|
| 74 |
return LU(self.model.eval(), self.config, self.model_constructor, memory), {}
|
app.py
CHANGED
|
@@ -8,6 +8,7 @@ from torch.cuda import is_available as use_cuda
|
|
| 8 |
import algs
|
| 9 |
import config
|
| 10 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
EDIT_ALGS = [
|
|
@@ -19,6 +20,26 @@ EDIT_ALGS = [
|
|
| 19 |
"LU: Lookup Cache",
|
| 20 |
]
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
def generate(ids):
|
| 23 |
output_ids = st.session_state.editable_model.generate(input_ids=ids, max_new_tokens=20, min_length=1,
|
| 24 |
num_return_sequences=1, num_beams=3)
|
|
@@ -30,15 +51,7 @@ def reset():
|
|
| 30 |
|
| 31 |
selected_alg = st.session_state.alg_selector
|
| 32 |
alg_abbrv = selected_alg[:selected_alg.index(":")]
|
| 33 |
-
|
| 34 |
-
alg_class = getattr(alg_module, alg_abbrv.upper())
|
| 35 |
-
st.session_state.config = getattr(config, f"{alg_abbrv.lower()}_config")
|
| 36 |
-
with st.spinner('Loading model...'):
|
| 37 |
-
st.session_state.editable_model = alg_class(
|
| 38 |
-
st.session_state.model,
|
| 39 |
-
st.session_state.config,
|
| 40 |
-
lambda: copy.deepcopy(st.session_state.model),
|
| 41 |
-
).eval()
|
| 42 |
|
| 43 |
def apply_edit():
|
| 44 |
st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
|
|
@@ -67,12 +80,13 @@ if "init" not in st.session_state:
|
|
| 67 |
st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
|
| 68 |
st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Output", "N edits", "Alg"])
|
| 69 |
st.session_state.init = True
|
| 70 |
-
st.session_state.
|
| 71 |
-
st.session_state.device = "cuda" if use_cuda() else "cpu"
|
| 72 |
with st.spinner('Loading model...'):
|
| 73 |
st.session_state.tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
|
| 74 |
st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").to(st.session_state.device).eval()
|
| 75 |
-
|
|
|
|
|
|
|
| 76 |
|
| 77 |
########################
|
| 78 |
#### Interface code ####
|
|
|
|
| 8 |
import algs
|
| 9 |
import config
|
| 10 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 11 |
+
import utils
|
| 12 |
|
| 13 |
|
| 14 |
EDIT_ALGS = [
|
|
|
|
| 20 |
"LU: Lookup Cache",
|
| 21 |
]
|
| 22 |
|
| 23 |
+
def get_alg_class(alg_abbrv):
|
| 24 |
+
alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
|
| 25 |
+
alg_class = getattr(alg_module, alg_abbrv.upper())
|
| 26 |
+
return alg_class
|
| 27 |
+
|
| 28 |
+
def load_editable_model(alg_abbrv):
|
| 29 |
+
alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
|
| 30 |
+
alg_class = getattr(alg_module, alg_abbrv.upper())
|
| 31 |
+
st.session_state.config = getattr(config, f"{alg_abbrv.lower()}_config")
|
| 32 |
+
with st.spinner('Loading model...'):
|
| 33 |
+
st.session_state.editable_model = alg_class(
|
| 34 |
+
st.session_state.model,
|
| 35 |
+
st.session_state.config,
|
| 36 |
+
lambda: copy.deepcopy(st.session_state.model),
|
| 37 |
+
).eval()
|
| 38 |
+
if "archive" in st.session_state.config:
|
| 39 |
+
archive, st.session_state.config.archive = utils.load_archive(str(st.session_state.config.archive))
|
| 40 |
+
print(f"Loading archive from {st.session_state.config.archive}")
|
| 41 |
+
st.session_state.editable_model.load_state_dict(archive["model"])
|
| 42 |
+
|
| 43 |
def generate(ids):
|
| 44 |
output_ids = st.session_state.editable_model.generate(input_ids=ids, max_new_tokens=20, min_length=1,
|
| 45 |
num_return_sequences=1, num_beams=3)
|
|
|
|
| 51 |
|
| 52 |
selected_alg = st.session_state.alg_selector
|
| 53 |
alg_abbrv = selected_alg[:selected_alg.index(":")]
|
| 54 |
+
load_editable_model(alg_abbrv)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def apply_edit():
|
| 57 |
st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
|
|
|
|
| 80 |
st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
|
| 81 |
st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Output", "N edits", "Alg"])
|
| 82 |
st.session_state.init = True
|
| 83 |
+
st.session_state.device = "cpu" # "cuda" if use_cuda() else "cpu"
|
|
|
|
| 84 |
with st.spinner('Loading model...'):
|
| 85 |
st.session_state.tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
|
| 86 |
st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").to(st.session_state.device).eval()
|
| 87 |
+
# There is a "Loading model..." spinner in load_editable_model
|
| 88 |
+
alg_abbrv = "MEND" # Default initial alg of dropdown selector
|
| 89 |
+
load_editable_model(alg_abbrv)
|
| 90 |
|
| 91 |
########################
|
| 92 |
#### Interface code ####
|
config.py
CHANGED
|
@@ -21,7 +21,7 @@ model_config = {
|
|
| 21 |
}
|
| 22 |
|
| 23 |
ft_config = OmegaConf.create({
|
| 24 |
-
"device": "
|
| 25 |
"edit_lr": 5e-6,
|
| 26 |
"train_base": False,
|
| 27 |
"grad_clip": 100,
|
|
@@ -43,7 +43,7 @@ ft_config = OmegaConf.create({
|
|
| 43 |
})
|
| 44 |
|
| 45 |
lu_config = OmegaConf.create({
|
| 46 |
-
"device": "
|
| 47 |
"lu": {
|
| 48 |
"threshold": 2.75,
|
| 49 |
"onehot_logit": 1,
|
|
@@ -52,14 +52,14 @@ lu_config = OmegaConf.create({
|
|
| 52 |
})
|
| 53 |
|
| 54 |
ke_config = OmegaConf.create({
|
| 55 |
-
"device": "
|
| 56 |
"train_base": False,
|
| 57 |
"lr": 1e-5,
|
| 58 |
"model": model_config,
|
| 59 |
})
|
| 60 |
|
| 61 |
enn_config = OmegaConf.create({
|
| 62 |
-
"device": "
|
| 63 |
"lr": 1e-5,
|
| 64 |
"edit_lr": 1e-2,
|
| 65 |
"lr_lr": 1e-3,
|
|
@@ -72,10 +72,11 @@ enn_config = OmegaConf.create({
|
|
| 72 |
"n_edit_steps": 1,
|
| 73 |
},
|
| 74 |
"model": model_config,
|
|
|
|
| 75 |
})
|
| 76 |
|
| 77 |
mend_config = OmegaConf.create({
|
| 78 |
-
"device": "
|
| 79 |
"lr": 1e-6,
|
| 80 |
"edit_lr": 1e-4,
|
| 81 |
"lr_lr": 1e-4,
|
|
@@ -99,10 +100,11 @@ mend_config = OmegaConf.create({
|
|
| 99 |
"descent": False,
|
| 100 |
},
|
| 101 |
"model": model_config,
|
|
|
|
| 102 |
})
|
| 103 |
|
| 104 |
serac_config = OmegaConf.create({
|
| 105 |
-
"device": "cuda" if use_cuda() else "cpu",
|
| 106 |
"lr": 1e-5,
|
| 107 |
"edit_lr": 1e-2,
|
| 108 |
"lr_lr": 0,
|
|
@@ -128,4 +130,5 @@ serac_config = OmegaConf.create({
|
|
| 128 |
"cache_embeds": True,
|
| 129 |
},
|
| 130 |
"model": model_config,
|
|
|
|
| 131 |
})
|
|
|
|
| 21 |
}
|
| 22 |
|
| 23 |
ft_config = OmegaConf.create({
|
| 24 |
+
"device": "cpu",
|
| 25 |
"edit_lr": 5e-6,
|
| 26 |
"train_base": False,
|
| 27 |
"grad_clip": 100,
|
|
|
|
| 43 |
})
|
| 44 |
|
| 45 |
lu_config = OmegaConf.create({
|
| 46 |
+
"device": "cpu",
|
| 47 |
"lu": {
|
| 48 |
"threshold": 2.75,
|
| 49 |
"onehot_logit": 1,
|
|
|
|
| 52 |
})
|
| 53 |
|
| 54 |
ke_config = OmegaConf.create({
|
| 55 |
+
"device": "cpu",
|
| 56 |
"train_base": False,
|
| 57 |
"lr": 1e-5,
|
| 58 |
"model": model_config,
|
| 59 |
})
|
| 60 |
|
| 61 |
enn_config = OmegaConf.create({
|
| 62 |
+
"device": "cpu",
|
| 63 |
"lr": 1e-5,
|
| 64 |
"edit_lr": 1e-2,
|
| 65 |
"lr_lr": 1e-3,
|
|
|
|
| 72 |
"n_edit_steps": 1,
|
| 73 |
},
|
| 74 |
"model": model_config,
|
| 75 |
+
"archive": 8684705655, # "/iris/u/clin/code/efk/outputs/2022-02-09_05-48-20_8684705655/models/t5-large-ssm-nq.2022-02-09_05-48-20_8684705655",
|
| 76 |
})
|
| 77 |
|
| 78 |
mend_config = OmegaConf.create({
|
| 79 |
+
"device": "cpu",
|
| 80 |
"lr": 1e-6,
|
| 81 |
"edit_lr": 1e-4,
|
| 82 |
"lr_lr": 1e-4,
|
|
|
|
| 100 |
"descent": False,
|
| 101 |
},
|
| 102 |
"model": model_config,
|
| 103 |
+
"archive": 5940349945, # "/iris/u/clin/code/efk/outputs/2022-02-09_11-47-28_5940349945/models/t5-large-ssm-nq.2022-02-09_11-47-28_5940349945",
|
| 104 |
})
|
| 105 |
|
| 106 |
serac_config = OmegaConf.create({
|
| 107 |
+
"device": "cpu", # "device": "cuda" if use_cuda() else "cpu",
|
| 108 |
"lr": 1e-5,
|
| 109 |
"edit_lr": 1e-2,
|
| 110 |
"lr_lr": 0,
|
|
|
|
| 130 |
"cache_embeds": True,
|
| 131 |
},
|
| 132 |
"model": model_config,
|
| 133 |
+
"archive": 4719776130, # "/iris/u/clin/code/efk/outputs/2022-02-09_14-05-56_4719776130/models/t5-large-ssm-nq.2022-02-09_14-05-56_4719776130",
|
| 134 |
})
|
utils.py
CHANGED
|
@@ -156,12 +156,18 @@ def safe_backward(loss, parameters, accumulate=1, allow_unused=False, backward=F
|
|
| 156 |
|
| 157 |
|
| 158 |
def _logits(x):
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
def _last_encoder_state(x):
|
| 163 |
if hasattr(x, "encoder_last_hidden_state"):
|
| 164 |
return x.encoder_last_hidden_state
|
|
|
|
|
|
|
| 165 |
else:
|
| 166 |
return x.hidden_states[-1]
|
| 167 |
|
|
|
|
| 156 |
|
| 157 |
|
| 158 |
def _logits(x):
|
| 159 |
+
if hasattr(x, "logits"):
|
| 160 |
+
return x.logits
|
| 161 |
+
elif hasattr(x, "scores"):
|
| 162 |
+
return torch.cat(x.scores).unsqueeze(0)
|
| 163 |
+
return x
|
| 164 |
|
| 165 |
|
| 166 |
def _last_encoder_state(x):
|
| 167 |
if hasattr(x, "encoder_last_hidden_state"):
|
| 168 |
return x.encoder_last_hidden_state
|
| 169 |
+
elif hasattr(x, "encoder_hidden_states"):
|
| 170 |
+
return x.encoder_hidden_states[-1]
|
| 171 |
else:
|
| 172 |
return x.hidden_states[-1]
|
| 173 |
|