Spaces:
Runtime error
Runtime error
| # Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Model construction.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| # Dependency imports | |
| import tensorflow as tf | |
| from models import bidirectional | |
| from models import bidirectional_vd | |
| from models import bidirectional_zaremba | |
| from models import cnn | |
| from models import critic_vd | |
| from models import feedforward | |
| from models import rnn | |
| from models import rnn_nas | |
| from models import rnn_vd | |
| from models import rnn_zaremba | |
| from models import seq2seq | |
| from models import seq2seq_nas | |
| from models import seq2seq_vd | |
| from models import seq2seq_zaremba | |
| FLAGS = tf.app.flags.FLAGS | |
| # TODO(adai): IMDB labels placeholder to model. | |
| def create_generator(hparams, | |
| inputs, | |
| targets, | |
| present, | |
| is_training, | |
| is_validating, | |
| reuse=None): | |
| """Create the Generator model specified by the FLAGS and hparams. | |
| Args; | |
| hparams: Hyperparameters for the MaskGAN. | |
| inputs: tf.int32 Tensor of the sequence input of shape [batch_size, | |
| sequence_length]. | |
| present: tf.bool Tensor indicating the presence or absence of the token | |
| of shape [batch_size, sequence_length]. | |
| is_training: Whether the model is training. | |
| is_validating: Whether the model is being run in validation mode for | |
| calculating the perplexity. | |
| reuse (Optional): Whether to reuse the model. | |
| Returns: | |
| Tuple of the (sequence, logits, log_probs) of the Generator. Sequence | |
| and logits have shape [batch_size, sequence_length, vocab_size]. The | |
| log_probs will have shape [batch_size, sequence_length]. Log_probs | |
| corresponds to the log probability of selecting the words. | |
| """ | |
| if FLAGS.generator_model == 'rnn': | |
| (sequence, logits, log_probs, initial_state, final_state) = rnn.generator( | |
| hparams, | |
| inputs, | |
| targets, | |
| present, | |
| is_training=is_training, | |
| is_validating=is_validating, | |
| reuse=reuse) | |
| elif FLAGS.generator_model == 'rnn_zaremba': | |
| (sequence, logits, log_probs, initial_state, | |
| final_state) = rnn_zaremba.generator( | |
| hparams, | |
| inputs, | |
| targets, | |
| present, | |
| is_training=is_training, | |
| is_validating=is_validating, | |
| reuse=reuse) | |
| elif FLAGS.generator_model == 'seq2seq': | |
| (sequence, logits, log_probs, initial_state, | |
| final_state) = seq2seq.generator( | |
| hparams, | |
| inputs, | |
| targets, | |
| present, | |
| is_training=is_training, | |
| is_validating=is_validating, | |
| reuse=reuse) | |
| elif FLAGS.generator_model == 'seq2seq_zaremba': | |
| (sequence, logits, log_probs, initial_state, | |
| final_state) = seq2seq_zaremba.generator( | |
| hparams, | |
| inputs, | |
| targets, | |
| present, | |
| is_training=is_training, | |
| is_validating=is_validating, | |
| reuse=reuse) | |
| elif FLAGS.generator_model == 'rnn_nas': | |
| (sequence, logits, log_probs, initial_state, | |
| final_state) = rnn_nas.generator( | |
| hparams, | |
| inputs, | |
| targets, | |
| present, | |
| is_training=is_training, | |
| is_validating=is_validating, | |
| reuse=reuse) | |
| elif FLAGS.generator_model == 'seq2seq_nas': | |
| (sequence, logits, log_probs, initial_state, | |
| final_state) = seq2seq_nas.generator( | |
| hparams, | |
| inputs, | |
| targets, | |
| present, | |
| is_training=is_training, | |
| is_validating=is_validating, | |
| reuse=reuse) | |
| elif FLAGS.generator_model == 'seq2seq_vd': | |
| (sequence, logits, log_probs, initial_state, final_state, | |
| encoder_states) = seq2seq_vd.generator( | |
| hparams, | |
| inputs, | |
| targets, | |
| present, | |
| is_training=is_training, | |
| is_validating=is_validating, | |
| reuse=reuse) | |
| else: | |
| raise NotImplementedError | |
| return (sequence, logits, log_probs, initial_state, final_state, | |
| encoder_states) | |
| def create_discriminator(hparams, | |
| sequence, | |
| is_training, | |
| reuse=None, | |
| initial_state=None, | |
| inputs=None, | |
| present=None): | |
| """Create the Discriminator model specified by the FLAGS and hparams. | |
| Args: | |
| hparams: Hyperparameters for the MaskGAN. | |
| sequence: tf.int32 Tensor sequence of shape [batch_size, sequence_length] | |
| is_training: Whether the model is training. | |
| reuse (Optional): Whether to reuse the model. | |
| Returns: | |
| predictions: tf.float32 Tensor of predictions of shape [batch_size, | |
| sequence_length] | |
| """ | |
| if FLAGS.discriminator_model == 'cnn': | |
| predictions = cnn.discriminator( | |
| hparams, sequence, is_training=is_training, reuse=reuse) | |
| elif FLAGS.discriminator_model == 'fnn': | |
| predictions = feedforward.discriminator( | |
| hparams, sequence, is_training=is_training, reuse=reuse) | |
| elif FLAGS.discriminator_model == 'rnn': | |
| predictions = rnn.discriminator( | |
| hparams, sequence, is_training=is_training, reuse=reuse) | |
| elif FLAGS.discriminator_model == 'bidirectional': | |
| predictions = bidirectional.discriminator( | |
| hparams, sequence, is_training=is_training, reuse=reuse) | |
| elif FLAGS.discriminator_model == 'bidirectional_zaremba': | |
| predictions = bidirectional_zaremba.discriminator( | |
| hparams, sequence, is_training=is_training, reuse=reuse) | |
| elif FLAGS.discriminator_model == 'seq2seq_vd': | |
| predictions = seq2seq_vd.discriminator( | |
| hparams, | |
| inputs, | |
| present, | |
| sequence, | |
| is_training=is_training, | |
| reuse=reuse) | |
| elif FLAGS.discriminator_model == 'rnn_zaremba': | |
| predictions = rnn_zaremba.discriminator( | |
| hparams, sequence, is_training=is_training, reuse=reuse) | |
| elif FLAGS.discriminator_model == 'rnn_nas': | |
| predictions = rnn_nas.discriminator( | |
| hparams, sequence, is_training=is_training, reuse=reuse) | |
| elif FLAGS.discriminator_model == 'rnn_vd': | |
| predictions = rnn_vd.discriminator( | |
| hparams, | |
| sequence, | |
| is_training=is_training, | |
| reuse=reuse, | |
| initial_state=initial_state) | |
| elif FLAGS.discriminator_model == 'bidirectional_vd': | |
| predictions = bidirectional_vd.discriminator( | |
| hparams, | |
| sequence, | |
| is_training=is_training, | |
| reuse=reuse, | |
| initial_state=initial_state) | |
| else: | |
| raise NotImplementedError | |
| return predictions | |
| def create_critic(hparams, sequence, is_training, reuse=None): | |
| """Create the Critic model specified by the FLAGS and hparams. | |
| Args: | |
| hparams: Hyperparameters for the MaskGAN. | |
| sequence: tf.int32 Tensor sequence of shape [batch_size, sequence_length] | |
| is_training: Whether the model is training. | |
| reuse (Optional): Whether to reuse the model. | |
| Returns: | |
| values: tf.float32 Tensor of predictions of shape [batch_size, | |
| sequence_length] | |
| """ | |
| if FLAGS.baseline_method == 'critic': | |
| if FLAGS.discriminator_model == 'seq2seq_vd': | |
| values = critic_vd.critic_seq2seq_vd_derivative( | |
| hparams, sequence, is_training, reuse=reuse) | |
| else: | |
| raise NotImplementedError | |
| else: | |
| raise NotImplementedError | |
| return values | |