Spaces:
Paused
Paused
| from abc import ABC, abstractmethod | |
| import numpy as np | |
| import torch as th | |
| import torch.distributed as dist | |
| def create_named_schedule_sampler(name, diffusion): | |
| """ | |
| Create a ScheduleSampler from a library of pre-defined samplers. | |
| :param name: the name of the sampler. | |
| :param diffusion: the diffusion object to sample for. | |
| """ | |
| if name == "uniform": | |
| return UniformSampler(diffusion) | |
| else: | |
| raise NotImplementedError(f"unknown schedule sampler: {name}") | |
| class ScheduleSampler(ABC): | |
| """ | |
| A distribution over timesteps in the diffusion process, intended to reduce | |
| variance of the objective. | |
| By default, samplers perform unbiased importance sampling, in which the | |
| objective's mean is unchanged. | |
| However, subclasses may override sample() to change how the resampled | |
| terms are reweighted, allowing for actual changes in the objective. | |
| """ | |
| def weights(self): | |
| """ | |
| Get a numpy array of weights, one per diffusion step. | |
| The weights needn't be normalized, but must be positive. | |
| """ | |
| def sample(self, batch_size, device): | |
| """ | |
| Importance-sample timesteps for a batch. | |
| :param batch_size: the number of timesteps. | |
| :param device: the torch device to save to. | |
| :return: a tuple (timesteps, weights): | |
| - timesteps: a tensor of timestep indices. | |
| - weights: a tensor of weights to scale the resulting losses. | |
| """ | |
| w = self.weights() | |
| p = w / np.sum(w) | |
| indices_np = np.random.choice(len(p), size=(batch_size, ), p=p) | |
| indices = th.from_numpy(indices_np).long().to(device) | |
| weights_np = 1 / (len(p) * p[indices_np]) | |
| weights = th.from_numpy(weights_np).float().to(device) | |
| return indices, weights | |
| class UniformSampler(ScheduleSampler): | |
| def __init__(self, num_timesteps): | |
| self._weights = np.ones([num_timesteps]) | |
| def weights(self): | |
| return self._weights | |