Spaces:
Paused
Paused
| from collections.abc import Iterable | |
| from typing import Union | |
| import torch | |
| from torch import Tensor | |
| from .utils_motion import linear_conversion, normalize_min_max, extend_to_batch_size | |
| class ScaleType: | |
| ABSOLUTE = "absolute" | |
| RELATIVE = "relative" | |
| LIST = [ABSOLUTE, RELATIVE] | |
| class MultivalDynamicNode: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001},), | |
| }, | |
| "optional": { | |
| "mask_optional": ("MASK",) | |
| } | |
| } | |
| RETURN_TYPES = ("MULTIVAL",) | |
| CATEGORY = "Animate Diff ππ π /multival" | |
| FUNCTION = "create_multival" | |
| def create_multival(self, float_val: Union[float, list[float]]=1.0, mask_optional: Tensor=None): | |
| # first, normalize inputs | |
| # if float_val is iterable, treat as a list and assume inputs are floats | |
| float_is_iterable = False | |
| if isinstance(float_val, Iterable): | |
| float_is_iterable = True | |
| float_val = list(float_val) | |
| # if mask present, make sure float_val list can be applied to list - match lengths | |
| if mask_optional is not None: | |
| if len(float_val) < mask_optional.shape[0]: | |
| # copies last entry enough times to match mask shape | |
| float_val = float_val + float_val[-1]*(mask_optional.shape[0]-len(float_val)) | |
| if mask_optional.shape[0] < len(float_val): | |
| mask_optional = extend_to_batch_size(mask_optional, len(float_val)) | |
| float_val = float_val[:mask_optional.shape[0]] | |
| float_val: Tensor = torch.tensor(float_val).unsqueeze(-1).unsqueeze(-1) | |
| # now that inputs are normalized, figure out what value to actually return | |
| if mask_optional is not None: | |
| mask_optional = mask_optional.clone() | |
| if float_is_iterable: | |
| mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) | |
| else: | |
| mask_optional = mask_optional * float_val | |
| return (mask_optional,) | |
| else: | |
| if not float_is_iterable: | |
| return (float_val,) | |
| # create a dummy mask of b,h,w=float_len,1,1 (sigle pixel) | |
| # purpose is for float input to work with mask code, without special cases | |
| float_len = float_val.shape[0] if float_is_iterable else 1 | |
| shape = (float_len,1,1) | |
| mask_optional = torch.ones(shape) | |
| mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) | |
| return (mask_optional,) | |
| class MultivalScaledMaskNode: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "min_float_val": ("FLOAT", {"default": 0.0, "min": 0.0, "step": 0.001}), | |
| "max_float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}), | |
| "mask": ("MASK",), | |
| }, | |
| "optional": { | |
| "scaling": (ScaleType.LIST,), | |
| } | |
| } | |
| RETURN_TYPES = ("MULTIVAL",) | |
| CATEGORY = "Animate Diff ππ π /multival" | |
| FUNCTION = "create_multival" | |
| def create_multival(self, min_float_val: float, max_float_val: float, mask: Tensor, scaling: str=ScaleType.ABSOLUTE): | |
| # TODO: allow min_float_val and max_float_val to be list[float] | |
| if isinstance(min_float_val, Iterable): | |
| raise ValueError(f"min_float_val must be type float (no lists allowed here), not {type(min_float_val).__name__}.") | |
| if isinstance(max_float_val, Iterable): | |
| raise ValueError(f"max_float_val must be type float (no lists allowed here), not {type(max_float_val).__name__}.") | |
| if scaling == ScaleType.ABSOLUTE: | |
| mask = linear_conversion(mask.clone(), new_min=min_float_val, new_max=max_float_val) | |
| elif scaling == ScaleType.RELATIVE: | |
| mask = normalize_min_max(mask.clone(), new_min=min_float_val, new_max=max_float_val) | |
| else: | |
| raise ValueError(f"scaling '{scaling}' not recognized.") | |
| return MultivalDynamicNode.create_multival(self, mask_optional=mask) | |
| class MultivalDynamicFloatInputNode: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "forceInput": True},), | |
| }, | |
| "optional": { | |
| "mask_optional": ("MASK",) | |
| } | |
| } | |
| RETURN_TYPES = ("MULTIVAL",) | |
| CATEGORY = "Animate Diff ππ π /multival" | |
| FUNCTION = "create_multival" | |
| def create_multival(self, float_val: Union[float, list[float]]=None, mask_optional: Tensor=None): | |
| return MultivalDynamicNode.create_multival(self, float_val=float_val, mask_optional=mask_optional) | |
| class MultivalFloatNode: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001},), | |
| }, | |
| } | |
| RETURN_TYPES = ("MULTIVAL",) | |
| CATEGORY = "Animate Diff ππ π /multival" | |
| FUNCTION = "create_multival" | |
| def create_multival(self, float_val: Union[float, list[float]]=None): | |
| return MultivalDynamicNode.create_multival(self, float_val=float_val) | |