Spaces:
Running
Running
| import torch | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| from einops import rearrange, repeat | |
| import math | |
| def get_unique_embedder_keys_from_conditioner(conditioner): | |
| return list(set([x.input_key for x in conditioner.embedders])) | |
| def get_batch(keys, value_dict, N, T, device): | |
| batch = {} | |
| batch_uc = {} | |
| for key in keys: | |
| if key == "fps_id": | |
| batch[key] = ( | |
| torch.tensor([value_dict["fps_id"]]) | |
| .to(device) | |
| .repeat(int(math.prod(N))) | |
| ) | |
| elif key == "motion_bucket_id": | |
| batch[key] = ( | |
| torch.tensor([value_dict["motion_bucket_id"]]) | |
| .to(device) | |
| .repeat(int(math.prod(N))) | |
| ) | |
| elif key == "cond_aug": | |
| batch[key] = repeat( | |
| torch.tensor([value_dict["cond_aug"]]).to(device), | |
| "1 -> b", | |
| b=math.prod(N), | |
| ) | |
| elif key == "cond_frames": | |
| batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) | |
| elif key == "cond_frames_without_noise": | |
| batch[key] = repeat( | |
| value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] | |
| ) | |
| else: | |
| batch[key] = value_dict[key] | |
| if T is not None: | |
| batch["num_video_frames"] = T | |
| for key in batch.keys(): | |
| if key not in batch_uc and isinstance(batch[key], torch.Tensor): | |
| batch_uc[key] = torch.clone(batch[key]) | |
| return batch, batch_uc | |
| def merge_overlapping_segments(segments: torch.Tensor, overlap: int) -> torch.Tensor: | |
| """ | |
| Merges overlapping segments by averaging overlapping frames. | |
| Segments have shape (b, t, ...), where 'b' is the number of segments, | |
| 't' is frames per segment, and '...' are other dimensions. | |
| Args: | |
| segments: Tensor of shape (b, t, ...) | |
| overlap: Integer, number of frames that overlap between consecutive segments | |
| Returns: | |
| Tensor of the merged video | |
| """ | |
| # Get the shape details | |
| b, t, *other_dims = segments.shape | |
| num_frames = (b - 1) * ( | |
| t - overlap | |
| ) + t # Calculate the total number of frames in the merged video | |
| # Initialize the output tensor and a count tensor to keep track of contributions for averaging | |
| output_shape = [num_frames] + other_dims | |
| output = torch.zeros(output_shape, dtype=segments.dtype, device=segments.device) | |
| count = torch.zeros(output_shape, dtype=torch.float32, device=segments.device) | |
| current_index = 0 | |
| for i in range(b): | |
| end_index = current_index + t | |
| # Add the segment to the output tensor | |
| output[current_index:end_index] += rearrange(segments[i], "... -> ...") | |
| # Increment the count tensor for each frame that's added | |
| count[current_index:end_index] += 1 | |
| # Update the starting index for the next segment | |
| current_index += t - overlap | |
| # Avoid division by zero | |
| count[count == 0] = 1 | |
| # Average the frames where there's overlap | |
| output /= count | |
| return output | |
| def get_batch_overlap( | |
| keys: List[str], | |
| value_dict: Dict[str, Any], | |
| N: Tuple[int, ...], | |
| T: Optional[int], | |
| device: str, | |
| ) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
| """ | |
| Create a batch dictionary with overlapping frames for model input. | |
| Args: | |
| keys: List of keys to include in the batch | |
| value_dict: Dictionary containing values for each key | |
| N: Batch dimensions | |
| T: Number of frames (optional) | |
| device: Device to place tensors on | |
| Returns: | |
| Tuple of (batch dictionary, unconditional batch dictionary) | |
| """ | |
| batch = {} | |
| batch_uc = {} | |
| for key in keys: | |
| if key == "fps_id": | |
| batch[key] = ( | |
| torch.tensor([value_dict["fps_id"]]) | |
| .to(device) | |
| .repeat(int(math.prod(N))) | |
| ) | |
| elif key == "motion_bucket_id": | |
| batch[key] = ( | |
| torch.tensor([value_dict["motion_bucket_id"]]) | |
| .to(device) | |
| .repeat(int(math.prod(N))) | |
| ) | |
| elif key == "cond_aug": | |
| batch[key] = repeat( | |
| torch.tensor([value_dict["cond_aug"]]).to(device), | |
| "1 -> b", | |
| b=math.prod(N), | |
| ) | |
| elif key == "cond_frames": | |
| batch[key] = repeat(value_dict["cond_frames"], "b ... -> (b t) ...", t=N[0]) | |
| elif key == "cond_frames_without_noise": | |
| batch[key] = repeat( | |
| value_dict["cond_frames_without_noise"], "b ... -> (b t) ...", t=N[0] | |
| ) | |
| else: | |
| batch[key] = value_dict[key] | |
| if T is not None: | |
| batch["num_video_frames"] = T | |
| for key in batch.keys(): | |
| if key not in batch_uc and isinstance(batch[key], torch.Tensor): | |
| batch_uc[key] = torch.clone(batch[key]) | |
| return batch, batch_uc | |
| def sample_keyframes( | |
| model_keyframes: Any, | |
| audio_list: torch.Tensor, | |
| gt_list: torch.Tensor, | |
| masks_list: torch.Tensor, | |
| condition: torch.Tensor, | |
| num_frames: int, | |
| fps_id: int, | |
| cond_aug: float, | |
| device: str, | |
| embbedings: Optional[torch.Tensor], | |
| force_uc_zero_embeddings: List[str], | |
| n_batch_keyframes: int, | |
| added_frames: int, | |
| strength: float, | |
| scale: Optional[Union[float, List[float]]], | |
| gt_as_cond: bool = False, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Sample keyframes using the keyframe generation model. | |
| Args: | |
| model_keyframes: The keyframe generation model | |
| audio_list: List of audio embeddings | |
| gt_list: List of ground truth frames | |
| masks_list: List of masks | |
| condition: Conditioning tensor | |
| num_frames: Number of frames to generate | |
| fps_id: FPS ID | |
| cond_aug: Conditioning augmentation factor | |
| device: Device to use for computation | |
| embbedings: Optional embeddings | |
| force_uc_zero_embeddings: List of embeddings to force to zero in unconditional case | |
| n_batch_keyframes: Batch size for keyframe generation | |
| added_frames: Number of additional frames | |
| strength: Strength parameter for sampling | |
| scale: Scale parameter for guidance | |
| gt_as_cond: Whether to use ground truth as conditioning | |
| Returns: | |
| Tuple of (latent samples, decoded samples) | |
| """ | |
| if scale is not None: | |
| model_keyframes.sampler.guider.set_scale(scale) | |
| # samples_list = [] | |
| samples_z_list = [] | |
| # samples_x_list = [] | |
| for i in range(audio_list.shape[0]): | |
| H, W = condition.shape[-2:] | |
| assert condition.shape[1] == 3 | |
| F = 8 | |
| C = 4 | |
| shape = (num_frames, C, H // F, W // F) | |
| audio_cond = audio_list[i].unsqueeze(0) | |
| value_dict: Dict[str, Any] = {} | |
| value_dict["fps_id"] = fps_id | |
| value_dict["cond_aug"] = cond_aug | |
| value_dict["cond_frames_without_noise"] = condition | |
| if embbedings is not None: | |
| value_dict["cond_frames"] = embbedings + cond_aug * torch.randn_like( | |
| embbedings | |
| ) | |
| else: | |
| value_dict["cond_frames"] = condition + cond_aug * torch.randn_like( | |
| condition | |
| ) | |
| gt = rearrange(gt_list[i].unsqueeze(0), "b t c h w -> b c t h w").to(device) | |
| if gt_as_cond: | |
| value_dict["cond_frames"] = gt[:, :, 0] | |
| value_dict["cond_aug"] = cond_aug | |
| value_dict["audio_emb"] = audio_cond | |
| value_dict["gt"] = gt | |
| value_dict["masks"] = masks_list[i].unsqueeze(0).transpose(1, 2).to(device) | |
| with torch.no_grad(): | |
| batch, batch_uc = get_batch( | |
| get_unique_embedder_keys_from_conditioner(model_keyframes.conditioner), | |
| value_dict, | |
| [1, 1], | |
| T=num_frames, | |
| device=device, | |
| ) | |
| c, uc = model_keyframes.conditioner.get_unconditional_conditioning( | |
| batch, | |
| batch_uc=batch_uc, | |
| force_uc_zero_embeddings=force_uc_zero_embeddings, | |
| ) | |
| for k in ["crossattn"]: | |
| if c[k].shape[1] != num_frames: | |
| uc[k] = repeat( | |
| uc[k], | |
| "b ... -> b t ...", | |
| t=num_frames, | |
| ) | |
| uc[k] = rearrange( | |
| uc[k], | |
| "b t ... -> (b t) ...", | |
| t=num_frames, | |
| ) | |
| c[k] = repeat( | |
| c[k], | |
| "b ... -> b t ...", | |
| t=num_frames, | |
| ) | |
| c[k] = rearrange( | |
| c[k], | |
| "b t ... -> (b t) ...", | |
| t=num_frames, | |
| ) | |
| video = torch.randn(shape, device=device) | |
| additional_model_inputs: Dict[str, torch.Tensor] = {} | |
| additional_model_inputs["image_only_indicator"] = torch.zeros( | |
| n_batch_keyframes, num_frames | |
| ).to(device) | |
| additional_model_inputs["num_video_frames"] = batch["num_video_frames"] | |
| def denoiser( | |
| input: torch.Tensor, sigma: torch.Tensor, c: Dict[str, torch.Tensor] | |
| ) -> torch.Tensor: | |
| return model_keyframes.denoiser( | |
| model_keyframes.model, | |
| input, | |
| sigma, | |
| c, | |
| **additional_model_inputs, | |
| ) | |
| samples_z = model_keyframes.sampler( | |
| denoiser, video, cond=c, uc=uc, strength=strength | |
| ) | |
| samples_z_list.append(samples_z) | |
| # samples_x_list.append(samples_x) | |
| # samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) | |
| # samples_list.append(samples) | |
| video = None | |
| # samples = ( | |
| # torch.concat(samples_list)[:-added_frames] | |
| # if added_frames > 0 | |
| # else torch.concat(samples_list) | |
| # ) | |
| samples_z = ( | |
| torch.concat(samples_z_list)[:-added_frames] | |
| if added_frames > 0 | |
| else torch.concat(samples_z_list) | |
| ) | |
| # samples_x = ( | |
| # torch.concat(samples_x_list)[:-added_frames] | |
| # if added_frames > 0 | |
| # else torch.concat(samples_x_list) | |
| # ) | |
| return samples_z | |
| def sample_interpolation( | |
| model: Any, | |
| samples_z: torch.Tensor, | |
| # samples_x: torch.Tensor, | |
| audio_interpolation_list: List[torch.Tensor], | |
| gt_chunks: List[torch.Tensor], | |
| masks_chunks: List[torch.Tensor], | |
| condition: torch.Tensor, | |
| num_frames: int, | |
| device: str, | |
| overlap: int, | |
| fps_id: int, | |
| cond_aug: float, | |
| force_uc_zero_embeddings: List[str], | |
| n_batch: int, | |
| chunk_size: Optional[int], | |
| strength: float, | |
| scale: Optional[float] = None, | |
| cut_audio: bool = False, | |
| to_remove: List[bool] = [], | |
| ) -> np.ndarray: | |
| """ | |
| Sample interpolation frames between keyframes. | |
| Args: | |
| model: The interpolation model | |
| samples_z: Latent samples from keyframe generation | |
| samples_x: Decoded samples from keyframe generation | |
| audio_interpolation_list: List of audio embeddings for interpolation | |
| gt_chunks: Ground truth video chunks | |
| masks_chunks: Mask chunks for conditional generation | |
| condition: Visual conditioning | |
| num_frames: Number of frames to generate | |
| device: Device to run inference on | |
| overlap: Number of frames to overlap between segments | |
| fps_id: FPS ID for conditioning | |
| motion_bucket_id: Motion bucket ID for conditioning | |
| cond_aug: Conditioning augmentation strength | |
| force_uc_zero_embeddings: Keys to zero out in unconditional embeddings | |
| n_batch: Batch size for generation | |
| chunk_size: Size of chunks for processing (to manage memory) | |
| strength: Strength of the conditioning | |
| scale: Optional scale for classifier-free guidance | |
| cut_audio: Whether to cut audio embeddings | |
| to_remove: List of flags indicating which frames to remove | |
| Returns: | |
| Generated video frames as numpy array | |
| """ | |
| if scale is not None: | |
| model.sampler.guider.set_scale(scale) | |
| # Creating condition for interpolation model. We need to create a list of inputs, each input is [first, last] | |
| # The first and last are the first and last frames of the interpolation | |
| # interpolation_cond_list = [] | |
| interpolation_cond_list_emb = [] | |
| # samples_x = [sample for i, sample in zip(to_remove, samples_x) if not i] | |
| samples_z = [sample for i, sample in zip(to_remove, samples_z) if not i] | |
| for i in range(0, len(samples_z) - 1, overlap if overlap > 0 else 2): | |
| # interpolation_cond_list.append( | |
| # torch.stack([samples_x[i], samples_x[i + 1]], dim=1) | |
| # ) | |
| interpolation_cond_list_emb.append( | |
| torch.stack([samples_z[i], samples_z[i + 1]], dim=1) | |
| ) | |
| # condition = torch.stack(interpolation_cond_list).to(device) | |
| audio_cond = torch.stack(audio_interpolation_list).to(device) | |
| embbedings = torch.stack(interpolation_cond_list_emb).to(device) | |
| gt_chunks = torch.stack(gt_chunks).to(device) | |
| masks_chunks = torch.stack(masks_chunks).to(device) | |
| H, W = 512, 512 | |
| F = 8 | |
| C = 4 | |
| shape = (num_frames * audio_cond.shape[0], C, H // F, W // F) | |
| value_dict: Dict[str, Any] = {} | |
| value_dict["fps_id"] = fps_id | |
| value_dict["cond_aug"] = cond_aug | |
| # value_dict["cond_frames_without_noise"] = condition | |
| value_dict["cond_frames"] = embbedings | |
| value_dict["cond_aug"] = cond_aug | |
| if cut_audio: | |
| value_dict["audio_emb"] = audio_cond[:, :, :, :768] | |
| else: | |
| value_dict["audio_emb"] = audio_cond | |
| value_dict["gt"] = rearrange(gt_chunks, "b t c h w -> b c t h w").to(device) | |
| value_dict["masks"] = masks_chunks.transpose(1, 2).to(device) | |
| with torch.no_grad(): | |
| with torch.autocast(device): | |
| batch, batch_uc = get_batch_overlap( | |
| get_unique_embedder_keys_from_conditioner(model.conditioner), | |
| value_dict, | |
| [1, num_frames], | |
| T=num_frames, | |
| device=device, | |
| ) | |
| c, uc = model.conditioner.get_unconditional_conditioning( | |
| batch, | |
| batch_uc=batch_uc, | |
| force_uc_zero_embeddings=force_uc_zero_embeddings, | |
| ) | |
| for k in ["crossattn"]: | |
| if c[k].shape[1] != num_frames: | |
| uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) | |
| uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) | |
| c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) | |
| c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) | |
| video = torch.randn(shape, device=device) | |
| additional_model_inputs: Dict[str, torch.Tensor] = {} | |
| additional_model_inputs["image_only_indicator"] = torch.zeros( | |
| n_batch, num_frames | |
| ).to(device) | |
| additional_model_inputs["num_video_frames"] = batch["num_video_frames"] | |
| # Debug information | |
| print( | |
| f"Shapes - Embeddings: {embbedings.shape}, " | |
| f"Audio: {audio_cond.shape}, Video: {shape}, Additional inputs: {additional_model_inputs}" | |
| ) | |
| if chunk_size is not None: | |
| chunk_size = chunk_size * num_frames | |
| def denoiser( | |
| input: torch.Tensor, sigma: torch.Tensor, c: Dict[str, torch.Tensor] | |
| ) -> torch.Tensor: | |
| return model.denoiser( | |
| model.model, | |
| input, | |
| sigma, | |
| c, | |
| num_overlap_frames=overlap, | |
| num_frames=num_frames, | |
| n_skips=n_batch, | |
| chunk_size=chunk_size, | |
| **additional_model_inputs, | |
| ) | |
| samples_z = model.sampler(denoiser, video, cond=c, uc=uc, strength=strength) | |
| samples_z = rearrange(samples_z, "(b t) c h w -> b t c h w", t=num_frames) | |
| samples_z[:, 0] = embbedings[:, :, 0] | |
| samples_z[:, -1] = embbedings[:, :, 1] | |
| samples_z = rearrange(samples_z, "b t c h w -> (b t) c h w") | |
| samples_x = model.decode_first_stage(samples_z) | |
| samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) | |
| # Free up memory | |
| video = None | |
| samples = rearrange(samples, "(b t) c h w -> b t c h w", t=num_frames) | |
| samples = merge_overlapping_segments(samples, overlap) | |
| vid = ( | |
| (rearrange(samples, "t c h w -> t c h w") * 255).cpu().numpy().astype(np.uint8) | |
| ) | |
| return vid | |