Spaces:
Sleeping
Sleeping
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # This code is modified from https://github.com/ming024/FastSpeech2/blob/master/model/fastspeech2.py | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from modules.transformer.Models import Encoder, Decoder | |
| from modules.transformer.Layers import PostNet | |
| from collections import OrderedDict | |
| import os | |
| import json | |
| def get_mask_from_lengths(lengths, max_len=None): | |
| device = lengths.device | |
| batch_size = lengths.shape[0] | |
| if max_len is None: | |
| max_len = torch.max(lengths).item() | |
| ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) | |
| mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) | |
| return mask | |
| def pad(input_ele, mel_max_length=None): | |
| if mel_max_length: | |
| max_len = mel_max_length | |
| else: | |
| max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) | |
| out_list = list() | |
| for i, batch in enumerate(input_ele): | |
| if len(batch.shape) == 1: | |
| one_batch_padded = F.pad( | |
| batch, (0, max_len - batch.size(0)), "constant", 0.0 | |
| ) | |
| elif len(batch.shape) == 2: | |
| one_batch_padded = F.pad( | |
| batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 | |
| ) | |
| out_list.append(one_batch_padded) | |
| out_padded = torch.stack(out_list) | |
| return out_padded | |
| class VarianceAdaptor(nn.Module): | |
| """Variance Adaptor""" | |
| def __init__(self, cfg): | |
| super(VarianceAdaptor, self).__init__() | |
| self.duration_predictor = VariancePredictor(cfg) | |
| self.length_regulator = LengthRegulator() | |
| self.pitch_predictor = VariancePredictor(cfg) | |
| self.energy_predictor = VariancePredictor(cfg) | |
| # assign the pitch/energy feature level | |
| if cfg.preprocess.use_frame_pitch: | |
| self.pitch_feature_level = "frame_level" | |
| self.pitch_dir = cfg.preprocess.pitch_dir | |
| else: | |
| self.pitch_feature_level = "phoneme_level" | |
| self.pitch_dir = cfg.preprocess.phone_pitch_dir | |
| if cfg.preprocess.use_frame_energy: | |
| self.energy_feature_level = "frame_level" | |
| self.energy_dir = cfg.preprocess.energy_dir | |
| else: | |
| self.energy_feature_level = "phoneme_level" | |
| self.energy_dir = cfg.preprocess.phone_energy_dir | |
| assert self.pitch_feature_level in ["phoneme_level", "frame_level"] | |
| assert self.energy_feature_level in ["phoneme_level", "frame_level"] | |
| pitch_quantization = cfg.model.variance_embedding.pitch_quantization | |
| energy_quantization = cfg.model.variance_embedding.energy_quantization | |
| n_bins = cfg.model.variance_embedding.n_bins | |
| assert pitch_quantization in ["linear", "log"] | |
| assert energy_quantization in ["linear", "log"] | |
| with open( | |
| os.path.join( | |
| cfg.preprocess.processed_dir, | |
| cfg.dataset[0], | |
| self.energy_dir, | |
| "statistics.json", | |
| ) | |
| ) as f: | |
| stats = json.load(f) | |
| stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]] | |
| mean, std = ( | |
| stats["voiced_positions"]["mean"], | |
| stats["voiced_positions"]["std"], | |
| ) | |
| energy_min = (stats["total_positions"]["min"] - mean) / std | |
| energy_max = (stats["total_positions"]["max"] - mean) / std | |
| with open( | |
| os.path.join( | |
| cfg.preprocess.processed_dir, | |
| cfg.dataset[0], | |
| self.pitch_dir, | |
| "statistics.json", | |
| ) | |
| ) as f: | |
| stats = json.load(f) | |
| stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]] | |
| mean, std = ( | |
| stats["voiced_positions"]["mean"], | |
| stats["voiced_positions"]["std"], | |
| ) | |
| pitch_min = (stats["total_positions"]["min"] - mean) / std | |
| pitch_max = (stats["total_positions"]["max"] - mean) / std | |
| if pitch_quantization == "log": | |
| self.pitch_bins = nn.Parameter( | |
| torch.exp( | |
| torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1) | |
| ), | |
| requires_grad=False, | |
| ) | |
| else: | |
| self.pitch_bins = nn.Parameter( | |
| torch.linspace(pitch_min, pitch_max, n_bins - 1), | |
| requires_grad=False, | |
| ) | |
| if energy_quantization == "log": | |
| self.energy_bins = nn.Parameter( | |
| torch.exp( | |
| torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1) | |
| ), | |
| requires_grad=False, | |
| ) | |
| else: | |
| self.energy_bins = nn.Parameter( | |
| torch.linspace(energy_min, energy_max, n_bins - 1), | |
| requires_grad=False, | |
| ) | |
| self.pitch_embedding = nn.Embedding( | |
| n_bins, cfg.model.transformer.encoder_hidden | |
| ) | |
| self.energy_embedding = nn.Embedding( | |
| n_bins, cfg.model.transformer.encoder_hidden | |
| ) | |
| def get_pitch_embedding(self, x, target, mask, control): | |
| prediction = self.pitch_predictor(x, mask) | |
| if target is not None: | |
| embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins)) | |
| else: | |
| prediction = prediction * control | |
| embedding = self.pitch_embedding( | |
| torch.bucketize(prediction, self.pitch_bins) | |
| ) | |
| return prediction, embedding | |
| def get_energy_embedding(self, x, target, mask, control): | |
| prediction = self.energy_predictor(x, mask) | |
| if target is not None: | |
| embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins)) | |
| else: | |
| prediction = prediction * control | |
| embedding = self.energy_embedding( | |
| torch.bucketize(prediction, self.energy_bins) | |
| ) | |
| return prediction, embedding | |
| def forward( | |
| self, | |
| x, | |
| src_mask, | |
| mel_mask=None, | |
| max_len=None, | |
| pitch_target=None, | |
| energy_target=None, | |
| duration_target=None, | |
| p_control=1.0, | |
| e_control=1.0, | |
| d_control=1.0, | |
| ): | |
| log_duration_prediction = self.duration_predictor(x, src_mask) | |
| if self.pitch_feature_level == "phoneme_level": | |
| pitch_prediction, pitch_embedding = self.get_pitch_embedding( | |
| x, pitch_target, src_mask, p_control | |
| ) | |
| x = x + pitch_embedding | |
| if self.energy_feature_level == "phoneme_level": | |
| energy_prediction, energy_embedding = self.get_energy_embedding( | |
| x, energy_target, src_mask, p_control | |
| ) | |
| x = x + energy_embedding | |
| if duration_target is not None: | |
| x, mel_len = self.length_regulator(x, duration_target, max_len) | |
| duration_rounded = duration_target | |
| else: | |
| duration_rounded = torch.clamp( | |
| (torch.round(torch.exp(log_duration_prediction) - 1) * d_control), | |
| min=0, | |
| ) | |
| x, mel_len = self.length_regulator(x, duration_rounded, max_len) | |
| mel_mask = get_mask_from_lengths(mel_len) | |
| if self.pitch_feature_level == "frame_level": | |
| pitch_prediction, pitch_embedding = self.get_pitch_embedding( | |
| x, pitch_target, mel_mask, p_control | |
| ) | |
| x = x + pitch_embedding | |
| if self.energy_feature_level == "frame_level": | |
| energy_prediction, energy_embedding = self.get_energy_embedding( | |
| x, energy_target, mel_mask, p_control | |
| ) | |
| x = x + energy_embedding | |
| return ( | |
| x, | |
| pitch_prediction, | |
| energy_prediction, | |
| log_duration_prediction, | |
| duration_rounded, | |
| mel_len, | |
| mel_mask, | |
| ) | |
| class LengthRegulator(nn.Module): | |
| """Length Regulator""" | |
| def __init__(self): | |
| super(LengthRegulator, self).__init__() | |
| def LR(self, x, duration, max_len): | |
| device = x.device | |
| output = list() | |
| mel_len = list() | |
| for batch, expand_target in zip(x, duration): | |
| expanded = self.expand(batch, expand_target) | |
| output.append(expanded) | |
| mel_len.append(expanded.shape[0]) | |
| if max_len is not None: | |
| output = pad(output, max_len) | |
| else: | |
| output = pad(output) | |
| return output, torch.LongTensor(mel_len).to(device) | |
| def expand(self, batch, predicted): | |
| out = list() | |
| for i, vec in enumerate(batch): | |
| expand_size = predicted[i].item() | |
| out.append(vec.expand(max(int(expand_size), 0), -1)) | |
| out = torch.cat(out, 0) | |
| return out | |
| def forward(self, x, duration, max_len): | |
| output, mel_len = self.LR(x, duration, max_len) | |
| return output, mel_len | |
| class VariancePredictor(nn.Module): | |
| """Duration, Pitch and Energy Predictor""" | |
| def __init__(self, cfg): | |
| super(VariancePredictor, self).__init__() | |
| self.input_size = cfg.model.transformer.encoder_hidden | |
| self.filter_size = cfg.model.variance_predictor.filter_size | |
| self.kernel = cfg.model.variance_predictor.kernel_size | |
| self.conv_output_size = cfg.model.variance_predictor.filter_size | |
| self.dropout = cfg.model.variance_predictor.dropout | |
| self.conv_layer = nn.Sequential( | |
| OrderedDict( | |
| [ | |
| ( | |
| "conv1d_1", | |
| Conv( | |
| self.input_size, | |
| self.filter_size, | |
| kernel_size=self.kernel, | |
| padding=(self.kernel - 1) // 2, | |
| ), | |
| ), | |
| ("relu_1", nn.ReLU()), | |
| ("layer_norm_1", nn.LayerNorm(self.filter_size)), | |
| ("dropout_1", nn.Dropout(self.dropout)), | |
| ( | |
| "conv1d_2", | |
| Conv( | |
| self.filter_size, | |
| self.filter_size, | |
| kernel_size=self.kernel, | |
| padding=1, | |
| ), | |
| ), | |
| ("relu_2", nn.ReLU()), | |
| ("layer_norm_2", nn.LayerNorm(self.filter_size)), | |
| ("dropout_2", nn.Dropout(self.dropout)), | |
| ] | |
| ) | |
| ) | |
| self.linear_layer = nn.Linear(self.conv_output_size, 1) | |
| def forward(self, encoder_output, mask): | |
| out = self.conv_layer(encoder_output) | |
| out = self.linear_layer(out) | |
| out = out.squeeze(-1) | |
| if mask is not None: | |
| out = out.masked_fill(mask, 0.0) | |
| return out | |
| class Conv(nn.Module): | |
| """ | |
| Convolution Module | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| bias=True, | |
| w_init="linear", | |
| ): | |
| """ | |
| :param in_channels: dimension of input | |
| :param out_channels: dimension of output | |
| :param kernel_size: size of kernel | |
| :param stride: size of stride | |
| :param padding: size of padding | |
| :param dilation: dilation rate | |
| :param bias: boolean. if True, bias is included. | |
| :param w_init: str. weight inits with xavier initialization. | |
| """ | |
| super(Conv, self).__init__() | |
| self.conv = nn.Conv1d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| bias=bias, | |
| ) | |
| def forward(self, x): | |
| x = x.contiguous().transpose(1, 2) | |
| x = self.conv(x) | |
| x = x.contiguous().transpose(1, 2) | |
| return x | |
| class FastSpeech2(nn.Module): | |
| def __init__(self, cfg) -> None: | |
| super(FastSpeech2, self).__init__() | |
| self.cfg = cfg | |
| self.encoder = Encoder(cfg.model) | |
| self.variance_adaptor = VarianceAdaptor(cfg) | |
| self.decoder = Decoder(cfg.model) | |
| self.mel_linear = nn.Linear( | |
| cfg.model.transformer.decoder_hidden, | |
| cfg.preprocess.n_mel, | |
| ) | |
| self.postnet = PostNet(n_mel_channels=cfg.preprocess.n_mel) | |
| self.speaker_emb = None | |
| if cfg.train.multi_speaker_training: | |
| with open( | |
| os.path.join( | |
| cfg.preprocess.processed_dir, cfg.dataset[0], "spk2id.json" | |
| ), | |
| "r", | |
| ) as f: | |
| n_speaker = len(json.load(f)) | |
| self.speaker_emb = nn.Embedding( | |
| n_speaker, | |
| cfg.model.transformer.encoder_hidden, | |
| ) | |
| def forward(self, data, p_control=1.0, e_control=1.0, d_control=1.0): | |
| speakers = data["spk_id"] | |
| texts = data["texts"] | |
| src_lens = data["text_len"] | |
| max_src_len = max(src_lens) | |
| mels = data["mel"] if "mel" in data else None | |
| mel_lens = data["target_len"] if "target_len" in data else None | |
| max_mel_len = max(mel_lens) if "target_len" in data else None | |
| p_targets = data["pitch"] if "pitch" in data else None | |
| e_targets = data["energy"] if "energy" in data else None | |
| d_targets = data["durations"] if "durations" in data else None | |
| src_masks = data["text_mask"].squeeze(-1) > 0 | |
| src_masks = ~src_masks | |
| mel_masks = ( | |
| get_mask_from_lengths(mel_lens, max_mel_len) | |
| if mel_lens is not None | |
| else None | |
| ) | |
| output = self.encoder(texts, src_masks) | |
| if self.speaker_emb is not None: | |
| output = output + self.speaker_emb(speakers).unsqueeze(1).expand( | |
| -1, max_src_len, -1 | |
| ) | |
| ( | |
| output, | |
| p_predictions, | |
| e_predictions, | |
| log_d_predictions, | |
| d_rounded, | |
| mel_lens, | |
| mel_masks, | |
| ) = self.variance_adaptor( | |
| output, | |
| src_masks, | |
| mel_masks, | |
| max_mel_len, | |
| p_targets, | |
| e_targets, | |
| d_targets, | |
| p_control, | |
| e_control, | |
| d_control, | |
| ) | |
| output, mel_masks = self.decoder(output, mel_masks) | |
| output = self.mel_linear(output) | |
| postnet_output = self.postnet(output) + output | |
| return { | |
| "output": output, | |
| "postnet_output": postnet_output, | |
| "p_predictions": p_predictions, | |
| "e_predictions": e_predictions, | |
| "log_d_predictions": log_d_predictions, | |
| "d_rounded": d_rounded, | |
| "src_masks": src_masks, | |
| "mel_masks": mel_masks, | |
| "src_lens": src_lens, | |
| "mel_lens": mel_lens, | |
| } | |
| class FastSpeech2Loss(nn.Module): | |
| """FastSpeech2 Loss""" | |
| def __init__(self, cfg): | |
| super(FastSpeech2Loss, self).__init__() | |
| if cfg.preprocess.use_frame_pitch: | |
| self.pitch_feature_level = "frame_level" | |
| else: | |
| self.pitch_feature_level = "phoneme_level" | |
| if cfg.preprocess.use_frame_energy: | |
| self.energy_feature_level = "frame_level" | |
| else: | |
| self.energy_feature_level = "phoneme_level" | |
| self.mse_loss = nn.MSELoss() | |
| self.mae_loss = nn.L1Loss() | |
| def forward(self, data, predictions): | |
| mel_targets = data["mel"] | |
| pitch_targets = data["pitch"].float() | |
| energy_targets = data["energy"].float() | |
| duration_targets = data["durations"] | |
| mel_predictions = predictions["output"] | |
| postnet_mel_predictions = predictions["postnet_output"] | |
| pitch_predictions = predictions["p_predictions"] | |
| energy_predictions = predictions["e_predictions"] | |
| log_duration_predictions = predictions["log_d_predictions"] | |
| src_masks = predictions["src_masks"] | |
| mel_masks = predictions["mel_masks"] | |
| src_masks = ~src_masks | |
| mel_masks = ~mel_masks | |
| log_duration_targets = torch.log(duration_targets.float() + 1) | |
| mel_targets = mel_targets[:, : mel_masks.shape[1], :] | |
| mel_masks = mel_masks[:, : mel_masks.shape[1]] | |
| log_duration_targets.requires_grad = False | |
| pitch_targets.requires_grad = False | |
| energy_targets.requires_grad = False | |
| mel_targets.requires_grad = False | |
| if self.pitch_feature_level == "phoneme_level": | |
| pitch_predictions = pitch_predictions.masked_select(src_masks) | |
| pitch_targets = pitch_targets.masked_select(src_masks) | |
| elif self.pitch_feature_level == "frame_level": | |
| pitch_predictions = pitch_predictions.masked_select(mel_masks) | |
| pitch_targets = pitch_targets.masked_select(mel_masks) | |
| if self.energy_feature_level == "phoneme_level": | |
| energy_predictions = energy_predictions.masked_select(src_masks) | |
| energy_targets = energy_targets.masked_select(src_masks) | |
| if self.energy_feature_level == "frame_level": | |
| energy_predictions = energy_predictions.masked_select(mel_masks) | |
| energy_targets = energy_targets.masked_select(mel_masks) | |
| log_duration_predictions = log_duration_predictions.masked_select(src_masks) | |
| log_duration_targets = log_duration_targets.masked_select(src_masks) | |
| mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1)) | |
| postnet_mel_predictions = postnet_mel_predictions.masked_select( | |
| mel_masks.unsqueeze(-1) | |
| ) | |
| mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1)) | |
| mel_loss = self.mae_loss(mel_predictions, mel_targets) | |
| postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets) | |
| pitch_loss = self.mse_loss(pitch_predictions, pitch_targets) | |
| energy_loss = self.mse_loss(energy_predictions, energy_targets) | |
| duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets) | |
| total_loss = ( | |
| mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss | |
| ) | |
| return { | |
| "loss": total_loss, | |
| "mel_loss": mel_loss, | |
| "postnet_mel_loss": postnet_mel_loss, | |
| "pitch_loss": pitch_loss, | |
| "energy_loss": energy_loss, | |
| "duration_loss": duration_loss, | |
| } | |