Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.hub | |
| from huggingface_hub import hf_hub_download | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import numpy as np | |
| from datasets import load_dataset | |
| from main import DinoRegressionHeteroImages | |
| import mediapipe as mp | |
| import time, random | |
| # ============================ | |
| # Model definition | |
| # ============================ | |
| # class DinoRegressionHeteroImages(nn.Module): | |
| # def __init__(self, dino_model, hidden_dim=128, dropout=0.1, dino_dim=1024): | |
| # super().__init__() | |
| # self.dino = dino_model | |
| # for p in self.dino.parameters(): | |
| # p.requires_grad = False | |
| # self.embedding_to_hidden = nn.Linear(dino_dim, hidden_dim) | |
| # self.leaky_relu = nn.LeakyReLU() | |
| # self.dropout = nn.Dropout(dropout) | |
| # self.hidden_to_hidden = nn.Linear(hidden_dim, hidden_dim) | |
| # self.out_mu = nn.Linear(hidden_dim, 1) | |
| # self.out_logvar = nn.Linear(hidden_dim, 1) | |
| # def forward(self, x): | |
| # h = self.dino(x) | |
| # h = self.embedding_to_hidden(h) | |
| # h = self.leaky_relu(h) | |
| # h = self.dropout(h) | |
| # h = self.hidden_to_hidden(h) | |
| # h = self.leaky_relu(h) | |
| # mu = self.out_mu(h).squeeze(1) | |
| # logvar = self.out_logvar(h).squeeze(1) | |
| # logvar = torch.clamp(logvar, -10.0, 3.0) | |
| # return mu, logvar | |
| # ============================ | |
| # Load model checkpoint | |
| # ============================ | |
| repo_id = "SiyaYan/lifespan-predictor" # π change this | |
| filename = "dino_finetuned_faces_l1_1024_best.pth" | |
| ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| # Auto-select device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Running on:", device) | |
| # Load Dino backbone | |
| dino_backbone = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg").to(device) | |
| # Build model | |
| model = DinoRegressionHeteroImages( | |
| dino_backbone, | |
| hidden_dim=128, | |
| dropout=0.01, | |
| dino_dim=1024 | |
| ).to(device) | |
| # Load weights | |
| model.load_state_dict(torch.load(ckpt_path, map_location=device)) | |
| model.eval() | |
| # ============================ | |
| # Dataset stats (denormalization) | |
| # ============================ | |
| ds = load_dataset("TristanKE/RemainingLifespanPredictionFaces", split="train[:2000]") | |
| remaining = np.array(ds["remaining_lifespan"], dtype=np.float32) | |
| lifespan_mean = float(remaining.mean()) | |
| lifespan_std = float(remaining.std()) | |
| # ============================ | |
| # Preprocessing | |
| # ============================ | |
| imgtransform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.Lambda(lambda x: x.convert('RGB')), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| mp_face_detection = mp.solutions.face_detection | |
| mp_drawing = mp.solutions.drawing_utils | |
| def crop_face(img: Image.Image): | |
| """Detects the largest face and crops it, fallback to full image.""" | |
| img_rgb = np.array(img.convert("RGB")) | |
| h, w, _ = img_rgb.shape | |
| with mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) as face_det: | |
| results = face_det.process(img_rgb) | |
| if results.detections: | |
| # take the first (largest) face | |
| bbox = results.detections[0].location_data.relative_bounding_box | |
| x1 = max(int(bbox.xmin * w), 0) | |
| y1 = max(int(bbox.ymin * h), 0) | |
| x2 = min(int((bbox.xmin + bbox.width) * w), w) | |
| y2 = min(int((bbox.ymin + bbox.height) * h), h) | |
| face = img.crop((x1, y1, x2, y2)) | |
| return face | |
| # fallback: return original if no face found | |
| return img | |
| # ============================ | |
| # Prediction function | |
| # ============================ | |
| def predict(img): | |
| # Step 2: Fake progress updates | |
| for i in range(3): | |
| rand_val = random.randint(50, 120) | |
| yield f"<div style='font-size:36px; font-weight:bold; color:#888; text-align:center;'>Estimating... {rand_val} yrs</div>" | |
| time.sleep(0.5) | |
| # Crop face | |
| face = crop_face(img) | |
| x = imgtransform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| mu, logvar = model(x) | |
| pred_years = mu.item() * lifespan_std + lifespan_mean | |
| final_years = pred_years + 65 | |
| yield f"<div style='font-size:42px; font-weight:bold; color:#2E86AB; text-align:center;'>Estimated remaining lifespan: <br> {final_years:.1f} years</div>" | |
| # ============================ | |
| # Gradio UI (Webcam + Upload) | |
| # ============================ | |
| ds = load_dataset("TristanKE/RemainingLifespanPredictionFaces", split="train") | |
| sample1 = ds[1] | |
| img1= sample1["image"] | |
| img1.save("example1.jpg") | |
| sample2 = ds[2] | |
| img2= sample2["image"] | |
| img2.save("example2.jpg") | |
| sample3 = ds[3] | |
| img3= sample3["image"] | |
| img3.save("example3.jpg") | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(height=500, type="pil", sources=["upload", "webcam"], label="Face Photo"), | |
| outputs=gr.HTML(label="Prediction", elem_id="pred-box", value="<div style='font-size:36px; font-weight:bold; color:#2E86AB; text-align:center;'>Predict remaining lifespan: <br> ??? years</div>"), | |
| live=False, | |
| title="Lifespan Predictor (Demo)", | |
| description="Upload or take a face photo. Model predicts remaining lifespan (years).", | |
| examples=[["example1.jpg"],["example2.jpg"],["example3.jpg"]], # π make sure example.jpg exists in repo | |
| css=""" | |
| #pred-box { | |
| /* size & layout to mirror the image area */ | |
| height: 500px; /* match inputs=gr.Image(height=500) */ | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| text-align: center; | |
| /* typography */ | |
| font-size: 42px; | |
| font-weight: 700; | |
| color: var(--body-text-color); | |
| /* make it look like a Gradio card */ | |
| background: var(--block-background-fill); | |
| border: 1px solid var(--border-color); | |
| border-radius: var(--radius-lg); | |
| box-shadow: var(--shadow-drop); | |
| padding: 16px; | |
| } | |
| /* tighten the label spacing so it feels like the input card */ | |
| #pred-box + .wrap.svelte-1ipelgc, /* older gradio */ | |
| #pred-box + .container.svelte-1ipelgc, /* fallback */ | |
| #pred-box:has(+ *) { /* future-proof: keep label close */ | |
| margin-top: 0; | |
| } | |
| """ | |
| ) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() | |