import torch import torch.nn as nn import torch.hub from huggingface_hub import hf_hub_download from torchvision import transforms from PIL import Image import gradio as gr import numpy as np from datasets import load_dataset from main import DinoRegressionHeteroImages import mediapipe as mp import time, random # ============================ # Model definition # ============================ # class DinoRegressionHeteroImages(nn.Module): # def __init__(self, dino_model, hidden_dim=128, dropout=0.1, dino_dim=1024): # super().__init__() # self.dino = dino_model # for p in self.dino.parameters(): # p.requires_grad = False # self.embedding_to_hidden = nn.Linear(dino_dim, hidden_dim) # self.leaky_relu = nn.LeakyReLU() # self.dropout = nn.Dropout(dropout) # self.hidden_to_hidden = nn.Linear(hidden_dim, hidden_dim) # self.out_mu = nn.Linear(hidden_dim, 1) # self.out_logvar = nn.Linear(hidden_dim, 1) # def forward(self, x): # h = self.dino(x) # h = self.embedding_to_hidden(h) # h = self.leaky_relu(h) # h = self.dropout(h) # h = self.hidden_to_hidden(h) # h = self.leaky_relu(h) # mu = self.out_mu(h).squeeze(1) # logvar = self.out_logvar(h).squeeze(1) # logvar = torch.clamp(logvar, -10.0, 3.0) # return mu, logvar # ============================ # Load model checkpoint # ============================ repo_id = "SiyaYan/lifespan-predictor" # 👈 change this filename = "dino_finetuned_faces_l1_1024_best.pth" ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) # Auto-select device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Running on:", device) # Load Dino backbone dino_backbone = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg").to(device) # Build model model = DinoRegressionHeteroImages( dino_backbone, hidden_dim=128, dropout=0.01, dino_dim=1024 ).to(device) # Load weights model.load_state_dict(torch.load(ckpt_path, map_location=device)) model.eval() # ============================ # Dataset stats (denormalization) # ============================ ds = load_dataset("TristanKE/RemainingLifespanPredictionFaces", split="train[:2000]") remaining = np.array(ds["remaining_lifespan"], dtype=np.float32) lifespan_mean = float(remaining.mean()) lifespan_std = float(remaining.std()) # ============================ # Preprocessing # ============================ imgtransform = transforms.Compose([ transforms.Resize((224, 224)), transforms.Lambda(lambda x: x.convert('RGB')), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) mp_face_detection = mp.solutions.face_detection mp_drawing = mp.solutions.drawing_utils def crop_face(img: Image.Image): """Detects the largest face and crops it, fallback to full image.""" img_rgb = np.array(img.convert("RGB")) h, w, _ = img_rgb.shape with mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) as face_det: results = face_det.process(img_rgb) if results.detections: # take the first (largest) face bbox = results.detections[0].location_data.relative_bounding_box x1 = max(int(bbox.xmin * w), 0) y1 = max(int(bbox.ymin * h), 0) x2 = min(int((bbox.xmin + bbox.width) * w), w) y2 = min(int((bbox.ymin + bbox.height) * h), h) face = img.crop((x1, y1, x2, y2)) return face # fallback: return original if no face found return img # ============================ # Prediction function # ============================ def predict(img): # Step 2: Fake progress updates for i in range(3): rand_val = random.randint(50, 120) yield f"
Estimating... {rand_val} yrs
" time.sleep(0.5) # Crop face face = crop_face(img) x = imgtransform(img).unsqueeze(0).to(device) with torch.no_grad(): mu, logvar = model(x) pred_years = mu.item() * lifespan_std + lifespan_mean final_years = pred_years + 65 yield f"
Estimated remaining lifespan:
{final_years:.1f} years
" # ============================ # Gradio UI (Webcam + Upload) # ============================ ds = load_dataset("TristanKE/RemainingLifespanPredictionFaces", split="train") sample1 = ds[1] img1= sample1["image"] img1.save("example1.jpg") sample2 = ds[2] img2= sample2["image"] img2.save("example2.jpg") sample3 = ds[3] img3= sample3["image"] img3.save("example3.jpg") demo = gr.Interface( fn=predict, inputs=gr.Image(height=500, type="pil", sources=["upload", "webcam"], label="Face Photo"), outputs=gr.HTML(label="Prediction", elem_id="pred-box", value="
Predict remaining lifespan:
??? years
"), live=False, title="Lifespan Predictor (Demo)", description="Upload or take a face photo. Model predicts remaining lifespan (years).", examples=[["example1.jpg"],["example2.jpg"],["example3.jpg"]], # 👈 make sure example.jpg exists in repo css=""" #pred-box { /* size & layout to mirror the image area */ height: 500px; /* match inputs=gr.Image(height=500) */ display: flex; align-items: center; justify-content: center; text-align: center; /* typography */ font-size: 42px; font-weight: 700; color: var(--body-text-color); /* make it look like a Gradio card */ background: var(--block-background-fill); border: 1px solid var(--border-color); border-radius: var(--radius-lg); box-shadow: var(--shadow-drop); padding: 16px; } /* tighten the label spacing so it feels like the input card */ #pred-box + .wrap.svelte-1ipelgc, /* older gradio */ #pred-box + .container.svelte-1ipelgc, /* fallback */ #pred-box:has(+ *) { /* future-proof: keep label close */ margin-top: 0; } """ ) demo.queue() if __name__ == "__main__": demo.launch()