Spaces:
Runtime error
Runtime error
| """ | |
| File: app_utils.py | |
| Author: Elena Ryumina and Dmitry Ryumin (modified by Assistant) | |
| Description: This module contains utility functions for facial expression recognition application, including FACS Analysis for SAD. | |
| License: MIT License | |
| """ | |
| import torch | |
| import numpy as np | |
| import mediapipe as mp | |
| from PIL import Image | |
| import cv2 | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import matplotlib.pyplot as plt | |
| # Importing necessary components for the Gradio app | |
| from app.model import pth_model_static, pth_model_dynamic, cam, pth_processing | |
| from app.face_utils import get_box, display_info | |
| from app.config import DICT_EMO, config_data | |
| from app.plot import statistics_plot | |
| mp_face_mesh = mp.solutions.face_mesh | |
| def preprocess_image_and_predict(inp): | |
| inp = np.array(inp) | |
| if inp is None: | |
| return None, None, None | |
| try: | |
| h, w = inp.shape[:2] | |
| except Exception: | |
| return None, None, None | |
| with mp_face_mesh.FaceMesh( | |
| max_num_faces=1, | |
| refine_landmarks=False, | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5, | |
| ) as face_mesh: | |
| results = face_mesh.process(inp) | |
| if results.multi_face_landmarks: | |
| for fl in results.multi_face_landmarks: | |
| startX, startY, endX, endY = get_box(fl, w, h) | |
| cur_face = inp[startY:endY, startX:endX] | |
| cur_face_n = pth_processing(Image.fromarray(cur_face)) | |
| with torch.no_grad(): | |
| prediction = ( | |
| torch.nn.functional.softmax(pth_model_static(cur_face_n), dim=1) | |
| .detach() | |
| .numpy()[0] | |
| ) | |
| confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)} | |
| grayscale_cam = cam(input_tensor=cur_face_n) | |
| grayscale_cam = grayscale_cam[0, :] | |
| cur_face_hm = cv2.resize(cur_face,(224,224)) | |
| cur_face_hm = np.float32(cur_face_hm) / 255 | |
| heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=True) | |
| return cur_face, heatmap, confidences | |
| def preprocess_frame_and_predict_aus(frame): | |
| if len(frame.shape) == 2: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) | |
| elif frame.shape[2] == 4: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) | |
| with mp_face_mesh.FaceMesh( | |
| max_num_faces=1, | |
| refine_landmarks=False, | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5 | |
| ) as face_mesh: | |
| results = face_mesh.process(frame) | |
| if results.multi_face_landmarks: | |
| h, w = frame.shape[:2] | |
| for fl in results.multi_face_landmarks: | |
| startX, startY, endX, endY = get_box(fl, w, h) | |
| cur_face = frame[startY:endY, startX:endX] | |
| cur_face_n = pth_processing(Image.fromarray(cur_face)) | |
| with torch.no_grad(): | |
| features = pth_model_static(cur_face_n) | |
| au_intensities = features_to_au_intensities(features) | |
| grayscale_cam = cam(input_tensor=cur_face_n) | |
| grayscale_cam = grayscale_cam[0, :] | |
| cur_face_hm = cv2.resize(cur_face, (224, 224)) | |
| cur_face_hm = np.float32(cur_face_hm) / 255 | |
| heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=True) | |
| return cur_face, au_intensities, heatmap | |
| return None, None, None | |
| def features_to_au_intensities(features): | |
| features_np = features.detach().cpu().numpy()[0] | |
| au_intensities = (features_np - features_np.min()) / (features_np.max() - features_np.min()) | |
| return au_intensities[:24] # Assuming we want 24 AUs | |
| def preprocess_video_and_predict(video): | |
| cap = cv2.VideoCapture(video) | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = np.round(cap.get(cv2.CAP_PROP_FPS)) | |
| path_save_video_face = 'result_face.mp4' | |
| vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224)) | |
| path_save_video_hm = 'result_hm.mp4' | |
| vid_writer_hm = cv2.VideoWriter(path_save_video_hm, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224)) | |
| lstm_features = [] | |
| count_frame = 1 | |
| count_face = 0 | |
| probs = [] | |
| frames = [] | |
| au_intensities_list = [] | |
| last_output = None | |
| last_heatmap = None | |
| last_au_intensities = None | |
| cur_face = None | |
| with mp_face_mesh.FaceMesh( | |
| max_num_faces=1, | |
| refine_landmarks=False, | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5) as face_mesh: | |
| while cap.isOpened(): | |
| _, frame = cap.read() | |
| if frame is None: break | |
| frame_copy = frame.copy() | |
| frame_copy.flags.writeable = False | |
| frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB) | |
| results = face_mesh.process(frame_copy) | |
| frame_copy.flags.writeable = True | |
| if results.multi_face_landmarks: | |
| for fl in results.multi_face_landmarks: | |
| startX, startY, endX, endY = get_box(fl, w, h) | |
| cur_face = frame_copy[startY:endY, startX: endX] | |
| if count_face%config_data.FRAME_DOWNSAMPLING == 0: | |
| cur_face_copy = pth_processing(Image.fromarray(cur_face)) | |
| with torch.no_grad(): | |
| features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().numpy() | |
| au_intensities = features_to_au_intensities(pth_model_static(cur_face_copy)) | |
| grayscale_cam = cam(input_tensor=cur_face_copy) | |
| grayscale_cam = grayscale_cam[0, :] | |
| cur_face_hm = cv2.resize(cur_face,(224,224), interpolation = cv2.INTER_AREA) | |
| cur_face_hm = np.float32(cur_face_hm) / 255 | |
| heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=False) | |
| last_heatmap = heatmap | |
| last_au_intensities = au_intensities | |
| if len(lstm_features) == 0: | |
| lstm_features = [features]*10 | |
| else: | |
| lstm_features = lstm_features[1:] + [features] | |
| lstm_f = torch.from_numpy(np.vstack(lstm_features)) | |
| lstm_f = torch.unsqueeze(lstm_f, 0) | |
| with torch.no_grad(): | |
| output = pth_model_dynamic(lstm_f).detach().numpy() | |
| last_output = output | |
| if count_face == 0: | |
| count_face += 1 | |
| else: | |
| if last_output is not None: | |
| output = last_output | |
| heatmap = last_heatmap | |
| au_intensities = last_au_intensities | |
| elif last_output is None: | |
| output = np.empty((1, 7)) | |
| output[:] = np.nan | |
| au_intensities = np.empty(24) | |
| au_intensities[:] = np.nan | |
| probs.append(output[0]) | |
| frames.append(count_frame) | |
| au_intensities_list.append(au_intensities) | |
| else: | |
| if last_output is not None: | |
| lstm_features = [] | |
| empty = np.empty((7)) | |
| empty[:] = np.nan | |
| probs.append(empty) | |
| frames.append(count_frame) | |
| au_intensities_list.append(np.full(24, np.nan)) | |
| if cur_face is not None: | |
| heatmap_f = display_info(heatmap, 'Frame: {}'.format(count_frame), box_scale=.3) | |
| cur_face = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR) | |
| cur_face = cv2.resize(cur_face, (224,224), interpolation = cv2.INTER_AREA) | |
| cur_face = display_info(cur_face, 'Frame: {}'.format(count_frame), box_scale=.3) | |
| vid_writer_face.write(cur_face) | |
| vid_writer_hm.write(heatmap_f) | |
| count_frame += 1 | |
| if count_face != 0: | |
| count_face += 1 | |
| vid_writer_face.release() | |
| vid_writer_hm.release() | |
| stat = statistics_plot(frames, probs) | |
| au_stat = au_statistics_plot(frames, au_intensities_list) | |
| if not stat or not au_stat: | |
| return None, None, None, None, None | |
| return video, path_save_video_face, path_save_video_hm, stat, au_stat | |
| def au_statistics_plot(frames, au_intensities_list): | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| au_intensities_array = np.array(au_intensities_list) | |
| for i in range(au_intensities_array.shape[1]): | |
| ax.plot(frames, au_intensities_array[:, i], label=f'AU{i+1}') | |
| ax.set_xlabel('Frame') | |
| ax.set_ylabel('AU Intensity') | |
| ax.set_title('Action Unit Intensities Over Time') | |
| ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') | |
| plt.tight_layout() | |
| return fig | |
| def preprocess_video_and_predict_sleep_quality(video): | |
| cap = cv2.VideoCapture(video) | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = np.round(cap.get(cv2.CAP_PROP_FPS)) | |
| path_save_video_original = 'result_original.mp4' | |
| path_save_video_face = 'result_face.mp4' | |
| path_save_video_sleep = 'result_sleep.mp4' | |
| vid_writer_original = cv2.VideoWriter(path_save_video_original, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) | |
| vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224)) | |
| vid_writer_sleep = cv2.VideoWriter(path_save_video_sleep, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224)) | |
| frames = [] | |
| sleep_quality_scores = [] | |
| eye_bags_images = [] | |
| with mp_face_mesh.FaceMesh( | |
| max_num_faces=1, | |
| refine_landmarks=False, | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5) as face_mesh: | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| results = face_mesh.process(frame_rgb) | |
| if results.multi_face_landmarks: | |
| for fl in results.multi_face_landmarks: | |
| startX, startY, endX, endY = get_box(fl, w, h) | |
| cur_face = frame_rgb[startY:endY, startX:endX] | |
| sleep_quality_score, eye_bags_image = analyze_sleep_quality(cur_face) | |
| sleep_quality_scores.append(sleep_quality_score) | |
| eye_bags_images.append(cv2.resize(eye_bags_image, (224, 224))) | |
| sleep_quality_viz = create_sleep_quality_visualization(cur_face, sleep_quality_score) | |
| cur_face = cv2.resize(cur_face, (224, 224)) | |
| vid_writer_face.write(cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR)) | |
| vid_writer_sleep.write(sleep_quality_viz) | |
| vid_writer_original.write(frame) | |
| frames.append(len(frames) + 1) | |
| cap.release() | |
| vid_writer_original.release() | |
| vid_writer_face.release() | |
| vid_writer_sleep.release() | |
| sleep_stat = sleep_quality_statistics_plot(frames, sleep_quality_scores) | |
| if eye_bags_images: | |
| average_eye_bags_image = np.mean(np.array(eye_bags_images), axis=0).astype(np.uint8) | |
| else: | |
| average_eye_bags_image = np.zeros((224, 224, 3), dtype=np.uint8) | |
| return (path_save_video_original, path_save_video_face, path_save_video_sleep, | |
| average_eye_bags_image, sleep_stat) | |
| def analyze_sleep_quality(face_image): | |
| # Placeholder function - implement your sleep quality analysis here | |
| sleep_quality_score = np.random.random() | |
| eye_bags_image = cv2.resize(face_image, (224, 224)) | |
| return sleep_quality_score, eye_bags_image | |
| def create_sleep_quality_visualization(face_image, sleep_quality_score): | |
| viz = face_image.copy() | |
| cv2.putText(viz, f"Sleep Quality: {sleep_quality_score:.2f}", (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) | |
| return cv2.cvtColor(viz, cv2.COLOR_RGB2BGR) | |
| def sleep_quality_statistics_plot(frames, sleep_quality_scores): | |
| # Placeholder function - implement your statistics plotting here | |
| fig, ax = plt.subplots() | |
| ax.plot(frames, sleep_quality_scores) | |
| ax.set_xlabel('Frame') | |
| ax.set_ylabel('Sleep Quality Score') | |
| ax.set_title('Sleep Quality Over Time') | |
| return fig |