Spaces:
Runtime error
Runtime error
| """A2S model definition. | |
| Copyright PolyAI Limited. | |
| """ | |
| from typing import Union | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from einops import rearrange | |
| import constants as c | |
| from modules import masking_logic | |
| from modules.conformer import Conformer | |
| from modules.masking_logic import (State, mask_by_random_topk, | |
| sample_from_logits, state_init) | |
| from utils import load_checkpoint | |
| class Pheme(pl.LightningModule): | |
| def __init__(self, hp): | |
| super().__init__() | |
| self.hp = hp | |
| self.model = TTSConformer(hp) | |
| self.cross_entropy = nn.CrossEntropyLoss( | |
| label_smoothing=self.hp.label_smoothing, | |
| ignore_index=self.hp.n_codes | |
| ) | |
| if self.hp.pretrained_path: | |
| self.load() | |
| else: | |
| self.apply(self.init_weights) | |
| if self.hp.only_inference: | |
| self.model.eval() | |
| self.save_hyperparameters() | |
| def load(self): | |
| state_dict = load_checkpoint(self.hp.pretrained_path) | |
| print(f"Parameters loaded from {self.hp.pretrained_path}") | |
| self.load_state_dict(state_dict, strict=True) | |
| def init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| if isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| module._fill_padding_idx_with_zero() | |
| elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| elif isinstance(module, nn.Conv1d): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| def configure_optimizers(self): | |
| optimizer_adam = optim.AdamW( | |
| self.parameters(), lr=self.hp.lr, | |
| betas=(self.hp.adam_beta1, self.hp.adam_beta2)) | |
| # Learning rate scheduler | |
| num_training_steps = self.hp.training_step | |
| num_warmup_steps = self.hp.warmup_step | |
| num_flat_steps = int(self.hp.optim_flat_percent * num_training_steps) | |
| def lambda_lr(current_step: int): | |
| if current_step < num_warmup_steps: | |
| return float(current_step) / float(max(1, num_warmup_steps)) | |
| elif current_step < (num_warmup_steps + num_flat_steps): | |
| return 1.0 | |
| return max( | |
| 0.0, | |
| float(num_training_steps - current_step) | |
| / float( | |
| max(1, num_training_steps - (num_warmup_steps + num_flat_steps)) # noqa | |
| ), | |
| ) | |
| scheduler_adam = { | |
| "scheduler": optim.lr_scheduler.LambdaLR( | |
| optimizer_adam, lambda_lr), | |
| "interval": "step", | |
| } | |
| return [optimizer_adam], [scheduler_adam] | |
| def top_k_accuracy(self, y_true, y_pred_probabilities, k): | |
| _, sorted_indices = torch.sort(y_pred_probabilities, descending=True) | |
| # Get the top-k predictions | |
| top_k_indices = sorted_indices[:, :k] | |
| expanded_y_true = y_true.unsqueeze(1).expand_as(top_k_indices) | |
| # Check if true labels exist in top-k predictions | |
| hits = torch.sum(torch.eq(top_k_indices, expanded_y_true)) | |
| accuracy = hits.item() / (len(y_true) + 1e-7) | |
| return accuracy | |
| def training_step(self, batch, batch_idx): | |
| # Sample training level | |
| rvq_level = torch.randint( | |
| 0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,)).item() | |
| target, chosen_tokens, _, _ = self.model( | |
| batch["tts_quantize_input"], rvq_level, batch["semantic_tokens"], | |
| batch["quantization_lengths"], | |
| speaker_emb=batch["speaker"], | |
| min_seq_length=batch["quantization_lengths"].min().item()) | |
| # Mask targets and labels | |
| mask = chosen_tokens | |
| target = target[mask] | |
| labels = batch["tts_quantize_input"][:, :, rvq_level] | |
| labels = labels[mask] | |
| loss = self.cross_entropy(target, labels) | |
| acc = (target.argmax(-1) == labels).float().mean() | |
| self.log("train/loss", loss, on_step=True, prog_bar=True) | |
| self.log("train/acc", acc, on_step=True, prog_bar=True) | |
| self.log( | |
| f"train/acc_lvl_{rvq_level}", acc, on_step=True, prog_bar=False) | |
| return loss | |
| def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
| speaker_emb = batch["speaker"] | |
| acoustic_tokens = batch["tts_quantize_input"] | |
| semantic_tokens = batch["semantic_tokens"] | |
| if self.hp.only_inference: | |
| self.inference( | |
| acoustic_tokens, semantic_tokens, self.hp.first_n_lvls) | |
| else: | |
| rvq_level = torch.randint( | |
| 0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,) | |
| ).item() | |
| # FIXME: edge case | |
| if len(semantic_tokens.shape) == 3: | |
| semantic_tokens = rearrange(semantic_tokens, "B 1 T -> B T") | |
| target, chosen_tokens, _, _ = self.model( | |
| acoustic_tokens, rvq_level, semantic_tokens, | |
| torch.tensor([acoustic_tokens.shape[1]]).to(self.device), | |
| speaker_emb=speaker_emb, | |
| min_seq_length=acoustic_tokens.shape[1] | |
| ) | |
| target = target[chosen_tokens] | |
| labels = acoustic_tokens[:, :, rvq_level][chosen_tokens] | |
| loss = self.cross_entropy(target, labels) | |
| acc = (target.argmax(-1) == labels).float().mean() | |
| acc_5 = self.top_k_accuracy(labels, target, 5) | |
| self.log( | |
| f"val/dataset_{dataloader_idx}/loss", | |
| loss, | |
| on_epoch=True, | |
| logger=True, | |
| add_dataloader_idx=False, | |
| ) | |
| self.log( | |
| f"val/dataset_{dataloader_idx}/acc_lvl", | |
| acc, | |
| on_epoch=True, | |
| logger=True, | |
| add_dataloader_idx=False, | |
| ) | |
| self.log( | |
| f"val/dataset_{dataloader_idx}/acc_lvl_{rvq_level}", | |
| acc, | |
| on_epoch=True, | |
| logger=True, | |
| add_dataloader_idx=False, | |
| ) | |
| self.log( | |
| f"val/dataset_{dataloader_idx}/acc_top_5", | |
| acc_5, | |
| on_epoch=True, | |
| logger=True, | |
| add_dataloader_idx=False, | |
| ) | |
| self.log( | |
| f"val/dataset_{dataloader_idx}/acc_top_5_lvl_{rvq_level}", | |
| acc_5, | |
| on_epoch=True, | |
| logger=True, | |
| add_dataloader_idx=False, | |
| ) | |
| def compute_stats(self, logits, labels, mask_ratio=0, rvq_level=0): | |
| acc = (logits.argmax(-1) == labels).float().mean() | |
| acc_5 = self.top_k_accuracy(labels, logits, 5) | |
| acc_10 = self.top_k_accuracy(labels, logits, 10) | |
| idx = torch.randperm(logits.shape[0]) | |
| logits_shuffled = logits[idx] | |
| random = self.top_k_accuracy(labels, logits_shuffled, 10) | |
| print(f"Mask ratio: {mask_ratio}, Level {rvq_level}: acc {acc}," | |
| f"acc 5 {acc_5}, acc 10 {acc_10}, quasi random {random}") | |
| class TTSConformer(pl.LightningModule): | |
| def __init__(self, hp): | |
| super().__init__() | |
| self.hp = hp | |
| self.padding_id = self.hp.n_codes | |
| additional_codes = [c.PAD, c.SPKR_1, c.SPKR_2] | |
| self.embedding = nn.ModuleList( | |
| [ | |
| nn.Embedding( | |
| self.hp.n_codes + len(additional_codes), | |
| self.hp.hidden_size, | |
| padding_idx=self.padding_id) | |
| for _ in range(self.hp.n_cluster_groups) | |
| ] | |
| ) | |
| # Additional modules | |
| self.semantic_embedding = nn.Embedding( | |
| self.hp.n_semantic_codes + len(additional_codes), | |
| self.hp.hidden_size, | |
| padding_idx=self.padding_id) | |
| if self.hp.use_spkr_emb: | |
| self.spkr_linear = nn.Linear(c.SPKR_EMB_SIZE, self.hp.hidden_size) | |
| self.conformer = Conformer( | |
| dim=self.hp.hidden_size, | |
| num_layers=self.hp.enc_nlayers, | |
| heads=self.hp.nheads, | |
| dim_head=64, | |
| ff_mult=4, # 512*4=2048 | |
| conv_expansion_factor=2, | |
| conv_kernel_size=self.hp.depthwise_conv_kernel_size, | |
| attn_dropout=self.hp.dropout, | |
| ff_dropout=self.hp.dropout, | |
| conv_dropout=self.hp.dropout, | |
| attn_flash=True, | |
| t5_rel_pos_bias=False | |
| ) | |
| self.heads = nn.ModuleList( | |
| [ | |
| nn.Linear( | |
| self.hp.hidden_size, | |
| self.hp.n_codes + len(additional_codes) | |
| ) | |
| for _ in range(self.hp.n_cluster_groups) | |
| ] | |
| ) | |
| def build_mask_from_lengths(self, length, max_len=None): | |
| max_len = max_len or length.max().item() | |
| mask = torch.arange( | |
| max_len, device=length.device)[None, :] >= length[:, None] | |
| return mask.bool() | |
| def create_mask( | |
| self, B, T, lengths, mask_ratio=None, start_t=None, | |
| min_seq_length=None | |
| ): | |
| # 1. Define the random length of condition tokens given the shortest | |
| # audio in the batch | |
| if start_t is None: | |
| start_t = torch.randint(1, min_seq_length - 1, (1,)).item() | |
| # 2. Mask other tokens - sample different masking levels per | |
| if mask_ratio is None: | |
| ratio = torch.rand(1).item() | |
| mask_ratio = masking_logic.schedule(ratio) | |
| # Create a random tensor with values between 0 and 1 | |
| random_tensor = torch.rand( | |
| (B, T - start_t), dtype=torch.float).to(self.device) | |
| # Create a mask where values less than p are set to True | |
| initial_mask = random_tensor < mask_ratio | |
| length_mask = self.build_mask_from_lengths( | |
| lengths - start_t, T - start_t) | |
| # we can't pick up tokens past token lengths | |
| initial_mask = torch.logical_and(initial_mask, ~length_mask) | |
| # Constrain ratio to always include some samples | |
| # If all are False let's pick up at least one: | |
| if torch.sum(initial_mask) == 0: | |
| choose_steps = torch.randint(low=0, high=(T - start_t), size=(B,)) | |
| initial_mask[torch.arange(B), choose_steps] = torch.tensor( | |
| True, device=self.device) | |
| # 3. Add condition tokens containing information | |
| acoustic_token_mask = torch.cat( | |
| (torch.full((B, start_t), False, device=self.device), initial_mask), # noqa | |
| 1 | |
| ) | |
| return acoustic_token_mask, start_t, mask_ratio | |
| def process_input( | |
| self, data, lengths, rvq_level, min_seq_length=None, | |
| mask_ratio=None, start_t=None, acoustic_token_mask=None | |
| ): | |
| """ | |
| data: (B, T, code_level, D) | |
| rvq_level: int | |
| """ | |
| B = data.size(0) | |
| T = data.size(1) | |
| level_data = data[:, :, rvq_level, :] # [B, T, C, D] -> [B, T, D] | |
| # Choose acoustic tokens to mask | |
| if acoustic_token_mask is None: | |
| acoustic_token_mask, start_t, mask_ratio = self.create_mask( | |
| B, T, lengths, mask_ratio=mask_ratio, start_t=start_t, | |
| min_seq_length=min_seq_length) | |
| # Remove code information from chosen tokens | |
| level_data[acoustic_token_mask, :] = 0 | |
| # Embed only lower rvq_level | |
| lower_code_data = data[:, :, :rvq_level, :].sum(dim=2) | |
| # Combine with chosen tokens at rvq_level. | |
| # Note: all tokens at rvq_level+1: will be discarded. | |
| summed_data = torch.add(lower_code_data, level_data) | |
| return summed_data, acoustic_token_mask, mask_ratio, start_t | |
| def forward( | |
| self, x, code_level, semantic_tokens, lengths, | |
| speaker_emb=None, min_seq_length=10, mask_ratio=None, start_t=None, | |
| acoustic_token_mask=None | |
| ): | |
| # FIXME: parallelize this | |
| batch = [] | |
| for lvl, embed in enumerate(self.embedding[:(code_level + 1)]): | |
| batch.append(embed(x[:, :, lvl])) # [B T D] | |
| x = torch.stack(batch, dim=2) # [B T C D] | |
| x, acoustic_token_mask, mask_ratio, start_t = self.process_input( | |
| x, lengths, code_level, min_seq_length=min_seq_length, | |
| mask_ratio=mask_ratio, start_t=start_t, | |
| acoustic_token_mask=acoustic_token_mask | |
| ) | |
| # Add phoneme embeddings | |
| # Cross attention for all tokens? | |
| # Add semantic tokens | |
| # HACK ME | |
| semantic_emb = self.semantic_embedding(semantic_tokens) | |
| x = torch.add(x, semantic_emb) | |
| # FIXME pfb30 | |
| # Merge different modalities | |
| if self.hp.use_spkr_emb: | |
| spkr_emb = F.normalize(speaker_emb, dim=-1) | |
| spkr_emb = self.spkr_linear( | |
| F.dropout(spkr_emb, self.hp.speaker_embed_dropout) | |
| ) | |
| x = torch.add(x, spkr_emb) | |
| output_frames = self.conformer(x, None) | |
| x = self.heads[code_level](output_frames) | |
| return x, acoustic_token_mask, mask_ratio, start_t | |
| def inference( | |
| self, codes, semantic_tokens, | |
| length: torch.LongTensor, rvq_levels=7, | |
| mask_ratio=0.99, maskgit_inference=True, | |
| start_t: Union[torch.LongTensor, None] = None, | |
| speaker_emb=None, steps=16 | |
| ): | |
| # Use half of the recording for the conditioning | |
| if start_t is None: | |
| start_t = torch.tensor(int((codes.shape[1]) / 2)).long() | |
| start_t = start_t.item() | |
| for rvq_level in range(rvq_levels): | |
| original_codes = torch.clone(codes) | |
| if rvq_level == 0 and maskgit_inference: | |
| codes = self.multi_step_inference( | |
| original_codes, semantic_tokens, length, | |
| start_t=start_t, vamp_filtering=False, | |
| speaker_emb=speaker_emb, steps=16 | |
| ) | |
| else: | |
| codes = self.one_step_inference( | |
| original_codes, semantic_tokens, length, | |
| code_level=rvq_level, | |
| mask_ratio=mask_ratio, start_t=start_t, | |
| speaker_emb=speaker_emb | |
| ) | |
| codes = rearrange(codes, 'T C -> 1 T C') | |
| # Remove any padding left | |
| codes = rearrange(codes, '1 T C -> 1 C T') | |
| codes = torch.where(codes >= self.hp.n_codes, 0, codes) | |
| acoustic_tokens = codes | |
| semantic_tokens = rearrange(semantic_tokens, 'b c -> b 1 c') | |
| semantic_tokens = torch.where( | |
| semantic_tokens >= self.hp.n_codes, 0, semantic_tokens) | |
| codes = torch.cat([semantic_tokens, acoustic_tokens], dim=1) | |
| return codes | |
| def one_step_inference( | |
| self, original_codes, semantic_tokens, lengths, code_level=0, | |
| mask_ratio=0.99, start_t=0, inference_setup="argmax", speaker_emb=None | |
| ): | |
| codes = torch.clone(original_codes) | |
| logits, _, _, _ = self.forward( | |
| codes, code_level, semantic_tokens, lengths, | |
| mask_ratio=mask_ratio, start_t=start_t, | |
| speaker_emb=speaker_emb, acoustic_token_mask=False) | |
| if inference_setup == "argmax": | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| top_indeces = torch.argmax(probs, dim=-1) | |
| if inference_setup == "sampling": | |
| top_indeces = torch.distributions.Categorical( | |
| logits=logits).sample() | |
| codes = rearrange(codes, '1 T C -> T C') | |
| codes[start_t:, code_level] = top_indeces[0, start_t:] | |
| return codes | |
| def multi_step_inference( | |
| self, original_codes, semantic_tokens, lengths, | |
| start_t: torch.LongTensor=None, | |
| choice_temperature=1.0, start_iter=0, | |
| steps=16, vamp_filtering=False, speaker_emb=None | |
| ): | |
| codes = torch.clone(original_codes) | |
| code_level = 0 | |
| _, seq_len, _ = original_codes.shape | |
| mask_token_id = self.padding_id | |
| # Get true codes for the prompt | |
| prompt_mask = codes[:, :start_t, code_level] | |
| # Fill up rest with masks | |
| mask = torch.full( | |
| (1, seq_len - start_t), mask_token_id, device=self.device) | |
| inputs = torch.cat((prompt_mask, mask), 1) | |
| num_mask_tokens_at_start = torch.sum(inputs == mask_token_id, axis=-1) | |
| # Initializes state | |
| state = state_init(inputs, steps, start_iter=start_iter) | |
| def loop_cond_fn(state): | |
| """Beam search loop termination condition.""" | |
| not_at_end = (state.cur_index < steps) | |
| return not_at_end | |
| while loop_cond_fn(state): | |
| """Beam search loop state update function.""" | |
| step = state.cur_index | |
| # Current input ids: [batch_size, seq_length]. | |
| cur_ids = state.cur_seqs | |
| # Calls model on current seqs to get next-iteration seqs. | |
| with torch.no_grad(): | |
| logits, _, _, _ = self.forward( | |
| rearrange(inputs, 'B T -> B T 1'), | |
| code_level, | |
| semantic_tokens, lengths, | |
| acoustic_token_mask=False, | |
| speaker_emb=speaker_emb) | |
| # Samples the ids using categorical sampling: | |
| if vamp_filtering: | |
| typical_mass = 0.2 | |
| typical_min_tokens = 1 | |
| top_p = None | |
| sample_cutoff = 0.5 | |
| typical_filtering = False | |
| sampled_ids, selected_probs = sample_from_logits( | |
| logits, sample=((step / steps) <= sample_cutoff), | |
| temperature=choice_temperature, | |
| typical_filtering=typical_filtering, | |
| typical_mass=typical_mass, | |
| typical_min_tokens=typical_min_tokens, | |
| top_k=None, top_p=top_p, return_probs=True, | |
| ) | |
| else: | |
| sampled_ids = torch.distributions.Categorical( | |
| logits=logits).sample() | |
| # Just updates the masked tokens. | |
| unknown_map = (cur_ids == mask_token_id) | |
| sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids) | |
| # Defines the mask ratio for the next round. The number to mask out | |
| # is determined by mask_ratio * unknown_number_in_the_beginning. | |
| ratio = 1. * (step + 1) / steps | |
| mask_ratio = masking_logic.schedule(ratio) | |
| # Updates final seqs with the current sampled_ids. | |
| final_seqs = torch.clone(state.final_seqs) | |
| final_seqs[:, step, :] = sampled_ids | |
| # Computes the probabilities of each selected tokens. | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| # Extract the probabilities of sampled ids | |
| selected_probs = torch.squeeze( | |
| torch.take_along_dim( | |
| probs, torch.unsqueeze(sampled_ids, -1) , -1), | |
| -1 | |
| ) | |
| # Ignores the tokens given in the input | |
| # by overwriting their confidence. | |
| selected_probs = torch.where( | |
| unknown_map, selected_probs, torch.inf) | |
| # Gets mask lens for each sample in the | |
| # batch according to the mask ratio. | |
| num_to_mask = torch.unsqueeze( | |
| torch.floor(num_mask_tokens_at_start * mask_ratio), 1) | |
| # Keeps at least one of prediction in this | |
| # round and also masks out at least | |
| # one and for the next iteration | |
| num_to_mask = torch.maximum( | |
| torch.tensor(1), | |
| torch.minimum( | |
| torch.sum(unknown_map, dim=-1, keepdim=True) - 1, | |
| num_to_mask) | |
| ) | |
| # Adds noise for randomness | |
| masking = mask_by_random_topk( | |
| num_to_mask, selected_probs, choice_temperature * (1. - ratio)) | |
| # Masks tokens with lower confidence. | |
| sampled_ids = torch.where(masking, mask_token_id, sampled_ids) | |
| state = State( | |
| cur_index=state.cur_index + 1, | |
| cur_seqs=sampled_ids, | |
| final_seqs=final_seqs | |
| ) | |
| codes = torch.clone(original_codes) | |
| codes = rearrange(codes, '1 T C -> T C') | |
| codes[:, 0] = state.final_seqs[0][-1] | |
| return codes | |