|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
from typing import List, Optional, Union, Dict |
|
|
|
|
|
import torch |
|
|
|
|
|
from diffusers import AutoencoderKLWan |
|
|
from diffusers.configuration_utils import FrozenDict |
|
|
from diffusers.schedulers import UniPCMultistepScheduler |
|
|
from diffusers.utils import logging |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
from diffusers.video_processor import VideoProcessor |
|
|
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState |
|
|
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
FRAME_MULTIPLE = 4 |
|
|
DEFAULT_SAMPLES_PER_ACTION = 4 |
|
|
DEFAULT_FRAMES_PER_ACTION = 12 |
|
|
|
|
|
DEFAULT_MOUSE_DIM = 2 |
|
|
DEFAULT_KEYBOARD_DIM = 4 |
|
|
|
|
|
|
|
|
CAMERA_MOVEMENT_VALUE = 0.1 |
|
|
CAMERA_VALUE_MAP = { |
|
|
"camera_up": [CAMERA_MOVEMENT_VALUE, 0], |
|
|
"camera_down": [-CAMERA_MOVEMENT_VALUE, 0], |
|
|
"camera_l": [0, -CAMERA_MOVEMENT_VALUE], |
|
|
"camera_r": [0, CAMERA_MOVEMENT_VALUE], |
|
|
"camera_ur": [CAMERA_MOVEMENT_VALUE, CAMERA_MOVEMENT_VALUE], |
|
|
"camera_ul": [CAMERA_MOVEMENT_VALUE, -CAMERA_MOVEMENT_VALUE], |
|
|
"camera_dr": [-CAMERA_MOVEMENT_VALUE, CAMERA_MOVEMENT_VALUE], |
|
|
"camera_dl": [-CAMERA_MOVEMENT_VALUE, -CAMERA_MOVEMENT_VALUE], |
|
|
} |
|
|
|
|
|
|
|
|
MOVEMENT_ACTIONS = ["forward", "left", "right"] |
|
|
COMPOUND_MOVEMENTS = ["forward_left", "forward_right"] |
|
|
CAMERA_ACTIONS = list(CAMERA_VALUE_MAP.keys()) |
|
|
|
|
|
|
|
|
KEYBOARD_ACTION_INDICES = {"forward": 0, "back": 1, "left": 2, "right": 3} |
|
|
|
|
|
|
|
|
def sync_actions_to_frames( |
|
|
actions: List[str], |
|
|
num_frames: int, |
|
|
min_duration: int = 12 |
|
|
) -> List[Dict[str, Union[str, int]]]: |
|
|
""" |
|
|
Synchronize a list of actions to fit exactly within the given number of frames |
|
|
using equal distribution strategy. |
|
|
|
|
|
Args: |
|
|
actions: List of action names to perform |
|
|
num_frames: Total frames to fill |
|
|
min_duration: Minimum frames per action (should be multiple of frame_multiple) |
|
|
frame_multiple: Actions must be multiples of this value |
|
|
|
|
|
Returns: |
|
|
List of action dictionaries with 'type', 'start_frame', and 'duration' |
|
|
""" |
|
|
|
|
|
if not actions: |
|
|
raise ValueError("No actions provided") |
|
|
|
|
|
max_possible_actions = num_frames // DEFAULT_FRAMES_PER_ACTION |
|
|
if len(actions) > max_possible_actions: |
|
|
actions = actions[:max_possible_actions] |
|
|
|
|
|
num_actions = len(actions) |
|
|
|
|
|
frames_per_action = num_frames // num_actions |
|
|
frames_per_action = (frames_per_action // FRAME_MULTIPLE) * FRAME_MULTIPLE |
|
|
frames_per_action = max(DEFAULT_FRAMES_PER_ACTION, frames_per_action) |
|
|
|
|
|
remaining_frames = num_frames - (frames_per_action * num_actions) |
|
|
output = [] |
|
|
current_frame = 0 |
|
|
|
|
|
for i, action in enumerate(actions): |
|
|
duration = frames_per_action if i != num_actions - 1 else num_frames - current_frame |
|
|
|
|
|
output.append({ |
|
|
"action_type": action, |
|
|
"start_frame": current_frame, |
|
|
"duration": duration |
|
|
}) |
|
|
|
|
|
current_frame += duration |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def actions_to_condition_tensors(actions, num_frames): |
|
|
keyboard_conditions = torch.zeros((num_frames, DEFAULT_KEYBOARD_DIM)) |
|
|
mouse_conditions = torch.zeros((num_frames, DEFAULT_MOUSE_DIM)) |
|
|
|
|
|
for action in actions: |
|
|
action_type = action['action_type'] |
|
|
start_frame = action['start_frame'] |
|
|
end_frame = start_frame + action['duration'] |
|
|
|
|
|
action_components = action_type.split("_") |
|
|
for component in action_components: |
|
|
if component in KEYBOARD_ACTION_INDICES: |
|
|
action_idx = KEYBOARD_ACTION_INDICES[component] |
|
|
keyboard_conditions[start_frame:end_frame, action_idx] = 1.0 |
|
|
|
|
|
if not "camera" in action_type: |
|
|
continue |
|
|
|
|
|
mouse_x = mouse_y = 0.0 |
|
|
for idx, component in enumerate(action_components): |
|
|
if not action_components[idx] == "camera": |
|
|
continue |
|
|
|
|
|
camera_action = f"camera_{action_components[idx+1]}" |
|
|
if camera_action not in CAMERA_VALUE_MAP: |
|
|
continue |
|
|
|
|
|
camera_values = CAMERA_VALUE_MAP[camera_action] |
|
|
mouse_x += camera_values[0] |
|
|
mouse_y += camera_values[1] |
|
|
|
|
|
mouse_conditions[start_frame:end_frame, 0] = mouse_x |
|
|
mouse_conditions[start_frame:end_frame, 1] = mouse_y |
|
|
|
|
|
return keyboard_conditions, mouse_conditions |
|
|
|
|
|
|
|
|
def _build_test_actions( |
|
|
movement_actions: List[str], |
|
|
compound_movements: List[str], |
|
|
camera_actions: List[str], |
|
|
) -> List[str]: |
|
|
"""Build comprehensive list of test action combinations. |
|
|
|
|
|
Args: |
|
|
movement_actions: List of basic movement actions |
|
|
compound_movements: List of compound movement actions |
|
|
camera_actions: List of camera control actions |
|
|
|
|
|
Returns: |
|
|
List of all action combinations to test |
|
|
""" |
|
|
|
|
|
test_actions = compound_movements * 5 + camera_actions * 5 + movement_actions * 5 |
|
|
|
|
|
|
|
|
for movement in movement_actions + compound_movements: |
|
|
for camera in camera_actions: |
|
|
combined_action = f"{movement}_{camera}" |
|
|
test_actions.append(combined_action) |
|
|
|
|
|
return test_actions |
|
|
|
|
|
|
|
|
def generate_random_condition_tensors(num_frames: int) -> Dict[str, torch.Tensor]: |
|
|
"""Generate benchmark action sequences for testing. |
|
|
|
|
|
Args: |
|
|
num_frames: Total number of frames to generate |
|
|
num_samples_per_action: Number of samples per action type |
|
|
|
|
|
Returns: |
|
|
Dictionary containing keyboard and mouse conditions for benchmark actions |
|
|
""" |
|
|
|
|
|
actions = _build_test_actions( |
|
|
MOVEMENT_ACTIONS, COMPOUND_MOVEMENTS, CAMERA_ACTIONS |
|
|
) |
|
|
actions = sync_actions_to_frames(actions, num_frames) |
|
|
return actions_to_condition_tensors(actions, num_frames) |
|
|
|
|
|
|
|
|
|
|
|
def retrieve_timesteps( |
|
|
scheduler, |
|
|
num_inference_steps: Optional[int] = None, |
|
|
device: Optional[Union[str, torch.device]] = None, |
|
|
timesteps: Optional[List[int]] = None, |
|
|
sigmas: Optional[List[float]] = None, |
|
|
**kwargs, |
|
|
): |
|
|
r""" |
|
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
|
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
|
|
|
|
|
Args: |
|
|
scheduler (`SchedulerMixin`): |
|
|
The scheduler to get timesteps from. |
|
|
num_inference_steps (`int`): |
|
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
|
|
must be `None`. |
|
|
device (`str` or `torch.device`, *optional*): |
|
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
|
|
timesteps (`List[int]`, *optional*): |
|
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
|
|
`num_inference_steps` and `sigmas` must be `None`. |
|
|
sigmas (`List[float]`, *optional*): |
|
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
|
|
`num_inference_steps` and `timesteps` must be `None`. |
|
|
|
|
|
Returns: |
|
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
|
|
second element is the number of inference steps. |
|
|
""" |
|
|
if timesteps is not None and sigmas is not None: |
|
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
|
|
if timesteps is not None: |
|
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
|
if not accepts_timesteps: |
|
|
raise ValueError( |
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
|
f" timestep schedules. Please check whether you are using the correct scheduler." |
|
|
) |
|
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
|
|
timesteps = scheduler.timesteps |
|
|
num_inference_steps = len(timesteps) |
|
|
elif sigmas is not None: |
|
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
|
if not accept_sigmas: |
|
|
raise ValueError( |
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
|
f" sigmas schedules. Please check whether you are using the correct scheduler." |
|
|
) |
|
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
|
|
timesteps = scheduler.timesteps |
|
|
num_inference_steps = len(timesteps) |
|
|
else: |
|
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
|
|
timesteps = scheduler.timesteps |
|
|
return timesteps, num_inference_steps |
|
|
|
|
|
|
|
|
def retrieve_latents( |
|
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
|
|
): |
|
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
|
|
return encoder_output.latent_dist.sample(generator) |
|
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
|
|
return encoder_output.latent_dist.mode() |
|
|
elif hasattr(encoder_output, "latents"): |
|
|
return encoder_output.latents |
|
|
else: |
|
|
raise AttributeError("Could not access latents of provided encoder_output") |
|
|
|
|
|
|
|
|
class MatrixGameWanActionInputStep(ModularPipelineBlocks): |
|
|
model_name = "MatrixGameWan" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Action Input step" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [InputParam("num_frames", type_hint=int, required=True), InputParam("actions", type_hint=List[str])] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"keyboard_conditions", |
|
|
type_hint=torch.Tensor, |
|
|
description="image embeddings used to guide the image generation", |
|
|
), |
|
|
OutputParam( |
|
|
"mouse_conditions", |
|
|
type_hint=torch.Tensor, |
|
|
description="image embeddings used to guide the image generation", |
|
|
) |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
|
|
|
|
|
block_state = self.get_block_state(state) |
|
|
block_state.device = components._execution_device |
|
|
actions = block_state.actions |
|
|
|
|
|
if actions is not None: |
|
|
actions = sync_actions_to_frames(actions, block_state.num_frames) |
|
|
keyboard_conditions, mouse_conditions = actions_to_condition_tensors(actions, block_state.num_frames) |
|
|
else: |
|
|
keyboard_conditions, mouse_conditions = generate_random_condition_tensors(block_state.num_frames) |
|
|
|
|
|
block_state.keyboard_conditions = keyboard_conditions.to(block_state.device) |
|
|
block_state.mouse_conditions = mouse_conditions.to(block_state.device) |
|
|
|
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|
|
|
|
|
|
class MatrixGameWanSetTimestepsStep(ModularPipelineBlocks): |
|
|
model_name = "MatrixGameWan" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("scheduler", UniPCMultistepScheduler), |
|
|
] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Step that sets the scheduler's timesteps for inference" |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("num_inference_steps", default=4), |
|
|
InputParam("timesteps"), |
|
|
InputParam("sigmas"), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), |
|
|
OutputParam( |
|
|
"num_inference_steps", |
|
|
type_hint=int, |
|
|
description="The number of denoising steps to perform at inference time", |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
block_state.device = components._execution_device |
|
|
|
|
|
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( |
|
|
components.scheduler, |
|
|
block_state.num_inference_steps, |
|
|
block_state.device, |
|
|
block_state.timesteps, |
|
|
block_state.sigmas, |
|
|
) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|
|
|
|
|
|
class MatrixGameWanPrepareLatentsStep(ModularPipelineBlocks): |
|
|
model_name = "MatrixGameWan" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ComponentSpec("vae", AutoencoderKLWan),] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Prepare latents step that prepares the latents for the text-to-video generation process" |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("height", type_hint=int), |
|
|
InputParam("width", type_hint=int), |
|
|
InputParam("num_frames", type_hint=int), |
|
|
InputParam("latents", type_hint=Optional[torch.Tensor]), |
|
|
InputParam("num_videos_per_prompt", type_hint=int, default=1), |
|
|
InputParam("generator"), |
|
|
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" |
|
|
) |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def check_inputs(components, block_state): |
|
|
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( |
|
|
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 |
|
|
): |
|
|
raise ValueError( |
|
|
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." |
|
|
) |
|
|
if block_state.num_frames is not None and ( |
|
|
block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 |
|
|
): |
|
|
raise ValueError( |
|
|
f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def prepare_latents( |
|
|
components, |
|
|
batch_size: int, |
|
|
num_channels_latents: int = 16, |
|
|
height: int = 352, |
|
|
width: int = 640, |
|
|
num_frames: int = 81, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
|
latents: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
if latents is not None: |
|
|
return latents.to(device=device, dtype=dtype) |
|
|
|
|
|
num_latent_frames = (num_frames - 1) // components.vae_scale_factor_temporal + 1 |
|
|
shape = ( |
|
|
batch_size, |
|
|
num_channels_latents, |
|
|
num_latent_frames, |
|
|
int(height) // components.vae_scale_factor_spatial, |
|
|
int(width) // components.vae_scale_factor_spatial, |
|
|
) |
|
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
|
raise ValueError( |
|
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
|
) |
|
|
|
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
|
return latents |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
|
|
|
block_state.height = block_state.height or components.default_height |
|
|
block_state.width = block_state.width or components.default_width |
|
|
block_state.num_frames = block_state.num_frames or components.default_num_frames |
|
|
block_state.device = components._execution_device |
|
|
block_state.dtype = torch.float32 |
|
|
block_state.num_channels_latents = components.num_channels_latents |
|
|
|
|
|
self.check_inputs(components, block_state) |
|
|
|
|
|
block_state.latents = self.prepare_latents( |
|
|
components, |
|
|
1, |
|
|
block_state.num_channels_latents, |
|
|
block_state.height, |
|
|
block_state.width, |
|
|
block_state.num_frames, |
|
|
block_state.dtype, |
|
|
block_state.device, |
|
|
block_state.generator, |
|
|
block_state.latents, |
|
|
) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|
|
|
|
|
|
class MatrixGameWanPrepareImageMaskLatentsStep(ModularPipelineBlocks): |
|
|
model_name = "MatrixGameWan" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("vae", AutoencoderKLWan), |
|
|
ComponentSpec("video_processor", VideoProcessor, config=FrozenDict({"vae_scale_factor": 8})) |
|
|
] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Prepare latents step that prepares the latents for the text-to-video generation process" |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("image"), |
|
|
InputParam("height", type_hint=int), |
|
|
InputParam("width", type_hint=int), |
|
|
InputParam("num_frames", type_hint=int), |
|
|
InputParam("image_mask_latents", type_hint=Optional[torch.Tensor]), |
|
|
InputParam("num_videos_per_prompt", type_hint=int, default=1), |
|
|
InputParam("generator"), |
|
|
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"image_mask_latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" |
|
|
) |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def check_inputs(components, block_state): |
|
|
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( |
|
|
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 |
|
|
): |
|
|
raise ValueError( |
|
|
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
@torch.no_grad() |
|
|
def prepare_latents( |
|
|
components, |
|
|
image, |
|
|
batch_size: int, |
|
|
num_channels_latents: int = 16, |
|
|
height: int = 352, |
|
|
width: int = 640, |
|
|
num_frames: int = 81, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
|
latents: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
if latents is not None: |
|
|
return latents.to(device=device, dtype=dtype) |
|
|
|
|
|
image = components.video_processor.preprocess(image, height, width).to(device, torch.float32) |
|
|
image = image.unsqueeze(2) |
|
|
|
|
|
video_condition = torch.cat( |
|
|
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 |
|
|
) |
|
|
video_condition = video_condition.to(device=device, dtype=components.vae.dtype) |
|
|
latent_condition = retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") |
|
|
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) |
|
|
|
|
|
latents_mean = ( |
|
|
torch.tensor(components.vae.config.latents_mean) |
|
|
.view(1, components.vae.config.z_dim, 1, 1, 1) |
|
|
.to(device, dtype) |
|
|
) |
|
|
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(1, components.vae.config.z_dim, 1, 1, 1).to( |
|
|
device, dtype |
|
|
) |
|
|
latent_condition = latent_condition.to(dtype) |
|
|
latent_condition = (latent_condition - latents_mean) * latents_std |
|
|
|
|
|
latent_height = height // components.vae_scale_factor_spatial |
|
|
latent_width = width // components.vae_scale_factor_spatial |
|
|
|
|
|
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) |
|
|
mask_lat_size[:, :, list(range(1, num_frames))] = 0 |
|
|
|
|
|
first_frame_mask = mask_lat_size[:, :, 0:1] |
|
|
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal) |
|
|
|
|
|
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) |
|
|
mask_lat_size = mask_lat_size.view(batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width) |
|
|
mask_lat_size = mask_lat_size.transpose(1, 2).to(latent_condition.device) |
|
|
|
|
|
image_mask_latents = torch.concat([mask_lat_size, latent_condition], dim=1) |
|
|
return image_mask_latents |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
|
|
|
block_state.height = block_state.height or components.default_height |
|
|
block_state.width = block_state.width or components.default_width |
|
|
block_state.num_frames = block_state.num_frames or components.default_num_frames |
|
|
block_state.device = components._execution_device |
|
|
block_state.dtype = torch.float32 |
|
|
block_state.num_channels_latents = components.num_channels_latents |
|
|
|
|
|
self.check_inputs(components, block_state) |
|
|
block_state.image_mask_latents = self.prepare_latents( |
|
|
components, |
|
|
block_state.image, |
|
|
1, |
|
|
block_state.num_channels_latents, |
|
|
block_state.height, |
|
|
block_state.width, |
|
|
block_state.num_frames, |
|
|
block_state.dtype, |
|
|
block_state.device, |
|
|
block_state.generator, |
|
|
block_state.image_mask_latents, |
|
|
) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|
|
|
|