Spaces:
Running
Running
| # schemas/visualize.py | |
| from pydantic import BaseModel, field_validator | |
| from typing import List, Optional, Union, Tuple | |
| class VisualizePCARequest(BaseModel): | |
| """ | |
| Schema for the /visualize-pca endpoint. | |
| """ | |
| model_name: str | |
| prompt_pair: List[str] | |
| layer_key: str | |
| highlight_diff: bool = True | |
| figure_format: str = "png" | |
| pair_index: int = 0 | |
| output_dir: Optional[str] = None | |
| def must_be_two_prompts(cls, v): | |
| if len(v) != 2: | |
| raise ValueError("prompt_pair must be a list of exactly two strings") | |
| return v | |
| class VisualizeMeanDiffRequest(BaseModel): | |
| model_name: str | |
| prompt_pair: List[str] | |
| layer_type: str # Changed from layer_key to layer_type | |
| figure_format: str = "png" | |
| output_dir: Optional[str] = None | |
| pair_index: int = 0 | |
| def must_be_two_prompts(cls, v): | |
| if len(v) != 2: | |
| raise ValueError("prompt_pair must be a list of exactly two strings") | |
| return v | |
| class VisualizeHeatmapRequest(BaseModel): | |
| """ | |
| Schema for the /visualize/heatmap endpoint. | |
| """ | |
| model_name: str | |
| prompt_pair: List[str] | |
| layer_key: str | |
| figure_format: str = "png" | |
| output_dir: Optional[str] = None | |
| def must_be_two_prompts(cls, v): | |
| if len(v) != 2: | |
| raise ValueError("prompt_pair must be a list of exactly two strings") | |
| return v | |