Spaces:
Runtime error
Runtime error
| # Copyright 2017, 2018 Google, Inc. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Common stuff used with LexNET.""" | |
| # pylint: disable=bad-whitespace | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import os | |
| import numpy as np | |
| from sklearn import metrics | |
| import tensorflow as tf | |
| # Part of speech tags used in the paths. | |
| POSTAGS = [ | |
| 'PAD', 'VERB', 'CONJ', 'NOUN', 'PUNCT', | |
| 'ADP', 'ADJ', 'DET', 'ADV', 'PART', | |
| 'NUM', 'X', 'INTJ', 'SYM', | |
| ] | |
| POSTAG_TO_ID = {tag: tid for tid, tag in enumerate(POSTAGS)} | |
| # Dependency labels used in the paths. | |
| DEPLABELS = [ | |
| 'PAD', 'UNK', 'ROOT', 'abbrev', 'acomp', 'advcl', | |
| 'advmod', 'agent', 'amod', 'appos', 'attr', 'aux', | |
| 'auxpass', 'cc', 'ccomp', 'complm', 'conj', 'cop', | |
| 'csubj', 'csubjpass', 'dep', 'det', 'dobj', 'expl', | |
| 'infmod', 'iobj', 'mark', 'mwe', 'nc', 'neg', | |
| 'nn', 'npadvmod', 'nsubj', 'nsubjpass', 'num', 'number', | |
| 'p', 'parataxis', 'partmod', 'pcomp', 'pobj', 'poss', | |
| 'preconj', 'predet', 'prep', 'prepc', 'prt', 'ps', | |
| 'purpcl', 'quantmod', 'rcmod', 'ref', 'rel', 'suffix', | |
| 'title', 'tmod', 'xcomp', 'xsubj', | |
| ] | |
| DEPLABEL_TO_ID = {label: lid for lid, label in enumerate(DEPLABELS)} | |
| # Direction codes used in the paths. | |
| DIRS = '_^V<>' | |
| DIR_TO_ID = {dir: did for did, dir in enumerate(DIRS)} | |
| def load_word_embeddings(embedding_filename): | |
| """Loads pretrained word embeddings from a binary file and returns the matrix. | |
| Adds the <PAD>, <UNK>, <X>, and <Y> tokens to the beginning of the vocab. | |
| Args: | |
| embedding_filename: filename of the binary NPY data | |
| Returns: | |
| The word embeddings matrix | |
| """ | |
| embeddings = np.load(embedding_filename) | |
| dim = embeddings.shape[1] | |
| # Four initially random vectors for the special tokens: <PAD>, <UNK>, <X>, <Y> | |
| special_embeddings = np.random.normal(0, 0.1, (4, dim)) | |
| embeddings = np.vstack((special_embeddings, embeddings)) | |
| embeddings = embeddings.astype(np.float32) | |
| return embeddings | |
| def full_evaluation(model, session, instances, labels, set_name, classes): | |
| """Prints a full evaluation on the current set. | |
| Performance (recall, precision and F1), classification report (per | |
| class performance), and confusion matrix). | |
| Args: | |
| model: The currently trained path-based model. | |
| session: The current TensorFlow session. | |
| instances: The current set instances. | |
| labels: The current set labels. | |
| set_name: The current set name (train/validation/test). | |
| classes: The class label names. | |
| Returns: | |
| The model's prediction for the given instances. | |
| """ | |
| # Predict the labels | |
| pred = model.predict(session, instances) | |
| # Print the performance | |
| precision, recall, f1, _ = metrics.precision_recall_fscore_support( | |
| labels, pred, average='weighted') | |
| print('%s set: Precision: %.3f, Recall: %.3f, F1: %.3f' % ( | |
| set_name, precision, recall, f1)) | |
| # Print a classification report | |
| print('%s classification report:' % set_name) | |
| print(metrics.classification_report(labels, pred, target_names=classes)) | |
| # Print the confusion matrix | |
| print('%s confusion matrix:' % set_name) | |
| cm = metrics.confusion_matrix(labels, pred, labels=range(len(classes))) | |
| cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100 | |
| print_cm(cm, labels=classes) | |
| return pred | |
| def print_cm(cm, labels): | |
| """Pretty print for confusion matrices. | |
| From: https://gist.github.com/zachguo/10296432. | |
| Args: | |
| cm: The confusion matrix. | |
| labels: The class names. | |
| """ | |
| columnwidth = 10 | |
| empty_cell = ' ' * columnwidth | |
| short_labels = [label[:12].rjust(10, ' ') for label in labels] | |
| # Print header | |
| header = empty_cell + ' ' | |
| header += ''.join([' %{0}s '.format(columnwidth) % label | |
| for label in short_labels]) | |
| print(header) | |
| # Print rows | |
| for i, label1 in enumerate(short_labels): | |
| row = '%{0}s '.format(columnwidth) % label1[:10] | |
| for j in range(len(short_labels)): | |
| value = int(cm[i, j]) if not np.isnan(cm[i, j]) else 0 | |
| cell = ' %{0}d '.format(10) % value | |
| row += cell + ' ' | |
| print(row) | |
| def load_all_labels(records): | |
| """Reads TensorFlow examples from a RecordReader and returns only the labels. | |
| Args: | |
| records: a record list with TensorFlow examples. | |
| Returns: | |
| The labels | |
| """ | |
| curr_features = tf.parse_example(records, { | |
| 'rel_id': tf.FixedLenFeature([1], dtype=tf.int64), | |
| }) | |
| labels = tf.squeeze(curr_features['rel_id'], [-1]) | |
| return labels | |
| def load_all_pairs(records): | |
| """Reads TensorFlow examples from a RecordReader and returns the word pairs. | |
| Args: | |
| records: a record list with TensorFlow examples. | |
| Returns: | |
| The word pairs | |
| """ | |
| curr_features = tf.parse_example(records, { | |
| 'pair': tf.FixedLenFeature([1], dtype=tf.string) | |
| }) | |
| word_pairs = curr_features['pair'] | |
| return word_pairs | |
| def write_predictions(pairs, labels, predictions, classes, predictions_file): | |
| """Write the predictions to a file. | |
| Args: | |
| pairs: the word pairs (list of tuple of two strings). | |
| labels: the gold-standard labels for these pairs (array of rel ID). | |
| predictions: the predicted labels for these pairs (array of rel ID). | |
| classes: a list of relation names. | |
| predictions_file: where to save the predictions. | |
| """ | |
| with open(predictions_file, 'w') as f_out: | |
| for pair, label, pred in zip(pairs, labels, predictions): | |
| w1, w2 = pair | |
| f_out.write('\t'.join([w1, w2, classes[label], classes[pred]]) + '\n') | |