Spaces:
Runtime error
Runtime error
| ## LIBRARIES ### | |
| ## Data | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import json | |
| from tqdm import tqdm | |
| from math import floor | |
| from datasets import load_dataset | |
| from collections import defaultdict | |
| from transformers import AutoTokenizer | |
| pd.options.display.float_format = '${:,.2f}'.format | |
| # Analysis | |
| # from gensim.models.doc2vec import Doc2Vec | |
| # from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score | |
| import nltk | |
| from nltk.cluster import KMeansClusterer | |
| import scipy.spatial.distance as sdist | |
| from scipy.spatial import distance_matrix | |
| # nltk.download('punkt') #make sure that punkt is downloaded | |
| # App & Visualization | |
| import streamlit as st | |
| import altair as alt | |
| import plotly.graph_objects as go | |
| from streamlit_vega_lite import altair_component | |
| # utils | |
| from random import sample | |
| from error_analysis import utils as ut | |
| def down_samp(embedding): | |
| """Down sample a data frame for altiar visualization """ | |
| # total number of positive and negative sentiments in the class | |
| #embedding = embedding.groupby('slice').apply(lambda x: x.sample(frac=0.3)) | |
| total_size = embedding.groupby(['slice','label'], as_index=False).count() | |
| user_data = 0 | |
| # if 'Your Sentences' in str(total_size['slice']): | |
| # tmp = embedding.groupby(['slice'], as_index=False).count() | |
| # val = int(tmp[tmp['slice'] == "Your Sentences"]['source']) | |
| # user_data = val | |
| max_sample = total_size.groupby('slice').max()['content'] | |
| # # down sample to meeting altair's max values | |
| # # but keep the proportional representation of groups | |
| down_samp = 1/(sum(max_sample.astype(float))/(1000-user_data)) | |
| max_samp = max_sample.apply(lambda x: floor(x*down_samp)).astype(int).to_dict() | |
| max_samp['Your Sentences'] = user_data | |
| # # sample down for each group in the data frame | |
| embedding = embedding.groupby('slice').apply(lambda x: x.sample(n=max_samp.get(x.name))).reset_index(drop=True) | |
| # # order the embedding | |
| return(embedding) | |
| def data_comparison(df): | |
| selection = alt.selection_multi(fields=['cluster:N','label:O']) | |
| color = alt.condition(alt.datum.slice == 'high-loss', alt.Color('cluster:N', scale = alt.Scale(domain=df.cluster.unique().tolist())), alt.value("lightgray")) | |
| opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25)) | |
| # basic chart | |
| scatter = alt.Chart(df).mark_point(size=100, filled=True).encode( | |
| x=alt.X('x:Q', axis=None), | |
| y=alt.Y('y:Q', axis=None), | |
| color=color, | |
| shape=alt.Shape('label:O', scale=alt.Scale(range=['circle', 'diamond'])), | |
| tooltip=['cluster:N','slice:N','content:N','label:O','pred:O'], | |
| opacity=opacity | |
| ).properties( | |
| width=1000, | |
| height=800 | |
| ).interactive() | |
| legend = alt.Chart(df).mark_point(size=100, filled=True).encode( | |
| x=alt.X("label:O"), | |
| y=alt.Y('cluster:N', axis=alt.Axis(orient='right'), title=""), | |
| shape=alt.Shape('label:O', scale=alt.Scale( | |
| range=['circle', 'diamond']), legend=None), | |
| color=color, | |
| ).add_selection( | |
| selection | |
| ) | |
| layered = scatter | legend | |
| layered = layered.configure_axis( | |
| grid=False | |
| ).configure_view( | |
| strokeOpacity=0 | |
| ) | |
| return layered | |
| def quant_panel(embedding_df): | |
| """ Quantitative Panel Layout""" | |
| all_metrics = {} | |
| st.warning("**Error slice visualization**") | |
| with st.expander("How to read this chart:"): | |
| st.markdown("* Each **point** is an input example.") | |
| st.markdown("* Gray points have low-loss and the colored have high-loss. High-loss instances are clustered using **kmeans** and each color represents a cluster.") | |
| st.markdown("* The **shape** of each point reflects the label category -- positive (diamond) or negative sentiment (circle).") | |
| st.altair_chart(data_comparison(down_samp(embedding_df)), use_container_width=True) | |
| def frequent_tokens(data, tokenizer, loss_quantile=0.95, top_k=200, smoothing=0.005): | |
| unique_tokens = [] | |
| tokens = [] | |
| for row in tqdm(data['content']): | |
| tokenized = tokenizer(row,padding=True, return_tensors='pt') | |
| tokens.append(tokenized['input_ids'].flatten()) | |
| unique_tokens.append(torch.unique(tokenized['input_ids'])) | |
| losses = data['loss'].astype(float) | |
| high_loss = losses.quantile(loss_quantile) | |
| loss_weights = (losses > high_loss) | |
| loss_weights = loss_weights / loss_weights.sum() | |
| token_frequencies = defaultdict(float) | |
| token_frequencies_error = defaultdict(float) | |
| weights_uniform = np.full_like(loss_weights, 1 / len(loss_weights)) | |
| num_examples = len(data) | |
| for i in tqdm(range(num_examples)): | |
| for token in unique_tokens[i]: | |
| token_frequencies[token.item()] += weights_uniform[i] | |
| token_frequencies_error[token.item()] += loss_weights[i] | |
| token_lrs = {k: (smoothing+token_frequencies_error[k]) / (smoothing+token_frequencies[k]) for k in token_frequencies} | |
| tokens_sorted = list(map(lambda x: x[0], sorted(token_lrs.items(), key=lambda x: x[1])[::-1])) | |
| top_tokens = [] | |
| for i, (token) in enumerate(tokens_sorted[:top_k]): | |
| top_tokens.append(['%10s' % (tokenizer.decode(token)), '%.4f' % (token_frequencies[token]), '%.4f' % ( | |
| token_frequencies_error[token]), '%4.2f' % (token_lrs[token])]) | |
| return pd.DataFrame(top_tokens, columns=['Token', 'Freq', 'Freq error slice', 'lrs']) | |
| def get_data(inference, emb): | |
| preds = inference.outputs.numpy() | |
| losses = inference.losses.numpy() | |
| embeddings = pd.DataFrame(emb, columns=['x', 'y']) | |
| num_examples = len(losses) | |
| # dataset_labels = [dataset[i]['label'] for i in range(num_examples)] | |
| return pd.concat([pd.DataFrame(np.transpose(np.vstack([dataset[:num_examples]['content'], | |
| dataset[:num_examples]['label'], preds, losses])), columns=['content', 'label', 'pred', 'loss']), embeddings], axis=1) | |
| def clustering(data,num_clusters): | |
| X = np.array(data['embedding'].tolist()) | |
| kclusterer = KMeansClusterer( | |
| num_clusters, distance=nltk.cluster.util.cosine_distance, | |
| repeats=25,avoid_empty_clusters=True) | |
| assigned_clusters = kclusterer.cluster(X, assign_clusters=True) | |
| data['cluster'] = pd.Series(assigned_clusters, index=data.index).astype('int') | |
| data['centroid'] = data['cluster'].apply(lambda x: kclusterer.means()[x]) | |
| return data, assigned_clusters | |
| def kmeans(df, num_clusters=3): | |
| data_hl = df.loc[df['slice'] == 'high-loss'] | |
| data_kmeans,clusters = clustering(data_hl,num_clusters) | |
| merged = pd.merge(df, data_kmeans, left_index=True, right_index=True, how='outer', suffixes=('', '_y')) | |
| merged.drop(merged.filter(regex='_y$').columns.tolist(),axis=1,inplace=True) | |
| merged['cluster'] = merged['cluster'].fillna(num_clusters).astype('int') | |
| return merged | |
| def distance_from_centroid(row): | |
| return sdist.norm(row['embedding'] - row['centroid'].tolist()) | |
| def topic_distribution(weights, smoothing=0.01): | |
| topic_frequencies = defaultdict(float) | |
| topic_frequencies_spotlight = defaultdict(float) | |
| weights_uniform = np.full_like(weights, 1 / len(weights)) | |
| num_examples = len(weights) | |
| for i in range(num_examples): | |
| example = dataset[i] | |
| category = example['title'] | |
| topic_frequencies[category] += weights_uniform[i] | |
| topic_frequencies_spotlight[category] += weights[i] | |
| topic_ratios = {c: (smoothing + topic_frequencies_spotlight[c]) / ( | |
| smoothing + topic_frequencies[c]) for c in topic_frequencies} | |
| categories_sorted = map(lambda x: x[0], sorted( | |
| topic_ratios.items(), key=lambda x: x[1], reverse=True)) | |
| topic_distr = [] | |
| for category in categories_sorted: | |
| topic_distr.append(['%.3f' % topic_frequencies[category], '%.3f' % | |
| topic_frequencies_spotlight[category], '%.2f' % topic_ratios[category], '%s' % category]) | |
| return pd.DataFrame(topic_distr, columns=['Overall frequency', 'Error frequency', 'Ratio', 'Category']) | |
| # for category in categories_sorted: | |
| # return(topic_frequencies[category], topic_frequencies_spotlight[category], topic_ratios[category], category) | |
| def populate_session(dataset,model): | |
| data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet') | |
| if model == 'albert-base-v2-yelp-polarity': | |
| tokenizer = AutoTokenizer.from_pretrained('textattack/'+model) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model) | |
| if "user_data" not in st.session_state: | |
| st.session_state["user_data"] = data_df | |
| if "selected_slice" not in st.session_state: | |
| st.session_state["selected_slice"] = None | |
| def read_file_to_df(file): | |
| return pd.read_parquet(file) | |
| if __name__ == "__main__": | |
| ### STREAMLIT APP CONGFIG ### | |
| st.set_page_config(layout="wide", page_title="Interactive Error Analysis") | |
| ut.init_style() | |
| lcol, rcol = st.columns([2, 2]) | |
| # ******* loading the mode and the data | |
| #st.sidebar.mardown("<h4>Interactive Error Analysis</h4>", unsafe_allow_html=True) | |
| dataset = st.sidebar.selectbox( | |
| "Dataset", | |
| ["amazon_polarity", "yelp_polarity"], | |
| index = 1 | |
| ) | |
| model = st.sidebar.selectbox( | |
| "Model", | |
| ["distilbert-base-uncased-finetuned-sst-2-english", | |
| "albert-base-v2-yelp-polarity"], | |
| ) | |
| ### LOAD DATA AND SESSION VARIABLES ### | |
| ##uncomment the next next line to run dynamically and not from file | |
| #populate_session(dataset, model) | |
| data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet') | |
| loss_quantile = st.sidebar.slider( | |
| "Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95 | |
| ) | |
| data_df['loss'] = data_df['loss'].astype(float) | |
| losses = data_df['loss'] | |
| high_loss = losses.quantile(loss_quantile) | |
| data_df['slice'] = 'high-loss' | |
| data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss') | |
| with rcol: | |
| with st.spinner(text='loading...'): | |
| st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True) | |
| #uncomment the next two lines to run dynamically and not from file | |
| #commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile) | |
| commontokens = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'_commontokens.parquet') | |
| with st.expander("How to read the table:"): | |
| st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.") | |
| st.write(commontokens) | |
| run_kmeans = st.sidebar.radio("Cluster error slice?", ('True', 'False'), index=0) | |
| num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3) | |
| if run_kmeans == 'True': | |
| with st.spinner(text='running kmeans...'): | |
| merged = kmeans(data_df,num_clusters=num_clusters) | |
| with lcol: | |
| st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True) | |
| with st.expander("How to read the table:"): | |
| st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.") | |
| st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.") | |
| st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.") | |
| with st.spinner(text='loading error slice...'): | |
| dataframe=read_file_to_df('./assets/data/'+dataset+ '_'+ model+'_error-slices.parquet') | |
| #uncomment the next next line to run dynamically and not from file | |
| # dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values( | |
| # by=['loss'], ascending=False) | |
| # table_html = dataframe.to_html( | |
| # columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50) | |
| # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers | |
| st.write(dataframe,width=900, height=300) | |
| with st.spinner(text='loading visualization...'): | |
| quant_panel(merged) |