Spaces:
Runtime error
Runtime error
| # Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import json | |
| import os | |
| import numpy as np | |
| import tensorflow as tf | |
| FLAGS = tf.flags.FLAGS | |
| class Vocabulary(object): | |
| """Class that holds a vocabulary for the dataset.""" | |
| def __init__(self, filename): | |
| self._id_to_word = [] | |
| self._word_to_id = {} | |
| self._unk = -1 | |
| self._bos = -1 | |
| self._eos = -1 | |
| with tf.gfile.Open(filename) as f: | |
| idx = 0 | |
| for line in f: | |
| word_name = line.strip() | |
| if word_name == '<S>': | |
| self._bos = idx | |
| elif word_name == '</S>': | |
| self._eos = idx | |
| elif word_name == '<UNK>': | |
| self._unk = idx | |
| if word_name == '!!!MAXTERMID': | |
| continue | |
| self._id_to_word.append(word_name) | |
| self._word_to_id[word_name] = idx | |
| idx += 1 | |
| def bos(self): | |
| return self._bos | |
| def eos(self): | |
| return self._eos | |
| def unk(self): | |
| return self._unk | |
| def size(self): | |
| return len(self._id_to_word) | |
| def word_to_id(self, word): | |
| if word in self._word_to_id: | |
| return self._word_to_id[word] | |
| else: | |
| if word.lower() in self._word_to_id: | |
| return self._word_to_id[word.lower()] | |
| return self.unk | |
| def id_to_word(self, cur_id): | |
| if cur_id < self.size: | |
| return self._id_to_word[int(cur_id)] | |
| return '<ERROR_out_of_vocab_id>' | |
| def decode(self, cur_ids): | |
| return ' '.join([self.id_to_word(cur_id) for cur_id in cur_ids]) | |
| def encode(self, sentence): | |
| word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()] | |
| return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32) | |
| class CharsVocabulary(Vocabulary): | |
| """Vocabulary containing character-level information.""" | |
| def __init__(self, filename, max_word_length): | |
| super(CharsVocabulary, self).__init__(filename) | |
| self._max_word_length = max_word_length | |
| chars_set = set() | |
| for word in self._id_to_word: | |
| chars_set |= set(word) | |
| free_ids = [] | |
| for i in range(256): | |
| if chr(i) in chars_set: | |
| continue | |
| free_ids.append(chr(i)) | |
| if len(free_ids) < 5: | |
| raise ValueError('Not enough free char ids: %d' % len(free_ids)) | |
| self.bos_char = free_ids[0] # <begin sentence> | |
| self.eos_char = free_ids[1] # <end sentence> | |
| self.bow_char = free_ids[2] # <begin word> | |
| self.eow_char = free_ids[3] # <end word> | |
| self.pad_char = free_ids[4] # <padding> | |
| chars_set |= {self.bos_char, self.eos_char, self.bow_char, self.eow_char, | |
| self.pad_char} | |
| self._char_set = chars_set | |
| num_words = len(self._id_to_word) | |
| self._word_char_ids = np.zeros([num_words, max_word_length], dtype=np.int32) | |
| self.bos_chars = self._convert_word_to_char_ids(self.bos_char) | |
| self.eos_chars = self._convert_word_to_char_ids(self.eos_char) | |
| for i, word in enumerate(self._id_to_word): | |
| if i == self.bos: | |
| self._word_char_ids[i] = self.bos_chars | |
| elif i == self.eos: | |
| self._word_char_ids[i] = self.eos_chars | |
| else: | |
| self._word_char_ids[i] = self._convert_word_to_char_ids(word) | |
| def max_word_length(self): | |
| return self._max_word_length | |
| def _convert_word_to_char_ids(self, word): | |
| code = np.zeros([self.max_word_length], dtype=np.int32) | |
| code[:] = ord(self.pad_char) | |
| if len(word) > self.max_word_length - 2: | |
| word = word[:self.max_word_length-2] | |
| cur_word = self.bow_char + word + self.eow_char | |
| for j in range(len(cur_word)): | |
| code[j] = ord(cur_word[j]) | |
| return code | |
| def word_to_char_ids(self, word): | |
| if word in self._word_to_id: | |
| return self._word_char_ids[self._word_to_id[word]] | |
| else: | |
| return self._convert_word_to_char_ids(word) | |
| def encode_chars(self, sentence): | |
| chars_ids = [self.word_to_char_ids(cur_word) | |
| for cur_word in sentence.split()] | |
| return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars]) | |
| _SPECIAL_CHAR_MAP = { | |
| '\xe2\x80\x98': '\'', | |
| '\xe2\x80\x99': '\'', | |
| '\xe2\x80\x9c': '"', | |
| '\xe2\x80\x9d': '"', | |
| '\xe2\x80\x93': '-', | |
| '\xe2\x80\x94': '-', | |
| '\xe2\x88\x92': '-', | |
| '\xce\x84': '\'', | |
| '\xc2\xb4': '\'', | |
| '`': '\'' | |
| } | |
| _START_SPECIAL_CHARS = ['.', ',', '?', '!', ';', ':', '[', ']', '\'', '+', '/', | |
| '\xc2\xa3', '$', '~', '*', '%', '{', '}', '#', '&', '-', | |
| '"', '(', ')', '='] + list(_SPECIAL_CHAR_MAP.keys()) | |
| _SPECIAL_CHARS = _START_SPECIAL_CHARS + [ | |
| '\'s', '\'m', '\'t', '\'re', '\'d', '\'ve', '\'ll'] | |
| def tokenize(sentence): | |
| """Tokenize a sentence.""" | |
| sentence = str(sentence) | |
| words = sentence.strip().split() | |
| tokenized = [] # return this | |
| for word in words: | |
| if word.lower() in ['mr.', 'ms.']: | |
| tokenized.append(word) | |
| continue | |
| # Split special chars at the start of word | |
| will_split = True | |
| while will_split: | |
| will_split = False | |
| for char in _START_SPECIAL_CHARS: | |
| if word.startswith(char): | |
| tokenized.append(char) | |
| word = word[len(char):] | |
| will_split = True | |
| # Split special chars at the end of word | |
| special_end_tokens = [] | |
| will_split = True | |
| while will_split: | |
| will_split = False | |
| for char in _SPECIAL_CHARS: | |
| if word.endswith(char): | |
| special_end_tokens = [char] + special_end_tokens | |
| word = word[:-len(char)] | |
| will_split = True | |
| if word: | |
| tokenized.append(word) | |
| tokenized += special_end_tokens | |
| # Add necessary end of sentence token. | |
| if tokenized[-1] not in ['.', '!', '?']: | |
| tokenized += ['.'] | |
| return tokenized | |
| def parse_commonsense_reasoning_test(test_data_name): | |
| """Read JSON test data.""" | |
| with tf.gfile.Open(os.path.join( | |
| FLAGS.data_dir, 'commonsense_test', | |
| '{}.json'.format(test_data_name)), 'r') as f: | |
| data = json.load(f) | |
| question_ids = [d['question_id'] for d in data] | |
| sentences = [tokenize(d['substitution']) for d in data] | |
| labels = [d['correctness'] for d in data] | |
| return question_ids, sentences, labels | |
| PAD = '<padding>' | |
| def cut_to_patches(sentences, batch_size, num_timesteps): | |
| """Cut sentences into patches of shape (batch_size, num_timesteps). | |
| Args: | |
| sentences: a list of sentences, each sentence is a list of str token. | |
| batch_size: batch size | |
| num_timesteps: number of backprop step | |
| Returns: | |
| patches: A 2D matrix, | |
| each entry is a matrix of shape (batch_size, num_timesteps). | |
| """ | |
| preprocessed = [['<S>']+sentence+['</S>'] for sentence in sentences] | |
| max_len = max([len(sent) for sent in preprocessed]) | |
| # Pad to shape [height, width] | |
| # where height is a multiple of batch_size | |
| # and width is a multiple of num_timesteps | |
| nrow = int(np.ceil(len(preprocessed) * 1.0 / batch_size)) | |
| ncol = int(np.ceil(max_len * 1.0 / num_timesteps)) | |
| height, width = nrow * batch_size, ncol * num_timesteps + 1 | |
| preprocessed = [sent + [PAD] * (width - len(sent)) for sent in preprocessed] | |
| preprocessed += [[PAD] * width] * (height - len(preprocessed)) | |
| # Cut preprocessed into patches of shape [batch_size, num_timesteps] | |
| patches = [] | |
| for row in range(nrow): | |
| patches.append([]) | |
| for col in range(ncol): | |
| patch = [sent[col * num_timesteps: | |
| (col+1) * num_timesteps + 1] | |
| for sent in preprocessed[row * batch_size: | |
| (row+1) * batch_size]] | |
| if np.all(np.array(patch)[:, 1:] == PAD): | |
| patch = None # no need to process this patch. | |
| patches[-1].append(patch) | |
| return patches | |
| def _substitution_mask(sent1, sent2): | |
| """Binary mask identifying substituted part in two sentences. | |
| Example sentence and their mask: | |
| First sentence = "I like the cat 's color" | |
| 0 0 0 1 0 0 | |
| Second sentence = "I like the yellow dog 's color" | |
| 0 0 0 1 1 0 0 | |
| Args: | |
| sent1: first sentence | |
| sent2: second sentence | |
| Returns: | |
| mask1: mask for first sentence | |
| mask2: mask for second sentence | |
| """ | |
| mask1_start, mask2_start = [], [] | |
| while sent1[0] == sent2[0]: | |
| sent1 = sent1[1:] | |
| sent2 = sent2[1:] | |
| mask1_start.append(0.) | |
| mask2_start.append(0.) | |
| mask1_end, mask2_end = [], [] | |
| while sent1[-1] == sent2[-1]: | |
| if (len(sent1) == 1) or (len(sent2) == 1): | |
| break | |
| sent1 = sent1[:-1] | |
| sent2 = sent2[:-1] | |
| mask1_end = [0.] + mask1_end | |
| mask2_end = [0.] + mask2_end | |
| assert sent1 or sent2, 'Two sentences are identical.' | |
| return (mask1_start + [1.] * len(sent1) + mask1_end, | |
| mask2_start + [1.] * len(sent2) + mask2_end) | |
| def _convert_to_partial(scoring1, scoring2): | |
| """Convert full scoring into partial scoring.""" | |
| mask1, mask2 = _substitution_mask( | |
| scoring1['sentence'], scoring2['sentence']) | |
| def _partial_score(scoring, mask): | |
| word_probs = [max(_) for _ in zip(scoring['word_probs'], mask)] | |
| scoring.update(word_probs=word_probs, | |
| joint_prob=np.prod(word_probs)) | |
| _partial_score(scoring1, mask1) | |
| _partial_score(scoring2, mask2) | |
| def compare_substitutions(question_ids, scorings, mode='full'): | |
| """Return accuracy by comparing two consecutive scorings.""" | |
| prediction_correctness = [] | |
| # Compare two consecutive substitutions | |
| for i in range(len(scorings) // 2): | |
| scoring1, scoring2 = scorings[2*i: 2*i+2] | |
| if mode == 'partial': # fix joint prob into partial prob | |
| _convert_to_partial(scoring1, scoring2) | |
| prediction_correctness.append( | |
| (scoring2['joint_prob'] > scoring1['joint_prob']) == | |
| scoring2['correctness']) | |
| # Two consecutive substitutions always belong to the same question | |
| question_ids = [qid for i, qid in enumerate(question_ids) if i % 2 == 0] | |
| assert len(question_ids) == len(prediction_correctness) | |
| num_questions = len(set(question_ids)) | |
| # Question is correctly answered only if | |
| # all predictions of the same question_id is correct | |
| num_correct_answer = 0 | |
| previous_qid = None | |
| correctly_answered = False | |
| for predict, qid in zip(prediction_correctness, question_ids): | |
| if qid != previous_qid: | |
| previous_qid = qid | |
| num_correct_answer += int(correctly_answered) | |
| correctly_answered = True | |
| correctly_answered = correctly_answered and predict | |
| num_correct_answer += int(correctly_answered) | |
| return num_correct_answer / num_questions | |