Spaces:
Running
on
Zero
Running
on
Zero
| import types | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import WhisperFeatureExtractor | |
| import whisper | |
| from ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs | |
| class WhisperWrappedEncoder: | |
| def load(cls, model_config): | |
| def replace_layer_norm(module): | |
| from whisper.model import LayerNorm | |
| for name, child in module.named_children(): | |
| if isinstance(child, LayerNorm): | |
| old_params = child.state_dict() | |
| new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) | |
| new_layer_norm.load_state_dict(old_params) | |
| setattr(module, name, new_layer_norm) | |
| else: | |
| replace_layer_norm(child) | |
| encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder | |
| replace_layer_norm(encoder) | |
| return encoder | |
| class DualWrappedEncoder(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.whisper_model = self.load_whisper(config) | |
| self.beats_model = self.load_beats(config) | |
| def load_whisper(cls, model_config): | |
| def replace_layer_norm(module): | |
| from whisper.model import LayerNorm | |
| for name, child in module.named_children(): | |
| if isinstance(child, LayerNorm): | |
| old_params = child.state_dict() | |
| new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) | |
| new_layer_norm.load_state_dict(old_params) | |
| setattr(module, name, new_layer_norm) | |
| else: | |
| replace_layer_norm(child) | |
| encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder | |
| replace_layer_norm(encoder) | |
| return encoder | |
| def load_beats(cls, model_config): | |
| beats_path = model_config.music_encoder | |
| print("Loading BEATs Model") | |
| beats_ckpt = torch.load(beats_path, map_location='cpu') | |
| beats_cfg = BEATsConfig(beats_ckpt['cfg']) | |
| beats = BEATs(beats_cfg) | |
| beats.load_state_dict(beats_ckpt['model']) | |
| return beats | |
| def forward(self, x, raw_wav=None, audio_padding_mask=None): | |
| with torch.no_grad(): | |
| self.beats_model = self.beats_model.float() | |
| speech_embeds = self.whisper_model(x) | |
| audio_embeds, _ = self.beats_model.extract_features(raw_wav.float(), padding_mask=audio_padding_mask, feature_only=True) | |
| if audio_embeds.size(1) < speech_embeds.size(1): | |
| audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) | |
| elif audio_embeds.size(1) > speech_embeds.size(1): | |
| speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) | |
| speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) | |
| speech_embeds = speech_embeds.to(torch.bfloat16) | |
| return speech_embeds |