import tempfile import os import cv2 import numpy as np import imageio import torch import torchvision.io as io from torchvision.transforms import functional as F from PIL import Image, ImageDraw, ImageFont import torch.nn.functional as nnf def convert_to_rgb(frame): """Convert frame to RGB format.""" if frame.shape[2] == 4: # RGBA # Convert RGBA to RGB using alpha compositing with white background alpha = frame[:, :, 3:4] / 255.0 rgb = frame[:, :, :3] return (rgb * alpha + (1 - alpha) * 255).astype(np.uint8) return frame def process_frames_batch(frames, target_size, device): """Process a batch of frames efficiently.""" # Stack frames and move to GPU frames = torch.stack(frames).to(device) # Batch resize frames = nnf.interpolate(frames, size=target_size, mode='bilinear', align_corners=False) return frames def combine_video(obj_dir, output_path, input_frames=None, displayed_preds=3): """Combine multiple GIFs into a grid layout using torchvision.""" print("Starting video combination process...") # Set device for GPU acceleration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Get all GIF files from shadow_gif directory shadow_gif_dir = os.path.join(obj_dir, 'shadow_gif') gif_files = [f for f in os.listdir(shadow_gif_dir) if f.endswith('_tranp.gif') and not f.startswith('obs')] gif_files = sorted(gif_files) # Limit number of GIFs based on displayed_preds gif_files = gif_files[:displayed_preds] print(f"Using {len(gif_files)} GIFs for {displayed_preds} predictions") # Calculate grid dimensions grid_cols = min(displayed_preds, 3) # Maximum 3 columns grid_rows = (displayed_preds + grid_cols - 1) // grid_cols print(f"Grid layout: {grid_rows}x{grid_cols}") # Load and process all GIFs gif_frames = [] durations = [] for gif_file in gif_files: gif_path = os.path.join(shadow_gif_dir, gif_file) print(f"Loading {gif_file}...") # Read GIF frames efficiently with imageio.get_reader(gif_path) as reader: frames = [] for frame in reader: # Convert to RGB if needed frame = convert_to_rgb(frame) frame = cv2.resize(frame, (frame.shape[1] // 4, frame.shape[0] // 4), interpolation=cv2.INTER_AREA) # Convert to tensor and normalize frame = torch.from_numpy(frame).permute(2, 0, 1).float().to(device) / 255.0 frames.append(frame) # Get duration from the first frame with Image.open(gif_path) as img: duration = img.info.get('duration', 100) / 1000.0 # Convert to seconds gif_frames.append(frames) durations.append(duration) if not gif_frames: raise ValueError("No GIF files found!") # Get common duration common_duration = min(durations) print(f"Common duration: {common_duration}") # Process input frames if provided if input_frames is not None: # Convert BGR to RGB and resize input frames input_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in input_frames] input_frames = [cv2.resize(frame, (frame.shape[1]//8, frame.shape[0]//8), interpolation=cv2.INTER_NEAREST) for frame in input_frames] input_frames = [torch.from_numpy(frame).permute(2, 0, 1).float().to(device) / 255.0 for frame in input_frames] # Calculate target size for each GIF in the grid first_frame = gif_frames[0][0] target_height = first_frame.shape[1] target_width = first_frame.shape[2] # Create grid frames num_frames = max(len(frames) for frames in gif_frames) grid_frames = [] # Process frames in batches batch_size = 4 # Adjust based on GPU memory for frame_idx in range(0, num_frames, batch_size): batch_end = min(frame_idx + batch_size, num_frames) # Create empty grid for the batch grid = torch.ones((batch_end - frame_idx, 3, target_height * grid_rows, target_width * grid_cols), device=device) # Process each GIF in the batch for i, frames in enumerate(gif_frames): row = i // grid_cols col = i % grid_cols # Get frames for this batch batch_frames = frames[frame_idx:batch_end] if batch_frames: # Process frames in batch resized_frames = process_frames_batch(batch_frames, (target_height, target_width), device) # Add to grid for j, frame in enumerate(resized_frames): grid[j, :, row*target_height:(row+1)*target_height, col*target_width:(col+1)*target_width] = frame # Add input frames if provided if input_frames is not None: for i in range(len(gif_frames)): row = i // grid_cols col = i % grid_cols # Get input frames for this batch batch_input_frames = input_frames[frame_idx:batch_end] if batch_input_frames: orig_h, orig_w = batch_input_frames[0].shape[1:3] # (C, H, W) pip_max_width = target_width // 2 pip_max_height = target_height // 2 aspect = orig_w / orig_h if pip_max_width / aspect <= pip_max_height: pip_w = pip_max_width pip_h = int(pip_max_width / aspect) else: pip_h = pip_max_height pip_w = int(pip_max_height * aspect) # resize resized_input_frames = process_frames_batch(batch_input_frames, (pip_h, pip_w), device) # Add to grid for j, frame in enumerate(resized_input_frames): x_pos = col * target_width + target_width - frame.shape[2] - 10 y_pos = row * target_height + 10 grid[j, :, y_pos:y_pos+frame.shape[1], x_pos:x_pos+frame.shape[2]] = frame # Add batch to grid_frames grid_frames.extend([frame for frame in grid]) # Convert frames to numpy and save as GIF print(f"Saving to {output_path}") frames_np = [(frame.cpu().permute(1, 2, 0).numpy() * 255).astype(np.uint8) for frame in grid_frames] # Save as GIF with optimization imageio.mimsave(output_path, frames_np, fps=30, optimize=True, quantizer=0, loop=0) print("Video combination completed!") return output_path if __name__ == "__main__": combine_video("./9622_GRAB/", tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name)