Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,6 +8,7 @@ import gradio as gr
|
|
| 8 |
import numpy as np
|
| 9 |
from datasets import load_dataset
|
| 10 |
from main import DinoRegressionHeteroImages
|
|
|
|
| 11 |
|
| 12 |
# ============================
|
| 13 |
# Model definition
|
|
@@ -84,15 +85,46 @@ imgtransform = transforms.Compose([
|
|
| 84 |
std=[0.229, 0.224, 0.225]),
|
| 85 |
])
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
# ============================
|
| 88 |
# Prediction function
|
| 89 |
# ============================
|
| 90 |
def predict(img):
|
|
|
|
|
|
|
|
|
|
| 91 |
x = imgtransform(img).unsqueeze(0).to(device)
|
| 92 |
with torch.no_grad():
|
| 93 |
mu, logvar = model(x)
|
| 94 |
pred_years = mu.item() * lifespan_std + lifespan_mean
|
| 95 |
-
|
|
|
|
| 96 |
|
| 97 |
# ============================
|
| 98 |
# Gradio UI (Webcam + Upload)
|
|
@@ -105,7 +137,7 @@ img.save("example.jpg")
|
|
| 105 |
demo = gr.Interface(
|
| 106 |
fn=predict,
|
| 107 |
inputs=gr.Image(type="pil", sources=["upload", "webcam"], label="Face Photo"),
|
| 108 |
-
outputs=gr.
|
| 109 |
live=False,
|
| 110 |
title="Lifespan Predictor (Demo)",
|
| 111 |
description="Upload or take a face photo. Model predicts remaining lifespan (years).",
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
from datasets import load_dataset
|
| 10 |
from main import DinoRegressionHeteroImages
|
| 11 |
+
import mediapipe as mp
|
| 12 |
|
| 13 |
# ============================
|
| 14 |
# Model definition
|
|
|
|
| 85 |
std=[0.229, 0.224, 0.225]),
|
| 86 |
])
|
| 87 |
|
| 88 |
+
|
| 89 |
+
mp_face_detection = mp.solutions.face_detection
|
| 90 |
+
mp_drawing = mp.solutions.drawing_utils
|
| 91 |
+
|
| 92 |
+
def crop_face(img: Image.Image):
|
| 93 |
+
"""Detects the largest face and crops it, fallback to full image."""
|
| 94 |
+
img_rgb = np.array(img.convert("RGB"))
|
| 95 |
+
h, w, _ = img_rgb.shape
|
| 96 |
+
|
| 97 |
+
with mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) as face_det:
|
| 98 |
+
results = face_det.process(img_rgb)
|
| 99 |
+
|
| 100 |
+
if results.detections:
|
| 101 |
+
# take the first (largest) face
|
| 102 |
+
bbox = results.detections[0].location_data.relative_bounding_box
|
| 103 |
+
x1 = max(int(bbox.xmin * w), 0)
|
| 104 |
+
y1 = max(int(bbox.ymin * h), 0)
|
| 105 |
+
x2 = min(int((bbox.xmin + bbox.width) * w), w)
|
| 106 |
+
y2 = min(int((bbox.ymin + bbox.height) * h), h)
|
| 107 |
+
|
| 108 |
+
face = img.crop((x1, y1, x2, y2))
|
| 109 |
+
return face
|
| 110 |
+
|
| 111 |
+
# fallback: return original if no face found
|
| 112 |
+
return img
|
| 113 |
+
|
| 114 |
+
|
| 115 |
# ============================
|
| 116 |
# Prediction function
|
| 117 |
# ============================
|
| 118 |
def predict(img):
|
| 119 |
+
# Crop face
|
| 120 |
+
face = crop_face(img)
|
| 121 |
+
|
| 122 |
x = imgtransform(img).unsqueeze(0).to(device)
|
| 123 |
with torch.no_grad():
|
| 124 |
mu, logvar = model(x)
|
| 125 |
pred_years = mu.item() * lifespan_std + lifespan_mean
|
| 126 |
+
|
| 127 |
+
return f"<div style='font-size:32px; font-weight:bold; color:#2E86AB;'>Estimated remaining lifespan: {pred_years:.1f} years</div>"
|
| 128 |
|
| 129 |
# ============================
|
| 130 |
# Gradio UI (Webcam + Upload)
|
|
|
|
| 137 |
demo = gr.Interface(
|
| 138 |
fn=predict,
|
| 139 |
inputs=gr.Image(type="pil", sources=["upload", "webcam"], label="Face Photo"),
|
| 140 |
+
outputs=gr.HTML(label="Prediction"),
|
| 141 |
live=False,
|
| 142 |
title="Lifespan Predictor (Demo)",
|
| 143 |
description="Upload or take a face photo. Model predicts remaining lifespan (years).",
|