Spaces:
Running
Running
| import sys | |
| import os | |
| sys.path.append("./") | |
| sys.path.append("../") | |
| sys.path.append("../../") | |
| sys.path.append("../../../") | |
| from models import SPAN_EXTRACTION_MODEL_CLASSES | |
| from models import TOKENIZER_CLASSES | |
| import numpy as np | |
| import torch | |
| class HugIEAPI: | |
| def __init__(self, model_type, hugie_model_name_or_path) -> None: | |
| if model_type not in SPAN_EXTRACTION_MODEL_CLASSES[ | |
| "global_pointer"].keys(): | |
| raise KeyError( | |
| "You must choose one of the following model: {}".format( | |
| ", ".join( | |
| list(SPAN_EXTRACTION_MODEL_CLASSES["global_pointer"]. | |
| keys())))) | |
| self.model_type = model_type | |
| self.model = SPAN_EXTRACTION_MODEL_CLASSES["global_pointer"][ | |
| self.model_type].from_pretrained(hugie_model_name_or_path) | |
| self.tokenizer = TOKENIZER_CLASSES[self.model_type].from_pretrained( | |
| hugie_model_name_or_path) | |
| self.max_seq_length = 512 | |
| def fush_multi_answer(self, has_answer, new_answer): | |
| # 对于某个id测试集,出现多个example时(例如同一个测试样本使用了多个模板而生成了多个example),此时将预测的topk结果进行合并 | |
| # has为已经合并的结果,new为当前新产生的结果, | |
| # has格式为 {"ans": {"prob": float(prob[index_ids[ei]]), "pos": (s, e)}, ...} | |
| # new {"ans": {"prob": float(prob[index_ids[ei]]), "pos": (s, e)}, ...} | |
| # print("has_answer=", has_answer) | |
| for ans, value in new_answer.items(): | |
| if ans not in has_answer.keys(): | |
| has_answer[ans] = value | |
| else: | |
| has_answer[ans]["prob"] += value["prob"] | |
| has_answer[ans]["pos"].extend(value["pos"]) | |
| return has_answer | |
| def get_predict_result(self, probs, indices, examples): | |
| probs = probs.squeeze(1) # topk结果的概率 | |
| indices = indices.squeeze(1) # topk结果的索引 | |
| # print("probs=", probs) # [n, m] | |
| # print("indices=", indices) # [n, m] | |
| predictions = {} | |
| topk_predictions = {} | |
| idx = 0 | |
| for prob, index in zip(probs, indices): | |
| index_ids = torch.Tensor([i for i in range(len(index))]).long() | |
| topk_answer = list() | |
| answer = [] | |
| topk_answer_dict = dict() | |
| # TODO 1. 调节阈值 2. 处理输出实体重叠问题 | |
| entity_index = index[prob > 0.1] | |
| index_ids = index_ids[prob > 0.1] | |
| for ei, entity in enumerate(entity_index): | |
| # 1D index转2D index | |
| start_end = np.unravel_index( | |
| entity, (self.max_seq_length, self.max_seq_length)) | |
| s = examples["offset_mapping"][idx][start_end[0]][0] | |
| e = examples["offset_mapping"][idx][start_end[1]][1] | |
| ans = examples["content"][idx][s:e] | |
| if ans not in answer: | |
| answer.append(ans) | |
| # topk_answer.append({"answer": ans, "prob": float(prob[index_ids[ei]]), "pos": (s, e)}) | |
| topk_answer_dict[ans] = { | |
| "prob": | |
| float(prob[index_ids[ei]]), | |
| "pos": [(s.detach().cpu().numpy().tolist(), | |
| e.detach().cpu().numpy().tolist())] | |
| } | |
| predictions[idx] = answer | |
| if idx not in topk_predictions.keys(): | |
| # print("topk_answer_dict=", topk_answer_dict) | |
| topk_predictions[idx] = topk_answer_dict | |
| else: | |
| # print("topk_predictions[id_]=", topk_predictions[id_]) | |
| topk_predictions[idx] = self.fush_multi_answer( | |
| topk_predictions[idx], topk_answer_dict) | |
| idx += 1 | |
| for idx, values in topk_predictions.items(): | |
| # values {"ans": {}, ...} | |
| answer_list = list() | |
| for ans, value in values.items(): | |
| answer_list.append({ | |
| "answer": ans, | |
| "prob": value["prob"], | |
| "pos": value["pos"] | |
| }) | |
| topk_predictions[idx] = answer_list | |
| return predictions, topk_predictions | |
| def request(self, text: str, entity_type: str, relation: str = None): | |
| assert text is not None and entity_type is not None | |
| if relation is None: | |
| instruction = "找到文章中所有【{}】类型的实体?文章:【{}】".format(entity_type, text) | |
| pre_len = 21 - 2 + len(entity_type) | |
| else: | |
| instruction = "找到文章中【{}】的【{}】?文章:【{}】".format( | |
| entity_type, relation, text) | |
| pre_len = 19 - 4 + len(entity_type) + len(relation) | |
| inputs = self.tokenizer(instruction, | |
| max_length=self.max_seq_length, | |
| padding="max_length", | |
| return_tensors="pt", | |
| return_offsets_mapping=True) | |
| examples = { | |
| "content": [instruction], | |
| "offset_mapping": inputs["offset_mapping"] | |
| } | |
| batch_input = { | |
| "input_ids": inputs["input_ids"], | |
| "token_type_ids": inputs["token_type_ids"], | |
| "attention_mask": inputs["attention_mask"], | |
| } | |
| outputs = self.model(**batch_input) | |
| probs, indices = outputs["topk_probs"], outputs["topk_indices"] | |
| predictions, topk_predictions = self.get_predict_result( | |
| probs, indices, examples=examples) | |
| return predictions, topk_predictions | |
| if __name__ == "__main__": | |
| from applications.information_extraction.HugIE.api_test import HugIEAPI | |
| model_type = "bert" | |
| hugie_model_name_or_path = "wjn1996/wjn1996-hugnlp-hugie-large-zh" | |
| hugie = HugIEAPI("bert", hugie_model_name_or_path) | |
| text = "央广网北京2月23日消息 据中国地震台网正式测定,2月23日8时37分在塔吉克斯坦发生7.2级地震,震源深度10公里,震中位于北纬37.98度,东经73.29度,距我国边境线最近约82公里,地震造成新疆喀什等地震感强烈。" | |
| ## named entity recognition | |
| entity_type = "国家" | |
| predictions, topk_predictions = hugie.request(text, entity_type) | |
| print("entity_type:{}".format(entity_type)) | |
| print("predictions:\n{}".format(predictions)) | |
| print("topk_predictions:\n{}".format(topk_predictions)) | |
| print("\n\n") | |
| ## event extraction | |
| entity = "塔吉克斯坦地震" | |
| relation = "震源深度" | |
| predictions, topk_predictions = hugie.request(text, | |
| entity, | |
| relation=relation) | |
| print("entity:{}, relation:{}".format(entity, relation)) | |
| print("predictions:\n{}".format(predictions)) | |
| print("topk_predictions:\n{}".format(topk_predictions)) | |
| print("\n\n") | |
| ## event extraction | |
| entity = "塔吉克斯坦地震" | |
| relation = "震源位置" | |
| predictions, topk_predictions = hugie.request(text, | |
| entity, | |
| relation=relation) | |
| print("entity:{}, relation:{}".format(entity, relation)) | |
| print("predictions:\n{}".format(predictions)) | |
| print("topk_predictions:\n{}".format(topk_predictions)) | |
| print("\n\n") | |
| ## event extraction | |
| entity = "塔吉克斯坦地震" | |
| relation = "时间" | |
| predictions, topk_predictions = hugie.request(text, | |
| entity, | |
| relation=relation) | |
| print("entity:{}, relation:{}".format(entity, relation)) | |
| print("predictions:\n{}".format(predictions)) | |
| print("topk_predictions:\n{}".format(topk_predictions)) | |
| print("\n\n") | |
| ## event extraction | |
| entity = "塔吉克斯坦地震" | |
| relation = "影响" | |
| predictions, topk_predictions = hugie.request(text, | |
| entity, | |
| relation=relation) | |
| print("entity:{}, relation:{}".format(entity, relation)) | |
| print("predictions:\n{}".format(predictions)) | |
| print("topk_predictions:\n{}".format(topk_predictions)) | |
| print("\n\n") | |
| """ | |
| Output results: | |
| entity_type:国家 | |
| predictions: | |
| {0: ["塔吉克斯坦"]} | |
| predictions: | |
| {0: [{"answer": "塔吉克斯坦", "prob": 0.9999997615814209, "pos": [(tensor(57), tensor(62))]}]} | |
| entity:塔吉克斯坦地震, relation:震源深度 | |
| predictions: | |
| {0: ["10公里"]} | |
| predictions: | |
| {0: [{"answer": "10公里", "prob": 0.999994158744812, "pos": [(tensor(80), tensor(84))]}]} | |
| entity:塔吉克斯坦地震, relation:震源位置 | |
| predictions: | |
| {0: ["10公里", "距我国边境线最近约82公里", "北纬37.98度,东经73.29度", "北纬37.98度,东经73.29度,距我国边境线最近约82公里"]} | |
| predictions: | |
| {0: [{"answer": "10公里", "prob": 0.9895901083946228, "pos": [(tensor(80), tensor(84))]}, {"answer": "距我国边境线最近约82公里", "prob": 0.8584909439086914, "pos": [(tensor(107), tensor(120))]}, {"answer": "北纬37.98度,东经73.29度", "prob": 0.7202121615409851, "pos": [(tensor(89), tensor(106))]}, {"answer": "北纬37.98度,东经73.29度,距我国边境线最近约82公里", "prob": 0.11628123372793198, "pos": [(tensor(89), tensor(120))]}]} | |
| entity:塔吉克斯坦地震, relation:时间 | |
| predictions: | |
| {0: ["2月23日8时37分"]} | |
| predictions: | |
| {0: [{"answer": "2月23日8时37分", "prob": 0.9999995231628418, "pos": [(tensor(49), tensor(59))]}]} | |
| entity:塔吉克斯坦地震, relation:影响 | |
| predictions: | |
| {0: ["新疆喀什等地震感强烈"]} | |
| predictions: | |
| {0: [{"answer": "新疆喀什等地震感强烈", "prob": 0.9525265693664551, "pos": [(tensor(123), tensor(133))]}]} | |
| """ | |