Spaces:
Runtime error
Runtime error
Commit
·
df3eae0
1
Parent(s):
dfd8182
Create new file
Browse files- qasrl_model_pipeline.py +182 -0
qasrl_model_pipeline.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import json
|
| 3 |
+
from argparse import Namespace
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer
|
| 6 |
+
|
| 7 |
+
def get_markers_for_model(is_t5_model: bool) -> Namespace:
|
| 8 |
+
special_tokens_constants = Namespace()
|
| 9 |
+
if is_t5_model:
|
| 10 |
+
# T5 model have 100 special tokens by default
|
| 11 |
+
special_tokens_constants.separator_input_question_predicate = "<extra_id_1>"
|
| 12 |
+
special_tokens_constants.separator_output_answers = "<extra_id_3>"
|
| 13 |
+
special_tokens_constants.separator_output_questions = "<extra_id_5>" # if using only questions
|
| 14 |
+
special_tokens_constants.separator_output_question_answer = "<extra_id_7>"
|
| 15 |
+
special_tokens_constants.separator_output_pairs = "<extra_id_9>"
|
| 16 |
+
special_tokens_constants.predicate_generic_marker = "<extra_id_10>"
|
| 17 |
+
special_tokens_constants.predicate_verb_marker = "<extra_id_11>"
|
| 18 |
+
special_tokens_constants.predicate_nominalization_marker = "<extra_id_12>"
|
| 19 |
+
|
| 20 |
+
else:
|
| 21 |
+
special_tokens_constants.separator_input_question_predicate = "<question_predicate_sep>"
|
| 22 |
+
special_tokens_constants.separator_output_answers = "<answers_sep>"
|
| 23 |
+
special_tokens_constants.separator_output_questions = "<question_sep>" # if using only questions
|
| 24 |
+
special_tokens_constants.separator_output_question_answer = "<question_answer_sep>"
|
| 25 |
+
special_tokens_constants.separator_output_pairs = "<qa_pairs_sep>"
|
| 26 |
+
special_tokens_constants.predicate_generic_marker = "<predicate_marker>"
|
| 27 |
+
special_tokens_constants.predicate_verb_marker = "<verbal_predicate_marker>"
|
| 28 |
+
special_tokens_constants.predicate_nominalization_marker = "<nominalization_predicate_marker>"
|
| 29 |
+
return special_tokens_constants
|
| 30 |
+
|
| 31 |
+
def load_trained_model(name_or_path):
|
| 32 |
+
import huggingface_hub as HFhub
|
| 33 |
+
tokenizer = AutoTokenizer.from_pretrained(name_or_path)
|
| 34 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path)
|
| 35 |
+
# load preprocessing_kwargs from the model repo on HF hub, or from the local model directory
|
| 36 |
+
kwargs_filename = None
|
| 37 |
+
if name_or_path.startswith("kleinay/"): # and 'preprocessing_kwargs.json' in HFhub.list_repo_files(name_or_path): # the supported version of HFhub doesn't support list_repo_files
|
| 38 |
+
kwargs_filename = HFhub.hf_hub_download(repo_id=name_or_path, filename="preprocessing_kwargs.json")
|
| 39 |
+
elif Path(name_or_path).is_dir() and (Path(name_or_path) / "experiment_kwargs.json").exists():
|
| 40 |
+
kwargs_filename = Path(name_or_path) / "experiment_kwargs.json"
|
| 41 |
+
|
| 42 |
+
if kwargs_filename:
|
| 43 |
+
preprocessing_kwargs = json.load(open(kwargs_filename))
|
| 44 |
+
# integrate into model.config (for decoding args, e.g. "num_beams"), and save also as standalone object for preprocessing
|
| 45 |
+
model.config.preprocessing_kwargs = Namespace(**preprocessing_kwargs)
|
| 46 |
+
model.config.update(preprocessing_kwargs)
|
| 47 |
+
return model, tokenizer
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class QASRL_Pipeline(Text2TextGenerationPipeline):
|
| 51 |
+
def __init__(self, model_repo: str, **kwargs):
|
| 52 |
+
model, tokenizer = load_trained_model(model_repo)
|
| 53 |
+
super().__init__(model, tokenizer, framework="pt")
|
| 54 |
+
self.is_t5_model = "t5" in model.config.model_type
|
| 55 |
+
self.special_tokens = get_markers_for_model(self.is_t5_model)
|
| 56 |
+
self.data_args = model.config.preprocessing_kwargs
|
| 57 |
+
# backward compatibility - default keyword values implemeted in `run_summarization`, thus not saved in `preprocessing_kwargs`
|
| 58 |
+
if "predicate_marker_type" not in vars(self.data_args):
|
| 59 |
+
self.data_args.predicate_marker_type = "generic"
|
| 60 |
+
if "use_bilateral_predicate_marker" not in vars(self.data_args):
|
| 61 |
+
self.data_args.use_bilateral_predicate_marker = True
|
| 62 |
+
if "append_verb_form" not in vars(self.data_args):
|
| 63 |
+
self.data_args.append_verb_form = True
|
| 64 |
+
self._update_config(**kwargs)
|
| 65 |
+
|
| 66 |
+
def _update_config(self, **kwargs):
|
| 67 |
+
" Update self.model.config with initialization parameters and necessary defaults. "
|
| 68 |
+
# set default values that will always override model.config, but can overriden by __init__ kwargs
|
| 69 |
+
kwargs["max_length"] = kwargs.get("max_length", 80)
|
| 70 |
+
# override model.config with kwargs
|
| 71 |
+
for k,v in kwargs.items():
|
| 72 |
+
self.model.config.__dict__[k] = v
|
| 73 |
+
|
| 74 |
+
def _sanitize_parameters(self, **kwargs):
|
| 75 |
+
preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {}
|
| 76 |
+
if "predicate_marker" in kwargs:
|
| 77 |
+
preprocess_kwargs["predicate_marker"] = kwargs["predicate_marker"]
|
| 78 |
+
if "predicate_type" in kwargs:
|
| 79 |
+
preprocess_kwargs["predicate_type"] = kwargs["predicate_type"]
|
| 80 |
+
if "verb_form" in kwargs:
|
| 81 |
+
preprocess_kwargs["verb_form"] = kwargs["verb_form"]
|
| 82 |
+
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
| 83 |
+
|
| 84 |
+
def preprocess(self, inputs, predicate_marker="<predicate>", predicate_type=None, verb_form=None):
|
| 85 |
+
# Here, inputs is string or list of strings; apply string postprocessing
|
| 86 |
+
if isinstance(inputs, str):
|
| 87 |
+
processed_inputs = self._preprocess_string(inputs, predicate_marker, predicate_type, verb_form)
|
| 88 |
+
elif hasattr(inputs, "__iter__"):
|
| 89 |
+
processed_inputs = [self._preprocess_string(s, predicate_marker, predicate_type, verb_form) for s in inputs]
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError("inputs must be str or Iterable[str]")
|
| 92 |
+
# Now pass to super.preprocess for tokenization
|
| 93 |
+
return super().preprocess(processed_inputs)
|
| 94 |
+
|
| 95 |
+
def _preprocess_string(self, seq: str, predicate_marker: str, predicate_type: Optional[str], verb_form: Optional[str]) -> str:
|
| 96 |
+
sent_tokens = seq.split(" ")
|
| 97 |
+
assert predicate_marker in sent_tokens, f"Input sentence must include a predicate-marker token ('{predicate_marker}') before the target predicate word"
|
| 98 |
+
predicate_idx = sent_tokens.index(predicate_marker)
|
| 99 |
+
sent_tokens.remove(predicate_marker)
|
| 100 |
+
sentence_before_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx)])
|
| 101 |
+
predicate = sent_tokens[predicate_idx]
|
| 102 |
+
sentence_after_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx+1, len(sent_tokens))])
|
| 103 |
+
|
| 104 |
+
if self.data_args.predicate_marker_type == "generic":
|
| 105 |
+
predicate_marker = self.special_tokens.predicate_generic_marker
|
| 106 |
+
# In case we want special marker for each predicate type: """
|
| 107 |
+
elif self.data_args.predicate_marker_type == "pred_type":
|
| 108 |
+
assert predicate_type is not None, "For this model, you must provide the `predicate_type` either when initializing QASRL_Pipeline(...) or when applying __call__(...) on it"
|
| 109 |
+
assert predicate_type in ("verbal", "nominal"), f"`predicate_type` must be either 'verbal' or 'nominal'; got '{predicate_type}'"
|
| 110 |
+
predicate_marker = {"verbal": self.special_tokens.predicate_verb_marker ,
|
| 111 |
+
"nominal": self.special_tokens.predicate_nominalization_marker
|
| 112 |
+
}[predicate_type]
|
| 113 |
+
|
| 114 |
+
if self.data_args.use_bilateral_predicate_marker:
|
| 115 |
+
seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {predicate_marker} {sentence_after_predicate}"
|
| 116 |
+
else:
|
| 117 |
+
seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {sentence_after_predicate}"
|
| 118 |
+
|
| 119 |
+
# embed also verb_form
|
| 120 |
+
if self.data_args.append_verb_form and verb_form is None:
|
| 121 |
+
raise ValueError(f"For this model, you must provide the `verb_form` of the predicate when applying __call__(...)")
|
| 122 |
+
elif self.data_args.append_verb_form:
|
| 123 |
+
seq = f"{seq} {self.special_tokens.separator_input_question_predicate} {verb_form} "
|
| 124 |
+
else:
|
| 125 |
+
seq = f"{seq} "
|
| 126 |
+
|
| 127 |
+
# append source prefix (for t5 models)
|
| 128 |
+
prefix = self._get_source_prefix(predicate_type)
|
| 129 |
+
|
| 130 |
+
return prefix + seq
|
| 131 |
+
|
| 132 |
+
def _get_source_prefix(self, predicate_type: Optional[str]):
|
| 133 |
+
if not self.is_t5_model or self.data_args.source_prefix is None:
|
| 134 |
+
return ''
|
| 135 |
+
if not self.data_args.source_prefix.startswith("<"): # Regular prefix - not dependent on input row x
|
| 136 |
+
return self.data_args.source_prefix
|
| 137 |
+
if self.data_args.source_prefix == "<predicate-type>":
|
| 138 |
+
if predicate_type is None:
|
| 139 |
+
raise ValueError("source_prefix is '<predicate-type>' but input no `predicate_type`.")
|
| 140 |
+
else:
|
| 141 |
+
return f"Generate QAs for {predicate_type} QASRL: "
|
| 142 |
+
|
| 143 |
+
def _forward(self, *args, **kwargs):
|
| 144 |
+
outputs = super()._forward(*args, **kwargs)
|
| 145 |
+
return outputs
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def postprocess(self, model_outputs):
|
| 149 |
+
output_seq = self.tokenizer.decode(
|
| 150 |
+
model_outputs["output_ids"].squeeze(),
|
| 151 |
+
skip_special_tokens=False,
|
| 152 |
+
clean_up_tokenization_spaces=False,
|
| 153 |
+
)
|
| 154 |
+
output_seq = output_seq.strip(self.tokenizer.pad_token).strip(self.tokenizer.eos_token).strip()
|
| 155 |
+
qa_subseqs = output_seq.split(self.special_tokens.separator_output_pairs)
|
| 156 |
+
qas = [self._postrocess_qa(qa_subseq) for qa_subseq in qa_subseqs]
|
| 157 |
+
return {"generated_text": output_seq,
|
| 158 |
+
"QAs": qas}
|
| 159 |
+
|
| 160 |
+
def _postrocess_qa(self, seq: str) -> str:
|
| 161 |
+
# split question and answers
|
| 162 |
+
if self.special_tokens.separator_output_question_answer in seq:
|
| 163 |
+
question, answer = seq.split(self.special_tokens.separator_output_question_answer)[:2]
|
| 164 |
+
else:
|
| 165 |
+
print("invalid format: no separator between question and answer found...")
|
| 166 |
+
return None
|
| 167 |
+
# question, answer = seq, '' # Or: backoff to only question
|
| 168 |
+
# skip "_" slots in questions
|
| 169 |
+
question = ' '.join(t for t in question.split(' ') if t != '_')
|
| 170 |
+
answers = [a.strip() for a in answer.split(self.special_tokens.separator_output_answers)]
|
| 171 |
+
return {"question": question, "answers": answers}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
pipe = QASRL_Pipeline("kleinay/qanom-seq2seq-model-baseline")
|
| 176 |
+
res1 = pipe("The student was interested in Luke 's <predicate> research about sea animals .", verb_form="research", predicate_type="nominal")
|
| 177 |
+
res2 = pipe(["The doctor was interested in Luke 's <predicate> treatment .",
|
| 178 |
+
"The Veterinary student was interested in Luke 's <predicate> treatment of sea animals ."], verb_form="treat", predicate_type="nominal", num_beams=10)
|
| 179 |
+
res3 = pipe("A number of professions have <predicate> developed that specialize in the treatment of mental disorders .", verb_form="develop", predicate_type="verbal")
|
| 180 |
+
print(res1)
|
| 181 |
+
print(res2)
|
| 182 |
+
print(res3)
|