Spaces:
Sleeping
Sleeping
| import logging | |
| from dataclasses import dataclass | |
| from typing import Union | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor, nn | |
| from torch.nn.utils.parametrizations import weight_norm | |
| from ...common import Normalizer | |
| logger = logging.getLogger(__name__) | |
| class IRMAEOutput: | |
| latent: Tensor # latent vector | |
| decoded: Union[Tensor, None] # decoder output, include extra dim | |
| class ResBlock(nn.Sequential): | |
| def __init__(self, channels, dilations=[1, 2, 4, 8]): | |
| wn = weight_norm | |
| super().__init__( | |
| nn.GroupNorm(32, channels), | |
| nn.GELU(), | |
| wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])), | |
| nn.GroupNorm(32, channels), | |
| nn.GELU(), | |
| wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])), | |
| nn.GroupNorm(32, channels), | |
| nn.GELU(), | |
| wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])), | |
| nn.GroupNorm(32, channels), | |
| nn.GELU(), | |
| wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])), | |
| ) | |
| def forward(self, x: Tensor): | |
| return x + super().forward(x) | |
| class IRMAE(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim, | |
| output_dim, | |
| latent_dim, | |
| hidden_dim=1024, | |
| num_irms=4, | |
| ): | |
| """ | |
| Args: | |
| input_dim: input dimension | |
| output_dim: output dimension | |
| latent_dim: latent dimension | |
| hidden_dim: hidden layer dimension | |
| num_irm_matrics: number of implicit rank minimization matrices | |
| norm: normalization layer | |
| """ | |
| self.input_dim = input_dim | |
| super().__init__() | |
| self.encoder = nn.Sequential( | |
| nn.Conv1d(input_dim, hidden_dim, 3, padding="same"), | |
| *[ResBlock(hidden_dim) for _ in range(4)], | |
| # Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf) | |
| *[ | |
| nn.Conv1d( | |
| hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False | |
| ) | |
| for i in range(num_irms) | |
| ], | |
| nn.Tanh(), | |
| ) | |
| self.decoder = nn.Sequential( | |
| nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"), | |
| *[ResBlock(hidden_dim) for _ in range(4)], | |
| nn.Conv1d(hidden_dim, output_dim, 1), | |
| ) | |
| self.head = nn.Sequential( | |
| nn.Conv1d(output_dim, hidden_dim, 3, padding="same"), | |
| nn.GELU(), | |
| nn.Conv1d(hidden_dim, input_dim, 1), | |
| ) | |
| self.estimator = Normalizer() | |
| def encode(self, x): | |
| """ | |
| Args: | |
| x: (b c t) tensor | |
| """ | |
| z = self.encoder(x) # (b c t) | |
| _ = self.estimator(z) # Estimate the glboal mean and std of z | |
| self.stats = {} | |
| self.stats["z_mean"] = z.mean().item() | |
| self.stats["z_std"] = z.std().item() | |
| z_float = z.float() | |
| self.stats["z_abs_68"] = z_float.abs().quantile(0.6827).item() | |
| self.stats["z_abs_95"] = z_float.abs().quantile(0.9545).item() | |
| self.stats["z_abs_99"] = z_float.abs().quantile(0.9973).item() | |
| return z | |
| def decode(self, z): | |
| """ | |
| Args: | |
| z: (b c t) tensor | |
| """ | |
| return self.decoder(z) | |
| def forward(self, x, skip_decoding=False): | |
| """ | |
| Args: | |
| x: (b c t) tensor | |
| skip_decoding: if True, skip the decoding step | |
| """ | |
| z = self.encode(x) # q(z|x) | |
| if skip_decoding: | |
| # This speeds up the training in cfm only mode | |
| decoded = None | |
| else: | |
| decoded = self.decode(z) # p(x|z) | |
| predicted = self.head(decoded) | |
| self.losses = dict(mse=F.mse_loss(predicted, x)) | |
| return IRMAEOutput(latent=z, decoded=decoded) | |