Spaces:
Runtime error
Runtime error
| import csv | |
| import json | |
| import torch | |
| from transformers import BertTokenizer | |
| class CNerTokenizer(BertTokenizer): | |
| def __init__(self, vocab_file, do_lower_case=True): | |
| super().__init__(vocab_file=str(vocab_file), do_lower_case=do_lower_case) | |
| self.vocab_file = str(vocab_file) | |
| self.do_lower_case = do_lower_case | |
| def tokenize(self, text): | |
| _tokens = [] | |
| for c in text: | |
| if self.do_lower_case: | |
| c = c.lower() | |
| if c in self.vocab: | |
| _tokens.append(c) | |
| else: | |
| _tokens.append('[UNK]') | |
| return _tokens | |
| class DataProcessor(object): | |
| """Base class for data converters for sequence classification data sets.""" | |
| def get_train_examples(self, data_dir): | |
| """Gets a collection of `InputExample`s for the train set.""" | |
| raise NotImplementedError() | |
| def get_dev_examples(self, data_dir): | |
| """Gets a collection of `InputExample`s for the dev set.""" | |
| raise NotImplementedError() | |
| def get_labels(self): | |
| """Gets the list of labels for this data set.""" | |
| raise NotImplementedError() | |
| def _read_tsv(cls, input_file, quotechar=None): | |
| """Reads a tab separated value file.""" | |
| with open(input_file, "r", encoding="utf-8-sig") as f: | |
| reader = csv.reader(f, delimiter="\t", quotechar=quotechar) | |
| lines = [] | |
| for line in reader: | |
| lines.append(line) | |
| return lines | |
| def _read_text(self, input_file): | |
| lines = [] | |
| with open(input_file, 'r') as f: | |
| words = [] | |
| labels = [] | |
| for line in f: | |
| if line.startswith("-DOCSTART-") or line == "" or line == "\n": | |
| if words: | |
| lines.append({"words": words, "labels": labels}) | |
| words = [] | |
| labels = [] | |
| else: | |
| splits = line.split(" ") | |
| words.append(splits[0]) | |
| if len(splits) > 1: | |
| labels.append(splits[-1].replace("\n", "")) | |
| else: | |
| # Examples could have no label for mode = "test" | |
| labels.append("O") | |
| if words: | |
| lines.append({"words": words, "labels": labels}) | |
| return lines | |
| def _read_json(self, input_file): | |
| lines = [] | |
| with open(input_file, 'r', encoding='utf8') as f: | |
| for line in f: | |
| line = json.loads(line.strip()) | |
| text = line['text'] | |
| label_entities = line.get('label', None) | |
| words = list(text) | |
| labels = ['O'] * len(words) | |
| if label_entities is not None: | |
| for key, value in label_entities.items(): | |
| for sub_name, sub_index in value.items(): | |
| for start_index, end_index in sub_index: | |
| assert ''.join(words[start_index:end_index+1]) == sub_name | |
| if start_index == end_index: | |
| labels[start_index] = 'S-'+key | |
| else: | |
| if end_index - start_index == 1: | |
| labels[start_index] = 'B-' + key | |
| labels[end_index] = 'E-' + key | |
| else: | |
| labels[start_index] = 'B-' + key | |
| labels[start_index + 1:end_index] = ['I-' + key] * (len(sub_name) - 2) | |
| labels[end_index] = 'E-' + key | |
| lines.append({"words": words, "labels": labels}) | |
| return lines | |
| def get_entity_bios(seq, id2label, middle_prefix='I-'): | |
| """Gets entities from sequence. | |
| note: BIOS | |
| Args: | |
| seq (list): sequence of labels. | |
| Returns: | |
| list: list of (chunk_type, chunk_start, chunk_end). | |
| Example: | |
| # >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC'] | |
| # >>> get_entity_bios(seq) | |
| [['PER', 0,1], ['LOC', 3, 3]] | |
| """ | |
| chunks = [] | |
| chunk = [-1, -1, -1] | |
| for indx, tag in enumerate(seq): | |
| if not isinstance(tag, str): | |
| tag = id2label[tag] | |
| if tag.startswith("S-"): | |
| if chunk[2] != -1: | |
| chunks.append(chunk) | |
| chunk = [-1, -1, -1] | |
| chunk[1] = indx | |
| chunk[2] = indx | |
| chunk[0] = tag.split('-')[1] | |
| chunks.append(chunk) | |
| chunk = (-1, -1, -1) | |
| if tag.startswith("B-"): | |
| if chunk[2] != -1: | |
| chunks.append(chunk) | |
| chunk = [-1, -1, -1] | |
| chunk[1] = indx | |
| chunk[0] = tag.split('-')[1] | |
| elif tag.startswith(middle_prefix) and chunk[1] != -1: | |
| _type = tag.split('-')[1] | |
| if _type == chunk[0]: | |
| chunk[2] = indx | |
| if indx == len(seq) - 1: | |
| chunks.append(chunk) | |
| else: | |
| if chunk[2] != -1: | |
| chunks.append(chunk) | |
| chunk = [-1, -1, -1] | |
| return chunks | |
| def get_entity_bio(seq, id2label, middle_prefix='I-'): | |
| """Gets entities from sequence. | |
| note: BIO | |
| Args: | |
| seq (list): sequence of labels. | |
| Returns: | |
| list: list of (chunk_type, chunk_start, chunk_end). | |
| Example: | |
| seq = ['B-PER', 'I-PER', 'O', 'B-LOC'] | |
| get_entity_bio(seq) | |
| #output | |
| [['PER', 0,1], ['LOC', 3, 3]] | |
| """ | |
| chunks = [] | |
| chunk = [-1, -1, -1] | |
| for indx, tag in enumerate(seq): | |
| if not isinstance(tag, str): | |
| tag = id2label[tag] | |
| if tag.startswith("B-"): | |
| if chunk[2] != -1: | |
| chunks.append(chunk) | |
| chunk = [-1, -1, -1] | |
| chunk[1] = indx | |
| chunk[0] = tag.split('-')[1] | |
| chunk[2] = indx | |
| if indx == len(seq) - 1: | |
| chunks.append(chunk) | |
| elif tag.startswith(middle_prefix) and chunk[1] != -1: | |
| _type = tag.split('-')[1] | |
| if _type == chunk[0]: | |
| chunk[2] = indx | |
| if indx == len(seq) - 1: | |
| chunks.append(chunk) | |
| else: | |
| if chunk[2] != -1: | |
| chunks.append(chunk) | |
| chunk = [-1, -1, -1] | |
| return chunks | |
| def get_entity_bioes(seq, id2label, middle_prefix='I-'): | |
| """Gets entities from sequence. | |
| note: BIOS | |
| Args: | |
| seq (list): sequence of labels. | |
| Returns: | |
| list: list of (chunk_type, chunk_start, chunk_end). | |
| Example: | |
| # >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC'] | |
| # >>> get_entity_bios(seq) | |
| [['PER', 0,1], ['LOC', 3, 3]] | |
| """ | |
| chunks = [] | |
| chunk = [-1, -1, -1] | |
| for indx, tag in enumerate(seq): | |
| if not isinstance(tag, str): | |
| tag = id2label[tag] | |
| if tag.startswith("S-"): | |
| if chunk[2] != -1: | |
| chunks.append(chunk) | |
| chunk = [-1, -1, -1] | |
| chunk[1] = indx | |
| chunk[2] = indx | |
| chunk[0] = tag.split('-')[1] | |
| chunks.append(chunk) | |
| chunk = (-1, -1, -1) | |
| if tag.startswith("B-"): | |
| if chunk[2] != -1: | |
| chunks.append(chunk) | |
| chunk = [-1, -1, -1] | |
| chunk[1] = indx | |
| chunk[0] = tag.split('-')[1] | |
| elif (tag.startswith(middle_prefix) or tag.startswith("E-")) and chunk[1] != -1: | |
| _type = tag.split('-')[1] | |
| if _type == chunk[0]: | |
| chunk[2] = indx | |
| if indx == len(seq) - 1: | |
| chunks.append(chunk) | |
| else: | |
| if chunk[2] != -1: | |
| chunks.append(chunk) | |
| chunk = [-1, -1, -1] | |
| return chunks | |
| def get_entities(seq, id2label, markup='bio', middle_prefix='I-'): | |
| ''' | |
| :param seq: | |
| :param id2label: | |
| :param markup: | |
| :return: | |
| ''' | |
| assert markup in ['bio', 'bios', 'bioes'] | |
| if markup == 'bio': | |
| return get_entity_bio(seq, id2label, middle_prefix) | |
| elif markup == 'bios': | |
| return get_entity_bios(seq, id2label, middle_prefix) | |
| else: | |
| return get_entity_bioes(seq, id2label, middle_prefix) | |
| def bert_extract_item(start_logits, end_logits): | |
| S = [] | |
| start_pred = torch.argmax(start_logits, -1).cpu().numpy()[0][1:-1] | |
| end_pred = torch.argmax(end_logits, -1).cpu().numpy()[0][1:-1] | |
| for i, s_l in enumerate(start_pred): | |
| if s_l == 0: | |
| continue | |
| for j, e_l in enumerate(end_pred[i:]): | |
| if s_l == e_l: | |
| S.append((s_l, i, i + j)) | |
| break | |
| return S | |