Spaces:
Running
Running
| import os | |
| import sys | |
| import tempfile | |
| import os.path as osp | |
| from PIL import Image | |
| from io import BytesIO | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from PIL import ImageOps | |
| from matplotlib import pyplot as plt | |
| import altair as alt | |
| root_path = osp.abspath(osp.join(__file__, osp.pardir)) | |
| sys.path.append(root_path) | |
| from registry_utils import import_registered_modules | |
| from app_utils import ( | |
| extract_frames, | |
| is_image, | |
| is_video, | |
| convert_diameter, | |
| overlay_text_on_frame, | |
| process_frames, | |
| process_video, | |
| resize_frame, | |
| ) | |
| import_registered_modules() | |
| CAM_METHODS = ["CAM"] | |
| TV_MODELS = ["ResNet18", "ResNet50"] | |
| SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"] | |
| UPSCALE = [2, 4] | |
| UPSCALE_METHODS = ["BILINEAR", "BICUBIC"] | |
| LABEL_MAP = ["left_pupil", "right_pupil"] | |
| def main(): | |
| st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide") | |
| st.markdown( | |
| """ | |
| <style> | |
| /* Remove the top margin/padding */ | |
| .block-container { | |
| padding-top: 0rem; | |
| padding-bottom: 1rem; /* Adjust this as needed */ | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.title("EyeDentify Playground") | |
| cols = st.columns((1, 1)) | |
| cols[0].header("Input") | |
| cols[-1].header("Prediction") | |
| st.sidebar.title("Upload Face or Eye") | |
| uploaded_file = st.sidebar.file_uploader( | |
| "Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"] | |
| ) | |
| if uploaded_file is not None: | |
| file_extension = uploaded_file.name.split(".")[-1] | |
| if is_image(file_extension): | |
| input_img = Image.open(BytesIO(uploaded_file.read())).convert("RGB") | |
| # NOTE: images taken with phone camera has an EXIF data field which often rotates images taken with the phone in a tilted position. PIL has a utility function that removes this data and ‘uprights’ the image. | |
| input_img = ImageOps.exif_transpose(input_img) | |
| input_img = resize_frame(input_img, max_width=640, max_height=480) | |
| input_img = resize_frame(input_img, max_width=640, max_height=480) | |
| cols[0].image(input_img, use_column_width=True) | |
| st.session_state.total_frames = 1 | |
| elif is_video(file_extension): | |
| tfile = tempfile.NamedTemporaryFile(delete=False) | |
| tfile.write(uploaded_file.read()) | |
| video_path = tfile.name | |
| video_frames = extract_frames(video_path) | |
| cols[0].video(video_path) | |
| st.session_state.total_frames = len(video_frames) | |
| st.session_state.current_frame = 0 | |
| st.session_state.frame_placeholder = cols[0].empty() | |
| txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>" | |
| st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True) | |
| st.sidebar.title("Setup") | |
| pupil_selection = st.sidebar.selectbox( | |
| "Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation" | |
| ) | |
| tv_model = st.sidebar.selectbox("Classification model", ["ResNet18", "ResNet50"], help="Supported Models") | |
| blink_detection = st.sidebar.checkbox("Detect Blinks") | |
| if st.sidebar.button("Predict Diameter & Compute CAM"): | |
| if uploaded_file is None: | |
| st.sidebar.error("Please upload an image or video") | |
| else: | |
| with st.spinner("Analyzing..."): | |
| if is_image(file_extension): | |
| input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames( | |
| cols, | |
| [input_img], | |
| tv_model, | |
| pupil_selection, | |
| cam_method=CAM_METHODS[-1], | |
| blink_detection=blink_detection, | |
| ) | |
| # for ff in face_frames: | |
| # if ff["has_face"]: | |
| # cols[1].image(face_frames[0]["img"], use_column_width=True) | |
| input_frames_keys = input_frames.keys() | |
| video_cols = cols[1].columns(len(input_frames_keys)) | |
| for i, eye_type in enumerate(input_frames_keys): | |
| video_cols[i].image(input_frames[eye_type][-1], use_column_width=True) | |
| output_frames_keys = output_frames.keys() | |
| fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5)) | |
| for i, eye_type in enumerate(output_frames_keys): | |
| height, width, c = output_frames[eye_type][0].shape | |
| video_cols[i].image(output_frames[eye_type][-1], use_column_width=True) | |
| frame = np.zeros((height, width, c), dtype=np.uint8) | |
| text = f"{predicted_diameters[eye_type][0]:.2f}" | |
| frame = overlay_text_on_frame(frame, text) | |
| video_cols[i].image(frame, use_column_width=True) | |
| elif is_video(file_extension): | |
| output_video_path = f"{root_path}/tmp.webm" | |
| input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_video( | |
| cols, | |
| video_frames, | |
| tv_model, | |
| pupil_selection, | |
| output_video_path, | |
| cam_method=CAM_METHODS[-1], | |
| blink_detection=blink_detection, | |
| ) | |
| os.remove(video_path) | |
| num_columns = len(predicted_diameters) | |
| # Create a layout for the charts | |
| cols = st.columns(num_columns) | |
| colors = ["#2ca02c", "#d62728", "#1f77b4", "#ff7f0e"] # Green, Red, Blue, Orange | |
| # Iterate through categories and assign charts to columns | |
| for i, (category, values) in enumerate(predicted_diameters.items()): | |
| with cols[i]: # Directly use the column index | |
| # st.subheader(category) # Add a subheader for the category | |
| # Convert values to numeric, replacing non-numeric values with None | |
| values = [convert_diameter(value) for value in values] | |
| # Create a DataFrame from the values for Altair | |
| df = pd.DataFrame(values, columns=[category]) | |
| df["Frame"] = range(1, len(values) + 1) # Create a frame column starting from 1 | |
| # Get the min and max values for y-axis limits, ignoring None | |
| min_value = min(filter(lambda x: x is not None, values), default=None) | |
| max_value = max(filter(lambda x: x is not None, values), default=None) | |
| # Create an Altair chart with y-axis limits | |
| chart = ( | |
| alt.Chart(df) | |
| .mark_line(point=True, color=colors[i]) | |
| .encode( | |
| x=alt.X("Frame:Q", title="Frame Number"), | |
| y=alt.Y( | |
| f"{category}:Q", | |
| title="Diameter", | |
| scale=alt.Scale(domain=[min_value, max_value]), | |
| ), | |
| tooltip=[ | |
| alt.Tooltip("Frame:Q", title="Frame Number"), | |
| alt.Tooltip(f"{category}:Q", title="Diameter"), | |
| ], | |
| ) | |
| .properties(title=f"{category} - Predicted Diameters") | |
| .configure_axis(grid=True) | |
| ) | |
| # Display the Altair chart | |
| st.altair_chart(chart, use_container_width=True) | |
| if eyes_ratios is not None and len(eyes_ratios) > 0: | |
| df = pd.DataFrame(eyes_ratios, columns=["Eyes Aspect Ratio"]) | |
| df["Frame"] = range(1, len(eyes_ratios) + 1) # Create a frame column starting from 1 | |
| # Create an Altair chart for eyes_ratios | |
| line_chart = ( | |
| alt.Chart(df) | |
| .mark_line(point=True, color=colors[-1]) # Set color of the line | |
| .encode( | |
| x=alt.X("Frame:Q", title="Frame Number"), | |
| y=alt.Y("Eyes Aspect Ratio:Q", title="Eyes Aspect Ratio"), | |
| tooltip=[ | |
| alt.Tooltip("Frame:Q", title="Frame Number"), | |
| alt.Tooltip("Eyes Aspect Ratio:Q", title="Eyes Aspect Ratio"), | |
| ], | |
| ) | |
| # .properties(title="Eyes Aspect Ratios (EARs)") | |
| # .configure_axis(grid=True) | |
| ) | |
| # Create a horizontal rule at y=0.22 | |
| line1 = alt.Chart(pd.DataFrame({"y": [0.22]})).mark_rule(color="red").encode(y="y:Q") | |
| line2 = alt.Chart(pd.DataFrame({"y": [0.25]})).mark_rule(color="blue").encode(y="y:Q") | |
| # Combine line chart and horizontal line, and apply configuration | |
| final_chart = line_chart.properties(title="Eyes Aspect Ratios (EARs)") + line1 + line2 | |
| # Configure axis properties at the chart level | |
| final_chart = final_chart.configure_axis(grid=True) | |
| # Display the Altair chart | |
| st.subheader("Eyes Aspect Ratios (EARs)") | |
| st.altair_chart(final_chart, use_container_width=True) | |
| if __name__ == "__main__": | |
| main() | |