Spaces:
Runtime error
Runtime error
improve memeory usage
Browse files
GenAU/src/models/genau_ddpm.py
CHANGED
|
@@ -2318,77 +2318,85 @@ class GenAu(DDPM):
|
|
| 2318 |
|
| 2319 |
use_ddim = ddim_steps is not None
|
| 2320 |
|
| 2321 |
-
with
|
| 2322 |
-
|
| 2323 |
-
|
| 2324 |
-
|
| 2325 |
-
self.
|
| 2326 |
-
|
| 2327 |
-
|
| 2328 |
-
|
| 2329 |
-
|
| 2330 |
-
|
| 2331 |
-
|
| 2332 |
-
|
| 2333 |
-
# Generate multiple samples
|
| 2334 |
-
num_samples = len(batch['text'])
|
| 2335 |
-
batch_size = len(batch['text']) * n_gen
|
| 2336 |
-
|
| 2337 |
-
# Generate multiple samples at a time and filter out the best
|
| 2338 |
-
# The condition to the diffusion wrapper can have many format
|
| 2339 |
-
for cond_key in c.keys():
|
| 2340 |
-
if isinstance(c[cond_key], list):
|
| 2341 |
-
for i in range(len(c[cond_key])):
|
| 2342 |
-
c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0)
|
| 2343 |
-
elif isinstance(c[cond_key], dict):
|
| 2344 |
-
for k in c[cond_key].keys():
|
| 2345 |
-
c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0)
|
| 2346 |
-
else:
|
| 2347 |
-
c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0)
|
| 2348 |
-
|
| 2349 |
-
text = text * n_gen
|
| 2350 |
-
if unconditional_guidance_scale != 1.0:
|
| 2351 |
-
unconditional_conditioning = {}
|
| 2352 |
-
for key in self.cond_stage_model_metadata:
|
| 2353 |
-
model_idx = self.cond_stage_model_metadata[key]["model_idx"]
|
| 2354 |
-
unconditional_conditioning[key] = self.cond_stage_models[
|
| 2355 |
-
model_idx
|
| 2356 |
-
].get_unconditional_condition(batch_size)
|
| 2357 |
-
|
| 2358 |
-
# Prepare X_T
|
| 2359 |
-
# shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size)
|
| 2360 |
-
x_T = self.generate_noise_for_batch(batch, self.channels, self.latent_t_size, self.latent_f_size, n_gen=n_gen).to(self.device)
|
| 2361 |
-
|
| 2362 |
-
samples, _ = self.sample_log(
|
| 2363 |
-
cond=c,
|
| 2364 |
-
batch_size=batch_size,
|
| 2365 |
-
x_T=x_T,
|
| 2366 |
-
ddim=use_ddim,
|
| 2367 |
-
ddim_steps=ddim_steps,
|
| 2368 |
-
eta=ddim_eta,
|
| 2369 |
-
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 2370 |
-
unconditional_conditioning=unconditional_conditioning,
|
| 2371 |
-
use_plms=use_plms,
|
| 2372 |
-
)
|
| 2373 |
|
| 2374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2375 |
|
| 2376 |
-
|
| 2377 |
-
|
| 2378 |
-
)
|
| 2379 |
-
if n_gen > 1:
|
| 2380 |
-
best_index = []
|
| 2381 |
-
similarity = self.clap.cos_similarity(
|
| 2382 |
-
torch.FloatTensor(waveform).squeeze(1), text
|
| 2383 |
)
|
| 2384 |
-
|
| 2385 |
-
|
| 2386 |
-
|
| 2387 |
-
|
| 2388 |
-
|
| 2389 |
-
|
| 2390 |
-
|
| 2391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2392 |
return waveform_save_paths
|
| 2393 |
|
| 2394 |
@torch.no_grad()
|
|
|
|
| 2318 |
|
| 2319 |
use_ddim = ddim_steps is not None
|
| 2320 |
|
| 2321 |
+
with torch.no_grad():
|
| 2322 |
+
with self.ema_scope("Plotting", use_ema=use_ema):
|
| 2323 |
+
# offload first stage model to CPU
|
| 2324 |
+
print("Offloading first stage model to CPU for inference...")
|
| 2325 |
+
self.first_stage_model.to("cpu")
|
| 2326 |
+
fnames = list(batch["fname"])
|
| 2327 |
+
_, c = self.get_input(
|
| 2328 |
+
batch,
|
| 2329 |
+
self.first_stage_key, # fbank
|
| 2330 |
+
unconditional_prob_cfg=0.0, # Do not output unconditional information in the c
|
| 2331 |
+
return_first_stage_encode=False,
|
| 2332 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2333 |
|
| 2334 |
+
c = self.filter_useful_cond_dict(c)
|
| 2335 |
+
text = batch['text']
|
| 2336 |
+
|
| 2337 |
+
# Generate multiple samples
|
| 2338 |
+
num_samples = len(batch['text'])
|
| 2339 |
+
batch_size = len(batch['text']) * n_gen
|
| 2340 |
+
|
| 2341 |
+
# Generate multiple samples at a time and filter out the best
|
| 2342 |
+
# The condition to the diffusion wrapper can have many format
|
| 2343 |
+
for cond_key in c.keys():
|
| 2344 |
+
if isinstance(c[cond_key], list):
|
| 2345 |
+
for i in range(len(c[cond_key])):
|
| 2346 |
+
c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0)
|
| 2347 |
+
elif isinstance(c[cond_key], dict):
|
| 2348 |
+
for k in c[cond_key].keys():
|
| 2349 |
+
c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0)
|
| 2350 |
+
else:
|
| 2351 |
+
c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0)
|
| 2352 |
+
|
| 2353 |
+
text = text * n_gen
|
| 2354 |
+
if unconditional_guidance_scale != 1.0:
|
| 2355 |
+
unconditional_conditioning = {}
|
| 2356 |
+
for key in self.cond_stage_model_metadata:
|
| 2357 |
+
model_idx = self.cond_stage_model_metadata[key]["model_idx"]
|
| 2358 |
+
unconditional_conditioning[key] = self.cond_stage_models[
|
| 2359 |
+
model_idx
|
| 2360 |
+
].get_unconditional_condition(batch_size)
|
| 2361 |
+
|
| 2362 |
+
# Prepare X_T
|
| 2363 |
+
# shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size)
|
| 2364 |
+
x_T = self.generate_noise_for_batch(batch, self.channels, self.latent_t_size, self.latent_f_size, n_gen=n_gen).to(self.device)
|
| 2365 |
+
|
| 2366 |
+
samples, _ = self.sample_log(
|
| 2367 |
+
cond=c,
|
| 2368 |
+
batch_size=batch_size,
|
| 2369 |
+
x_T=x_T,
|
| 2370 |
+
ddim=use_ddim,
|
| 2371 |
+
ddim_steps=ddim_steps,
|
| 2372 |
+
eta=ddim_eta,
|
| 2373 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 2374 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 2375 |
+
use_plms=use_plms,
|
| 2376 |
+
)
|
| 2377 |
+
print("Moving first stage model back to GPU for decoding...")
|
| 2378 |
+
self.first_stage_model.to("cuda")
|
| 2379 |
+
mel = self.decode_first_stage(samples)
|
| 2380 |
|
| 2381 |
+
waveform = self.mel_spectrogram_to_waveform(
|
| 2382 |
+
mel, savepath=waveform_save_dir, bs=None, name=fnames, save=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2383 |
)
|
| 2384 |
+
if n_gen > 1:
|
| 2385 |
+
best_index = []
|
| 2386 |
+
similarity = self.clap.cos_similarity(
|
| 2387 |
+
torch.FloatTensor(waveform).squeeze(1), text
|
| 2388 |
+
)
|
| 2389 |
+
for i in range(num_samples):
|
| 2390 |
+
candidates = similarity[i :: num_samples]
|
| 2391 |
+
max_index = torch.argmax(candidates).item()
|
| 2392 |
+
best_index.append(i + max_index * num_samples)
|
| 2393 |
+
|
| 2394 |
+
waveform = waveform[best_index]
|
| 2395 |
+
|
| 2396 |
+
waveform_save_paths = self.save_waveform(waveform, waveform_save_dir, name=fnames)
|
| 2397 |
+
|
| 2398 |
+
print("Offloading first stage model to CPU for inference...")
|
| 2399 |
+
self.first_stage_model.to("cpu")
|
| 2400 |
return waveform_save_paths
|
| 2401 |
|
| 2402 |
@torch.no_grad()
|
GenAU/src/modules/conditional/conditional_models.py
CHANGED
|
@@ -1507,7 +1507,7 @@ class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
|
|
| 1507 |
audio_dict = get_audio_features(
|
| 1508 |
audio_data,
|
| 1509 |
mel,
|
| 1510 |
-
|
| 1511 |
data_truncating="fusion",
|
| 1512 |
data_filling="repeatpad",
|
| 1513 |
audio_cfg=self.model_cfg["audio_cfg"],
|
|
|
|
| 1507 |
audio_dict = get_audio_features(
|
| 1508 |
audio_data,
|
| 1509 |
mel,
|
| 1510 |
+
460000,
|
| 1511 |
data_truncating="fusion",
|
| 1512 |
data_filling="repeatpad",
|
| 1513 |
audio_cfg=self.model_cfg["audio_cfg"],
|