Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| from coqpit import Coqpit | |
| from torch.nn import functional as F | |
| from torch.utils.data import DataLoader | |
| from trainer.torch import DistributedSampler | |
| from trainer.trainer_utils import get_optimizer, get_scheduler | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| from TTS.tts.datasets.dataset import TTSDataset | |
| from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram | |
| from TTS.tts.layers.xtts.dvae import DiscreteVAE | |
| from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer | |
| from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset | |
| from TTS.tts.models.base_tts import BaseTTS | |
| from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig | |
| from TTS.utils.io import load_fsspec | |
| class GPTTrainerConfig(XttsConfig): | |
| lr: float = 5e-06 | |
| training_seed: int = 1 | |
| optimizer_wd_only_on_weights: bool = False | |
| weighted_loss_attrs: dict = field(default_factory=lambda: {}) | |
| weighted_loss_multipliers: dict = field(default_factory=lambda: {}) | |
| test_sentences: List[dict] = field(default_factory=lambda: []) | |
| class XttsAudioConfig(XttsAudioConfig): | |
| dvae_sample_rate: int = 22050 | |
| class GPTArgs(XttsArgs): | |
| min_conditioning_length: int = 66150 | |
| max_conditioning_length: int = 132300 | |
| gpt_loss_text_ce_weight: float = 0.01 | |
| gpt_loss_mel_ce_weight: float = 1.0 | |
| gpt_num_audio_tokens: int = 8194 | |
| debug_loading_failures: bool = False | |
| max_wav_length: int = 255995 # ~11.6 seconds | |
| max_text_length: int = 200 | |
| tokenizer_file: str = "" | |
| mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth" | |
| dvae_checkpoint: str = "" | |
| xtts_checkpoint: str = "" | |
| gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model | |
| vocoder: str = "" # overide vocoder key on the config to avoid json write issues | |
| def callback_clearml_load_save(operation_type, model_info): | |
| # return None means skip the file upload/log, returning model_info will continue with the log/upload | |
| # you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size | |
| assert operation_type in ("load", "save") | |
| # print(operation_type, model_info.__dict__) | |
| if "similarities.pth" in model_info.__dict__["local_model_path"]: | |
| return None | |
| return model_info | |
| class GPTTrainer(BaseTTS): | |
| def __init__(self, config: Coqpit): | |
| """ | |
| Tortoise GPT training class | |
| """ | |
| super().__init__(config, ap=None, tokenizer=None) | |
| self.config = config | |
| # init XTTS model | |
| self.xtts = Xtts(self.config) | |
| # create the tokenizer with the target vocabulary | |
| self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file) | |
| # init gpt encoder and hifigan decoder | |
| self.xtts.init_models() | |
| if self.args.xtts_checkpoint: | |
| self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False) | |
| # set mel stats | |
| if self.args.mel_norm_file: | |
| self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file) | |
| # load GPT if available | |
| if self.args.gpt_checkpoint: | |
| gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu")) | |
| # deal with coqui Trainer exported model | |
| if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): | |
| print("Coqui Trainer checkpoint detected! Converting it!") | |
| gpt_checkpoint = gpt_checkpoint["model"] | |
| states_keys = list(gpt_checkpoint.keys()) | |
| for key in states_keys: | |
| if "gpt." in key: | |
| new_key = key.replace("gpt.", "") | |
| gpt_checkpoint[new_key] = gpt_checkpoint[key] | |
| del gpt_checkpoint[key] | |
| else: | |
| del gpt_checkpoint[key] | |
| # edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible | |
| if ( | |
| "text_embedding.weight" in gpt_checkpoint | |
| and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape | |
| ): | |
| num_new_tokens = ( | |
| self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] | |
| ) | |
| print(f" > Loading checkpoint with {num_new_tokens} additional tokens.") | |
| # add new tokens to a linear layer (text_head) | |
| emb_g = gpt_checkpoint["text_embedding.weight"] | |
| new_row = torch.randn(num_new_tokens, emb_g.shape[1]) | |
| start_token_row = emb_g[-1, :] | |
| emb_g = torch.cat([emb_g, new_row], axis=0) | |
| emb_g[-1, :] = start_token_row | |
| gpt_checkpoint["text_embedding.weight"] = emb_g | |
| # add new weights to the linear layer (text_head) | |
| text_head_weight = gpt_checkpoint["text_head.weight"] | |
| start_token_row = text_head_weight[-1, :] | |
| new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1]) | |
| text_head_weight = torch.cat([text_head_weight, new_entry], axis=0) | |
| text_head_weight[-1, :] = start_token_row | |
| gpt_checkpoint["text_head.weight"] = text_head_weight | |
| # add new biases to the linear layer (text_head) | |
| text_head_bias = gpt_checkpoint["text_head.bias"] | |
| start_token_row = text_head_bias[-1] | |
| new_bias_entry = torch.zeros(num_new_tokens) | |
| text_head_bias = torch.cat([text_head_bias, new_bias_entry], axis=0) | |
| text_head_bias[-1] = start_token_row | |
| gpt_checkpoint["text_head.bias"] = text_head_bias | |
| self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True) | |
| print(">> GPT weights restored from:", self.args.gpt_checkpoint) | |
| # Mel spectrogram extractor for conditioning | |
| if self.args.gpt_use_perceiver_resampler: | |
| self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram( | |
| filter_length=2048, | |
| hop_length=256, | |
| win_length=1024, | |
| normalize=False, | |
| sampling_rate=config.audio.sample_rate, | |
| mel_fmin=0, | |
| mel_fmax=8000, | |
| n_mel_channels=80, | |
| mel_norm_file=self.args.mel_norm_file, | |
| ) | |
| else: | |
| self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram( | |
| filter_length=4096, | |
| hop_length=1024, | |
| win_length=4096, | |
| normalize=False, | |
| sampling_rate=config.audio.sample_rate, | |
| mel_fmin=0, | |
| mel_fmax=8000, | |
| n_mel_channels=80, | |
| mel_norm_file=self.args.mel_norm_file, | |
| ) | |
| # Load DVAE | |
| self.dvae = DiscreteVAE( | |
| channels=80, | |
| normalization=None, | |
| positional_dims=1, | |
| num_tokens=self.args.gpt_num_audio_tokens - 2, | |
| codebook_dim=512, | |
| hidden_dim=512, | |
| num_resnet_blocks=3, | |
| kernel_size=3, | |
| num_layers=2, | |
| use_transposed_convs=False, | |
| ) | |
| self.dvae.eval() | |
| if self.args.dvae_checkpoint: | |
| dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu")) | |
| self.dvae.load_state_dict(dvae_checkpoint, strict=False) | |
| print(">> DVAE weights restored from:", self.args.dvae_checkpoint) | |
| else: | |
| raise RuntimeError( | |
| "You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!" | |
| ) | |
| # Mel spectrogram extractor for DVAE | |
| self.torch_mel_spectrogram_dvae = TorchMelSpectrogram( | |
| mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate | |
| ) | |
| def device(self): | |
| return next(self.parameters()).device | |
| def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens): | |
| """ | |
| Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode | |
| (actuated by `text_first`). | |
| text_inputs: long tensor, (b,t) | |
| text_lengths: long tensor, (b,) | |
| mel_inputs: long tensor, (b,m) | |
| wav_lengths: long tensor, (b,) | |
| cond_mels: MEL float tensor, (b, num_samples, 80,t_m) | |
| cond_idxs: cond start and end indexs, (b, 2) | |
| cond_lens: long tensor, (b,) | |
| """ | |
| losses = self.xtts.gpt( | |
| text_inputs, | |
| text_lengths, | |
| audio_codes, | |
| wav_lengths, | |
| cond_mels=cond_mels, | |
| cond_idxs=cond_idxs, | |
| cond_lens=cond_lens, | |
| ) | |
| return losses | |
| def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613 | |
| if self.config.test_sentences: | |
| # init gpt for inference mode | |
| self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) | |
| self.xtts.gpt.eval() | |
| test_audios = {} | |
| print(" | > Synthesizing test sentences.") | |
| for idx, s_info in enumerate(self.config.test_sentences): | |
| wav = self.xtts.synthesize( | |
| s_info["text"], | |
| self.config, | |
| s_info["speaker_wav"], | |
| s_info["language"], | |
| gpt_cond_len=3, | |
| )["wav"] | |
| test_audios["{}-audio".format(idx)] = wav | |
| # delete inference layers | |
| del self.xtts.gpt.gpt_inference | |
| del self.xtts.gpt.gpt.wte | |
| return {"audios": test_audios} | |
| def test_log( | |
| self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument | |
| ) -> None: | |
| logger.test_audios(steps, outputs["audios"], self.args.output_sample_rate) | |
| def format_batch(self, batch: Dict) -> Dict: | |
| return batch | |
| # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction | |
| def format_batch_on_device(self, batch): | |
| """Compute spectrograms on the device.""" | |
| batch["text_lengths"] = batch["text_lengths"] | |
| batch["wav_lengths"] = batch["wav_lengths"] | |
| batch["text_inputs"] = batch["padded_text"] | |
| batch["cond_idxs"] = batch["cond_idxs"] | |
| # compute conditioning mel specs | |
| # transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor | |
| B, num_cond_samples, C, T = batch["conditioning"].size() | |
| conditioning_reshaped = batch["conditioning"].view(B * num_cond_samples, C, T) | |
| paired_conditioning_mel = self.torch_mel_spectrogram_style_encoder(conditioning_reshaped) | |
| # transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel]) | |
| n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1) | |
| T_mel = paired_conditioning_mel.size(2) | |
| paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel) | |
| # get the conditioning embeddings | |
| batch["cond_mels"] = paired_conditioning_mel | |
| # compute codes using DVAE | |
| if self.config.audio.sample_rate != self.config.audio.dvae_sample_rate: | |
| dvae_wav = torchaudio.functional.resample( | |
| batch["wav"], | |
| orig_freq=self.config.audio.sample_rate, | |
| new_freq=self.config.audio.dvae_sample_rate, | |
| lowpass_filter_width=64, | |
| rolloff=0.9475937167399596, | |
| resampling_method="kaiser_window", | |
| beta=14.769656459379492, | |
| ) | |
| else: | |
| dvae_wav = batch["wav"] | |
| dvae_mel_spec = self.torch_mel_spectrogram_dvae(dvae_wav) | |
| codes = self.dvae.get_codebook_indices(dvae_mel_spec) | |
| batch["audio_codes"] = codes | |
| # delete useless batch tensors | |
| del batch["padded_text"] | |
| del batch["wav"] | |
| del batch["conditioning"] | |
| return batch | |
| def train_step(self, batch, criterion): | |
| loss_dict = {} | |
| cond_mels = batch["cond_mels"] | |
| text_inputs = batch["text_inputs"] | |
| text_lengths = batch["text_lengths"] | |
| audio_codes = batch["audio_codes"] | |
| wav_lengths = batch["wav_lengths"] | |
| cond_idxs = batch["cond_idxs"] | |
| cond_lens = batch["cond_lens"] | |
| loss_text, loss_mel, _ = self.forward( | |
| text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens | |
| ) | |
| loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight | |
| loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight | |
| loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"] | |
| return {"model_outputs": None}, loss_dict | |
| def eval_step(self, batch, criterion): | |
| # ignore masking for more consistent evaluation | |
| batch["cond_idxs"] = None | |
| return self.train_step(batch, criterion) | |
| def on_train_epoch_start(self, trainer): | |
| trainer.model.eval() # the whole model to eval | |
| # put gpt model in training mode | |
| trainer.model.xtts.gpt.train() | |
| def on_init_end(self, trainer): # pylint: disable=W0613 | |
| # ignore similarities.pth on clearml save/upload | |
| if self.config.dashboard_logger.lower() == "clearml": | |
| from clearml.binding.frameworks import WeightsFileHandler | |
| WeightsFileHandler.add_pre_callback(callback_clearml_load_save) | |
| def inference( | |
| self, | |
| x, | |
| aux_input=None, | |
| ): # pylint: disable=dangerous-default-value | |
| return None | |
| def get_criterion(): | |
| return None | |
| def get_sampler(self, dataset: TTSDataset, num_gpus=1): | |
| # sampler for DDP | |
| batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None | |
| return batch_sampler | |
| def get_data_loader( | |
| self, | |
| config: Coqpit, | |
| assets: Dict, | |
| is_eval: bool, | |
| samples: Union[List[Dict], List[List]], | |
| verbose: bool, | |
| num_gpus: int, | |
| rank: int = None, | |
| ) -> "DataLoader": # pylint: disable=W0613 | |
| if is_eval and not config.run_eval: | |
| loader = None | |
| else: | |
| # init dataloader | |
| dataset = XTTSDataset(self.config, samples, self.xtts.tokenizer, config.audio.sample_rate, is_eval) | |
| # wait all the DDP process to be ready | |
| if num_gpus > 1: | |
| torch.distributed.barrier() | |
| # sort input sequences from short to long | |
| # dataset.preprocess_samples() | |
| # get samplers | |
| sampler = self.get_sampler(dataset, num_gpus) | |
| # ignore sampler when is eval because if we changed the sampler parameter we will not be able to compare previous runs | |
| if sampler is None or is_eval: | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=config.eval_batch_size if is_eval else config.batch_size, | |
| shuffle=False, | |
| drop_last=False, | |
| collate_fn=dataset.collate_fn, | |
| num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, | |
| pin_memory=False, | |
| ) | |
| else: | |
| loader = DataLoader( | |
| dataset, | |
| batch_sampler=sampler, | |
| collate_fn=dataset.collate_fn, | |
| num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, | |
| pin_memory=False, | |
| ) | |
| return loader | |
| def get_optimizer(self) -> List: | |
| """Initiate and return the optimizer based on the config parameters.""" | |
| # ToDo: deal with multi GPU training | |
| if self.config.optimizer_wd_only_on_weights: | |
| # parameters to only GPT model | |
| net = self.xtts.gpt | |
| # normalizations | |
| norm_modules = ( | |
| nn.BatchNorm2d, | |
| nn.InstanceNorm2d, | |
| nn.BatchNorm1d, | |
| nn.InstanceNorm1d, | |
| nn.BatchNorm3d, | |
| nn.InstanceNorm3d, | |
| nn.GroupNorm, | |
| nn.LayerNorm, | |
| ) | |
| # nn.Embedding | |
| emb_modules = (nn.Embedding, nn.EmbeddingBag) | |
| param_names_notweights = set() | |
| all_param_names = set() | |
| param_map = {} | |
| for mn, m in net.named_modules(): | |
| for k, v in m.named_parameters(): | |
| v.is_bias = k.endswith(".bias") | |
| v.is_weight = k.endswith(".weight") | |
| v.is_norm = isinstance(m, norm_modules) | |
| v.is_emb = isinstance(m, emb_modules) | |
| fpn = "%s.%s" % (mn, k) if mn else k # full param name | |
| all_param_names.add(fpn) | |
| param_map[fpn] = v | |
| if v.is_bias or v.is_norm or v.is_emb: | |
| param_names_notweights.add(fpn) | |
| params_names_notweights = sorted(list(param_names_notweights)) | |
| params_notweights = [param_map[k] for k in params_names_notweights] | |
| params_names_weights = sorted(list(all_param_names ^ param_names_notweights)) | |
| params_weights = [param_map[k] for k in params_names_weights] | |
| groups = [ | |
| {"params": params_weights, "weight_decay": self.config.optimizer_params["weight_decay"]}, | |
| {"params": params_notweights, "weight_decay": 0}, | |
| ] | |
| # torch.optim.AdamW | |
| opt = get_optimizer( | |
| self.config.optimizer, | |
| self.config.optimizer_params, | |
| self.config.lr, | |
| parameters=groups, | |
| ) | |
| opt._group_names = [params_names_weights, params_names_notweights] | |
| return opt | |
| return get_optimizer( | |
| self.config.optimizer, | |
| self.config.optimizer_params, | |
| self.config.lr, | |
| # optimize only for the GPT model | |
| parameters=self.xtts.gpt.parameters(), | |
| ) | |
| def get_scheduler(self, optimizer) -> List: | |
| """Set the scheduler for the optimizer. | |
| Args: | |
| optimizer: `torch.optim.Optimizer`. | |
| """ | |
| return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) | |
| def load_checkpoint( | |
| self, | |
| config, | |
| checkpoint_path, | |
| eval=False, | |
| strict=True, | |
| cache_storage="/tmp/tts_cache", | |
| target_protocol="s3", | |
| target_options={"anon": True}, | |
| ): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin | |
| """Load the model checkpoint and setup for training or inference""" | |
| state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path) | |
| # load the model weights | |
| self.xtts.load_state_dict(state, strict=strict) | |
| if eval: | |
| self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) | |
| self.eval() | |
| assert not self.training | |
| def init_from_config(config: "GPTTrainerConfig", samples: Union[List[List], List[Dict]] = None): | |
| """Initiate model from config | |
| Args: | |
| config (GPTTrainerConfig): Model config. | |
| samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. | |
| Defaults to None. | |
| """ | |
| return GPTTrainer(config) | |