Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import React, { useEffect, useState } from 'react' | |
| import LeaderboardFilter from './LeaderboardFilter' | |
| import LoadingSpinner from './LoadingSpinner' | |
| interface LeaderboardTableProps { | |
| benchmarkData: any | |
| selectedModels: Set<string> | |
| } | |
| interface Row { | |
| metric: string | |
| [key: string]: string | number | |
| } | |
| interface Groups { | |
| [group: string]: { [subgroup: string]: string[] } | |
| } | |
| const OverallMetricFilter: React.FC<{ | |
| overallMetrics: string[] | |
| selectedOverallMetrics: Set<string> | |
| setSelectedOverallMetrics: (metrics: Set<string>) => void | |
| }> = ({ overallMetrics, selectedOverallMetrics, setSelectedOverallMetrics }) => { | |
| const toggleMetric = (metric: string) => { | |
| const newSelected = new Set(selectedOverallMetrics) | |
| if (newSelected.has(metric)) { | |
| newSelected.delete(metric) | |
| } else { | |
| newSelected.add(metric) | |
| } | |
| setSelectedOverallMetrics(newSelected) | |
| } | |
| return ( | |
| <div className="w-full mb-4"> | |
| <fieldset className="fieldset w-full p-4 rounded border border-gray-700"> | |
| <legend className="fieldset-legend font-semibold"> | |
| Metrics ({selectedOverallMetrics.size}/{overallMetrics.length}) | |
| </legend> | |
| <div className="grid grid-cols-2 md:grid-cols-4 lg:grid-cols-6 gap-1 max-h-48 overflow-y-auto pr-2"> | |
| {overallMetrics.map((metric) => ( | |
| <label key={metric} className="flex items-center gap-2 text-sm"> | |
| <input | |
| type="checkbox" | |
| className="form-checkbox h-4 w-4" | |
| checked={selectedOverallMetrics.has(metric)} | |
| onChange={() => toggleMetric(metric)} | |
| /> | |
| <span className="truncate" title={metric}> | |
| {metric} | |
| </span> | |
| </label> | |
| ))} | |
| </div> | |
| </fieldset> | |
| </div> | |
| ) | |
| } | |
| const LeaderboardTable: React.FC<LeaderboardTableProps> = ({ benchmarkData, selectedModels }) => { | |
| const [tableRows, setTableRows] = useState<Row[]>([]) | |
| const [tableHeader, setTableHeader] = useState<string[]>([]) | |
| const [error, setError] = useState<string | null>(null) | |
| const [groups, setGroups] = useState<Groups>({}) | |
| const [openGroups, setOpenGroups] = useState<{ [key: string]: boolean }>({}) | |
| const [openSubGroups, setOpenSubGroups] = useState<{ [key: string]: { [key: string]: boolean } }>( | |
| {} | |
| ) | |
| const [selectedMetrics, setSelectedMetrics] = useState<Set<string>>(new Set()) | |
| const [overallMetrics, setOverallMetrics] = useState<string[]>([]) | |
| const [selectedOverallMetrics, setSelectedOverallMetrics] = useState<Set<string>>(new Set()) | |
| useEffect(() => { | |
| if (!benchmarkData) { | |
| return | |
| } | |
| try { | |
| const data = benchmarkData | |
| const rows: Row[] = data['rows'] | |
| const allGroups = data['groups'] as { [key: string]: string[] } | |
| const { Overall: overallGroup, ...groups } = allGroups | |
| const uniqueMetrics = new Set<string>() | |
| overallGroup?.forEach((metric) => { | |
| if (metric.includes('_')) { | |
| const metricName = metric.split('_').slice(1).join('_') | |
| uniqueMetrics.add(metricName) | |
| } | |
| }) | |
| setOverallMetrics(Array.from(uniqueMetrics).sort()) | |
| setSelectedOverallMetrics(new Set(Array.from(uniqueMetrics))) | |
| const groupsData = Object.entries(groups) | |
| .sort(([groupA], [groupB]) => { | |
| if (groupA === 'Overall') return -1 | |
| if (groupB === 'Overall') return 1 | |
| return groupA.localeCompare(groupB) | |
| }) | |
| .reduce( | |
| (acc, [group, metrics]) => { | |
| const sortedMetrics = [...metrics].sort() | |
| acc[group] = sortedMetrics.reduce<{ [key: string]: string[] }>((subAcc, metric) => { | |
| const [mainGroup, subGroup] = metric.split('_') | |
| if (!subAcc[mainGroup]) { | |
| subAcc[mainGroup] = [] | |
| } | |
| subAcc[mainGroup].push(metric) | |
| return subAcc | |
| }, {}) | |
| acc[group] = Object.fromEntries( | |
| Object.entries(acc[group]).sort(([subGroupA], [subGroupB]) => | |
| subGroupA.localeCompare(subGroupB) | |
| ) | |
| ) | |
| return acc | |
| }, | |
| {} as { [key: string]: { [key: string]: string[] } } | |
| ) | |
| const allKeys: string[] = Array.from(new Set(rows.flatMap((row) => Object.keys(row)))) | |
| const headers = allKeys.filter((key) => key !== 'metric') | |
| const initialOpenGroups: { [key: string]: boolean } = {} | |
| const initialOpenSubGroups: { [key: string]: { [key: string]: boolean } } = {} | |
| Object.keys(groupsData).forEach((group) => { | |
| initialOpenGroups[group] = false | |
| initialOpenSubGroups[group] = {} | |
| Object.keys(groupsData[group]).forEach((subGroup) => { | |
| initialOpenSubGroups[group][subGroup] = false | |
| }) | |
| }) | |
| const allMetrics = Object.values(groups).flat() | |
| setSelectedMetrics(new Set(allMetrics)) | |
| setTableHeader(headers) | |
| setTableRows(rows) | |
| setGroups(groupsData) | |
| setOpenGroups(initialOpenGroups) | |
| setOpenSubGroups(initialOpenSubGroups) | |
| setError(null) | |
| } catch (err: any) { | |
| setError('Failed to parse benchmark data, please try again: ' + err.message) | |
| } | |
| }, [benchmarkData]) | |
| const toggleGroup = (group: string) => { | |
| setOpenGroups((prev) => ({ ...prev, [group]: !prev[group] })) | |
| } | |
| const toggleSubGroup = (group: string, subGroup: string) => { | |
| setOpenSubGroups((prev) => ({ | |
| ...prev, | |
| [group]: { | |
| ...(prev[group] || {}), | |
| [subGroup]: !prev[group]?.[subGroup], | |
| }, | |
| })) | |
| } | |
| // Find all metrics matching a particular extracted metric name (like "log10_p_value") | |
| const findAllMetricsForName = (metricName: string): string[] => { | |
| return tableRows | |
| .filter((row) => { | |
| const metric = row.metric as string | |
| if (metric.includes('_')) { | |
| const extractedName = metric.split('_').slice(1).join('_') | |
| return extractedName.endsWith(metricName) | |
| } | |
| return false | |
| }) | |
| .map((row) => row.metric as string) | |
| } | |
| // Identify metrics that don't belong to any overall metric group | |
| const findStandaloneMetrics = (): string[] => { | |
| // Get all metrics from the table rows | |
| const allMetrics = tableRows.map((row) => row.metric as string) | |
| // Filter to only include metrics that aren't part of any of the overall metrics | |
| return allMetrics.filter((metric) => { | |
| // Check if this metric is part of any of the overall metrics | |
| for (const overall of overallMetrics) { | |
| if (metric.endsWith(`_${overall}`) || metric === overall) { | |
| return false // This metric belongs to an overall group | |
| } | |
| } | |
| return true | |
| }) | |
| } | |
| // Calculate average and standard deviation for a set of metrics for a specific column | |
| const calculateStats = ( | |
| metricNames: string[], | |
| columnKey: string | |
| ): { avg: number; stdDev: number } => { | |
| const values = metricNames | |
| .map((metricName) => { | |
| const row = tableRows.find((row) => row.metric === metricName) | |
| return row ? Number(row[columnKey]) : NaN | |
| }) | |
| .filter((value) => !isNaN(value)) | |
| if (values.length === 0) return { avg: NaN, stdDev: NaN } | |
| const avg = values.reduce((sum, val) => sum + val, 0) / values.length | |
| const squareDiffs = values.map((value) => { | |
| const diff = value - avg | |
| return diff * diff | |
| }) | |
| const variance = squareDiffs.reduce((sum, sqrDiff) => sum + sqrDiff, 0) / values.length | |
| const stdDev = Math.sqrt(variance) | |
| return { avg, stdDev } | |
| } | |
| // Filter metrics by group and/or subgroup | |
| const filterMetricsByGroupAndSubgroup = ( | |
| metricNames: string[], | |
| group: string | null = null, | |
| subgroup: string | null = null | |
| ): string[] => { | |
| // If no group specified, return all metrics | |
| if (!group) return metricNames | |
| // Get all metrics for the specified group | |
| const groupMetrics = Object.values(groups[group] || {}).flat() | |
| // If subgroup is specified, further filter to that subgroup | |
| if (subgroup && groups[group]?.[subgroup]) { | |
| return metricNames.filter( | |
| (metric) => groups[group][subgroup].includes(metric) && selectedMetrics.has(metric) | |
| ) | |
| } | |
| // Otherwise return all metrics in the group | |
| return metricNames.filter( | |
| (metric) => groupMetrics.includes(metric) && selectedMetrics.has(metric) | |
| ) | |
| } | |
| return ( | |
| <div className="rounded shadow"> | |
| {error && <div className="text-red-500">{error}</div>} | |
| {!error && ( | |
| <div className="flex flex-col gap-8"> | |
| <div className="flex flex-col gap-4"> | |
| <OverallMetricFilter | |
| overallMetrics={overallMetrics} | |
| selectedOverallMetrics={selectedOverallMetrics} | |
| setSelectedOverallMetrics={setSelectedOverallMetrics} | |
| /> | |
| {/* <LeaderboardFilter | |
| groups={groups} | |
| selectedMetrics={selectedMetrics} | |
| setSelectedMetrics={setSelectedMetrics} | |
| /> */} | |
| </div> | |
| {selectedModels.size === 0 || selectedMetrics.size === 0 ? ( | |
| <div className="text-center p-4 text-lg"> | |
| Please select at least one model and one metric to display the data | |
| </div> | |
| ) : ( | |
| <> | |
| {/* Standalone metrics table */} | |
| {(() => { | |
| const standaloneMetrics = findStandaloneMetrics() | |
| if (standaloneMetrics.length === 0) return null | |
| return ( | |
| <div className="overflow-x-auto max-h-[80vh] overflow-y-auto"> | |
| <table className="table w-full min-w-max border-gray-700 border"> | |
| <thead> | |
| <tr> | |
| <th className="sticky left-0 top-0 bg-base-100 z-20 border-gray-700 border"> | |
| Metric | |
| </th> | |
| {tableHeader | |
| .filter((model) => selectedModels.has(model)) | |
| .map((model) => ( | |
| <th | |
| key={`standalone-${model}`} | |
| className="sticky top-0 bg-base-100 z-10 text-center text-xs border-gray-700 border" | |
| > | |
| {model} | |
| </th> | |
| ))} | |
| </tr> | |
| </thead> | |
| <tbody> | |
| {standaloneMetrics.sort().map((metric) => { | |
| const row = tableRows.find((r) => r.metric === metric) | |
| if (!row) return null | |
| return ( | |
| <tr key={`standalone-${metric}`} className="hover:bg-base-100"> | |
| <td className="sticky left-0 bg-base-100 z-10 border-gray-700 border"> | |
| {metric} | |
| </td> | |
| {tableHeader | |
| .filter((model) => selectedModels.has(model)) | |
| .map((col) => { | |
| const cell = row[col] | |
| return ( | |
| <td | |
| key={`standalone-${metric}-${col}`} | |
| className="text-center border-gray-700 border" | |
| > | |
| {!isNaN(Number(cell)) | |
| ? Number(Number(cell).toFixed(3)) | |
| : cell} | |
| </td> | |
| ) | |
| })} | |
| </tr> | |
| ) | |
| })} | |
| </tbody> | |
| </table> | |
| </div> | |
| ) | |
| })()} | |
| {/* Main metrics table */} | |
| <div className="overflow-x-auto max-h-[80vh] overflow-y-auto"> | |
| <table className="table w-full min-w-max border-gray-700 border"> | |
| <thead> | |
| <tr> | |
| <th className="sticky left-0 top-0 bg-base-100 z-20 border-gray-700 border"> | |
| Attack Category Metrics | |
| </th> | |
| {overallMetrics | |
| .filter((metric) => selectedOverallMetrics.has(metric)) | |
| .map((metric) => ( | |
| <th | |
| key={metric} | |
| colSpan={ | |
| tableHeader.filter((model) => selectedModels.has(model)).length | |
| } | |
| className="sticky top-0 bg-base-100 z-10 text-center border-x border-gray-300 border border-gray-700 border" | |
| > | |
| {metric} | |
| </th> | |
| ))} | |
| </tr> | |
| <tr> | |
| <th className="sticky left-0 bg-base-100 z-10 border-gray-700 border"></th> | |
| {overallMetrics | |
| .filter((metric) => selectedOverallMetrics.has(metric)) | |
| .map((metric) => ( | |
| <React.Fragment key={`header-models-${metric}`}> | |
| {tableHeader | |
| .filter((model) => selectedModels.has(model)) | |
| .map((model) => ( | |
| <th | |
| key={`${metric}-${model}`} | |
| className="sticky top-12 bg-base-100 z-10 text-center text-xs border-gray-700 border border-bottom-solid border-b-gray-700 border-b-2" | |
| > | |
| {model} | |
| </th> | |
| ))} | |
| </React.Fragment> | |
| ))} | |
| </tr> | |
| </thead> | |
| <tbody> | |
| {/* First render each group */} | |
| {Object.entries(groups).map(([group, subGroups]) => { | |
| // Skip the "Overall" group completely | |
| if (group === 'Overall') return null | |
| // Get all metrics for this group | |
| const allGroupMetrics = Object.values(subGroups).flat() | |
| // Filter to only include selected metrics | |
| const visibleGroupMetrics = filterMetricsByGroupAndSubgroup( | |
| allGroupMetrics, | |
| group | |
| ) | |
| // Skip this group if no metrics are selected | |
| if (visibleGroupMetrics.length === 0) return null | |
| return ( | |
| <React.Fragment key={group}> | |
| {/* Group row with average stats for the entire group */} | |
| <tr | |
| className="bg-base-200 cursor-pointer hover:bg-base-300" | |
| onClick={() => toggleGroup(group)} | |
| > | |
| <td className="sticky left-0 bg-base-200 z-10 font-medium border-gray-700 border"> | |
| {openGroups[group] ? '▼ ' : '▶ '} | |
| {group} | |
| </td> | |
| {/* For each metric column */} | |
| {overallMetrics | |
| .filter((metric) => selectedOverallMetrics.has(metric)) | |
| .map((metric) => ( | |
| // Render sub-columns for each model | |
| <React.Fragment key={`${group}-${metric}`}> | |
| {tableHeader | |
| .filter((model) => selectedModels.has(model)) | |
| .map((col) => { | |
| // Find all metrics in this group that match the current metric name | |
| const allMetricsWithName = findAllMetricsForName(metric) | |
| const metricsInGroupForThisMetric = | |
| visibleGroupMetrics.filter((m) => | |
| allMetricsWithName.includes(m) | |
| ) | |
| const stats = calculateStats(metricsInGroupForThisMetric, col) | |
| return ( | |
| <td | |
| key={`${group}-${metric}-${col}`} | |
| className="font-medium text-center border-gray-700 border" | |
| > | |
| {!isNaN(stats.avg) | |
| ? `${stats.avg.toFixed(3)} ± ${stats.stdDev.toFixed(3)}` | |
| : 'N/A'} | |
| </td> | |
| ) | |
| })} | |
| </React.Fragment> | |
| ))} | |
| </tr> | |
| {/* Only render subgroups if group is open */} | |
| {openGroups[group] && | |
| Object.entries(subGroups).map(([subGroup, metrics]) => { | |
| // Filter to only include selected metrics in this subgroup | |
| const visibleSubgroupMetrics = filterMetricsByGroupAndSubgroup( | |
| metrics, | |
| group, | |
| subGroup | |
| ) | |
| // Skip this subgroup if no metrics are selected | |
| if (visibleSubgroupMetrics.length === 0) return null | |
| return ( | |
| <React.Fragment key={`${group}-${subGroup}`}> | |
| {/* Subgroup row with average stats for the subgroup */} | |
| <tr | |
| className="bg-base-100 cursor-pointer hover:bg-base-200" | |
| onClick={() => toggleSubGroup(group, subGroup)} | |
| > | |
| <td className="sticky left-0 bg-base-100 z-10 pl-6 font-medium border-gray-700 border"> | |
| {openSubGroups[group]?.[subGroup] ? '▼ ' : '▶ '} | |
| {subGroup} | |
| </td> | |
| {/* For each metric column */} | |
| {overallMetrics | |
| .filter((metric) => selectedOverallMetrics.has(metric)) | |
| .map((metric) => ( | |
| // Render sub-columns for each model | |
| <React.Fragment key={`${group}-${subGroup}-${metric}`}> | |
| {tableHeader | |
| .filter((model) => selectedModels.has(model)) | |
| .map((col) => { | |
| // Find all metrics in this subgroup that match the current metric name | |
| const allMetricsWithName = | |
| findAllMetricsForName(metric) | |
| const metricsInSubgroupForThisMetric = | |
| visibleSubgroupMetrics.filter((m) => | |
| allMetricsWithName.includes(m) | |
| ) | |
| const stats = calculateStats( | |
| metricsInSubgroupForThisMetric, | |
| col | |
| ) | |
| return ( | |
| <td | |
| key={`${group}-${subGroup}-${metric}-${col}`} | |
| className="font-medium text-center border-gray-700 border" | |
| > | |
| {!isNaN(stats.avg) | |
| ? `${stats.avg.toFixed(3)} ± ${stats.stdDev.toFixed(3)}` | |
| : 'N/A'} | |
| </td> | |
| ) | |
| })} | |
| </React.Fragment> | |
| ))} | |
| </tr> | |
| {/* Individual metric rows */} | |
| {openSubGroups[group]?.[subGroup] && | |
| // Sort visibleSubgroupMetrics alphabetically by the clean metric name | |
| [...visibleSubgroupMetrics] | |
| .sort((a, b) => { | |
| // For metrics with format {category}_{strength}_{overall_metric_name}, | |
| // First sort by category, then by overall_metric_name, then by strength | |
| // First extract the overall metric group | |
| const getOverallMetricGroup = (metric: string) => { | |
| for (const overall of overallMetrics) { | |
| if ( | |
| metric.endsWith(`_${overall}`) || | |
| metric === overall | |
| ) { | |
| return overall | |
| } | |
| } | |
| return '' | |
| } | |
| const overallA = getOverallMetricGroup(a) | |
| const overallB = getOverallMetricGroup(b) | |
| // Extract the strength (last part before the overall metric) | |
| const stripOverall = (metric: string, overall: string) => { | |
| if (metric.endsWith(`_${overall}`)) { | |
| // Remove the overall metric group and any preceding underscore | |
| const stripped = metric.slice( | |
| 0, | |
| metric.length - overall.length - 1 | |
| ) | |
| const parts = stripped.split('_') | |
| return parts.length > 0 ? parts[parts.length - 1] : '' | |
| } | |
| return metric | |
| } | |
| // Extract the category (what remains after removing strength and overall_metric_name) | |
| const getCategory = (metric: string, overall: string) => { | |
| if (metric.endsWith(`_${overall}`)) { | |
| const stripped = metric.slice( | |
| 0, | |
| metric.length - overall.length - 1 | |
| ) | |
| const parts = stripped.split('_') | |
| // Remove the last part (strength) and join the rest (category) | |
| return parts.length > 1 | |
| ? parts.slice(0, parts.length - 1).join('_') | |
| : '' | |
| } | |
| return metric | |
| } | |
| const categoryA = getCategory(a, overallA) | |
| const categoryB = getCategory(b, overallB) | |
| // First sort by category | |
| if (categoryA !== categoryB) { | |
| return categoryA.localeCompare(categoryB) | |
| } | |
| // Then sort by overall metric name | |
| if (overallA !== overallB) { | |
| return overallA.localeCompare(overallB) | |
| } | |
| // Finally sort by strength | |
| const subA = stripOverall(a, overallA) | |
| const subB = stripOverall(b, overallB) | |
| // Try to parse subA and subB as numbers, handling k/m/b suffixes | |
| const parseNumber = (str: string) => { | |
| const match = str.match(/^(\d+(?:\.\d+)?)([kKmMbB]?)$/) | |
| if (!match) return NaN | |
| let [_, num, suffix] = match | |
| let value = parseFloat(num) | |
| switch (suffix.toLowerCase()) { | |
| case 'k': | |
| value *= 1e3 | |
| break | |
| case 'm': | |
| value *= 1e6 | |
| break | |
| case 'b': | |
| value *= 1e9 | |
| break | |
| } | |
| return value | |
| } | |
| const numA = parseNumber(subA) | |
| const numB = parseNumber(subB) | |
| if (!isNaN(numA) && !isNaN(numB)) { | |
| return numA - numB | |
| } | |
| // Fallback to string comparison if not both numbers | |
| return subA.localeCompare(subB) | |
| }) | |
| .map((metric) => { | |
| const row = tableRows.find((r) => r.metric === metric) | |
| if (!row) return null | |
| // Extract the metric name (after the underscore) | |
| const metricName = metric.includes('_') | |
| ? metric.split('_').slice(1).join('_') | |
| : metric | |
| return ( | |
| <tr key={metric} className="hover:bg-base-100"> | |
| <td className="sticky left-0 bg-base-100 z-10 pl-10 border-gray-700 border"> | |
| {metric} | |
| </td> | |
| {/* For each metric column */} | |
| {overallMetrics | |
| .filter((oMetric) => | |
| selectedOverallMetrics.has(oMetric) | |
| ) | |
| .map((oMetric) => { | |
| // Only show values for the matching metric | |
| const isMatchingMetric = | |
| findAllMetricsForName(oMetric).includes(metric) | |
| if (!isMatchingMetric) { | |
| // Fill empty cells for non-matching metrics | |
| return ( | |
| <React.Fragment key={`${metric}-${oMetric}`}> | |
| {tableHeader | |
| .filter((model) => | |
| selectedModels.has(model) | |
| ) | |
| .map((col) => ( | |
| <td | |
| key={`${metric}-${oMetric}-${col}`} | |
| className="text-center border-gray-700 border" | |
| ></td> | |
| ))} | |
| </React.Fragment> | |
| ) | |
| } | |
| return ( | |
| <React.Fragment key={`${metric}-${oMetric}`}> | |
| {tableHeader | |
| .filter((model) => selectedModels.has(model)) | |
| .map((col) => { | |
| const cell = row[col] | |
| return ( | |
| <td | |
| key={`${metric}-${oMetric}-${col}`} | |
| className="text-center border-gray-700 border" | |
| > | |
| {!isNaN(Number(cell)) | |
| ? Number(Number(cell).toFixed(3)) | |
| : cell} | |
| </td> | |
| ) | |
| })} | |
| </React.Fragment> | |
| ) | |
| })} | |
| </tr> | |
| ) | |
| })} | |
| </React.Fragment> | |
| ) | |
| })} | |
| </React.Fragment> | |
| ) | |
| })} | |
| </tbody> | |
| </table> | |
| </div> | |
| </> | |
| )} | |
| </div> | |
| )} | |
| </div> | |
| ) | |
| } | |
| export default LeaderboardTable | |