Spaces:
Sleeping
Sleeping
| import logging | |
| from enum import Enum | |
| from typing import Union | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor, nn | |
| from .cfm import CFM | |
| from .irmae import IRMAE, IRMAEOutput | |
| logger = logging.getLogger(__name__) | |
| def freeze_(module): | |
| for p in module.parameters(): | |
| p.requires_grad_(False) | |
| class LCFM(nn.Module): | |
| class Mode(Enum): | |
| AE = "ae" | |
| CFM = "cfm" | |
| def __init__(self, ae: IRMAE, cfm: CFM, z_scale: float = 1.0): | |
| super().__init__() | |
| self.ae = ae | |
| self.cfm = cfm | |
| self.z_scale = z_scale | |
| self._mode = None | |
| self._eval_tau = 0.5 | |
| def mode(self): | |
| return self._mode | |
| def set_mode_(self, mode): | |
| mode = self.Mode(mode) | |
| self._mode = mode | |
| if mode == mode.AE: | |
| freeze_(self.cfm) | |
| logger.info("Freeze cfm") | |
| elif mode == mode.CFM: | |
| freeze_(self.ae) | |
| logger.info("Freeze ae (encoder and decoder)") | |
| else: | |
| raise ValueError(f"Unknown training mode: {mode}") | |
| def get_running_train_loop(self): | |
| try: | |
| # Lazy import | |
| from ...utils.train_loop import TrainLoop | |
| return TrainLoop.get_running_loop() | |
| except ImportError: | |
| return None | |
| def global_step(self): | |
| loop = self.get_running_train_loop() | |
| if loop is None: | |
| return None | |
| return loop.global_step | |
| def _visualize(self, x, y, y_): | |
| loop = self.get_running_train_loop() | |
| if loop is None: | |
| return | |
| plt.subplot(221) | |
| plt.imshow( | |
| y[0].detach().cpu().numpy(), | |
| aspect="auto", | |
| origin="lower", | |
| interpolation="none", | |
| ) | |
| plt.title("GT") | |
| plt.subplot(222) | |
| y_ = y_[:, : y.shape[1]] | |
| plt.imshow( | |
| y_[0].detach().cpu().numpy(), | |
| aspect="auto", | |
| origin="lower", | |
| interpolation="none", | |
| ) | |
| plt.title("Posterior") | |
| plt.subplot(223) | |
| z_ = self.cfm(x) | |
| y__ = self.ae.decode(z_) | |
| y__ = y__[:, : y.shape[1]] | |
| plt.imshow( | |
| y__[0].detach().cpu().numpy(), | |
| aspect="auto", | |
| origin="lower", | |
| interpolation="none", | |
| ) | |
| plt.title("C-Prior") | |
| del y__ | |
| plt.subplot(224) | |
| z_ = torch.randn_like(z_) | |
| y__ = self.ae.decode(z_) | |
| y__ = y__[:, : y.shape[1]] | |
| plt.imshow( | |
| y__[0].detach().cpu().numpy(), | |
| aspect="auto", | |
| origin="lower", | |
| interpolation="none", | |
| ) | |
| plt.title("Prior") | |
| del z_, y__ | |
| path = loop.make_current_step_viz_path("recon", ".png") | |
| path.parent.mkdir(exist_ok=True, parents=True) | |
| plt.tight_layout() | |
| plt.savefig(path, dpi=500) | |
| plt.close() | |
| def _scale(self, z: Tensor): | |
| return z * self.z_scale | |
| def _unscale(self, z: Tensor): | |
| return z / self.z_scale | |
| def eval_tau_(self, tau): | |
| self._eval_tau = tau | |
| def forward(self, x, y: Union[Tensor, None] = None, ψ0: Union[Tensor, None] = None): | |
| """ | |
| Args: | |
| x: (b d t), condition mel | |
| y: (b d t), target mel | |
| ψ0: (b d t), starting mel | |
| """ | |
| if self.mode == self.Mode.CFM: | |
| self.ae.eval() # Always set to eval when training cfm | |
| if ψ0 is not None: | |
| ψ0 = self._scale(self.ae.encode(ψ0)) | |
| if self.training: | |
| tau = torch.rand_like(ψ0[:, :1, :1]) | |
| else: | |
| tau = self._eval_tau | |
| ψ0 = tau * torch.randn_like(ψ0) + (1 - tau) * ψ0 | |
| if y is None: | |
| if self.mode == self.Mode.AE: | |
| with torch.no_grad(): | |
| training = self.ae.training | |
| self.ae.eval() | |
| z = self.ae.encode(x) | |
| self.ae.train(training) | |
| else: | |
| z = self._unscale(self.cfm(x, ψ0=ψ0)) | |
| h = self.ae.decode(z) | |
| else: | |
| ae_output: IRMAEOutput = self.ae( | |
| y, skip_decoding=self.mode == self.Mode.CFM | |
| ) | |
| if self.mode == self.Mode.CFM: | |
| _ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0) | |
| h = ae_output.decoded | |
| if ( | |
| h is not None | |
| and self.global_step is not None | |
| and self.global_step % 100 == 0 | |
| ): | |
| self._visualize(x[:1], y[:1], h[:1]) | |
| return h | |