Spaces:
Runtime error
Runtime error
| import math | |
| from typing import Dict, Optional, Tuple | |
| import torch | |
| from wenet.ssl.bestrq.mask import compute_mask_indices_v2 | |
| from wenet.utils.mask import make_non_pad_mask, make_pad_mask | |
| from wenet.transformer.attention import RelPositionMultiHeadedAttention | |
| from wenet.transformer.encoder_layer import ConformerEncoderLayer | |
| def quantize_vector(latent: torch.Tensor, codebook: torch.Tensor): | |
| """ | |
| Symbols in comments: | |
| B: batch_size. | |
| D: latent_dim. | |
| C: num_latent_classes per group | |
| G: num of codebook groups. | |
| Args: | |
| latent: [B, D] | |
| codebook: [C, G, D // G] | |
| Returns: | |
| (quantized, codes, onehot). | |
| - quantized: [B, D] | |
| - codes: [B, G] | |
| - onehot: [B, G, C] | |
| """ | |
| assert len(codebook.size()) == 3 | |
| b, d = latent.size() | |
| c, g, _ = codebook.size() | |
| assert d % g == 0 | |
| latent = latent.reshape(b, g, d // g) | |
| # [B, G, C] | |
| # torch.transpose(codebook, [2,1,0]) | |
| distance = ( | |
| # [b, g, 1] | |
| torch.sum(latent**2, -1, keepdim=True) - | |
| # [b, g, c] | |
| 2 * torch.einsum('bgd,cgd->bgc', latent, codebook) + | |
| # [1, g, c] | |
| torch.sum(codebook.permute([2, 1, 0])**2, 0, keepdim=True)) | |
| # [B, G] | |
| codes = torch.argmin(distance, dim=-1) | |
| # [B, G, C] | |
| one_hot = torch.nn.functional.one_hot(codes, c).type(codebook.dtype) | |
| quantized = torch.einsum('bgc,cgd->bgd', one_hot, codebook) | |
| quantized = torch.reshape(quantized, [b, d]) | |
| return quantized, codes, one_hot | |
| class BestRQModel(torch.nn.Module): | |
| def __init__( | |
| self, | |
| encoder: torch.nn.Module, | |
| num_mel_bins: int = 80, | |
| embedding_dim: int = 16, | |
| num_embeddings: int = 8192, | |
| num_codebooks: int = 1, | |
| mask_prob: float = 0.01, | |
| mask_length: int = 10, | |
| min_masks: int = 2, | |
| norm_epsilon: float = 1e-5, | |
| out_bias: bool = False, | |
| features_regularization_weight: float = 0.01, | |
| ) -> None: | |
| super().__init__() | |
| assert mask_prob > 0.0 | |
| self.mask_prob = mask_prob | |
| self.mask_length = mask_length | |
| self.min_masks = min_masks | |
| self.num_codebooks = num_codebooks | |
| self.num_embeddings = num_embeddings | |
| self.features_regularization_weight = features_regularization_weight | |
| # encoder | |
| self.encoder = encoder | |
| # n softmax | |
| self.encoder_top_n_out = torch.nn.parameter.Parameter( | |
| torch.empty(self.num_codebooks, self.encoder.output_size(), | |
| num_embeddings)) | |
| torch.nn.init.trunc_normal_(self.encoder_top_n_out, std=0.02) | |
| self.out_bias = out_bias | |
| if self.out_bias: | |
| self.encoder_top_n_out_bias = torch.nn.parameter.Parameter( | |
| torch.empty(self.num_codebooks, num_embeddings)) | |
| torch.nn.init.zeros_(self.encoder_top_n_out_bias) | |
| # stack input: eg: fbank | |
| self.stack_frames = self.encoder.embed.right_context + 1 | |
| self.stride = self.encoder.embed.subsampling_rate | |
| input_dim = num_mel_bins * self.stride | |
| # random projectoin | |
| self.projection = torch.nn.parameter.Parameter( | |
| torch.empty(input_dim, embedding_dim * self.num_codebooks), | |
| requires_grad=False, | |
| ) | |
| torch.nn.init.xavier_uniform_(self.projection) | |
| # codebooks | |
| # [num_embeddings, num_codebooks, num_embeddings] means | |
| # [C, G, D] see quantize_vector | |
| self.embeddings = torch.nn.parameter.Parameter( | |
| torch.empty(num_embeddings, self.num_codebooks, embedding_dim), | |
| requires_grad=False, | |
| ) | |
| torch.nn.init.normal_(self.embeddings) | |
| self.embeddings /= (self.embeddings.norm(dim=-1, p=2, keepdim=True) + | |
| 1e-8) | |
| # force reset encoder papameter | |
| self.reset_encoder_parameter() | |
| def reset_encoder_parameter(self): | |
| def _reset_parameter(module: torch.nn.Module): | |
| if isinstance(module, torch.nn.Linear): | |
| torch.nn.init.trunc_normal_(module.weight.data, | |
| mean=0.0, | |
| std=0.02) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, torch.nn.Conv1d): | |
| torch.nn.init.kaiming_normal_(module.weight) | |
| if module.bias is not None: | |
| k = math.sqrt(module.groups / | |
| (module.in_channels * module.kernel_size[0])) | |
| torch.nn.init.uniform_(module.bias, a=-k, b=k) | |
| elif isinstance(module, torch.Tensor): | |
| torch.nn.init.trunc_normal_(module) | |
| else: | |
| raise NotImplementedError("other module not support now") | |
| encoders = self.encoder.encoders | |
| for _, layer in enumerate(encoders): | |
| self_attn = layer.self_attn | |
| _reset_parameter(self_attn.linear_q) | |
| _reset_parameter(self_attn.linear_k) | |
| _reset_parameter(self_attn.linear_v) | |
| _reset_parameter(self_attn.linear_out) | |
| if isinstance(self_attn, RelPositionMultiHeadedAttention): | |
| _reset_parameter(self_attn.pos_bias_u) | |
| _reset_parameter(self_attn.pos_bias_v) | |
| if isinstance(layer, ConformerEncoderLayer): | |
| conv1, conv2 = (layer.conv_module.pointwise_conv1, | |
| layer.conv_module.depthwise_conv) | |
| _reset_parameter(conv1) | |
| _reset_parameter(conv2) | |
| def forward( | |
| self, | |
| batch: Dict, | |
| device: torch.device, | |
| ): | |
| xs = batch['feats'].to(device) | |
| xs_lens = batch['feats_lengths'].to(device) | |
| input = xs | |
| features_pen: Optional[torch.Tensor] = None | |
| if self.features_regularization_weight != 0.0: | |
| features_pen = input.pow(2).mean() | |
| # 1 mask input | |
| xs, code_ids_mask = self._apply_mask_signal(xs, xs_lens) | |
| # 2.0 stack fbank | |
| unmasked_xs = self._stack_features(input, xs_lens) | |
| masked_xs = xs | |
| # 2.1 get nearest embedding | |
| target_ids = self._nearest_embedding_idx(unmasked_xs) | |
| target_ids = target_ids[:, :code_ids_mask.size(1), :] | |
| # 3 forward xxx-formaer block and its subsampling layer | |
| out, out_mask = self.encoder(masked_xs, xs_lens) | |
| # 4 get logits | |
| out = out.unsqueeze(1) # [B, 1, T', dim] | |
| top_n_out = self.encoder_top_n_out.unsqueeze( | |
| 0) # [1, num_codebooks, dim, num_embeddings] | |
| out = torch.matmul(out, | |
| top_n_out) # [B, num_codebooks, T', num_embeddings] | |
| if self.out_bias: | |
| out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2) | |
| # 5 compute loss | |
| masks = out_mask.squeeze(1) * code_ids_mask | |
| loss = self._compute_loss(out, target_ids, mask=masks) | |
| if self.features_regularization_weight != 0.0: | |
| loss = loss + self.features_regularization_weight * features_pen | |
| # 6 other info: num codes used in batch, unique num codes used in batch | |
| num_codes = masks.sum() * self.num_codebooks | |
| uniq_num_codes = torch.tensor( | |
| torch.unique(target_ids * masks.unsqueeze(2)).numel()).detach() | |
| ids_corr = out.argmax(dim=-1, keepdim=False).transpose(1, | |
| 2) == target_ids | |
| codes_acc = (ids_corr * masks.unsqueeze(2)).sum() / num_codes | |
| return { | |
| "codes_acc": codes_acc, | |
| "features_l2": features_pen, | |
| "loss": loss, | |
| "num_codes": num_codes, | |
| "uniq_num_codes": uniq_num_codes, | |
| "th_accuracy": codes_acc, | |
| } | |
| def _apply_mask_signal( | |
| self, input: torch.Tensor, | |
| input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| device = input.device | |
| B, T, _ = input.size() | |
| padding_mask = make_pad_mask(input_lens) | |
| # calc subsampling masks | |
| padding_mask_stride = padding_mask.unfold( | |
| 1, | |
| size=self.stack_frames, | |
| step=self.stride, | |
| ) | |
| padding_mask, _ = torch.max(padding_mask_stride, dim=-1) | |
| masks = compute_mask_indices_v2(padding_mask.size(), | |
| padding_mask, | |
| self.mask_prob, | |
| self.mask_length, | |
| min_masks=self.min_masks, | |
| device=device) | |
| # calc signal mask | |
| subsampling_mask = masks | |
| bool_stride_mask = torch.ones_like(padding_mask_stride, device=device) | |
| mask_stride = torch.where(masks.unsqueeze(-1), bool_stride_mask, False) | |
| # recover orign seq masks | |
| masks = mask_stride[:, :, :self.stride].flatten(start_dim=1) | |
| masks_padding = torch.zeros( | |
| B, | |
| T, | |
| device=device, | |
| dtype=padding_mask.dtype, | |
| ) | |
| masks_padding[:, :masks.size(-1)] = masks | |
| masks = masks_padding | |
| masks_expand = masks.unsqueeze(-1) # [B, T, 1] | |
| # NOTE(Mddct): you can use size (b,t,d) for torch.normal | |
| mask_emb = torch.normal(mean=0, std=0.1, | |
| size=(1, 1, input.size(2))).to(input.device) | |
| xs = torch.where(masks_expand, mask_emb, input) | |
| return xs, subsampling_mask | |
| def _stack_features(self, input: torch.Tensor, | |
| input_lens: torch.Tensor) -> torch.Tensor: | |
| stack_input = input.unfold(1, size=self.stride, step=self.stride) | |
| stack_input = stack_input.transpose(-1, -2) | |
| b, n, f, d = stack_input.size() | |
| stack_input = stack_input.reshape(b, n, f * d) | |
| # NOTE(Mddct): important!!! | |
| # norm stack features | |
| mask = make_non_pad_mask(input_lens) | |
| stack_mask = mask.unfold(1, size=self.stride, step=self.stride) | |
| stack_mask, _ = torch.min(stack_mask, dim=-1) | |
| stack_input = stack_input * stack_mask.unsqueeze(2) | |
| mean = stack_input.sum(1, keepdim=True) / stack_mask.sum( | |
| dim=1, keepdim=True).unsqueeze(1) | |
| std = torch.sqrt(((stack_input - mean)**2).sum(dim=1, keepdim=True) / | |
| stack_mask.sum(dim=1, keepdim=True).unsqueeze(1)) | |
| norm_stack_input = (stack_input - mean) / (std + 1e-5) | |
| return norm_stack_input | |
| def _compute_loss(self, input: torch.Tensor, target: torch.Tensor, | |
| mask: torch.Tensor) -> torch.Tensor: | |
| logits = input.transpose(1, 2).contiguous().view(-1, input.size(-1)) | |
| loss = torch.nn.functional.cross_entropy( | |
| logits, | |
| target.contiguous().view(-1), | |
| reduction='none', | |
| ) | |
| loss = (loss * mask.view(-1)).sum() / mask.sum() | |
| return loss | |
| def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor: | |
| xs = torch.matmul(xs, self.projection.to(xs.device)) | |
| xs = xs / (xs.norm(dim=-1, p=2, keepdim=True) + 1e-8) | |
| codebooks = self.embeddings | |
| B, T, C = xs.size() | |
| xs_flatten = xs.view(B * T, C) | |
| _, codes, _ = quantize_vector(xs_flatten, codebooks) | |
| return codes.reshape(B, T, -1) # [B, T, num_codebooks] | |