|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from diffusers.loaders import WanLoraLoaderMixin |
|
|
from diffusers.utils import logging |
|
|
from diffusers.modular_pipelines import ModularPipeline |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class MatrixGameWanModularPipeline(ModularPipeline, WanLoraLoaderMixin): |
|
|
""" |
|
|
A ModularPipeline for MatrixGameWan. |
|
|
|
|
|
<Tip warning={true}> |
|
|
|
|
|
This is an experimental feature and is likely to change in the future. |
|
|
|
|
|
</Tip> |
|
|
""" |
|
|
|
|
|
@property |
|
|
def default_height(self): |
|
|
return self.default_sample_height * self.vae_scale_factor_spatial |
|
|
|
|
|
@property |
|
|
def default_width(self): |
|
|
return self.default_sample_width * self.vae_scale_factor_spatial |
|
|
|
|
|
@property |
|
|
def default_num_frames(self): |
|
|
return (self.default_sample_num_frames - 1) * self.vae_scale_factor_temporal + 1 |
|
|
|
|
|
@property |
|
|
def default_sample_height(self): |
|
|
return 44 |
|
|
|
|
|
@property |
|
|
def default_sample_width(self): |
|
|
return 80 |
|
|
|
|
|
@property |
|
|
def default_sample_num_frames(self): |
|
|
return 21 |
|
|
|
|
|
@property |
|
|
def vae_scale_factor_spatial(self): |
|
|
vae_scale_factor = 8 |
|
|
if hasattr(self, "vae") and self.vae is not None: |
|
|
vae_scale_factor = 2 ** len(self.vae.temperal_downsample) |
|
|
return vae_scale_factor |
|
|
|
|
|
@property |
|
|
def vae_scale_factor_temporal(self): |
|
|
vae_scale_factor = 4 |
|
|
if hasattr(self, "vae") and self.vae is not None: |
|
|
vae_scale_factor = 2 ** sum(self.vae.temperal_downsample) |
|
|
return vae_scale_factor |
|
|
|
|
|
@property |
|
|
def num_channels_transformer(self): |
|
|
num_channels_transformer = 16 |
|
|
if hasattr(self, "transformer") and self.transformer is not None: |
|
|
num_channels_transformer = self.transformer.config.in_channels |
|
|
return num_channels_transformer |
|
|
|
|
|
@property |
|
|
def num_channels_latents(self): |
|
|
num_channels_latents = 16 |
|
|
if hasattr(self, "vae") and self.vae is not None: |
|
|
num_channels_latents = self.vae.config.z_dim |
|
|
return num_channels_latents |
|
|
|