Spaces:
Sleeping
Sleeping
File size: 4,482 Bytes
c431adc 8103ea6 c431adc 8103ea6 20c3a43 c431adc d5455f4 c431adc d5455f4 c431adc d5455f4 c431adc d5455f4 c431adc d5455f4 c431adc d5455f4 c431adc d5455f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# routers/visualize.py
import logging
import os
from fastapi import APIRouter, HTTPException
from fastapi.responses import FileResponse
from schemas.visualize import (
VisualizeHeatmapRequest,
VisualizeMeanDiffRequest,
VisualizePCARequest,
)
from utils.visualize_pca import (
run_visualize_heatmap,
run_visualize_mean_diff,
run_visualize_pca,
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
router = APIRouter(
prefix="/visualize",
tags=["visualization"],
)
@router.post(
"/pca",
summary="Generates and returns the PCA visualization of activations",
response_class=FileResponse,
)
async def visualize_pca_endpoint(req: VisualizePCARequest):
"""
Receives the parameters, calls the wrapper for optipfair.bias.visualize_pca,
and returns the resulting PNG/SVG image.
"""
# 1. Execute the image generation and get the file path
try:
filepath = run_visualize_pca(
model_name=req.model_name,
prompt_pair=tuple(req.prompt_pair),
layer_key=req.layer_key,
highlight_diff=req.highlight_diff,
output_dir=req.output_dir,
figure_format=req.figure_format,
pair_index=req.pair_index,
)
except Exception as e:
# Log the full trace for debugging
logger.exception("❌ Error in visualize_pca_endpoint")
# And return the message to the client
raise HTTPException(status_code=500, detail=str(e))
# 2. Verify that the file exists
if not filepath or not os.path.isfile(filepath):
raise HTTPException(
status_code=500, detail="Image file not found after generation"
)
# 3. Return the file directly to the client
return FileResponse(
path=filepath,
media_type=f"image/{req.figure_format}",
filename=os.path.basename(filepath),
headers={
"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'
},
)
@router.post("/mean-diff", response_class=FileResponse)
async def visualize_mean_diff_endpoint(req: VisualizeMeanDiffRequest):
"""
Receives the parameters, calls the wrapper for optipfair.bias.visualize_mean_differences,
and returns the resulting PNG/SVG image.
"""
try:
filepath = run_visualize_mean_diff(
model_name=req.model_name,
prompt_pair=tuple(req.prompt_pair),
layer_type=req.layer_type, # Changed from layer_key to layer_type
figure_format=req.figure_format,
output_dir=req.output_dir,
pair_index=req.pair_index,
)
except Exception as e:
# Log the full trace for debugging
logger.exception("Error in mean-diff endpoint")
raise HTTPException(status_code=500, detail=str(e))
# Verify that the file exists
if not os.path.isfile(filepath):
raise HTTPException(status_code=500, detail="Image file not found")
# Return the file directly to the client
return FileResponse(
path=filepath,
media_type=f"image/{req.figure_format}",
filename=os.path.basename(filepath),
headers={
"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'
},
)
@router.post("/heatmap", response_class=FileResponse)
async def visualize_heatmap_endpoint(req: VisualizeHeatmapRequest):
"""
Receives the parameters, calls the wrapper for optipfair.bias.visualize_heatmap,
and returns the resulting PNG/SVG image.
"""
try:
filepath = run_visualize_heatmap(
model_name=req.model_name,
prompt_pair=tuple(req.prompt_pair),
layer_key=req.layer_key,
figure_format=req.figure_format,
output_dir=req.output_dir,
)
except Exception as e:
# Log the full trace for debugging
logger.exception("Error in heatmap endpoint")
raise HTTPException(status_code=500, detail=str(e))
# Verify that the file exists
if not os.path.isfile(filepath):
raise HTTPException(status_code=500, detail="Image file not found")
# Return the file directly to the client
return FileResponse(
path=filepath,
media_type=f"image/{req.figure_format}",
filename=os.path.basename(filepath),
headers={
"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'
},
)
|