| from diffusers import DDPMScheduler | |
| import torch | |
| class HookedNoiseScheduler: | |
| scheduler: DDPMScheduler | |
| pre_hooks: list | |
| post_hooks: list | |
| def __init__(self, scheduler): | |
| object.__setattr__(self, 'scheduler', scheduler) | |
| object.__setattr__(self, 'pre_hooks', []) | |
| object.__setattr__(self, 'post_hooks', []) | |
| def step( | |
| self, | |
| model_output, timestep, sample, generator, return_dict | |
| ): | |
| assert return_dict == False, "return_dict == True is not implemented" | |
| for hook in self.pre_hooks: | |
| hook_output = hook(model_output, timestep, sample, generator) | |
| if hook_output is not None: | |
| model_output, timestep, sample, generator = hook_output | |
| (pred_prev_sample, ) = self.scheduler.step(model_output, timestep, sample, generator, return_dict) | |
| for hook in self.post_hooks: | |
| hook_output = hook(pred_prev_sample) | |
| if hook_output is not None: | |
| pred_prev_sample = hook_output | |
| return (pred_prev_sample, ) | |
| def __getattr__(self, name): | |
| return getattr(self.scheduler, name) | |
| def __setattr__(self, name, value): | |
| if name in {'scheduler', 'pre_hooks', 'post_hooks'}: | |
| object.__setattr__(self, name, value) | |
| else: | |
| setattr(self.scheduler, name, value) |