Spaces:
Running
Running
File size: 5,938 Bytes
efa06b4 7dd3ffd efa06b4 6db8de6 7dd3ffd efa06b4 73cf928 d02c1e3 73cf928 efa06b4 7dd3ffd efa06b4 7dd3ffd efa06b4 7dd3ffd efa06b4 7dd3ffd efa06b4 7dd3ffd efa06b4 7dd3ffd efa06b4 7dd3ffd efa06b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import pandas as pd
import plotly.express as px
from config.constants import (
CC_BENCHMARKS,
LC_BENCHMARKS,
NON_RTL_METRICS,
RTL_METRICS,
S2R_BENCHMARKS,
SCATTER_PLOT_X_TICKS,
TYPE_COLORS,
Y_AXIS_LIMITS,
DISCARDED_MODELS,
)
from utils import filter_bench, filter_bench_all, filter_RTLRepo, handle_special_cases
# this is just a simple class to load the correct data depending on which sim we are at
class Simulator:
def __init__(self, icarus_df, icarus_agg, verilator_df, verilator_agg):
self.icarus_df = icarus_df
self.icarus_agg = icarus_agg
self.verilator_df = verilator_df
self.verilator_agg = verilator_agg
self.current_simulator = "Icarus"
def get_current_df(self):
if self.current_simulator == "Icarus":
return self.icarus_df
else:
return self.verilator_df
def get_current_agg(self):
if self.current_simulator == "Icarus":
return self.icarus_agg
else:
return self.verilator_agg
def set_simulator(self, simulator):
self.current_simulator = simulator
# filtering main function for the leaderboard body
def filter_leaderboard(task, benchmark, model_type, search_query, max_params, state, name):
"""Filter leaderboard data based on user selections."""
subset = state.get_current_df().copy()
# Filter by task specific benchmarks when 'All' benchmarks is selected
if task == "Spec-to-RTL":
valid_benchmarks = S2R_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
elif task == "Code Completion":
valid_benchmarks = CC_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
elif task == "Line Completion †":
valid_benchmarks = LC_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
if benchmark != "All":
subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
if model_type != "All":
# without emojis
subset = subset[subset["Model Type"] == model_type.split(" ")[0]]
if search_query:
subset = subset[subset["Model"].str.contains(search_query, case=False, na=False)]
max_params = float(max_params)
if max_params < 995: # when re-setting the max param slider we never reach 1000 again xd
subset = subset[subset["Params"] <= max_params]
else:
subset["Params"] = subset["Params"].fillna("Unknown")
if name == "Other Models":
subset = subset[subset["Model"].isin(DISCARDED_MODELS)]
else:
subset = subset[~subset["Model"].isin(DISCARDED_MODELS)]
if benchmark == "All":
if task == "Spec-to-RTL":
return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg S2R", name=name)
elif task == "Code Completion":
return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg MC", name=name)
elif task == "Line Completion †":
return filter_RTLRepo(subset, name=name)
elif benchmark == "RTL-Repo":
return filter_RTLRepo(subset, name=name)
else:
agg_column = None
if benchmark == "VerilogEval S2R":
agg_column = "Agg VerilogEval S2R"
elif benchmark == "VerilogEval MC":
agg_column = "Agg VerilogEval MC"
elif benchmark == "RTLLM":
agg_column = "Agg RTLLM"
elif benchmark == "VeriGen":
agg_column = "Agg VeriGen"
return filter_bench(subset, state.get_current_agg(), agg_column, name=name)
def generate_scatter_plot(benchmark, metric, state):
"""Generate a scatter plot for the given benchmark and metric."""
benchmark, metric = handle_special_cases(benchmark, metric)
subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
subset = subset[~subset["Model"].isin(DISCARDED_MODELS)]
if benchmark == "RTL-Repo":
subset = subset[subset["Metric"].str.contains("EM", case=False, na=False)]
detailed_scores = subset.groupby("Model", as_index=False)["Score"].mean()
detailed_scores.rename(columns={"Score": "Exact Matching (EM)"}, inplace=True)
else:
detailed_scores = subset.pivot_table(index="Model", columns="Metric", values="Score").reset_index()
details = state.get_current_df()[["Model", "Params", "Model Type"]].drop_duplicates("Model")
scatter_data = pd.merge(detailed_scores, details, on="Model", how="left").dropna(
subset=["Params", metric]
)
scatter_data["x"] = scatter_data["Params"]
scatter_data["y"] = scatter_data[metric]
scatter_data["size"] = (scatter_data["x"] ** 0.3) * 40
scatter_data["color"] = scatter_data["Model Type"].map(TYPE_COLORS).fillna("gray")
y_range = Y_AXIS_LIMITS.get(metric, [0, 80])
fig = px.scatter(
scatter_data,
x="x",
y="y",
log_x=True,
size="size",
color="Model Type",
text="Model",
hover_data={metric: ":.2f"},
title=f"Params vs. {metric} for {benchmark}",
labels={"x": "# Params (Log Scale)", "y": metric},
template="plotly_white",
height=600,
width=1200,
)
fig.update_traces(
textposition="top center",
textfont_size=10,
marker=dict(opacity=0.8, line=dict(width=0.5, color="black")),
)
fig.update_layout(
xaxis=dict(
showgrid=True,
type="log",
tickmode="array",
tickvals=SCATTER_PLOT_X_TICKS["tickvals"],
ticktext=SCATTER_PLOT_X_TICKS["ticktext"],
),
showlegend=False,
yaxis=dict(range=y_range),
margin=dict(l=50, r=50, t=50, b=50),
plot_bgcolor="white",
)
return fig
|