import pandas as pd from plotnine import ( labs, theme, theme_bw, guides, position_nudge, aes, geom_violin, geom_line, geom_jitter, scale_x_discrete, ggplot, ) from pathlib import Path import math class GrewTSEVisualiser: """ A basic visualisation class that creates a violin plot based on a syntactic evaluation. """ def __init__(self) -> None: self.data = None def visualise_syntactic_performance( self, results: pd.DataFrame, title: str, target_x_label: str, alt_x_label: str, x_axis_label: str, y_axis_label: str, filename: str, ) -> None: """ Visualise a syntactic performance evaluation result. :param results: pass the results DataFrame created by the GrewTSEEvaluator. :param title: Give the diagram a main title. :param target_x_label: Give the original target word and hence first word in the minimal pair a label e.g. 'Accusative'. :param alt_x_label: Give the second element in the minimal pair a label e.g. 'Dative'. :param x_axis_label: Give the X Axis a title. :param y_axis_label: Give the Y Axis a title. :param filename: A filename to save the visualisation. :return: """ visualise_slope( filename, results, target_x_label, alt_x_label, x_axis_label, y_axis_label, title, ) def visualise_slope( path: Path, results: pd.DataFrame, target_x_label: str, alt_x_label: str, x_axis_label: str, y_axis_label: str, title: str, ): lsize = 0.65 fill_alpha = 0.7 # X-axis: Acc, Gen # Y-axis: surprisal filtered_df = results[ results["form_ungrammatical"].notna() & (results["form_ungrammatical"].str.strip() != "") ] filtered_df["subject_id"] = filtered_df.index # Melt the dataframe df_long = pd.melt( filtered_df, id_vars=["subject_id"], value_vars=["p_grammatical", "p_ungrammatical"], var_name="source", value_name="log_prob", ) # Map source to fixed x-axis labels df_long["x_label"] = df_long["source"].map( {"p_grammatical": target_x_label, "p_ungrammatical": alt_x_label} ) def surprisal(p: float) -> float: return -math.log2(p) def confidence(p: float) -> float: return math.log2(p) df_long["surprisal"] = df_long["log_prob"].apply(surprisal) p = ( ggplot(df_long, aes(x="x_label", y="surprisal", fill="x_label")) + scale_x_discrete(limits=[target_x_label, alt_x_label]) + geom_jitter(aes(color="x_label"), width=0.01, alpha=0.7) + # geom_text(aes(label='label'), nudge_y=0.1) + geom_line(aes(group="subject_id"), color="gray", alpha=0.7, size=0.2) + geom_violin( df_long[df_long["x_label"] == target_x_label], aes(x="x_label", y="surprisal", group="x_label"), position=position_nudge(x=-0.2), style="left-right", alpha=fill_alpha, size=lsize, ) + geom_violin( df_long[df_long["x_label"] == alt_x_label], aes(x="x_label", y="surprisal", group="x_label"), position=position_nudge(x=0.2), style="right-left", alpha=fill_alpha, size=lsize, ) + guides(fill=False) + theme_bw() + theme(figure_size=(8, 4), legend_position="none") + labs(x=x_axis_label, y=y_axis_label, title=title) ) p.save(path, width=14, height=8, dpi=300)