Spaces:
Running
on
Zero
Running
on
Zero
| from modules.Device import Device | |
| import torch | |
| from typing import List, Tuple, Any | |
| def get_models_from_cond(cond: dict, model_type: str) -> List[object]: | |
| """#### Get models from a condition. | |
| #### Args: | |
| - `cond` (dict): The condition. | |
| - `model_type` (str): The model type. | |
| #### Returns: | |
| - `List[object]`: The list of models. | |
| """ | |
| models = [] | |
| for c in cond: | |
| if model_type in c: | |
| models += [c[model_type]] | |
| return models | |
| def get_additional_models(conds: dict, dtype: torch.dtype) -> Tuple[List[object], int]: | |
| """#### Load additional models in conditioning. | |
| #### Args: | |
| - `conds` (dict): The conditions. | |
| - `dtype` (torch.dtype): The data type. | |
| #### Returns: | |
| - `Tuple[List[object], int]`: The list of models and the inference memory. | |
| """ | |
| cnets = [] | |
| gligen = [] | |
| for k in conds: | |
| cnets += get_models_from_cond(conds[k], "control") | |
| gligen += get_models_from_cond(conds[k], "gligen") | |
| control_nets = set(cnets) | |
| inference_memory = 0 | |
| control_models = [] | |
| for m in control_nets: | |
| control_models += m.get_models() | |
| inference_memory += m.inference_memory_requirements(dtype) | |
| gligen = [x[1] for x in gligen] | |
| models = control_models + gligen | |
| return models, inference_memory | |
| def prepare_sampling( | |
| model: object, noise_shape: Tuple[int], conds: dict, flux_enabled: bool = False | |
| ) -> Tuple[object, dict, List[object]]: | |
| """#### Prepare the model for sampling. | |
| #### Args: | |
| - `model` (object): The model. | |
| - `noise_shape` (Tuple[int]): The shape of the noise. | |
| - `conds` (dict): The conditions. | |
| - `flux_enabled` (bool, optional): Whether flux is enabled. Defaults to False. | |
| #### Returns: | |
| - `Tuple[object, dict, List[object]]`: The prepared model, conditions, and additional models. | |
| """ | |
| real_model = None | |
| models, inference_memory = get_additional_models(conds, model.model_dtype()) | |
| memory_required = ( | |
| model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) | |
| + inference_memory | |
| ) | |
| minimum_memory_required = ( | |
| model.memory_required([noise_shape[0]] + list(noise_shape[1:])) | |
| + inference_memory | |
| ) | |
| Device.load_models_gpu( | |
| [model] + models, | |
| memory_required=memory_required, | |
| minimum_memory_required=minimum_memory_required, | |
| flux_enabled=flux_enabled, | |
| ) | |
| real_model = model.model | |
| return real_model, conds, models | |
| def cleanup_additional_models(models: List[object]) -> None: | |
| """#### Clean up additional models. | |
| #### Args: | |
| - `models` (List[object]): The list of models. | |
| """ | |
| for m in models: | |
| if hasattr(m, "cleanup"): | |
| m.cleanup() | |
| def cleanup_models(conds: dict, models: List[object]) -> None: | |
| """#### Clean up the models after sampling. | |
| #### Args: | |
| - `conds` (dict): The conditions. | |
| - `models` (List[object]): The list of models. | |
| """ | |
| cleanup_additional_models(models) | |
| control_cleanup = [] | |
| for k in conds: | |
| control_cleanup += get_models_from_cond(conds[k], "control") | |
| cleanup_additional_models(set(control_cleanup)) | |
| def cond_equal_size(c1: Any, c2: Any) -> bool: | |
| """#### Check if two conditions have equal size. | |
| #### Args: | |
| - `c1` (Any): The first condition. | |
| - `c2` (Any): The second condition. | |
| #### Returns: | |
| - `bool`: Whether the conditions have equal size. | |
| """ | |
| if c1 is c2: | |
| return True | |
| if c1.keys() != c2.keys(): | |
| return False | |
| return True | |
| def can_concat_cond(c1: Any, c2: Any) -> bool: | |
| """#### Check if two conditions can be concatenated. | |
| #### Args: | |
| - `c1` (Any): The first condition. | |
| - `c2` (Any): The second condition. | |
| #### Returns: | |
| - `bool`: Whether the conditions can be concatenated. | |
| """ | |
| if c1.input_x.shape != c2.input_x.shape: | |
| return False | |
| def objects_concatable(obj1, obj2): | |
| """#### Check if two objects can be concatenated.""" | |
| if (obj1 is None) != (obj2 is None): | |
| return False | |
| if obj1 is not None: | |
| if obj1 is not obj2: | |
| return False | |
| return True | |
| if not objects_concatable(c1.control, c2.control): | |
| return False | |
| if not objects_concatable(c1.patches, c2.patches): | |
| return False | |
| return cond_equal_size(c1.conditioning, c2.conditioning) | |
| def cond_cat(c_list: List[dict]) -> dict: | |
| """#### Concatenate a list of conditions. | |
| #### Args: | |
| - `c_list` (List[dict]): The list of conditions. | |
| #### Returns: | |
| - `dict`: The concatenated conditions. | |
| """ | |
| temp = {} | |
| for x in c_list: | |
| for k in x: | |
| cur = temp.get(k, []) | |
| cur.append(x[k]) | |
| temp[k] = cur | |
| out = {} | |
| for k in temp: | |
| conds = temp[k] | |
| out[k] = conds[0].concat(conds[1:]) | |
| return out | |
| def create_cond_with_same_area_if_none(conds: List[dict], c: dict) -> None: | |
| """#### Create a condition with the same area if none exists. | |
| #### Args: | |
| - `conds` (List[dict]): The list of conditions. | |
| - `c` (dict): The condition. | |
| """ | |
| if "area" not in c: | |
| return | |
| c_area = c["area"] | |
| smallest = None | |
| for x in conds: | |
| if "area" in x: | |
| a = x["area"] | |
| if c_area[2] >= a[2] and c_area[3] >= a[3]: | |
| if a[0] + a[2] >= c_area[0] + c_area[2]: | |
| if a[1] + a[3] >= c_area[1] + c_area[3]: | |
| if smallest is None: | |
| smallest = x | |
| elif "area" not in smallest: | |
| smallest = x | |
| else: | |
| if smallest["area"][0] * smallest["area"][1] > a[0] * a[1]: | |
| smallest = x | |
| else: | |
| if smallest is None: | |
| smallest = x | |
| if smallest is None: | |
| return | |
| if "area" in smallest: | |
| if smallest["area"] == c_area: | |
| return | |
| out = c.copy() | |
| out["model_conds"] = smallest[ | |
| "model_conds" | |
| ].copy() | |
| conds += [out] | |