Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| from .base_model import CaptionModel | |
| from .utils import repeat_tensor | |
| import audio_to_text.captioning.models.decoder | |
| class TransformerModel(CaptionModel): | |
| def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): | |
| if not hasattr(self, "compatible_decoders"): | |
| self.compatible_decoders = ( | |
| audio_to_text.captioning.models.decoder.TransformerDecoder, | |
| ) | |
| super().__init__(encoder, decoder, **kwargs) | |
| def seq_forward(self, input_dict): | |
| cap = input_dict["cap"] | |
| cap_padding_mask = (cap == self.pad_idx).to(cap.device) | |
| cap_padding_mask = cap_padding_mask[:, :-1] | |
| output = self.decoder( | |
| { | |
| "word": cap[:, :-1], | |
| "attn_emb": input_dict["attn_emb"], | |
| "attn_emb_len": input_dict["attn_emb_len"], | |
| "cap_padding_mask": cap_padding_mask | |
| } | |
| ) | |
| return output | |
| def prepare_decoder_input(self, input_dict, output): | |
| decoder_input = { | |
| "attn_emb": input_dict["attn_emb"], | |
| "attn_emb_len": input_dict["attn_emb_len"] | |
| } | |
| t = input_dict["t"] | |
| ############### | |
| # determine input word | |
| ################ | |
| if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling | |
| word = input_dict["cap"][:, :t+1] | |
| else: | |
| start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long() | |
| if t == 0: | |
| word = start_word | |
| else: | |
| word = torch.cat((start_word, output["seq"][:, :t]), dim=-1) | |
| # word: [N, T] | |
| decoder_input["word"] = word | |
| cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) | |
| decoder_input["cap_padding_mask"] = cap_padding_mask | |
| return decoder_input | |
| def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
| decoder_input = {} | |
| t = input_dict["t"] | |
| i = input_dict["sample_idx"] | |
| beam_size = input_dict["beam_size"] | |
| ############### | |
| # prepare attn embeds | |
| ################ | |
| if t == 0: | |
| attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) | |
| attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size) | |
| output_i["attn_emb"] = attn_emb | |
| output_i["attn_emb_len"] = attn_emb_len | |
| decoder_input["attn_emb"] = output_i["attn_emb"] | |
| decoder_input["attn_emb_len"] = output_i["attn_emb_len"] | |
| ############### | |
| # determine input word | |
| ################ | |
| start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long() | |
| if t == 0: | |
| word = start_word | |
| else: | |
| word = torch.cat((start_word, output_i["seq"]), dim=-1) | |
| decoder_input["word"] = word | |
| cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) | |
| decoder_input["cap_padding_mask"] = cap_padding_mask | |
| return decoder_input | |
| class M2TransformerModel(CaptionModel): | |
| def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): | |
| if not hasattr(self, "compatible_decoders"): | |
| self.compatible_decoders = ( | |
| captioning.models.decoder.M2TransformerDecoder, | |
| ) | |
| super().__init__(encoder, decoder, **kwargs) | |
| self.check_encoder_compatibility() | |
| def check_encoder_compatibility(self): | |
| assert isinstance(self.encoder, captioning.models.encoder.M2TransformerEncoder), \ | |
| f"only M2TransformerModel is compatible with {self.__class__.__name__}" | |
| def seq_forward(self, input_dict): | |
| cap = input_dict["cap"] | |
| output = self.decoder( | |
| { | |
| "word": cap[:, :-1], | |
| "attn_emb": input_dict["attn_emb"], | |
| "attn_emb_mask": input_dict["attn_emb_mask"], | |
| } | |
| ) | |
| return output | |
| def prepare_decoder_input(self, input_dict, output): | |
| decoder_input = { | |
| "attn_emb": input_dict["attn_emb"], | |
| "attn_emb_mask": input_dict["attn_emb_mask"] | |
| } | |
| t = input_dict["t"] | |
| ############### | |
| # determine input word | |
| ################ | |
| if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling | |
| word = input_dict["cap"][:, :t+1] | |
| else: | |
| start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long() | |
| if t == 0: | |
| word = start_word | |
| else: | |
| word = torch.cat((start_word, output["seq"][:, :t]), dim=-1) | |
| # word: [N, T] | |
| decoder_input["word"] = word | |
| return decoder_input | |
| def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
| decoder_input = {} | |
| t = input_dict["t"] | |
| i = input_dict["sample_idx"] | |
| beam_size = input_dict["beam_size"] | |
| ############### | |
| # prepare attn embeds | |
| ################ | |
| if t == 0: | |
| attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) | |
| attn_emb_mask = repeat_tensor(input_dict["attn_emb_mask"][i], beam_size) | |
| output_i["attn_emb"] = attn_emb | |
| output_i["attn_emb_mask"] = attn_emb_mask | |
| decoder_input["attn_emb"] = output_i["attn_emb"] | |
| decoder_input["attn_emb_mask"] = output_i["attn_emb_mask"] | |
| ############### | |
| # determine input word | |
| ################ | |
| start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long() | |
| if t == 0: | |
| word = start_word | |
| else: | |
| word = torch.cat((start_word, output_i["seq"]), dim=-1) | |
| decoder_input["word"] = word | |
| return decoder_input | |
| class EventEncoder(nn.Module): | |
| """ | |
| Encode the Label information in AudioCaps and AudioSet | |
| """ | |
| def __init__(self, emb_dim, vocab_size=527): | |
| super(EventEncoder, self).__init__() | |
| self.label_embedding = nn.Parameter( | |
| torch.randn((vocab_size, emb_dim)), requires_grad=True) | |
| def forward(self, word_idxs): | |
| indices = word_idxs / word_idxs.sum(dim=1, keepdim=True) | |
| embeddings = indices @ self.label_embedding | |
| return embeddings | |
| class EventCondTransformerModel(TransformerModel): | |
| def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): | |
| if not hasattr(self, "compatible_decoders"): | |
| self.compatible_decoders = ( | |
| captioning.models.decoder.EventTransformerDecoder, | |
| ) | |
| super().__init__(encoder, decoder, **kwargs) | |
| self.label_encoder = EventEncoder(decoder.emb_dim, 527) | |
| self.train_forward_keys += ["events"] | |
| self.inference_forward_keys += ["events"] | |
| # def seq_forward(self, input_dict): | |
| # cap = input_dict["cap"] | |
| # cap_padding_mask = (cap == self.pad_idx).to(cap.device) | |
| # cap_padding_mask = cap_padding_mask[:, :-1] | |
| # output = self.decoder( | |
| # { | |
| # "word": cap[:, :-1], | |
| # "attn_emb": input_dict["attn_emb"], | |
| # "attn_emb_len": input_dict["attn_emb_len"], | |
| # "cap_padding_mask": cap_padding_mask | |
| # } | |
| # ) | |
| # return output | |
| def prepare_decoder_input(self, input_dict, output): | |
| decoder_input = super().prepare_decoder_input(input_dict, output) | |
| decoder_input["events"] = self.label_encoder(input_dict["events"]) | |
| return decoder_input | |
| def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
| decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i) | |
| t = input_dict["t"] | |
| i = input_dict["sample_idx"] | |
| beam_size = input_dict["beam_size"] | |
| if t == 0: | |
| output_i["events"] = repeat_tensor(self.label_encoder(input_dict["events"])[i], beam_size) | |
| decoder_input["events"] = output_i["events"] | |
| return decoder_input | |
| class KeywordCondTransformerModel(TransformerModel): | |
| def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): | |
| if not hasattr(self, "compatible_decoders"): | |
| self.compatible_decoders = ( | |
| captioning.models.decoder.KeywordProbTransformerDecoder, | |
| ) | |
| super().__init__(encoder, decoder, **kwargs) | |
| self.train_forward_keys += ["keyword"] | |
| self.inference_forward_keys += ["keyword"] | |
| def seq_forward(self, input_dict): | |
| cap = input_dict["cap"] | |
| cap_padding_mask = (cap == self.pad_idx).to(cap.device) | |
| cap_padding_mask = cap_padding_mask[:, :-1] | |
| keyword = input_dict["keyword"] | |
| output = self.decoder( | |
| { | |
| "word": cap[:, :-1], | |
| "attn_emb": input_dict["attn_emb"], | |
| "attn_emb_len": input_dict["attn_emb_len"], | |
| "keyword": keyword, | |
| "cap_padding_mask": cap_padding_mask | |
| } | |
| ) | |
| return output | |
| def prepare_decoder_input(self, input_dict, output): | |
| decoder_input = super().prepare_decoder_input(input_dict, output) | |
| decoder_input["keyword"] = input_dict["keyword"] | |
| return decoder_input | |
| def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
| decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i) | |
| t = input_dict["t"] | |
| i = input_dict["sample_idx"] | |
| beam_size = input_dict["beam_size"] | |
| if t == 0: | |
| output_i["keyword"] = repeat_tensor(input_dict["keyword"][i], | |
| beam_size) | |
| decoder_input["keyword"] = output_i["keyword"] | |
| return decoder_input | |