Spaces:
Runtime error
Runtime error
feat: remove unused metrics
Browse filesFormer-commit-id: 00a582c7b2dc2f5d8c86bc8818bf8968d4903a70
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -218,7 +218,7 @@ class DataTrainingArguments:
|
|
| 218 |
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
| 219 |
)
|
| 220 |
predict_with_generate: bool = field(
|
| 221 |
-
default=False, metadata={"help": "Whether to use generate to calculate generative metrics
|
| 222 |
)
|
| 223 |
num_beams: Optional[int] = field(
|
| 224 |
default=None,
|
|
@@ -605,35 +605,6 @@ def main():
|
|
| 605 |
desc="Running tokenizer on prediction dataset",
|
| 606 |
)
|
| 607 |
|
| 608 |
-
# Metric
|
| 609 |
-
#metric = load_metric("rouge")
|
| 610 |
-
|
| 611 |
-
def postprocess_text(preds, labels):
|
| 612 |
-
preds = [pred.strip() for pred in preds]
|
| 613 |
-
labels = [label.strip() for label in labels]
|
| 614 |
-
|
| 615 |
-
# rougeLSum expects newline after each sentence
|
| 616 |
-
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
| 617 |
-
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
| 618 |
-
|
| 619 |
-
return preds, labels
|
| 620 |
-
|
| 621 |
-
def compute_metrics(preds, labels):
|
| 622 |
-
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
| 623 |
-
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
| 624 |
-
|
| 625 |
-
# Some simple post-processing
|
| 626 |
-
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
| 627 |
-
|
| 628 |
-
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
| 629 |
-
# Extract a few results from ROUGE
|
| 630 |
-
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
| 631 |
-
|
| 632 |
-
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
| 633 |
-
result["gen_len"] = np.mean(prediction_lens)
|
| 634 |
-
result = {k: round(v, 4) for k, v in result.items()}
|
| 635 |
-
return result
|
| 636 |
-
|
| 637 |
# Initialize our training
|
| 638 |
rng = jax.random.PRNGKey(training_args.seed)
|
| 639 |
rng, dropout_rng = jax.random.split(rng)
|
|
@@ -819,15 +790,8 @@ def main():
|
|
| 819 |
# log metrics
|
| 820 |
wandb_log(eval_metrics, step=global_step, prefix='eval')
|
| 821 |
|
| 822 |
-
# compute ROUGE metrics
|
| 823 |
-
rouge_desc = ""
|
| 824 |
-
# if data_args.predict_with_generate:
|
| 825 |
-
# rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
| 826 |
-
# eval_metrics.update(rouge_metrics)
|
| 827 |
-
# rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
| 828 |
-
|
| 829 |
# Print metrics and update progress bar
|
| 830 |
-
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']}
|
| 831 |
epochs.write(desc)
|
| 832 |
epochs.desc = desc
|
| 833 |
|
|
@@ -952,15 +916,8 @@ def main():
|
|
| 952 |
pred_metrics = get_metrics(pred_metrics)
|
| 953 |
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
| 954 |
|
| 955 |
-
# compute ROUGE metrics
|
| 956 |
-
rouge_desc = ""
|
| 957 |
-
if data_args.predict_with_generate:
|
| 958 |
-
rouge_metrics = compute_metrics(pred_generations, pred_labels)
|
| 959 |
-
pred_metrics.update(rouge_metrics)
|
| 960 |
-
rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
|
| 961 |
-
|
| 962 |
# Print metrics
|
| 963 |
-
desc = f"Predict Loss: {pred_metrics['loss']}
|
| 964 |
logger.info(desc)
|
| 965 |
|
| 966 |
|
|
|
|
| 218 |
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
| 219 |
)
|
| 220 |
predict_with_generate: bool = field(
|
| 221 |
+
default=False, metadata={"help": "Whether to use generate to calculate generative metrics."}
|
| 222 |
)
|
| 223 |
num_beams: Optional[int] = field(
|
| 224 |
default=None,
|
|
|
|
| 605 |
desc="Running tokenizer on prediction dataset",
|
| 606 |
)
|
| 607 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
# Initialize our training
|
| 609 |
rng = jax.random.PRNGKey(training_args.seed)
|
| 610 |
rng, dropout_rng = jax.random.split(rng)
|
|
|
|
| 790 |
# log metrics
|
| 791 |
wandb_log(eval_metrics, step=global_step, prefix='eval')
|
| 792 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 793 |
# Print metrics and update progress bar
|
| 794 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
| 795 |
epochs.write(desc)
|
| 796 |
epochs.desc = desc
|
| 797 |
|
|
|
|
| 916 |
pred_metrics = get_metrics(pred_metrics)
|
| 917 |
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
| 918 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 919 |
# Print metrics
|
| 920 |
+
desc = f"Predict Loss: {pred_metrics['loss']})"
|
| 921 |
logger.info(desc)
|
| 922 |
|
| 923 |
|