Spaces:
Running
Running
| import base64 | |
| from io import BytesIO | |
| import os | |
| import sys | |
| import cv2 | |
| from matplotlib import pyplot as plt | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| import tempfile | |
| from PIL import Image | |
| from torchvision.transforms.functional import to_pil_image | |
| from torchvision import transforms | |
| from torchcam.methods import CAM | |
| from torchcam import methods as torchcam_methods | |
| from torchcam.utils import overlay_mask | |
| import os.path as osp | |
| root_path = osp.abspath(osp.join(__file__, osp.pardir)) | |
| sys.path.append(root_path) | |
| from preprocessing.dataset_creation import EyeDentityDatasetCreation | |
| from utils import get_model | |
| def load_model(model_configs, device="cpu"): | |
| """Loads the pre-trained model.""" | |
| model_path = os.path.join(root_path, model_configs["model_path"]) | |
| model_dict = torch.load(model_path, map_location=device) | |
| model = get_model(model_configs=model_configs) | |
| model.load_state_dict(model_dict) | |
| model = model.to(device).eval() | |
| return model | |
| def extract_frames(video_path): | |
| """Extracts frames from a video file.""" | |
| vidcap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| success, image = vidcap.read() | |
| while success: | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| frames.append(image_rgb) | |
| success, image = vidcap.read() | |
| vidcap.release() | |
| return frames | |
| def resize_frame(image, max_width=640, max_height=480): | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| original_size = image.size | |
| # Resize the frame similarly to the image resizing logic | |
| if original_size[0] == original_size[1] and original_size[0] >= 256: | |
| max_size = (256, 256) | |
| else: | |
| max_size = list(original_size) | |
| if original_size[0] >= max_width: | |
| max_size[0] = max_width | |
| elif original_size[0] < 64: | |
| max_size[0] = 64 | |
| if original_size[1] >= max_height: | |
| max_size[1] = max_height | |
| elif original_size[1] < 32: | |
| max_size[1] = 32 | |
| image.thumbnail(max_size) | |
| # image = image.resize(max_size) | |
| return image | |
| def is_image(file_extension): | |
| """Checks if the file is an image.""" | |
| return file_extension.lower() in ["png", "jpeg", "jpg"] | |
| def is_video(file_extension): | |
| """Checks if the file is a video.""" | |
| return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm"] | |
| def display_results(input_image, cam_frame, pupil_diameter, cols): | |
| """Displays the input image and overlayed CAM result.""" | |
| fig, axs = plt.subplots(1, 2, figsize=(10, 5)) | |
| axs[0].imshow(input_image) | |
| axs[0].axis("off") | |
| axs[0].set_title("Input Image") | |
| axs[1].imshow(cam_frame) | |
| axs[1].axis("off") | |
| axs[1].set_title("Overlayed CAM") | |
| cols[-1].pyplot(fig) | |
| cols[-1].text(f"Pupil Diameter: {pupil_diameter:.2f} mm") | |
| def preprocess_image(input_img, max_size=(256, 256)): | |
| """Resizes and preprocesses an image.""" | |
| input_img.thumbnail(max_size) | |
| preprocess_steps = [ | |
| transforms.ToTensor(), | |
| transforms.Resize([32, 64], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True), | |
| ] | |
| return transforms.Compose(preprocess_steps)(input_img).unsqueeze(0) | |
| def overlay_text_on_frame(frame, text, position=(16, 20)): | |
| """Write text on the image frame using OpenCV.""" | |
| return cv2.putText(frame, text, position, cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1, cv2.LINE_AA) | |
| def get_configs(blink_detection=False): | |
| upscale = "-" | |
| upscale_method_or_model = "-" | |
| if upscale == "-": | |
| sr_configs = None | |
| else: | |
| sr_configs = { | |
| "method": upscale_method_or_model, | |
| "params": {"upscale": upscale}, | |
| } | |
| config_file = { | |
| "sr_configs": sr_configs, | |
| "feature_extraction_configs": { | |
| "blink_detection": blink_detection, | |
| "upscale": upscale, | |
| "extraction_library": "mediapipe", | |
| }, | |
| } | |
| return config_file | |
| def setup(cols, pupil_selection, tv_model, output_path): | |
| left_pupil_model = None | |
| left_pupil_cam_extractor = None | |
| right_pupil_model = None | |
| right_pupil_cam_extractor = None | |
| output_frames = {} | |
| input_frames = {} | |
| predicted_diameters = {} | |
| if pupil_selection == "both": | |
| selected_eyes = ["left_eye", "right_eye"] | |
| elif pupil_selection == "left_pupil": | |
| selected_eyes = ["left_eye"] | |
| elif pupil_selection == "right_pupil": | |
| selected_eyes = ["right_eye"] | |
| for i, eye_type in enumerate(selected_eyes): | |
| model_configs = { | |
| "model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt", | |
| "registered_model_name": tv_model, | |
| "num_classes": 1, | |
| } | |
| if eye_type == "left_eye": | |
| left_pupil_model = load_model(model_configs) | |
| left_pupil_cam_extractor = None | |
| output_frames[eye_type] = [] | |
| input_frames[eye_type] = [] | |
| predicted_diameters[eye_type] = [] | |
| else: | |
| right_pupil_model = load_model(model_configs) | |
| right_pupil_cam_extractor = None | |
| output_frames[eye_type] = [] | |
| input_frames[eye_type] = [] | |
| predicted_diameters[eye_type] = [] | |
| video_input_placeholders = {} | |
| video_output_placeholders = {} | |
| video_predictions_placeholders = {} | |
| if output_path: | |
| video_cols = cols[1].columns(len(input_frames.keys())) | |
| for i, eye_type in enumerate(list(input_frames.keys())): | |
| video_input_placeholders[eye_type] = video_cols[i].empty() | |
| for i, eye_type in enumerate(list(input_frames.keys())): | |
| video_output_placeholders[eye_type] = video_cols[i].empty() | |
| for i, eye_type in enumerate(list(input_frames.keys())): | |
| video_predictions_placeholders[eye_type] = video_cols[i].empty() | |
| return ( | |
| selected_eyes, | |
| input_frames, | |
| output_frames, | |
| predicted_diameters, | |
| video_input_placeholders, | |
| video_output_placeholders, | |
| video_predictions_placeholders, | |
| left_pupil_model, | |
| left_pupil_cam_extractor, | |
| right_pupil_model, | |
| right_pupil_cam_extractor, | |
| ) | |
| def process_frames( | |
| cols, input_imgs, tv_model, pupil_selection, cam_method, output_path=None, codec=None, blink_detection=False | |
| ): | |
| config_file = get_configs(blink_detection) | |
| face_frames = [] | |
| ( | |
| selected_eyes, | |
| input_frames, | |
| output_frames, | |
| predicted_diameters, | |
| video_input_placeholders, | |
| video_output_placeholders, | |
| video_predictions_placeholders, | |
| left_pupil_model, | |
| left_pupil_cam_extractor, | |
| right_pupil_model, | |
| right_pupil_cam_extractor, | |
| ) = setup(cols, pupil_selection, tv_model, output_path) | |
| ds_creation = EyeDentityDatasetCreation( | |
| feature_extraction_configs=config_file["feature_extraction_configs"], | |
| sr_configs=config_file["sr_configs"], | |
| ) | |
| preprocess_steps = [ | |
| transforms.ToTensor(), | |
| transforms.Resize( | |
| [32, 64], | |
| interpolation=transforms.InterpolationMode.BICUBIC, | |
| antialias=True, | |
| ), | |
| ] | |
| preprocess_function = transforms.Compose(preprocess_steps) | |
| eyes_ratios = [] | |
| for idx, input_img in enumerate(input_imgs): | |
| img = np.array(input_img) | |
| ds_results = ds_creation(img) | |
| left_eye = None | |
| right_eye = None | |
| blinked = False | |
| eyes_ratio = None | |
| if ds_results is not None and "face" in ds_results: | |
| face_img = to_pil_image(ds_results["face"]) | |
| has_face = True | |
| else: | |
| face_img = to_pil_image(np.zeros((256, 256, 3), dtype=np.uint8)) | |
| has_face = False | |
| face_frames.append({"has_face": has_face, "img": face_img}) | |
| if ds_results is not None and "eyes" in ds_results.keys(): | |
| blinked = ds_results["eyes"]["blinked"] | |
| eyes_ratio = ds_results["eyes"]["eyes_ratio"] | |
| if eyes_ratio is not None: | |
| eyes_ratios.append(eyes_ratio) | |
| if "left_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["left_eye"] is not None: | |
| left_eye = ds_results["eyes"]["left_eye"] | |
| left_eye = to_pil_image(left_eye).convert("RGB") | |
| left_eye = preprocess_function(left_eye) | |
| left_eye = left_eye.unsqueeze(0) | |
| if "right_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["right_eye"] is not None: | |
| right_eye = ds_results["eyes"]["right_eye"] | |
| right_eye = to_pil_image(right_eye).convert("RGB") | |
| right_eye = preprocess_function(right_eye) | |
| right_eye = right_eye.unsqueeze(0) | |
| else: | |
| input_img = preprocess_function(input_img) | |
| input_img = input_img.unsqueeze(0) | |
| if pupil_selection == "left_pupil": | |
| left_eye = input_img | |
| elif pupil_selection == "right_pupil": | |
| right_eye = input_img | |
| else: | |
| left_eye = input_img | |
| right_eye = input_img | |
| for i, eye_type in enumerate(selected_eyes): | |
| if blinked: | |
| if left_eye is not None and eye_type == "left_eye": | |
| _, height, width = left_eye.squeeze(0).shape | |
| input_image_pil = to_pil_image(left_eye.squeeze(0)) | |
| elif right_eye is not None and eye_type == "right_eye": | |
| _, height, width = right_eye.squeeze(0).shape | |
| input_image_pil = to_pil_image(right_eye.squeeze(0)) | |
| input_img_np = np.array(input_image_pil) | |
| zeros_img = to_pil_image(np.zeros((height, width, 3), dtype=np.uint8)) | |
| output_img_np = overlay_text_on_frame(np.array(zeros_img), "blink") | |
| predicted_diameter = "blink" | |
| else: | |
| if left_eye is not None and eye_type == "left_eye": | |
| if left_pupil_cam_extractor is None: | |
| if tv_model == "ResNet18": | |
| target_layer = left_pupil_model.resnet.layer4[-1].conv2 | |
| elif tv_model == "ResNet50": | |
| target_layer = left_pupil_model.resnet.layer4[-1].conv3 | |
| else: | |
| raise Exception(f"No target layer available for selected model: {tv_model}") | |
| left_pupil_cam_extractor = torchcam_methods.__dict__[cam_method]( | |
| left_pupil_model, | |
| target_layer=target_layer, | |
| fc_layer=left_pupil_model.resnet.fc, | |
| input_shape=left_eye.shape, | |
| ) | |
| output = left_pupil_model(left_eye) | |
| predicted_diameter = output[0].item() | |
| act_maps = left_pupil_cam_extractor(0, output) | |
| activation_map = act_maps[0] if len(act_maps) == 1 else left_pupil_cam_extractor.fuse_cams(act_maps) | |
| input_image_pil = to_pil_image(left_eye.squeeze(0)) | |
| elif right_eye is not None and eye_type == "right_eye": | |
| if right_pupil_cam_extractor is None: | |
| if tv_model == "ResNet18": | |
| target_layer = right_pupil_model.resnet.layer4[-1].conv2 | |
| elif tv_model == "ResNet50": | |
| target_layer = right_pupil_model.resnet.layer4[-1].conv3 | |
| else: | |
| raise Exception(f"No target layer available for selected model: {tv_model}") | |
| right_pupil_cam_extractor = torchcam_methods.__dict__[cam_method]( | |
| right_pupil_model, | |
| target_layer=target_layer, | |
| fc_layer=right_pupil_model.resnet.fc, | |
| input_shape=right_eye.shape, | |
| ) | |
| output = right_pupil_model(right_eye) | |
| predicted_diameter = output[0].item() | |
| act_maps = right_pupil_cam_extractor(0, output) | |
| activation_map = ( | |
| act_maps[0] if len(act_maps) == 1 else right_pupil_cam_extractor.fuse_cams(act_maps) | |
| ) | |
| input_image_pil = to_pil_image(right_eye.squeeze(0)) | |
| # Create CAM overlay | |
| activation_map_pil = to_pil_image(activation_map, mode="F") | |
| result = overlay_mask(input_image_pil, activation_map_pil, alpha=0.5) | |
| input_img_np = np.array(input_image_pil) | |
| output_img_np = np.array(result) | |
| # Add frame and predicted diameter to lists | |
| input_frames[eye_type].append(input_img_np) | |
| output_frames[eye_type].append(output_img_np) | |
| predicted_diameters[eye_type].append(predicted_diameter) | |
| if output_path: | |
| height, width, _ = output_img_np.shape | |
| frame = np.zeros((height, width, 3), dtype=np.uint8) | |
| if not isinstance(predicted_diameter, str): | |
| text = f"{predicted_diameter:.2f}" | |
| else: | |
| text = predicted_diameter | |
| frame = overlay_text_on_frame(frame, text) | |
| video_input_placeholders[eye_type].image(input_img_np, use_column_width=True) | |
| video_output_placeholders[eye_type].image(output_img_np, use_column_width=True) | |
| video_predictions_placeholders[eye_type].image(frame, use_column_width=True) | |
| st.session_state.current_frame = idx + 1 | |
| txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>" | |
| st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True) | |
| if output_path: | |
| show_input_frames(input_frames, output_path, codec, video_input_placeholders) | |
| show_cam_frames(output_frames, output_path, codec, video_output_placeholders) | |
| show_pred_text_frames(output_frames, output_path, predicted_diameters, codec, video_predictions_placeholders) | |
| return input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios | |
| # Function to display video with autoplay and loop | |
| def display_video_with_autoplay(video_col, video_path): | |
| video_html = f""" | |
| <video width="100%" height="auto" autoplay loop muted> | |
| <source src="data:video/mp4;base64,{video_path}" type="video/mp4"> | |
| </video> | |
| """ | |
| video_col.markdown(video_html, unsafe_allow_html=True) | |
| def get_codec_and_extension(file_format): | |
| """Return codec and file extension based on the format.""" | |
| if file_format == "mp4": | |
| return "H264", ".mp4" | |
| elif file_format == "avi": | |
| return "MJPG", ".avi" | |
| elif file_format == "webm": | |
| return "VP80", ".webm" | |
| else: | |
| return "MJPG", ".avi" | |
| def show_input_frames(input_frames, output_path, codec, video_cols): | |
| for i, eye_type in enumerate(input_frames.keys()): | |
| in_frames = input_frames[eye_type] | |
| height, width, _ = in_frames[0].shape | |
| fourcc = cv2.VideoWriter_fourcc(*codec) | |
| fps = 10.0 | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| for frame in in_frames: | |
| out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
| out.release() | |
| with open(output_path, "rb") as video_file: | |
| video_bytes = video_file.read() | |
| video_base64 = base64.b64encode(video_bytes).decode("utf-8") | |
| display_video_with_autoplay(video_cols[eye_type], video_base64) | |
| os.remove(output_path) | |
| def show_cam_frames(output_frames, output_path, codec, video_cols): | |
| for i, eye_type in enumerate(output_frames.keys()): | |
| out_frames = output_frames[eye_type] | |
| height, width, _ = out_frames[0].shape | |
| fourcc = cv2.VideoWriter_fourcc(*codec) | |
| fps = 10.0 | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| for j, frame in enumerate(out_frames): | |
| out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
| out.release() | |
| with open(output_path, "rb") as video_file: | |
| video_bytes = video_file.read() | |
| video_base64 = base64.b64encode(video_bytes).decode("utf-8") | |
| display_video_with_autoplay(video_cols[eye_type], video_base64) | |
| os.remove(output_path) | |
| def show_pred_text_frames(output_frames, output_path, predicted_diameters, codec, video_cols): | |
| for i, eye_type in enumerate(output_frames.keys()): | |
| out_frames = output_frames[eye_type] | |
| height, width, _ = out_frames[0].shape | |
| fourcc = cv2.VideoWriter_fourcc(*codec) | |
| fps = 10.0 | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| for diameter in predicted_diameters[eye_type]: | |
| frame = np.zeros((height, width, 3), dtype=np.uint8) | |
| if not isinstance(diameter, str): | |
| text = f"{diameter:.2f}" | |
| else: | |
| text = diameter | |
| frame = overlay_text_on_frame(frame, text) | |
| out.write(frame) | |
| out.release() | |
| with open(output_path, "rb") as video_file: | |
| video_bytes = video_file.read() | |
| video_base64 = base64.b64encode(video_bytes).decode("utf-8") | |
| display_video_with_autoplay(video_cols[eye_type], video_base64) | |
| os.remove(output_path) | |
| def process_video(cols, video_frames, tv_model, pupil_selection, output_path, cam_method, blink_detection=False): | |
| resized_frames = [] | |
| for i, frame in enumerate(video_frames): | |
| input_img = resize_frame(frame, max_width=640, max_height=480) | |
| resized_frames.append(input_img) | |
| file_format = output_path.split(".")[-1] | |
| codec, extension = get_codec_and_extension(file_format) | |
| input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames( | |
| cols, resized_frames, tv_model, pupil_selection, cam_method, output_path, codec, blink_detection | |
| ) | |
| return input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios | |
| # Function to convert string values to float or None | |
| def convert_diameter(value): | |
| try: | |
| return float(value) | |
| except (ValueError, TypeError): | |
| return None # Return None if conversion fails | |