Spaces:
Runtime error
Runtime error
| import torch | |
| from typing import Union | |
| def deterministic_ddpm_step( | |
| model_output: torch.FloatTensor, | |
| timestep: Union[float, torch.FloatTensor], | |
| sample: torch.FloatTensor, | |
| scheduler, | |
| ): | |
| """ | |
| Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
| process from the learned model outputs (most often the predicted noise). | |
| """ | |
| t = timestep | |
| prev_t = scheduler.previous_timestep(t) | |
| if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [ | |
| "learned", | |
| "learned_range", | |
| ]: | |
| model_output, predicted_variance = torch.split( | |
| model_output, sample.shape[1], dim=1 | |
| ) | |
| else: | |
| predicted_variance = None | |
| # 1. compute alphas, betas | |
| alpha_prod_t = scheduler.alphas_cumprod[t] | |
| alpha_prod_t_prev = ( | |
| scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one | |
| ) | |
| beta_prod_t = 1 - alpha_prod_t | |
| beta_prod_t_prev = 1 - alpha_prod_t_prev | |
| current_alpha_t = alpha_prod_t / alpha_prod_t_prev | |
| current_beta_t = 1 - current_alpha_t | |
| # 2. compute predicted original sample from predicted noise also called | |
| # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf | |
| if scheduler.config.prediction_type == "epsilon": | |
| pred_original_sample = ( | |
| sample - beta_prod_t ** (0.5) * model_output | |
| ) / alpha_prod_t ** (0.5) | |
| elif scheduler.config.prediction_type == "sample": | |
| pred_original_sample = model_output | |
| elif scheduler.config.prediction_type == "v_prediction": | |
| pred_original_sample = (alpha_prod_t**0.5) * sample - ( | |
| beta_prod_t**0.5 | |
| ) * model_output | |
| else: | |
| raise ValueError( | |
| f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" | |
| " `v_prediction` for the DDPMScheduler." | |
| ) | |
| # 3. Clip or threshold "predicted x_0" | |
| if scheduler.config.thresholding: | |
| pred_original_sample = scheduler._threshold_sample(pred_original_sample) | |
| elif scheduler.config.clip_sample: | |
| pred_original_sample = pred_original_sample.clamp( | |
| -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range | |
| ) | |
| current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t | |
| coef_D = current_sample_coeff * (beta_prod_t ** (0.5)) ## it is equal to coef_D | |
| pred_prev_sample = (alpha_prod_t_prev ** (0.5) * pred_original_sample) + ( | |
| coef_D * model_output | |
| ) | |
| return pred_prev_sample, coef_D | |