Spaces:
Running
Running
File size: 6,353 Bytes
3cfb288 350a321 3cfb288 64e1853 c0d332e 0d97808 350a321 3cfb288 64e1853 350a321 64e1853 3cfb288 64e1853 3cfb288 64e1853 830f255 3cfb288 c0d332e 3cfb288 0d97808 c0d332e 3cfb288 fd99188 3cfb288 3c92847 6f68014 b2963cf a640135 482c35e b2963cf 6f68014 482c35e b2963cf d91c488 3cfb288 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import torch
import torch.nn as nn
import torch.hub
from huggingface_hub import hf_hub_download
from torchvision import transforms
from PIL import Image
import gradio as gr
import numpy as np
from datasets import load_dataset
from main import DinoRegressionHeteroImages
import mediapipe as mp
import time, random
# ============================
# Model definition
# ============================
# class DinoRegressionHeteroImages(nn.Module):
# def __init__(self, dino_model, hidden_dim=128, dropout=0.1, dino_dim=1024):
# super().__init__()
# self.dino = dino_model
# for p in self.dino.parameters():
# p.requires_grad = False
# self.embedding_to_hidden = nn.Linear(dino_dim, hidden_dim)
# self.leaky_relu = nn.LeakyReLU()
# self.dropout = nn.Dropout(dropout)
# self.hidden_to_hidden = nn.Linear(hidden_dim, hidden_dim)
# self.out_mu = nn.Linear(hidden_dim, 1)
# self.out_logvar = nn.Linear(hidden_dim, 1)
# def forward(self, x):
# h = self.dino(x)
# h = self.embedding_to_hidden(h)
# h = self.leaky_relu(h)
# h = self.dropout(h)
# h = self.hidden_to_hidden(h)
# h = self.leaky_relu(h)
# mu = self.out_mu(h).squeeze(1)
# logvar = self.out_logvar(h).squeeze(1)
# logvar = torch.clamp(logvar, -10.0, 3.0)
# return mu, logvar
# ============================
# Load model checkpoint
# ============================
repo_id = "SiyaYan/lifespan-predictor" # π change this
filename = "dino_finetuned_faces_l1_1024_best.pth"
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
# Auto-select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", device)
# Load Dino backbone
dino_backbone = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg").to(device)
# Build model
model = DinoRegressionHeteroImages(
dino_backbone,
hidden_dim=128,
dropout=0.01,
dino_dim=1024
).to(device)
# Load weights
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()
# ============================
# Dataset stats (denormalization)
# ============================
ds = load_dataset("TristanKE/RemainingLifespanPredictionFaces", split="train[:2000]")
remaining = np.array(ds["remaining_lifespan"], dtype=np.float32)
lifespan_mean = float(remaining.mean())
lifespan_std = float(remaining.std())
# ============================
# Preprocessing
# ============================
imgtransform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Lambda(lambda x: x.convert('RGB')),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
mp_face_detection = mp.solutions.face_detection
mp_drawing = mp.solutions.drawing_utils
def crop_face(img: Image.Image):
"""Detects the largest face and crops it, fallback to full image."""
img_rgb = np.array(img.convert("RGB"))
h, w, _ = img_rgb.shape
with mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) as face_det:
results = face_det.process(img_rgb)
if results.detections:
# take the first (largest) face
bbox = results.detections[0].location_data.relative_bounding_box
x1 = max(int(bbox.xmin * w), 0)
y1 = max(int(bbox.ymin * h), 0)
x2 = min(int((bbox.xmin + bbox.width) * w), w)
y2 = min(int((bbox.ymin + bbox.height) * h), h)
face = img.crop((x1, y1, x2, y2))
return face
# fallback: return original if no face found
return img
# ============================
# Prediction function
# ============================
def predict(img):
# Step 2: Fake progress updates
for i in range(3):
rand_val = random.randint(50, 120)
yield f"<div style='font-size:36px; font-weight:bold; color:#888; text-align:center;'>Estimating... {rand_val} yrs</div>"
time.sleep(0.5)
# Crop face
face = crop_face(img)
x = imgtransform(img).unsqueeze(0).to(device)
with torch.no_grad():
mu, logvar = model(x)
pred_years = mu.item() * lifespan_std + lifespan_mean
final_years = pred_years + 65
yield f"<div style='font-size:42px; font-weight:bold; color:#2E86AB; text-align:center;'>Estimated remaining lifespan: <br> {final_years:.1f} years</div>"
# ============================
# Gradio UI (Webcam + Upload)
# ============================
ds = load_dataset("TristanKE/RemainingLifespanPredictionFaces", split="train")
sample1 = ds[1]
img1= sample1["image"]
img1.save("example1.jpg")
sample2 = ds[2]
img2= sample2["image"]
img2.save("example2.jpg")
sample3 = ds[3]
img3= sample3["image"]
img3.save("example3.jpg")
demo = gr.Interface(
fn=predict,
inputs=gr.Image(height=500, type="pil", sources=["upload", "webcam"], label="Face Photo"),
outputs=gr.HTML(label="Prediction", elem_id="pred-box", value="<div style='font-size:36px; font-weight:bold; color:#2E86AB; text-align:center;'>Predict remaining lifespan: <br> ??? years</div>"),
live=False,
title="Lifespan Predictor (Demo)",
description="Upload or take a face photo. Model predicts remaining lifespan (years).",
examples=[["example1.jpg"],["example2.jpg"],["example3.jpg"]], # π make sure example.jpg exists in repo
css="""
#pred-box {
/* size & layout to mirror the image area */
height: 500px; /* match inputs=gr.Image(height=500) */
display: flex;
align-items: center;
justify-content: center;
text-align: center;
/* typography */
font-size: 42px;
font-weight: 700;
color: var(--body-text-color);
/* make it look like a Gradio card */
background: var(--block-background-fill);
border: 1px solid var(--border-color);
border-radius: var(--radius-lg);
box-shadow: var(--shadow-drop);
padding: 16px;
}
/* tighten the label spacing so it feels like the input card */
#pred-box + .wrap.svelte-1ipelgc, /* older gradio */
#pred-box + .container.svelte-1ipelgc, /* fallback */
#pred-box:has(+ *) { /* future-proof: keep label close */
margin-top: 0;
}
"""
)
demo.queue()
if __name__ == "__main__":
demo.launch()
|