Spaces:
Runtime error
Runtime error
File size: 6,512 Bytes
8822914 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import torch
import torch.nn.functional as F
def add_first_frame_conditioning(
latent_model_input,
first_frame,
vae
):
"""
Adds first frame conditioning to a video diffusion model input.
Args:
latent_model_input: Original latent input (bs, channels, num_frames, height, width)
first_frame: Tensor of first frame to condition on (bs, channels, height, width)
vae: VAE model for encoding the conditioning
Returns:
conditioned_latent: The complete conditioned latent input (bs, 36, num_frames, height, width)
"""
device = latent_model_input.device
dtype = latent_model_input.dtype
vae_scale_factor_temporal = 2 ** sum(vae.temperal_downsample)
# Get number of frames from latent model input
_, _, num_latent_frames, _, _ = latent_model_input.shape
# Calculate original number of frames
# For n original frames, there are (n-1)//4 + 1 latent frames
# So to get n: n = (num_latent_frames-1)*4 + 1
num_frames = (num_latent_frames - 1) * 4 + 1
if len(first_frame.shape) == 3:
# we have a single image
first_frame = first_frame.unsqueeze(0)
# if it doesnt match the batch size, we need to expand it
if first_frame.shape[0] != latent_model_input.shape[0]:
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1)
# resize first frame to match the latent model input
vae_scale_factor = vae.config.scale_factor_spatial
first_frame = F.interpolate(
first_frame,
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor),
mode='bilinear',
align_corners=False
)
# Add temporal dimension to first frame
first_frame = first_frame.unsqueeze(2)
# Create video condition with first frame and zeros for remaining frames
zero_frame = torch.zeros_like(first_frame)
video_condition = torch.cat([
first_frame,
*[zero_frame for _ in range(num_frames - 1)]
], dim=2)
# Prepare for VAE encoding (bs, channels, num_frames, height, width)
# video_condition = video_condition.permute(0, 2, 1, 3, 4)
# Encode with VAE
latent_condition = vae.encode(
video_condition.to(device, dtype)
).latent_dist.sample()
latent_condition = latent_condition.to(device, dtype)
latents_mean = (
torch.tensor(vae.config.latents_mean)
.view(1, vae.config.z_dim, 1, 1, 1)
.to(device, dtype)
)
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
device, dtype
)
latent_condition = (latent_condition - latents_mean) * latents_std
# Create mask: 1 for conditioning frames, 0 for frames to generate
batch_size = first_frame.shape[0]
latent_height = latent_condition.shape[3]
latent_width = latent_condition.shape[4]
# Initialize mask for all frames
mask_lat_size = torch.ones(
batch_size, 1, num_frames, latent_height, latent_width)
# Set all non-first frames to 0
mask_lat_size[:, :, list(range(1, num_frames))] = 0
# Special handling for first frame
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=vae_scale_factor_temporal)
# Combine first frame mask with rest
mask_lat_size = torch.concat(
[first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
# Reshape and transpose for model input
mask_lat_size = mask_lat_size.view(
batch_size, -1, vae_scale_factor_temporal, latent_height, latent_width)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(device, dtype)
# Combine conditioning with latent input
first_frame_condition = torch.concat(
[mask_lat_size, latent_condition], dim=1)
conditioned_latent = torch.cat(
[latent_model_input, first_frame_condition], dim=1)
return conditioned_latent
def add_first_frame_conditioning_v22(
latent_model_input,
first_frame,
vae,
last_frame=None
):
"""
Overwrites first few time steps in latent_model_input with VAE-encoded first_frame,
and returns the modified latent + binary mask (0=conditioned, 1=noise).
Args:
latent_model_input: torch.Tensor of shape (bs, 48, T, H, W)
first_frame: torch.Tensor of shape (bs, 3, H*scale, W*scale)
vae: VAE model with .encode() and .config.latents_mean/std
Returns:
latent: (bs, 48, T, H, W) - modified input latent
mask: (bs, 1, T, H, W) - binary mask
"""
device = latent_model_input.device
dtype = latent_model_input.dtype
bs, _, T, H, W = latent_model_input.shape
scale = vae.config.scale_factor_spatial
target_h = H * scale
target_w = W * scale
# Ensure shape
if first_frame.ndim == 3:
first_frame = first_frame.unsqueeze(0)
if first_frame.shape[0] != bs:
first_frame = first_frame.expand(bs, -1, -1, -1)
# Resize and encode
first_frame_up = F.interpolate(first_frame, size=(target_h, target_w), mode="bilinear", align_corners=False)
first_frame_up = first_frame_up.unsqueeze(2) # (bs, 3, 1, H, W)
encoded = vae.encode(first_frame_up).latent_dist.sample().to(dtype).to(device)
# Normalize
mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
std = 1.0 / torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
encoded = (encoded - mean) * std
# Replace in latent
latent = latent_model_input.clone()
latent[:, :, :encoded.shape[2]] = encoded # typically first frame: [:, :, 0]
# Mask: 0 where conditioned, 1 otherwise
mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype)
mask[:, :, :encoded.shape[2]] = 0.0
if last_frame is not None:
# If last_frame is provided, encode it similarly
last_frame_up = F.interpolate(last_frame, size=(target_h, target_w), mode="bilinear", align_corners=False)
last_frame_up = last_frame_up.unsqueeze(2)
last_encoded = vae.encode(last_frame_up).latent_dist.sample().to(dtype).to(device)
last_encoded = (last_encoded - mean) * std
latent[:, :, -last_encoded.shape[2]:] = last_encoded # replace last
mask[:, :, -last_encoded.shape[2]:] = 0.0 #
# Ensure mask is still binary
mask = mask.clamp(0.0, 1.0)
return latent, mask |