SiyaYan's picture
minor fix
fd99188 verified
import torch
import torch.nn as nn
import torch.hub
from huggingface_hub import hf_hub_download
from torchvision import transforms
from PIL import Image
import gradio as gr
import numpy as np
from datasets import load_dataset
from main import DinoRegressionHeteroImages
import mediapipe as mp
import time, random
# ============================
# Model definition
# ============================
# class DinoRegressionHeteroImages(nn.Module):
# def __init__(self, dino_model, hidden_dim=128, dropout=0.1, dino_dim=1024):
# super().__init__()
# self.dino = dino_model
# for p in self.dino.parameters():
# p.requires_grad = False
# self.embedding_to_hidden = nn.Linear(dino_dim, hidden_dim)
# self.leaky_relu = nn.LeakyReLU()
# self.dropout = nn.Dropout(dropout)
# self.hidden_to_hidden = nn.Linear(hidden_dim, hidden_dim)
# self.out_mu = nn.Linear(hidden_dim, 1)
# self.out_logvar = nn.Linear(hidden_dim, 1)
# def forward(self, x):
# h = self.dino(x)
# h = self.embedding_to_hidden(h)
# h = self.leaky_relu(h)
# h = self.dropout(h)
# h = self.hidden_to_hidden(h)
# h = self.leaky_relu(h)
# mu = self.out_mu(h).squeeze(1)
# logvar = self.out_logvar(h).squeeze(1)
# logvar = torch.clamp(logvar, -10.0, 3.0)
# return mu, logvar
# ============================
# Load model checkpoint
# ============================
repo_id = "SiyaYan/lifespan-predictor" # πŸ‘ˆ change this
filename = "dino_finetuned_faces_l1_1024_best.pth"
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
# Auto-select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", device)
# Load Dino backbone
dino_backbone = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg").to(device)
# Build model
model = DinoRegressionHeteroImages(
dino_backbone,
hidden_dim=128,
dropout=0.01,
dino_dim=1024
).to(device)
# Load weights
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()
# ============================
# Dataset stats (denormalization)
# ============================
ds = load_dataset("TristanKE/RemainingLifespanPredictionFaces", split="train[:2000]")
remaining = np.array(ds["remaining_lifespan"], dtype=np.float32)
lifespan_mean = float(remaining.mean())
lifespan_std = float(remaining.std())
# ============================
# Preprocessing
# ============================
imgtransform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Lambda(lambda x: x.convert('RGB')),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
mp_face_detection = mp.solutions.face_detection
mp_drawing = mp.solutions.drawing_utils
def crop_face(img: Image.Image):
"""Detects the largest face and crops it, fallback to full image."""
img_rgb = np.array(img.convert("RGB"))
h, w, _ = img_rgb.shape
with mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) as face_det:
results = face_det.process(img_rgb)
if results.detections:
# take the first (largest) face
bbox = results.detections[0].location_data.relative_bounding_box
x1 = max(int(bbox.xmin * w), 0)
y1 = max(int(bbox.ymin * h), 0)
x2 = min(int((bbox.xmin + bbox.width) * w), w)
y2 = min(int((bbox.ymin + bbox.height) * h), h)
face = img.crop((x1, y1, x2, y2))
return face
# fallback: return original if no face found
return img
# ============================
# Prediction function
# ============================
def predict(img):
# Step 2: Fake progress updates
for i in range(3):
rand_val = random.randint(50, 120)
yield f"<div style='font-size:36px; font-weight:bold; color:#888; text-align:center;'>Estimating... {rand_val} yrs</div>"
time.sleep(0.5)
# Crop face
face = crop_face(img)
x = imgtransform(img).unsqueeze(0).to(device)
with torch.no_grad():
mu, logvar = model(x)
pred_years = mu.item() * lifespan_std + lifespan_mean
final_years = pred_years + 65
yield f"<div style='font-size:42px; font-weight:bold; color:#2E86AB; text-align:center;'>Estimated remaining lifespan: <br> {final_years:.1f} years</div>"
# ============================
# Gradio UI (Webcam + Upload)
# ============================
ds = load_dataset("TristanKE/RemainingLifespanPredictionFaces", split="train")
sample1 = ds[1]
img1= sample1["image"]
img1.save("example1.jpg")
sample2 = ds[2]
img2= sample2["image"]
img2.save("example2.jpg")
sample3 = ds[3]
img3= sample3["image"]
img3.save("example3.jpg")
demo = gr.Interface(
fn=predict,
inputs=gr.Image(height=500, type="pil", sources=["upload", "webcam"], label="Face Photo"),
outputs=gr.HTML(label="Prediction", elem_id="pred-box", value="<div style='font-size:36px; font-weight:bold; color:#2E86AB; text-align:center;'>Predict remaining lifespan: <br> ??? years</div>"),
live=False,
title="Lifespan Predictor (Demo)",
description="Upload or take a face photo. Model predicts remaining lifespan (years).",
examples=[["example1.jpg"],["example2.jpg"],["example3.jpg"]], # πŸ‘ˆ make sure example.jpg exists in repo
css="""
#pred-box {
/* size & layout to mirror the image area */
height: 500px; /* match inputs=gr.Image(height=500) */
display: flex;
align-items: center;
justify-content: center;
text-align: center;
/* typography */
font-size: 42px;
font-weight: 700;
color: var(--body-text-color);
/* make it look like a Gradio card */
background: var(--block-background-fill);
border: 1px solid var(--border-color);
border-radius: var(--radius-lg);
box-shadow: var(--shadow-drop);
padding: 16px;
}
/* tighten the label spacing so it feels like the input card */
#pred-box + .wrap.svelte-1ipelgc, /* older gradio */
#pred-box + .container.svelte-1ipelgc, /* fallback */
#pred-box:has(+ *) { /* future-proof: keep label close */
margin-top: 0;
}
"""
)
demo.queue()
if __name__ == "__main__":
demo.launch()