|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Euler ODE solver. |
|
|
""" |
|
|
|
|
|
from typing import Callable |
|
|
import torch |
|
|
from einops import rearrange |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from models.dit_v2 import na |
|
|
|
|
|
from ..types import PredictionType |
|
|
from ..utils import expand_dims |
|
|
from .base import Sampler, SamplerModelArgs |
|
|
|
|
|
|
|
|
class EulerSampler(Sampler): |
|
|
""" |
|
|
The Euler method is the simplest ODE solver. |
|
|
<https://en.wikipedia.org/wiki/Euler_method> |
|
|
""" |
|
|
|
|
|
def sample( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
f: Callable[[SamplerModelArgs], torch.Tensor], |
|
|
) -> torch.Tensor: |
|
|
timesteps = self.timesteps.timesteps |
|
|
progress = self.get_progress_bar() |
|
|
i = 0 |
|
|
for t, s in zip(timesteps[:-1], timesteps[1:]): |
|
|
pred = f(SamplerModelArgs(x, t, i)) |
|
|
x = self.step_to(pred, x, t, s) |
|
|
i += 1 |
|
|
progress.update() |
|
|
|
|
|
if self.return_endpoint: |
|
|
t = timesteps[-1] |
|
|
pred = f(SamplerModelArgs(x, t, i)) |
|
|
x = self.get_endpoint(pred, x, t) |
|
|
progress.update() |
|
|
return x |
|
|
|
|
|
def step( |
|
|
self, |
|
|
pred: torch.Tensor, |
|
|
x_t: torch.Tensor, |
|
|
t: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Step to the next timestep. |
|
|
""" |
|
|
return self.step_to(pred, x_t, t, self.get_next_timestep(t)) |
|
|
|
|
|
def step_to( |
|
|
self, |
|
|
pred: torch.Tensor, |
|
|
x_t: torch.Tensor, |
|
|
t: torch.Tensor, |
|
|
s: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Steps from x_t at timestep t to x_s at timestep s. Returns x_s. |
|
|
""" |
|
|
t = expand_dims(t, x_t.ndim) |
|
|
s = expand_dims(s, x_t.ndim) |
|
|
T = self.schedule.T |
|
|
|
|
|
pred_x_0, pred_x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t) |
|
|
pred_x_s = self.schedule.forward(pred_x_0, pred_x_T, s.clamp(0, T)) |
|
|
|
|
|
pred_x_s = pred_x_s.where(s >= 0, pred_x_0) |
|
|
pred_x_s = pred_x_s.where(s <= T, pred_x_T) |
|
|
return pred_x_s |
|
|
|