Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """A multi-task and semi-supervised NLP model.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import tensorflow as tf | |
| from model import encoder | |
| from model import shared_inputs | |
| class Inference(object): | |
| def __init__(self, config, inputs, pretrained_embeddings, tasks): | |
| with tf.variable_scope('encoder'): | |
| self.encoder = encoder.Encoder(config, inputs, pretrained_embeddings) | |
| self.modules = {} | |
| for task in tasks: | |
| with tf.variable_scope(task.name): | |
| self.modules[task.name] = task.get_module(inputs, self.encoder) | |
| class Model(object): | |
| def __init__(self, config, pretrained_embeddings, tasks): | |
| self._config = config | |
| self._tasks = tasks | |
| self._global_step, self._optimizer = self._get_optimizer() | |
| self._inputs = shared_inputs.Inputs(config) | |
| with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope: | |
| inference = Inference(config, self._inputs, pretrained_embeddings, | |
| tasks) | |
| self._trainer = inference | |
| self._tester = inference | |
| self._teacher = inference | |
| if config.ema_test or config.ema_teacher: | |
| ema = tf.train.ExponentialMovingAverage(config.ema_decay) | |
| model_vars = tf.get_collection("trainable_variables", "model") | |
| ema_op = ema.apply(model_vars) | |
| tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op) | |
| def ema_getter(getter, name, *args, **kwargs): | |
| var = getter(name, *args, **kwargs) | |
| return ema.average(var) | |
| scope.set_custom_getter(ema_getter) | |
| inference_ema = Inference( | |
| config, self._inputs, pretrained_embeddings, tasks) | |
| if config.ema_teacher: | |
| self._teacher = inference_ema | |
| if config.ema_test: | |
| self._tester = inference_ema | |
| self._unlabeled_loss = self._get_consistency_loss(tasks) | |
| self._unlabeled_train_op = self._get_train_op(self._unlabeled_loss) | |
| self._labeled_train_ops = {} | |
| for task in self._tasks: | |
| task_loss = self._trainer.modules[task.name].supervised_loss | |
| self._labeled_train_ops[task.name] = self._get_train_op(task_loss) | |
| def _get_consistency_loss(self, tasks): | |
| return sum([self._trainer.modules[task.name].unsupervised_loss | |
| for task in tasks]) | |
| def _get_optimizer(self): | |
| global_step = tf.get_variable('global_step', initializer=0, trainable=False) | |
| warm_up_multiplier = (tf.minimum(tf.to_float(global_step), | |
| self._config.warm_up_steps) | |
| / self._config.warm_up_steps) | |
| decay_multiplier = 1.0 / (1 + self._config.lr_decay * | |
| tf.sqrt(tf.to_float(global_step))) | |
| lr = self._config.lr * warm_up_multiplier * decay_multiplier | |
| optimizer = tf.train.MomentumOptimizer(lr, self._config.momentum) | |
| return global_step, optimizer | |
| def _get_train_op(self, loss): | |
| grads, vs = zip(*self._optimizer.compute_gradients(loss)) | |
| grads, _ = tf.clip_by_global_norm(grads, self._config.grad_clip) | |
| update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
| with tf.control_dependencies(update_ops): | |
| return self._optimizer.apply_gradients( | |
| zip(grads, vs), global_step=self._global_step) | |
| def _create_feed_dict(self, mb, model, is_training=True): | |
| feed = self._inputs.create_feed_dict(mb, is_training) | |
| if mb.task_name in model.modules: | |
| model.modules[mb.task_name].update_feed_dict(feed, mb) | |
| else: | |
| for module in model.modules.values(): | |
| module.update_feed_dict(feed, mb) | |
| return feed | |
| def train_unlabeled(self, sess, mb): | |
| return sess.run([self._unlabeled_train_op, self._unlabeled_loss], | |
| feed_dict=self._create_feed_dict(mb, self._trainer))[1] | |
| def train_labeled(self, sess, mb): | |
| return sess.run([self._labeled_train_ops[mb.task_name], | |
| self._trainer.modules[mb.task_name].supervised_loss,], | |
| feed_dict=self._create_feed_dict(mb, self._trainer))[1] | |
| def run_teacher(self, sess, mb): | |
| result = sess.run({task.name: self._teacher.modules[task.name].probs | |
| for task in self._tasks}, | |
| feed_dict=self._create_feed_dict(mb, self._teacher, | |
| False)) | |
| for task_name, probs in result.iteritems(): | |
| mb.teacher_predictions[task_name] = probs.astype('float16') | |
| def test(self, sess, mb): | |
| return sess.run( | |
| [self._tester.modules[mb.task_name].supervised_loss, | |
| self._tester.modules[mb.task_name].preds], | |
| feed_dict=self._create_feed_dict(mb, self._tester, False)) | |
| def get_global_step(self, sess): | |
| return sess.run(self._global_step) | |