File size: 6,919 Bytes
3afa065 77a8fb4 3afa065 77a8fb4 3afa065 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from torch.ao.nn.quantized import Sigmoid
from transformers import BartModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from peft import get_peft_model, LoraConfig
from huggingface_hub import PyTorchModelHubMixin
from transformers import BartConfig
class MLP(nn.Module):
def __init__(self, layer_sizes=[64, 64, 64, 1], arl=False, dropout=0.0):
super().__init__()
self.arl = arl
self.attention = nn.Sequential(
nn.Linear(layer_sizes[0], layer_sizes[0]),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(layer_sizes[0], layer_sizes[0])
)
self.layer_sizes = layer_sizes
if len(layer_sizes) < 2:
raise ValueError()
self.layers = nn.ModuleList()
self.act = nn.LeakyReLU(negative_slope=0.01, inplace=True)
self.dropout = nn.Dropout(dropout)
for i in range(len(layer_sizes) - 1):
self.layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
def forward(self, x):
if self.arl:
x = x * self.attention(x)
for layer in self.layers[:-1]:
x = self.dropout(self.act(layer(x)))
x = self.layers[-1](x)
return x
class BART(nn.Module):
def __init__(self, bartconfig, class_num=100):
super().__init__()
d_model = bartconfig.d_model
self.decoder_emb = nn.Embedding(class_num, d_model)
self.bart = BartModel(bartconfig)
def forward(self, x_encoder, x_decoder, attn_mask_encoder=None, attn_mask_decoder=None):
emb_encoder = x_encoder
emb_decoder = self.decoder_emb(x_decoder)
y = self.bart(inputs_embeds=emb_encoder, decoder_inputs_embeds=emb_decoder,
attention_mask=attn_mask_encoder, decoder_attention_mask=attn_mask_decoder,
output_hidden_states=False)
y = y.last_hidden_state
return y
def encode(self, x_encoder, attn_mask_encoder=None):
emb_encoder = x_encoder
y = self.bart.encoder(inputs_embeds=emb_encoder, attention_mask=attn_mask_encoder, output_hidden_states=False)
y = y.last_hidden_state
return y
class ML_BART(nn.Module):
def __init__(self, bartconfig, class_num=[180, 256], pretrain=False, music_dim=512):
super().__init__()
d_model = bartconfig.d_model
self.decoder_emb2 = nn.ModuleList([
nn.Embedding(class_num[0] + 1, d_model // 4),
nn.Embedding(class_num[1] + 1, d_model // 4)
])
self.decoder = MLP([music_dim, d_model // 2])
self.bart = BartModel(bartconfig)
self.pretrain = pretrain
self.encoder = MLP([music_dim, d_model])
self.lora_config = LoraConfig(
r=4,
lora_alpha=16,
lora_dropout=0.1
)
def forward(self, x_encoder, x_decoder, attn_mask_encoder=None, attn_mask_decoder=None):
# emb_encoder = x_encoder
emb_encoder = self.encoder(x_encoder)
if self.pretrain:
# emb_decoder = x_decoder
emb_decoder = self.encoder(x_decoder)
else:
emb_decoder = torch.concatenate(
[self.decoder_emb2[0](x_decoder[..., 0]), self.decoder_emb2[1](x_decoder[..., 1]),
self.decoder(x_encoder)], dim=-1)
y = self.bart(inputs_embeds=emb_encoder, decoder_inputs_embeds=emb_decoder,
attention_mask=attn_mask_encoder, decoder_attention_mask=attn_mask_decoder,
output_hidden_states=False)
y = y.last_hidden_state
return y
def encode(self, x_encoder, attn_mask_encoder=None):
# emb_encoder = x_encoder
emb_encoder = self.encoder(x_encoder)
y = self.bart.encoder(inputs_embeds=emb_encoder, attention_mask=attn_mask_encoder, output_hidden_states=False)
y = y.last_hidden_state
return y
def reset_decoder(self):
for name, param in self.bart.decoder.named_parameters():
if param.dim() >= 2:
init.xavier_uniform_(param)
elif param.dim() == 1:
init.zeros_(param)
class ML_Classifier(nn.Module):
def __init__(self, hidden_dim=512, class_num=[180, 256]):
super().__init__()
self.classifier = nn.ModuleList([
MLP([hidden_dim, hidden_dim, class_num[0] + 1]),
MLP([hidden_dim, hidden_dim, class_num[1] + 1])
])
def forward(self, x):
h = self.classifier[0](x)
v = self.classifier[1](x)
return h, v
class SelfAttention(nn.Module):
def __init__(self, input_dim, da, r):
super().__init__()
self.ws1 = nn.Linear(input_dim, da, bias=False)
self.ws2 = nn.Linear(da, r, bias=False)
def forward(self, h):
attn_mat = F.softmax(self.ws2(torch.tanh(self.ws1(h))), dim=1)
attn_mat = attn_mat.permute(0, 2, 1)
return attn_mat
class Sequence_Classifier(nn.Module):
def __init__(self, class_num=1, hs=512, da=512, r=8):
super().__init__()
self.attention = SelfAttention(hs, da, r)
self.classifier = MLP([hs * r, (hs * r + class_num) // 2, class_num])
def forward(self, x):
attn_mat = self.attention(x)
m = torch.bmm(attn_mat, x)
flatten = m.view(m.size()[0], -1)
res = self.classifier(flatten)
return res
class Token_Predictor(nn.Module):
def __init__(self, hidden_dim=512, class_num=1):
super().__init__()
self.classifier = MLP([hidden_dim, (hidden_dim + class_num) // 2, class_num])
def forward(self, x):
x = self.classifier(x)
return x
class Skip_BART(nn.Module,
PyTorchModelHubMixin
):
def __init__(self, class_num=[180, 256], max_position_embeddings=1024, hidden_size=1024, layers=8, heads=8, ffn_dims=2048, pretrain=False):
super().__init__()
self.config = BartConfig(max_position_embeddings=max_position_embeddings,
d_model=hidden_size,
encoder_layers=layers,
encoder_ffn_dim=ffn_dims,
encoder_attention_heads=heads,
decoder_layers=layers,
decoder_ffn_dim=ffn_dims,
decoder_attention_heads=heads
)
self.model = ML_BART(self.config, class_num = class_num, pretrain = pretrain)
def forward(self, x_encoder, x_decoder, attn_mask_encoder=None, attn_mask_decoder=None):
return self.model(x_encoder, x_decoder, attn_mask_encoder, attn_mask_decoder)
|