File size: 4,482 Bytes
c431adc
 
8103ea6
 
c431adc
 
8103ea6
20c3a43
 
 
 
 
 
 
 
 
 
c431adc
 
 
 
 
 
 
 
 
d5455f4
c431adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5455f4
 
 
c431adc
 
 
 
 
 
d5455f4
 
 
c431adc
 
d5455f4
c431adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5455f4
 
 
c431adc
 
d5455f4
c431adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5455f4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# routers/visualize.py
import logging
import os

from fastapi import APIRouter, HTTPException
from fastapi.responses import FileResponse

from schemas.visualize import (
    VisualizeHeatmapRequest,
    VisualizeMeanDiffRequest,
    VisualizePCARequest,
)
from utils.visualize_pca import (
    run_visualize_heatmap,
    run_visualize_mean_diff,
    run_visualize_pca,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

router = APIRouter(
    prefix="/visualize",
    tags=["visualization"],
)


@router.post(
    "/pca",
    summary="Generates and returns the PCA visualization of activations",
    response_class=FileResponse,
)
async def visualize_pca_endpoint(req: VisualizePCARequest):
    """
    Receives the parameters, calls the wrapper for optipfair.bias.visualize_pca,
    and returns the resulting PNG/SVG image.
    """
    # 1. Execute the image generation and get the file path
    try:
        filepath = run_visualize_pca(
            model_name=req.model_name,
            prompt_pair=tuple(req.prompt_pair),
            layer_key=req.layer_key,
            highlight_diff=req.highlight_diff,
            output_dir=req.output_dir,
            figure_format=req.figure_format,
            pair_index=req.pair_index,
        )
    except Exception as e:
        # Log the full trace for debugging
        logger.exception("❌ Error in visualize_pca_endpoint")
        # And return the message to the client
        raise HTTPException(status_code=500, detail=str(e))
    # 2. Verify that the file exists
    if not filepath or not os.path.isfile(filepath):
        raise HTTPException(
            status_code=500, detail="Image file not found after generation"
        )

    # 3. Return the file directly to the client
    return FileResponse(
        path=filepath,
        media_type=f"image/{req.figure_format}",
        filename=os.path.basename(filepath),
        headers={
            "Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'
        },
    )


@router.post("/mean-diff", response_class=FileResponse)
async def visualize_mean_diff_endpoint(req: VisualizeMeanDiffRequest):
    """
    Receives the parameters, calls the wrapper for optipfair.bias.visualize_mean_differences,
    and returns the resulting PNG/SVG image.
    """
    try:
        filepath = run_visualize_mean_diff(
            model_name=req.model_name,
            prompt_pair=tuple(req.prompt_pair),
            layer_type=req.layer_type,  # Changed from layer_key to layer_type
            figure_format=req.figure_format,
            output_dir=req.output_dir,
            pair_index=req.pair_index,
        )
    except Exception as e:
        # Log the full trace for debugging
        logger.exception("Error in mean-diff endpoint")
        raise HTTPException(status_code=500, detail=str(e))

    # Verify that the file exists
    if not os.path.isfile(filepath):
        raise HTTPException(status_code=500, detail="Image file not found")

    # Return the file directly to the client
    return FileResponse(
        path=filepath,
        media_type=f"image/{req.figure_format}",
        filename=os.path.basename(filepath),
        headers={
            "Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'
        },
    )


@router.post("/heatmap", response_class=FileResponse)
async def visualize_heatmap_endpoint(req: VisualizeHeatmapRequest):
    """
    Receives the parameters, calls the wrapper for optipfair.bias.visualize_heatmap,
    and returns the resulting PNG/SVG image.
    """
    try:
        filepath = run_visualize_heatmap(
            model_name=req.model_name,
            prompt_pair=tuple(req.prompt_pair),
            layer_key=req.layer_key,
            figure_format=req.figure_format,
            output_dir=req.output_dir,
        )
    except Exception as e:
        # Log the full trace for debugging
        logger.exception("Error in heatmap endpoint")
        raise HTTPException(status_code=500, detail=str(e))

    # Verify that the file exists
    if not os.path.isfile(filepath):
        raise HTTPException(status_code=500, detail="Image file not found")

    # Return the file directly to the client
    return FileResponse(
        path=filepath,
        media_type=f"image/{req.figure_format}",
        filename=os.path.basename(filepath),
        headers={
            "Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'
        },
    )