|
|
from abc import ABC, abstractmethod |
|
|
from typing import Sequence, Union |
|
|
import torch |
|
|
|
|
|
from ..types import SamplingDirection |
|
|
|
|
|
|
|
|
class Timesteps(ABC): |
|
|
""" |
|
|
Timesteps base class. |
|
|
""" |
|
|
|
|
|
def __init__(self, T: Union[int, float]): |
|
|
assert T > 0 |
|
|
self._T = T |
|
|
|
|
|
@property |
|
|
def T(self) -> Union[int, float]: |
|
|
""" |
|
|
Maximum timestep inclusive. |
|
|
int if discrete, float if continuous. |
|
|
""" |
|
|
return self._T |
|
|
|
|
|
def is_continuous(self) -> bool: |
|
|
""" |
|
|
Whether the schedule is continuous. |
|
|
""" |
|
|
return isinstance(self.T, float) |
|
|
|
|
|
|
|
|
class SamplingTimesteps(Timesteps): |
|
|
""" |
|
|
Sampling timesteps. |
|
|
It defines the discretization of sampling steps. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
T: Union[int, float], |
|
|
timesteps: torch.Tensor, |
|
|
direction: SamplingDirection, |
|
|
): |
|
|
assert timesteps.ndim == 1 |
|
|
super().__init__(T) |
|
|
self.timesteps = timesteps |
|
|
self.direction = direction |
|
|
|
|
|
def __len__(self) -> int: |
|
|
""" |
|
|
Number of sampling steps. |
|
|
""" |
|
|
return len(self.timesteps) |
|
|
|
|
|
def __getitem__(self, idx: Union[int, torch.IntTensor]) -> torch.Tensor: |
|
|
""" |
|
|
The timestep at the sampling step. |
|
|
Returns a scalar tensor if idx is int, |
|
|
or tensor of the same size if idx is a tensor. |
|
|
""" |
|
|
return self.timesteps[idx] |
|
|
|
|
|
def index(self, t: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Find index by t. |
|
|
Return index of the same shape as t. |
|
|
Index is -1 if t not found in timesteps. |
|
|
""" |
|
|
i, j = t.reshape(-1, 1).eq(self.timesteps).nonzero(as_tuple=True) |
|
|
idx = torch.full_like(t, fill_value=-1, dtype=torch.int) |
|
|
idx.view(-1)[i] = j.int() |
|
|
return idx |
|
|
|