File size: 6,353 Bytes
3cfb288
 
 
 
 
 
350a321
3cfb288
 
64e1853
c0d332e
0d97808
350a321
3cfb288
 
 
64e1853
 
 
 
 
 
350a321
64e1853
 
 
 
 
 
3cfb288
64e1853
 
 
 
 
 
 
 
 
 
 
3cfb288
 
 
 
64e1853
830f255
3cfb288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0d332e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cfb288
 
 
 
0d97808
 
 
 
 
 
 
c0d332e
 
 
3cfb288
 
 
 
fd99188
 
3cfb288
 
 
3c92847
6f68014
 
 
 
 
 
 
 
 
b2963cf
 
 
a640135
482c35e
b2963cf
 
 
6f68014
482c35e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2963cf
d91c488
3cfb288
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import torch
import torch.nn as nn
import torch.hub
from huggingface_hub import hf_hub_download
from torchvision import transforms
from PIL import Image
import gradio as gr
import numpy as np
from datasets import load_dataset
from main import DinoRegressionHeteroImages
import mediapipe as mp
import time, random

# ============================
# Model definition
# ============================
# class DinoRegressionHeteroImages(nn.Module):
#     def __init__(self, dino_model, hidden_dim=128, dropout=0.1, dino_dim=1024):
#         super().__init__()
#         self.dino = dino_model
#         for p in self.dino.parameters():
#             p.requires_grad = False 

#         self.embedding_to_hidden = nn.Linear(dino_dim, hidden_dim)
#         self.leaky_relu = nn.LeakyReLU()
#         self.dropout = nn.Dropout(dropout)
#         self.hidden_to_hidden = nn.Linear(hidden_dim, hidden_dim)
#         self.out_mu = nn.Linear(hidden_dim, 1)
#         self.out_logvar = nn.Linear(hidden_dim, 1)

#     def forward(self, x):
#         h = self.dino(x)  
#         h = self.embedding_to_hidden(h)
#         h = self.leaky_relu(h)
#         h = self.dropout(h)
#         h = self.hidden_to_hidden(h)
#         h = self.leaky_relu(h)
#         mu = self.out_mu(h).squeeze(1)
#         logvar = self.out_logvar(h).squeeze(1)
#         logvar = torch.clamp(logvar, -10.0, 3.0)
#         return mu, logvar

# ============================
# Load model checkpoint
# ============================
repo_id = "SiyaYan/lifespan-predictor"  # πŸ‘ˆ change this
filename = "dino_finetuned_faces_l1_1024_best.pth"

ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)

# Auto-select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", device)

# Load Dino backbone
dino_backbone = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg").to(device)

# Build model
model = DinoRegressionHeteroImages(
    dino_backbone,
    hidden_dim=128,
    dropout=0.01,
    dino_dim=1024
).to(device)

# Load weights
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

# ============================
# Dataset stats (denormalization)
# ============================
ds = load_dataset("TristanKE/RemainingLifespanPredictionFaces", split="train[:2000]")
remaining = np.array(ds["remaining_lifespan"], dtype=np.float32)
lifespan_mean = float(remaining.mean())
lifespan_std  = float(remaining.std())

# ============================
# Preprocessing
# ============================
imgtransform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Lambda(lambda x: x.convert('RGB')),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])


mp_face_detection = mp.solutions.face_detection
mp_drawing = mp.solutions.drawing_utils

def crop_face(img: Image.Image):
    """Detects the largest face and crops it, fallback to full image."""
    img_rgb = np.array(img.convert("RGB"))
    h, w, _ = img_rgb.shape

    with mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) as face_det:
        results = face_det.process(img_rgb)

        if results.detections:
            # take the first (largest) face
            bbox = results.detections[0].location_data.relative_bounding_box
            x1 = max(int(bbox.xmin * w), 0)
            y1 = max(int(bbox.ymin * h), 0)
            x2 = min(int((bbox.xmin + bbox.width) * w), w)
            y2 = min(int((bbox.ymin + bbox.height) * h), h)

            face = img.crop((x1, y1, x2, y2))
            return face

    # fallback: return original if no face found
    return img


# ============================
# Prediction function
# ============================
def predict(img):

    # Step 2: Fake progress updates
    for i in range(3):
        rand_val = random.randint(50, 120)
        yield f"<div style='font-size:36px; font-weight:bold; color:#888; text-align:center;'>Estimating... {rand_val} yrs</div>"
        time.sleep(0.5)
        
    # Crop face
    face = crop_face(img)
    
    x = imgtransform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        mu, logvar = model(x)
    pred_years = mu.item() * lifespan_std + lifespan_mean
    final_years = pred_years + 65
    yield f"<div style='font-size:42px; font-weight:bold; color:#2E86AB; text-align:center;'>Estimated remaining lifespan: <br> {final_years:.1f} years</div>"
# ============================
# Gradio UI (Webcam + Upload)
# ============================
ds = load_dataset("TristanKE/RemainingLifespanPredictionFaces", split="train")
sample1 = ds[1]
img1= sample1["image"]
img1.save("example1.jpg") 
sample2 = ds[2]
img2= sample2["image"]
img2.save("example2.jpg") 
sample3 = ds[3]
img3= sample3["image"]
img3.save("example3.jpg") 

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(height=500, type="pil", sources=["upload", "webcam"], label="Face Photo"),
    outputs=gr.HTML(label="Prediction", elem_id="pred-box", value="<div style='font-size:36px; font-weight:bold; color:#2E86AB; text-align:center;'>Predict remaining lifespan: <br> ??? years</div>"), 
    live=False,
    title="Lifespan Predictor (Demo)",
    description="Upload or take a face photo. Model predicts remaining lifespan (years).",
    examples=[["example1.jpg"],["example2.jpg"],["example3.jpg"]],  # πŸ‘ˆ make sure example.jpg exists in repo
css="""
#pred-box {
    /* size & layout to mirror the image area */
    height: 500px;                      /* match inputs=gr.Image(height=500) */
    display: flex;
    align-items: center;
    justify-content: center;
    text-align: center;

    /* typography */
    font-size: 42px;
    font-weight: 700;
    color: var(--body-text-color);

    /* make it look like a Gradio card */
    background: var(--block-background-fill);
    border: 1px solid var(--border-color);
    border-radius: var(--radius-lg);
    box-shadow: var(--shadow-drop);
    padding: 16px;
}

/* tighten the label spacing so it feels like the input card */
#pred-box + .wrap.svelte-1ipelgc,       /* older gradio */
#pred-box + .container.svelte-1ipelgc,  /* fallback */
#pred-box:has(+ *) {                    /* future-proof: keep label close */
    margin-top: 0;
}
"""
)
demo.queue()
if __name__ == "__main__":
    demo.launch()