|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""XLNet SQUAD finetuning runner in tf2.0.""" |
|
|
|
|
|
from __future__ import absolute_import |
|
|
from __future__ import division |
|
|
|
|
|
from __future__ import print_function |
|
|
|
|
|
import functools |
|
|
import json |
|
|
import os |
|
|
import pickle |
|
|
|
|
|
from absl import app |
|
|
from absl import flags |
|
|
from absl import logging |
|
|
|
|
|
import tensorflow as tf |
|
|
|
|
|
import sentencepiece as spm |
|
|
from official.nlp.xlnet import common_flags |
|
|
from official.nlp.xlnet import data_utils |
|
|
from official.nlp.xlnet import optimization |
|
|
from official.nlp.xlnet import squad_utils |
|
|
from official.nlp.xlnet import training_utils |
|
|
from official.nlp.xlnet import xlnet_config |
|
|
from official.nlp.xlnet import xlnet_modeling as modeling |
|
|
from official.utils.misc import tpu_lib |
|
|
|
|
|
flags.DEFINE_string( |
|
|
"test_feature_path", default=None, help="Path to feature of test set.") |
|
|
flags.DEFINE_integer("query_len", default=64, help="Max query length.") |
|
|
flags.DEFINE_integer("start_n_top", default=5, help="Beam size for span start.") |
|
|
flags.DEFINE_integer("end_n_top", default=5, help="Beam size for span end.") |
|
|
flags.DEFINE_string( |
|
|
"predict_dir", default=None, help="Path to write predictions.") |
|
|
flags.DEFINE_string( |
|
|
"predict_file", default=None, help="Path to json file of test set.") |
|
|
flags.DEFINE_integer( |
|
|
"n_best_size", default=5, help="n best size for predictions.") |
|
|
flags.DEFINE_integer("max_answer_length", default=64, help="Max answer length.") |
|
|
|
|
|
flags.DEFINE_string( |
|
|
"spiece_model_file", default=None, help="Sentence Piece model path.") |
|
|
flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length.") |
|
|
flags.DEFINE_integer("max_query_length", default=64, help="Max query length.") |
|
|
flags.DEFINE_integer("doc_stride", default=128, help="Doc stride.") |
|
|
|
|
|
FLAGS = flags.FLAGS |
|
|
|
|
|
|
|
|
class InputFeatures(object): |
|
|
"""A single set of features of data.""" |
|
|
|
|
|
def __init__(self, |
|
|
unique_id, |
|
|
example_index, |
|
|
doc_span_index, |
|
|
tok_start_to_orig_index, |
|
|
tok_end_to_orig_index, |
|
|
token_is_max_context, |
|
|
input_ids, |
|
|
input_mask, |
|
|
p_mask, |
|
|
segment_ids, |
|
|
paragraph_len, |
|
|
cls_index, |
|
|
start_position=None, |
|
|
end_position=None, |
|
|
is_impossible=None): |
|
|
self.unique_id = unique_id |
|
|
self.example_index = example_index |
|
|
self.doc_span_index = doc_span_index |
|
|
self.tok_start_to_orig_index = tok_start_to_orig_index |
|
|
self.tok_end_to_orig_index = tok_end_to_orig_index |
|
|
self.token_is_max_context = token_is_max_context |
|
|
self.input_ids = input_ids |
|
|
self.input_mask = input_mask |
|
|
self.p_mask = p_mask |
|
|
self.segment_ids = segment_ids |
|
|
self.paragraph_len = paragraph_len |
|
|
self.cls_index = cls_index |
|
|
self.start_position = start_position |
|
|
self.end_position = end_position |
|
|
self.is_impossible = is_impossible |
|
|
|
|
|
|
|
|
|
|
|
def run_evaluation(strategy, test_input_fn, eval_examples, eval_features, |
|
|
original_data, eval_steps, input_meta_data, model, |
|
|
current_step, eval_summary_writer): |
|
|
"""Run evaluation for SQUAD task. |
|
|
|
|
|
Args: |
|
|
strategy: distribution strategy. |
|
|
test_input_fn: input function for evaluation data. |
|
|
eval_examples: tf.Examples of the evaluation set. |
|
|
eval_features: Feature objects of the evaluation set. |
|
|
original_data: The original json data for the evaluation set. |
|
|
eval_steps: total number of evaluation steps. |
|
|
input_meta_data: input meta data. |
|
|
model: keras model object. |
|
|
current_step: current training step. |
|
|
eval_summary_writer: summary writer used to record evaluation metrics. |
|
|
|
|
|
Returns: |
|
|
A float metric, F1 score. |
|
|
""" |
|
|
|
|
|
def _test_step_fn(inputs): |
|
|
"""Replicated validation step.""" |
|
|
|
|
|
inputs["mems"] = None |
|
|
res = model(inputs, training=False) |
|
|
return res, inputs["unique_ids"] |
|
|
|
|
|
@tf.function |
|
|
def _run_evaluation(test_iterator): |
|
|
"""Runs validation steps.""" |
|
|
res, unique_ids = strategy.run( |
|
|
_test_step_fn, args=(next(test_iterator),)) |
|
|
return res, unique_ids |
|
|
|
|
|
test_iterator = data_utils.get_input_iterator(test_input_fn, strategy) |
|
|
cur_results = [] |
|
|
for _ in range(eval_steps): |
|
|
results, unique_ids = _run_evaluation(test_iterator) |
|
|
unique_ids = strategy.experimental_local_results(unique_ids) |
|
|
|
|
|
for result_key in results: |
|
|
results[result_key] = ( |
|
|
strategy.experimental_local_results(results[result_key])) |
|
|
for core_i in range(strategy.num_replicas_in_sync): |
|
|
bsz = int(input_meta_data["test_batch_size"] / |
|
|
strategy.num_replicas_in_sync) |
|
|
for j in range(bsz): |
|
|
result = {} |
|
|
for result_key in results: |
|
|
result[result_key] = results[result_key][core_i].numpy()[j] |
|
|
result["unique_ids"] = unique_ids[core_i].numpy()[j] |
|
|
|
|
|
|
|
|
|
|
|
if result["unique_ids"] == 1000012047: |
|
|
continue |
|
|
unique_id = int(result["unique_ids"]) |
|
|
|
|
|
start_top_log_probs = ([ |
|
|
float(x) for x in result["start_top_log_probs"].flat |
|
|
]) |
|
|
start_top_index = [int(x) for x in result["start_top_index"].flat] |
|
|
end_top_log_probs = ([ |
|
|
float(x) for x in result["end_top_log_probs"].flat |
|
|
]) |
|
|
end_top_index = [int(x) for x in result["end_top_index"].flat] |
|
|
|
|
|
cls_logits = float(result["cls_logits"].flat[0]) |
|
|
cur_results.append( |
|
|
squad_utils.RawResult( |
|
|
unique_id=unique_id, |
|
|
start_top_log_probs=start_top_log_probs, |
|
|
start_top_index=start_top_index, |
|
|
end_top_log_probs=end_top_log_probs, |
|
|
end_top_index=end_top_index, |
|
|
cls_logits=cls_logits)) |
|
|
if len(cur_results) % 1000 == 0: |
|
|
logging.info("Processing example: %d", len(cur_results)) |
|
|
|
|
|
output_prediction_file = os.path.join(input_meta_data["predict_dir"], |
|
|
"predictions.json") |
|
|
output_nbest_file = os.path.join(input_meta_data["predict_dir"], |
|
|
"nbest_predictions.json") |
|
|
output_null_log_odds_file = os.path.join(input_meta_data["predict_dir"], |
|
|
"null_odds.json") |
|
|
|
|
|
results = squad_utils.write_predictions( |
|
|
eval_examples, eval_features, cur_results, input_meta_data["n_best_size"], |
|
|
input_meta_data["max_answer_length"], output_prediction_file, |
|
|
output_nbest_file, output_null_log_odds_file, original_data, |
|
|
input_meta_data["start_n_top"], input_meta_data["end_n_top"]) |
|
|
|
|
|
|
|
|
log_str = "Result | " |
|
|
for key, val in results.items(): |
|
|
log_str += "{} {} | ".format(key, val) |
|
|
logging.info(log_str) |
|
|
with eval_summary_writer.as_default(): |
|
|
tf.summary.scalar("best_f1", results["best_f1"], step=current_step) |
|
|
tf.summary.scalar("best_exact", results["best_exact"], step=current_step) |
|
|
eval_summary_writer.flush() |
|
|
return results["best_f1"] |
|
|
|
|
|
|
|
|
def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top): |
|
|
model = modeling.QAXLNetModel( |
|
|
model_config, |
|
|
run_config, |
|
|
start_n_top=start_n_top, |
|
|
end_n_top=end_n_top, |
|
|
name="model") |
|
|
return model |
|
|
|
|
|
|
|
|
def main(unused_argv): |
|
|
del unused_argv |
|
|
if FLAGS.strategy_type == "mirror": |
|
|
strategy = tf.distribute.MirroredStrategy() |
|
|
elif FLAGS.strategy_type == "tpu": |
|
|
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) |
|
|
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) |
|
|
else: |
|
|
raise ValueError("The distribution strategy type is not supported: %s" % |
|
|
FLAGS.strategy_type) |
|
|
if strategy: |
|
|
logging.info("***** Number of cores used : %d", |
|
|
strategy.num_replicas_in_sync) |
|
|
train_input_fn = functools.partial(data_utils.get_squad_input_data, |
|
|
FLAGS.train_batch_size, FLAGS.seq_len, |
|
|
FLAGS.query_len, strategy, True, |
|
|
FLAGS.train_tfrecord_path) |
|
|
|
|
|
test_input_fn = functools.partial(data_utils.get_squad_input_data, |
|
|
FLAGS.test_batch_size, FLAGS.seq_len, |
|
|
FLAGS.query_len, strategy, False, |
|
|
FLAGS.test_tfrecord_path) |
|
|
|
|
|
total_training_steps = FLAGS.train_steps |
|
|
steps_per_loop = FLAGS.iterations |
|
|
eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size) |
|
|
|
|
|
optimizer, learning_rate_fn = optimization.create_optimizer( |
|
|
FLAGS.learning_rate, |
|
|
total_training_steps, |
|
|
FLAGS.warmup_steps, |
|
|
adam_epsilon=FLAGS.adam_epsilon) |
|
|
model_config = xlnet_config.XLNetConfig(FLAGS) |
|
|
run_config = xlnet_config.create_run_config(True, False, FLAGS) |
|
|
input_meta_data = {} |
|
|
input_meta_data["start_n_top"] = FLAGS.start_n_top |
|
|
input_meta_data["end_n_top"] = FLAGS.end_n_top |
|
|
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate |
|
|
input_meta_data["predict_dir"] = FLAGS.predict_dir |
|
|
input_meta_data["n_best_size"] = FLAGS.n_best_size |
|
|
input_meta_data["max_answer_length"] = FLAGS.max_answer_length |
|
|
input_meta_data["test_batch_size"] = FLAGS.test_batch_size |
|
|
input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / |
|
|
strategy.num_replicas_in_sync) |
|
|
input_meta_data["mem_len"] = FLAGS.mem_len |
|
|
model_fn = functools.partial(get_qaxlnet_model, model_config, run_config, |
|
|
FLAGS.start_n_top, FLAGS.end_n_top) |
|
|
eval_examples = squad_utils.read_squad_examples( |
|
|
FLAGS.predict_file, is_training=False) |
|
|
if FLAGS.test_feature_path: |
|
|
logging.info("start reading pickle file...") |
|
|
with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f: |
|
|
eval_features = pickle.load(f) |
|
|
logging.info("finishing reading pickle file...") |
|
|
else: |
|
|
sp_model = spm.SentencePieceProcessor() |
|
|
sp_model.LoadFromSerializedProto( |
|
|
tf.io.gfile.GFile(FLAGS.spiece_model_file, "rb").read()) |
|
|
spm_basename = os.path.basename(FLAGS.spiece_model_file) |
|
|
eval_features = squad_utils.create_eval_data( |
|
|
spm_basename, sp_model, eval_examples, FLAGS.max_seq_length, |
|
|
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased) |
|
|
|
|
|
with tf.io.gfile.GFile(FLAGS.predict_file) as f: |
|
|
original_data = json.load(f)["data"] |
|
|
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, |
|
|
eval_examples, eval_features, original_data, |
|
|
eval_steps, input_meta_data) |
|
|
|
|
|
training_utils.train( |
|
|
strategy=strategy, |
|
|
model_fn=model_fn, |
|
|
input_meta_data=input_meta_data, |
|
|
eval_fn=eval_fn, |
|
|
metric_fn=None, |
|
|
train_input_fn=train_input_fn, |
|
|
init_checkpoint=FLAGS.init_checkpoint, |
|
|
init_from_transformerxl=FLAGS.init_from_transformerxl, |
|
|
total_training_steps=total_training_steps, |
|
|
steps_per_loop=steps_per_loop, |
|
|
optimizer=optimizer, |
|
|
learning_rate_fn=learning_rate_fn, |
|
|
model_dir=FLAGS.model_dir, |
|
|
save_steps=FLAGS.save_steps) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.run(main) |
|
|
|