Spaces:
Runtime error
Runtime error
Update pico_model.py
Browse files- pico_model.py +5 -57
pico_model.py
CHANGED
|
@@ -8,40 +8,6 @@ import torch.nn.functional as F
|
|
| 8 |
from diffusers.utils.torch_utils import randn_tensor
|
| 9 |
from diffusers import DDPMScheduler, UNet2DConditionModel
|
| 10 |
|
| 11 |
-
from audioldm.audio.stft import TacotronSTFT
|
| 12 |
-
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
|
| 13 |
-
from audioldm.utils import default_audioldm_config, get_metadata
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def build_pretrained_models(name):
|
| 18 |
-
checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu")
|
| 19 |
-
scale_factor = checkpoint["state_dict"]["scale_factor"].item()
|
| 20 |
-
|
| 21 |
-
vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
|
| 22 |
-
|
| 23 |
-
config = default_audioldm_config(name)
|
| 24 |
-
vae_config = config["model"]["params"]["first_stage_config"]["params"]
|
| 25 |
-
vae_config["scale_factor"] = scale_factor
|
| 26 |
-
|
| 27 |
-
vae = AutoencoderKL(**vae_config)
|
| 28 |
-
vae.load_state_dict(vae_state_dict)
|
| 29 |
-
|
| 30 |
-
fn_STFT = TacotronSTFT(
|
| 31 |
-
config["preprocessing"]["stft"]["filter_length"],
|
| 32 |
-
config["preprocessing"]["stft"]["hop_length"],
|
| 33 |
-
config["preprocessing"]["stft"]["win_length"],
|
| 34 |
-
config["preprocessing"]["mel"]["n_mel_channels"],
|
| 35 |
-
config["preprocessing"]["audio"]["sampling_rate"],
|
| 36 |
-
config["preprocessing"]["mel"]["mel_fmin"],
|
| 37 |
-
config["preprocessing"]["mel"]["mel_fmax"],
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
vae.eval()
|
| 41 |
-
fn_STFT.eval()
|
| 42 |
-
|
| 43 |
-
return vae, fn_STFT
|
| 44 |
-
|
| 45 |
def _init_layer(layer):
|
| 46 |
"""Initialize a Linear or Convolutional layer. """
|
| 47 |
nn.init.xavier_uniform_(layer.weight)
|
|
@@ -243,7 +209,7 @@ class ClapText_Onset_2_Audio_Diffusion(nn.Module):
|
|
| 243 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 244 |
import laion_clap
|
| 245 |
from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
|
| 246 |
-
|
| 247 |
class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
|
| 248 |
def __init__(self,
|
| 249 |
scheduler_name,
|
|
@@ -260,31 +226,12 @@ class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
|
|
| 260 |
ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
|
| 261 |
del_parameter_key = ["text_branch.embeddings.position_ids"]
|
| 262 |
ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
|
| 263 |
-
diffusion_ckpt = torch.load(diffusion_pt)
|
| 264 |
del diffusion_ckpt["class_emb.weight"]
|
| 265 |
ckpt.update(diffusion_ckpt)
|
| 266 |
self.load_state_dict(ckpt)
|
| 267 |
|
| 268 |
-
self.event_list =
|
| 269 |
-
"burping_belching", # 0
|
| 270 |
-
"car_horn_honking", #
|
| 271 |
-
"cat_meowing", #
|
| 272 |
-
"cow_mooing", #
|
| 273 |
-
"dog_barking", #
|
| 274 |
-
"door_knocking", #
|
| 275 |
-
"door_slamming", #
|
| 276 |
-
"explosion", #
|
| 277 |
-
"gunshot", # 8
|
| 278 |
-
"sheep_goat_bleating", #
|
| 279 |
-
"sneeze", #
|
| 280 |
-
"spraying", #
|
| 281 |
-
"thump_thud", #
|
| 282 |
-
"train_horn", #
|
| 283 |
-
"tapping_clicking_clanking", #
|
| 284 |
-
"woman_laughing", #
|
| 285 |
-
"duck_quacking", # 16
|
| 286 |
-
"whistling", #
|
| 287 |
-
]
|
| 288 |
self.events_emb = self.freeze_text_encoder.get_text_embedding(self.event_list, use_tensor=False)
|
| 289 |
|
| 290 |
|
|
@@ -300,10 +247,11 @@ class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
|
|
| 300 |
for event_timestamp in timestampCaption.split(' and '):
|
| 301 |
# event_timestamp : event1__onset1-offset1_onset2-offset2
|
| 302 |
(event, instance) = event_timestamp.split(' at ')
|
| 303 |
-
|
| 304 |
# instance : onset1-offset1_onset2-offset2
|
| 305 |
event_emb = self.freeze_text_encoder.get_text_embedding([event, ""], use_tensor=False)[0]
|
| 306 |
event_id = np.argmax(cosine_similarity(event_emb.reshape(1, -1), self.events_emb))
|
|
|
|
| 307 |
for start_end in instance.split('_'):
|
| 308 |
(start, end) = start_end.split('-')
|
| 309 |
start, end = int(float(start)*250/10), int(float(end)*250/10)
|
|
|
|
| 8 |
from diffusers.utils.torch_utils import randn_tensor
|
| 9 |
from diffusers import DDPMScheduler, UNet2DConditionModel
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def _init_layer(layer):
|
| 12 |
"""Initialize a Linear or Convolutional layer. """
|
| 13 |
nn.init.xavier_uniform_(layer.weight)
|
|
|
|
| 209 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 210 |
import laion_clap
|
| 211 |
from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
|
| 212 |
+
from llm_preprocess import get_event
|
| 213 |
class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
|
| 214 |
def __init__(self,
|
| 215 |
scheduler_name,
|
|
|
|
| 226 |
ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
|
| 227 |
del_parameter_key = ["text_branch.embeddings.position_ids"]
|
| 228 |
ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
|
| 229 |
+
diffusion_ckpt = torch.load(diffusion_pt, map_location=self.device)
|
| 230 |
del diffusion_ckpt["class_emb.weight"]
|
| 231 |
ckpt.update(diffusion_ckpt)
|
| 232 |
self.load_state_dict(ckpt)
|
| 233 |
|
| 234 |
+
self.event_list = get_event()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
self.events_emb = self.freeze_text_encoder.get_text_embedding(self.event_list, use_tensor=False)
|
| 236 |
|
| 237 |
|
|
|
|
| 247 |
for event_timestamp in timestampCaption.split(' and '):
|
| 248 |
# event_timestamp : event1__onset1-offset1_onset2-offset2
|
| 249 |
(event, instance) = event_timestamp.split(' at ')
|
| 250 |
+
|
| 251 |
# instance : onset1-offset1_onset2-offset2
|
| 252 |
event_emb = self.freeze_text_encoder.get_text_embedding([event, ""], use_tensor=False)[0]
|
| 253 |
event_id = np.argmax(cosine_similarity(event_emb.reshape(1, -1), self.events_emb))
|
| 254 |
+
events.append(self.event_list[event_id])
|
| 255 |
for start_end in instance.split('_'):
|
| 256 |
(start, end) = start_end.split('-')
|
| 257 |
start, end = int(float(start)*250/10), int(float(end)*250/10)
|