Spaces:
Build error
Build error
| from bleu import Bleu | |
| from rouge import Rouge | |
| from datasets import load_metric | |
| from pathlib import Path | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| #from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py | |
| rouge = Rouge() | |
| bleu = Bleu() | |
| def read_markdown_file(markdown_file): | |
| return Path(markdown_file).read_text() | |
| metrics= ['rouge','bleu'] | |
| def compute(data): | |
| return metric.compute(predictions=data["predictions"], references=data["references"])["accuracy"] | |
| st.sidebar.markdown("Choose a functionality below:") | |
| with st.sidebar.expander("Compare one or more metrics", expanded=True): | |
| metric_names = st.multiselect( | |
| f"Choose metrics to explore:", | |
| metrics, | |
| default="rouge") | |
| loaded_metrics= [] | |
| for metric in metric_names: | |
| metric = load_metric(metric) | |
| loaded_metrics.append(metric) | |
| ### Single metric mode | |
| print(metric_names) | |
| if metric_names == []: | |
| st.markdown("## Please choose one or more metrics.") | |
| elif len(metric_names) == 1: | |
| metric_name = metric_names[0] | |
| st.markdown("# You chose " + metric_name.upper()) | |
| st.markdown("## You can test it out below:") | |
| reference = st.text_input(label= 'Input a reference sentence here:', value = "hello world") | |
| prediction = st.text_input(label= 'Input a prediction sentence here:', value = "goodnight moon") | |
| predictions = [] | |
| predictions.append(prediction.split()) | |
| #print(predictions) | |
| references = [] | |
| references.append(reference.split()) | |
| #print(references) | |
| if metric_name == "bleu": | |
| score = metric.compute(predictions=predictions, references=[references]) | |
| col1, col2, col3 = st.columns(3) | |
| col1.metric("BLEU", score['bleu']) | |
| col2.metric("Brevity penalty", score['brevity_penalty']) | |
| col3.metric('Length Ratio', score['length_ratio']) | |
| if metric_name == "rouge": | |
| score = metric.compute(predictions=predictions, references=references) | |
| #print(score) | |
| col1, col2, col3 = st.columns(3) | |
| col1.metric("Rouge 1 Precision", score['rouge1'].mid.precision) | |
| col2.metric("Rouge 1 Recall", score['rouge1'].mid.recall) | |
| col3.metric("Rouge 1 FMeasure", score['rouge1'].mid.fmeasure) | |
| col4, col5, col6 = st.columns(3) | |
| col4.metric("Rouge 2 Precision", score['rouge2'].mid.precision) | |
| col5.metric("Rouge 2 Recall", score['rouge2'].mid.recall) | |
| col6.metric("Rouge 2 FMeasure", score['rouge2'].mid.fmeasure) | |
| # col1.metric("BLEU", score['bleu']) | |
| # col2.metric("Brevity penalty", score['brevity_penalty']) | |
| # col3.metric('Length Ratio', score['length_ratio']) | |
| st.markdown('===================================================================================') | |
| #components.html("""<hr style="height:10px;border:none;color:#333;background-color:#333;" /> """) | |
| st.markdown(read_markdown_file(metric_name+"_metric_card.md")) | |
| # Multiple metric mode | |
| else: | |
| metric1 = metric_names[0] | |
| metric2 = metric_names[1] | |
| st.markdown("# You chose " + metric1.upper() + " and " + metric2.upper()) | |
| st.markdown("## You can test it out below:") | |
| reference = st.text_input(label= 'Input a reference sentence here:', value = "hello world") | |
| prediction = st.text_input(label= 'Input a prediction sentence here:', value = "goodnight moon") | |
| predictions = [] | |
| predictions.append(prediction.split()) | |
| #print(predictions) | |
| references = [] | |
| references.append(reference.split()) | |
| #print(references) | |
| if "bleu" in metric_names: | |
| bleu_ix = metric_names.index("bleu") | |
| bleu_score = loaded_metrics[bleu_ix].compute(predictions=predictions, references=[references]) | |
| col1, col2, col3 = st.columns(3) | |
| col1.metric("BLEU", bleu_score['bleu']) | |
| col2.metric("Brevity penalty", bleu_score['brevity_penalty']) | |
| col3.metric('Length Ratio', bleu_score['length_ratio']) | |
| if "rouge" in metric_names: | |
| rouge_ix = metric_names.index("rouge") | |
| rouge_score = loaded_metrics[rouge_ix].compute(predictions=predictions, references=references) | |
| #print(score) | |
| col1, col2, col3 = st.columns(3) | |
| col1.metric("Rouge 1 Precision", rouge_score['rouge1'].mid.precision) | |
| col2.metric("Rouge 1 Recall", rouge_score['rouge1'].mid.recall) | |
| col3.metric("Rouge 1 FMeasure", rouge_score['rouge1'].mid.fmeasure) | |
| col4, col5, col6 = st.columns(3) | |
| col4.metric("Rouge 2 Precision", rouge_score['rouge2'].mid.precision) | |
| col5.metric("Rouge 2 Recall", rouge_score['rouge2'].mid.recall) | |
| col6.metric("Rouge 2 FMeasure", rouge_score['rouge2'].mid.fmeasure) | |