File size: 4,275 Bytes
2b67076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch

class family_handler():

    @staticmethod
    def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache):
        if base_model_type == "sky_df_1.3B":
            coefficients= [2.39676752e+03, -1.31110545e+03,  2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
        else: 
            coefficients= [-5784.54975374,  5449.50911966, -1811.16591783,   256.27178429, -13.02252404]

        skip_steps_cache.coefficients = coefficients

    @staticmethod
    def query_model_def(base_model_type, model_def):
        extra_model_def = {}
        if base_model_type in ["sky_df_14B"]:
            fps = 24
        else:
            fps = 16
        extra_model_def["fps"] =fps
        extra_model_def["frames_minimum"] = 17
        extra_model_def["frames_steps"] = 20
        extra_model_def["latent_size"] = 4
        extra_model_def["sliding_window"] = True
        extra_model_def["skip_layer_guidance"] = True
        extra_model_def["tea_cache"] = True
        extra_model_def["guidance_max_phases"] = 1

        extra_model_def["model_modes"] = {
                    "choices": [
                        ("Synchronous", 0),
                        ("Asynchronous (better quality but around 50% extra steps added)", 5),
                        ],
                    "default": 0,
                    "label" : "Generation Type"
        }

        extra_model_def["image_prompt_types_allowed"] = "TSV"


        return extra_model_def 

    @staticmethod
    def query_supported_types():
        return ["sky_df_1.3B", "sky_df_14B"]


    @staticmethod
    def query_family_maps():
        models_eqv_map = {
            "sky_df_1.3B" : "sky_df_14B",
        }

        models_comp_map = { 
                    "sky_df_14B": ["sky_df_1.3B"],
                    }
        return models_eqv_map, models_comp_map



    @staticmethod
    def query_model_family():
        return "wan"

    @staticmethod
    def query_family_infos():
        return {}

    @staticmethod
    def get_rgb_factors(base_model_type ):
        from shared.RGB_factors import get_rgb_factors
        latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
        return latent_rgb_factors, latent_rgb_factors_bias

    @staticmethod
    def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
        from .wan_handler import family_handler
        return family_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization)
    
    @staticmethod
    def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None):
        from .configs import WAN_CONFIGS
        from .wan_handler import family_handler
        cfg = WAN_CONFIGS['t2v-14B']
        from . import DTT2V
        wan_model = DTT2V(
            config=cfg,
            checkpoint_dir="ckpts",
            model_filename=model_filename,
            model_type = model_type,        
            model_def = model_def,
            base_model_type=base_model_type,
            text_encoder_filename= family_handler.get_wan_text_encoder_filename(text_encoder_quantization),
            quantizeTransformer = quantizeTransformer,
            dtype = dtype,
            VAE_dtype = VAE_dtype, 
            mixed_precision_transformer = mixed_precision_transformer,
            save_quantized = save_quantized
        )

        pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
        return wan_model, pipe

    @staticmethod
    def update_default_settings(base_model_type, model_def, ui_defaults):
        ui_defaults.update({
            "guidance_scale": 6.0,
            "flow_shift": 8,
            "sliding_window_discard_last_frames" : 0,
            "resolution": "1280x720" if "720" in base_model_type else "960x544",
            "sliding_window_size" : 121 if "720" in base_model_type else 97,
            "RIFLEx_setting": 2,
            "guidance_scale": 6,
            "flow_shift": 8,
        })