|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Utility functions for creating schedules and samplers from config. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from omegaconf import DictConfig |
|
|
|
|
|
from .samplers.base import Sampler |
|
|
from .samplers.euler import EulerSampler |
|
|
from .schedules.base import Schedule |
|
|
from .schedules.lerp import LinearInterpolationSchedule |
|
|
from .timesteps.base import SamplingTimesteps |
|
|
from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps |
|
|
|
|
|
|
|
|
def create_schedule_from_config( |
|
|
config: DictConfig, |
|
|
device: torch.device, |
|
|
dtype: torch.dtype = torch.float32, |
|
|
) -> Schedule: |
|
|
""" |
|
|
Create a schedule from configuration. |
|
|
""" |
|
|
if config.type == "lerp": |
|
|
return LinearInterpolationSchedule(T=config.get("T", 1.0)) |
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def create_sampler_from_config( |
|
|
config: DictConfig, |
|
|
schedule: Schedule, |
|
|
timesteps: SamplingTimesteps, |
|
|
) -> Sampler: |
|
|
""" |
|
|
Create a sampler from configuration. |
|
|
""" |
|
|
if config.type == "euler": |
|
|
return EulerSampler( |
|
|
schedule=schedule, |
|
|
timesteps=timesteps, |
|
|
prediction_type=config.prediction_type, |
|
|
) |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def create_sampling_timesteps_from_config( |
|
|
config: DictConfig, |
|
|
schedule: Schedule, |
|
|
device: torch.device, |
|
|
dtype: torch.dtype = torch.float32, |
|
|
) -> SamplingTimesteps: |
|
|
if config.type == "uniform_trailing": |
|
|
return UniformTrailingSamplingTimesteps( |
|
|
T=schedule.T, |
|
|
steps=config.steps, |
|
|
shift=config.get("shift", 1.0), |
|
|
device=device, |
|
|
) |
|
|
raise NotImplementedError |