Spaces:
Runtime error
Runtime error
| import time | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from torch import autocast | |
| from torch.cuda.amp import GradScaler | |
| from diffusion.logger import utils | |
| from diffusion.logger.saver import Saver | |
| def test(args, model, vocoder, loader_test, saver): | |
| print(' [*] testing...') | |
| model.eval() | |
| # losses | |
| test_loss = 0. | |
| # intialization | |
| num_batches = len(loader_test) | |
| rtf_all = [] | |
| # run | |
| with torch.no_grad(): | |
| for bidx, data in enumerate(loader_test): | |
| fn = data['name'][0].split("/")[-1] | |
| speaker = data['name'][0].split("/")[-2] | |
| print('--------') | |
| print('{}/{} - {}'.format(bidx, num_batches, fn)) | |
| # unpack data | |
| for k in data.keys(): | |
| if not k.startswith('name'): | |
| data[k] = data[k].to(args.device) | |
| print('>>', data['name'][0]) | |
| # forward | |
| st_time = time.time() | |
| mel = model( | |
| data['units'], | |
| data['f0'], | |
| data['volume'], | |
| data['spk_id'], | |
| gt_spec=None if model.k_step_max == model.timesteps else data['mel'], | |
| infer=True, | |
| infer_speedup=args.infer.speedup, | |
| method=args.infer.method, | |
| k_step=model.k_step_max | |
| ) | |
| signal = vocoder.infer(mel, data['f0']) | |
| ed_time = time.time() | |
| # RTF | |
| run_time = ed_time - st_time | |
| song_time = signal.shape[-1] / args.data.sampling_rate | |
| rtf = run_time / song_time | |
| print('RTF: {} | {} / {}'.format(rtf, run_time, song_time)) | |
| rtf_all.append(rtf) | |
| # loss | |
| for i in range(args.train.batch_size): | |
| loss = model( | |
| data['units'], | |
| data['f0'], | |
| data['volume'], | |
| data['spk_id'], | |
| gt_spec=data['mel'], | |
| infer=False, | |
| k_step=model.k_step_max) | |
| test_loss += loss.item() | |
| # log mel | |
| saver.log_spec(f"{speaker}_{fn}.wav", data['mel'], mel) | |
| # log audi | |
| path_audio = data['name_ext'][0] | |
| audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate) | |
| if len(audio.shape) > 1: | |
| audio = librosa.to_mono(audio) | |
| audio = torch.from_numpy(audio).unsqueeze(0).to(signal) | |
| saver.log_audio({f"{speaker}_{fn}_gt.wav": audio,f"{speaker}_{fn}_pred.wav": signal}) | |
| # report | |
| test_loss /= args.train.batch_size | |
| test_loss /= num_batches | |
| # check | |
| print(' [test_loss] test_loss:', test_loss) | |
| print(' Real Time Factor', np.mean(rtf_all)) | |
| return test_loss | |
| def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test): | |
| # saver | |
| saver = Saver(args, initial_global_step=initial_global_step) | |
| # model size | |
| params_count = utils.get_network_paras_amount({'model': model}) | |
| saver.log_info('--- model size ---') | |
| saver.log_info(params_count) | |
| # run | |
| num_batches = len(loader_train) | |
| model.train() | |
| saver.log_info('======= start training =======') | |
| scaler = GradScaler() | |
| if args.train.amp_dtype == 'fp32': | |
| dtype = torch.float32 | |
| elif args.train.amp_dtype == 'fp16': | |
| dtype = torch.float16 | |
| elif args.train.amp_dtype == 'bf16': | |
| dtype = torch.bfloat16 | |
| else: | |
| raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype) | |
| saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step") | |
| for epoch in range(args.train.epochs): | |
| for batch_idx, data in enumerate(loader_train): | |
| saver.global_step_increment() | |
| optimizer.zero_grad() | |
| # unpack data | |
| for k in data.keys(): | |
| if not k.startswith('name'): | |
| data[k] = data[k].to(args.device) | |
| # forward | |
| if dtype == torch.float32: | |
| loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], | |
| aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False, k_step=model.k_step_max) | |
| else: | |
| with autocast(device_type=args.device, dtype=dtype): | |
| loss = model(data['units'], data['f0'], data['volume'], data['spk_id'], | |
| aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False, k_step=model.k_step_max) | |
| # handle nan loss | |
| if torch.isnan(loss): | |
| raise ValueError(' [x] nan loss ') | |
| else: | |
| # backpropagate | |
| if dtype == torch.float32: | |
| loss.backward() | |
| optimizer.step() | |
| else: | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| scheduler.step() | |
| # log loss | |
| if saver.global_step % args.train.interval_log == 0: | |
| current_lr = optimizer.param_groups[0]['lr'] | |
| saver.log_info( | |
| 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format( | |
| epoch, | |
| batch_idx, | |
| num_batches, | |
| args.env.expdir, | |
| args.train.interval_log/saver.get_interval_time(), | |
| current_lr, | |
| loss.item(), | |
| saver.get_total_time(), | |
| saver.global_step | |
| ) | |
| ) | |
| saver.log_value({ | |
| 'train/loss': loss.item() | |
| }) | |
| saver.log_value({ | |
| 'train/lr': current_lr | |
| }) | |
| # validation | |
| if saver.global_step % args.train.interval_val == 0: | |
| optimizer_save = optimizer if args.train.save_opt else None | |
| # save latest | |
| saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}') | |
| last_val_step = saver.global_step - args.train.interval_val | |
| if last_val_step % args.train.interval_force_save != 0: | |
| saver.delete_model(postfix=f'{last_val_step}') | |
| # run testing set | |
| test_loss = test(args, model, vocoder, loader_test, saver) | |
| # log loss | |
| saver.log_info( | |
| ' --- <validation> --- \nloss: {:.3f}. '.format( | |
| test_loss, | |
| ) | |
| ) | |
| saver.log_value({ | |
| 'validation/loss': test_loss | |
| }) | |
| model.train() | |