Spaces:
Runtime error
Runtime error
| from collections import Counter | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| import streamlit as st | |
| from datasets import load_dataset | |
| from matplotlib import pyplot as plt | |
| from matplotlib_venn import venn2, venn3 | |
| from ngram import get_tuples_manual_sentences | |
| from rich import print as rprint | |
| from bigbio.dataloader import BigBioConfigHelpers | |
| # from matplotlib_venn_wordcloud import venn2_wordcloud, venn3_wordcloud | |
| # vanilla tokenizer | |
| def tokenizer(text, counter): | |
| if not text: | |
| return text, [] | |
| text = text.strip() | |
| text = text.replace("\t", "") | |
| text = text.replace("\n", "") | |
| # split | |
| text_list = text.split(" ") | |
| return text, text_list | |
| def norm(lengths): | |
| mu = np.mean(lengths) | |
| sigma = np.std(lengths) | |
| return mu, sigma | |
| def load_helper(): | |
| conhelps = BigBioConfigHelpers() | |
| conhelps = conhelps.filtered(lambda x: x.dataset_name != "pubtator_central") | |
| conhelps = conhelps.filtered(lambda x: x.is_bigbio_schema) | |
| conhelps = conhelps.filtered(lambda x: not x.is_local) | |
| rprint( | |
| "loaded {} configs from {} datasets".format( | |
| len(conhelps), | |
| len(set([helper.dataset_name for helper in conhelps])), | |
| ) | |
| ) | |
| return conhelps | |
| _TEXT_MAPS = { | |
| "bigbio_kb": ["text"], | |
| "bigbio_text": ["text"], | |
| "bigbio_qa": ["question", "context"], | |
| "bigbio_te": ["premise", "hypothesis"], | |
| "bigbio_tp": ["text_1", "text_2"], | |
| "bigbio_pairs": ["text_1", "text_2"], | |
| "bigbio_t2t": ["text_1", "text_2"], | |
| } | |
| IBM_COLORS = [ | |
| "#648fff", | |
| "#dc267f", | |
| "#ffb000", | |
| "#fe6100", | |
| "#785ef0", | |
| "#000000", | |
| "#ffffff", | |
| ] | |
| N = 3 | |
| def token_length_per_entry(entry, schema, counter): | |
| result = {} | |
| if schema == "bigbio_kb": | |
| for passage in entry["passages"]: | |
| result_key = passage["type"] | |
| for key in _TEXT_MAPS[schema]: | |
| text = passage[key][0] | |
| sents, ngrams = get_tuples_manual_sentences(text.lower(), N) | |
| toks = [tok for sent in sents for tok in sent] | |
| tups = ["_".join(tup) for tup in ngrams] | |
| counter.update(tups) | |
| result[result_key] = len(toks) | |
| else: | |
| for key in _TEXT_MAPS[schema]: | |
| text = entry[key] | |
| sents, ngrams = get_tuples_manual_sentences(text.lower(), N) | |
| toks = [tok for sent in sents for tok in sent] | |
| result[key] = len(toks) | |
| tups = ["_".join(tup) for tup in ngrams] | |
| counter.update(tups) | |
| return result, counter | |
| def parse_token_length_and_n_gram(dataset, data_config, st=None): | |
| hist_data = [] | |
| n_gram_counters = [] | |
| rprint(data_config) | |
| for split, data in dataset.items(): | |
| my_bar = st.progress(0) | |
| total = len(data) | |
| n_gram_counter = Counter() | |
| for i, entry in enumerate(data): | |
| my_bar.progress(int(i / total * 100)) | |
| result, n_gram_counter = token_length_per_entry( | |
| entry, data_config.schema, n_gram_counter | |
| ) | |
| result["total_token_length"] = sum([v for k, v in result.items()]) | |
| result["split"] = split | |
| hist_data.append(result) | |
| # remove single count | |
| # n_gram_counter = Counter({x: count for x, count in n_gram_counter.items() if count > 1}) | |
| n_gram_counters.append(n_gram_counter) | |
| my_bar.empty() | |
| st.write("token lengths complete!") | |
| return pd.DataFrame(hist_data), n_gram_counters | |
| def center_title(fig): | |
| fig.update_layout( | |
| title={"y": 0.9, "x": 0.5, "xanchor": "center", "yanchor": "top"}, | |
| font=dict( | |
| size=18, | |
| ), | |
| ) | |
| return fig | |
| def draw_histogram(hist_data, col_name, st=None): | |
| fig = px.histogram( | |
| hist_data, | |
| x=col_name, | |
| color="split", | |
| color_discrete_sequence=IBM_COLORS, | |
| marginal="box", # or violin, rug | |
| barmode="group", | |
| hover_data=hist_data.columns, | |
| histnorm="probability", | |
| nbins=20, | |
| title=f"{col_name} distribution by split", | |
| ) | |
| st.plotly_chart(center_title(fig), use_container_width=True) | |
| def draw_bar(bar_data, x, y, st=None): | |
| fig = px.bar( | |
| bar_data, | |
| x=x, | |
| y=y, | |
| color="split", | |
| color_discrete_sequence=IBM_COLORS, | |
| # marginal="box", # or violin, rug | |
| barmode="group", | |
| hover_data=bar_data.columns, | |
| title=f"{y} distribution by split", | |
| ) | |
| st.plotly_chart(center_title(fig), use_container_width=True) | |
| def parse_metrics(metadata, st=None): | |
| for k, m in metadata.items(): | |
| mattrs = m.__dict__ | |
| for m, attr in mattrs.items(): | |
| if type(attr) == int and attr > 0: | |
| st.metric(label=f"{k}-{m}", value=attr) | |
| def parse_counters(metadata): | |
| metadata = metadata["train"] # using the training counter to fetch the names | |
| counters = [] | |
| for k, v in metadata.__dict__.items(): | |
| if "counter" in k and len(v) > 0: | |
| counters.append(k) | |
| return counters | |
| # generate the df for histogram | |
| def parse_label_counter(metadata, counter_type): | |
| hist_data = [] | |
| for split, m in metadata.items(): | |
| metadata_counter = getattr(m, counter_type) | |
| for k, v in metadata_counter.items(): | |
| row = {} | |
| row["labels"] = k | |
| row[counter_type] = v | |
| row["split"] = split | |
| hist_data.append(row) | |
| return pd.DataFrame(hist_data) | |
| if __name__ == "__main__": | |
| # load helpers | |
| conhelps = load_helper() | |
| configs_set = set() | |
| for conhelper in conhelps: | |
| configs_set.add(conhelper.dataset_name) | |
| # st.write(sorted(configs_set)) | |
| # setup page, sidebar, columns | |
| st.set_page_config(layout="wide") | |
| s = st.session_state | |
| if not s: | |
| s.pressed_first_button = False | |
| data_name = st.sidebar.selectbox("dataset", sorted(configs_set)) | |
| st.sidebar.write("you selected:", data_name) | |
| st.header(f"Dataset stats for {data_name}") | |
| # setup data configs | |
| data_helpers = conhelps.for_dataset(data_name) | |
| data_configs = [d.config for d in data_helpers] | |
| data_config_names = [d.config.name for d in data_helpers] | |
| data_config_name = st.sidebar.selectbox("config", set(data_config_names)) | |
| if st.sidebar.button("fetch") or s.pressed_first_button: | |
| s.pressed_first_button = True | |
| helper = conhelps.for_config_name(data_config_name) | |
| metadata_helper = helper.get_metadata() | |
| parse_metrics(metadata_helper, st.sidebar) | |
| # load HF dataset | |
| data_idx = data_config_names.index(data_config_name) | |
| data_config = data_configs[data_idx] | |
| # st.write(data_name) | |
| dataset = load_dataset( | |
| f"bigbio/{data_name}", name=data_config_name | |
| ) | |
| ds = pd.DataFrame(dataset["train"]) | |
| st.write(ds) | |
| # general token length | |
| tok_hist_data, ngram_counters = parse_token_length_and_n_gram( | |
| dataset, data_config, st.sidebar | |
| ) | |
| # draw token distribution | |
| draw_histogram(tok_hist_data, "total_token_length", st) | |
| # general counter(s) | |
| col1, col2 = st.columns([1, 6]) | |
| counters = parse_counters(metadata_helper) | |
| counter_type = col1.selectbox("counter_type", counters) | |
| label_df = parse_label_counter(metadata_helper, counter_type) | |
| label_max = int(label_df[counter_type].max() - 1) | |
| label_min = int(label_df[counter_type].min()) | |
| filter_value = col1.slider("counter_filter (min, max)", label_min, label_max) | |
| label_df = label_df[label_df[counter_type] >= filter_value] | |
| # draw bar chart for counter | |
| draw_bar(label_df, "labels", counter_type, col2) | |
| venn_fig, ax = plt.subplots() | |
| if len(ngram_counters) == 2: | |
| union_counter = ngram_counters[0] + ngram_counters[1] | |
| print(ngram_counters[0].most_common(10)) | |
| print(ngram_counters[1].most_common(10)) | |
| total = len(union_counter.keys()) | |
| ngram_counter_sets = [ | |
| set(ngram_counter.keys()) for ngram_counter in ngram_counters | |
| ] | |
| venn2( | |
| ngram_counter_sets, | |
| dataset.keys(), | |
| set_colors=IBM_COLORS[:3], | |
| subset_label_formatter=lambda x: f"{(x/total):1.0%}", | |
| ) | |
| else: | |
| union_counter = ngram_counters[0] + ngram_counters[1] + ngram_counters[2] | |
| total = len(union_counter.keys()) | |
| ngram_counter_sets = [ | |
| set(ngram_counter.keys()) for ngram_counter in ngram_counters | |
| ] | |
| venn3( | |
| ngram_counter_sets, | |
| dataset.keys(), | |
| set_colors=IBM_COLORS[:4], | |
| subset_label_formatter=lambda x: f"{(x/total):1.0%}", | |
| ) | |
| venn_fig.suptitle(f"{N}-gram intersection for {data_name}", fontsize=20) | |
| st.pyplot(venn_fig) | |
| st.sidebar.button("Re-run") | |