Spaces:
Running
Running
File size: 4,275 Bytes
2b67076 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import torch
class family_handler():
@staticmethod
def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache):
if base_model_type == "sky_df_1.3B":
coefficients= [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
else:
coefficients= [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
skip_steps_cache.coefficients = coefficients
@staticmethod
def query_model_def(base_model_type, model_def):
extra_model_def = {}
if base_model_type in ["sky_df_14B"]:
fps = 24
else:
fps = 16
extra_model_def["fps"] =fps
extra_model_def["frames_minimum"] = 17
extra_model_def["frames_steps"] = 20
extra_model_def["latent_size"] = 4
extra_model_def["sliding_window"] = True
extra_model_def["skip_layer_guidance"] = True
extra_model_def["tea_cache"] = True
extra_model_def["guidance_max_phases"] = 1
extra_model_def["model_modes"] = {
"choices": [
("Synchronous", 0),
("Asynchronous (better quality but around 50% extra steps added)", 5),
],
"default": 0,
"label" : "Generation Type"
}
extra_model_def["image_prompt_types_allowed"] = "TSV"
return extra_model_def
@staticmethod
def query_supported_types():
return ["sky_df_1.3B", "sky_df_14B"]
@staticmethod
def query_family_maps():
models_eqv_map = {
"sky_df_1.3B" : "sky_df_14B",
}
models_comp_map = {
"sky_df_14B": ["sky_df_1.3B"],
}
return models_eqv_map, models_comp_map
@staticmethod
def query_model_family():
return "wan"
@staticmethod
def query_family_infos():
return {}
@staticmethod
def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
return latent_rgb_factors, latent_rgb_factors_bias
@staticmethod
def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
from .wan_handler import family_handler
return family_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization)
@staticmethod
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None):
from .configs import WAN_CONFIGS
from .wan_handler import family_handler
cfg = WAN_CONFIGS['t2v-14B']
from . import DTT2V
wan_model = DTT2V(
config=cfg,
checkpoint_dir="ckpts",
model_filename=model_filename,
model_type = model_type,
model_def = model_def,
base_model_type=base_model_type,
text_encoder_filename= family_handler.get_wan_text_encoder_filename(text_encoder_quantization),
quantizeTransformer = quantizeTransformer,
dtype = dtype,
VAE_dtype = VAE_dtype,
mixed_precision_transformer = mixed_precision_transformer,
save_quantized = save_quantized
)
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
return wan_model, pipe
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
"guidance_scale": 6.0,
"flow_shift": 8,
"sliding_window_discard_last_frames" : 0,
"resolution": "1280x720" if "720" in base_model_type else "960x544",
"sliding_window_size" : 121 if "720" in base_model_type else 97,
"RIFLEx_setting": 2,
"guidance_scale": 6,
"flow_shift": 8,
}) |