Spaces:
Running
Running
update mmlu chart
Browse files- data_handler.py +36 -9
data_handler.py
CHANGED
|
@@ -82,13 +82,25 @@ def unified_exam_chart(unified_exam_df, plot_column):
|
|
| 82 |
|
| 83 |
def mmlu_chart(mmlu_df, plot_column):
|
| 84 |
df = mmlu_df.copy()
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
df['Average'] = df[subject_cols].mean(axis=1)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
x=x_col,
|
| 93 |
y='Model',
|
| 94 |
color=x_col,
|
|
@@ -96,15 +108,30 @@ def mmlu_chart(mmlu_df, plot_column):
|
|
| 96 |
labels={x_col: 'Accuracy', 'Model': 'Model'},
|
| 97 |
title=title,
|
| 98 |
orientation='h',
|
| 99 |
-
range_color=[0,1]
|
| 100 |
)
|
| 101 |
|
| 102 |
fig.update_layout(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
xaxis=dict(range=[0, x_range_max]),
|
| 104 |
title=dict(text=title, font=dict(size=16)),
|
| 105 |
xaxis_title=dict(font=dict(size=12)),
|
| 106 |
yaxis_title=dict(font=dict(size=12)),
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
)
|
|
|
|
|
|
|
|
|
|
| 110 |
return fig
|
|
|
|
|
|
| 82 |
|
| 83 |
def mmlu_chart(mmlu_df, plot_column):
|
| 84 |
df = mmlu_df.copy()
|
| 85 |
+
|
| 86 |
+
subject_cols = [
|
| 87 |
+
'Biology', 'Business', 'Chemistry', 'Computer Science', 'Economics',
|
| 88 |
+
'Engineering', 'Health', 'History', 'Law', 'Math', 'Other',
|
| 89 |
+
'Philosophy', 'Physics', 'Psychology'
|
| 90 |
+
]
|
| 91 |
df['Average'] = df[subject_cols].mean(axis=1)
|
| 92 |
+
|
| 93 |
+
df = df.sort_values(by=[plot_column, 'Model'],
|
| 94 |
+
ascending=[False, True]
|
| 95 |
+
).reset_index(drop=True)
|
| 96 |
+
|
| 97 |
+
x_col = plot_column
|
| 98 |
+
title = f'{plot_column}'
|
| 99 |
+
x_range_max = 1.0
|
| 100 |
+
bar_height_px = 28
|
| 101 |
+
|
| 102 |
+
fig = px.bar(
|
| 103 |
+
df,
|
| 104 |
x=x_col,
|
| 105 |
y='Model',
|
| 106 |
color=x_col,
|
|
|
|
| 108 |
labels={x_col: 'Accuracy', 'Model': 'Model'},
|
| 109 |
title=title,
|
| 110 |
orientation='h',
|
| 111 |
+
range_color=[0, 1]
|
| 112 |
)
|
| 113 |
|
| 114 |
fig.update_layout(
|
| 115 |
+
height=bar_height_px * len(df) + 120,
|
| 116 |
+
margin=dict(l=220, r=40, t=60, b=40),
|
| 117 |
+
width=1000,
|
| 118 |
+
|
| 119 |
xaxis=dict(range=[0, x_range_max]),
|
| 120 |
title=dict(text=title, font=dict(size=16)),
|
| 121 |
xaxis_title=dict(font=dict(size=12)),
|
| 122 |
yaxis_title=dict(font=dict(size=12)),
|
| 123 |
+
|
| 124 |
+
yaxis=dict(
|
| 125 |
+
automargin=True,
|
| 126 |
+
tickmode='array',
|
| 127 |
+
tickvals=df['Model'],
|
| 128 |
+
ticktext=df['Model'],
|
| 129 |
+
dtick=1,
|
| 130 |
+
autorange='reversed'
|
| 131 |
+
)
|
| 132 |
)
|
| 133 |
+
|
| 134 |
+
fig.update_yaxes(tickfont=dict(size=10))
|
| 135 |
+
|
| 136 |
return fig
|
| 137 |
+
|