Spaces:
Running
Running
| import base64 | |
| from io import BytesIO | |
| import io | |
| import os | |
| import sys | |
| import cv2 | |
| from matplotlib import pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import torch | |
| import tempfile | |
| from PIL import Image | |
| from torchvision.transforms.functional import to_pil_image | |
| from torchvision import transforms | |
| from PIL import ImageOps | |
| import altair as alt | |
| import streamlit.components.v1 as components | |
| from torchcam.methods import CAM | |
| from torchcam import methods as torchcam_methods | |
| from torchcam.utils import overlay_mask | |
| import os.path as osp | |
| root_path = osp.abspath(osp.join(__file__, osp.pardir)) | |
| sys.path.append(root_path) | |
| from preprocessing.dataset_creation import EyeDentityDatasetCreation | |
| from utils import get_model | |
| CAM_METHODS = ["CAM"] | |
| # colors = ["#2ca02c", "#d62728", "#1f77b4", "#ff7f0e"] # Green, Red, Blue, Orange | |
| colors = ["#1f77b4", "#ff7f0e", "#636363"] # Blue, Orange, Gray | |
| def load_model(model_configs, device="cpu"): | |
| """Loads the pre-trained model.""" | |
| model_path = os.path.join(root_path, model_configs["model_path"]) | |
| model_dict = torch.load(model_path, map_location=device) | |
| model = get_model(model_configs=model_configs) | |
| model.load_state_dict(model_dict) | |
| model = model.to(device).eval() | |
| return model | |
| def extract_frames(video_path): | |
| """Extracts frames from a video file.""" | |
| vidcap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| success, image = vidcap.read() | |
| while success: | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| frames.append(image_rgb) | |
| success, image = vidcap.read() | |
| vidcap.release() | |
| return frames | |
| def resize_frame(image, max_width=640, max_height=480): | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| original_size = image.size | |
| # Resize the frame similarly to the image resizing logic | |
| if original_size[0] == original_size[1] and original_size[0] >= 256: | |
| max_size = (256, 256) | |
| else: | |
| max_size = list(original_size) | |
| if original_size[0] >= max_width: | |
| max_size[0] = max_width | |
| elif original_size[0] < 64: | |
| max_size[0] = 64 | |
| if original_size[1] >= max_height: | |
| max_size[1] = max_height | |
| elif original_size[1] < 32: | |
| max_size[1] = 32 | |
| image.thumbnail(max_size) | |
| # image = image.resize(max_size) | |
| return image | |
| def is_image(file_extension): | |
| """Checks if the file is an image.""" | |
| return file_extension.lower() in ["png", "jpeg", "jpg"] | |
| def is_video(file_extension): | |
| """Checks if the file is a video.""" | |
| return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm"] | |
| def get_codec_and_extension(file_format): | |
| """Return codec and file extension based on the format.""" | |
| if file_format == "mp4": | |
| return "H264", ".mp4" | |
| elif file_format == "avi": | |
| return "MJPG", ".avi" | |
| elif file_format == "webm": | |
| return "VP80", ".webm" | |
| else: | |
| return "MJPG", ".avi" | |
| def display_results(input_image, cam_frame, pupil_diameter, cols): | |
| """Displays the input image and overlayed CAM result.""" | |
| fig, axs = plt.subplots(1, 2, figsize=(10, 5)) | |
| axs[0].imshow(input_image) | |
| axs[0].axis("off") | |
| axs[0].set_title("Input Image") | |
| axs[1].imshow(cam_frame) | |
| axs[1].axis("off") | |
| axs[1].set_title("Overlayed CAM") | |
| cols[-1].pyplot(fig) | |
| cols[-1].text(f"Pupil Diameter: {pupil_diameter:.2f} mm") | |
| def preprocess_image(input_img, max_size=(256, 256)): | |
| """Resizes and preprocesses an image.""" | |
| input_img.thumbnail(max_size) | |
| preprocess_steps = [ | |
| transforms.ToTensor(), | |
| transforms.Resize([32, 64], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True), | |
| ] | |
| return transforms.Compose(preprocess_steps)(input_img).unsqueeze(0) | |
| def overlay_text_on_frame(frame, text, position=(16, 20)): | |
| """Write text on the image frame using OpenCV.""" | |
| return cv2.putText(frame, text, position, cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1, cv2.LINE_AA) | |
| def get_configs(blink_detection=False): | |
| upscale = "-" | |
| upscale_method_or_model = "-" | |
| if upscale == "-": | |
| sr_configs = None | |
| else: | |
| sr_configs = { | |
| "method": upscale_method_or_model, | |
| "params": {"upscale": upscale}, | |
| } | |
| config_file = { | |
| "sr_configs": sr_configs, | |
| "feature_extraction_configs": { | |
| "blink_detection": blink_detection, | |
| "upscale": upscale, | |
| "extraction_library": "mediapipe", | |
| }, | |
| } | |
| return config_file | |
| def setup(cols, pupil_selection, tv_model, output_path): | |
| left_pupil_model = None | |
| left_pupil_cam_extractor = None | |
| right_pupil_model = None | |
| right_pupil_cam_extractor = None | |
| output_frames = {} | |
| input_frames = {} | |
| predicted_diameters = {} | |
| pred_diameters_frames = {} | |
| if pupil_selection == "both": | |
| selected_eyes = ["left_eye", "right_eye"] | |
| elif pupil_selection == "left_pupil": | |
| selected_eyes = ["left_eye"] | |
| elif pupil_selection == "right_pupil": | |
| selected_eyes = ["right_eye"] | |
| for i, eye_type in enumerate(selected_eyes): | |
| model_configs = { | |
| "model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt", | |
| "registered_model_name": tv_model, | |
| "num_classes": 1, | |
| } | |
| if eye_type == "left_eye": | |
| left_pupil_model = load_model(model_configs) | |
| left_pupil_cam_extractor = None | |
| output_frames[eye_type] = [] | |
| input_frames[eye_type] = [] | |
| predicted_diameters[eye_type] = [] | |
| pred_diameters_frames[eye_type] = [] | |
| else: | |
| right_pupil_model = load_model(model_configs) | |
| right_pupil_cam_extractor = None | |
| output_frames[eye_type] = [] | |
| input_frames[eye_type] = [] | |
| predicted_diameters[eye_type] = [] | |
| pred_diameters_frames[eye_type] = [] | |
| video_placeholders = {} | |
| if output_path: | |
| video_cols = cols[1].columns(len(input_frames.keys())) | |
| for i, eye_type in enumerate(list(input_frames.keys())): | |
| video_placeholders[eye_type] = video_cols[i].empty() | |
| return ( | |
| selected_eyes, | |
| input_frames, | |
| output_frames, | |
| predicted_diameters, | |
| pred_diameters_frames, | |
| video_placeholders, | |
| left_pupil_model, | |
| left_pupil_cam_extractor, | |
| right_pupil_model, | |
| right_pupil_cam_extractor, | |
| ) | |
| def process_frames( | |
| cols, input_imgs, tv_model, pupil_selection, cam_method, output_path=None, codec=None, blink_detection=False | |
| ): | |
| config_file = get_configs(blink_detection) | |
| face_frames = [] | |
| ( | |
| selected_eyes, | |
| input_frames, | |
| output_frames, | |
| predicted_diameters, | |
| pred_diameters_frames, | |
| video_placeholders, | |
| left_pupil_model, | |
| left_pupil_cam_extractor, | |
| right_pupil_model, | |
| right_pupil_cam_extractor, | |
| ) = setup(cols, pupil_selection, tv_model, output_path) | |
| ds_creation = EyeDentityDatasetCreation( | |
| feature_extraction_configs=config_file["feature_extraction_configs"], | |
| sr_configs=config_file["sr_configs"], | |
| ) | |
| preprocess_steps = [ | |
| transforms.Resize( | |
| [32, 64], | |
| interpolation=transforms.InterpolationMode.BICUBIC, | |
| antialias=True, | |
| ), | |
| transforms.ToTensor(), | |
| ] | |
| preprocess_function = transforms.Compose(preprocess_steps) | |
| eyes_ratios = [] | |
| for idx, input_img in enumerate(input_imgs): | |
| img = np.array(input_img) | |
| ds_results = ds_creation(img) | |
| left_eye = None | |
| right_eye = None | |
| blinked = False | |
| eyes_ratio = None | |
| if ds_results is not None and "face" in ds_results: | |
| face_img = to_pil_image(ds_results["face"]) | |
| has_face = True | |
| else: | |
| face_img = to_pil_image(np.zeros((256, 256, 3), dtype=np.uint8)) | |
| has_face = False | |
| face_frames.append({"has_face": has_face, "img": face_img}) | |
| if ds_results is not None and "eyes" in ds_results.keys(): | |
| blinked = ds_results["eyes"]["blinked"] | |
| eyes_ratio = ds_results["eyes"]["eyes_ratio"] | |
| if eyes_ratio is not None: | |
| eyes_ratios.append(eyes_ratio) | |
| if "left_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["left_eye"] is not None: | |
| left_eye = ds_results["eyes"]["left_eye"] | |
| left_eye = to_pil_image(left_eye).convert("RGB") | |
| left_eye = preprocess_function(left_eye) | |
| left_eye = left_eye.unsqueeze(0) | |
| if "right_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["right_eye"] is not None: | |
| right_eye = ds_results["eyes"]["right_eye"] | |
| right_eye = to_pil_image(right_eye).convert("RGB") | |
| right_eye = preprocess_function(right_eye) | |
| right_eye = right_eye.unsqueeze(0) | |
| else: | |
| input_img = preprocess_function(input_img) | |
| input_img = input_img.unsqueeze(0) | |
| if pupil_selection == "left_pupil": | |
| left_eye = input_img | |
| elif pupil_selection == "right_pupil": | |
| right_eye = input_img | |
| else: | |
| left_eye = input_img | |
| right_eye = input_img | |
| for i, eye_type in enumerate(selected_eyes): | |
| if blinked: | |
| if left_eye is not None and eye_type == "left_eye": | |
| _, height, width = left_eye.squeeze(0).shape | |
| input_image_pil = to_pil_image(left_eye.squeeze(0)) | |
| elif right_eye is not None and eye_type == "right_eye": | |
| _, height, width = right_eye.squeeze(0).shape | |
| input_image_pil = to_pil_image(right_eye.squeeze(0)) | |
| input_img_np = np.array(input_image_pil) | |
| zeros_img = to_pil_image(np.zeros((height, width, 3), dtype=np.uint8)) | |
| output_img_np = overlay_text_on_frame(np.array(zeros_img), "blink") | |
| predicted_diameter = "blink" | |
| else: | |
| if left_eye is not None and eye_type == "left_eye": | |
| if left_pupil_cam_extractor is None: | |
| if tv_model == "ResNet18": | |
| target_layer = left_pupil_model.resnet.layer4[-1].conv2 | |
| elif tv_model == "ResNet50": | |
| target_layer = left_pupil_model.resnet.layer4[-1].conv3 | |
| else: | |
| raise Exception(f"No target layer available for selected model: {tv_model}") | |
| left_pupil_cam_extractor = torchcam_methods.__dict__[cam_method]( | |
| left_pupil_model, | |
| target_layer=target_layer, | |
| fc_layer=left_pupil_model.resnet.fc, | |
| input_shape=left_eye.shape, | |
| ) | |
| output = left_pupil_model(left_eye) | |
| predicted_diameter = output[0].item() | |
| act_maps = left_pupil_cam_extractor(0, output) | |
| activation_map = act_maps[0] if len(act_maps) == 1 else left_pupil_cam_extractor.fuse_cams(act_maps) | |
| input_image_pil = to_pil_image(left_eye.squeeze(0)) | |
| elif right_eye is not None and eye_type == "right_eye": | |
| if right_pupil_cam_extractor is None: | |
| if tv_model == "ResNet18": | |
| target_layer = right_pupil_model.resnet.layer4[-1].conv2 | |
| elif tv_model == "ResNet50": | |
| target_layer = right_pupil_model.resnet.layer4[-1].conv3 | |
| else: | |
| raise Exception(f"No target layer available for selected model: {tv_model}") | |
| right_pupil_cam_extractor = torchcam_methods.__dict__[cam_method]( | |
| right_pupil_model, | |
| target_layer=target_layer, | |
| fc_layer=right_pupil_model.resnet.fc, | |
| input_shape=right_eye.shape, | |
| ) | |
| output = right_pupil_model(right_eye) | |
| predicted_diameter = output[0].item() | |
| act_maps = right_pupil_cam_extractor(0, output) | |
| activation_map = ( | |
| act_maps[0] if len(act_maps) == 1 else right_pupil_cam_extractor.fuse_cams(act_maps) | |
| ) | |
| input_image_pil = to_pil_image(right_eye.squeeze(0)) | |
| # Create CAM overlay | |
| activation_map_pil = to_pil_image(activation_map, mode="F") | |
| result = overlay_mask(input_image_pil, activation_map_pil, alpha=0.5) | |
| input_img_np = np.array(input_image_pil) | |
| output_img_np = np.array(result) | |
| # Add frame and predicted diameter to lists | |
| input_frames[eye_type].append(input_img_np) | |
| output_frames[eye_type].append(output_img_np) | |
| predicted_diameters[eye_type].append(predicted_diameter) | |
| if output_path: | |
| height, width, _ = output_img_np.shape | |
| frame = np.zeros((height, width, 3), dtype=np.uint8) | |
| if not isinstance(predicted_diameter, str): | |
| text = f"{predicted_diameter:.2f}" | |
| else: | |
| text = predicted_diameter | |
| frame = overlay_text_on_frame(frame, text) | |
| pred_diameters_frames[eye_type].append(frame) | |
| combined_frame = np.vstack((input_img_np, output_img_np, frame)) | |
| img_base64 = pil_image_to_base64(Image.fromarray(combined_frame)) | |
| image_html = f'<div style="width: {str(50*len(selected_eyes))}%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>' | |
| video_placeholders[eye_type].markdown(image_html, unsafe_allow_html=True) | |
| # video_placeholders[eye_type].image(combined_frame, use_column_width=True) | |
| st.session_state.current_frame = idx + 1 | |
| txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>" | |
| st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True) | |
| if output_path: | |
| combine_and_show_frames( | |
| input_frames, output_frames, pred_diameters_frames, output_path, codec, video_placeholders | |
| ) | |
| return input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios | |
| # Function to display video with autoplay and loop | |
| def display_video_with_autoplay(video_col, video_path, width): | |
| video_html = f""" | |
| <video width="{str(width)}%" height="auto" autoplay loop muted> | |
| <source src="data:video/mp4;base64,{video_path}" type="video/mp4"> | |
| </video> | |
| """ | |
| video_col.markdown(video_html, unsafe_allow_html=True) | |
| def process_video(cols, video_frames, tv_model, pupil_selection, output_path, cam_method, blink_detection=False): | |
| resized_frames = [] | |
| for i, frame in enumerate(video_frames): | |
| input_img = resize_frame(frame, max_width=640, max_height=480) | |
| resized_frames.append(input_img) | |
| file_format = output_path.split(".")[-1] | |
| codec, extension = get_codec_and_extension(file_format) | |
| input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames( | |
| cols, resized_frames, tv_model, pupil_selection, cam_method, output_path, codec, blink_detection | |
| ) | |
| return input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios | |
| # Function to convert string values to float or None | |
| def convert_diameter(value): | |
| try: | |
| return float(value) | |
| except (ValueError, TypeError): | |
| return None # Return None if conversion fails | |
| def combine_and_show_frames(input_frames, cam_frames, pred_diameters_frames, output_path, codec, video_cols): | |
| # Assuming all frames have the same keys (eye types) | |
| eye_types = input_frames.keys() | |
| for i, eye_type in enumerate(eye_types): | |
| in_frames = input_frames[eye_type] | |
| cam_out_frames = cam_frames[eye_type] | |
| pred_diameters_text_frames = pred_diameters_frames[eye_type] | |
| # Get frame properties (assuming all frames have the same dimensions) | |
| height, width, _ = in_frames[0].shape | |
| fourcc = cv2.VideoWriter_fourcc(*codec) | |
| fps = 10.0 | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height * 3)) # Width is tripled for concatenation | |
| # Loop through each set of frames and concatenate them | |
| for j in range(len(in_frames)): | |
| input_frame = in_frames[j] | |
| cam_frame = cam_out_frames[j] | |
| pred_frame = pred_diameters_text_frames[j] | |
| # Convert frames to BGR if necessary | |
| input_frame_bgr = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR) | |
| cam_frame_bgr = cv2.cvtColor(cam_frame, cv2.COLOR_RGB2BGR) | |
| pred_frame_bgr = cv2.cvtColor(pred_frame, cv2.COLOR_RGB2BGR) | |
| # Concatenate frames horizontally (input, cam, pred) | |
| combined_frame = np.vstack((input_frame_bgr, cam_frame_bgr, pred_frame_bgr)) | |
| # Write the combined frame to the video | |
| out.write(combined_frame) | |
| # Release the video writer | |
| out.release() | |
| # Read the video and encode it in base64 for displaying | |
| with open(output_path, "rb") as video_file: | |
| video_bytes = video_file.read() | |
| video_base64 = base64.b64encode(video_bytes).decode("utf-8") | |
| # Display the combined video | |
| display_video_with_autoplay(video_cols[eye_type], video_base64, width=len(video_cols) * 50) | |
| # Clean up | |
| os.remove(output_path) | |
| def set_input_image_on_ui(uploaded_file, cols): | |
| input_img = Image.open(BytesIO(uploaded_file.read())).convert("RGB") | |
| # NOTE: images taken with phone camera has an EXIF data field which often rotates images taken with the phone in a tilted position. PIL has a utility function that removes this data and ‘uprights’ the image. | |
| input_img = ImageOps.exif_transpose(input_img) | |
| input_img = resize_frame(input_img, max_width=640, max_height=480) | |
| input_img = resize_frame(input_img, max_width=640, max_height=480) | |
| cols[0].image(input_img, use_column_width=True) | |
| st.session_state.total_frames = 1 | |
| return input_img | |
| def set_input_video_on_ui(uploaded_file, cols): | |
| tfile = tempfile.NamedTemporaryFile(delete=False) | |
| try: | |
| tfile.write(uploaded_file.read()) | |
| except Exception: | |
| tfile.write(uploaded_file) | |
| video_path = tfile.name | |
| video_frames = extract_frames(video_path) | |
| cols[0].video(video_path) | |
| st.session_state.total_frames = len(video_frames) | |
| return video_frames, video_path | |
| def set_frames_processed_count_placeholder(cols): | |
| st.session_state.current_frame = 0 | |
| st.session_state.frame_placeholder = cols[0].empty() | |
| txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>" | |
| st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True) | |
| def video_to_bytes(video_path): | |
| # Open the video file in binary mode and return the bytes | |
| with open(video_path, "rb") as video_file: | |
| return video_file.read() | |
| def display_video_library(video_folder="./sample_videos"): | |
| # Get all video files from the folder | |
| video_files = [f for f in os.listdir(video_folder) if f.endswith(".webm")] | |
| # Store the selected video path | |
| selected_video_path = None | |
| # Calculate number of columns (adjust based on your layout preferences) | |
| num_columns = 3 # For a grid of 3 videos per row | |
| # Display videos in a grid layout with 'Select' button for each video | |
| for i in range(0, len(video_files), num_columns): | |
| cols = st.columns(num_columns) | |
| for idx, video_file in enumerate(video_files[i : i + num_columns]): | |
| with cols[idx]: | |
| st.subheader(video_file.split(".")[0]) # Use the file name as the title | |
| video_path = os.path.join(video_folder, video_file) | |
| st.video(video_path) # Show the video | |
| if st.button(f"Select {video_file.split('.')[0]}", key=video_file, type="primary"): | |
| st.session_state.clear() | |
| st.toast("Scroll Down to see the input and predictions", icon="⏬") | |
| selected_video_path = video_path # Store the path of the selected video | |
| return selected_video_path | |
| def set_page_info_and_sidebar_info(): | |
| st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide") | |
| st.title("👁️ PupilSense 👁️🕵️♂️") | |
| # st.markdown("Upload your own images or video **OR** select from our sample library below") | |
| st.markdown( | |
| "<p style='font-size: 30px;'>" | |
| "Upload your own image 🖼️ or video 🎞️ <strong>OR</strong> select from our sample videos 📚" | |
| "</p>", | |
| unsafe_allow_html=True, | |
| ) | |
| # video_path = display_video_library() | |
| show_demo_videos = st.sidebar.checkbox("Show Sample Videos", value=False) | |
| if show_demo_videos: | |
| video_path = display_video_library() | |
| else: | |
| video_path = None | |
| st.markdown("<hr id='target_element' style='border: 1px solid #6d6d6d; margin: 20px 0;'>", unsafe_allow_html=True) | |
| cols = st.columns((1, 1)) | |
| cols[0].header("Input") | |
| cols[-1].header("Prediction") | |
| st.markdown("<hr style='border: 1px solid #6d6d6d; margin: 20px 0;'>", unsafe_allow_html=True) | |
| LABEL_MAP = ["left_pupil", "right_pupil"] | |
| TV_MODELS = ["ResNet18", "ResNet50"] | |
| if "uploader_key" not in st.session_state: | |
| st.session_state["uploader_key"] = 1 | |
| st.sidebar.title("Upload Face 👨🦱 or Eye 👁️") | |
| uploaded_file = st.sidebar.file_uploader( | |
| "Upload Image or Video", | |
| type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"], | |
| key=st.session_state["uploader_key"], | |
| ) | |
| if uploaded_file is not None: | |
| st.session_state["uploaded_file"] = uploaded_file | |
| st.sidebar.title("Setup") | |
| pupil_selection = st.sidebar.selectbox( | |
| "Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation" | |
| ) | |
| tv_model = st.sidebar.selectbox("Classification model", TV_MODELS, help="Supported Models") | |
| blink_detection = st.sidebar.checkbox("Detect Blinks", value=True) | |
| st.markdown("<style>#vg-tooltip-element{z-index: 1000051}</style>", unsafe_allow_html=True) | |
| if "uploaded_file" not in st.session_state: | |
| st.session_state["uploaded_file"] = None | |
| if "og_video_path" not in st.session_state: | |
| st.session_state["og_video_path"] = None | |
| if uploaded_file is None and video_path is not None: | |
| video_bytes = video_to_bytes(video_path) | |
| uploaded_file = video_bytes | |
| st.session_state["uploaded_file"] = uploaded_file | |
| st.session_state["og_video_path"] = video_path | |
| st.session_state["uploader_key"] = 0 | |
| return ( | |
| cols, | |
| st.session_state["og_video_path"], | |
| st.session_state["uploaded_file"], | |
| pupil_selection, | |
| tv_model, | |
| blink_detection, | |
| ) | |
| def pil_image_to_base64(img): | |
| """Convert a PIL Image to a base64 encoded string.""" | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return img_str | |
| def process_image_and_vizualize_data(cols, input_img, tv_model, pupil_selection, blink_detection): | |
| input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames( | |
| cols, | |
| [input_img], | |
| tv_model, | |
| pupil_selection, | |
| cam_method=CAM_METHODS[-1], | |
| blink_detection=blink_detection, | |
| ) | |
| # for ff in face_frames: | |
| # if ff["has_face"]: | |
| # cols[1].image(face_frames[0]["img"], use_column_width=True) | |
| input_frames_keys = input_frames.keys() | |
| video_cols = cols[1].columns(len(input_frames_keys)) | |
| for i, eye_type in enumerate(input_frames_keys): | |
| # Check the pupil_selection and set the width accordingly | |
| if pupil_selection == "both": | |
| video_cols[i].image(input_frames[eye_type][-1], use_column_width=True) | |
| else: | |
| img_base64 = pil_image_to_base64(Image.fromarray(input_frames[eye_type][-1])) | |
| image_html = f'<div style="width: 50%; margin-bottom: 1.2%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>' | |
| video_cols[i].markdown(image_html, unsafe_allow_html=True) | |
| output_frames_keys = output_frames.keys() | |
| fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5)) | |
| for i, eye_type in enumerate(output_frames_keys): | |
| height, width, c = output_frames[eye_type][0].shape | |
| frame = np.zeros((height, width, c), dtype=np.uint8) | |
| text = f"{predicted_diameters[eye_type][0]:.2f}" | |
| frame = overlay_text_on_frame(frame, text) | |
| if pupil_selection == "both": | |
| video_cols[i].image(output_frames[eye_type][-1], use_column_width=True) | |
| video_cols[i].image(frame, use_column_width=True) | |
| else: | |
| img_base64 = pil_image_to_base64(Image.fromarray(output_frames[eye_type][-1])) | |
| image_html = f'<div style="width: 50%; margin-top: 1.2%; margin-bottom: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>' | |
| video_cols[i].markdown(image_html, unsafe_allow_html=True) | |
| img_base64 = pil_image_to_base64(Image.fromarray(frame)) | |
| image_html = f'<div style="width: 50%; margin-top: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>' | |
| video_cols[i].markdown(image_html, unsafe_allow_html=True) | |
| return None | |
| def plot_ears(eyes_ratios, eyes_df): | |
| eyes_df["EAR"] = eyes_ratios | |
| df = pd.DataFrame(eyes_ratios, columns=["EAR"]) | |
| df["Frame"] = range(1, len(eyes_ratios) + 1) # Create a frame column starting from 1 | |
| # Create an Altair chart for eyes_ratios | |
| line_chart = ( | |
| alt.Chart(df) | |
| .mark_line(color=colors[-1]) # Set color of the line | |
| .encode( | |
| x=alt.X("Frame:Q", title="Frame Number"), | |
| y=alt.Y("EAR:Q", title="Eyes Aspect Ratio"), | |
| tooltip=["Frame", "EAR"], | |
| ) | |
| # .properties(title="Eyes Aspect Ratios (EARs)") | |
| # .configure_axis(grid=True) | |
| ) | |
| points_chart = line_chart.mark_point(color=colors[-1], filled=True) | |
| # Create a horizontal rule at y=0.22 | |
| line1 = alt.Chart(pd.DataFrame({"y": [0.22]})).mark_rule(color="red").encode(y="y:Q") | |
| line2 = alt.Chart(pd.DataFrame({"y": [0.25]})).mark_rule(color="green").encode(y="y:Q") | |
| # Add text annotations for the lines | |
| text1 = ( | |
| alt.Chart(pd.DataFrame({"y": [0.22], "label": ["Definite Blinks (<=0.22)"]})) | |
| .mark_text(align="left", dx=100, dy=9, color="red", size=16) | |
| .encode(y="y:Q", text="label:N") | |
| ) | |
| text2 = ( | |
| alt.Chart(pd.DataFrame({"y": [0.25], "label": ["No Blinks (>=0.25)"]})) | |
| .mark_text(align="left", dx=-150, dy=-9, color="green", size=16) | |
| .encode(y="y:Q", text="label:N") | |
| ) | |
| # Add gray area text for the region between red and green lines | |
| gray_area_text = ( | |
| alt.Chart(pd.DataFrame({"y": [0.235], "label": ["Gray Area"]})) | |
| .mark_text(align="left", dx=0, dy=0, color="gray", size=16) | |
| .encode(y="y:Q", text="label:N") | |
| ) | |
| # Combine all elements: line chart, points, rules, and text annotations | |
| final_chart = ( | |
| line_chart.properties(title="Eyes Aspect Ratios (EARs)") | |
| + points_chart | |
| + line1 | |
| + line2 | |
| + text1 | |
| + text2 | |
| + gray_area_text | |
| ).interactive() | |
| # Configure axis properties at the chart level | |
| final_chart = final_chart.configure_axis(grid=True) | |
| # Display the Altair chart | |
| # st.subheader("Eyes Aspect Ratios (EARs)") | |
| st.altair_chart(final_chart, use_container_width=True) | |
| return eyes_df | |
| def plot_individual_charts(predicted_diameters, cols): | |
| # Iterate through categories and assign charts to columns | |
| for i, (category, values) in enumerate(predicted_diameters.items()): | |
| with cols[i]: # Directly use the column index | |
| # st.subheader(category) # Add a subheader for the category | |
| if "left" in category: | |
| selected_color = colors[0] | |
| elif "right" in category: | |
| selected_color = colors[1] | |
| else: | |
| selected_color = colors[i] | |
| # Convert values to numeric, replacing non-numeric values with None | |
| values = [convert_diameter(value) for value in values] | |
| if "left" in category: | |
| category_name = "Left Pupil Diameter" | |
| else: | |
| category_name = "Right Pupil Diameter" | |
| # Create a DataFrame from the values for Altair | |
| df = pd.DataFrame( | |
| { | |
| "Frame": range(1, len(values) + 1), | |
| category_name: values, | |
| } | |
| ) | |
| # Get the min and max values for y-axis limits, ignoring None | |
| min_value = min(filter(lambda x: x is not None, values), default=None) | |
| max_value = max(filter(lambda x: x is not None, values), default=None) | |
| # Create an Altair chart with y-axis limits | |
| line_chart = ( | |
| alt.Chart(df) | |
| .mark_line(color=selected_color) | |
| .encode( | |
| x=alt.X("Frame:Q", title="Frame Number"), | |
| y=alt.Y( | |
| f"{category_name}:Q", | |
| title="Diameter", | |
| scale=alt.Scale(domain=[min_value, max_value]), | |
| ), | |
| tooltip=[ | |
| "Frame", | |
| alt.Tooltip(f"{category_name}:Q", title="Diameter"), | |
| ], | |
| ) | |
| # .properties(title=f"{category} - Predicted Diameters") | |
| # .configure_axis(grid=True) | |
| ) | |
| points_chart = line_chart.mark_point(color=selected_color, filled=True) | |
| final_chart = ( | |
| line_chart.properties( | |
| title=f"{'Left Pupil' if 'left' in category else 'Right Pupil'} - Predicted Diameters" | |
| ) | |
| + points_chart | |
| ).interactive() | |
| final_chart = final_chart.configure_axis(grid=True) | |
| # Display the Altair chart | |
| st.altair_chart(final_chart, use_container_width=True) | |
| return df | |
| def plot_combined_charts(predicted_diameters): | |
| all_min_values = [] | |
| all_max_values = [] | |
| # Create an empty DataFrame to store combined data for plotting | |
| combined_df = pd.DataFrame() | |
| # Iterate through categories and collect data | |
| for category, values in predicted_diameters.items(): | |
| # Convert values to numeric, replacing non-numeric values with None | |
| values = [convert_diameter(value) for value in values] | |
| # Get the min and max values for y-axis limits, ignoring None | |
| min_value = min(filter(lambda x: x is not None, values), default=None) | |
| max_value = max(filter(lambda x: x is not None, values), default=None) | |
| all_min_values.append(min_value) | |
| all_max_values.append(max_value) | |
| category = "left_pupil" if "left" in category else "right_pupil" | |
| # Create a DataFrame from the values | |
| df = pd.DataFrame( | |
| { | |
| "Diameter": values, | |
| "Frame": range(1, len(values) + 1), # Create a frame column starting from 1 | |
| "Category": category, # Add a column to specify the category | |
| } | |
| ) | |
| # Append to combined DataFrame | |
| combined_df = pd.concat([combined_df, df], ignore_index=True) | |
| combined_chart = ( | |
| alt.Chart(combined_df) | |
| .mark_line() | |
| .encode( | |
| x=alt.X("Frame:Q", title="Frame Number"), | |
| y=alt.Y( | |
| "Diameter:Q", | |
| title="Diameter", | |
| scale=alt.Scale(domain=[min(all_min_values), max(all_max_values)]), | |
| ), | |
| color=alt.Color("Category:N", scale=alt.Scale(range=colors), title="Pupil Type"), | |
| tooltip=["Frame", "Diameter:Q", "Category:N"], | |
| ) | |
| ) | |
| points_chart = combined_chart.mark_point(filled=True) | |
| final_chart = (combined_chart.properties(title="Predicted Diameters") + points_chart).interactive() | |
| final_chart = final_chart.configure_axis(grid=True) | |
| # Display the combined chart | |
| st.altair_chart(final_chart, use_container_width=True) | |
| # -------------------------------------------- | |
| # Convert to a DataFrame | |
| left_pupil_values = [convert_diameter(value) for value in predicted_diameters["left_eye"]] | |
| right_pupil_values = [convert_diameter(value) for value in predicted_diameters["right_eye"]] | |
| df = pd.DataFrame( | |
| { | |
| "Frame": range(1, len(left_pupil_values) + 1), | |
| "Left Pupil Diameter": left_pupil_values, | |
| "Right Pupil Diameter": right_pupil_values, | |
| } | |
| ) | |
| # Calculate the difference between left and right pupil diameters | |
| df["Difference Value"] = df["Left Pupil Diameter"] - df["Right Pupil Diameter"] | |
| # Determine the status of the difference | |
| df["Difference Status"] = df.apply( | |
| lambda row: "L>R" if row["Left Pupil Diameter"] > row["Right Pupil Diameter"] else "L<R", | |
| axis=1, | |
| ) | |
| return df | |
| def process_video_and_visualize_data(cols, video_frames, tv_model, pupil_selection, blink_detection, video_path): | |
| output_video_path = f"{root_path}/tmp.webm" | |
| input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_video( | |
| cols, | |
| video_frames, | |
| tv_model, | |
| pupil_selection, | |
| output_video_path, | |
| cam_method=CAM_METHODS[-1], | |
| blink_detection=blink_detection, | |
| ) | |
| os.remove(video_path) | |
| num_columns = len(predicted_diameters) | |
| cols = st.columns(num_columns) | |
| if num_columns == 2: | |
| df = plot_combined_charts(predicted_diameters) | |
| else: | |
| df = plot_individual_charts(predicted_diameters, cols) | |
| if eyes_ratios is not None and len(eyes_ratios) > 0: | |
| df = plot_ears(eyes_ratios, df) | |
| st.dataframe(df, hide_index=True, use_container_width=True) | |