Spaces:
Paused
Paused
| import hashlib | |
| from pathlib import Path | |
| from typing import Callable, Union | |
| from collections.abc import Iterable | |
| from time import time | |
| import copy | |
| import torch | |
| import numpy as np | |
| import folder_paths | |
| from comfy.model_base import SD21UNCLIP, SDXL, BaseModel, SDXLRefiner, SVD_img2vid, model_sampling, ModelType | |
| from comfy.model_management import xformers_enabled | |
| from comfy.model_patcher import ModelPatcher | |
| import comfy.model_sampling | |
| import comfy_extras.nodes_model_advanced | |
| BIGMIN = -(2**53-1) | |
| BIGMAX = (2**53-1) | |
| class ModelSamplingConfig: | |
| def __init__(self, beta_schedule: str, linear_start: float=None, linear_end: float=None): | |
| self.sampling_settings = {"beta_schedule": beta_schedule} | |
| if linear_start is not None: | |
| self.sampling_settings["linear_start"] = linear_start | |
| if linear_end is not None: | |
| self.sampling_settings["linear_end"] = linear_end | |
| self.beta_schedule = beta_schedule # keeping this for backwards compatibility | |
| class ModelSamplingType: | |
| EPS = "eps" | |
| V_PREDICTION = "v_prediction" | |
| LCM = "lcm" | |
| _NON_LCM_LIST = [EPS, V_PREDICTION] | |
| _FULL_LIST = [EPS, V_PREDICTION, LCM] | |
| MAP = { | |
| EPS: ModelType.EPS, | |
| V_PREDICTION: ModelType.V_PREDICTION, | |
| LCM: comfy_extras.nodes_model_advanced.LCM, | |
| } | |
| def from_alias(cls, alias: str): | |
| return cls.MAP[alias] | |
| def factory_model_sampling_discrete_distilled(original_timesteps=50): | |
| class ModelSamplingDiscreteDistilledEvolved(comfy_extras.nodes_model_advanced.ModelSamplingDiscreteDistilled): | |
| def __init__(self, *args, **kwargs): | |
| self.original_timesteps = original_timesteps # normal LCM has 50 | |
| super().__init__(*args, **kwargs) | |
| return ModelSamplingDiscreteDistilledEvolved | |
| # based on code in comfy_extras/nodes_model_advanced.py | |
| def evolved_model_sampling(model_config: ModelSamplingConfig, model_type: ModelType, alias: str, original_timesteps: int=None): | |
| # if LCM, need to handle manually | |
| if BetaSchedules.is_lcm(alias) or original_timesteps is not None: | |
| sampling_type = comfy_extras.nodes_model_advanced.LCM | |
| if original_timesteps is not None: | |
| sampling_base = factory_model_sampling_discrete_distilled(original_timesteps=original_timesteps) | |
| elif alias == BetaSchedules.LCM_100: | |
| sampling_base = factory_model_sampling_discrete_distilled(original_timesteps=100) | |
| elif alias == BetaSchedules.LCM_25: | |
| sampling_base = factory_model_sampling_discrete_distilled(original_timesteps=25) | |
| else: | |
| sampling_base = comfy_extras.nodes_model_advanced.ModelSamplingDiscreteDistilled | |
| class ModelSamplingAdvancedEvolved(sampling_base, sampling_type): | |
| pass | |
| # NOTE: if I want to support zsnr, this is where I would add that code | |
| return ModelSamplingAdvancedEvolved(model_config) | |
| # otherwise, use vanilla model_sampling function | |
| return model_sampling(model_config, model_type) | |
| class BetaSchedules: | |
| AUTOSELECT = "autoselect" | |
| SQRT_LINEAR = "sqrt_linear (AnimateDiff)" | |
| LINEAR_ADXL = "linear (AnimateDiff-SDXL)" | |
| LINEAR = "linear (HotshotXL/default)" | |
| AVG_LINEAR_SQRT_LINEAR = "avg(sqrt_linear,linear)" | |
| LCM_AVG_LINEAR_SQRT_LINEAR = "lcm avg(sqrt_linear,linear)" | |
| LCM = "lcm" | |
| LCM_100 = "lcm[100_ots]" | |
| LCM_25 = "lcm[25_ots]" | |
| LCM_SQRT_LINEAR = "lcm >> sqrt_linear" | |
| USE_EXISTING = "use existing" | |
| SQRT = "sqrt" | |
| COSINE = "cosine" | |
| SQUAREDCOS_CAP_V2 = "squaredcos_cap_v2" | |
| RAW_LINEAR = "linear" | |
| RAW_SQRT_LINEAR = "sqrt_linear" | |
| RAW_BETA_SCHEDULE_LIST = [RAW_LINEAR, RAW_SQRT_LINEAR, SQRT, COSINE, SQUAREDCOS_CAP_V2] | |
| ALIAS_LCM_LIST = [LCM, LCM_100, LCM_25, LCM_SQRT_LINEAR] | |
| ALIAS_ACTIVE_LIST = [SQRT_LINEAR, LINEAR_ADXL, LINEAR, AVG_LINEAR_SQRT_LINEAR, LCM_AVG_LINEAR_SQRT_LINEAR, LCM, LCM_100, LCM_SQRT_LINEAR, # LCM_25 is purposely omitted | |
| SQRT, COSINE, SQUAREDCOS_CAP_V2] | |
| ALIAS_LIST = [AUTOSELECT, USE_EXISTING] + ALIAS_ACTIVE_LIST | |
| ALIAS_MAP = { | |
| SQRT_LINEAR: "sqrt_linear", | |
| LINEAR_ADXL: "linear", # also linear, but has different linear_end (0.020) | |
| LINEAR: "linear", | |
| LCM_100: "linear", # distilled, 100 original timesteps | |
| LCM_25: "linear", # distilled, 25 original timesteps | |
| LCM: "linear", # distilled | |
| LCM_SQRT_LINEAR: "sqrt_linear", # distilled, sqrt_linear | |
| SQRT: "sqrt", | |
| COSINE: "cosine", | |
| SQUAREDCOS_CAP_V2: "squaredcos_cap_v2", | |
| RAW_LINEAR: "linear", | |
| RAW_SQRT_LINEAR: "sqrt_linear" | |
| } | |
| def is_lcm(cls, alias: str): | |
| return alias in cls.ALIAS_LCM_LIST | |
| def to_name(cls, alias: str): | |
| return cls.ALIAS_MAP[alias] | |
| def to_config(cls, alias: str) -> ModelSamplingConfig: | |
| linear_start = None | |
| linear_end = None | |
| if alias == cls.LINEAR_ADXL: | |
| # uses linear_end=0.020 | |
| linear_end = 0.020 | |
| return ModelSamplingConfig(cls.to_name(alias), linear_start=linear_start, linear_end=linear_end) | |
| def _to_model_sampling(cls, alias: str, model_type: ModelType, config_override: ModelSamplingConfig=None, original_timesteps: int=None): | |
| if alias == cls.USE_EXISTING: | |
| return None | |
| elif config_override != None: | |
| return evolved_model_sampling(config_override, model_type=model_type, alias=alias, original_timesteps=original_timesteps) | |
| elif alias == cls.AVG_LINEAR_SQRT_LINEAR: | |
| ms_linear = evolved_model_sampling(cls.to_config(cls.LINEAR), model_type=model_type, alias=cls.LINEAR) | |
| ms_sqrt_linear = evolved_model_sampling(cls.to_config(cls.SQRT_LINEAR), model_type=model_type, alias=cls.SQRT_LINEAR) | |
| avg_sigmas = (ms_linear.sigmas + ms_sqrt_linear.sigmas) / 2 | |
| ms_linear.set_sigmas(avg_sigmas) | |
| return ms_linear | |
| elif alias == cls.LCM_AVG_LINEAR_SQRT_LINEAR: | |
| ms_linear = evolved_model_sampling(cls.to_config(cls.LCM), model_type=model_type, alias=cls.LCM) | |
| ms_sqrt_linear = evolved_model_sampling(cls.to_config(cls.LCM_SQRT_LINEAR), model_type=model_type, alias=cls.LCM_SQRT_LINEAR) | |
| avg_sigmas = (ms_linear.sigmas + ms_sqrt_linear.sigmas) / 2 | |
| ms_linear.set_sigmas(avg_sigmas) | |
| return ms_linear | |
| # average out the sigmas | |
| ms_obj = evolved_model_sampling(cls.to_config(alias), model_type=model_type, alias=alias, original_timesteps=original_timesteps) | |
| return ms_obj | |
| def to_model_sampling(cls, alias: str, model: ModelPatcher): | |
| return cls._to_model_sampling(alias=alias, model_type=model.model.model_type) | |
| def get_alias_list_with_first_element(first_element: str): | |
| new_list = BetaSchedules.ALIAS_LIST.copy() | |
| element_index = new_list.index(first_element) | |
| new_list[0], new_list[element_index] = new_list[element_index], new_list[0] | |
| return new_list | |
| class SigmaSchedule: | |
| def __init__(self, model_sampling: comfy.model_sampling.ModelSamplingDiscrete, model_type: ModelType): | |
| self.model_sampling = model_sampling | |
| #self.config = config | |
| self.model_type = model_type | |
| self.original_timesteps = getattr(self.model_sampling, "original_timesteps", None) | |
| def is_lcm(self): | |
| return self.original_timesteps is not None | |
| def total_sigmas(self): | |
| return len(self.model_sampling.sigmas) | |
| def clone(self) -> 'SigmaSchedule': | |
| new_model_sampling = copy.deepcopy(self.model_sampling) | |
| #new_config = copy.deepcopy(self.config) | |
| return SigmaSchedule(model_sampling=new_model_sampling, model_type=self.model_type) | |
| # def clone(self): | |
| # pass | |
| def apply_zsnr(new_model_sampling: comfy.model_sampling.ModelSamplingDiscrete): | |
| new_model_sampling.set_sigmas(comfy_extras.nodes_model_advanced.rescale_zero_terminal_snr_sigmas(new_model_sampling.sigmas)) | |
| # def get_lcmified(self, original_timesteps=50, zsnr=False) -> 'SigmaSchedule': | |
| # new_model_sampling = evolved_model_sampling(model_config=self.config, model_type=self.model_type, alias=None, original_timesteps=original_timesteps) | |
| # if zsnr: | |
| # new_model_sampling.set_sigmas(comfy_extras.nodes_model_advanced.rescale_zero_terminal_snr_sigmas(new_model_sampling.sigmas)) | |
| # return SigmaSchedule(model_sampling=new_model_sampling, config=self.config, model_type=self.model_type, is_lcm=True) | |
| class InterpolationMethod: | |
| LINEAR = "linear" | |
| EASE_IN = "ease_in" | |
| EASE_OUT = "ease_out" | |
| EASE_IN_OUT = "ease_in_out" | |
| _LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT] | |
| def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False): | |
| diff = num_to - num_from | |
| if method == cls.LINEAR: | |
| weights = torch.linspace(num_from, num_to, length) | |
| elif method == cls.EASE_IN: | |
| index = torch.linspace(0, 1, length) | |
| weights = diff * np.power(index, 2) + num_from | |
| elif method == cls.EASE_OUT: | |
| index = torch.linspace(0, 1, length) | |
| weights = diff * (1 - np.power(1 - index, 2)) + num_from | |
| elif method == cls.EASE_IN_OUT: | |
| index = torch.linspace(0, 1, length) | |
| weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from | |
| else: | |
| raise ValueError(f"Unrecognized interpolation method '{method}'.") | |
| if reverse: | |
| weights = weights.flip(dims=(0,)) | |
| return weights | |
| class Folders: | |
| ANIMATEDIFF_MODELS = "animatediff_models" | |
| MOTION_LORA = "animatediff_motion_lora" | |
| VIDEO_FORMATS = "animatediff_video_formats" | |
| def add_extension_to_folder_path(folder_name: str, extensions: Union[str, list[str]]): | |
| if folder_name in folder_paths.folder_names_and_paths: | |
| if isinstance(extensions, str): | |
| folder_paths.folder_names_and_paths[folder_name][1].add(extensions) | |
| elif isinstance(extensions, Iterable): | |
| for ext in extensions: | |
| folder_paths.folder_names_and_paths[folder_name][1].add(ext) | |
| def try_mkdir(full_path: str): | |
| try: | |
| Path(full_path).mkdir() | |
| except Exception: | |
| pass | |
| # register motion models folder(s) | |
| folder_paths.add_model_folder_path(Folders.ANIMATEDIFF_MODELS, str(Path(__file__).parent.parent / "models")) | |
| folder_paths.add_model_folder_path(Folders.ANIMATEDIFF_MODELS, str(Path(folder_paths.models_dir) / Folders.ANIMATEDIFF_MODELS)) | |
| add_extension_to_folder_path(Folders.ANIMATEDIFF_MODELS, folder_paths.supported_pt_extensions) | |
| try_mkdir(str(Path(folder_paths.models_dir) / Folders.ANIMATEDIFF_MODELS)) | |
| # register motion LoRA folder(s) | |
| folder_paths.add_model_folder_path(Folders.MOTION_LORA, str(Path(__file__).parent.parent / "motion_lora")) | |
| folder_paths.add_model_folder_path(Folders.MOTION_LORA, str(Path(folder_paths.models_dir) / Folders.MOTION_LORA)) | |
| add_extension_to_folder_path(Folders.MOTION_LORA, folder_paths.supported_pt_extensions) | |
| try_mkdir(str(Path(folder_paths.models_dir) / Folders.MOTION_LORA)) | |
| # register video_formats folder | |
| folder_paths.add_model_folder_path(Folders.VIDEO_FORMATS, str(Path(__file__).parent.parent / "video_formats")) | |
| add_extension_to_folder_path(Folders.VIDEO_FORMATS, ".json") | |
| def get_available_motion_models(): | |
| return folder_paths.get_filename_list(Folders.ANIMATEDIFF_MODELS) | |
| def get_motion_model_path(model_name: str): | |
| return folder_paths.get_full_path(Folders.ANIMATEDIFF_MODELS, model_name) | |
| def get_available_motion_loras(): | |
| return folder_paths.get_filename_list(Folders.MOTION_LORA) | |
| def get_motion_lora_path(lora_name: str): | |
| return folder_paths.get_full_path(Folders.MOTION_LORA, lora_name) | |
| # modified from https://stackoverflow.com/questions/22058048/hashing-a-file-in-python | |
| def calculate_file_hash(filename: str, hash_every_n: int = 50): | |
| h = hashlib.sha256() | |
| b = bytearray(1024*1024) | |
| mv = memoryview(b) | |
| with open(filename, 'rb', buffering=0) as f: | |
| i = 0 | |
| # don't hash entire file, only portions of it | |
| while n := f.readinto(mv): | |
| if i%hash_every_n == 0: | |
| h.update(mv[:n]) | |
| i += 1 | |
| return h.hexdigest() | |
| def calculate_model_hash(model: ModelPatcher): | |
| unet = model.model.diff | |
| t = unet.input_blocks[1] | |
| m = hashlib.sha256() | |
| for buf in t.buffers(): | |
| m.update(buf.cpu().numpy().view(np.uint8)) | |
| return m.hexdigest() | |
| class ModelTypeSD: | |
| SD1_5 = "SD1.5" | |
| SD2_1 = "SD2.1" | |
| SDXL = "SDXL" | |
| SDXL_REFINER = "SDXL_Refiner" | |
| SVD = "SVD" | |
| def get_sd_model_type(model: ModelPatcher) -> str: | |
| if model is None: | |
| return None | |
| elif type(model.model) == BaseModel: | |
| return ModelTypeSD.SD1_5 | |
| elif type(model.model) == SDXL: | |
| return ModelTypeSD.SDXL | |
| elif type(model.model) == SD21UNCLIP: | |
| return ModelTypeSD.SD2_1 | |
| elif type(model.model) == SDXLRefiner: | |
| return ModelTypeSD.SDXL_REFINER | |
| elif type(model.model) == SVD_img2vid: | |
| return ModelTypeSD.SVD | |
| else: | |
| return str(type(model.model).__name__) | |
| def is_checkpoint_sd1_5(model: ModelPatcher): | |
| return False if model is None else type(model.model) == BaseModel | |
| def is_checkpoint_sdxl(model: ModelPatcher): | |
| return False if model is None else type(model.model) == SDXL | |
| def raise_if_not_checkpoint_sd1_5(model: ModelPatcher): | |
| if not is_checkpoint_sd1_5(model): | |
| raise ValueError(f"For AnimateDiff, SD Checkpoint (model) is expected to be SD1.5-based (BaseModel), but was: {type(model.model).__name__}") | |
| # TODO: remove this filth when xformers bug gets fixed in future xformers version | |
| def wrap_function_to_inject_xformers_bug_info(function_to_wrap: Callable) -> Callable: | |
| if not xformers_enabled: | |
| return function_to_wrap | |
| else: | |
| def wrapped_function(*args, **kwargs): | |
| try: | |
| return function_to_wrap(*args, **kwargs) | |
| except RuntimeError as e: | |
| if str(e).startswith("CUDA error: invalid configuration argument"): | |
| raise RuntimeError(f"An xformers bug was encountered in AnimateDiff - this is unexpected, \ | |
| report this to Kosinkadink/ComfyUI-AnimateDiff-Evolved repo as an issue, \ | |
| and a workaround for now is to run ComfyUI with the --disable-xformers argument.") | |
| raise | |
| return wrapped_function | |
| class Timer(object): | |
| __slots__ = ("start_time", "end_time") | |
| def __init__(self) -> None: | |
| self.start_time = 0.0 | |
| self.end_time = 0.0 | |
| def start(self) -> None: | |
| self.start_time = time() | |
| def update(self) -> None: | |
| self.start() | |
| def stop(self) -> float: | |
| self.end_time = time() | |
| return self.get_time_diff() | |
| def get_time_diff(self) -> float: | |
| return self.end_time - self.start_time | |
| def get_time_current(self) -> float: | |
| return time() - self.start_time | |
| # TODO: possibly add configuration file in future when needed? | |
| # # Load config settings | |
| # ADE_DIR = Path(__file__).parent.parent | |
| # ADE_CONFIG_FILE = ADE_DIR / "ade_config.json" | |
| # class ADE_Settings: | |
| # USE_XFORMERS_IN_VERSATILE_ATTENTION = "use_xformers_in_VersatileAttention" | |
| # # Create ADE config if not present | |
| # ABS_CONFIG = { | |
| # ADE_Settings.USE_XFORMERS_IN_VERSATILE_ATTENTION: True | |
| # } | |
| # if not ADE_CONFIG_FILE.exists(): | |
| # with ADE_CONFIG_FILE.open("w") as f: | |
| # json.dumps(ABS_CONFIG, indent=4) | |
| # # otherwise, load it and use values | |
| # else: | |
| # loaded_values: dict = None | |
| # with ADE_CONFIG_FILE.open("r") as f: | |
| # loaded_values = json.load(f) | |
| # if loaded_values is not None: | |
| # for key, value in loaded_values.items(): | |
| # if key in ABS_CONFIG: | |
| # ABS_CONFIG[key] = value | |