robodiff_exp / app.py
ar0s's picture
fixed streamlit bug
5eaaa30
import os
import json
from pathlib import Path
import pandas as pd
from typing import Dict, List
import streamlit as st
from st_aggrid import AgGrid, GridOptionsBuilder
def create_config_dataframe(flattened_configs: List[Dict], ids: List[str]) -> pd.DataFrame:
df = pd.DataFrame(flattened_configs)
df.columns = [str(col).strip() for col in df.columns]
df.insert(0, 'id', ids)
return df
def flatten_dict(d, parent_key='', sep='.'):
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
@st.cache_data
def load_config_data():
log_dir = Path(__file__).parent / './logs'
configs = []
dir_ids = []
for dir_name in log_dir.glob('*'):
if not dir_name.is_dir():
continue
config_path = dir_name / 'config.json'
if not config_path.exists():
continue
with open(config_path, 'r') as f:
config = json.load(f)
flattened_config = flatten_dict(config)
configs.append(flattened_config)
dir_ids.append(dir_name.name)
return create_config_dataframe(configs, dir_ids)
@st.cache_data
def load_eval_data():
log_dir = Path(__file__).parent / './logs'
scores = []
dir_ids = []
for dir_name in log_dir.glob('*'):
if not dir_name.is_dir():
continue
eval_path = dir_name / 'eval_log.json'
if not eval_path.exists():
continue
with open(eval_path, 'r') as f:
eval_data = json.load(f)
score_dict = {'test/mean_score': eval_data.get('test/mean_score')}
scores.append(score_dict)
dir_ids.append(dir_name.name)
return create_config_dataframe(scores, dir_ids)
@st.cache_data
def load_meta_data():
log_dir = Path(__file__).parent / './logs'
metas = []
dir_ids = []
for dir_name in log_dir.glob('*'):
if not dir_name.is_dir():
continue
meta_path = dir_name / 'meta.json'
if not meta_path.exists():
continue
with open(meta_path, 'r') as f:
meta = json.load(f)
metas.append(meta)
dir_ids.append(dir_name.name)
return create_config_dataframe(metas, dir_ids)
def configure_grid(df):
gb = GridOptionsBuilder.from_dataframe(df)
gb.configure_pagination(paginationAutoPageSize=True)
gb.configure_side_bar()
gb.configure_default_column(groupable=True, value=True, enableRowGroup=True, aggFunc='sum', editable=False)
return gb.build()
# Streamlit app
st.set_page_config(layout="wide")
st.title("Experiment Results Dashboard")
# Load data
config_df = load_config_data()
score_df = load_eval_data()
meta_df = load_meta_data()
experiments_df = pd.merge(
config_df,
score_df,
on='id',
how='inner'
)
# Preprocess data
columns_to_keep = ['id', 'Filter.name', 'checkpoint', 'model', 'task', 'test/mean_score',
'tags', 'start_time', 'Filter.threshold', 'Filter.seed', 'dataset']
filtered_df = experiments_df[columns_to_keep].copy()
filtered_df['Filter.threshold'] = filtered_df['Filter.threshold'].fillna('None')
filtered_df['Filter.seed'] = filtered_df['Filter.seed'].fillna('None')
filtered_df['start_time'] = pd.to_datetime(filtered_df['start_time'], format='%Y%m%d_%H%M%S')
# Grouped view
grouped_df = filtered_df.groupby(['model', 'Filter.name', 'tags', 'task', 'Filter.threshold']).agg({
'test/mean_score': ['mean', lambda x: list(x)],
'checkpoint': ('count', list),
'start_time': ('max', lambda x: sorted(x, reverse=True)),
'Filter.seed': ('count', list),
}).reset_index()
tab1, tab2, tab3 = st.tabs(["Meta Data", "Experiment Results", "Grouped Analysis"])
with tab1:
st.header("Experiment Metadata")
AgGrid(meta_df.sort_values(['start_time'], ascending=False),
gridOptions=configure_grid(meta_df),
height=400,
fit_columns_on_grid_load=True)
with tab2:
st.header("Filtered Experiment Results")
AgGrid(filtered_df.sort_values(['start_time'], ascending=False),
gridOptions=configure_grid(filtered_df),
height=600,
fit_columns_on_grid_load=True)
with tab3:
st.header("Grouped Performance Analysis")
AgGrid(grouped_df.sort_values([('start_time', 'max')], ascending=False),
gridOptions=configure_grid(grouped_df),
height=600,
fit_columns_on_grid_load=True)