Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from modules.Utilities import util | |
| from modules.Device import Device | |
| from modules.cond import cond_util | |
| from modules.sample import ksampler_util | |
| class CONDRegular: | |
| """#### Class representing a regular condition.""" | |
| def __init__(self, cond: torch.Tensor): | |
| """#### Initialize the CONDRegular class. | |
| #### Args: | |
| - `cond` (torch.Tensor): The condition tensor. | |
| """ | |
| self.cond = cond | |
| def _copy_with(self, cond: torch.Tensor) -> "CONDRegular": | |
| """#### Copy the condition with a new condition. | |
| #### Args: | |
| - `cond` (torch.Tensor): The new condition. | |
| #### Returns: | |
| - `CONDRegular`: The copied condition. | |
| """ | |
| return self.__class__(cond) | |
| def process_cond( | |
| self, batch_size: int, device: torch.device, **kwargs | |
| ) -> "CONDRegular": | |
| """#### Process the condition. | |
| #### Args: | |
| - `batch_size` (int): The batch size. | |
| - `device` (torch.device): The device. | |
| #### Returns: | |
| - `CONDRegular`: The processed condition. | |
| """ | |
| return self._copy_with( | |
| util.repeat_to_batch_size(self.cond, batch_size).to(device) | |
| ) | |
| def can_concat(self, other: "CONDRegular") -> bool: | |
| """#### Check if conditions can be concatenated. | |
| #### Args: | |
| - `other` (CONDRegular): The other condition. | |
| #### Returns: | |
| - `bool`: True if conditions can be concatenated, False otherwise. | |
| """ | |
| if self.cond.shape != other.cond.shape: | |
| return False | |
| return True | |
| def concat(self, others: list) -> torch.Tensor: | |
| """#### Concatenate conditions. | |
| #### Args: | |
| - `others` (list): The list of other conditions. | |
| #### Returns: | |
| - `torch.Tensor`: The concatenated conditions. | |
| """ | |
| conds = [self.cond] | |
| for x in others: | |
| conds.append(x.cond) | |
| return torch.cat(conds) | |
| class CONDCrossAttn(CONDRegular): | |
| """#### Class representing a cross-attention condition.""" | |
| def can_concat(self, other: "CONDRegular") -> bool: | |
| """#### Check if conditions can be concatenated. | |
| #### Args: | |
| - `other` (CONDRegular): The other condition. | |
| #### Returns: | |
| - `bool`: True if conditions can be concatenated, False otherwise. | |
| """ | |
| s1 = self.cond.shape | |
| s2 = other.cond.shape | |
| if s1 != s2: | |
| if s1[0] != s2[0] or s1[2] != s2[2]: # these 2 cases should not happen | |
| return False | |
| mult_min = torch.lcm(s1[1], s2[1]) | |
| diff = mult_min // min(s1[1], s2[1]) | |
| if ( | |
| diff > 4 | |
| ): # arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much | |
| return False | |
| return True | |
| def concat(self, others: list) -> torch.Tensor: | |
| """Optimized version of cross-attention condition concatenation.""" | |
| conds = [self.cond] | |
| shapes = [self.cond.shape[1]] | |
| # Collect all conditions and their shapes | |
| for x in others: | |
| conds.append(x.cond) | |
| shapes.append(x.cond.shape[1]) | |
| # Calculate LCM more efficiently | |
| crossattn_max_len = util.lcm_of_list(shapes) | |
| # Process and concat in one step where possible | |
| if all(c.shape[1] == shapes[0] for c in conds): | |
| # All same length, simple concatenation | |
| return torch.cat(conds) | |
| else: | |
| # Process conditions that need repeating | |
| out = [] | |
| for c in conds: | |
| if c.shape[1] < crossattn_max_len: | |
| repeat_factor = crossattn_max_len // c.shape[1] | |
| # Use repeat instead of individual operations | |
| c = c.repeat(1, repeat_factor, 1) | |
| out.append(c) | |
| return torch.cat(out) | |
| def convert_cond(cond: list) -> list: | |
| """#### Convert conditions to cross-attention conditions. | |
| #### Args: | |
| - `cond` (list): The list of conditions. | |
| #### Returns: | |
| - `list`: The converted conditions. | |
| """ | |
| out = [] | |
| for c in cond: | |
| temp = c[1].copy() | |
| model_conds = temp.get("model_conds", {}) | |
| if c[0] is not None: | |
| model_conds["c_crossattn"] = CONDCrossAttn(c[0]) | |
| temp["cross_attn"] = c[0] | |
| temp["model_conds"] = model_conds | |
| out.append(temp) | |
| return out | |
| def calc_cond_batch( | |
| model: object, | |
| conds: list, | |
| x_in: torch.Tensor, | |
| timestep: torch.Tensor, | |
| model_options: dict, | |
| ) -> list: | |
| """#### Calculate the condition batch. | |
| #### Args: | |
| - `model` (object): The model. | |
| - `conds` (list): The list of conditions. | |
| - `x_in` (torch.Tensor): The input tensor. | |
| - `timestep` (torch.Tensor): The timestep tensor. | |
| - `model_options` (dict): The model options. | |
| #### Returns: | |
| - `list`: The calculated condition batch. | |
| """ | |
| out_conds = [] | |
| out_counts = [] | |
| to_run = [] | |
| for i in range(len(conds)): | |
| out_conds.append(torch.zeros_like(x_in)) | |
| out_counts.append(torch.ones_like(x_in) * 1e-37) | |
| cond = conds[i] | |
| if cond is not None: | |
| for x in cond: | |
| p = ksampler_util.get_area_and_mult(x, x_in, timestep) | |
| if p is None: | |
| continue | |
| to_run += [(p, i)] | |
| while len(to_run) > 0: | |
| first = to_run[0] | |
| first_shape = first[0][0].shape | |
| to_batch_temp = [] | |
| for x in range(len(to_run)): | |
| if cond_util.can_concat_cond(to_run[x][0], first[0]): | |
| to_batch_temp += [x] | |
| to_batch_temp.reverse() | |
| to_batch = to_batch_temp[:1] | |
| free_memory = Device.get_free_memory(x_in.device) | |
| for i in range(1, len(to_batch_temp) + 1): | |
| batch_amount = to_batch_temp[: len(to_batch_temp) // i] | |
| input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] | |
| if model.memory_required(input_shape) * 1.5 < free_memory: | |
| to_batch = batch_amount | |
| break | |
| input_x = [] | |
| mult = [] | |
| c = [] | |
| cond_or_uncond = [] | |
| area = [] | |
| control = None | |
| patches = None | |
| for x in to_batch: | |
| o = to_run.pop(x) | |
| p = o[0] | |
| input_x.append(p.input_x) | |
| mult.append(p.mult) | |
| c.append(p.conditioning) | |
| area.append(p.area) | |
| cond_or_uncond.append(o[1]) | |
| control = p.control | |
| patches = p.patches | |
| batch_chunks = len(cond_or_uncond) | |
| input_x = torch.cat(input_x) | |
| c = cond_util.cond_cat(c) | |
| timestep_ = torch.cat([timestep] * batch_chunks) | |
| if control is not None: | |
| c["control"] = control.get_control( | |
| input_x, timestep_, c, len(cond_or_uncond) | |
| ) | |
| transformer_options = {} | |
| if "transformer_options" in model_options: | |
| transformer_options = model_options["transformer_options"].copy() | |
| if patches is not None: | |
| if "patches" in transformer_options: | |
| cur_patches = transformer_options["patches"].copy() | |
| for p in patches: | |
| if p in cur_patches: | |
| cur_patches[p] = cur_patches[p] + patches[p] | |
| else: | |
| cur_patches[p] = patches[p] | |
| transformer_options["patches"] = cur_patches | |
| else: | |
| transformer_options["patches"] = patches | |
| transformer_options["cond_or_uncond"] = cond_or_uncond[:] | |
| transformer_options["sigmas"] = timestep | |
| c["transformer_options"] = transformer_options | |
| if "model_function_wrapper" in model_options: | |
| output = model_options["model_function_wrapper"]( | |
| model.apply_model, | |
| { | |
| "input": input_x, | |
| "timestep": timestep_, | |
| "c": c, | |
| "cond_or_uncond": cond_or_uncond, | |
| }, | |
| ).chunk(batch_chunks) | |
| else: | |
| output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) | |
| for o in range(batch_chunks): | |
| cond_index = cond_or_uncond[o] | |
| a = area[o] | |
| if a is None: | |
| out_conds[cond_index] += output[o] * mult[o] | |
| out_counts[cond_index] += mult[o] | |
| else: | |
| out_c = out_conds[cond_index] | |
| out_cts = out_counts[cond_index] | |
| dims = len(a) // 2 | |
| for i in range(dims): | |
| out_c = out_c.narrow(i + 2, a[i + dims], a[i]) | |
| out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) | |
| out_c += output[o] * mult[o] | |
| out_cts += mult[o] | |
| # Vectorize the division at the end | |
| for i in range(len(out_conds)): | |
| # Inplace division is already efficient | |
| out_conds[i].div_(out_counts[i]) # Using .div_ instead of /= for clarity | |
| return out_conds | |
| def encode_model_conds( | |
| model_function: callable, | |
| conds: list, | |
| noise: torch.Tensor, | |
| device: torch.device, | |
| prompt_type: str, | |
| **kwargs, | |
| ) -> list: | |
| """#### Encode model conditions. | |
| #### Args: | |
| - `model_function` (callable): The model function. | |
| - `conds` (list): The list of conditions. | |
| - `noise` (torch.Tensor): The noise tensor. | |
| - `device` (torch.device): The device. | |
| - `prompt_type` (str): The prompt type. | |
| - `**kwargs`: Additional keyword arguments. | |
| #### Returns: | |
| - `list`: The encoded model conditions. | |
| """ | |
| for t in range(len(conds)): | |
| x = conds[t] | |
| params = x.copy() | |
| params["device"] = device | |
| params["noise"] = noise | |
| default_width = None | |
| if len(noise.shape) >= 4: # TODO: 8 multiple should be set by the model | |
| default_width = noise.shape[3] * 8 | |
| params["width"] = params.get("width", default_width) | |
| params["height"] = params.get("height", noise.shape[2] * 8) | |
| params["prompt_type"] = params.get("prompt_type", prompt_type) | |
| for k in kwargs: | |
| if k not in params: | |
| params[k] = kwargs[k] | |
| out = model_function(**params) | |
| x = x.copy() | |
| model_conds = x["model_conds"].copy() | |
| for k in out: | |
| model_conds[k] = out[k] | |
| x["model_conds"] = model_conds | |
| conds[t] = x | |
| return conds | |
| def resolve_areas_and_cond_masks_multidim(conditions, dims, device): | |
| """Optimized version that processes areas and masks more efficiently""" | |
| for i in range(len(conditions)): | |
| c = conditions[i] | |
| # Process area | |
| if "area" in c: | |
| area = c["area"] | |
| if area[0] == "percentage": | |
| # Vectorized calculation of area dimensions | |
| a = area[1:] | |
| a_len = len(a) // 2 | |
| # Calculate all dimensions at once using tensor operations | |
| dims_tensor = torch.tensor(dims, device="cpu") | |
| first_part = torch.tensor(a[:a_len], device="cpu") * dims_tensor | |
| second_part = torch.tensor(a[a_len:], device="cpu") * dims_tensor | |
| # Convert to rounded integers and tuple | |
| first_part = torch.max( | |
| torch.ones_like(first_part), torch.round(first_part) | |
| ) | |
| second_part = torch.round(second_part) | |
| # Create the new area tuple | |
| new_area = tuple(first_part.int().tolist()) + tuple( | |
| second_part.int().tolist() | |
| ) | |
| # Create a modified copy with the new area | |
| modified = c.copy() | |
| modified["area"] = new_area | |
| conditions[i] = modified | |
| # Process mask | |
| if "mask" in c: | |
| modified = c.copy() | |
| mask = c["mask"].to(device=device) | |
| # Combine dimension checks and unsqueeze operation | |
| if len(mask.shape) == len(dims): | |
| mask = mask.unsqueeze(0) | |
| # Only interpolate if needed | |
| if mask.shape[1:] != dims: | |
| # Optimize interpolation by ensuring mask is in the right format for the operation | |
| if len(mask.shape) == 3 and mask.shape[0] == 1: | |
| # Already in the right format for interpolation | |
| mask = torch.nn.functional.interpolate( | |
| mask.unsqueeze(1), | |
| size=dims, | |
| mode="bilinear", | |
| align_corners=False, | |
| ).squeeze(1) | |
| else: | |
| # Ensure mask is properly formatted for interpolation | |
| mask = torch.nn.functional.interpolate( | |
| mask | |
| if len(mask.shape) > 3 and mask.shape[1] == 1 | |
| else mask.unsqueeze(1), | |
| size=dims, | |
| mode="bilinear", | |
| align_corners=False, | |
| ).squeeze(1) | |
| modified["mask"] = mask | |
| conditions[i] = modified | |
| def process_conds( | |
| model: object, | |
| noise: torch.Tensor, | |
| conds: dict, | |
| device: torch.device, | |
| latent_image: torch.Tensor = None, | |
| denoise_mask: torch.Tensor = None, | |
| seed: int = None, | |
| ) -> dict: | |
| """#### Process conditions. | |
| #### Args: | |
| - `model` (object): The model. | |
| - `noise` (torch.Tensor): The noise tensor. | |
| - `conds` (dict): The conditions. | |
| - `device` (torch.device): The device. | |
| - `latent_image` (torch.Tensor, optional): The latent image tensor. Defaults to None. | |
| - `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None. | |
| - `seed` (int, optional): The seed. Defaults to None. | |
| #### Returns: | |
| - `dict`: The processed conditions. | |
| """ | |
| for k in conds: | |
| conds[k] = conds[k][:] | |
| resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device) | |
| for k in conds: | |
| ksampler_util.calculate_start_end_timesteps(model, conds[k]) | |
| if hasattr(model, "extra_conds"): | |
| for k in conds: | |
| conds[k] = encode_model_conds( | |
| model.extra_conds, | |
| conds[k], | |
| noise, | |
| device, | |
| k, | |
| latent_image=latent_image, | |
| denoise_mask=denoise_mask, | |
| seed=seed, | |
| ) | |
| # make sure each cond area has an opposite one with the same area | |
| for k in conds: | |
| for c in conds[k]: | |
| for kk in conds: | |
| if k != kk: | |
| cond_util.create_cond_with_same_area_if_none(conds[kk], c) | |
| for k in conds: | |
| ksampler_util.pre_run_control(model, conds[k]) | |
| if "positive" in conds: | |
| positive = conds["positive"] | |
| for k in conds: | |
| if k != "positive": | |
| ksampler_util.apply_empty_x_to_equal_area( | |
| list( | |
| filter( | |
| lambda c: c.get("control_apply_to_uncond", False) is True, | |
| positive, | |
| ) | |
| ), | |
| conds[k], | |
| "control", | |
| lambda cond_cnets, x: cond_cnets[x], | |
| ) | |
| ksampler_util.apply_empty_x_to_equal_area( | |
| positive, conds[k], "gligen", lambda cond_cnets, x: cond_cnets[x] | |
| ) | |
| return conds | |