Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # Copyright 2024 Wen-Chin Huang | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| # SSLMOS model | |
| # modified from: https://github.com/nii-yamagishilab/mos-finetune-ssl/blob/main/mos_fairseq.py (written by Erica Cooper) | |
| import torch | |
| import torch.nn as nn | |
| from .modules import Projection | |
| class SSLMOS(torch.nn.Module): | |
| def __init__( | |
| self, | |
| # dummy, for signature need | |
| model_input: str, | |
| # model related | |
| ssl_module: str, | |
| s3prl_name: str, | |
| ssl_model_output_dim: int, | |
| ssl_model_layer_idx: int, | |
| # mean net related | |
| mean_net_dnn_dim: int = 64, | |
| mean_net_output_type: str = "scalar", | |
| mean_net_output_dim: int = 5, | |
| mean_net_output_step: float = 0.25, | |
| mean_net_range_clipping: bool = True, | |
| # listener related | |
| use_listener_modeling: bool = False, | |
| num_listeners: int = None, | |
| listener_emb_dim: int = None, | |
| use_mean_listener: bool = True, | |
| # decoder related | |
| decoder_type: str = "ffn", | |
| decoder_dnn_dim: int = 64, | |
| output_type: str = "scalar", | |
| range_clipping: bool = True, | |
| # dummy | |
| num_domains: int = None, | |
| ): | |
| super().__init__() # this is needed! or else there will be an error. | |
| self.use_mean_listener = use_mean_listener | |
| self.output_type = output_type | |
| # define listener embedding | |
| self.use_listener_modeling = use_listener_modeling | |
| # define ssl model | |
| if ssl_module == "s3prl": | |
| from s3prl.nn import S3PRLUpstream | |
| if s3prl_name in S3PRLUpstream.available_names(): | |
| self.ssl_model = S3PRLUpstream(s3prl_name) | |
| self.ssl_model_layer_idx = ssl_model_layer_idx | |
| else: | |
| raise NotImplementedError | |
| # default uses ffn type mean net | |
| self.mean_net_dnn = Projection( | |
| ssl_model_output_dim, | |
| mean_net_dnn_dim, | |
| nn.ReLU, | |
| mean_net_output_type, | |
| mean_net_output_dim, | |
| mean_net_output_step, | |
| mean_net_range_clipping, | |
| ) | |
| # listener modeling related | |
| self.use_listener_modeling = use_listener_modeling | |
| if use_listener_modeling: | |
| self.num_listeners = num_listeners | |
| self.listener_embeddings = nn.Embedding( | |
| num_embeddings=num_listeners, embedding_dim=listener_emb_dim | |
| ) | |
| # define decoder | |
| self.decoder_type = decoder_type | |
| if decoder_type == "ffn": | |
| decoder_dnn_input_dim = ssl_model_output_dim + listener_emb_dim | |
| else: | |
| raise NotImplementedError | |
| # there is always dnn | |
| self.decoder_dnn = Projection( | |
| decoder_dnn_input_dim, | |
| decoder_dnn_dim, | |
| self.activation, | |
| output_type, | |
| range_clipping, | |
| ) | |
| def get_num_params(self): | |
| return sum(p.numel() for n, p in self.named_parameters()) | |
| def forward(self, inputs): | |
| """Calculate forward propagation. | |
| Args: | |
| waveform has shape (batch, time) | |
| waveform_lengths has shape (batch) | |
| listener_ids has shape (batch) | |
| """ | |
| waveform = inputs["waveform"] | |
| waveform_lengths = inputs["waveform_lengths"] | |
| batch, time = waveform.shape | |
| # get listener embedding | |
| if self.use_listener_modeling: | |
| listener_ids = inputs["listener_idxs"] | |
| # NOTE(unlight): not tested yet | |
| listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) | |
| listener_embs = torch.stack( | |
| [listener_embs for i in range(time)], dim=1 | |
| ) # (batch, time, feat_dim) | |
| # ssl model forward | |
| all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( | |
| waveform, waveform_lengths | |
| ) | |
| encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] | |
| encoder_outputs_lens = all_encoder_outputs_lens[self.ssl_model_layer_idx] | |
| # inject listener embedding | |
| if self.use_listener_modeling: | |
| # NOTE(unlight): not tested yet | |
| encoder_outputs = encoder_outputs.view( | |
| (batch, time, -1) | |
| ) # (batch, time, feat_dim) | |
| decoder_inputs = torch.cat( | |
| [encoder_outputs, listener_embs], dim=-1 | |
| ) # concat along feature dimension | |
| else: | |
| decoder_inputs = encoder_outputs | |
| # masked mean pooling | |
| # masks = make_non_pad_mask(encoder_outputs_lens) | |
| # masks = masks.unsqueeze(-1).to(decoder_inputs.device) # [B, max_time, 1] | |
| # decoder_inputs = torch.sum(decoder_inputs * masks, dim=1) / encoder_outputs_lens.unsqueeze(-1) | |
| # mean net | |
| mean_net_outputs = self.mean_net_dnn( | |
| decoder_inputs | |
| ) # [batch, time, 1 (scalar) / 5 (categorical)] | |
| # decoder | |
| if self.use_listener_modeling: | |
| if self.decoder_type == "rnn": | |
| decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) | |
| else: | |
| decoder_outputs = decoder_inputs | |
| decoder_outputs = self.decoder_dnn( | |
| decoder_outputs | |
| ) # [batch, time, 1 (scalar) / 5 (categorical)] | |
| # set outputs | |
| # return lengths for masked loss calculation | |
| ret = { | |
| "waveform_lengths": waveform_lengths, | |
| "frame_lengths": encoder_outputs_lens, | |
| } | |
| # define scores | |
| ret["mean_scores"] = mean_net_outputs | |
| ret["ld_scores"] = decoder_outputs if self.use_listener_modeling else None | |
| return ret | |
| def mean_net_inference(self, inputs): | |
| waveform = inputs["waveform"] | |
| waveform_lengths = inputs["waveform_lengths"] | |
| # ssl model forward | |
| all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( | |
| waveform, waveform_lengths | |
| ) | |
| encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] | |
| # mean net | |
| decoder_inputs = encoder_outputs | |
| mean_net_outputs = self.mean_net_dnn( | |
| decoder_inputs, inference=True | |
| ) # [batch, time, 1 (scalar) / 5 (categorical)] | |
| mean_net_outputs = mean_net_outputs.squeeze(-1) | |
| scores = torch.mean(mean_net_outputs, dim=1) # [batch] | |
| return { | |
| "ssl_embeddings": encoder_outputs, | |
| "scores": scores | |
| } | |
| def mean_net_inference_p1(self, waveform, waveform_lengths): | |
| # ssl model forward | |
| all_encoder_outputs, _ = self.ssl_model(waveform, waveform_lengths) | |
| encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] | |
| return encoder_outputs | |
| def mean_net_inference_p2(self, encoder_outputs): | |
| # mean net | |
| mean_net_outputs = self.mean_net_dnn( | |
| encoder_outputs | |
| ) # [batch, time, 1 (scalar) / 5 (categorical)] | |
| mean_net_outputs = mean_net_outputs.squeeze(-1) | |
| scores = torch.mean(mean_net_outputs, dim=1) | |
| return scores | |
| def get_ssl_embeddings(self, inputs): | |
| waveform = inputs["waveform"] | |
| waveform_lengths = inputs["waveform_lengths"] | |
| all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( | |
| waveform, waveform_lengths | |
| ) | |
| encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] | |
| return encoder_outputs |