Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import transformers | |
| from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config | |
| from modules.finetune.utils.output import ansi, get_ansi_len, output_iter | |
| from .utils.dataset import AudioCollator, XzListTar | |
| from .utils.logger import MetricLogger | |
| from .utils.model import quantize | |
| IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index | |
| def train_speaker_embeddings( | |
| chat, | |
| dataset, | |
| gpt, | |
| batch_size=16, | |
| epochs=10, | |
| train_text=True, | |
| speaker_embeds=None, | |
| ): | |
| tokenizer = chat.pretrain_models["tokenizer"] | |
| decoder_decoder = chat.pretrain_models["decoder"] | |
| decoder_decoder.eval().requires_grad_(False) | |
| decoder_encoder = DVAEEncoder(**get_encoder_config(decoder_decoder.decoder)).to( | |
| device=dataset.device | |
| ) | |
| decoder_encoder.eval().requires_grad_(False) | |
| dvae_decoder = chat.pretrain_models["dvae"] | |
| dvae_decoder.eval().requires_grad_(False) | |
| dvae_encoder = DVAEEncoder(**get_encoder_config(dvae_decoder.decoder)).to( | |
| device=dataset.device | |
| ) | |
| dvae_encoder.eval().requires_grad_(False) | |
| if speaker_embeds is None: | |
| speaker_embeds = { | |
| speaker: torch.randn( | |
| 768, | |
| device=dataset.device, | |
| requires_grad=True, | |
| ) | |
| for speaker in dataset.speakers | |
| } | |
| for speaker_embed in speaker_embeds.values(): | |
| std, mean = chat.pretrain_models["spk_stat"].chunk(2) | |
| speaker_embed.data = speaker_embed.data * std + mean | |
| SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]") | |
| AUDIO_EOS_TOKEN_ID = 0 | |
| AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID | |
| optimizer = torch.optim.Adam( | |
| speaker_embeds.values(), lr=1e-2, weight_decay=0, betas=[0.9, 0.95], eps=1e-5 | |
| ) | |
| loss_fn = torch.nn.CrossEntropyLoss() | |
| lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7) | |
| loader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id), | |
| ) | |
| logger = MetricLogger() | |
| logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None) | |
| for _epoch in range(epochs): | |
| _epoch += 1 | |
| logger.reset() | |
| header = "{blue_light}{0}: {1}{reset}".format( | |
| "Epoch", output_iter(_epoch, epochs), **ansi | |
| ) | |
| header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header)) | |
| iterator = logger.log_every(loader, header=header, tqdm_header="Batch") | |
| for batch in iterator: | |
| speakers = batch["speaker"] | |
| text_input_ids = batch["text_input_ids"] | |
| text_attention_mask = batch["text_attention_mask"] | |
| audio_mel_specs = batch["audio_mel_specs"] | |
| audio_attention_mask = batch["audio_attention_mask"] | |
| batch_size, text_len = text_attention_mask.size() | |
| dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask) | |
| _, dvae_audio_input_ids = quantize( | |
| dvae_decoder.vq_layer.quantizer, dvae_audio_latents | |
| ) | |
| dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID | |
| extended_audio_attention_mask = torch.cat( | |
| [ | |
| audio_attention_mask, | |
| torch.zeros( | |
| (batch_size, 1), | |
| dtype=audio_attention_mask.dtype, | |
| device=audio_attention_mask.device, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| extended_audio_input_ids = torch.cat( | |
| [ | |
| dvae_audio_input_ids, | |
| AUDIO_PAD_TOKEN_ID | |
| * torch.ones( | |
| (batch_size, 1, gpt.num_vq), | |
| dtype=dvae_audio_input_ids.dtype, | |
| device=dvae_audio_input_ids.device, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| indices = audio_attention_mask.int().sum(dim=1) | |
| for i in range(batch_size): | |
| extended_audio_attention_mask[i, indices[i]] = 1 | |
| extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID | |
| input_ids = torch.cat( | |
| [ | |
| text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq), | |
| extended_audio_input_ids, | |
| ], | |
| dim=1, | |
| ) | |
| attention_mask = torch.cat( | |
| [text_attention_mask, extended_audio_attention_mask], dim=1 | |
| ) | |
| text_mask = torch.cat( | |
| [ | |
| torch.ones_like(text_attention_mask, dtype=bool), | |
| torch.zeros_like(extended_audio_attention_mask, dtype=bool), | |
| ], | |
| dim=1, | |
| ) | |
| labels = input_ids.clone() | |
| labels[~attention_mask.bool()] = IGNORE_TOKEN_ID | |
| inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask) | |
| indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1) | |
| for i, speaker in enumerate(speakers): | |
| inputs_embeds[i, indices[i]] = F.normalize( | |
| speaker_embeds[speaker].to(dtype=inputs_embeds.dtype), | |
| p=2.0, | |
| dim=-1, | |
| eps=1e-12, | |
| ).unsqueeze(0) | |
| outputs = gpt.gpt.forward( | |
| inputs_embeds=inputs_embeds, attention_mask=attention_mask | |
| ) | |
| hidden_states = outputs.last_hidden_state | |
| text_hidden_states = hidden_states[:, : text_len - 1] | |
| audio_hidden_states = hidden_states[:, text_len - 1 : -1] | |
| audio_logits = torch.stack( | |
| [gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)], | |
| dim=2, | |
| ) | |
| audio_loss = loss_fn( | |
| audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2) | |
| ) | |
| loss = audio_loss | |
| text_logits = gpt.head_text(text_hidden_states) | |
| text_loss = loss_fn( | |
| text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1) | |
| ) | |
| loss += text_loss | |
| logger.meters["text_loss"].update(text_loss.item(), n=batch_size) | |
| gpt_gen_mel_specs = decoder_decoder( | |
| audio_hidden_states[:, :-1].transpose(1, 2) | |
| ).transpose(1, 2) | |
| mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs) | |
| loss += 0.01 * mse_loss | |
| optimizer.zero_grad() | |
| if train_text: | |
| # just for test | |
| text_loss.backward() | |
| else: | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0) | |
| optimizer.step() | |
| logger.meters["loss"].update(loss.item(), n=batch_size) | |
| logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size) | |
| logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size) | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| return speaker_embeds | |
| if __name__ == "__main__": | |
| import argparse | |
| import os | |
| import pathlib | |
| import numpy as np | |
| from modules import config | |
| from modules.devices import devices | |
| from modules.models import load_chat_tts | |
| from modules.speaker import Speaker | |
| config.runtime_env_vars.no_half = True | |
| config.runtime_env_vars.use_cpu = [] | |
| devices.reset_device() | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--save_folder", type=str, default="./") | |
| parser.add_argument("--batch_size", type=int, default=16) | |
| parser.add_argument("--epochs", type=int, default=100) | |
| parser.add_argument("--train_text", action="store_true", help="train text loss") | |
| # 初始化 speaker | |
| parser.add_argument("--init_speaker", type=str) | |
| parser.add_argument( | |
| "--data_path", | |
| type=str, | |
| default="datasets/data_speaker_a/speaker_a.list", | |
| help="the data_path to json/list file", | |
| ) | |
| parser.add_argument("--tar_path", type=str, help="the tarball path with wavs") | |
| parser.add_argument( | |
| "--tar_in_memory", action="store_true", help="load tarball in memory" | |
| ) | |
| args = parser.parse_args() | |
| data_path: str = args.data_path | |
| tar_path: str | None = args.tar_path | |
| tar_in_memory: bool = args.tar_in_memory | |
| train_text: bool = args.train_text | |
| # gpt_lora: bool = args.gpt_lora | |
| # gpt_kbit: int = args.gpt_kbit | |
| save_folder: str = args.save_folder | |
| batch_size: int = args.batch_size | |
| epochs: int = args.epochs | |
| init_speaker: str = args.init_speaker | |
| speaker_embeds_save_path = os.path.join(save_folder, "speaker_embeds.npz") | |
| chat = load_chat_tts() | |
| dataset = XzListTar( | |
| root=data_path, | |
| tokenizer=chat.pretrain_models["tokenizer"], | |
| vocos_model=chat.pretrain_models["vocos"], | |
| tar_path=tar_path, | |
| tar_in_memory=tar_in_memory, | |
| device=devices.get_device_for("trainer"), | |
| # speakers=None, # set(['speaker_A', 'speaker_B']) | |
| ) | |
| print("len(dataset)", len(dataset)) | |
| speaker_embeds = None | |
| if init_speaker: | |
| spk: Speaker = Speaker.from_file(init_speaker) | |
| speaker_embeds = { | |
| speaker: torch.tensor( | |
| spk.emb, | |
| device=devices.get_device_for("trainer"), | |
| requires_grad=True, | |
| ) | |
| for speaker in dataset.speakers | |
| } | |
| speaker_embeds = train_speaker_embeddings( | |
| chat, | |
| dataset, | |
| chat.pretrain_models["gpt"], | |
| batch_size=batch_size, | |
| epochs=epochs, | |
| train_text=train_text, | |
| speaker_embeds=speaker_embeds, | |
| ) | |
| speaker_outs = { | |
| speaker: Speaker(speaker_embed.detach().cpu(), f"ep{epochs}_{speaker}") | |
| for speaker, speaker_embed in speaker_embeds.items() | |
| } | |
| time_str = np.datetime_as_string(np.datetime64("now", "s")) | |
| time_str = time_str.replace(":", "_").replace(" ", "_").replace("-", "_") | |
| for speaker, speaker_out in speaker_outs.items(): | |
| torch.save( | |
| speaker_out, | |
| pathlib.Path(save_folder) / f"spk_{speaker}_{time_str}_ep{epochs}.pt", | |
| ) | |
| # example | |
| """ | |
| python -m modules.finetune.train_speaker \ | |
| --data_path datasets/data_speaker_a/speaker_a.list \ | |
| --save_folder ./data \ | |
| --init_speaker ./data/speakers/Bob.pt \ | |
| --epochs 100 \ | |
| --batch_size 6 \ | |
| --train_text | |
| """ | |