Spaces:
Running
on
Zero
Running
on
Zero
| import tempfile | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import imageio | |
| import torch | |
| import torchvision.io as io | |
| from torchvision.transforms import functional as F | |
| from PIL import Image, ImageDraw, ImageFont | |
| import torch.nn.functional as nnf | |
| def convert_to_rgb(frame): | |
| """Convert frame to RGB format.""" | |
| if frame.shape[2] == 4: # RGBA | |
| # Convert RGBA to RGB using alpha compositing with white background | |
| alpha = frame[:, :, 3:4] / 255.0 | |
| rgb = frame[:, :, :3] | |
| return (rgb * alpha + (1 - alpha) * 255).astype(np.uint8) | |
| return frame | |
| def process_frames_batch(frames, target_size, device): | |
| """Process a batch of frames efficiently.""" | |
| # Stack frames and move to GPU | |
| frames = torch.stack(frames).to(device) | |
| # Batch resize | |
| frames = nnf.interpolate(frames, size=target_size, | |
| mode='bilinear', align_corners=False) | |
| return frames | |
| def combine_video(obj_dir, output_path, input_frames=None, displayed_preds=3): | |
| """Combine multiple GIFs into a grid layout using torchvision.""" | |
| print("Starting video combination process...") | |
| # Set device for GPU acceleration | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Get all GIF files from shadow_gif directory | |
| shadow_gif_dir = os.path.join(obj_dir, 'shadow_gif') | |
| gif_files = [f for f in os.listdir(shadow_gif_dir) if f.endswith('_tranp.gif') and not f.startswith('obs')] | |
| gif_files = sorted(gif_files) | |
| # Limit number of GIFs based on displayed_preds | |
| gif_files = gif_files[:displayed_preds] | |
| print(f"Using {len(gif_files)} GIFs for {displayed_preds} predictions") | |
| # Calculate grid dimensions | |
| grid_cols = min(displayed_preds, 3) # Maximum 3 columns | |
| grid_rows = (displayed_preds + grid_cols - 1) // grid_cols | |
| print(f"Grid layout: {grid_rows}x{grid_cols}") | |
| # Load and process all GIFs | |
| gif_frames = [] | |
| durations = [] | |
| for gif_file in gif_files: | |
| gif_path = os.path.join(shadow_gif_dir, gif_file) | |
| print(f"Loading {gif_file}...") | |
| # Read GIF frames efficiently | |
| with imageio.get_reader(gif_path) as reader: | |
| frames = [] | |
| for frame in reader: | |
| # Convert to RGB if needed | |
| frame = convert_to_rgb(frame) | |
| frame = cv2.resize(frame, (frame.shape[1] // 4, frame.shape[0] // 4), interpolation=cv2.INTER_AREA) | |
| # Convert to tensor and normalize | |
| frame = torch.from_numpy(frame).permute(2, 0, 1).float().to(device) / 255.0 | |
| frames.append(frame) | |
| # Get duration from the first frame | |
| with Image.open(gif_path) as img: | |
| duration = img.info.get('duration', 100) / 1000.0 # Convert to seconds | |
| gif_frames.append(frames) | |
| durations.append(duration) | |
| if not gif_frames: | |
| raise ValueError("No GIF files found!") | |
| # Get common duration | |
| common_duration = min(durations) | |
| print(f"Common duration: {common_duration}") | |
| # Process input frames if provided | |
| if input_frames is not None: | |
| # Convert BGR to RGB and resize input frames | |
| input_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in input_frames] | |
| input_frames = [cv2.resize(frame, (frame.shape[1]//8, frame.shape[0]//8), | |
| interpolation=cv2.INTER_NEAREST) for frame in input_frames] | |
| input_frames = [torch.from_numpy(frame).permute(2, 0, 1).float().to(device) / 255.0 | |
| for frame in input_frames] | |
| # Calculate target size for each GIF in the grid | |
| first_frame = gif_frames[0][0] | |
| target_height = first_frame.shape[1] | |
| target_width = first_frame.shape[2] | |
| # Create grid frames | |
| num_frames = max(len(frames) for frames in gif_frames) | |
| grid_frames = [] | |
| # Process frames in batches | |
| batch_size = 4 # Adjust based on GPU memory | |
| for frame_idx in range(0, num_frames, batch_size): | |
| batch_end = min(frame_idx + batch_size, num_frames) | |
| # Create empty grid for the batch | |
| grid = torch.ones((batch_end - frame_idx, 3, target_height * grid_rows, target_width * grid_cols), | |
| device=device) | |
| # Process each GIF in the batch | |
| for i, frames in enumerate(gif_frames): | |
| row = i // grid_cols | |
| col = i % grid_cols | |
| # Get frames for this batch | |
| batch_frames = frames[frame_idx:batch_end] | |
| if batch_frames: | |
| # Process frames in batch | |
| resized_frames = process_frames_batch(batch_frames, (target_height, target_width), device) | |
| # Add to grid | |
| for j, frame in enumerate(resized_frames): | |
| grid[j, :, row*target_height:(row+1)*target_height, | |
| col*target_width:(col+1)*target_width] = frame | |
| # Add input frames if provided | |
| if input_frames is not None: | |
| for i in range(len(gif_frames)): | |
| row = i // grid_cols | |
| col = i % grid_cols | |
| # Get input frames for this batch | |
| batch_input_frames = input_frames[frame_idx:batch_end] | |
| if batch_input_frames: | |
| orig_h, orig_w = batch_input_frames[0].shape[1:3] # (C, H, W) | |
| pip_max_width = target_width // 2 | |
| pip_max_height = target_height // 2 | |
| aspect = orig_w / orig_h | |
| if pip_max_width / aspect <= pip_max_height: | |
| pip_w = pip_max_width | |
| pip_h = int(pip_max_width / aspect) | |
| else: | |
| pip_h = pip_max_height | |
| pip_w = int(pip_max_height * aspect) | |
| # resize | |
| resized_input_frames = process_frames_batch(batch_input_frames, (pip_h, pip_w), device) | |
| # Add to grid | |
| for j, frame in enumerate(resized_input_frames): | |
| x_pos = col * target_width + target_width - frame.shape[2] - 10 | |
| y_pos = row * target_height + 10 | |
| grid[j, :, y_pos:y_pos+frame.shape[1], x_pos:x_pos+frame.shape[2]] = frame | |
| # Add batch to grid_frames | |
| grid_frames.extend([frame for frame in grid]) | |
| # Convert frames to numpy and save as GIF | |
| print(f"Saving to {output_path}") | |
| frames_np = [(frame.cpu().permute(1, 2, 0).numpy() * 255).astype(np.uint8) | |
| for frame in grid_frames] | |
| # Save as GIF with optimization | |
| imageio.mimsave(output_path, frames_np, fps=30, optimize=True, quantizer=0, loop=0) | |
| print("Video combination completed!") | |
| return output_path | |
| if __name__ == "__main__": | |
| combine_video("./9622_GRAB/", tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name) | |