Spaces:
Sleeping
Sleeping
| """ | |
| Utility classes and functions for the GuardBench Leaderboard display. | |
| """ | |
| from dataclasses import dataclass, field, fields | |
| from enum import Enum, auto | |
| from typing import List, Optional | |
| class ModelType(Enum): | |
| """Model types for the leaderboard.""" | |
| Unknown = auto() | |
| OpenSource = auto() | |
| ClosedSource = auto() | |
| API = auto() | |
| def to_str(self, separator: str = " ") -> str: | |
| """Convert enum to string with separator.""" | |
| if self == ModelType.Unknown: | |
| return "Unknown" | |
| elif self == ModelType.OpenSource: | |
| return f"Open{separator}Source" | |
| elif self == ModelType.ClosedSource: | |
| return f"Closed{separator}Source" | |
| elif self == ModelType.API: | |
| return "API" | |
| return "Unknown" | |
| class Precision(Enum): | |
| """Model precision types.""" | |
| Unknown = auto() | |
| float16 = auto() | |
| bfloat16 = auto() | |
| float32 = auto() | |
| int8 = auto() | |
| int4 = auto() | |
| def __str__(self): | |
| """String representation of the precision type.""" | |
| return self.name | |
| class WeightType(Enum): | |
| """Model weight types.""" | |
| Original = auto() | |
| Delta = auto() | |
| Adapter = auto() | |
| def __str__(self): | |
| """String representation of the weight type.""" | |
| return self.name | |
| class ColumnInfo: | |
| """Information about a column in the leaderboard.""" | |
| name: str | |
| display_name: str | |
| type: str = "text" | |
| hidden: bool = False | |
| never_hidden: bool = False | |
| displayed_by_default: bool = True | |
| class GuardBenchColumn: | |
| """Columns for the GuardBench leaderboard.""" | |
| model_name: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
| name="model_name", | |
| display_name="Model", | |
| never_hidden=True, | |
| displayed_by_default=True | |
| )) | |
| model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
| name="model_type", | |
| display_name="Type", | |
| displayed_by_default=True | |
| )) | |
| # Metrics for all categories | |
| default_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
| name="default_prompts_f1", | |
| display_name="Default Prompts F1", | |
| type="number", | |
| displayed_by_default=True | |
| )) | |
| jailbreaked_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
| name="jailbreaked_prompts_f1", | |
| display_name="Jailbreaked Prompts F1", | |
| type="number", | |
| displayed_by_default=True | |
| )) | |
| default_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
| name="default_answers_f1", | |
| display_name="Default Answers F1", | |
| type="number", | |
| displayed_by_default=True | |
| )) | |
| jailbreaked_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
| name="jailbreaked_answers_f1", | |
| display_name="Jailbreaked Answers F1", | |
| type="number", | |
| displayed_by_default=True | |
| )) | |
| # Average metrics | |
| average_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
| name="average_f1", | |
| display_name="Average F1", | |
| type="number", | |
| displayed_by_default=True, | |
| never_hidden=True | |
| )) | |
| average_recall: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
| name="average_recall", | |
| display_name="Average Recall", | |
| type="number", | |
| displayed_by_default=False | |
| )) | |
| average_precision: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
| name="average_precision", | |
| display_name="Average Precision", | |
| type="number", | |
| displayed_by_default=False | |
| )) | |
| # Additional metadata | |
| submission_date: ColumnInfo = field(default_factory=lambda: ColumnInfo( | |
| name="submission_date", | |
| display_name="Submission Date", | |
| displayed_by_default=False | |
| )) | |
| # Create instances for easy access | |
| GUARDBENCH_COLUMN = GuardBenchColumn() | |
| # Extract column lists for different views | |
| COLS = [f.name for f in fields(GUARDBENCH_COLUMN)] | |
| DISPLAY_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
| if getattr(GUARDBENCH_COLUMN, f.name).displayed_by_default] | |
| METRIC_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
| if getattr(GUARDBENCH_COLUMN, f.name).type == "number"] | |
| HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
| if getattr(GUARDBENCH_COLUMN, f.name).hidden] | |
| NEVER_HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) | |
| if getattr(GUARDBENCH_COLUMN, f.name).never_hidden] | |
| # Categories in GuardBench | |
| CATEGORIES = [ | |
| "Criminal, Violent, and Terrorist Activity", | |
| "Manipulation, Deception, and Misinformation", | |
| "Creative Content Involving Illicit Themes", | |
| "Sexual Content and Violence", | |
| "Political Corruption and Legal Evasion", | |
| "Labor Exploitation and Human Trafficking", | |
| "Environmental and Industrial Harm", | |
| "Animal Cruelty and Exploitation", | |
| "Self–Harm and Suicidal Ideation", | |
| "Safe Prompts" | |
| ] | |
| # Test types in GuardBench | |
| TEST_TYPES = [ | |
| "default_prompts", | |
| "jailbreaked_prompts", | |
| "default_answers", | |
| "jailbreaked_answers" | |
| ] | |
| # Metrics in GuardBench | |
| METRICS = [ | |
| "f1_binary", | |
| "recall_binary", | |
| "precision_binary", | |
| "error_ratio", | |
| "avg_runtime_ms" | |
| ] | |