|
|
import altair as alt |
|
|
import fev |
|
|
import pandas as pd |
|
|
import pandas.io.formats.style |
|
|
|
|
|
|
|
|
|
|
|
COLORS = { |
|
|
"dl_text": "#5A7FA5", |
|
|
"st_text": "#A5795A", |
|
|
|
|
|
"bar_fill": "#8d5eb7", |
|
|
"error_bar": "#222222", |
|
|
"point": "#111111", |
|
|
"text_white": "white", |
|
|
"text_black": "black", |
|
|
"text_default": "#111", |
|
|
"gold": "#F7D36B", |
|
|
"silver": "#E5E7EB", |
|
|
"bronze": "#E6B089", |
|
|
"leakage_impute": "#3B82A0", |
|
|
"failure_impute": "#E07B39", |
|
|
} |
|
|
HEATMAP_COLOR_SCHEME = "purplegreen" |
|
|
|
|
|
|
|
|
MODEL_CONFIG = { |
|
|
|
|
|
"chronos_tiny": ("amazon/chronos-t5-tiny", "AWS", True, "DL"), |
|
|
"chronos_mini": ("amazon/chronos-t5-mini", "AWS", True, "DL"), |
|
|
"chronos_small": ("amazon/chronos-t5-small", "AWS", True, "DL"), |
|
|
"chronos_base": ("amazon/chronos-t5-base", "AWS", True, "DL"), |
|
|
"chronos_large": ("amazon/chronos-t5-large", "AWS", True, "DL"), |
|
|
"chronos_bolt_tiny": ("amazon/chronos-bolt-tiny", "AWS", True, "DL"), |
|
|
"chronos_bolt_mini": ("amazon/chronos-bolt-mini", "AWS", True, "DL"), |
|
|
"chronos_bolt_small": ("amazon/chronos-bolt-small", "AWS", True, "DL"), |
|
|
"chronos_bolt_base": ("amazon/chronos-bolt-base", "AWS", True, "DL"), |
|
|
"chronos-bolt": ("amazon/chronos-bolt-base", "AWS", True, "DL"), |
|
|
|
|
|
"moirai_large": ("Salesforce/moirai-1.1-R-large", "Salesforce", True, "DL"), |
|
|
"moirai_base": ("Salesforce/moirai-1.1-R-base", "Salesforce", True, "DL"), |
|
|
"moirai_small": ("Salesforce/moirai-1.1-R-small", "Salesforce", True, "DL"), |
|
|
"moirai-2.0": ("Salesforce/moirai-2.0-R-small", "Salesforce", True, "DL"), |
|
|
|
|
|
"timesfm": ("google/timesfm-1.0-200m-pytorch", "Google", True, "DL"), |
|
|
"timesfm-2.0": ("google/timesfm-2.0-500m-pytorch", "Google", True, "DL"), |
|
|
"timesfm-2.5": ("google/timesfm-2.5-200m-pytorch", "Google", True, "DL"), |
|
|
|
|
|
"toto-1.0": ("Datadog/Toto-Open-Base-1.0", "Datadog", True, "DL"), |
|
|
|
|
|
"tirex": ("NX-AI/TiRex", "NX-AI", True, "DL"), |
|
|
"tabpfn-ts": ("Prior-Labs/TabPFN-v2-reg", "Prior Labs", True, "DL"), |
|
|
"sundial-base": ("thuml/sundial-base-128m", "Tsinghua University", True, "DL"), |
|
|
"ttm-r2": ("ibm-granite/granite-timeseries-ttm-r2", "IBM", True, "DL"), |
|
|
|
|
|
"stat. ensemble": ( |
|
|
"https://nixtlaverse.nixtla.io/statsforecast/", |
|
|
"—", |
|
|
False, |
|
|
"ST", |
|
|
), |
|
|
"autoarima": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), |
|
|
"autotheta": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), |
|
|
"autoets": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), |
|
|
"seasonalnaive": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), |
|
|
"seasonal naive": ( |
|
|
"https://nixtlaverse.nixtla.io/statsforecast/", |
|
|
"—", |
|
|
False, |
|
|
"ST", |
|
|
), |
|
|
"drift": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), |
|
|
"naive": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), |
|
|
} |
|
|
|
|
|
|
|
|
ALL_METRICS = { |
|
|
"SQL": ( |
|
|
"SQL: Scaled Quantile Loss", |
|
|
"The [Scaled Quantile Loss (SQL)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.SQL) is a **scale-invariant** metric for evaluating **probabilistic** forecasts.", |
|
|
), |
|
|
"MASE": ( |
|
|
"MASE: Mean Absolute Scaled Error", |
|
|
"The [Mean Absolute Scaled Error (MASE)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.MASE) is a **scale-invariant** metric for evaluating **point** forecasts.", |
|
|
), |
|
|
"WQL": ( |
|
|
"WQL: Weighted Quantile Loss", |
|
|
"The [Weighted Quantile Loss (WQL)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.WQL), is a **scale-dependent** metric for evaluating **probabilistic** forecasts.", |
|
|
), |
|
|
"WAPE": ( |
|
|
"WAPE: Weighted Absolute Percentage Error", |
|
|
"The [Weighted Absolute Percentage Error (WAPE)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.WAPE) is a **scale-dependent** metric for evaluating **point** forecasts.", |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
def format_metric_name(metric_name: str): |
|
|
return ALL_METRICS[metric_name][0] |
|
|
|
|
|
|
|
|
def get_metric_description(metric_name: str): |
|
|
return ALL_METRICS[metric_name][1] |
|
|
|
|
|
|
|
|
def get_model_link(model_name): |
|
|
config = MODEL_CONFIG.get(model_name.lower()) |
|
|
if not config or not config[0]: |
|
|
return "" |
|
|
url = config[0] |
|
|
return url if url.startswith("https:") else f"https://huggingface.co/{url}" |
|
|
|
|
|
|
|
|
def get_model_organization(model_name): |
|
|
config = MODEL_CONFIG.get(model_name.lower()) |
|
|
return config[1] if config else "—" |
|
|
|
|
|
|
|
|
def get_zero_shot_status(model_name): |
|
|
config = MODEL_CONFIG.get(model_name.lower()) |
|
|
return "✓" if config and config[2] else "×" |
|
|
|
|
|
|
|
|
def get_model_type(model_name): |
|
|
config = MODEL_CONFIG.get(model_name.lower()) |
|
|
return config[3] if config else "—" |
|
|
|
|
|
|
|
|
def highlight_model_type_color(cell): |
|
|
config = MODEL_CONFIG.get(cell.lower()) |
|
|
if config: |
|
|
color = COLORS["dl_text"] if config[3] == "DL" else COLORS["st_text"] |
|
|
return f"font-weight: bold; color: {color}" |
|
|
return "font-weight: bold" |
|
|
|
|
|
|
|
|
def format_leaderboard(df: pd.DataFrame): |
|
|
df = df.copy() |
|
|
df["skill_score"] = df["skill_score"].round(1) |
|
|
df["win_rate"] = df["win_rate"].round(1) |
|
|
df["zero_shot"] = df["model_name"].apply(get_zero_shot_status) |
|
|
|
|
|
df["training_corpus_overlap"] = df.apply( |
|
|
lambda row: int(round(row["training_corpus_overlap"] * 100)) |
|
|
if row["zero_shot"] == "✓" |
|
|
else 0, |
|
|
axis=1, |
|
|
) |
|
|
df["link"] = df["model_name"].apply(get_model_link) |
|
|
df["org"] = df["model_name"].apply(get_model_organization) |
|
|
df = df[ |
|
|
[ |
|
|
"model_name", |
|
|
"win_rate", |
|
|
"skill_score", |
|
|
"median_inference_time_s", |
|
|
"training_corpus_overlap", |
|
|
"num_failures", |
|
|
"zero_shot", |
|
|
"org", |
|
|
"link", |
|
|
] |
|
|
] |
|
|
return ( |
|
|
df.style.map(highlight_model_type_color, subset=["model_name"]) |
|
|
.map(lambda x: "font-weight: bold", subset=["zero_shot"]) |
|
|
.apply( |
|
|
lambda x: [ |
|
|
"background-color: #f8f9fa" if i % 2 == 1 else "" for i in range(len(x)) |
|
|
], |
|
|
axis=0, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def construct_bar_chart(df: pd.DataFrame, col: str, metric_name: str): |
|
|
label = "Skill Score" if col == "skill_score" else "Win Rate" |
|
|
|
|
|
tooltip = [ |
|
|
alt.Tooltip("model_name:N"), |
|
|
alt.Tooltip(f"{col}:Q", format=".2f"), |
|
|
alt.Tooltip(f"{col}_lower:Q", title="95% CI Lower", format=".2f"), |
|
|
alt.Tooltip(f"{col}_upper:Q", title="95% CI Upper", format=".2f"), |
|
|
] |
|
|
|
|
|
base_encode = { |
|
|
"y": alt.Y("model_name:N", title="Forecasting Model", sort=None), |
|
|
"tooltip": tooltip, |
|
|
} |
|
|
|
|
|
bars = ( |
|
|
alt.Chart(df) |
|
|
.mark_bar(color=COLORS["bar_fill"], cornerRadius=4) |
|
|
.encode( |
|
|
x=alt.X(f"{col}:Q", title=f"{label} (%)", scale=alt.Scale(zero=False)), |
|
|
**base_encode, |
|
|
) |
|
|
) |
|
|
|
|
|
error_bars = ( |
|
|
alt.Chart(df) |
|
|
.mark_errorbar(ticks={"height": 5}, color=COLORS["error_bar"]) |
|
|
.encode( |
|
|
y=alt.Y("model_name:N", title=None, sort=None), |
|
|
x=alt.X(f"{col}_lower:Q", title=f"{label} (%)"), |
|
|
x2=alt.X2(f"{col}_upper:Q"), |
|
|
tooltip=tooltip, |
|
|
) |
|
|
) |
|
|
|
|
|
points = ( |
|
|
alt.Chart(df) |
|
|
.mark_point(filled=True, color=COLORS["point"]) |
|
|
.encode(x=alt.X(f"{col}:Q", title=f"{label} (%)"), **base_encode) |
|
|
) |
|
|
|
|
|
return ( |
|
|
(bars + error_bars + points) |
|
|
.properties(height=500, title=f"{label} ({metric_name}) with 95% CIs") |
|
|
.configure_title(fontSize=16) |
|
|
) |
|
|
|
|
|
|
|
|
def construct_pairwise_chart(df: pd.DataFrame, col: str, metric_name: str): |
|
|
config = { |
|
|
"win_rate": ("Win Rate", [0, 100], 50, f"abs(datum.{col} - 50) > 30"), |
|
|
"skill_score": ("Skill Score", [-15, 15], 0, f"abs(datum.{col}) > 10"), |
|
|
} |
|
|
cbar_label, domain, domain_mid, text_condition = config[col] |
|
|
|
|
|
df = df.copy() |
|
|
for c in [col, f"{col}_lower", f"{col}_upper"]: |
|
|
df[c] *= 100 |
|
|
|
|
|
model_order = ( |
|
|
df.groupby("model_1")[col].mean().sort_values(ascending=False).index.tolist() |
|
|
) |
|
|
|
|
|
tooltip = [ |
|
|
alt.Tooltip("model_1:N", title="Model 1"), |
|
|
alt.Tooltip("model_2:N", title="Model 2"), |
|
|
alt.Tooltip(f"{col}:Q", title=cbar_label.split(" ")[0], format=".1f"), |
|
|
alt.Tooltip(f"{col}_lower:Q", title="95% CI Lower", format=".1f"), |
|
|
alt.Tooltip(f"{col}_upper:Q", title="95% CI Upper", format=".1f"), |
|
|
] |
|
|
|
|
|
base = alt.Chart(df).encode( |
|
|
x=alt.X( |
|
|
"model_2:N", |
|
|
sort=model_order, |
|
|
title="Model 2", |
|
|
axis=alt.Axis(orient="top", labelAngle=-90), |
|
|
), |
|
|
y=alt.Y("model_1:N", sort=model_order, title="Model 1"), |
|
|
) |
|
|
|
|
|
heatmap = base.mark_rect().encode( |
|
|
color=alt.Color( |
|
|
f"{col}:Q", |
|
|
legend=None, |
|
|
scale=alt.Scale( |
|
|
scheme=HEATMAP_COLOR_SCHEME, |
|
|
domain=domain, |
|
|
domainMid=domain_mid, |
|
|
clamp=True, |
|
|
), |
|
|
), |
|
|
tooltip=tooltip, |
|
|
) |
|
|
|
|
|
text_main = base.mark_text(dy=-8, fontSize=8, baseline="top", yOffset=5).encode( |
|
|
text=alt.Text(f"{col}:Q", format=".1f"), |
|
|
color=alt.condition( |
|
|
text_condition, |
|
|
alt.value(COLORS["text_white"]), |
|
|
alt.value(COLORS["text_black"]), |
|
|
), |
|
|
tooltip=tooltip, |
|
|
) |
|
|
|
|
|
return ( |
|
|
(heatmap + text_main) |
|
|
.properties( |
|
|
height=550, |
|
|
title={ |
|
|
"text": f"Pairwise {cbar_label} ({metric_name}) with 95% CIs", |
|
|
"fontSize": 16, |
|
|
}, |
|
|
) |
|
|
.configure_axis(labelFontSize=11, titleFontSize=13, titleFontWeight="bold") |
|
|
.resolve_scale(color="independent") |
|
|
) |
|
|
|
|
|
|
|
|
def construct_pivot_table_from_df( |
|
|
errors: pd.DataFrame, metric_name: str |
|
|
) -> pd.io.formats.style.Styler: |
|
|
"""Construct styled pivot table from precomputed DataFrame.""" |
|
|
|
|
|
def highlight_by_position(styler): |
|
|
rank_colors = {1: COLORS["gold"], 2: COLORS["silver"], 3: COLORS["bronze"]} |
|
|
|
|
|
for row_idx in errors.index: |
|
|
row_ranks = errors.loc[row_idx].rank(method="min") |
|
|
for col_idx in errors.columns: |
|
|
rank = row_ranks[col_idx] |
|
|
style_parts = [] |
|
|
|
|
|
|
|
|
if rank <= 3: |
|
|
style_parts.append(f"background-color: {rank_colors[rank]}") |
|
|
else: |
|
|
style_parts.append(f"color: {COLORS['text_default']}") |
|
|
|
|
|
if style_parts: |
|
|
styler = styler.map( |
|
|
lambda x, s="; ".join(style_parts): s, |
|
|
subset=pd.IndexSlice[row_idx:row_idx, col_idx:col_idx], |
|
|
) |
|
|
return styler |
|
|
|
|
|
return highlight_by_position(errors.style).format(precision=3) |
|
|
|
|
|
|
|
|
def construct_pivot_table( |
|
|
summaries: pd.DataFrame, |
|
|
metric_name: str, |
|
|
baseline_model: str, |
|
|
leakage_imputation_model: str, |
|
|
) -> pd.io.formats.style.Styler: |
|
|
errors = fev.pivot_table( |
|
|
summaries=summaries, metric_column=metric_name, task_columns=["task_name"] |
|
|
) |
|
|
train_overlap = ( |
|
|
fev.pivot_table( |
|
|
summaries=summaries, |
|
|
metric_column="trained_on_this_dataset", |
|
|
task_columns=["task_name"], |
|
|
) |
|
|
.fillna(False) |
|
|
.astype(bool) |
|
|
) |
|
|
|
|
|
is_imputed_baseline = errors.isna() |
|
|
is_leakage_imputed = train_overlap |
|
|
|
|
|
|
|
|
errors = errors.mask(train_overlap, errors[leakage_imputation_model], axis=0) |
|
|
for col in errors.columns: |
|
|
if col != baseline_model: |
|
|
errors[col] = errors[col].fillna(errors[baseline_model]) |
|
|
|
|
|
errors = errors[errors.rank(axis=1).mean().sort_values().index] |
|
|
errors.index.rename("Task name", inplace=True) |
|
|
|
|
|
def highlight_by_position(styler): |
|
|
rank_colors = {1: COLORS["gold"], 2: COLORS["silver"], 3: COLORS["bronze"]} |
|
|
|
|
|
for row_idx in errors.index: |
|
|
row_ranks = errors.loc[row_idx].rank(method="min") |
|
|
for col_idx in errors.columns: |
|
|
rank = row_ranks[col_idx] |
|
|
style_parts = [] |
|
|
|
|
|
|
|
|
if rank <= 3: |
|
|
style_parts.append(f"background-color: {rank_colors[rank]}") |
|
|
|
|
|
|
|
|
if is_leakage_imputed.loc[row_idx, col_idx]: |
|
|
style_parts.append(f"color: {COLORS['leakage_impute']}") |
|
|
elif is_imputed_baseline.loc[row_idx, col_idx]: |
|
|
style_parts.append(f"color: {COLORS['failure_impute']}") |
|
|
elif not style_parts or ( |
|
|
len(style_parts) == 1 and "font-weight" in style_parts[0] |
|
|
): |
|
|
style_parts.append(f"color: {COLORS['text_default']}") |
|
|
|
|
|
if style_parts: |
|
|
styler = styler.map( |
|
|
lambda x, s="; ".join(style_parts): s, |
|
|
subset=pd.IndexSlice[row_idx:row_idx, col_idx:col_idx], |
|
|
) |
|
|
return styler |
|
|
|
|
|
return highlight_by_position(errors.style).format(precision=3) |
|
|
|