import torch import torch.nn as nn import torch.hub from huggingface_hub import hf_hub_download from torchvision import transforms from PIL import Image import gradio as gr import numpy as np from datasets import load_dataset from main import DinoRegressionHeteroImages import mediapipe as mp import time, random # ============================ # Model definition # ============================ # class DinoRegressionHeteroImages(nn.Module): # def __init__(self, dino_model, hidden_dim=128, dropout=0.1, dino_dim=1024): # super().__init__() # self.dino = dino_model # for p in self.dino.parameters(): # p.requires_grad = False # self.embedding_to_hidden = nn.Linear(dino_dim, hidden_dim) # self.leaky_relu = nn.LeakyReLU() # self.dropout = nn.Dropout(dropout) # self.hidden_to_hidden = nn.Linear(hidden_dim, hidden_dim) # self.out_mu = nn.Linear(hidden_dim, 1) # self.out_logvar = nn.Linear(hidden_dim, 1) # def forward(self, x): # h = self.dino(x) # h = self.embedding_to_hidden(h) # h = self.leaky_relu(h) # h = self.dropout(h) # h = self.hidden_to_hidden(h) # h = self.leaky_relu(h) # mu = self.out_mu(h).squeeze(1) # logvar = self.out_logvar(h).squeeze(1) # logvar = torch.clamp(logvar, -10.0, 3.0) # return mu, logvar # ============================ # Load model checkpoint # ============================ repo_id = "SiyaYan/lifespan-predictor" # 👈 change this filename = "dino_finetuned_faces_l1_1024_best.pth" ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) # Auto-select device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Running on:", device) # Load Dino backbone dino_backbone = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg").to(device) # Build model model = DinoRegressionHeteroImages( dino_backbone, hidden_dim=128, dropout=0.01, dino_dim=1024 ).to(device) # Load weights model.load_state_dict(torch.load(ckpt_path, map_location=device)) model.eval() # ============================ # Dataset stats (denormalization) # ============================ ds = load_dataset("TristanKE/RemainingLifespanPredictionFaces", split="train[:2000]") remaining = np.array(ds["remaining_lifespan"], dtype=np.float32) lifespan_mean = float(remaining.mean()) lifespan_std = float(remaining.std()) # ============================ # Preprocessing # ============================ imgtransform = transforms.Compose([ transforms.Resize((224, 224)), transforms.Lambda(lambda x: x.convert('RGB')), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) mp_face_detection = mp.solutions.face_detection mp_drawing = mp.solutions.drawing_utils def crop_face(img: Image.Image): """Detects the largest face and crops it, fallback to full image.""" img_rgb = np.array(img.convert("RGB")) h, w, _ = img_rgb.shape with mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) as face_det: results = face_det.process(img_rgb) if results.detections: # take the first (largest) face bbox = results.detections[0].location_data.relative_bounding_box x1 = max(int(bbox.xmin * w), 0) y1 = max(int(bbox.ymin * h), 0) x2 = min(int((bbox.xmin + bbox.width) * w), w) y2 = min(int((bbox.ymin + bbox.height) * h), h) face = img.crop((x1, y1, x2, y2)) return face # fallback: return original if no face found return img # ============================ # Prediction function # ============================ def predict(img): # Step 2: Fake progress updates for i in range(3): rand_val = random.randint(50, 120) yield f"