Spaces:
Runtime error
Runtime error
| import warnings | |
| from typing import Tuple, Union | |
| import torch | |
| from diffusers.schedulers.scheduling_lms_discrete import \ | |
| LMSDiscreteScheduler as _LMSDiscreteScheduler | |
| from diffusers.schedulers.scheduling_lms_discrete import \ | |
| LMSDiscreteSchedulerOutput | |
| class LMSDiscreteScheduler(_LMSDiscreteScheduler): | |
| def step( | |
| self, | |
| model_output: torch.FloatTensor, | |
| step_index: int, | |
| sample: torch.FloatTensor, | |
| order: int = 4, | |
| return_dict: bool = True, | |
| ) -> Union[LMSDiscreteSchedulerOutput, Tuple]: | |
| """ | |
| Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | |
| process from the learned model outputs (most often the predicted noise). | |
| Args: | |
| model_output (`torch.FloatTensor`): direct output from learned diffusion model. | |
| timestep (`float`): current timestep in the diffusion chain. | |
| sample (`torch.FloatTensor`): | |
| current instance of sample being created by diffusion process. | |
| order: coefficient for multi-step inference. | |
| return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class | |
| Returns: | |
| [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`: | |
| [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. | |
| When returning a tuple, the first element is the sample tensor. | |
| """ | |
| if not self.is_scale_input_called: | |
| warnings.warn( | |
| "The `scale_model_input` function should be called before `step` to ensure correct denoising. " | |
| "See `StableDiffusionPipeline` for a usage example." | |
| ) | |
| sigma = self.sigmas[step_index] | |
| # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | |
| if self.config.prediction_type == "epsilon": | |
| pred_original_sample = sample - sigma * model_output | |
| elif self.config.prediction_type == "v_prediction": | |
| # * c_out + input * c_skip | |
| pred_original_sample = model_output * \ | |
| (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) | |
| else: | |
| raise ValueError( | |
| f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" | |
| ) | |
| # 2. Convert to an ODE derivative | |
| derivative = (sample - pred_original_sample) / sigma | |
| self.derivatives.append(derivative) | |
| if len(self.derivatives) > order: | |
| self.derivatives.pop(0) | |
| # 3. Compute linear multistep coefficients | |
| order = min(step_index + 1, order) | |
| lms_coeffs = [self.get_lms_coefficient( | |
| order, step_index, curr_order) for curr_order in range(order)] | |
| # 4. Compute previous sample based on the derivatives path | |
| prev_sample = sample + sum( | |
| coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) | |
| ) | |
| if not return_dict: | |
| return (prev_sample,) | |
| return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) | |
| def scale_model_input( | |
| self, | |
| sample: torch.FloatTensor, | |
| iteration: int | |
| ) -> torch.FloatTensor: | |
| """ | |
| Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. | |
| Args: | |
| sample (`torch.FloatTensor`): input sample | |
| timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain | |
| Returns: | |
| `torch.FloatTensor`: scaled input sample | |
| """ | |
| sample = sample / ((self.sigmas[iteration]**2 + 1) ** 0.5) | |
| self.is_scale_input_called = True | |
| return sample |