Moayed commited on
Commit
51a7ef6
·
1 Parent(s): 90006ce

improve memeory usage

Browse files
GenAU/src/models/genau_ddpm.py CHANGED
@@ -2318,77 +2318,85 @@ class GenAu(DDPM):
2318
 
2319
  use_ddim = ddim_steps is not None
2320
 
2321
- with self.ema_scope("Plotting", use_ema=use_ema):
2322
- fnames = list(batch["fname"])
2323
- _, c = self.get_input(
2324
- batch,
2325
- self.first_stage_key, # fbank
2326
- unconditional_prob_cfg=0.0, # Do not output unconditional information in the c
2327
- return_first_stage_encode=False,
2328
- )
2329
-
2330
- c = self.filter_useful_cond_dict(c)
2331
- text = batch['text']
2332
-
2333
- # Generate multiple samples
2334
- num_samples = len(batch['text'])
2335
- batch_size = len(batch['text']) * n_gen
2336
-
2337
- # Generate multiple samples at a time and filter out the best
2338
- # The condition to the diffusion wrapper can have many format
2339
- for cond_key in c.keys():
2340
- if isinstance(c[cond_key], list):
2341
- for i in range(len(c[cond_key])):
2342
- c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0)
2343
- elif isinstance(c[cond_key], dict):
2344
- for k in c[cond_key].keys():
2345
- c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0)
2346
- else:
2347
- c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0)
2348
-
2349
- text = text * n_gen
2350
- if unconditional_guidance_scale != 1.0:
2351
- unconditional_conditioning = {}
2352
- for key in self.cond_stage_model_metadata:
2353
- model_idx = self.cond_stage_model_metadata[key]["model_idx"]
2354
- unconditional_conditioning[key] = self.cond_stage_models[
2355
- model_idx
2356
- ].get_unconditional_condition(batch_size)
2357
-
2358
- # Prepare X_T
2359
- # shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size)
2360
- x_T = self.generate_noise_for_batch(batch, self.channels, self.latent_t_size, self.latent_f_size, n_gen=n_gen).to(self.device)
2361
-
2362
- samples, _ = self.sample_log(
2363
- cond=c,
2364
- batch_size=batch_size,
2365
- x_T=x_T,
2366
- ddim=use_ddim,
2367
- ddim_steps=ddim_steps,
2368
- eta=ddim_eta,
2369
- unconditional_guidance_scale=unconditional_guidance_scale,
2370
- unconditional_conditioning=unconditional_conditioning,
2371
- use_plms=use_plms,
2372
- )
2373
 
2374
- mel = self.decode_first_stage(samples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2375
 
2376
- waveform = self.mel_spectrogram_to_waveform(
2377
- mel, savepath=waveform_save_dir, bs=None, name=fnames, save=False
2378
- )
2379
- if n_gen > 1:
2380
- best_index = []
2381
- similarity = self.clap.cos_similarity(
2382
- torch.FloatTensor(waveform).squeeze(1), text
2383
  )
2384
- for i in range(num_samples):
2385
- candidates = similarity[i :: num_samples]
2386
- max_index = torch.argmax(candidates).item()
2387
- best_index.append(i + max_index * num_samples)
2388
-
2389
- waveform = waveform[best_index]
2390
-
2391
- waveform_save_paths = self.save_waveform(waveform, waveform_save_dir, name=fnames)
 
 
 
 
 
 
 
 
2392
  return waveform_save_paths
2393
 
2394
  @torch.no_grad()
 
2318
 
2319
  use_ddim = ddim_steps is not None
2320
 
2321
+ with torch.no_grad():
2322
+ with self.ema_scope("Plotting", use_ema=use_ema):
2323
+ # offload first stage model to CPU
2324
+ print("Offloading first stage model to CPU for inference...")
2325
+ self.first_stage_model.to("cpu")
2326
+ fnames = list(batch["fname"])
2327
+ _, c = self.get_input(
2328
+ batch,
2329
+ self.first_stage_key, # fbank
2330
+ unconditional_prob_cfg=0.0, # Do not output unconditional information in the c
2331
+ return_first_stage_encode=False,
2332
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2333
 
2334
+ c = self.filter_useful_cond_dict(c)
2335
+ text = batch['text']
2336
+
2337
+ # Generate multiple samples
2338
+ num_samples = len(batch['text'])
2339
+ batch_size = len(batch['text']) * n_gen
2340
+
2341
+ # Generate multiple samples at a time and filter out the best
2342
+ # The condition to the diffusion wrapper can have many format
2343
+ for cond_key in c.keys():
2344
+ if isinstance(c[cond_key], list):
2345
+ for i in range(len(c[cond_key])):
2346
+ c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0)
2347
+ elif isinstance(c[cond_key], dict):
2348
+ for k in c[cond_key].keys():
2349
+ c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0)
2350
+ else:
2351
+ c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0)
2352
+
2353
+ text = text * n_gen
2354
+ if unconditional_guidance_scale != 1.0:
2355
+ unconditional_conditioning = {}
2356
+ for key in self.cond_stage_model_metadata:
2357
+ model_idx = self.cond_stage_model_metadata[key]["model_idx"]
2358
+ unconditional_conditioning[key] = self.cond_stage_models[
2359
+ model_idx
2360
+ ].get_unconditional_condition(batch_size)
2361
+
2362
+ # Prepare X_T
2363
+ # shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size)
2364
+ x_T = self.generate_noise_for_batch(batch, self.channels, self.latent_t_size, self.latent_f_size, n_gen=n_gen).to(self.device)
2365
+
2366
+ samples, _ = self.sample_log(
2367
+ cond=c,
2368
+ batch_size=batch_size,
2369
+ x_T=x_T,
2370
+ ddim=use_ddim,
2371
+ ddim_steps=ddim_steps,
2372
+ eta=ddim_eta,
2373
+ unconditional_guidance_scale=unconditional_guidance_scale,
2374
+ unconditional_conditioning=unconditional_conditioning,
2375
+ use_plms=use_plms,
2376
+ )
2377
+ print("Moving first stage model back to GPU for decoding...")
2378
+ self.first_stage_model.to("cuda")
2379
+ mel = self.decode_first_stage(samples)
2380
 
2381
+ waveform = self.mel_spectrogram_to_waveform(
2382
+ mel, savepath=waveform_save_dir, bs=None, name=fnames, save=False
 
 
 
 
 
2383
  )
2384
+ if n_gen > 1:
2385
+ best_index = []
2386
+ similarity = self.clap.cos_similarity(
2387
+ torch.FloatTensor(waveform).squeeze(1), text
2388
+ )
2389
+ for i in range(num_samples):
2390
+ candidates = similarity[i :: num_samples]
2391
+ max_index = torch.argmax(candidates).item()
2392
+ best_index.append(i + max_index * num_samples)
2393
+
2394
+ waveform = waveform[best_index]
2395
+
2396
+ waveform_save_paths = self.save_waveform(waveform, waveform_save_dir, name=fnames)
2397
+
2398
+ print("Offloading first stage model to CPU for inference...")
2399
+ self.first_stage_model.to("cpu")
2400
  return waveform_save_paths
2401
 
2402
  @torch.no_grad()
GenAU/src/modules/conditional/conditional_models.py CHANGED
@@ -1507,7 +1507,7 @@ class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
1507
  audio_dict = get_audio_features(
1508
  audio_data,
1509
  mel,
1510
- 480000,
1511
  data_truncating="fusion",
1512
  data_filling="repeatpad",
1513
  audio_cfg=self.model_cfg["audio_cfg"],
 
1507
  audio_dict = get_audio_features(
1508
  audio_data,
1509
  mel,
1510
+ 460000,
1511
  data_truncating="fusion",
1512
  data_filling="repeatpad",
1513
  audio_cfg=self.model_cfg["audio_cfg"],