Spaces:
Runtime error
Runtime error
| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """A library for loading 1B word benchmark dataset.""" | |
| import random | |
| import numpy as np | |
| import tensorflow as tf | |
| class Vocabulary(object): | |
| """Class that holds a vocabulary for the dataset.""" | |
| def __init__(self, filename): | |
| """Initialize vocabulary. | |
| Args: | |
| filename: Vocabulary file name. | |
| """ | |
| self._id_to_word = [] | |
| self._word_to_id = {} | |
| self._unk = -1 | |
| self._bos = -1 | |
| self._eos = -1 | |
| with tf.gfile.Open(filename) as f: | |
| idx = 0 | |
| for line in f: | |
| word_name = line.strip() | |
| if word_name == '<S>': | |
| self._bos = idx | |
| elif word_name == '</S>': | |
| self._eos = idx | |
| elif word_name == '<UNK>': | |
| self._unk = idx | |
| if word_name == '!!!MAXTERMID': | |
| continue | |
| self._id_to_word.append(word_name) | |
| self._word_to_id[word_name] = idx | |
| idx += 1 | |
| def bos(self): | |
| return self._bos | |
| def eos(self): | |
| return self._eos | |
| def unk(self): | |
| return self._unk | |
| def size(self): | |
| return len(self._id_to_word) | |
| def word_to_id(self, word): | |
| if word in self._word_to_id: | |
| return self._word_to_id[word] | |
| return self.unk | |
| def id_to_word(self, cur_id): | |
| if cur_id < self.size: | |
| return self._id_to_word[cur_id] | |
| return 'ERROR' | |
| def decode(self, cur_ids): | |
| """Convert a list of ids to a sentence, with space inserted.""" | |
| return ' '.join([self.id_to_word(cur_id) for cur_id in cur_ids]) | |
| def encode(self, sentence): | |
| """Convert a sentence to a list of ids, with special tokens added.""" | |
| word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()] | |
| return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32) | |
| class CharsVocabulary(Vocabulary): | |
| """Vocabulary containing character-level information.""" | |
| def __init__(self, filename, max_word_length): | |
| super(CharsVocabulary, self).__init__(filename) | |
| self._max_word_length = max_word_length | |
| chars_set = set() | |
| for word in self._id_to_word: | |
| chars_set |= set(word) | |
| free_ids = [] | |
| for i in range(256): | |
| if chr(i) in chars_set: | |
| continue | |
| free_ids.append(chr(i)) | |
| if len(free_ids) < 5: | |
| raise ValueError('Not enough free char ids: %d' % len(free_ids)) | |
| self.bos_char = free_ids[0] # <begin sentence> | |
| self.eos_char = free_ids[1] # <end sentence> | |
| self.bow_char = free_ids[2] # <begin word> | |
| self.eow_char = free_ids[3] # <end word> | |
| self.pad_char = free_ids[4] # <padding> | |
| chars_set |= {self.bos_char, self.eos_char, self.bow_char, self.eow_char, | |
| self.pad_char} | |
| self._char_set = chars_set | |
| num_words = len(self._id_to_word) | |
| self._word_char_ids = np.zeros([num_words, max_word_length], dtype=np.int32) | |
| self.bos_chars = self._convert_word_to_char_ids(self.bos_char) | |
| self.eos_chars = self._convert_word_to_char_ids(self.eos_char) | |
| for i, word in enumerate(self._id_to_word): | |
| self._word_char_ids[i] = self._convert_word_to_char_ids(word) | |
| def word_char_ids(self): | |
| return self._word_char_ids | |
| def max_word_length(self): | |
| return self._max_word_length | |
| def _convert_word_to_char_ids(self, word): | |
| code = np.zeros([self.max_word_length], dtype=np.int32) | |
| code[:] = ord(self.pad_char) | |
| if len(word) > self.max_word_length - 2: | |
| word = word[:self.max_word_length-2] | |
| cur_word = self.bow_char + word + self.eow_char | |
| for j in range(len(cur_word)): | |
| code[j] = ord(cur_word[j]) | |
| return code | |
| def word_to_char_ids(self, word): | |
| if word in self._word_to_id: | |
| return self._word_char_ids[self._word_to_id[word]] | |
| else: | |
| return self._convert_word_to_char_ids(word) | |
| def encode_chars(self, sentence): | |
| chars_ids = [self.word_to_char_ids(cur_word) | |
| for cur_word in sentence.split()] | |
| return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars]) | |
| def get_batch(generator, batch_size, num_steps, max_word_length, pad=False): | |
| """Read batches of input.""" | |
| cur_stream = [None] * batch_size | |
| inputs = np.zeros([batch_size, num_steps], np.int32) | |
| char_inputs = np.zeros([batch_size, num_steps, max_word_length], np.int32) | |
| global_word_ids = np.zeros([batch_size, num_steps], np.int32) | |
| targets = np.zeros([batch_size, num_steps], np.int32) | |
| weights = np.ones([batch_size, num_steps], np.float32) | |
| no_more_data = False | |
| while True: | |
| inputs[:] = 0 | |
| char_inputs[:] = 0 | |
| global_word_ids[:] = 0 | |
| targets[:] = 0 | |
| weights[:] = 0.0 | |
| for i in range(batch_size): | |
| cur_pos = 0 | |
| while cur_pos < num_steps: | |
| if cur_stream[i] is None or len(cur_stream[i][0]) <= 1: | |
| try: | |
| cur_stream[i] = list(generator.next()) | |
| except StopIteration: | |
| # No more data, exhaust current streams and quit | |
| no_more_data = True | |
| break | |
| how_many = min(len(cur_stream[i][0]) - 1, num_steps - cur_pos) | |
| next_pos = cur_pos + how_many | |
| inputs[i, cur_pos:next_pos] = cur_stream[i][0][:how_many] | |
| char_inputs[i, cur_pos:next_pos] = cur_stream[i][1][:how_many] | |
| global_word_ids[i, cur_pos:next_pos] = cur_stream[i][2][:how_many] | |
| targets[i, cur_pos:next_pos] = cur_stream[i][0][1:how_many+1] | |
| weights[i, cur_pos:next_pos] = 1.0 | |
| cur_pos = next_pos | |
| cur_stream[i][0] = cur_stream[i][0][how_many:] | |
| cur_stream[i][1] = cur_stream[i][1][how_many:] | |
| cur_stream[i][2] = cur_stream[i][2][how_many:] | |
| if pad: | |
| break | |
| if no_more_data and np.sum(weights) == 0: | |
| # There is no more data and this is an empty batch. Done! | |
| break | |
| yield inputs, char_inputs, global_word_ids, targets, weights | |
| class LM1BDataset(object): | |
| """Utility class for 1B word benchmark dataset. | |
| The current implementation reads the data from the tokenized text files. | |
| """ | |
| def __init__(self, filepattern, vocab): | |
| """Initialize LM1BDataset reader. | |
| Args: | |
| filepattern: Dataset file pattern. | |
| vocab: Vocabulary. | |
| """ | |
| self._vocab = vocab | |
| self._all_shards = tf.gfile.Glob(filepattern) | |
| tf.logging.info('Found %d shards at %s', len(self._all_shards), filepattern) | |
| def _load_random_shard(self): | |
| """Randomly select a file and read it.""" | |
| return self._load_shard(random.choice(self._all_shards)) | |
| def _load_shard(self, shard_name): | |
| """Read one file and convert to ids. | |
| Args: | |
| shard_name: file path. | |
| Returns: | |
| list of (id, char_id, global_word_id) tuples. | |
| """ | |
| tf.logging.info('Loading data from: %s', shard_name) | |
| with tf.gfile.Open(shard_name) as f: | |
| sentences = f.readlines() | |
| chars_ids = [self.vocab.encode_chars(sentence) for sentence in sentences] | |
| ids = [self.vocab.encode(sentence) for sentence in sentences] | |
| global_word_ids = [] | |
| current_idx = 0 | |
| for word_ids in ids: | |
| current_size = len(word_ids) - 1 # without <BOS> symbol | |
| cur_ids = np.arange(current_idx, current_idx + current_size) | |
| global_word_ids.append(cur_ids) | |
| current_idx += current_size | |
| tf.logging.info('Loaded %d words.', current_idx) | |
| tf.logging.info('Finished loading') | |
| return zip(ids, chars_ids, global_word_ids) | |
| def _get_sentence(self, forever=True): | |
| while True: | |
| ids = self._load_random_shard() | |
| for current_ids in ids: | |
| yield current_ids | |
| if not forever: | |
| break | |
| def get_batch(self, batch_size, num_steps, pad=False, forever=True): | |
| return get_batch(self._get_sentence(forever), batch_size, num_steps, | |
| self.vocab.max_word_length, pad=pad) | |
| def vocab(self): | |
| return self._vocab | |