Spaces:
Runtime error
Runtime error
| # Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| # Dependency imports | |
| import tensorflow as tf | |
| FLAGS = tf.app.flags.FLAGS | |
| def rnn_nas(hparams, model): | |
| assert model == 'gen' or model == 'dis' | |
| # This logic is only valid for rnn_zaremba | |
| if model == 'gen': | |
| assert FLAGS.generator_model == 'rnn_nas' | |
| assert hparams.gen_num_layers == 2 | |
| if model == 'dis': | |
| assert FLAGS.discriminator_model == 'rnn_nas' | |
| assert hparams.dis_num_layers == 2 | |
| # Output variables only for the Generator. Discriminator output biases | |
| # will begin randomly initialized. | |
| if model == 'gen': | |
| softmax_b = [ | |
| v for v in tf.trainable_variables() if v.op.name == 'gen/rnn/softmax_b' | |
| ][0] | |
| # Common elements to Generator and Discriminator. | |
| embedding = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == str(model) + '/rnn/embedding' | |
| ][0] | |
| lstm_w_0 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == | |
| str(model) + '/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat' | |
| ][0] | |
| lstm_b_0 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == str(model) + | |
| '/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat' | |
| ][0] | |
| lstm_w_1 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == | |
| str(model) + '/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat' | |
| ][0] | |
| lstm_b_1 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == str(model) + | |
| '/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat' | |
| ][0] | |
| # Dictionary mapping. | |
| if model == 'gen': | |
| variable_mapping = { | |
| 'Model/embeddings/input_embedding': | |
| embedding, | |
| 'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat': | |
| lstm_w_0, | |
| 'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat': | |
| lstm_b_0, | |
| 'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat': | |
| lstm_w_1, | |
| 'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat': | |
| lstm_b_1, | |
| 'Model/softmax_b': | |
| softmax_b | |
| } | |
| else: | |
| variable_mapping = { | |
| 'Model/embeddings/input_embedding': | |
| embedding, | |
| 'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat': | |
| lstm_w_0, | |
| 'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat': | |
| lstm_b_0, | |
| 'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat': | |
| lstm_w_1, | |
| 'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat': | |
| lstm_b_1 | |
| } | |
| return variable_mapping | |
| def cnn(): | |
| """Variable mapping for the CNN embedding. | |
| Returns: | |
| variable_mapping: Dictionary with Key: ckpt_name, Value: model_var. | |
| """ | |
| # This logic is only valid for cnn | |
| assert FLAGS.discriminator_model == 'cnn' | |
| # Retrieve CNN embedding. | |
| embedding = [ | |
| v for v in tf.trainable_variables() if v.op.name == 'dis/embedding' | |
| ][0] | |
| # Variable mapping. | |
| variable_mapping = {'Model/embedding': embedding} | |
| return variable_mapping | |
| def rnn_zaremba(hparams, model): | |
| """Returns the PTB Variable name to MaskGAN Variable dictionary mapping. This | |
| is a highly restrictive function just for testing. This will need to be | |
| generalized. | |
| Args: | |
| hparams: Hyperparameters for the MaskGAN. | |
| model: Model type, one of ['gen', 'dis']. | |
| Returns: | |
| variable_mapping: Dictionary with Key: ckpt_name, Value: model_var. | |
| """ | |
| assert model == 'gen' or model == 'dis' | |
| # This logic is only valid for rnn_zaremba | |
| if model == 'gen': | |
| assert FLAGS.generator_model == 'rnn_zaremba' | |
| assert hparams.gen_num_layers == 2 | |
| if model == 'dis': | |
| assert (FLAGS.discriminator_model == 'rnn_zaremba' or | |
| FLAGS.discriminator_model == 'rnn_vd') | |
| assert hparams.dis_num_layers == 2 | |
| # Output variables only for the Generator. Discriminator output weights | |
| # and biases will begin randomly initialized. | |
| if model == 'gen': | |
| softmax_w = [ | |
| v for v in tf.trainable_variables() if v.op.name == 'gen/rnn/softmax_w' | |
| ][0] | |
| softmax_b = [ | |
| v for v in tf.trainable_variables() if v.op.name == 'gen/rnn/softmax_b' | |
| ][0] | |
| # Common elements to Generator and Discriminator. | |
| if not FLAGS.dis_share_embedding or model != 'dis': | |
| embedding = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == str(model) + '/rnn/embedding' | |
| ][0] | |
| lstm_w_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == str(model) + | |
| '/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
| ][0] | |
| lstm_b_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == str(model) + | |
| '/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
| ][0] | |
| lstm_w_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == str(model) + | |
| '/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
| ][0] | |
| lstm_b_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == str(model) + | |
| '/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
| ][0] | |
| # Dictionary mapping. | |
| if model == 'gen': | |
| variable_mapping = { | |
| 'Model/embedding': embedding, | |
| 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0, | |
| 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0, | |
| 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1, | |
| 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1, | |
| 'Model/softmax_w': softmax_w, | |
| 'Model/softmax_b': softmax_b | |
| } | |
| else: | |
| if FLAGS.dis_share_embedding: | |
| variable_mapping = { | |
| 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0, | |
| 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0, | |
| 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1, | |
| 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1 | |
| } | |
| else: | |
| variable_mapping = { | |
| 'Model/embedding': embedding, | |
| 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0, | |
| 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0, | |
| 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1, | |
| 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1 | |
| } | |
| return variable_mapping | |
| def gen_encoder_seq2seq_nas(hparams): | |
| """Returns the NAS Variable name to MaskGAN Variable | |
| dictionary mapping. This is a highly restrictive function just for testing. | |
| This is for the *unidirecitional* seq2seq_nas encoder. | |
| Args: | |
| hparams: Hyperparameters for the MaskGAN. | |
| Returns: | |
| variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself. | |
| """ | |
| assert FLAGS.generator_model == 'seq2seq_nas' | |
| assert hparams.gen_num_layers == 2 | |
| ## Encoder forward variables. | |
| if not FLAGS.seq2seq_share_embedding: | |
| encoder_embedding = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'gen/encoder/rnn/embedding' | |
| ][0] | |
| encoder_lstm_w_0 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == | |
| 'gen/encoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat' | |
| ][0] | |
| encoder_lstm_b_0 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == | |
| 'gen/encoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat' | |
| ][0] | |
| encoder_lstm_w_1 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == | |
| 'gen/encoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat' | |
| ][0] | |
| encoder_lstm_b_1 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == | |
| 'gen/encoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat' | |
| ][0] | |
| if not FLAGS.seq2seq_share_embedding: | |
| variable_mapping = { | |
| 'Model/embeddings/input_embedding': | |
| encoder_embedding, | |
| 'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat': | |
| encoder_lstm_w_0, | |
| 'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat': | |
| encoder_lstm_b_0, | |
| 'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat': | |
| encoder_lstm_w_1, | |
| 'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat': | |
| encoder_lstm_b_1 | |
| } | |
| else: | |
| variable_mapping = { | |
| 'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat': | |
| encoder_lstm_w_0, | |
| 'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat': | |
| encoder_lstm_b_0, | |
| 'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat': | |
| encoder_lstm_w_1, | |
| 'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat': | |
| encoder_lstm_b_1 | |
| } | |
| return variable_mapping | |
| def gen_decoder_seq2seq_nas(hparams): | |
| assert FLAGS.generator_model == 'seq2seq_nas' | |
| assert hparams.gen_num_layers == 2 | |
| decoder_embedding = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'gen/decoder/rnn/embedding' | |
| ][0] | |
| decoder_lstm_w_0 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == | |
| 'gen/decoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat' | |
| ][0] | |
| decoder_lstm_b_0 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == | |
| 'gen/decoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat' | |
| ][0] | |
| decoder_lstm_w_1 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == | |
| 'gen/decoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat' | |
| ][0] | |
| decoder_lstm_b_1 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == | |
| 'gen/decoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat' | |
| ][0] | |
| decoder_softmax_b = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'gen/decoder/rnn/softmax_b' | |
| ][0] | |
| variable_mapping = { | |
| 'Model/embeddings/input_embedding': | |
| decoder_embedding, | |
| 'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat': | |
| decoder_lstm_w_0, | |
| 'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat': | |
| decoder_lstm_b_0, | |
| 'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat': | |
| decoder_lstm_w_1, | |
| 'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat': | |
| decoder_lstm_b_1, | |
| 'Model/softmax_b': | |
| decoder_softmax_b | |
| } | |
| return variable_mapping | |
| def gen_encoder_seq2seq(hparams): | |
| """Returns the PTB Variable name to MaskGAN Variable | |
| dictionary mapping. This is a highly restrictive function just for testing. | |
| This is foe the *unidirecitional* seq2seq_zaremba encoder. | |
| Args: | |
| hparams: Hyperparameters for the MaskGAN. | |
| Returns: | |
| variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself. | |
| """ | |
| assert (FLAGS.generator_model == 'seq2seq_zaremba' or | |
| FLAGS.generator_model == 'seq2seq_vd') | |
| assert hparams.gen_num_layers == 2 | |
| ## Encoder forward variables. | |
| if not FLAGS.seq2seq_share_embedding: | |
| encoder_embedding = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'gen/encoder/rnn/embedding' | |
| ][0] | |
| encoder_lstm_w_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
| ][0] | |
| encoder_lstm_b_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
| ][0] | |
| encoder_lstm_w_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
| ][0] | |
| encoder_lstm_b_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
| ][0] | |
| if FLAGS.data_set == 'ptb': | |
| model_str = 'Model' | |
| else: | |
| model_str = 'model' | |
| if not FLAGS.seq2seq_share_embedding: | |
| variable_mapping = { | |
| str(model_str) + '/embedding': | |
| encoder_embedding, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
| encoder_lstm_w_0, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
| encoder_lstm_b_0, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
| encoder_lstm_w_1, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
| encoder_lstm_b_1 | |
| } | |
| else: | |
| variable_mapping = { | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
| encoder_lstm_w_0, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
| encoder_lstm_b_0, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
| encoder_lstm_w_1, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
| encoder_lstm_b_1 | |
| } | |
| return variable_mapping | |
| def gen_decoder_seq2seq(hparams): | |
| assert (FLAGS.generator_model == 'seq2seq_zaremba' or | |
| FLAGS.generator_model == 'seq2seq_vd') | |
| assert hparams.gen_num_layers == 2 | |
| decoder_embedding = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'gen/decoder/rnn/embedding' | |
| ][0] | |
| decoder_lstm_w_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
| ][0] | |
| decoder_lstm_b_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
| ][0] | |
| decoder_lstm_w_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
| ][0] | |
| decoder_lstm_b_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
| ][0] | |
| decoder_softmax_b = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'gen/decoder/rnn/softmax_b' | |
| ][0] | |
| if FLAGS.data_set == 'ptb': | |
| model_str = 'Model' | |
| else: | |
| model_str = 'model' | |
| variable_mapping = { | |
| str(model_str) + '/embedding': | |
| decoder_embedding, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
| decoder_lstm_w_0, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
| decoder_lstm_b_0, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
| decoder_lstm_w_1, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
| decoder_lstm_b_1, | |
| str(model_str) + '/softmax_b': | |
| decoder_softmax_b | |
| } | |
| return variable_mapping | |
| def dis_fwd_bidirectional(hparams): | |
| """Returns the *forward* PTB Variable name to MaskGAN Variable dictionary | |
| mapping. This is a highly restrictive function just for testing. This is for | |
| the bidirectional_zaremba discriminator. | |
| Args: | |
| FLAGS: Flags for the model. | |
| hparams: Hyperparameters for the MaskGAN. | |
| Returns: | |
| variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself. | |
| """ | |
| assert (FLAGS.discriminator_model == 'bidirectional_zaremba' or | |
| FLAGS.discriminator_model == 'bidirectional_vd') | |
| assert hparams.dis_num_layers == 2 | |
| # Forward Discriminator Elements. | |
| if not FLAGS.dis_share_embedding: | |
| embedding = [ | |
| v for v in tf.trainable_variables() if v.op.name == 'dis/embedding' | |
| ][0] | |
| fw_lstm_w_0 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
| ][0] | |
| fw_lstm_b_0 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
| ][0] | |
| fw_lstm_w_1 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
| ][0] | |
| fw_lstm_b_1 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
| ][0] | |
| if FLAGS.dis_share_embedding: | |
| variable_mapping = { | |
| 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': fw_lstm_w_0, | |
| 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': fw_lstm_b_0, | |
| 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': fw_lstm_w_1, | |
| 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': fw_lstm_b_1 | |
| } | |
| else: | |
| variable_mapping = { | |
| 'Model/embedding': embedding, | |
| 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': fw_lstm_w_0, | |
| 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': fw_lstm_b_0, | |
| 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': fw_lstm_w_1, | |
| 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': fw_lstm_b_1 | |
| } | |
| return variable_mapping | |
| def dis_bwd_bidirectional(hparams): | |
| """Returns the *backward* PTB Variable name to MaskGAN Variable dictionary | |
| mapping. This is a highly restrictive function just for testing. This is for | |
| the bidirectional_zaremba discriminator. | |
| Args: | |
| hparams: Hyperparameters for the MaskGAN. | |
| Returns: | |
| variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself. | |
| """ | |
| assert (FLAGS.discriminator_model == 'bidirectional_zaremba' or | |
| FLAGS.discriminator_model == 'bidirectional_vd') | |
| assert hparams.dis_num_layers == 2 | |
| # Backward Discriminator Elements. | |
| bw_lstm_w_0 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
| ][0] | |
| bw_lstm_b_0 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
| ][0] | |
| bw_lstm_w_1 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
| ][0] | |
| bw_lstm_b_1 = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
| ][0] | |
| variable_mapping = { | |
| 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': bw_lstm_w_0, | |
| 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': bw_lstm_b_0, | |
| 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': bw_lstm_w_1, | |
| 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': bw_lstm_b_1 | |
| } | |
| return variable_mapping | |
| def dis_encoder_seq2seq(hparams): | |
| """Returns the PTB Variable name to MaskGAN Variable | |
| dictionary mapping. | |
| Args: | |
| hparams: Hyperparameters for the MaskGAN. | |
| Returns: | |
| variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself. | |
| """ | |
| assert FLAGS.discriminator_model == 'seq2seq_vd' | |
| assert hparams.dis_num_layers == 2 | |
| ## Encoder forward variables. | |
| encoder_lstm_w_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
| ][0] | |
| encoder_lstm_b_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
| ][0] | |
| encoder_lstm_w_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
| ][0] | |
| encoder_lstm_b_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
| ][0] | |
| if FLAGS.data_set == 'ptb': | |
| model_str = 'Model' | |
| else: | |
| model_str = 'model' | |
| variable_mapping = { | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
| encoder_lstm_w_0, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
| encoder_lstm_b_0, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
| encoder_lstm_w_1, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
| encoder_lstm_b_1 | |
| } | |
| return variable_mapping | |
| def dis_decoder_seq2seq(hparams): | |
| assert FLAGS.discriminator_model == 'seq2seq_vd' | |
| assert hparams.dis_num_layers == 2 | |
| if not FLAGS.dis_share_embedding: | |
| decoder_embedding = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'dis/decoder/rnn/embedding' | |
| ][0] | |
| decoder_lstm_w_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
| ][0] | |
| decoder_lstm_b_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
| ][0] | |
| decoder_lstm_w_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
| ][0] | |
| decoder_lstm_b_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
| ][0] | |
| if FLAGS.data_set == 'ptb': | |
| model_str = 'Model' | |
| else: | |
| model_str = 'model' | |
| if not FLAGS.dis_share_embedding: | |
| variable_mapping = { | |
| str(model_str) + '/embedding': | |
| decoder_embedding, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
| decoder_lstm_w_0, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
| decoder_lstm_b_0, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
| decoder_lstm_w_1, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
| decoder_lstm_b_1 | |
| } | |
| else: | |
| variable_mapping = { | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
| decoder_lstm_w_0, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
| decoder_lstm_b_0, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
| decoder_lstm_w_1, | |
| str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
| decoder_lstm_b_1, | |
| } | |
| return variable_mapping | |
| def dis_seq2seq_vd(hparams): | |
| assert FLAGS.discriminator_model == 'seq2seq_vd' | |
| assert hparams.dis_num_layers == 2 | |
| if not FLAGS.dis_share_embedding: | |
| decoder_embedding = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'dis/decoder/rnn/embedding' | |
| ][0] | |
| ## Encoder variables. | |
| encoder_lstm_w_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
| ][0] | |
| encoder_lstm_b_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
| ][0] | |
| encoder_lstm_w_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
| ][0] | |
| encoder_lstm_b_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
| ][0] | |
| ## Attention. | |
| if FLAGS.attention_option is not None: | |
| decoder_attention_keys = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'dis/decoder/attention_keys/weights' | |
| ][0] | |
| decoder_attention_construct_weights = [ | |
| v for v in tf.trainable_variables() | |
| if v.op.name == 'dis/decoder/rnn/attention_construct/weights' | |
| ][0] | |
| ## Decoder. | |
| decoder_lstm_w_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' | |
| ][0] | |
| decoder_lstm_b_0 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' | |
| ][0] | |
| decoder_lstm_w_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' | |
| ][0] | |
| decoder_lstm_b_1 = [ | |
| v for v in tf.trainable_variables() if v.op.name == | |
| 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' | |
| ][0] | |
| # Standard variable mappings. | |
| variable_mapping = { | |
| 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
| encoder_lstm_w_0, | |
| 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
| encoder_lstm_b_0, | |
| 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
| encoder_lstm_w_1, | |
| 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
| encoder_lstm_b_1, | |
| 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': | |
| decoder_lstm_w_0, | |
| 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias': | |
| decoder_lstm_b_0, | |
| 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': | |
| decoder_lstm_w_1, | |
| 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias': | |
| decoder_lstm_b_1 | |
| } | |
| # Optional variable mappings. | |
| if not FLAGS.dis_share_embedding: | |
| variable_mapping['gen/decoder/rnn/embedding'] = decoder_embedding | |
| if FLAGS.attention_option is not None: | |
| variable_mapping[ | |
| 'gen/decoder/attention_keys/weights'] = decoder_attention_keys | |
| variable_mapping[ | |
| 'gen/decoder/rnn/attention_construct/weights'] = decoder_attention_construct_weights | |
| return variable_mapping | |