Spaces:
Running
Running
vijul.shah
commited on
Commit
·
99cd14f
1
Parent(s):
3733e70
Can download data as csv file and some UI changes
Browse files- app.py +25 -216
- app_utils.py +370 -5
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -1,251 +1,60 @@
|
|
| 1 |
-
import os
|
| 2 |
import sys
|
| 3 |
-
import tempfile
|
| 4 |
import os.path as osp
|
| 5 |
-
from PIL import Image
|
| 6 |
-
from io import BytesIO
|
| 7 |
-
import numpy as np
|
| 8 |
-
import pandas as pd
|
| 9 |
import streamlit as st
|
| 10 |
-
from PIL import ImageOps
|
| 11 |
-
from matplotlib import pyplot as plt
|
| 12 |
-
import altair as alt
|
| 13 |
|
| 14 |
root_path = osp.abspath(osp.join(__file__, osp.pardir))
|
| 15 |
sys.path.append(root_path)
|
| 16 |
|
| 17 |
from registry_utils import import_registered_modules
|
| 18 |
from app_utils import (
|
| 19 |
-
extract_frames,
|
| 20 |
is_image,
|
| 21 |
is_video,
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
)
|
| 28 |
|
| 29 |
import_registered_modules()
|
| 30 |
|
| 31 |
-
CAM_METHODS = ["CAM"]
|
| 32 |
-
TV_MODELS = ["ResNet18", "ResNet50"]
|
| 33 |
-
SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"]
|
| 34 |
-
UPSCALE = [2, 4]
|
| 35 |
-
UPSCALE_METHODS = ["BILINEAR", "BICUBIC"]
|
| 36 |
-
LABEL_MAP = ["left_pupil", "right_pupil"]
|
| 37 |
-
|
| 38 |
|
| 39 |
def main():
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
cols = st.columns((1, 1))
|
| 43 |
-
cols[0].header("Input")
|
| 44 |
-
cols[-1].header("Prediction")
|
| 45 |
-
|
| 46 |
-
st.sidebar.title("Upload Face or Eye")
|
| 47 |
-
uploaded_file = st.sidebar.file_uploader(
|
| 48 |
-
"Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"]
|
| 49 |
-
)
|
| 50 |
|
| 51 |
if uploaded_file is not None:
|
| 52 |
file_extension = uploaded_file.name.split(".")[-1]
|
|
|
|
| 53 |
|
| 54 |
if is_image(file_extension):
|
| 55 |
-
input_img =
|
| 56 |
-
|
| 57 |
-
input_img = ImageOps.exif_transpose(input_img)
|
| 58 |
-
input_img = resize_frame(input_img, max_width=640, max_height=480)
|
| 59 |
-
input_img = resize_frame(input_img, max_width=640, max_height=480)
|
| 60 |
-
cols[0].image(input_img, use_column_width=True)
|
| 61 |
-
st.session_state.total_frames = 1
|
| 62 |
-
|
| 63 |
elif is_video(file_extension):
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
video_path =
|
| 67 |
-
video_frames = extract_frames(video_path)
|
| 68 |
-
cols[0].video(video_path)
|
| 69 |
-
st.session_state.total_frames = len(video_frames)
|
| 70 |
-
|
| 71 |
-
st.session_state.current_frame = 0
|
| 72 |
-
st.session_state.frame_placeholder = cols[0].empty()
|
| 73 |
-
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
|
| 74 |
-
st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True)
|
| 75 |
-
|
| 76 |
-
st.sidebar.title("Setup")
|
| 77 |
-
pupil_selection = st.sidebar.selectbox(
|
| 78 |
-
"Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation"
|
| 79 |
-
)
|
| 80 |
-
tv_model = st.sidebar.selectbox("Classification model", ["ResNet18", "ResNet50"], help="Supported Models")
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
st.markdown("<style>#vg-tooltip-element{z-index: 1000051}</style>", unsafe_allow_html=True)
|
| 85 |
|
| 86 |
if st.sidebar.button("Predict Diameter & Compute CAM"):
|
| 87 |
if uploaded_file is None:
|
| 88 |
st.sidebar.error("Please upload an image or video")
|
| 89 |
else:
|
| 90 |
with st.spinner("Analyzing..."):
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
blink_detection=blink_detection,
|
| 100 |
)
|
| 101 |
-
# for ff in face_frames:
|
| 102 |
-
# if ff["has_face"]:
|
| 103 |
-
# cols[1].image(face_frames[0]["img"], use_column_width=True)
|
| 104 |
-
|
| 105 |
-
input_frames_keys = input_frames.keys()
|
| 106 |
-
video_cols = cols[1].columns(len(input_frames_keys))
|
| 107 |
-
for i, eye_type in enumerate(input_frames_keys):
|
| 108 |
-
video_cols[i].image(input_frames[eye_type][-1], use_column_width=True)
|
| 109 |
-
|
| 110 |
-
output_frames_keys = output_frames.keys()
|
| 111 |
-
fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5))
|
| 112 |
-
for i, eye_type in enumerate(output_frames_keys):
|
| 113 |
-
height, width, c = output_frames[eye_type][0].shape
|
| 114 |
-
video_cols[i].image(output_frames[eye_type][-1], use_column_width=True)
|
| 115 |
-
|
| 116 |
-
frame = np.zeros((height, width, c), dtype=np.uint8)
|
| 117 |
-
text = f"{predicted_diameters[eye_type][0]:.2f}"
|
| 118 |
-
frame = overlay_text_on_frame(frame, text)
|
| 119 |
-
video_cols[i].image(frame, use_column_width=True)
|
| 120 |
-
|
| 121 |
-
elif is_video(file_extension):
|
| 122 |
-
output_video_path = f"{root_path}/tmp.webm"
|
| 123 |
-
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_video(
|
| 124 |
-
cols,
|
| 125 |
-
video_frames,
|
| 126 |
-
tv_model,
|
| 127 |
-
pupil_selection,
|
| 128 |
-
output_video_path,
|
| 129 |
-
cam_method=CAM_METHODS[-1],
|
| 130 |
-
blink_detection=blink_detection,
|
| 131 |
-
)
|
| 132 |
-
os.remove(video_path)
|
| 133 |
-
|
| 134 |
-
num_columns = len(predicted_diameters)
|
| 135 |
-
|
| 136 |
-
# Create a layout for the charts
|
| 137 |
-
cols = st.columns(num_columns)
|
| 138 |
-
|
| 139 |
-
# colors = ["#2ca02c", "#d62728", "#1f77b4", "#ff7f0e"] # Green, Red, Blue, Orange
|
| 140 |
-
colors = ["#1f77b4", "#ff7f0e", "#636363"] # Blue, Orange, Gray
|
| 141 |
-
|
| 142 |
-
# Iterate through categories and assign charts to columns
|
| 143 |
-
for i, (category, values) in enumerate(predicted_diameters.items()):
|
| 144 |
-
with cols[i]: # Directly use the column index
|
| 145 |
-
# st.subheader(category) # Add a subheader for the category
|
| 146 |
-
|
| 147 |
-
# Convert values to numeric, replacing non-numeric values with None
|
| 148 |
-
values = [convert_diameter(value) for value in values]
|
| 149 |
-
|
| 150 |
-
# Create a DataFrame from the values for Altair
|
| 151 |
-
df = pd.DataFrame(values, columns=[category])
|
| 152 |
-
df["Frame"] = range(1, len(values) + 1) # Create a frame column starting from 1
|
| 153 |
-
|
| 154 |
-
# Get the min and max values for y-axis limits, ignoring None
|
| 155 |
-
min_value = min(filter(lambda x: x is not None, values), default=None)
|
| 156 |
-
max_value = max(filter(lambda x: x is not None, values), default=None)
|
| 157 |
-
|
| 158 |
-
# Create an Altair chart with y-axis limits
|
| 159 |
-
line_chart = (
|
| 160 |
-
alt.Chart(df)
|
| 161 |
-
.mark_line(color=colors[i])
|
| 162 |
-
.encode(
|
| 163 |
-
x=alt.X("Frame:Q", title="Frame Number"),
|
| 164 |
-
y=alt.Y(
|
| 165 |
-
f"{category}:Q",
|
| 166 |
-
title="Diameter",
|
| 167 |
-
scale=alt.Scale(domain=[min_value, max_value]),
|
| 168 |
-
),
|
| 169 |
-
tooltip=[
|
| 170 |
-
"Frame",
|
| 171 |
-
alt.Tooltip(f"{category}:Q", title="Diameter"),
|
| 172 |
-
],
|
| 173 |
-
)
|
| 174 |
-
# .properties(title=f"{category} - Predicted Diameters")
|
| 175 |
-
# .configure_axis(grid=True)
|
| 176 |
-
)
|
| 177 |
-
points_chart = line_chart.mark_point(color=colors[i], filled=True)
|
| 178 |
-
|
| 179 |
-
final_chart = (
|
| 180 |
-
line_chart.properties(title=f"{category} - Predicted Diameters") + points_chart
|
| 181 |
-
).interactive()
|
| 182 |
-
|
| 183 |
-
final_chart = final_chart.configure_axis(grid=True)
|
| 184 |
-
|
| 185 |
-
# Display the Altair chart
|
| 186 |
-
st.altair_chart(final_chart, use_container_width=True)
|
| 187 |
-
|
| 188 |
-
if eyes_ratios is not None and len(eyes_ratios) > 0:
|
| 189 |
-
df = pd.DataFrame(eyes_ratios, columns=["EAR"])
|
| 190 |
-
df["Frame"] = range(1, len(eyes_ratios) + 1) # Create a frame column starting from 1
|
| 191 |
-
|
| 192 |
-
# Create an Altair chart for eyes_ratios
|
| 193 |
-
line_chart = (
|
| 194 |
-
alt.Chart(df)
|
| 195 |
-
.mark_line(color=colors[-1]) # Set color of the line
|
| 196 |
-
.encode(
|
| 197 |
-
x=alt.X("Frame:Q", title="Frame Number"),
|
| 198 |
-
y=alt.Y("EAR:Q", title="Eyes Aspect Ratio"),
|
| 199 |
-
tooltip=["Frame", "EAR"],
|
| 200 |
-
)
|
| 201 |
-
# .properties(title="Eyes Aspect Ratios (EARs)")
|
| 202 |
-
# .configure_axis(grid=True)
|
| 203 |
-
)
|
| 204 |
-
points_chart = line_chart.mark_point(color=colors[-1], filled=True)
|
| 205 |
-
|
| 206 |
-
# Create a horizontal rule at y=0.22
|
| 207 |
-
line1 = alt.Chart(pd.DataFrame({"y": [0.22]})).mark_rule(color="red").encode(y="y:Q")
|
| 208 |
-
|
| 209 |
-
line2 = alt.Chart(pd.DataFrame({"y": [0.25]})).mark_rule(color="green").encode(y="y:Q")
|
| 210 |
-
|
| 211 |
-
# Add text annotations for the lines
|
| 212 |
-
text1 = (
|
| 213 |
-
alt.Chart(pd.DataFrame({"y": [0.22], "label": ["Definite Blinks (<=0.22)"]}))
|
| 214 |
-
.mark_text(align="left", dx=100, dy=9, color="red", size=16)
|
| 215 |
-
.encode(y="y:Q", text="label:N")
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
-
text2 = (
|
| 219 |
-
alt.Chart(pd.DataFrame({"y": [0.25], "label": ["No Blinks (>=0.25)"]}))
|
| 220 |
-
.mark_text(align="left", dx=-150, dy=-9, color="green", size=16)
|
| 221 |
-
.encode(y="y:Q", text="label:N")
|
| 222 |
-
)
|
| 223 |
-
|
| 224 |
-
# Add gray area text for the region between red and green lines
|
| 225 |
-
gray_area_text = (
|
| 226 |
-
alt.Chart(pd.DataFrame({"y": [0.235], "label": ["Gray Area"]}))
|
| 227 |
-
.mark_text(align="left", dx=0, dy=0, color="gray", size=16)
|
| 228 |
-
.encode(y="y:Q", text="label:N")
|
| 229 |
-
)
|
| 230 |
-
|
| 231 |
-
# Combine all elements: line chart, points, rules, and text annotations
|
| 232 |
-
final_chart = (
|
| 233 |
-
line_chart.properties(title="Eyes Aspect Ratios (EARs)")
|
| 234 |
-
+ points_chart
|
| 235 |
-
+ line1
|
| 236 |
-
+ line2
|
| 237 |
-
+ text1
|
| 238 |
-
+ text2
|
| 239 |
-
+ gray_area_text
|
| 240 |
-
).interactive()
|
| 241 |
-
|
| 242 |
-
# Configure axis properties at the chart level
|
| 243 |
-
final_chart = final_chart.configure_axis(grid=True)
|
| 244 |
-
|
| 245 |
-
# Display the Altair chart
|
| 246 |
-
# st.subheader("Eyes Aspect Ratios (EARs)")
|
| 247 |
-
st.altair_chart(final_chart, use_container_width=True)
|
| 248 |
|
| 249 |
|
| 250 |
if __name__ == "__main__":
|
| 251 |
main()
|
|
|
|
|
|
|
|
|
| 1 |
import sys
|
|
|
|
| 2 |
import os.path as osp
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
root_path = osp.abspath(osp.join(__file__, osp.pardir))
|
| 6 |
sys.path.append(root_path)
|
| 7 |
|
| 8 |
from registry_utils import import_registered_modules
|
| 9 |
from app_utils import (
|
|
|
|
| 10 |
is_image,
|
| 11 |
is_video,
|
| 12 |
+
process_image_and_vizualize_data,
|
| 13 |
+
process_video_and_visualize_data,
|
| 14 |
+
set_frames_processed_count_placeholder,
|
| 15 |
+
set_input_image_on_ui,
|
| 16 |
+
set_input_video_on_ui,
|
| 17 |
+
set_page_info,
|
| 18 |
+
set_sidebar_info,
|
| 19 |
)
|
| 20 |
|
| 21 |
import_registered_modules()
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def main():
|
| 25 |
+
cols = set_page_info()
|
| 26 |
+
uploaded_file, pupil_selection, tv_model, blink_detection = set_sidebar_info()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
if uploaded_file is not None:
|
| 29 |
file_extension = uploaded_file.name.split(".")[-1]
|
| 30 |
+
st.session_state["file_extension"] = file_extension
|
| 31 |
|
| 32 |
if is_image(file_extension):
|
| 33 |
+
input_img = set_input_image_on_ui(uploaded_file, cols)
|
| 34 |
+
st.session_state["input_img"] = input_img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
elif is_video(file_extension):
|
| 36 |
+
video_frames, video_path = set_input_video_on_ui(uploaded_file, cols)
|
| 37 |
+
st.session_state["video_frames"] = video_frames
|
| 38 |
+
st.session_state["video_path"] = video_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
set_frames_processed_count_placeholder(cols)
|
|
|
|
|
|
|
| 41 |
|
| 42 |
if st.sidebar.button("Predict Diameter & Compute CAM"):
|
| 43 |
if uploaded_file is None:
|
| 44 |
st.sidebar.error("Please upload an image or video")
|
| 45 |
else:
|
| 46 |
with st.spinner("Analyzing..."):
|
| 47 |
+
if is_image(st.session_state.get("file_extension")):
|
| 48 |
+
input_img = st.session_state.get("input_img")
|
| 49 |
+
process_image_and_vizualize_data(cols, input_img, tv_model, pupil_selection, blink_detection)
|
| 50 |
+
elif is_video(st.session_state.get("file_extension")):
|
| 51 |
+
video_frames = st.session_state.get("video_frames")
|
| 52 |
+
video_path = st.session_state.get("video_path")
|
| 53 |
+
process_video_and_visualize_data(
|
| 54 |
+
cols, video_frames, tv_model, pupil_selection, blink_detection, video_path
|
|
|
|
| 55 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
if __name__ == "__main__":
|
| 59 |
main()
|
| 60 |
+
# run: streamlit run app.py --server.enableXsrfProtection false
|
app_utils.py
CHANGED
|
@@ -1,16 +1,21 @@
|
|
| 1 |
import base64
|
| 2 |
from io import BytesIO
|
|
|
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
import cv2
|
| 6 |
from matplotlib import pyplot as plt
|
| 7 |
import numpy as np
|
|
|
|
| 8 |
import streamlit as st
|
| 9 |
import torch
|
| 10 |
import tempfile
|
| 11 |
from PIL import Image
|
| 12 |
from torchvision.transforms.functional import to_pil_image
|
| 13 |
from torchvision import transforms
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
from torchcam.methods import CAM
|
| 16 |
from torchcam import methods as torchcam_methods
|
|
@@ -23,6 +28,10 @@ sys.path.append(root_path)
|
|
| 23 |
from preprocessing.dataset_creation import EyeDentityDatasetCreation
|
| 24 |
from utils import get_model
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
@torch.no_grad()
|
| 28 |
def load_model(model_configs, device="cpu"):
|
|
@@ -234,12 +243,12 @@ def process_frames(
|
|
| 234 |
)
|
| 235 |
|
| 236 |
preprocess_steps = [
|
| 237 |
-
transforms.ToTensor(),
|
| 238 |
transforms.Resize(
|
| 239 |
[32, 64],
|
| 240 |
interpolation=transforms.InterpolationMode.BICUBIC,
|
| 241 |
antialias=True,
|
| 242 |
),
|
|
|
|
| 243 |
]
|
| 244 |
preprocess_function = transforms.Compose(preprocess_steps)
|
| 245 |
|
|
@@ -368,7 +377,11 @@ def process_frames(
|
|
| 368 |
|
| 369 |
combined_frame = np.vstack((input_img_np, output_img_np, frame))
|
| 370 |
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
st.session_state.current_frame = idx + 1
|
| 374 |
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
|
|
@@ -383,9 +396,9 @@ def process_frames(
|
|
| 383 |
|
| 384 |
|
| 385 |
# Function to display video with autoplay and loop
|
| 386 |
-
def display_video_with_autoplay(video_col, video_path):
|
| 387 |
video_html = f"""
|
| 388 |
-
<video width="
|
| 389 |
<source src="data:video/mp4;base64,{video_path}" type="video/mp4">
|
| 390 |
</video>
|
| 391 |
"""
|
|
@@ -458,7 +471,359 @@ def combine_and_show_frames(input_frames, cam_frames, pred_diameters_frames, out
|
|
| 458 |
video_base64 = base64.b64encode(video_bytes).decode("utf-8")
|
| 459 |
|
| 460 |
# Display the combined video
|
| 461 |
-
display_video_with_autoplay(video_cols[eye_type], video_base64)
|
| 462 |
|
| 463 |
# Clean up
|
| 464 |
os.remove(output_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import base64
|
| 2 |
from io import BytesIO
|
| 3 |
+
import io
|
| 4 |
import os
|
| 5 |
import sys
|
| 6 |
import cv2
|
| 7 |
from matplotlib import pyplot as plt
|
| 8 |
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
import streamlit as st
|
| 11 |
import torch
|
| 12 |
import tempfile
|
| 13 |
from PIL import Image
|
| 14 |
from torchvision.transforms.functional import to_pil_image
|
| 15 |
from torchvision import transforms
|
| 16 |
+
from PIL import ImageOps
|
| 17 |
+
import altair as alt
|
| 18 |
+
|
| 19 |
|
| 20 |
from torchcam.methods import CAM
|
| 21 |
from torchcam import methods as torchcam_methods
|
|
|
|
| 28 |
from preprocessing.dataset_creation import EyeDentityDatasetCreation
|
| 29 |
from utils import get_model
|
| 30 |
|
| 31 |
+
CAM_METHODS = ["CAM"]
|
| 32 |
+
# colors = ["#2ca02c", "#d62728", "#1f77b4", "#ff7f0e"] # Green, Red, Blue, Orange
|
| 33 |
+
colors = ["#1f77b4", "#ff7f0e", "#636363"] # Blue, Orange, Gray
|
| 34 |
+
|
| 35 |
|
| 36 |
@torch.no_grad()
|
| 37 |
def load_model(model_configs, device="cpu"):
|
|
|
|
| 243 |
)
|
| 244 |
|
| 245 |
preprocess_steps = [
|
|
|
|
| 246 |
transforms.Resize(
|
| 247 |
[32, 64],
|
| 248 |
interpolation=transforms.InterpolationMode.BICUBIC,
|
| 249 |
antialias=True,
|
| 250 |
),
|
| 251 |
+
transforms.ToTensor(),
|
| 252 |
]
|
| 253 |
preprocess_function = transforms.Compose(preprocess_steps)
|
| 254 |
|
|
|
|
| 377 |
|
| 378 |
combined_frame = np.vstack((input_img_np, output_img_np, frame))
|
| 379 |
|
| 380 |
+
img_base64 = pil_image_to_base64(Image.fromarray(combined_frame))
|
| 381 |
+
image_html = f'<div style="width: {str(50*len(selected_eyes))}%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
|
| 382 |
+
video_placeholders[eye_type].markdown(image_html, unsafe_allow_html=True)
|
| 383 |
+
|
| 384 |
+
# video_placeholders[eye_type].image(combined_frame, use_column_width=True)
|
| 385 |
|
| 386 |
st.session_state.current_frame = idx + 1
|
| 387 |
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
|
|
|
|
| 396 |
|
| 397 |
|
| 398 |
# Function to display video with autoplay and loop
|
| 399 |
+
def display_video_with_autoplay(video_col, video_path, width):
|
| 400 |
video_html = f"""
|
| 401 |
+
<video width="{str(width)}%" height="auto" autoplay loop muted>
|
| 402 |
<source src="data:video/mp4;base64,{video_path}" type="video/mp4">
|
| 403 |
</video>
|
| 404 |
"""
|
|
|
|
| 471 |
video_base64 = base64.b64encode(video_bytes).decode("utf-8")
|
| 472 |
|
| 473 |
# Display the combined video
|
| 474 |
+
display_video_with_autoplay(video_cols[eye_type], video_base64, width=len(video_cols) * 50)
|
| 475 |
|
| 476 |
# Clean up
|
| 477 |
os.remove(output_path)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def set_input_image_on_ui(uploaded_file, cols):
|
| 481 |
+
input_img = Image.open(BytesIO(uploaded_file.read())).convert("RGB")
|
| 482 |
+
# NOTE: images taken with phone camera has an EXIF data field which often rotates images taken with the phone in a tilted position. PIL has a utility function that removes this data and ‘uprights’ the image.
|
| 483 |
+
input_img = ImageOps.exif_transpose(input_img)
|
| 484 |
+
input_img = resize_frame(input_img, max_width=640, max_height=480)
|
| 485 |
+
input_img = resize_frame(input_img, max_width=640, max_height=480)
|
| 486 |
+
cols[0].image(input_img, use_column_width=True)
|
| 487 |
+
st.session_state.total_frames = 1
|
| 488 |
+
return input_img
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def set_input_video_on_ui(uploaded_file, cols):
|
| 492 |
+
tfile = tempfile.NamedTemporaryFile(delete=False)
|
| 493 |
+
tfile.write(uploaded_file.read())
|
| 494 |
+
video_path = tfile.name
|
| 495 |
+
video_frames = extract_frames(video_path)
|
| 496 |
+
cols[0].video(video_path)
|
| 497 |
+
st.session_state.total_frames = len(video_frames)
|
| 498 |
+
return video_frames, video_path
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def set_frames_processed_count_placeholder(cols):
|
| 502 |
+
st.session_state.current_frame = 0
|
| 503 |
+
st.session_state.frame_placeholder = cols[0].empty()
|
| 504 |
+
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
|
| 505 |
+
st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True)
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def set_page_info():
|
| 509 |
+
st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")
|
| 510 |
+
st.title("EyeDentify Playground")
|
| 511 |
+
cols = st.columns((1, 1))
|
| 512 |
+
cols[0].header("Input")
|
| 513 |
+
cols[-1].header("Prediction")
|
| 514 |
+
return cols
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def set_sidebar_info():
|
| 518 |
+
LABEL_MAP = ["left_pupil", "right_pupil"]
|
| 519 |
+
TV_MODELS = ["ResNet18", "ResNet50"]
|
| 520 |
+
|
| 521 |
+
st.sidebar.title("Upload Face or Eye")
|
| 522 |
+
uploaded_file = st.sidebar.file_uploader(
|
| 523 |
+
"Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"]
|
| 524 |
+
)
|
| 525 |
+
st.sidebar.title("Setup")
|
| 526 |
+
pupil_selection = st.sidebar.selectbox(
|
| 527 |
+
"Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation"
|
| 528 |
+
)
|
| 529 |
+
tv_model = st.sidebar.selectbox("Classification model", TV_MODELS, help="Supported Models")
|
| 530 |
+
|
| 531 |
+
blink_detection = st.sidebar.checkbox("Detect Blinks")
|
| 532 |
+
|
| 533 |
+
st.markdown("<style>#vg-tooltip-element{z-index: 1000051}</style>", unsafe_allow_html=True)
|
| 534 |
+
|
| 535 |
+
return (uploaded_file, pupil_selection, tv_model, blink_detection)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def pil_image_to_base64(img):
|
| 539 |
+
"""Convert a PIL Image to a base64 encoded string."""
|
| 540 |
+
buffered = io.BytesIO()
|
| 541 |
+
img.save(buffered, format="PNG")
|
| 542 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 543 |
+
return img_str
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def process_image_and_vizualize_data(cols, input_img, tv_model, pupil_selection, blink_detection):
|
| 547 |
+
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames(
|
| 548 |
+
cols,
|
| 549 |
+
[input_img],
|
| 550 |
+
tv_model,
|
| 551 |
+
pupil_selection,
|
| 552 |
+
cam_method=CAM_METHODS[-1],
|
| 553 |
+
blink_detection=blink_detection,
|
| 554 |
+
)
|
| 555 |
+
# for ff in face_frames:
|
| 556 |
+
# if ff["has_face"]:
|
| 557 |
+
# cols[1].image(face_frames[0]["img"], use_column_width=True)
|
| 558 |
+
|
| 559 |
+
input_frames_keys = input_frames.keys()
|
| 560 |
+
video_cols = cols[1].columns(len(input_frames_keys))
|
| 561 |
+
|
| 562 |
+
for i, eye_type in enumerate(input_frames_keys):
|
| 563 |
+
# Check the pupil_selection and set the width accordingly
|
| 564 |
+
if pupil_selection == "both":
|
| 565 |
+
video_cols[i].image(input_frames[eye_type][-1], use_column_width=True)
|
| 566 |
+
else:
|
| 567 |
+
img_base64 = pil_image_to_base64(Image.fromarray(input_frames[eye_type][-1]))
|
| 568 |
+
image_html = f'<div style="width: 50%; margin-bottom: 1.2%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
|
| 569 |
+
video_cols[i].markdown(image_html, unsafe_allow_html=True)
|
| 570 |
+
|
| 571 |
+
output_frames_keys = output_frames.keys()
|
| 572 |
+
fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5))
|
| 573 |
+
for i, eye_type in enumerate(output_frames_keys):
|
| 574 |
+
height, width, c = output_frames[eye_type][0].shape
|
| 575 |
+
frame = np.zeros((height, width, c), dtype=np.uint8)
|
| 576 |
+
text = f"{predicted_diameters[eye_type][0]:.2f}"
|
| 577 |
+
frame = overlay_text_on_frame(frame, text)
|
| 578 |
+
|
| 579 |
+
if pupil_selection == "both":
|
| 580 |
+
video_cols[i].image(output_frames[eye_type][-1], use_column_width=True)
|
| 581 |
+
video_cols[i].image(frame, use_column_width=True)
|
| 582 |
+
else:
|
| 583 |
+
img_base64 = pil_image_to_base64(Image.fromarray(output_frames[eye_type][-1]))
|
| 584 |
+
image_html = f'<div style="width: 50%; margin-top: 1.2%; margin-bottom: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
|
| 585 |
+
video_cols[i].markdown(image_html, unsafe_allow_html=True)
|
| 586 |
+
img_base64 = pil_image_to_base64(Image.fromarray(frame))
|
| 587 |
+
image_html = f'<div style="width: 50%; margin-top: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
|
| 588 |
+
video_cols[i].markdown(image_html, unsafe_allow_html=True)
|
| 589 |
+
|
| 590 |
+
return None
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def plot_ears(eyes_ratios, eyes_df):
|
| 594 |
+
eyes_df["EAR"] = eyes_ratios
|
| 595 |
+
df = pd.DataFrame(eyes_ratios, columns=["EAR"])
|
| 596 |
+
df["Frame"] = range(1, len(eyes_ratios) + 1) # Create a frame column starting from 1
|
| 597 |
+
|
| 598 |
+
# Create an Altair chart for eyes_ratios
|
| 599 |
+
line_chart = (
|
| 600 |
+
alt.Chart(df)
|
| 601 |
+
.mark_line(color=colors[-1]) # Set color of the line
|
| 602 |
+
.encode(
|
| 603 |
+
x=alt.X("Frame:Q", title="Frame Number"),
|
| 604 |
+
y=alt.Y("EAR:Q", title="Eyes Aspect Ratio"),
|
| 605 |
+
tooltip=["Frame", "EAR"],
|
| 606 |
+
)
|
| 607 |
+
# .properties(title="Eyes Aspect Ratios (EARs)")
|
| 608 |
+
# .configure_axis(grid=True)
|
| 609 |
+
)
|
| 610 |
+
points_chart = line_chart.mark_point(color=colors[-1], filled=True)
|
| 611 |
+
|
| 612 |
+
# Create a horizontal rule at y=0.22
|
| 613 |
+
line1 = alt.Chart(pd.DataFrame({"y": [0.22]})).mark_rule(color="red").encode(y="y:Q")
|
| 614 |
+
|
| 615 |
+
line2 = alt.Chart(pd.DataFrame({"y": [0.25]})).mark_rule(color="green").encode(y="y:Q")
|
| 616 |
+
|
| 617 |
+
# Add text annotations for the lines
|
| 618 |
+
text1 = (
|
| 619 |
+
alt.Chart(pd.DataFrame({"y": [0.22], "label": ["Definite Blinks (<=0.22)"]}))
|
| 620 |
+
.mark_text(align="left", dx=100, dy=9, color="red", size=16)
|
| 621 |
+
.encode(y="y:Q", text="label:N")
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
text2 = (
|
| 625 |
+
alt.Chart(pd.DataFrame({"y": [0.25], "label": ["No Blinks (>=0.25)"]}))
|
| 626 |
+
.mark_text(align="left", dx=-150, dy=-9, color="green", size=16)
|
| 627 |
+
.encode(y="y:Q", text="label:N")
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
# Add gray area text for the region between red and green lines
|
| 631 |
+
gray_area_text = (
|
| 632 |
+
alt.Chart(pd.DataFrame({"y": [0.235], "label": ["Gray Area"]}))
|
| 633 |
+
.mark_text(align="left", dx=0, dy=0, color="gray", size=16)
|
| 634 |
+
.encode(y="y:Q", text="label:N")
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
# Combine all elements: line chart, points, rules, and text annotations
|
| 638 |
+
final_chart = (
|
| 639 |
+
line_chart.properties(title="Eyes Aspect Ratios (EARs)")
|
| 640 |
+
+ points_chart
|
| 641 |
+
+ line1
|
| 642 |
+
+ line2
|
| 643 |
+
+ text1
|
| 644 |
+
+ text2
|
| 645 |
+
+ gray_area_text
|
| 646 |
+
).interactive()
|
| 647 |
+
|
| 648 |
+
# Configure axis properties at the chart level
|
| 649 |
+
final_chart = final_chart.configure_axis(grid=True)
|
| 650 |
+
|
| 651 |
+
# Display the Altair chart
|
| 652 |
+
# st.subheader("Eyes Aspect Ratios (EARs)")
|
| 653 |
+
st.altair_chart(final_chart, use_container_width=True)
|
| 654 |
+
return eyes_df
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def plot_individual_charts(predicted_diameters, cols):
|
| 658 |
+
# Iterate through categories and assign charts to columns
|
| 659 |
+
for i, (category, values) in enumerate(predicted_diameters.items()):
|
| 660 |
+
with cols[i]: # Directly use the column index
|
| 661 |
+
# st.subheader(category) # Add a subheader for the category
|
| 662 |
+
if "left" in category:
|
| 663 |
+
selected_color = colors[0]
|
| 664 |
+
elif "right" in category:
|
| 665 |
+
selected_color = colors[1]
|
| 666 |
+
else:
|
| 667 |
+
selected_color = colors[i]
|
| 668 |
+
|
| 669 |
+
# Convert values to numeric, replacing non-numeric values with None
|
| 670 |
+
values = [convert_diameter(value) for value in values]
|
| 671 |
+
|
| 672 |
+
if "left" in category:
|
| 673 |
+
category_name = "Left Pupil Diameter"
|
| 674 |
+
else:
|
| 675 |
+
category_name = "Right Pupil Diameter"
|
| 676 |
+
|
| 677 |
+
# Create a DataFrame from the values for Altair
|
| 678 |
+
df = pd.DataFrame(
|
| 679 |
+
{
|
| 680 |
+
"Frame": range(1, len(values) + 1),
|
| 681 |
+
category_name: values,
|
| 682 |
+
}
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
# Get the min and max values for y-axis limits, ignoring None
|
| 686 |
+
min_value = min(filter(lambda x: x is not None, values), default=None)
|
| 687 |
+
max_value = max(filter(lambda x: x is not None, values), default=None)
|
| 688 |
+
|
| 689 |
+
# Create an Altair chart with y-axis limits
|
| 690 |
+
line_chart = (
|
| 691 |
+
alt.Chart(df)
|
| 692 |
+
.mark_line(color=selected_color)
|
| 693 |
+
.encode(
|
| 694 |
+
x=alt.X("Frame:Q", title="Frame Number"),
|
| 695 |
+
y=alt.Y(
|
| 696 |
+
f"{category_name}:Q",
|
| 697 |
+
title="Diameter",
|
| 698 |
+
scale=alt.Scale(domain=[min_value, max_value]),
|
| 699 |
+
),
|
| 700 |
+
tooltip=[
|
| 701 |
+
"Frame",
|
| 702 |
+
alt.Tooltip(f"{category_name}:Q", title="Diameter"),
|
| 703 |
+
],
|
| 704 |
+
)
|
| 705 |
+
# .properties(title=f"{category} - Predicted Diameters")
|
| 706 |
+
# .configure_axis(grid=True)
|
| 707 |
+
)
|
| 708 |
+
points_chart = line_chart.mark_point(color=selected_color, filled=True)
|
| 709 |
+
|
| 710 |
+
final_chart = (
|
| 711 |
+
line_chart.properties(
|
| 712 |
+
title=f"{'Left Pupil' if 'left' in category else 'Right Pupil'} - Predicted Diameters"
|
| 713 |
+
)
|
| 714 |
+
+ points_chart
|
| 715 |
+
).interactive()
|
| 716 |
+
|
| 717 |
+
final_chart = final_chart.configure_axis(grid=True)
|
| 718 |
+
|
| 719 |
+
# Display the Altair chart
|
| 720 |
+
st.altair_chart(final_chart, use_container_width=True)
|
| 721 |
+
return df
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def plot_combined_charts(predicted_diameters):
|
| 725 |
+
all_min_values = []
|
| 726 |
+
all_max_values = []
|
| 727 |
+
|
| 728 |
+
# Create an empty DataFrame to store combined data for plotting
|
| 729 |
+
combined_df = pd.DataFrame()
|
| 730 |
+
|
| 731 |
+
# Iterate through categories and collect data
|
| 732 |
+
for category, values in predicted_diameters.items():
|
| 733 |
+
# Convert values to numeric, replacing non-numeric values with None
|
| 734 |
+
values = [convert_diameter(value) for value in values]
|
| 735 |
+
|
| 736 |
+
# Get the min and max values for y-axis limits, ignoring None
|
| 737 |
+
min_value = min(filter(lambda x: x is not None, values), default=None)
|
| 738 |
+
max_value = max(filter(lambda x: x is not None, values), default=None)
|
| 739 |
+
|
| 740 |
+
all_min_values.append(min_value)
|
| 741 |
+
all_max_values.append(max_value)
|
| 742 |
+
|
| 743 |
+
category = "left_pupil" if "left" in category else "right_pupil"
|
| 744 |
+
|
| 745 |
+
# Create a DataFrame from the values
|
| 746 |
+
df = pd.DataFrame(
|
| 747 |
+
{
|
| 748 |
+
"Diameter": values,
|
| 749 |
+
"Frame": range(1, len(values) + 1), # Create a frame column starting from 1
|
| 750 |
+
"Category": category, # Add a column to specify the category
|
| 751 |
+
}
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
# Append to combined DataFrame
|
| 755 |
+
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
| 756 |
+
|
| 757 |
+
combined_chart = (
|
| 758 |
+
alt.Chart(combined_df)
|
| 759 |
+
.mark_line()
|
| 760 |
+
.encode(
|
| 761 |
+
x=alt.X("Frame:Q", title="Frame Number"),
|
| 762 |
+
y=alt.Y(
|
| 763 |
+
"Diameter:Q",
|
| 764 |
+
title="Diameter",
|
| 765 |
+
scale=alt.Scale(domain=[min(all_min_values), max(all_max_values)]),
|
| 766 |
+
),
|
| 767 |
+
color=alt.Color("Category:N", scale=alt.Scale(range=colors), title="Pupil Type"),
|
| 768 |
+
tooltip=["Frame", "Diameter:Q", "Category:N"],
|
| 769 |
+
)
|
| 770 |
+
)
|
| 771 |
+
points_chart = combined_chart.mark_point(filled=True)
|
| 772 |
+
|
| 773 |
+
final_chart = (combined_chart.properties(title="Predicted Diameters") + points_chart).interactive()
|
| 774 |
+
|
| 775 |
+
final_chart = final_chart.configure_axis(grid=True)
|
| 776 |
+
|
| 777 |
+
# Display the combined chart
|
| 778 |
+
st.altair_chart(final_chart, use_container_width=True)
|
| 779 |
+
|
| 780 |
+
# --------------------------------------------
|
| 781 |
+
# Convert to a DataFrame
|
| 782 |
+
left_pupil_values = [convert_diameter(value) for value in predicted_diameters["left_eye"]]
|
| 783 |
+
right_pupil_values = [convert_diameter(value) for value in predicted_diameters["right_eye"]]
|
| 784 |
+
|
| 785 |
+
df = pd.DataFrame(
|
| 786 |
+
{
|
| 787 |
+
"Frame": range(1, len(left_pupil_values) + 1),
|
| 788 |
+
"Left Pupil Diameter": left_pupil_values,
|
| 789 |
+
"Right Pupil Diameter": right_pupil_values,
|
| 790 |
+
}
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
# Calculate the difference between left and right pupil diameters
|
| 794 |
+
df["Difference Value"] = df["Left Pupil Diameter"] - df["Right Pupil Diameter"]
|
| 795 |
+
|
| 796 |
+
# Determine the status of the difference
|
| 797 |
+
df["Difference Status"] = df.apply(
|
| 798 |
+
lambda row: "L>R" if row["Left Pupil Diameter"] > row["Right Pupil Diameter"] else "L<R",
|
| 799 |
+
axis=1,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
return df
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
def process_video_and_visualize_data(cols, video_frames, tv_model, pupil_selection, blink_detection, video_path):
|
| 806 |
+
output_video_path = f"{root_path}/tmp.webm"
|
| 807 |
+
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_video(
|
| 808 |
+
cols,
|
| 809 |
+
video_frames,
|
| 810 |
+
tv_model,
|
| 811 |
+
pupil_selection,
|
| 812 |
+
output_video_path,
|
| 813 |
+
cam_method=CAM_METHODS[-1],
|
| 814 |
+
blink_detection=blink_detection,
|
| 815 |
+
)
|
| 816 |
+
os.remove(video_path)
|
| 817 |
+
|
| 818 |
+
num_columns = len(predicted_diameters)
|
| 819 |
+
cols = st.columns(num_columns)
|
| 820 |
+
|
| 821 |
+
if num_columns == 2:
|
| 822 |
+
df = plot_combined_charts(predicted_diameters)
|
| 823 |
+
else:
|
| 824 |
+
df = plot_individual_charts(predicted_diameters, cols)
|
| 825 |
+
|
| 826 |
+
if eyes_ratios is not None and len(eyes_ratios) > 0:
|
| 827 |
+
df = plot_ears(eyes_ratios, df)
|
| 828 |
+
|
| 829 |
+
st.dataframe(df, hide_index=True, use_container_width=True)
|
requirements.txt
CHANGED
|
@@ -20,7 +20,7 @@ dlib
|
|
| 20 |
einops
|
| 21 |
transformers
|
| 22 |
gfpgan
|
| 23 |
-
|
| 24 |
mediapipe
|
| 25 |
imutils
|
| 26 |
scipy
|
|
|
|
| 20 |
einops
|
| 21 |
transformers
|
| 22 |
gfpgan
|
| 23 |
+
streamlit==1.38.0
|
| 24 |
mediapipe
|
| 25 |
imutils
|
| 26 |
scipy
|