SiyaYan commited on
Commit
c0d332e
·
verified ·
1 Parent(s): f5f7a9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -2
app.py CHANGED
@@ -8,6 +8,7 @@ import gradio as gr
8
  import numpy as np
9
  from datasets import load_dataset
10
  from main import DinoRegressionHeteroImages
 
11
 
12
  # ============================
13
  # Model definition
@@ -84,15 +85,46 @@ imgtransform = transforms.Compose([
84
  std=[0.229, 0.224, 0.225]),
85
  ])
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # ============================
88
  # Prediction function
89
  # ============================
90
  def predict(img):
 
 
 
91
  x = imgtransform(img).unsqueeze(0).to(device)
92
  with torch.no_grad():
93
  mu, logvar = model(x)
94
  pred_years = mu.item() * lifespan_std + lifespan_mean
95
- return f"Estimated remaining lifespan: {pred_years:.1f} years"
 
96
 
97
  # ============================
98
  # Gradio UI (Webcam + Upload)
@@ -105,7 +137,7 @@ img.save("example.jpg")
105
  demo = gr.Interface(
106
  fn=predict,
107
  inputs=gr.Image(type="pil", sources=["upload", "webcam"], label="Face Photo"),
108
- outputs=gr.Textbox(label="Predicted Lifespan"),
109
  live=False,
110
  title="Lifespan Predictor (Demo)",
111
  description="Upload or take a face photo. Model predicts remaining lifespan (years).",
 
8
  import numpy as np
9
  from datasets import load_dataset
10
  from main import DinoRegressionHeteroImages
11
+ import mediapipe as mp
12
 
13
  # ============================
14
  # Model definition
 
85
  std=[0.229, 0.224, 0.225]),
86
  ])
87
 
88
+
89
+ mp_face_detection = mp.solutions.face_detection
90
+ mp_drawing = mp.solutions.drawing_utils
91
+
92
+ def crop_face(img: Image.Image):
93
+ """Detects the largest face and crops it, fallback to full image."""
94
+ img_rgb = np.array(img.convert("RGB"))
95
+ h, w, _ = img_rgb.shape
96
+
97
+ with mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) as face_det:
98
+ results = face_det.process(img_rgb)
99
+
100
+ if results.detections:
101
+ # take the first (largest) face
102
+ bbox = results.detections[0].location_data.relative_bounding_box
103
+ x1 = max(int(bbox.xmin * w), 0)
104
+ y1 = max(int(bbox.ymin * h), 0)
105
+ x2 = min(int((bbox.xmin + bbox.width) * w), w)
106
+ y2 = min(int((bbox.ymin + bbox.height) * h), h)
107
+
108
+ face = img.crop((x1, y1, x2, y2))
109
+ return face
110
+
111
+ # fallback: return original if no face found
112
+ return img
113
+
114
+
115
  # ============================
116
  # Prediction function
117
  # ============================
118
  def predict(img):
119
+ # Crop face
120
+ face = crop_face(img)
121
+
122
  x = imgtransform(img).unsqueeze(0).to(device)
123
  with torch.no_grad():
124
  mu, logvar = model(x)
125
  pred_years = mu.item() * lifespan_std + lifespan_mean
126
+
127
+ return f"<div style='font-size:32px; font-weight:bold; color:#2E86AB;'>Estimated remaining lifespan: {pred_years:.1f} years</div>"
128
 
129
  # ============================
130
  # Gradio UI (Webcam + Upload)
 
137
  demo = gr.Interface(
138
  fn=predict,
139
  inputs=gr.Image(type="pil", sources=["upload", "webcam"], label="Face Photo"),
140
+ outputs=gr.HTML(label="Prediction"),
141
  live=False,
142
  title="Lifespan Predictor (Demo)",
143
  description="Upload or take a face photo. Model predicts remaining lifespan (years).",