Spaces:
Runtime error
Runtime error
| import itertools | |
| import math | |
| from typing import Any, Callable | |
| import lightning as L | |
| import torch | |
| import torch.nn.functional as F | |
| import wandb | |
| from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger | |
| from matplotlib import pyplot as plt | |
| from torch import nn | |
| from fish_speech.models.vqgan.modules.discriminator import Discriminator | |
| from fish_speech.models.vqgan.modules.wavenet import WaveNet | |
| from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask | |
| class VQGAN(L.LightningModule): | |
| def __init__( | |
| self, | |
| optimizer: Callable, | |
| lr_scheduler: Callable, | |
| encoder: WaveNet, | |
| quantizer: nn.Module, | |
| decoder: WaveNet, | |
| discriminator: Discriminator, | |
| vocoder: nn.Module, | |
| encode_mel_transform: nn.Module, | |
| gt_mel_transform: nn.Module, | |
| weight_adv: float = 1.0, | |
| weight_vq: float = 1.0, | |
| weight_mel: float = 1.0, | |
| sampling_rate: int = 44100, | |
| freeze_encoder: bool = False, | |
| ): | |
| super().__init__() | |
| # Model parameters | |
| self.optimizer_builder = optimizer | |
| self.lr_scheduler_builder = lr_scheduler | |
| # Modules | |
| self.encoder = encoder | |
| self.quantizer = quantizer | |
| self.decoder = decoder | |
| self.vocoder = vocoder | |
| self.discriminator = discriminator | |
| self.encode_mel_transform = encode_mel_transform | |
| self.gt_mel_transform = gt_mel_transform | |
| # A simple linear layer to project quality to condition channels | |
| self.quality_projection = nn.Linear(1, 768) | |
| # Freeze vocoder | |
| for param in self.vocoder.parameters(): | |
| param.requires_grad = False | |
| # Loss weights | |
| self.weight_adv = weight_adv | |
| self.weight_vq = weight_vq | |
| self.weight_mel = weight_mel | |
| # Other parameters | |
| self.sampling_rate = sampling_rate | |
| # Disable strict loading | |
| self.strict_loading = False | |
| # If encoder is frozen | |
| if freeze_encoder: | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = False | |
| for param in self.quantizer.parameters(): | |
| param.requires_grad = False | |
| self.automatic_optimization = False | |
| def on_save_checkpoint(self, checkpoint): | |
| # Do not save vocoder | |
| state_dict = checkpoint["state_dict"] | |
| for name in list(state_dict.keys()): | |
| if "vocoder" in name: | |
| state_dict.pop(name) | |
| def configure_optimizers(self): | |
| optimizer_generator = self.optimizer_builder( | |
| itertools.chain( | |
| self.encoder.parameters(), | |
| self.quantizer.parameters(), | |
| self.decoder.parameters(), | |
| self.quality_projection.parameters(), | |
| ) | |
| ) | |
| optimizer_discriminator = self.optimizer_builder( | |
| self.discriminator.parameters() | |
| ) | |
| lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator) | |
| lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator) | |
| return ( | |
| { | |
| "optimizer": optimizer_generator, | |
| "lr_scheduler": { | |
| "scheduler": lr_scheduler_generator, | |
| "interval": "step", | |
| "name": "optimizer/generator", | |
| }, | |
| }, | |
| { | |
| "optimizer": optimizer_discriminator, | |
| "lr_scheduler": { | |
| "scheduler": lr_scheduler_discriminator, | |
| "interval": "step", | |
| "name": "optimizer/discriminator", | |
| }, | |
| }, | |
| ) | |
| def training_step(self, batch, batch_idx): | |
| optim_g, optim_d = self.optimizers() | |
| audios, audio_lengths = batch["audios"], batch["audio_lengths"] | |
| audios = audios.float() | |
| audios = audios[:, None, :] | |
| with torch.no_grad(): | |
| encoded_mels = self.encode_mel_transform(audios) | |
| gt_mels = self.gt_mel_transform(audios) | |
| quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10 | |
| quality = quality.unsqueeze(-1) | |
| mel_lengths = audio_lengths // self.gt_mel_transform.hop_length | |
| mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2]) | |
| mel_masks_float_conv = mel_masks[:, None, :].float() | |
| gt_mels = gt_mels * mel_masks_float_conv | |
| encoded_mels = encoded_mels * mel_masks_float_conv | |
| # Encode | |
| encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv | |
| # Quantize | |
| vq_result = self.quantizer(encoded_features) | |
| loss_vq = getattr("vq_result", "loss", 0.0) | |
| vq_recon_features = vq_result.z * mel_masks_float_conv | |
| vq_recon_features = ( | |
| vq_recon_features + self.quality_projection(quality)[:, :, None] | |
| ) | |
| # VQ Decode | |
| gen_mel = ( | |
| self.decoder( | |
| torch.randn_like(vq_recon_features) * mel_masks_float_conv, | |
| condition=vq_recon_features, | |
| ) | |
| * mel_masks_float_conv | |
| ) | |
| # Discriminator | |
| real_logits = self.discriminator(gt_mels) | |
| fake_logits = self.discriminator(gen_mel.detach()) | |
| d_mask = F.interpolate( | |
| mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest" | |
| ) | |
| loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask) | |
| loss_fake = avg_with_mask(fake_logits**2, d_mask) | |
| loss_d = loss_real + loss_fake | |
| self.log( | |
| "train/discriminator/loss", | |
| loss_d, | |
| on_step=True, | |
| on_epoch=False, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| # Discriminator backward | |
| optim_d.zero_grad() | |
| self.manual_backward(loss_d) | |
| self.clip_gradients( | |
| optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm" | |
| ) | |
| optim_d.step() | |
| # Mel Loss, applying l1, using a weighted sum | |
| mel_distance = ( | |
| gen_mel - gt_mels | |
| ).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5 | |
| loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv) | |
| loss_mel_mid_freq = avg_with_mask( | |
| mel_distance[:, 40:70, :], mel_masks_float_conv | |
| ) | |
| loss_mel_high_freq = avg_with_mask( | |
| mel_distance[:, 70:, :], mel_masks_float_conv | |
| ) | |
| loss_mel = ( | |
| loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1 | |
| ) | |
| # Adversarial Loss | |
| fake_logits = self.discriminator(gen_mel) | |
| loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask) | |
| # Total loss | |
| loss = ( | |
| self.weight_vq * loss_vq | |
| + self.weight_mel * loss_mel | |
| + self.weight_adv * loss_adv | |
| ) | |
| # Log losses | |
| self.log( | |
| "train/generator/loss", | |
| loss, | |
| on_step=True, | |
| on_epoch=False, | |
| prog_bar=True, | |
| logger=True, | |
| ) | |
| self.log( | |
| "train/generator/loss_vq", | |
| loss_vq, | |
| on_step=True, | |
| on_epoch=False, | |
| prog_bar=False, | |
| logger=True, | |
| ) | |
| self.log( | |
| "train/generator/loss_mel", | |
| loss_mel, | |
| on_step=True, | |
| on_epoch=False, | |
| prog_bar=False, | |
| logger=True, | |
| ) | |
| self.log( | |
| "train/generator/loss_adv", | |
| loss_adv, | |
| on_step=True, | |
| on_epoch=False, | |
| prog_bar=False, | |
| logger=True, | |
| ) | |
| # Generator backward | |
| optim_g.zero_grad() | |
| self.manual_backward(loss) | |
| self.clip_gradients( | |
| optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm" | |
| ) | |
| optim_g.step() | |
| scheduler_g, scheduler_d = self.lr_schedulers() | |
| scheduler_g.step() | |
| scheduler_d.step() | |
| def validation_step(self, batch: Any, batch_idx: int): | |
| audios, audio_lengths = batch["audios"], batch["audio_lengths"] | |
| audios = audios.float() | |
| audios = audios[:, None, :] | |
| encoded_mels = self.encode_mel_transform(audios) | |
| gt_mels = self.gt_mel_transform(audios) | |
| mel_lengths = audio_lengths // self.gt_mel_transform.hop_length | |
| mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2]) | |
| mel_masks_float_conv = mel_masks[:, None, :].float() | |
| gt_mels = gt_mels * mel_masks_float_conv | |
| encoded_mels = encoded_mels * mel_masks_float_conv | |
| # Encode | |
| encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv | |
| # Quantize | |
| vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv | |
| vq_recon_features = ( | |
| vq_recon_features | |
| + self.quality_projection( | |
| torch.ones( | |
| vq_recon_features.shape[0], 1, device=vq_recon_features.device | |
| ) | |
| * 2 | |
| )[:, :, None] | |
| ) | |
| # VQ Decode | |
| gen_aux_mels = ( | |
| self.decoder( | |
| torch.randn_like(vq_recon_features) * mel_masks_float_conv, | |
| condition=vq_recon_features, | |
| ) | |
| * mel_masks_float_conv | |
| ) | |
| loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv) | |
| self.log( | |
| "val/loss_mel", | |
| loss_mel, | |
| on_step=False, | |
| on_epoch=True, | |
| prog_bar=False, | |
| logger=True, | |
| sync_dist=True, | |
| ) | |
| recon_audios = self.vocoder(gt_mels) | |
| gen_aux_audios = self.vocoder(gen_aux_mels) | |
| # only log the first batch | |
| if batch_idx != 0: | |
| return | |
| for idx, ( | |
| gt_mel, | |
| gen_aux_mel, | |
| audio, | |
| gen_aux_audio, | |
| recon_audio, | |
| audio_len, | |
| ) in enumerate( | |
| zip( | |
| gt_mels, | |
| gen_aux_mels, | |
| audios.cpu().float(), | |
| gen_aux_audios.cpu().float(), | |
| recon_audios.cpu().float(), | |
| audio_lengths, | |
| ) | |
| ): | |
| if idx > 4: | |
| break | |
| mel_len = audio_len // self.gt_mel_transform.hop_length | |
| image_mels = plot_mel( | |
| [ | |
| gt_mel[:, :mel_len], | |
| gen_aux_mel[:, :mel_len], | |
| ], | |
| [ | |
| "Ground-Truth", | |
| "Auxiliary", | |
| ], | |
| ) | |
| if isinstance(self.logger, WandbLogger): | |
| self.logger.experiment.log( | |
| { | |
| "reconstruction_mel": wandb.Image(image_mels, caption="mels"), | |
| "wavs": [ | |
| wandb.Audio( | |
| audio[0, :audio_len], | |
| sample_rate=self.sampling_rate, | |
| caption="gt", | |
| ), | |
| wandb.Audio( | |
| gen_aux_audio[0, :audio_len], | |
| sample_rate=self.sampling_rate, | |
| caption="aux", | |
| ), | |
| wandb.Audio( | |
| recon_audio[0, :audio_len], | |
| sample_rate=self.sampling_rate, | |
| caption="recon", | |
| ), | |
| ], | |
| }, | |
| ) | |
| if isinstance(self.logger, TensorBoardLogger): | |
| self.logger.experiment.add_figure( | |
| f"sample-{idx}/mels", | |
| image_mels, | |
| global_step=self.global_step, | |
| ) | |
| self.logger.experiment.add_audio( | |
| f"sample-{idx}/wavs/gt", | |
| audio[0, :audio_len], | |
| self.global_step, | |
| sample_rate=self.sampling_rate, | |
| ) | |
| self.logger.experiment.add_audio( | |
| f"sample-{idx}/wavs/gen", | |
| gen_aux_audio[0, :audio_len], | |
| self.global_step, | |
| sample_rate=self.sampling_rate, | |
| ) | |
| self.logger.experiment.add_audio( | |
| f"sample-{idx}/wavs/recon", | |
| recon_audio[0, :audio_len], | |
| self.global_step, | |
| sample_rate=self.sampling_rate, | |
| ) | |
| plt.close(image_mels) | |
| def encode(self, audios, audio_lengths): | |
| audios = audios.float() | |
| mels = self.encode_mel_transform(audios) | |
| mel_lengths = audio_lengths // self.encode_mel_transform.hop_length | |
| mel_masks = sequence_mask(mel_lengths, mels.shape[2]) | |
| mel_masks_float_conv = mel_masks[:, None, :].float() | |
| mels = mels * mel_masks_float_conv | |
| # Encode | |
| encoded_features = self.encoder(mels) * mel_masks_float_conv | |
| feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor) | |
| return self.quantizer.encode(encoded_features), feature_lengths | |
| def decode(self, indices, feature_lengths, return_audios=False): | |
| factor = math.prod(self.quantizer.downsample_factor) | |
| mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor) | |
| mel_masks_float_conv = mel_masks[:, None, :].float() | |
| z = self.quantizer.decode(indices) * mel_masks_float_conv | |
| z = ( | |
| z | |
| + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[ | |
| :, :, None | |
| ] | |
| ) | |
| gen_mel = ( | |
| self.decoder( | |
| torch.randn_like(z) * mel_masks_float_conv, | |
| condition=z, | |
| ) | |
| * mel_masks_float_conv | |
| ) | |
| if return_audios: | |
| return self.vocoder(gen_mel) | |
| return gen_mel | |