Spaces:
Runtime error
Runtime error
| import torch | |
| import sys | |
| import os.path as osp | |
| import os | |
| import argparse | |
| import cv2 | |
| import time | |
| import h5py | |
| from tqdm import tqdm | |
| import numpy as np | |
| import warnings | |
| import signal | |
| warnings.filterwarnings('ignore') | |
| def signal_handler(sig, frame): | |
| print("\nInterrupted by user, shutting down...") | |
| if 'loader_thread' in globals() and loader_thread.is_alive(): | |
| loader_thread.join(timeout=1.0) # Give the thread 1 second to finish | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() # Free GPU memory immediately | |
| os.exit(0) | |
| # Register the signal handler | |
| signal.signal(signal.SIGINT, signal_handler) | |
| sys.path.insert(0, osp.dirname(osp.realpath(__file__))) | |
| from tools.utils import get_path | |
| from model.gast_net import SpatioTemporalModel, SpatioTemporalModelOptimized1f | |
| from common.skeleton import Skeleton | |
| from common.graph_utils import adj_mx_from_skeleton | |
| from common.generators import * | |
| from tools.preprocess import load_kpts_json, h36m_coco_format, revise_kpts, revise_skes | |
| from tools.inference import gen_pose | |
| from tools.vis_kpts import plot_keypoint | |
| cur_dir, chk_root, data_root, lib_root, output_root = get_path(__file__) | |
| model_dir = chk_root + 'gastnet/' | |
| sys.path.insert(1, lib_root) | |
| from lib.pose import gen_video_kpts as hrnet_pose | |
| sys.path.pop(1) | |
| sys.path.pop(0) | |
| skeleton = Skeleton(parents=[-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15], | |
| joints_left=[4, 5, 6, 11, 12, 13], joints_right=[1, 2, 3, 14, 15, 16]) | |
| adj = adj_mx_from_skeleton(skeleton) | |
| joints_left, joints_right = [4, 5, 6, 11, 12, 13], [1, 2, 3, 14, 15, 16] | |
| kps_left, kps_right = [4, 5, 6, 11, 12, 13], [1, 2, 3, 14, 15, 16] | |
| def load_model_layer(): | |
| chk = model_dir + '81_frame_model.bin' | |
| filters_width = [3, 3, 3, 3] | |
| channels = 64 | |
| model_pos = SpatioTemporalModel(adj, 17, 2, 17, filter_widths=filters_width, channels=channels, dropout=0.05) | |
| checkpoint = torch.load(chk) | |
| model_pos.load_state_dict(checkpoint['model_pos']) | |
| if torch.cuda.is_available(): | |
| model_pos = model_pos.cuda() | |
| model_pos = model_pos.eval() | |
| return model_pos | |
| def generate_skeletons(video=''): | |
| def force_exit(sig, frame): | |
| print("\nForce terminating...") | |
| os._exit(1) | |
| signal.signal(signal.SIGINT, force_exit) | |
| cap = cv2.VideoCapture(video) | |
| width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) | |
| height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # 2D Keypoint Generation (handled by gen_video_kpts) | |
| print('Generating 2D Keypoints:') | |
| sys.stdout.flush() | |
| keypoints, scores = hrnet_pose(video, det_dim=416, gen_output=True) | |
| keypoints, scores, valid_frames = h36m_coco_format(keypoints, scores) | |
| re_kpts = revise_kpts(keypoints, scores, valid_frames) | |
| num_person = len(re_kpts) | |
| model_pos = load_model_layer() | |
| pad = (81 - 1) // 2 | |
| causal_shift = 0 | |
| # 3D Pose Generation | |
| print('Recording 3D Pose:') | |
| print(f"PROGRESS:100.00") # Start 3D at 100% | |
| sys.stdout.flush() | |
| total_valid_frames = len(valid_frames) if valid_frames else total_frames | |
| prediction = gen_pose(re_kpts, valid_frames, width, height, model_pos, pad, causal_shift) | |
| # Simulate 3D progress (replace with gen_pose loop if shared) | |
| for i in range(total_valid_frames): | |
| progress = 100 + ((i + 1) / total_valid_frames * 100) # 100-200% for 3D | |
| print(f"PROGRESS:{progress:.2f}") | |
| sys.stdout.flush() | |
| time.sleep(0.01) # Placeholder; remove if gen_pose has its own loop | |
| output_dir = os.path.abspath('../outputs/') | |
| print(f"Creating output directory: {output_dir}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| npz_dir = os.path.join(output_dir, 'npz') | |
| print(f"Creating NPZ directory: {npz_dir}") | |
| os.makedirs(npz_dir, exist_ok=True) | |
| output_npz = os.path.join(npz_dir, os.path.basename(video).split('.')[0] + '.npz') | |
| print(f"Saving NPZ to: {output_npz}") | |
| np.savez_compressed(output_npz, reconstruction=prediction) | |
| print(f"NPZ saved successfully: {output_npz}") | |
| def arg_parse(): | |
| parser = argparse.ArgumentParser('Generating skeleton demo.') | |
| parser.add_argument('-v', '--video', type=str) | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == "__main__": | |
| args = arg_parse() | |
| if os.path.isabs(args.video): | |
| video_path = args.video | |
| else: | |
| video_path = os.path.join(data_root, 'video', args.video) | |
| generate_skeletons(video=video_path) |