Spaces:
Sleeping
Sleeping
mode viz fix
Browse files
app.py
CHANGED
|
@@ -805,9 +805,17 @@ with demo:
|
|
| 805 |
interactive=True,
|
| 806 |
visible=False
|
| 807 |
)
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 811 |
multiselect=True,
|
| 812 |
interactive=True
|
| 813 |
)
|
|
@@ -830,18 +838,83 @@ with demo:
|
|
| 830 |
plot_output = gr.Plot()
|
| 831 |
|
| 832 |
# Update visualization when any selector changes
|
| 833 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 834 |
control.change(
|
| 835 |
-
fn=lambda
|
| 836 |
-
inputs=[
|
| 837 |
outputs=plot_output
|
| 838 |
)
|
| 839 |
|
| 840 |
-
# Update
|
| 841 |
viz_version_selector.change(
|
| 842 |
-
fn=
|
| 843 |
inputs=[viz_version_selector],
|
| 844 |
-
outputs=[
|
| 845 |
)
|
| 846 |
|
| 847 |
# with gr.TabItem("About", elem_id="guardbench-about-tab", id=2):
|
|
|
|
| 805 |
interactive=True,
|
| 806 |
visible=False
|
| 807 |
)
|
| 808 |
+
# New: Mode selector
|
| 809 |
+
def get_model_mode_choices(version):
|
| 810 |
+
df = get_leaderboard_df(version=version)
|
| 811 |
+
if df.empty:
|
| 812 |
+
return []
|
| 813 |
+
# Return list of tuples (model_name, mode)
|
| 814 |
+
return sorted([f"{row['model_name']} [{row['mode']}]" for _, row in df.drop_duplicates(subset=["model_name", "mode"]).iterrows()])
|
| 815 |
+
|
| 816 |
+
model_mode_selector = gr.Dropdown(
|
| 817 |
+
choices=get_model_mode_choices(CURRENT_VERSION),
|
| 818 |
+
label="Select Model(s) [Mode] to Compare",
|
| 819 |
multiselect=True,
|
| 820 |
interactive=True
|
| 821 |
)
|
|
|
|
| 838 |
plot_output = gr.Plot()
|
| 839 |
|
| 840 |
# Update visualization when any selector changes
|
| 841 |
+
def update_visualization_with_mode(selected_model_modes, selected_category, selected_metric, version):
|
| 842 |
+
if not selected_model_modes:
|
| 843 |
+
return go.Figure()
|
| 844 |
+
df = get_leaderboard_df(version=version) if selected_category == "All Results" else get_category_leaderboard_df(selected_category, version=version)
|
| 845 |
+
if df.empty:
|
| 846 |
+
return go.Figure()
|
| 847 |
+
# Parse selected_model_modes into model_name and mode
|
| 848 |
+
selected_pairs = [s.rsplit(" [", 1) for s in selected_model_modes]
|
| 849 |
+
selected_pairs = [(name.strip(), mode.strip("] ")) for name, mode in selected_pairs]
|
| 850 |
+
mask = df.apply(lambda row: (row['model_name'], str(row['mode'])) in selected_pairs, axis=1)
|
| 851 |
+
filtered_df = df[mask]
|
| 852 |
+
metric_cols = [col for col in filtered_df.columns if selected_metric in col]
|
| 853 |
+
fig = go.Figure()
|
| 854 |
+
colors = ['#8FCCCC', '#C2A4B6', '#98B4A6', '#B68F7C']
|
| 855 |
+
for idx, (model_name, mode) in enumerate(selected_pairs):
|
| 856 |
+
model_data = filtered_df[(filtered_df['model_name'] == model_name) & (filtered_df['mode'] == mode)]
|
| 857 |
+
if not model_data.empty:
|
| 858 |
+
values = model_data[metric_cols].values[0].tolist()
|
| 859 |
+
values = values + [values[0]]
|
| 860 |
+
categories = [col.replace(f'_{selected_metric}', '') for col in metric_cols]
|
| 861 |
+
categories = categories + [categories[0]]
|
| 862 |
+
fig.add_trace(go.Scatterpolar(
|
| 863 |
+
r=values,
|
| 864 |
+
theta=categories,
|
| 865 |
+
name=f"{model_name} [{mode}]",
|
| 866 |
+
line_color=colors[idx % len(colors)],
|
| 867 |
+
fill='toself'
|
| 868 |
+
))
|
| 869 |
+
fig.update_layout(
|
| 870 |
+
paper_bgcolor='#000000',
|
| 871 |
+
plot_bgcolor='#000000',
|
| 872 |
+
font={'color': '#ffffff'},
|
| 873 |
+
title={
|
| 874 |
+
'text': f'{selected_category} - {selected_metric.upper()} Score Comparison',
|
| 875 |
+
'font': {'color': '#ffffff', 'size': 24}
|
| 876 |
+
},
|
| 877 |
+
polar=dict(
|
| 878 |
+
bgcolor='#000000',
|
| 879 |
+
radialaxis=dict(
|
| 880 |
+
visible=True,
|
| 881 |
+
range=[0, 1],
|
| 882 |
+
gridcolor='#333333',
|
| 883 |
+
linecolor='#333333',
|
| 884 |
+
tickfont={'color': '#ffffff'},
|
| 885 |
+
),
|
| 886 |
+
angularaxis=dict(
|
| 887 |
+
gridcolor='#333333',
|
| 888 |
+
linecolor='#333333',
|
| 889 |
+
tickfont={'color': '#ffffff'},
|
| 890 |
+
)
|
| 891 |
+
),
|
| 892 |
+
height=600,
|
| 893 |
+
showlegend=True,
|
| 894 |
+
legend=dict(
|
| 895 |
+
yanchor="top",
|
| 896 |
+
y=0.99,
|
| 897 |
+
xanchor="right",
|
| 898 |
+
x=0.99,
|
| 899 |
+
bgcolor='rgba(0,0,0,0.5)',
|
| 900 |
+
font={'color': '#ffffff'}
|
| 901 |
+
)
|
| 902 |
+
)
|
| 903 |
+
return fig
|
| 904 |
+
|
| 905 |
+
# Connect selectors to update function
|
| 906 |
+
for control in [viz_version_selector, model_mode_selector, category_selector, metric_selector]:
|
| 907 |
control.change(
|
| 908 |
+
fn=lambda smm, sc, s_metric, v: update_visualization_with_mode(smm, CATEGORY_REVERSE_MAP.get(sc, sc), s_metric, v),
|
| 909 |
+
inputs=[model_mode_selector, category_selector, metric_selector, viz_version_selector],
|
| 910 |
outputs=plot_output
|
| 911 |
)
|
| 912 |
|
| 913 |
+
# Update model_mode_selector choices when version changes
|
| 914 |
viz_version_selector.change(
|
| 915 |
+
fn=get_model_mode_choices,
|
| 916 |
inputs=[viz_version_selector],
|
| 917 |
+
outputs=[model_mode_selector]
|
| 918 |
)
|
| 919 |
|
| 920 |
# with gr.TabItem("About", elem_id="guardbench-about-tab", id=2):
|