Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import torch | |
| from modules.cond import cond, cond_util | |
| def cfg_function( | |
| model: torch.nn.Module, | |
| cond_pred: torch.Tensor, | |
| uncond_pred: torch.Tensor, | |
| cond_scale: float, | |
| x: torch.Tensor, | |
| timestep: int, | |
| model_options: dict = {}, | |
| cond: torch.Tensor = None, | |
| uncond: torch.Tensor = None, | |
| ) -> torch.Tensor: | |
| """#### Apply classifier-free guidance (CFG) to the model predictions. | |
| #### Args: | |
| - `model` (torch.nn.Module): The model. | |
| - `cond_pred` (torch.Tensor): The conditioned prediction. | |
| - `uncond_pred` (torch.Tensor): The unconditioned prediction. | |
| - `cond_scale` (float): The CFG scale. | |
| - `x` (torch.Tensor): The input tensor. | |
| - `timestep` (int): The current timestep. | |
| - `model_options` (dict, optional): Additional model options. Defaults to {}. | |
| - `cond` (torch.Tensor, optional): The conditioned tensor. Defaults to None. | |
| - `uncond` (torch.Tensor, optional): The unconditioned tensor. Defaults to None. | |
| #### Returns: | |
| - `torch.Tensor`: The CFG result. | |
| """ | |
| # Check for custom sampler CFG function first | |
| if "sampler_cfg_function" in model_options: | |
| # Precompute differences to avoid redundant operations | |
| cond_diff = x - cond_pred | |
| uncond_diff = x - uncond_pred | |
| args = { | |
| "cond": cond_diff, | |
| "uncond": uncond_diff, | |
| "cond_scale": cond_scale, | |
| "timestep": timestep, | |
| "input": x, | |
| "sigma": timestep, | |
| "cond_denoised": cond_pred, | |
| "uncond_denoised": uncond_pred, | |
| "model": model, | |
| "model_options": model_options, | |
| } | |
| cfg_result = x - model_options["sampler_cfg_function"](args) | |
| else: | |
| # Standard CFG calculation - optimized to avoid intermediate tensor allocation | |
| # When cond_scale = 1.0, we can just return cond_pred without computation | |
| if math.isclose(cond_scale, 1.0): | |
| cfg_result = cond_pred | |
| else: | |
| # Fused operation: uncond_pred + (cond_pred - uncond_pred) * cond_scale | |
| # Equivalent to: uncond_pred * (1 - cond_scale) + cond_pred * cond_scale | |
| cfg_result = torch.lerp(uncond_pred, cond_pred, cond_scale) | |
| # Apply post-CFG functions if any | |
| post_cfg_functions = model_options.get("sampler_post_cfg_function", []) | |
| if post_cfg_functions: | |
| args = { | |
| "denoised": cfg_result, | |
| "cond": cond, | |
| "uncond": uncond, | |
| "model": model, | |
| "uncond_denoised": uncond_pred, | |
| "cond_denoised": cond_pred, | |
| "sigma": timestep, | |
| "model_options": model_options, | |
| "input": x, | |
| } | |
| # Apply each post-CFG function in sequence | |
| for fn in post_cfg_functions: | |
| cfg_result = fn(args) | |
| # Update the denoised result for the next function | |
| args["denoised"] = cfg_result | |
| return cfg_result | |
| def sampling_function( | |
| model: torch.nn.Module, | |
| x: torch.Tensor, | |
| timestep: int, | |
| uncond: torch.Tensor, | |
| condo: torch.Tensor, | |
| cond_scale: float, | |
| model_options: dict = {}, | |
| seed: int = None, | |
| ) -> torch.Tensor: | |
| """#### Perform sampling with CFG. | |
| #### Args: | |
| - `model` (torch.nn.Module): The model. | |
| - `x` (torch.Tensor): The input tensor. | |
| - `timestep` (int): The current timestep. | |
| - `uncond` (torch.Tensor): The unconditioned tensor. | |
| - `condo` (torch.Tensor): The conditioned tensor. | |
| - `cond_scale` (float): The CFG scale. | |
| - `model_options` (dict, optional): Additional model options. Defaults to {}. | |
| - `seed` (int, optional): The random seed. Defaults to None. | |
| #### Returns: | |
| - `torch.Tensor`: The sampled tensor. | |
| """ | |
| # Optimize conditional logic for uncond | |
| uncond_ = ( | |
| None | |
| if ( | |
| math.isclose(cond_scale, 1.0) | |
| and not model_options.get("disable_cfg1_optimization", False) | |
| ) | |
| else uncond | |
| ) | |
| # Create conditions list once | |
| conds = [condo, uncond_] | |
| # Get model predictions for both conditions | |
| cond_outputs = cond.calc_cond_batch(model, conds, x, timestep, model_options) | |
| # Apply pre-CFG functions if any | |
| pre_cfg_functions = model_options.get("sampler_pre_cfg_function", []) | |
| if pre_cfg_functions: | |
| # Create args dictionary once | |
| args = { | |
| "conds": conds, | |
| "conds_out": cond_outputs, | |
| "cond_scale": cond_scale, | |
| "timestep": timestep, | |
| "input": x, | |
| "sigma": timestep, | |
| "model": model, | |
| "model_options": model_options, | |
| } | |
| # Apply each pre-CFG function | |
| for fn in pre_cfg_functions: | |
| cond_outputs = fn(args) | |
| args["conds_out"] = cond_outputs | |
| # Extract conditional and unconditional outputs explicitly for clarity | |
| cond_pred, uncond_pred = cond_outputs[0], cond_outputs[1] | |
| # Apply the CFG function | |
| return cfg_function( | |
| model, | |
| cond_pred, | |
| uncond_pred, | |
| cond_scale, | |
| x, | |
| timestep, | |
| model_options=model_options, | |
| cond=condo, | |
| uncond=uncond_, | |
| ) | |
| class CFGGuider: | |
| """#### Class for guiding the sampling process with CFG.""" | |
| def __init__(self, model_patcher, flux=False): | |
| """#### Initialize the CFGGuider. | |
| #### Args: | |
| - `model_patcher` (object): The model patcher. | |
| """ | |
| self.model_patcher = model_patcher | |
| self.model_options = model_patcher.model_options | |
| self.original_conds = {} | |
| self.cfg = 1.0 | |
| self.flux = flux | |
| def set_conds(self, positive, negative): | |
| """#### Set the conditions for CFG. | |
| #### Args: | |
| - `positive` (torch.Tensor): The positive condition. | |
| - `negative` (torch.Tensor): The negative condition. | |
| """ | |
| self.inner_set_conds({"positive": positive, "negative": negative}) | |
| def set_cfg(self, cfg): | |
| """#### Set the CFG scale. | |
| #### Args: | |
| - `cfg` (float): The CFG scale. | |
| """ | |
| self.cfg = cfg | |
| def inner_set_conds(self, conds): | |
| """#### Set the internal conditions. | |
| #### Args: | |
| - `conds` (dict): The conditions. | |
| """ | |
| for k in conds: | |
| self.original_conds[k] = cond.convert_cond(conds[k]) | |
| def __call__(self, *args, **kwargs): | |
| """#### Call the CFGGuider to predict noise. | |
| #### Returns: | |
| - `torch.Tensor`: The predicted noise. | |
| """ | |
| return self.predict_noise(*args, **kwargs) | |
| def predict_noise(self, x, timestep, model_options={}, seed=None): | |
| """#### Predict noise using CFG. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `timestep` (int): The current timestep. | |
| - `model_options` (dict, optional): Additional model options. Defaults to {}. | |
| - `seed` (int, optional): The random seed. Defaults to None. | |
| #### Returns: | |
| - `torch.Tensor`: The predicted noise. | |
| """ | |
| return sampling_function( | |
| self.inner_model, | |
| x, | |
| timestep, | |
| self.conds.get("negative", None), | |
| self.conds.get("positive", None), | |
| self.cfg, | |
| model_options=model_options, | |
| seed=seed, | |
| ) | |
| def inner_sample( | |
| self, | |
| noise, | |
| latent_image, | |
| device, | |
| sampler, | |
| sigmas, | |
| denoise_mask, | |
| callback, | |
| disable_pbar, | |
| seed, | |
| pipeline=False, | |
| ): | |
| """#### Perform the inner sampling process. | |
| #### Args: | |
| - `noise` (torch.Tensor): The noise tensor. | |
| - `latent_image` (torch.Tensor): The latent image tensor. | |
| - `device` (torch.device): The device to use. | |
| - `sampler` (object): The sampler object. | |
| - `sigmas` (torch.Tensor): The sigmas tensor. | |
| - `denoise_mask` (torch.Tensor): The denoise mask tensor. | |
| - `callback` (callable): The callback function. | |
| - `disable_pbar` (bool): Whether to disable the progress bar. | |
| - `seed` (int): The random seed. | |
| - `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False. | |
| #### Returns: | |
| - `torch.Tensor`: The sampled tensor. | |
| """ | |
| if ( | |
| latent_image is not None and torch.count_nonzero(latent_image) > 0 | |
| ): # Don't shift the empty latent image. | |
| latent_image = self.inner_model.process_latent_in(latent_image) | |
| self.conds = cond.process_conds( | |
| self.inner_model, | |
| noise, | |
| self.conds, | |
| device, | |
| latent_image, | |
| denoise_mask, | |
| seed, | |
| ) | |
| extra_args = {"model_options": self.model_options, "seed": seed} | |
| samples = sampler.sample( | |
| self, | |
| sigmas, | |
| extra_args, | |
| callback, | |
| noise, | |
| latent_image, | |
| denoise_mask, | |
| disable_pbar, | |
| pipeline=pipeline, | |
| ) | |
| return self.inner_model.process_latent_out(samples.to(torch.float32)) | |
| def sample( | |
| self, | |
| noise, | |
| latent_image, | |
| sampler, | |
| sigmas, | |
| denoise_mask=None, | |
| callback=None, | |
| disable_pbar=False, | |
| seed=None, | |
| pipeline=False, | |
| ): | |
| """#### Perform the sampling process with CFG. | |
| #### Args: | |
| - `noise` (torch.Tensor): The noise tensor. | |
| - `latent_image` (torch.Tensor): The latent image tensor. | |
| - `sampler` (object): The sampler object. | |
| - `sigmas` (torch.Tensor): The sigmas tensor. | |
| - `denoise_mask` (torch.Tensor, optional): The denoise mask tensor. Defaults to None. | |
| - `callback` (callable, optional): The callback function. Defaults to None. | |
| - `disable_pbar` (bool, optional): Whether to disable the progress bar. Defaults to False. | |
| - `seed` (int, optional): The random seed. Defaults to None. | |
| - `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False. | |
| #### Returns: | |
| - `torch.Tensor`: The sampled tensor. | |
| """ | |
| self.conds = {} | |
| for k in self.original_conds: | |
| self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) | |
| self.inner_model, self.conds, self.loaded_models = cond_util.prepare_sampling( | |
| self.model_patcher, noise.shape, self.conds, flux_enabled=self.flux | |
| ) | |
| device = self.model_patcher.load_device | |
| noise = noise.to(device) | |
| latent_image = latent_image.to(device) | |
| sigmas = sigmas.to(device) | |
| output = self.inner_sample( | |
| noise, | |
| latent_image, | |
| device, | |
| sampler, | |
| sigmas, | |
| denoise_mask, | |
| callback, | |
| disable_pbar, | |
| seed, | |
| pipeline=pipeline, | |
| ) | |
| cond_util.cleanup_models(self.conds, self.loaded_models) | |
| del self.inner_model | |
| del self.conds | |
| del self.loaded_models | |
| return output | |