Spaces:
Running
Running
| """Record3D visualizer | |
| Parse and stream record3d captures. To get the demo data, see `./assets/download_record3d_dance.sh`. | |
| """ | |
| import time | |
| from pathlib import Path | |
| import numpy as onp | |
| import tyro | |
| import cv2 | |
| from tqdm.auto import tqdm | |
| import viser | |
| import viser.extras | |
| import viser.transforms as tf | |
| from glob import glob | |
| import numpy as np | |
| import imageio.v3 as iio | |
| import matplotlib.pyplot as plt | |
| import psutil | |
| def log_memory_usage(message=""): | |
| """Log current memory usage with an optional message.""" | |
| process = psutil.Process() | |
| memory_info = process.memory_info() | |
| memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB | |
| print(f"Memory usage {message}: {memory_mb:.2f} MB") | |
| def load_trajectory_data(traj_path="results", use_float16=True, max_frames=None, mask_folder='./train', conf_thre_percentile=10): | |
| """Load trajectory data from files. | |
| Args: | |
| traj_path: Path to the directory containing trajectory data | |
| use_float16: Whether to convert data to float16 to save memory | |
| max_frames: Maximum number of frames to load (None for all) | |
| mask_folder: Path to the directory containing mask images | |
| Returns: | |
| A dictionary containing loaded data | |
| """ | |
| log_memory_usage("before loading data") | |
| data_cache = { | |
| 'traj_3d_head1': None, | |
| 'traj_3d_head2': None, | |
| 'conf_mask_head1': None, | |
| 'conf_mask_head2': None, | |
| 'masks': None, | |
| 'raw_video': None, | |
| 'loaded': False | |
| } | |
| # Load masks | |
| masks_paths = sorted(glob(mask_folder + '/*.jpg')) | |
| masks = None | |
| if masks_paths: | |
| masks = [iio.imread(p) for p in masks_paths] | |
| masks = np.stack(masks, axis=0) | |
| # Convert masks to binary (0 or 1) | |
| masks = (masks < 1).astype(np.float32) | |
| masks = masks.sum(axis=-1) > 2 # Combine all channels, True where any channel was 1 | |
| print(f"Original masks shape: {masks.shape}") | |
| else: | |
| print("No masks found. Will create default masks when needed.") | |
| data_cache['masks'] = masks | |
| if Path(traj_path).is_dir(): | |
| # Find all trajectory files | |
| traj_3d_paths_head1 = sorted(glob(traj_path + '/pts3d1_p*.npy'), | |
| key=lambda x: int(x.split('_p')[-1].split('.')[0])) | |
| conf_paths_head1 = sorted(glob(traj_path + '/conf1_p*.npy'), | |
| key=lambda x: int(x.split('_p')[-1].split('.')[0])) | |
| traj_3d_paths_head2 = sorted(glob(traj_path + '/pts3d2_p*.npy'), | |
| key=lambda x: int(x.split('_p')[-1].split('.')[0])) | |
| conf_paths_head2 = sorted(glob(traj_path + '/conf2_p*.npy'), | |
| key=lambda x: int(x.split('_p')[-1].split('.')[0])) | |
| # Limit number of frames if specified | |
| if max_frames is not None: | |
| traj_3d_paths_head1 = traj_3d_paths_head1[:max_frames] | |
| conf_paths_head1 = conf_paths_head1[:max_frames] if conf_paths_head1 else [] | |
| traj_3d_paths_head2 = traj_3d_paths_head2[:max_frames] | |
| conf_paths_head2 = conf_paths_head2[:max_frames] if conf_paths_head2 else [] | |
| # Process head1 | |
| if traj_3d_paths_head1: | |
| if use_float16: | |
| traj_3d_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head1], axis=0) | |
| else: | |
| traj_3d_head1 = onp.stack([onp.load(p) for p in traj_3d_paths_head1], axis=0) | |
| log_memory_usage("after loading head1 data") | |
| h, w, _ = traj_3d_head1.shape[1:] | |
| num_frames = traj_3d_head1.shape[0] | |
| # If masks is None, create default masks (all ones) | |
| if masks is None: | |
| masks = np.ones((num_frames, h, w), dtype=bool) | |
| print(f"Created default masks with shape: {masks.shape}") | |
| data_cache['masks'] = masks | |
| else: | |
| # Resize masks to match trajectory dimensions using nearest neighbor interpolation | |
| masks_resized = np.zeros((masks.shape[0], h, w), dtype=bool) | |
| for i in range(masks.shape[0]): | |
| masks_resized[i] = cv2.resize( | |
| masks[i].astype(np.uint8), | |
| (w, h), | |
| interpolation=cv2.INTER_NEAREST | |
| ).astype(bool) | |
| print(f"Resized masks shape: {masks_resized.shape}") | |
| data_cache['masks'] = masks_resized | |
| # Reshape trajectory data | |
| traj_3d_head1 = traj_3d_head1.reshape(traj_3d_head1.shape[0], -1, 6) | |
| data_cache['traj_3d_head1'] = traj_3d_head1 | |
| if conf_paths_head1: | |
| conf_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head1], axis=0) | |
| conf_head1 = conf_head1.reshape(conf_head1.shape[0], -1) | |
| conf_head1 = conf_head1.mean(axis=0) | |
| # repeat the conf_head1 to match the number of frames in the dimension 0 | |
| conf_head1 = np.tile(conf_head1, (num_frames, 1)) | |
| # Convert to float32 before calculating percentile to avoid overflow | |
| conf_thre = np.percentile(conf_head1.astype(np.float32), conf_thre_percentile) # Default percentile | |
| conf_mask_head1 = conf_head1 > conf_thre | |
| data_cache['conf_mask_head1'] = conf_mask_head1 | |
| # Process head2 | |
| if traj_3d_paths_head2: | |
| if use_float16: | |
| traj_3d_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head2], axis=0) | |
| else: | |
| traj_3d_head2 = onp.stack([onp.load(p) for p in traj_3d_paths_head2], axis=0) | |
| log_memory_usage("after loading head2 data") | |
| # Store raw video data | |
| raw_video = traj_3d_head2[:, :, :, 3:6] # [num_frames, h, w, 3] | |
| data_cache['raw_video'] = raw_video | |
| traj_3d_head2 = traj_3d_head2.reshape(traj_3d_head2.shape[0], -1, 6) | |
| data_cache['traj_3d_head2'] = traj_3d_head2 | |
| if conf_paths_head2: | |
| conf_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head2], axis=0) | |
| conf_head2 = conf_head2.reshape(conf_head2.shape[0], -1) | |
| # set conf thre to be 1 percentile of the conf_head2, for each frame | |
| conf_thre = np.percentile(conf_head2.astype(np.float32), conf_thre_percentile, axis=1) | |
| conf_mask_head2 = conf_head2 > conf_thre[:, None] | |
| data_cache['conf_mask_head2'] = conf_mask_head2 | |
| data_cache['loaded'] = True | |
| log_memory_usage("after loading all data") | |
| return data_cache | |
| def visualize_st4rtrack( | |
| traj_path: str = "results", | |
| up_dir: str = "-z", # should be +z or -z | |
| max_frames: int = 100, | |
| share: bool = False, | |
| point_size: float = 0.005, | |
| downsample_factor: int = 3, | |
| num_traj_points: int = 100, | |
| conf_thre_percentile: float = 1, | |
| traj_end_frame: int = 100, | |
| traj_start_frame: int = 0, | |
| traj_line_width: float = 3., | |
| fixed_length_traj: int = 20, | |
| server: viser.ViserServer = None, | |
| use_float16: bool = True, | |
| preloaded_data: dict = None, # Add this parameter to accept preloaded data | |
| color_code: str = "jet", | |
| # Updated hex colors: #002676 for blue and #FDB515 for red/gold | |
| blue_rgb: tuple[float, float, float] = (0.0, 0.149, 0.463), # #002676 | |
| red_rgb: tuple[float, float, float] = (0.769, 0.510, 0.055), # #FDB515 | |
| blend_ratio: float = 0.7, | |
| mask_folder: str = None, | |
| mid_anchor: bool = False, | |
| video_width: int = 320, # Video display width | |
| video_height: int = 180, # Video display height | |
| camera_position: tuple[float, float, float] = (1e-3, 1.5, -0.2), | |
| ) -> None: | |
| log_memory_usage("at start of visualization") | |
| if server is None: | |
| server = viser.ViserServer() | |
| if share: | |
| server.request_share_url() | |
| def _(client: viser.ClientHandle) -> None: | |
| client.camera.position = camera_position | |
| client.camera.look_at = (0, 0, 0) | |
| # Configure the GUI panel size and layout | |
| server.gui.configure_theme( | |
| control_layout="collapsible", | |
| control_width="small", | |
| dark_mode=False, | |
| show_logo=False, | |
| show_share_button=True | |
| ) | |
| # Add video preview to the GUI panel - placed at the top | |
| video_preview = server.gui.add_image( | |
| np.zeros((video_height, video_width, 3), dtype=np.uint8), # Initial blank image | |
| format="jpeg" | |
| ) | |
| # Use preloaded data if available | |
| if preloaded_data and preloaded_data.get('loaded', False): | |
| traj_3d_head1 = preloaded_data.get('traj_3d_head1') | |
| traj_3d_head2 = preloaded_data.get('traj_3d_head2') | |
| conf_mask_head1 = preloaded_data.get('conf_mask_head1') | |
| conf_mask_head2 = preloaded_data.get('conf_mask_head2') | |
| masks = preloaded_data.get('masks') | |
| raw_video = preloaded_data.get('raw_video') | |
| print("Using preloaded data!") | |
| else: | |
| # Load data using the shared function | |
| print("No preloaded data available, loading from files...") | |
| data = load_trajectory_data(traj_path, use_float16, max_frames, mask_folder, conf_thre_percentile) | |
| traj_3d_head1 = data.get('traj_3d_head1') | |
| traj_3d_head2 = data.get('traj_3d_head2') | |
| conf_mask_head1 = data.get('conf_mask_head1') | |
| conf_mask_head2 = data.get('conf_mask_head2') | |
| masks = data.get('masks') | |
| raw_video = data.get('raw_video') | |
| def process_video_frame(frame_idx): | |
| if raw_video is None: | |
| return np.zeros((video_height, video_width, 3), dtype=np.uint8) | |
| # Get the original frame | |
| raw_frame = raw_video[frame_idx] | |
| # Adjust value range to 0-255 | |
| if raw_frame.max() <= 1.0: | |
| frame = (raw_frame * 255).astype(np.uint8) | |
| else: | |
| frame = raw_frame.astype(np.uint8) | |
| # Resize to fit the preview window | |
| h, w = frame.shape[:2] | |
| # Calculate size while maintaining aspect ratio | |
| if h/w > video_height/video_width: # Height limited | |
| new_h = video_height | |
| new_w = int(w * (new_h / h)) | |
| else: # Width limited | |
| new_w = video_width | |
| new_h = int(h * (new_w / w)) | |
| # Resize | |
| resized_frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
| # Create a black background | |
| display_frame = np.zeros((video_height, video_width, 3), dtype=np.uint8) | |
| # Place the resized frame in the center | |
| y_offset = (video_height - new_h) // 2 | |
| x_offset = (video_width - new_w) // 2 | |
| display_frame[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized_frame | |
| return display_frame | |
| server.scene.set_up_direction(up_dir) | |
| print("Setting up visualization!") | |
| # Add visualization controls | |
| with server.gui.add_folder("Visualization"): | |
| gui_show_head1 = server.gui.add_checkbox("Tracking Points", True) | |
| gui_show_head2 = server.gui.add_checkbox("Recon Points", True) | |
| gui_show_trajectories = server.gui.add_checkbox("Trajectories", True) | |
| gui_use_color_tint = server.gui.add_checkbox("Use Color Tint", True) | |
| # Process and center point clouds | |
| center_point = None | |
| if traj_3d_head1 is not None: | |
| xyz_head1 = traj_3d_head1[:, :, :3] | |
| rgb_head1 = traj_3d_head1[:, :, 3:6] | |
| if center_point is None: | |
| center_point = onp.mean(xyz_head1, axis=(0, 1), keepdims=True) | |
| xyz_head1 -= center_point | |
| if rgb_head1.sum(axis=(-1)).max() > 125: | |
| rgb_head1 /= 255.0 | |
| if traj_3d_head2 is not None: | |
| xyz_head2 = traj_3d_head2[:, :, :3] | |
| rgb_head2 = traj_3d_head2[:, :, 3:6] | |
| if center_point is None: | |
| center_point = onp.mean(xyz_head2, axis=(0, 1), keepdims=True) | |
| xyz_head2 -= center_point | |
| if rgb_head2.sum(axis=(-1)).max() > 125: | |
| rgb_head2 /= 255.0 | |
| # Determine number of frames | |
| F = max( | |
| traj_3d_head1.shape[0] if traj_3d_head1 is not None else 0, | |
| traj_3d_head2.shape[0] if traj_3d_head2 is not None else 0 | |
| ) | |
| num_frames = min(max_frames, F) | |
| traj_end_frame = min(traj_end_frame, num_frames) | |
| print(f"Number of frames: {num_frames}") | |
| xyz_head1 = xyz_head1[:num_frames] | |
| xyz_head2 = xyz_head2[:num_frames] | |
| rgb_head1 = rgb_head1[:num_frames] | |
| rgb_head2 = rgb_head2[:num_frames] | |
| # Add playback UI. | |
| with server.gui.add_folder("Playback"): | |
| gui_timestep = server.gui.add_slider( | |
| "Timestep", | |
| min=0, | |
| max=num_frames - 1, | |
| step=1, | |
| initial_value=0, | |
| disabled=True, | |
| ) | |
| gui_next_frame = server.gui.add_button("Next Frame", disabled=True) | |
| gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True) | |
| gui_playing = server.gui.add_checkbox("Playing", True) | |
| gui_framerate = server.gui.add_slider( | |
| "FPS", min=1, max=60, step=0.1, initial_value=20 | |
| ) | |
| gui_framerate_options = server.gui.add_button_group( | |
| "FPS options", ("10", "20", "30") | |
| ) | |
| gui_show_all_frames = server.gui.add_checkbox("Show all frames", False) | |
| gui_stride = server.gui.add_slider( | |
| "Stride", | |
| min=1, | |
| max=num_frames, | |
| step=1, | |
| initial_value=5, | |
| disabled=True, # Initially disabled | |
| ) | |
| # Frame step buttons. | |
| def _(_) -> None: | |
| gui_timestep.value = (gui_timestep.value + 1) % num_frames | |
| def _(_) -> None: | |
| gui_timestep.value = (gui_timestep.value - 1) % num_frames | |
| # Disable frame controls when we're playing. | |
| def _(_) -> None: | |
| gui_timestep.disabled = gui_playing.value or gui_show_all_frames.value | |
| gui_next_frame.disabled = gui_playing.value or gui_show_all_frames.value | |
| gui_prev_frame.disabled = gui_playing.value or gui_show_all_frames.value | |
| # Set the framerate when we click one of the options. | |
| def _(_) -> None: | |
| gui_framerate.value = int(gui_framerate_options.value) | |
| prev_timestep = gui_timestep.value | |
| # Toggle frame visibility when the timestep slider changes. | |
| def _(_) -> None: | |
| nonlocal prev_timestep | |
| current_timestep = gui_timestep.value | |
| if not gui_show_all_frames.value: | |
| with server.atomic(): | |
| if gui_show_head1.value: | |
| frame_nodes_head1[current_timestep].visible = True | |
| frame_nodes_head1[prev_timestep].visible = False | |
| if gui_show_head2.value: | |
| frame_nodes_head2[current_timestep].visible = True | |
| frame_nodes_head2[prev_timestep].visible = False | |
| prev_timestep = current_timestep | |
| server.flush() # Optional! | |
| # Show or hide all frames based on the checkbox. | |
| def _(_) -> None: | |
| gui_stride.disabled = not gui_show_all_frames.value # Enable/disable stride slider | |
| if gui_show_all_frames.value: | |
| # Show frames with stride | |
| stride = gui_stride.value | |
| with server.atomic(): | |
| for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)): | |
| node1.visible = gui_show_head1.value and (i % stride == 0) | |
| node2.visible = gui_show_head2.value and (i % stride == 0) | |
| # Disable playback controls | |
| gui_playing.disabled = True | |
| gui_timestep.disabled = True | |
| gui_next_frame.disabled = True | |
| gui_prev_frame.disabled = True | |
| else: | |
| # Show only the current frame | |
| current_timestep = gui_timestep.value | |
| with server.atomic(): | |
| for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)): | |
| node1.visible = gui_show_head1.value and (i == current_timestep) | |
| node2.visible = gui_show_head2.value and (i == current_timestep) | |
| # Re-enable playback controls | |
| gui_playing.disabled = False | |
| gui_timestep.disabled = gui_playing.value | |
| gui_next_frame.disabled = gui_playing.value | |
| gui_prev_frame.disabled = gui_playing.value | |
| # Update frame visibility when the stride changes. | |
| def _(_) -> None: | |
| if gui_show_all_frames.value: | |
| # Update frame visibility based on new stride | |
| stride = gui_stride.value | |
| with server.atomic(): | |
| for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)): | |
| node1.visible = gui_show_head1.value and (i % stride == 0) | |
| node2.visible = gui_show_head2.value and (i % stride == 0) | |
| # Load in frames. | |
| server.scene.add_frame( | |
| "/frames", | |
| wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz, | |
| position=(0, 0, 0), | |
| show_axes=False, | |
| ) | |
| frame_nodes_head1: list[viser.FrameHandle] = [] | |
| frame_nodes_head2: list[viser.FrameHandle] = [] | |
| # Extract RGB components for tinting | |
| blue_r, blue_g, blue_b = blue_rgb | |
| red_r, red_g, red_b = red_rgb | |
| # Create frames for each timestep | |
| frame_nodes_head1 = [] | |
| frame_nodes_head2 = [] | |
| for i in tqdm(range(num_frames)): | |
| # Process head1 | |
| if traj_3d_head1 is not None: | |
| frame_nodes_head1.append(server.scene.add_frame(f"/frames/t{i}/head1", show_axes=False)) | |
| position = xyz_head1[i] | |
| color = rgb_head1[i] | |
| if conf_mask_head1 is not None: | |
| position = position[conf_mask_head1[i]] | |
| color = color[conf_mask_head1[i]] | |
| # Add point cloud for head1 with optional blue tint | |
| color_head1 = color.copy() | |
| if gui_use_color_tint.value: | |
| color_head1 *= blend_ratio | |
| color_head1[:, 0] = onp.clip(color_head1[:, 0] + blue_r * (1 - blend_ratio), 0, 1) # R | |
| color_head1[:, 1] = onp.clip(color_head1[:, 1] + blue_g * (1 - blend_ratio), 0, 1) # G | |
| color_head1[:, 2] = onp.clip(color_head1[:, 2] + blue_b * (1 - blend_ratio), 0, 1) # B | |
| server.scene.add_point_cloud( | |
| name=f"/frames/t{i}/head1/point_cloud", | |
| points=position[::downsample_factor], | |
| colors=color_head1[::downsample_factor], | |
| point_size=point_size, | |
| point_shape="rounded", | |
| ) | |
| # Process head2 | |
| if traj_3d_head2 is not None: | |
| frame_nodes_head2.append(server.scene.add_frame(f"/frames/t{i}/head2", show_axes=False)) | |
| position = xyz_head2[i] | |
| color = rgb_head2[i] | |
| if conf_mask_head2 is not None: | |
| position = position[conf_mask_head2[i]] | |
| color = color[conf_mask_head2[i]] | |
| # Add point cloud for head2 with optional red tint | |
| color_head2 = color.copy() | |
| if gui_use_color_tint.value: | |
| color_head2 *= blend_ratio | |
| color_head2[:, 0] = onp.clip(color_head2[:, 0] + red_r * (1 - blend_ratio), 0, 1) # R | |
| color_head2[:, 1] = onp.clip(color_head2[:, 1] + red_g * (1 - blend_ratio), 0, 1) # G | |
| color_head2[:, 2] = onp.clip(color_head2[:, 2] + red_b * (1 - blend_ratio), 0, 1) # B | |
| server.scene.add_point_cloud( | |
| name=f"/frames/t{i}/head2/point_cloud", | |
| points=position[::downsample_factor], | |
| colors=color_head2[::downsample_factor], | |
| point_size=point_size, | |
| point_shape="rounded", | |
| ) | |
| # Update visibility based on checkboxes | |
| def _(_) -> None: | |
| with server.atomic(): | |
| for frame_node in frame_nodes_head1: | |
| frame_node.visible = gui_show_head1.value and ( | |
| gui_show_all_frames.value | |
| or (not gui_show_all_frames.value ) | |
| ) | |
| def _(_) -> None: | |
| with server.atomic(): | |
| for frame_node in frame_nodes_head2: | |
| frame_node.visible = gui_show_head2.value and ( | |
| gui_show_all_frames.value | |
| or (not gui_show_all_frames.value ) | |
| ) | |
| # Initial visibility | |
| for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)): | |
| if gui_show_all_frames.value: | |
| node1.visible = gui_show_head1.value and (i % gui_stride.value == 0) | |
| node2.visible = gui_show_head2.value and (i % gui_stride.value == 0) | |
| else: | |
| node1.visible = gui_show_head1.value and (i == gui_timestep.value) | |
| node2.visible = gui_show_head2.value and (i == gui_timestep.value) | |
| # Process and visualize trajectories for head1 | |
| if traj_3d_head1 is not None: | |
| # Get points over time | |
| xyz_head1_centered = xyz_head1.copy() | |
| # Select points to visualize | |
| num_points = xyz_head1.shape[1] | |
| points_to_visualize = min(num_points, num_traj_points) | |
| # Get the mask for the first frame and reshape it to match point cloud dimensions | |
| if mid_anchor: | |
| first_frame_mask = masks[num_frames//2].reshape(-1) | |
| else: | |
| first_frame_mask = masks[0].reshape(-1) #[#points, h] | |
| # Calculate trajectory lengths for each point | |
| trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame] # Shape: (num_frames, num_points, 3) | |
| traj_diffs = np.diff(trajectories, axis=0) # Differences between consecutive frames | |
| traj_lengths = np.sum(np.sqrt(np.sum(traj_diffs**2, axis=-1)), axis=0) # Sum of distances for each point | |
| # Get points that are within the mask | |
| valid_indices = np.where(first_frame_mask)[0] | |
| if len(valid_indices) > 0: | |
| # Calculate average trajectory length for masked points | |
| masked_traj_lengths = traj_lengths[valid_indices] | |
| avg_traj_length = np.mean(masked_traj_lengths) | |
| if mask_folder is not None: | |
| # do not filter points by trajectory length | |
| long_traj_indices = valid_indices | |
| else: | |
| # Filter points by trajectory length | |
| long_traj_indices = valid_indices[masked_traj_lengths >= avg_traj_length] | |
| # Randomly sample from the filtered points | |
| if len(long_traj_indices) > 0: | |
| # Random sampling without replacement | |
| selected_indices = np.random.choice( | |
| len(long_traj_indices), | |
| min(points_to_visualize, len(long_traj_indices)), | |
| replace=False | |
| ) | |
| # Get the actual indices in their original order | |
| valid_point_indices = long_traj_indices[np.sort(selected_indices)] | |
| else: | |
| valid_point_indices = np.array([]) | |
| else: | |
| valid_point_indices = np.array([]) | |
| if len(valid_point_indices) > 0: | |
| # Get trajectories for all valid points | |
| trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame, valid_point_indices] | |
| N_point = trajectories.shape[1] | |
| if color_code == "rainbow": | |
| point_colors = plt.cm.rainbow(np.linspace(0, 1, N_point))[:, :3] | |
| elif color_code == "jet": | |
| point_colors = plt.cm.jet(np.linspace(0, 1, N_point))[:, :3] | |
| # Modify the loop to handle frames less than fixed_length_traj | |
| for i in range(traj_end_frame - traj_start_frame): | |
| # Calculate the actual trajectory length for this frame | |
| actual_length = min(fixed_length_traj, i + 1) | |
| if actual_length > 1: # Need at least 2 points to form a line | |
| # Get the appropriate slice of trajectory data | |
| start_idx = max(0, i - actual_length + 1) | |
| end_idx = i + 1 | |
| # Create line segments between consecutive frames | |
| traj_slice = trajectories[start_idx:end_idx] | |
| line_points = np.stack([traj_slice[:-1], traj_slice[1:]], axis=2) | |
| line_points = line_points.reshape(-1, 2, 3) | |
| # Create corresponding colors | |
| line_colors = np.tile(point_colors, (actual_length-1, 1)) | |
| line_colors = np.stack([line_colors, line_colors], axis=1) | |
| # Add line segments | |
| server.scene.add_line_segments( | |
| name=f"/frames/t{i+traj_start_frame}/head1/trajectory", | |
| points=line_points, | |
| colors=line_colors, | |
| line_width=traj_line_width, | |
| visible=gui_show_trajectories.value | |
| ) | |
| # Add trajectory controls functionality | |
| def _(_) -> None: | |
| with server.atomic(): | |
| # Remove all existing trajectories | |
| for i in range(num_frames): | |
| try: | |
| server.scene.remove_by_name(f"/frames/t{i}/head1/trajectory") | |
| except KeyError: | |
| pass | |
| # Create new trajectories if enabled | |
| if gui_show_trajectories.value and traj_3d_head1 is not None: | |
| # Get the mask for the last frame and reshape it | |
| last_frame_mask = masks[traj_end_frame-1].reshape(-1) | |
| # Calculate trajectory lengths | |
| trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame] | |
| traj_diffs = np.diff(trajectories, axis=0) | |
| traj_lengths = np.sum(np.sqrt(np.sum(traj_diffs**2, axis=-1)), axis=0) | |
| # Get points that are within the mask | |
| valid_indices = np.where(last_frame_mask)[0] | |
| if len(valid_indices) > 0: | |
| # Filter by trajectory length | |
| masked_traj_lengths = traj_lengths[valid_indices] | |
| avg_traj_length = np.mean(masked_traj_lengths) | |
| long_traj_indices = valid_indices[masked_traj_lengths >= avg_traj_length] | |
| # Randomly sample from the filtered points | |
| if len(long_traj_indices) > 0: | |
| # Random sampling without replacement | |
| selected_indices = np.random.choice( | |
| len(long_traj_indices), | |
| min(points_to_visualize, len(long_traj_indices)), | |
| replace=False | |
| ) | |
| # Get the actual indices in their original order | |
| valid_point_indices = long_traj_indices[np.sort(selected_indices)] | |
| else: | |
| valid_point_indices = np.array([]) | |
| else: | |
| valid_point_indices = np.array([]) | |
| if len(valid_point_indices) > 0: | |
| # Get trajectories for all valid points | |
| trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame, valid_point_indices] | |
| N_point = trajectories.shape[1] | |
| if color_code == "rainbow": | |
| point_colors = plt.cm.rainbow(np.linspace(0, 1, N_point))[:, :3] | |
| elif color_code == "jet": | |
| point_colors = plt.cm.jet(np.linspace(0, 1, N_point))[:, :3] | |
| # Modify the loop to handle frames less than fixed_length_traj | |
| for i in range(traj_end_frame - traj_start_frame): | |
| # Calculate the actual trajectory length for this frame | |
| actual_length = min(fixed_length_traj, i + 1) | |
| if actual_length > 1: # Need at least 2 points to form a line | |
| # Get the appropriate slice of trajectory data | |
| start_idx = max(0, i - actual_length + 1) | |
| end_idx = i + 1 | |
| # Create line segments between consecutive frames | |
| traj_slice = trajectories[start_idx:end_idx] | |
| line_points = np.stack([traj_slice[:-1], traj_slice[1:]], axis=2) | |
| line_points = line_points.reshape(-1, 2, 3) | |
| # Create corresponding colors | |
| line_colors = np.tile(point_colors, (actual_length-1, 1)) | |
| line_colors = np.stack([line_colors, line_colors], axis=1) | |
| # Add line segments | |
| server.scene.add_line_segments( | |
| name=f"/frames/t{i+traj_start_frame}/head1/trajectory", | |
| points=line_points, | |
| colors=line_colors, | |
| line_width=traj_line_width, | |
| visible=True | |
| ) | |
| # Update color tinting when the checkbox changes | |
| def _(_) -> None: | |
| with server.atomic(): | |
| for i in range(num_frames): | |
| # Update head1 point cloud | |
| if traj_3d_head1 is not None: | |
| position = xyz_head1[i] | |
| color = rgb_head1[i] | |
| if conf_mask_head1 is not None: | |
| position = position[conf_mask_head1[i]] | |
| color = color[conf_mask_head1[i]] | |
| color_head1 = color.copy() | |
| if gui_use_color_tint.value: | |
| color_head1 *= blend_ratio | |
| color_head1[:, 0] = onp.clip(color_head1[:, 0] + blue_r * (1 - blend_ratio), 0, 1) # R | |
| color_head1[:, 1] = onp.clip(color_head1[:, 1] + blue_g * (1 - blend_ratio), 0, 1) # G | |
| color_head1[:, 2] = onp.clip(color_head1[:, 2] + blue_b * (1 - blend_ratio), 0, 1) # B | |
| server.scene.remove_by_name(f"/frames/t{i}/head1/point_cloud") | |
| server.scene.add_point_cloud( | |
| name=f"/frames/t{i}/head1/point_cloud", | |
| points=position[::downsample_factor], | |
| colors=color_head1[::downsample_factor], | |
| point_size=point_size, | |
| point_shape="rounded", | |
| ) | |
| # Update head2 point cloud | |
| if traj_3d_head2 is not None: | |
| position = xyz_head2[i] | |
| color = rgb_head2[i] | |
| if conf_mask_head2 is not None: | |
| position = position[conf_mask_head2[i]] | |
| color = color[conf_mask_head2[i]] | |
| color_head2 = color.copy() | |
| if gui_use_color_tint.value: | |
| color_head2 *= blend_ratio | |
| color_head2[:, 0] = onp.clip(color_head2[:, 0] + red_r * (1 - blend_ratio), 0, 1) # R | |
| color_head2[:, 1] = onp.clip(color_head2[:, 1] + red_g * (1 - blend_ratio), 0, 1) # G | |
| color_head2[:, 2] = onp.clip(color_head2[:, 2] + red_b * (1 - blend_ratio), 0, 1) # B | |
| server.scene.remove_by_name(f"/frames/t{i}/head2/point_cloud") | |
| server.scene.add_point_cloud( | |
| name=f"/frames/t{i}/head2/point_cloud", | |
| points=position[::downsample_factor], | |
| colors=color_head2[::downsample_factor], | |
| point_size=point_size, | |
| point_shape="rounded", | |
| ) | |
| # Initialize video preview | |
| if raw_video is not None: | |
| video_preview.image = process_video_frame(0) | |
| # Update video preview when timestep changes | |
| def _(_) -> None: | |
| current_timestep = gui_timestep.value | |
| if raw_video is not None: | |
| video_preview.image = process_video_frame(current_timestep) | |
| # Playback update loop. | |
| log_memory_usage("before starting playback loop") | |
| prev_timestep = gui_timestep.value | |
| while True: | |
| current_timestep = gui_timestep.value | |
| # If timestep changes, update frame visibility | |
| if current_timestep != prev_timestep: | |
| with server.atomic(): | |
| # ... existing code ... | |
| # Update video preview | |
| if raw_video is not None: | |
| video_preview.image = process_video_frame(current_timestep) | |
| # Update in playback mode | |
| if gui_playing.value and not gui_show_all_frames.value: | |
| gui_timestep.value = (gui_timestep.value + 1) % num_frames | |
| # Update video preview in playback mode | |
| if raw_video is not None: | |
| video_preview.image = process_video_frame(gui_timestep.value) | |
| time.sleep(1.0 / gui_framerate.value) | |
| if __name__ == "__main__": | |
| tyro.cli(visualize_st4rtrack) | |