Spaces:
Runtime error
Runtime error
| # Copyright 2017 Google Inc. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Create TFRecord files of SequenceExample protos from dataset. | |
| Constructs 3 datasets: | |
| 1. Labeled data for the LSTM classification model, optionally with label gain. | |
| "*_classification.tfrecords" (for both unidirectional and bidirectional | |
| models). | |
| 2. Data for the unsupervised LM-LSTM model that predicts the next token. | |
| "*_lm.tfrecords" (generates forward and reverse data). | |
| 3. Data for the unsupervised SA-LSTM model that uses Seq2Seq. | |
| "*_sa.tfrecords". | |
| """ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import os | |
| import string | |
| # Dependency imports | |
| import tensorflow as tf | |
| from data import data_utils | |
| from data import document_generators | |
| data = data_utils | |
| flags = tf.app.flags | |
| FLAGS = flags.FLAGS | |
| # Flags for input data are in document_generators.py | |
| flags.DEFINE_string('vocab_file', '', 'Path to the vocabulary file. Defaults ' | |
| 'to FLAGS.output_dir/vocab.txt.') | |
| flags.DEFINE_string('output_dir', '', 'Path to save tfrecords.') | |
| # Config | |
| flags.DEFINE_boolean('label_gain', False, | |
| 'Enable linear label gain. If True, sentiment label will ' | |
| 'be included at each timestep with linear weight ' | |
| 'increase.') | |
| def build_shuffling_tf_record_writer(fname): | |
| return data.ShufflingTFRecordWriter(os.path.join(FLAGS.output_dir, fname)) | |
| def build_tf_record_writer(fname): | |
| return tf.python_io.TFRecordWriter(os.path.join(FLAGS.output_dir, fname)) | |
| def build_input_sequence(doc, vocab_ids): | |
| """Builds input sequence from file. | |
| Splits lines on whitespace. Treats punctuation as whitespace. For word-level | |
| sequences, only keeps terms that are in the vocab. | |
| Terms are added as token in the SequenceExample. The EOS_TOKEN is also | |
| appended. Label and weight features are set to 0. | |
| Args: | |
| doc: Document (defined in `document_generators`) from which to build the | |
| sequence. | |
| vocab_ids: dict<term, id>. | |
| Returns: | |
| SequenceExampleWrapper. | |
| """ | |
| seq = data.SequenceWrapper() | |
| for token in document_generators.tokens(doc): | |
| if token in vocab_ids: | |
| seq.add_timestep().set_token(vocab_ids[token]) | |
| # Add EOS token to end | |
| seq.add_timestep().set_token(vocab_ids[data.EOS_TOKEN]) | |
| return seq | |
| def make_vocab_ids(vocab_filename): | |
| if FLAGS.output_char: | |
| ret = dict([(char, i) for i, char in enumerate(string.printable)]) | |
| ret[data.EOS_TOKEN] = len(string.printable) | |
| return ret | |
| else: | |
| with open(vocab_filename, encoding='utf-8') as vocab_f: | |
| return dict([(line.strip(), i) for i, line in enumerate(vocab_f)]) | |
| def generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all): | |
| """Generates training data.""" | |
| # Construct training data writers | |
| writer_lm = build_shuffling_tf_record_writer(data.TRAIN_LM) | |
| writer_seq_ae = build_shuffling_tf_record_writer(data.TRAIN_SA) | |
| writer_class = build_shuffling_tf_record_writer(data.TRAIN_CLASS) | |
| writer_valid_class = build_tf_record_writer(data.VALID_CLASS) | |
| writer_rev_lm = build_shuffling_tf_record_writer(data.TRAIN_REV_LM) | |
| writer_bd_class = build_shuffling_tf_record_writer(data.TRAIN_BD_CLASS) | |
| writer_bd_valid_class = build_shuffling_tf_record_writer(data.VALID_BD_CLASS) | |
| for doc in document_generators.documents( | |
| dataset='train', include_unlabeled=True, include_validation=True): | |
| input_seq = build_input_sequence(doc, vocab_ids) | |
| if len(input_seq) < 2: | |
| continue | |
| rev_seq = data.build_reverse_sequence(input_seq) | |
| lm_seq = data.build_lm_sequence(input_seq) | |
| rev_lm_seq = data.build_lm_sequence(rev_seq) | |
| seq_ae_seq = data.build_seq_ae_sequence(input_seq) | |
| if doc.label is not None: | |
| # Used for sentiment classification. | |
| label_seq = data.build_labeled_sequence( | |
| input_seq, | |
| doc.label, | |
| label_gain=(FLAGS.label_gain and not doc.is_validation)) | |
| bd_label_seq = data.build_labeled_sequence( | |
| data.build_bidirectional_seq(input_seq, rev_seq), | |
| doc.label, | |
| label_gain=(FLAGS.label_gain and not doc.is_validation)) | |
| class_writer = writer_valid_class if doc.is_validation else writer_class | |
| bd_class_writer = (writer_bd_valid_class | |
| if doc.is_validation else writer_bd_class) | |
| class_writer.write(label_seq.seq.SerializeToString()) | |
| bd_class_writer.write(bd_label_seq.seq.SerializeToString()) | |
| # Write | |
| lm_seq_ser = lm_seq.seq.SerializeToString() | |
| seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString() | |
| writer_lm_all.write(lm_seq_ser) | |
| writer_seq_ae_all.write(seq_ae_seq_ser) | |
| if not doc.is_validation: | |
| writer_lm.write(lm_seq_ser) | |
| writer_rev_lm.write(rev_lm_seq.seq.SerializeToString()) | |
| writer_seq_ae.write(seq_ae_seq_ser) | |
| # Close writers | |
| writer_lm.close() | |
| writer_seq_ae.close() | |
| writer_class.close() | |
| writer_valid_class.close() | |
| writer_rev_lm.close() | |
| writer_bd_class.close() | |
| writer_bd_valid_class.close() | |
| def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all): | |
| """Generates test data.""" | |
| # Construct test data writers | |
| writer_lm = build_shuffling_tf_record_writer(data.TEST_LM) | |
| writer_rev_lm = build_shuffling_tf_record_writer(data.TEST_REV_LM) | |
| writer_seq_ae = build_shuffling_tf_record_writer(data.TEST_SA) | |
| writer_class = build_tf_record_writer(data.TEST_CLASS) | |
| writer_bd_class = build_shuffling_tf_record_writer(data.TEST_BD_CLASS) | |
| for doc in document_generators.documents( | |
| dataset='test', include_unlabeled=False, include_validation=True): | |
| input_seq = build_input_sequence(doc, vocab_ids) | |
| if len(input_seq) < 2: | |
| continue | |
| rev_seq = data.build_reverse_sequence(input_seq) | |
| lm_seq = data.build_lm_sequence(input_seq) | |
| rev_lm_seq = data.build_lm_sequence(rev_seq) | |
| seq_ae_seq = data.build_seq_ae_sequence(input_seq) | |
| label_seq = data.build_labeled_sequence(input_seq, doc.label) | |
| bd_label_seq = data.build_labeled_sequence( | |
| data.build_bidirectional_seq(input_seq, rev_seq), doc.label) | |
| # Write | |
| writer_class.write(label_seq.seq.SerializeToString()) | |
| writer_bd_class.write(bd_label_seq.seq.SerializeToString()) | |
| lm_seq_ser = lm_seq.seq.SerializeToString() | |
| seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString() | |
| writer_lm.write(lm_seq_ser) | |
| writer_rev_lm.write(rev_lm_seq.seq.SerializeToString()) | |
| writer_seq_ae.write(seq_ae_seq_ser) | |
| writer_lm_all.write(lm_seq_ser) | |
| writer_seq_ae_all.write(seq_ae_seq_ser) | |
| # Close test writers | |
| writer_lm.close() | |
| writer_rev_lm.close() | |
| writer_seq_ae.close() | |
| writer_class.close() | |
| writer_bd_class.close() | |
| def main(_): | |
| tf.logging.set_verbosity(tf.logging.INFO) | |
| tf.logging.info('Assigning vocabulary ids...') | |
| vocab_ids = make_vocab_ids( | |
| FLAGS.vocab_file or os.path.join(FLAGS.output_dir, 'vocab.txt')) | |
| with build_shuffling_tf_record_writer(data.ALL_LM) as writer_lm_all: | |
| with build_shuffling_tf_record_writer(data.ALL_SA) as writer_seq_ae_all: | |
| tf.logging.info('Generating training data...') | |
| generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all) | |
| tf.logging.info('Generating test data...') | |
| generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all) | |
| if __name__ == '__main__': | |
| tf.app.run() | |