|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Union |
|
|
import torch |
|
|
from diffusers.models.embeddings import get_timestep_embedding |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): |
|
|
return emb1 if emb2 is None else emb1 + emb2 |
|
|
|
|
|
|
|
|
class TimeEmbedding(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
sinusoidal_dim: int, |
|
|
hidden_dim: int, |
|
|
output_dim: int, |
|
|
): |
|
|
super().__init__() |
|
|
self.sinusoidal_dim = sinusoidal_dim |
|
|
self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) |
|
|
self.proj_hid = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.proj_out = nn.Linear(hidden_dim, output_dim) |
|
|
self.act = nn.SiLU() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], |
|
|
device: torch.device, |
|
|
dtype: torch.dtype, |
|
|
) -> torch.FloatTensor: |
|
|
if not torch.is_tensor(timestep): |
|
|
timestep = torch.tensor([timestep], device=device, dtype=dtype) |
|
|
if timestep.ndim == 0: |
|
|
timestep = timestep[None] |
|
|
|
|
|
emb = get_timestep_embedding( |
|
|
timesteps=timestep, |
|
|
embedding_dim=self.sinusoidal_dim, |
|
|
flip_sin_to_cos=False, |
|
|
downscale_freq_shift=0, |
|
|
) |
|
|
emb = emb.to(dtype) |
|
|
emb = self.proj_in(emb) |
|
|
emb = self.act(emb) |
|
|
emb = self.proj_hid(emb) |
|
|
emb = self.act(emb) |
|
|
emb = self.proj_out(emb) |
|
|
return emb |
|
|
|