|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Utility functions. |
|
|
""" |
|
|
|
|
|
from typing import Callable |
|
|
import torch |
|
|
|
|
|
|
|
|
def expand_dims(tensor: torch.Tensor, ndim: int): |
|
|
""" |
|
|
Expand tensor to target ndim. New dims are added to the right. |
|
|
For example, if the tensor shape was (8,), target ndim is 4, return (8, 1, 1, 1). |
|
|
""" |
|
|
shape = tensor.shape + (1,) * (ndim - tensor.ndim) |
|
|
return tensor.reshape(shape) |
|
|
|
|
|
|
|
|
def assert_schedule_timesteps_compatible(schedule, timesteps): |
|
|
""" |
|
|
Check if schedule and timesteps are compatible. |
|
|
""" |
|
|
if schedule.T != timesteps.T: |
|
|
raise ValueError("Schedule and timesteps must have the same T.") |
|
|
if schedule.is_continuous() != timesteps.is_continuous(): |
|
|
raise ValueError("Schedule and timesteps must have the same continuity.") |
|
|
|
|
|
|
|
|
def classifier_free_guidance( |
|
|
pos: torch.Tensor, |
|
|
neg: torch.Tensor, |
|
|
scale: float, |
|
|
rescale: float = 0.0, |
|
|
): |
|
|
""" |
|
|
Apply classifier-free guidance. |
|
|
""" |
|
|
|
|
|
cfg = neg + scale * (pos - neg) |
|
|
|
|
|
|
|
|
if rescale != 0.0: |
|
|
pos_std = pos.std(dim=list(range(1, pos.ndim)), keepdim=True) |
|
|
cfg_std = cfg.std(dim=list(range(1, cfg.ndim)), keepdim=True) |
|
|
factor = pos_std / cfg_std |
|
|
factor = rescale * factor + (1 - rescale) |
|
|
cfg *= factor |
|
|
|
|
|
return cfg |
|
|
|
|
|
|
|
|
def classifier_free_guidance_dispatcher( |
|
|
pos: Callable, |
|
|
neg: Callable, |
|
|
scale: float, |
|
|
rescale: float = 0.0, |
|
|
): |
|
|
""" |
|
|
Optionally execute models depending on classifer-free guidance scale. |
|
|
""" |
|
|
|
|
|
if scale == 1.0: |
|
|
return pos() |
|
|
|
|
|
|
|
|
return classifier_free_guidance( |
|
|
pos=pos(), |
|
|
neg=neg(), |
|
|
scale=scale, |
|
|
rescale=rescale, |
|
|
) |
|
|
|