Spaces:
Build error
Build error
update
Browse files- checkpoints/diffsinger/config.yaml +393 -0
- checkpoints/diffsinger/model_ckpt_steps_160000.ckpt +3 -0
- docs/diffspeech.md +62 -0
- docs/prepare_vocoder.md +1 -1
- egs/datasets/audio/lj/ds.yaml +29 -0
- egs/egs_bases/tts/ds.yaml +32 -0
- inference/tts/ds.py +30 -0
- modules/tts/commons/align_ops.py +2 -3
- modules/tts/diffspeech/net.py +110 -0
- modules/tts/diffspeech/shallow_diffusion_tts.py +281 -0
- tasks/tts/diffspeech.py +111 -0
checkpoints/diffsinger/config.yaml
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
K_step: 71
|
| 2 |
+
accumulate_grad_batches: 1
|
| 3 |
+
amp: false
|
| 4 |
+
audio_num_mel_bins: 80
|
| 5 |
+
audio_sample_rate: 22050
|
| 6 |
+
base_config:
|
| 7 |
+
- egs/egs_bases/tts/ds.yaml
|
| 8 |
+
- ./fs2_orig.yaml
|
| 9 |
+
binarization_args:
|
| 10 |
+
min_sil_duration: 0.1
|
| 11 |
+
shuffle: false
|
| 12 |
+
test_range:
|
| 13 |
+
- 0
|
| 14 |
+
- 523
|
| 15 |
+
train_range:
|
| 16 |
+
- 871
|
| 17 |
+
- -1
|
| 18 |
+
trim_eos_bos: false
|
| 19 |
+
valid_range:
|
| 20 |
+
- 523
|
| 21 |
+
- 871
|
| 22 |
+
with_align: true
|
| 23 |
+
with_f0: true
|
| 24 |
+
with_f0cwt: true
|
| 25 |
+
with_linear: false
|
| 26 |
+
with_spk_embed: false
|
| 27 |
+
with_wav: false
|
| 28 |
+
binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
|
| 29 |
+
binary_data_dir: data/binary/ljspeech_cwt
|
| 30 |
+
check_val_every_n_epoch: 10
|
| 31 |
+
clip_grad_norm: 1
|
| 32 |
+
clip_grad_value: 0
|
| 33 |
+
conv_use_pos: false
|
| 34 |
+
cwt_std_scale: 1.0
|
| 35 |
+
debug: false
|
| 36 |
+
dec_dilations:
|
| 37 |
+
- 1
|
| 38 |
+
- 1
|
| 39 |
+
- 1
|
| 40 |
+
- 1
|
| 41 |
+
dec_ffn_kernel_size: 9
|
| 42 |
+
dec_inp_add_noise: false
|
| 43 |
+
dec_kernel_size: 5
|
| 44 |
+
dec_layers: 4
|
| 45 |
+
dec_post_net_kernel: 3
|
| 46 |
+
decay_steps: 50000
|
| 47 |
+
decoder_rnn_dim: 0
|
| 48 |
+
decoder_type: fft
|
| 49 |
+
diff_decoder_type: wavenet
|
| 50 |
+
diff_loss_type: l1
|
| 51 |
+
dilation_cycle_length: 1
|
| 52 |
+
dropout: 0.0
|
| 53 |
+
ds_workers: 2
|
| 54 |
+
dur_predictor_kernel: 3
|
| 55 |
+
dur_predictor_layers: 2
|
| 56 |
+
enc_dec_norm: ln
|
| 57 |
+
enc_dilations:
|
| 58 |
+
- 1
|
| 59 |
+
- 1
|
| 60 |
+
- 1
|
| 61 |
+
- 1
|
| 62 |
+
enc_ffn_kernel_size: 9
|
| 63 |
+
enc_kernel_size: 5
|
| 64 |
+
enc_layers: 4
|
| 65 |
+
enc_post_net_kernel: 3
|
| 66 |
+
enc_pre_ln: true
|
| 67 |
+
enc_prenet: true
|
| 68 |
+
encoder_K: 8
|
| 69 |
+
encoder_type: fft
|
| 70 |
+
endless_ds: true
|
| 71 |
+
eval_max_batches: -1
|
| 72 |
+
f0_max: 600
|
| 73 |
+
f0_min: 80
|
| 74 |
+
ffn_act: gelu
|
| 75 |
+
ffn_hidden_size: 1024
|
| 76 |
+
fft_size: 1024
|
| 77 |
+
fmax: 7600
|
| 78 |
+
fmin: 80
|
| 79 |
+
frames_multiple: 1
|
| 80 |
+
fs2_ckpt: checkpoints/fs2_exp/model_ckpt_steps_160000.ckpt
|
| 81 |
+
gen_dir_name: ''
|
| 82 |
+
griffin_lim_iters: 30
|
| 83 |
+
hidden_size: 256
|
| 84 |
+
hop_size: 256
|
| 85 |
+
infer: false
|
| 86 |
+
keep_bins: 80
|
| 87 |
+
lambda_commit: 0.25
|
| 88 |
+
lambda_energy: 0.1
|
| 89 |
+
lambda_f0: 1.0
|
| 90 |
+
lambda_ph_dur: 0.1
|
| 91 |
+
lambda_sent_dur: 1.0
|
| 92 |
+
lambda_uv: 1.0
|
| 93 |
+
lambda_word_dur: 1.0
|
| 94 |
+
layers_in_block: 2
|
| 95 |
+
load_ckpt: ''
|
| 96 |
+
loud_norm: false
|
| 97 |
+
lr: 0.001
|
| 98 |
+
max_beta: 0.06
|
| 99 |
+
max_epochs: 1000
|
| 100 |
+
max_frames: 1548
|
| 101 |
+
max_input_tokens: 1550
|
| 102 |
+
max_sentences: 128
|
| 103 |
+
max_tokens: 30000
|
| 104 |
+
max_updates: 160000
|
| 105 |
+
max_valid_sentences: 1
|
| 106 |
+
max_valid_tokens: 60000
|
| 107 |
+
mel_losses: l1:0.5|ssim:0.5
|
| 108 |
+
mel_vmax: 1.5
|
| 109 |
+
mel_vmin: -6
|
| 110 |
+
min_frames: 0
|
| 111 |
+
num_ckpt_keep: 3
|
| 112 |
+
num_heads: 2
|
| 113 |
+
num_sanity_val_steps: 5
|
| 114 |
+
num_spk: 1
|
| 115 |
+
num_valid_plots: 10
|
| 116 |
+
optimizer_adam_beta1: 0.9
|
| 117 |
+
optimizer_adam_beta2: 0.98
|
| 118 |
+
out_wav_norm: false
|
| 119 |
+
pitch_extractor: parselmouth
|
| 120 |
+
pitch_key: pitch
|
| 121 |
+
pitch_type: cwt
|
| 122 |
+
predictor_dropout: 0.5
|
| 123 |
+
predictor_grad: 0.1
|
| 124 |
+
predictor_hidden: -1
|
| 125 |
+
predictor_kernel: 5
|
| 126 |
+
predictor_layers: 2
|
| 127 |
+
preprocess_args:
|
| 128 |
+
add_eos_bos: true
|
| 129 |
+
mfa_group_shuffle: false
|
| 130 |
+
mfa_offset: 0.02
|
| 131 |
+
nsample_per_mfa_group: 1000
|
| 132 |
+
reset_phone_dict: true
|
| 133 |
+
reset_word_dict: true
|
| 134 |
+
save_sil_mask: true
|
| 135 |
+
txt_processor: en
|
| 136 |
+
use_mfa: true
|
| 137 |
+
vad_max_silence_length: 12
|
| 138 |
+
wav_processors: []
|
| 139 |
+
with_phsep: true
|
| 140 |
+
preprocess_cls: egs.datasets.audio.lj.preprocess.LJPreprocess
|
| 141 |
+
print_nan_grads: false
|
| 142 |
+
processed_data_dir: data/processed/ljspeech
|
| 143 |
+
profile_infer: false
|
| 144 |
+
raw_data_dir: data/raw/LJSpeech-1.1
|
| 145 |
+
ref_norm_layer: bn
|
| 146 |
+
rename_tmux: true
|
| 147 |
+
residual_channels: 256
|
| 148 |
+
residual_layers: 20
|
| 149 |
+
resume_from_checkpoint: 0
|
| 150 |
+
save_best: false
|
| 151 |
+
save_codes:
|
| 152 |
+
- tasks
|
| 153 |
+
- modules
|
| 154 |
+
- egs
|
| 155 |
+
save_f0: false
|
| 156 |
+
save_gt: true
|
| 157 |
+
schedule_type: linear
|
| 158 |
+
scheduler: warmup
|
| 159 |
+
seed: 1234
|
| 160 |
+
sort_by_len: true
|
| 161 |
+
spec_max:
|
| 162 |
+
- -0.5982
|
| 163 |
+
- -0.0778
|
| 164 |
+
- 0.1205
|
| 165 |
+
- 0.2747
|
| 166 |
+
- 0.4657
|
| 167 |
+
- 0.5123
|
| 168 |
+
- 0.583
|
| 169 |
+
- 0.7093
|
| 170 |
+
- 0.6461
|
| 171 |
+
- 0.6101
|
| 172 |
+
- 0.7316
|
| 173 |
+
- 0.7715
|
| 174 |
+
- 0.7681
|
| 175 |
+
- 0.8349
|
| 176 |
+
- 0.7815
|
| 177 |
+
- 0.7591
|
| 178 |
+
- 0.791
|
| 179 |
+
- 0.7433
|
| 180 |
+
- 0.7352
|
| 181 |
+
- 0.6869
|
| 182 |
+
- 0.6854
|
| 183 |
+
- 0.6623
|
| 184 |
+
- 0.5353
|
| 185 |
+
- 0.6492
|
| 186 |
+
- 0.6909
|
| 187 |
+
- 0.6106
|
| 188 |
+
- 0.5761
|
| 189 |
+
- 0.5236
|
| 190 |
+
- 0.5638
|
| 191 |
+
- 0.4054
|
| 192 |
+
- 0.4545
|
| 193 |
+
- 0.3407
|
| 194 |
+
- 0.3037
|
| 195 |
+
- 0.338
|
| 196 |
+
- 0.1599
|
| 197 |
+
- 0.1603
|
| 198 |
+
- 0.2741
|
| 199 |
+
- 0.213
|
| 200 |
+
- 0.1569
|
| 201 |
+
- 0.1911
|
| 202 |
+
- 0.2324
|
| 203 |
+
- 0.1586
|
| 204 |
+
- 0.1221
|
| 205 |
+
- 0.0341
|
| 206 |
+
- -0.0558
|
| 207 |
+
- 0.0553
|
| 208 |
+
- -0.1153
|
| 209 |
+
- -0.0933
|
| 210 |
+
- -0.1171
|
| 211 |
+
- -0.005
|
| 212 |
+
- -0.1519
|
| 213 |
+
- -0.1629
|
| 214 |
+
- -0.0522
|
| 215 |
+
- -0.0739
|
| 216 |
+
- -0.2069
|
| 217 |
+
- -0.2405
|
| 218 |
+
- -0.1244
|
| 219 |
+
- -0.2582
|
| 220 |
+
- -0.1361
|
| 221 |
+
- -0.1575
|
| 222 |
+
- -0.1442
|
| 223 |
+
- 0.0513
|
| 224 |
+
- -0.1567
|
| 225 |
+
- -0.2
|
| 226 |
+
- 0.0086
|
| 227 |
+
- -0.0698
|
| 228 |
+
- 0.1385
|
| 229 |
+
- 0.0941
|
| 230 |
+
- 0.1864
|
| 231 |
+
- 0.1225
|
| 232 |
+
- 0.1389
|
| 233 |
+
- 0.1382
|
| 234 |
+
- 0.167
|
| 235 |
+
- 0.1007
|
| 236 |
+
- 0.1444
|
| 237 |
+
- 0.0888
|
| 238 |
+
- 0.1998
|
| 239 |
+
- 0.228
|
| 240 |
+
- 0.2932
|
| 241 |
+
- 0.3047
|
| 242 |
+
spec_min:
|
| 243 |
+
- -4.7574
|
| 244 |
+
- -4.6783
|
| 245 |
+
- -4.6431
|
| 246 |
+
- -4.5832
|
| 247 |
+
- -4.539
|
| 248 |
+
- -4.6771
|
| 249 |
+
- -4.8089
|
| 250 |
+
- -4.7672
|
| 251 |
+
- -4.5784
|
| 252 |
+
- -4.7755
|
| 253 |
+
- -4.715
|
| 254 |
+
- -4.8919
|
| 255 |
+
- -4.8271
|
| 256 |
+
- -4.7389
|
| 257 |
+
- -4.6047
|
| 258 |
+
- -4.7759
|
| 259 |
+
- -4.6799
|
| 260 |
+
- -4.8201
|
| 261 |
+
- -4.7823
|
| 262 |
+
- -4.8262
|
| 263 |
+
- -4.7857
|
| 264 |
+
- -4.7545
|
| 265 |
+
- -4.9358
|
| 266 |
+
- -4.9733
|
| 267 |
+
- -5.1134
|
| 268 |
+
- -5.1395
|
| 269 |
+
- -4.9016
|
| 270 |
+
- -4.8434
|
| 271 |
+
- -5.0189
|
| 272 |
+
- -4.846
|
| 273 |
+
- -5.0529
|
| 274 |
+
- -4.951
|
| 275 |
+
- -5.0217
|
| 276 |
+
- -5.0049
|
| 277 |
+
- -5.1831
|
| 278 |
+
- -5.1445
|
| 279 |
+
- -5.1015
|
| 280 |
+
- -5.0281
|
| 281 |
+
- -4.9887
|
| 282 |
+
- -4.9916
|
| 283 |
+
- -4.9785
|
| 284 |
+
- -4.9071
|
| 285 |
+
- -4.9488
|
| 286 |
+
- -5.0342
|
| 287 |
+
- -4.9332
|
| 288 |
+
- -5.065
|
| 289 |
+
- -4.8924
|
| 290 |
+
- -5.0875
|
| 291 |
+
- -5.0483
|
| 292 |
+
- -5.0848
|
| 293 |
+
- -5.0655
|
| 294 |
+
- -5.0279
|
| 295 |
+
- -5.0015
|
| 296 |
+
- -5.0792
|
| 297 |
+
- -5.0636
|
| 298 |
+
- -5.2413
|
| 299 |
+
- -5.1421
|
| 300 |
+
- -5.171
|
| 301 |
+
- -5.3256
|
| 302 |
+
- -5.0511
|
| 303 |
+
- -5.1186
|
| 304 |
+
- -5.0057
|
| 305 |
+
- -5.0446
|
| 306 |
+
- -5.1173
|
| 307 |
+
- -5.0325
|
| 308 |
+
- -5.1085
|
| 309 |
+
- -5.0053
|
| 310 |
+
- -5.0755
|
| 311 |
+
- -5.1176
|
| 312 |
+
- -5.1004
|
| 313 |
+
- -5.2153
|
| 314 |
+
- -5.2757
|
| 315 |
+
- -5.3025
|
| 316 |
+
- -5.2867
|
| 317 |
+
- -5.2918
|
| 318 |
+
- -5.3328
|
| 319 |
+
- -5.2731
|
| 320 |
+
- -5.2985
|
| 321 |
+
- -5.24
|
| 322 |
+
- -5.2211
|
| 323 |
+
task_cls: tasks.tts.diffspeech.DiffSpeechTask
|
| 324 |
+
tb_log_interval: 100
|
| 325 |
+
test_ids:
|
| 326 |
+
- 0
|
| 327 |
+
- 1
|
| 328 |
+
- 2
|
| 329 |
+
- 3
|
| 330 |
+
- 4
|
| 331 |
+
- 5
|
| 332 |
+
- 6
|
| 333 |
+
- 7
|
| 334 |
+
- 8
|
| 335 |
+
- 9
|
| 336 |
+
- 10
|
| 337 |
+
- 11
|
| 338 |
+
- 12
|
| 339 |
+
- 13
|
| 340 |
+
- 14
|
| 341 |
+
- 15
|
| 342 |
+
- 16
|
| 343 |
+
- 17
|
| 344 |
+
- 18
|
| 345 |
+
- 19
|
| 346 |
+
- 68
|
| 347 |
+
- 70
|
| 348 |
+
- 74
|
| 349 |
+
- 87
|
| 350 |
+
- 110
|
| 351 |
+
- 172
|
| 352 |
+
- 190
|
| 353 |
+
- 215
|
| 354 |
+
- 231
|
| 355 |
+
- 294
|
| 356 |
+
- 316
|
| 357 |
+
- 324
|
| 358 |
+
- 402
|
| 359 |
+
- 422
|
| 360 |
+
- 485
|
| 361 |
+
- 500
|
| 362 |
+
- 505
|
| 363 |
+
- 508
|
| 364 |
+
- 509
|
| 365 |
+
- 519
|
| 366 |
+
test_input_yaml: ''
|
| 367 |
+
test_num: 100
|
| 368 |
+
test_set_name: test
|
| 369 |
+
timesteps: 100
|
| 370 |
+
train_set_name: train
|
| 371 |
+
train_sets: ''
|
| 372 |
+
use_energy_embed: true
|
| 373 |
+
use_gt_dur: false
|
| 374 |
+
use_gt_energy: false
|
| 375 |
+
use_gt_f0: false
|
| 376 |
+
use_pitch_embed: true
|
| 377 |
+
use_pos_embed: true
|
| 378 |
+
use_spk_embed: false
|
| 379 |
+
use_spk_id: false
|
| 380 |
+
use_uv: true
|
| 381 |
+
use_word_input: false
|
| 382 |
+
val_check_interval: 2000
|
| 383 |
+
valid_infer_interval: 10000
|
| 384 |
+
valid_monitor_key: val_loss
|
| 385 |
+
valid_monitor_mode: min
|
| 386 |
+
valid_set_name: valid
|
| 387 |
+
vocoder: HifiGAN
|
| 388 |
+
vocoder_ckpt: checkpoints/hifi_lj
|
| 389 |
+
warmup_updates: 4000
|
| 390 |
+
weight_decay: 0
|
| 391 |
+
win_size: 1024
|
| 392 |
+
word_dict_size: 10000
|
| 393 |
+
work_dir: checkpoints/0209_ds_1
|
checkpoints/diffsinger/model_ckpt_steps_160000.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:503f81009a75c02d868253b6fb4f1411aeaa32308b101d7804447bc583636b83
|
| 3 |
+
size 168816223
|
docs/diffspeech.md
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run DiffSpeech
|
| 2 |
+
|
| 3 |
+
## Quick Start
|
| 4 |
+
|
| 5 |
+
### Install Dependencies
|
| 6 |
+
|
| 7 |
+
Install dependencies following [readme.md](../readme.md)
|
| 8 |
+
|
| 9 |
+
### Set Config Path and Experiment Name
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
export CONFIG_NAME=egs/datasets/audio/lj/ds.yaml
|
| 13 |
+
export MY_EXP_NAME=ds_exp
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
### Preprocess and binary dataset
|
| 17 |
+
|
| 18 |
+
Prepare dataset following [prepare_data.md](./prepare_data.md)
|
| 19 |
+
|
| 20 |
+
### Prepare Vocoder
|
| 21 |
+
|
| 22 |
+
Prepare vocoder following [prepare_vocoder.md](./prepare_vocoder.md)
|
| 23 |
+
|
| 24 |
+
## Training
|
| 25 |
+
|
| 26 |
+
First, you need a pre-trained FastSpeech2 checkpoint `chckpoints/fs2_exp/model_ckpt_steps_160000.ckpt`. To train a FastSpeech 2 model, run:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config egs/datasets/audio/lj/fs2_orig.yaml --exp_name fs2_exp --reset
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Then, run:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config $CONFIG_NAME --exp_name $MY_EXP_NAME --reset
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
You can check the training and validation curves open Tensorboard via:
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
tensorboard --logdir checkpoints/$MY_EXP_NAME
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Inference (Testing)
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config $CONFIG_NAME --exp_name $MY_EXP_NAME --infer
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Citation
|
| 51 |
+
|
| 52 |
+
If you find this useful for your research, please use the following.
|
| 53 |
+
|
| 54 |
+
```bib
|
| 55 |
+
@article{liu2021diffsinger,
|
| 56 |
+
title={Diffsinger: Singing voice synthesis via shallow diffusion mechanism},
|
| 57 |
+
author={Liu, Jinglin and Li, Chengxi and Ren, Yi and Chen, Feiyang and Liu, Peng and Zhao, Zhou},
|
| 58 |
+
journal={arXiv preprint arXiv:2105.02446},
|
| 59 |
+
volume={2},
|
| 60 |
+
year={2021}
|
| 61 |
+
}
|
| 62 |
+
```
|
docs/prepare_vocoder.md
CHANGED
|
@@ -26,7 +26,7 @@ export MY_EXP_NAME=my_hifigan_exp
|
|
| 26 |
Prepare dataset following [prepare_data.md](./prepare_data.md).
|
| 27 |
|
| 28 |
If you have run the `prepare_data` step of the acoustic
|
| 29 |
-
model (e.g.,
|
| 30 |
|
| 31 |
```bash
|
| 32 |
python data_gen/tts/runs/binarize.py --config $CONFIG_NAME
|
|
|
|
| 26 |
Prepare dataset following [prepare_data.md](./prepare_data.md).
|
| 27 |
|
| 28 |
If you have run the `prepare_data` step of the acoustic
|
| 29 |
+
model (e.g., PortaSpeech and DiffSpeech), you only need to binarize the dataset for the vocoder training:
|
| 30 |
|
| 31 |
```bash
|
| 32 |
python data_gen/tts/runs/binarize.py --config $CONFIG_NAME
|
egs/datasets/audio/lj/ds.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
base_config:
|
| 2 |
+
- egs/egs_bases/tts/ds.yaml
|
| 3 |
+
- ./fs2_orig.yaml
|
| 4 |
+
|
| 5 |
+
fs2_ckpt: checkpoints/fs2_exp/model_ckpt_steps_160000.ckpt
|
| 6 |
+
|
| 7 |
+
# spec_min and spec_max are calculated on the training set.
|
| 8 |
+
spec_min: [ -4.7574, -4.6783, -4.6431, -4.5832, -4.5390, -4.6771, -4.8089, -4.7672,
|
| 9 |
+
-4.5784, -4.7755, -4.7150, -4.8919, -4.8271, -4.7389, -4.6047, -4.7759,
|
| 10 |
+
-4.6799, -4.8201, -4.7823, -4.8262, -4.7857, -4.7545, -4.9358, -4.9733,
|
| 11 |
+
-5.1134, -5.1395, -4.9016, -4.8434, -5.0189, -4.8460, -5.0529, -4.9510,
|
| 12 |
+
-5.0217, -5.0049, -5.1831, -5.1445, -5.1015, -5.0281, -4.9887, -4.9916,
|
| 13 |
+
-4.9785, -4.9071, -4.9488, -5.0342, -4.9332, -5.0650, -4.8924, -5.0875,
|
| 14 |
+
-5.0483, -5.0848, -5.0655, -5.0279, -5.0015, -5.0792, -5.0636, -5.2413,
|
| 15 |
+
-5.1421, -5.1710, -5.3256, -5.0511, -5.1186, -5.0057, -5.0446, -5.1173,
|
| 16 |
+
-5.0325, -5.1085, -5.0053, -5.0755, -5.1176, -5.1004, -5.2153, -5.2757,
|
| 17 |
+
-5.3025, -5.2867, -5.2918, -5.3328, -5.2731, -5.2985, -5.2400, -5.2211 ]
|
| 18 |
+
spec_max: [ -0.5982, -0.0778, 0.1205, 0.2747, 0.4657, 0.5123, 0.5830, 0.7093,
|
| 19 |
+
0.6461, 0.6101, 0.7316, 0.7715, 0.7681, 0.8349, 0.7815, 0.7591,
|
| 20 |
+
0.7910, 0.7433, 0.7352, 0.6869, 0.6854, 0.6623, 0.5353, 0.6492,
|
| 21 |
+
0.6909, 0.6106, 0.5761, 0.5236, 0.5638, 0.4054, 0.4545, 0.3407,
|
| 22 |
+
0.3037, 0.3380, 0.1599, 0.1603, 0.2741, 0.2130, 0.1569, 0.1911,
|
| 23 |
+
0.2324, 0.1586, 0.1221, 0.0341, -0.0558, 0.0553, -0.1153, -0.0933,
|
| 24 |
+
-0.1171, -0.0050, -0.1519, -0.1629, -0.0522, -0.0739, -0.2069, -0.2405,
|
| 25 |
+
-0.1244, -0.2582, -0.1361, -0.1575, -0.1442, 0.0513, -0.1567, -0.2000,
|
| 26 |
+
0.0086, -0.0698, 0.1385, 0.0941, 0.1864, 0.1225, 0.1389, 0.1382,
|
| 27 |
+
0.1670, 0.1007, 0.1444, 0.0888, 0.1998, 0.2280, 0.2932, 0.3047 ]
|
| 28 |
+
|
| 29 |
+
max_tokens: 30000
|
egs/egs_bases/tts/ds.yaml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
base_config: ./fs2_orig.yaml
|
| 2 |
+
|
| 3 |
+
# special configs for diffspeech
|
| 4 |
+
task_cls: tasks.tts.diffspeech.DiffSpeechTask
|
| 5 |
+
lr: 0.001
|
| 6 |
+
timesteps: 100
|
| 7 |
+
K_step: 71
|
| 8 |
+
diff_loss_type: l1
|
| 9 |
+
diff_decoder_type: 'wavenet'
|
| 10 |
+
schedule_type: 'linear'
|
| 11 |
+
max_beta: 0.06
|
| 12 |
+
|
| 13 |
+
## model configs for diffspeech
|
| 14 |
+
dilation_cycle_length: 1
|
| 15 |
+
residual_layers: 20
|
| 16 |
+
residual_channels: 256
|
| 17 |
+
decay_steps: 50000
|
| 18 |
+
keep_bins: 80
|
| 19 |
+
#content_cond_steps: [ ] # [ 0, 10000 ]
|
| 20 |
+
#spk_cond_steps: [ ] # [ 0, 10000 ]
|
| 21 |
+
#gen_tgt_spk_id: -1
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# training configs for diffspeech
|
| 26 |
+
#max_sentences: 48
|
| 27 |
+
#num_sanity_val_steps: 1
|
| 28 |
+
num_valid_plots: 10
|
| 29 |
+
use_gt_dur: false
|
| 30 |
+
use_gt_f0: false
|
| 31 |
+
#pitch_type: cwt
|
| 32 |
+
max_updates: 160000
|
inference/tts/ds.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
# from inference.tts.fs import FastSpeechInfer
|
| 3 |
+
# from modules.tts.fs2_orig import FastSpeech2Orig
|
| 4 |
+
from inference.tts.base_tts_infer import BaseTTSInfer
|
| 5 |
+
from modules.tts.diffspeech.shallow_diffusion_tts import GaussianDiffusion
|
| 6 |
+
from utils.commons.ckpt_utils import load_ckpt
|
| 7 |
+
from utils.commons.hparams import hparams
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DiffSpeechInfer(BaseTTSInfer):
|
| 11 |
+
def build_model(self):
|
| 12 |
+
dict_size = len(self.ph_encoder)
|
| 13 |
+
model = GaussianDiffusion(dict_size, self.hparams)
|
| 14 |
+
model.eval()
|
| 15 |
+
load_ckpt(model, hparams['work_dir'], 'model')
|
| 16 |
+
return model
|
| 17 |
+
|
| 18 |
+
def forward_model(self, inp):
|
| 19 |
+
sample = self.input_to_batch(inp)
|
| 20 |
+
txt_tokens = sample['txt_tokens'] # [B, T_t]
|
| 21 |
+
spk_id = sample.get('spk_ids')
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True)
|
| 24 |
+
mel_out = output['mel_out']
|
| 25 |
+
wav_out = self.run_vocoder(mel_out)
|
| 26 |
+
wav_out = wav_out.cpu().numpy()
|
| 27 |
+
return wav_out[0]
|
| 28 |
+
|
| 29 |
+
if __name__ == '__main__':
|
| 30 |
+
DiffSpeechInfer.example_run()
|
modules/tts/commons/align_ops.py
CHANGED
|
@@ -13,9 +13,8 @@ def mel2ph_to_mel2word(mel2ph, ph2word):
|
|
| 13 |
|
| 14 |
|
| 15 |
def clip_mel2token_to_multiple(mel2token, frames_multiple):
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
mel2token = mel2token[:, :max_frames]
|
| 19 |
return mel2token
|
| 20 |
|
| 21 |
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
def clip_mel2token_to_multiple(mel2token, frames_multiple):
|
| 16 |
+
max_frames = mel2token.shape[1] // frames_multiple * frames_multiple
|
| 17 |
+
mel2token = mel2token[:, :max_frames]
|
|
|
|
| 18 |
return mel2token
|
| 19 |
|
| 20 |
|
modules/tts/diffspeech/net.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from math import sqrt
|
| 8 |
+
|
| 9 |
+
Linear = nn.Linear
|
| 10 |
+
ConvTranspose2d = nn.ConvTranspose2d
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Mish(nn.Module):
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
return x * torch.tanh(F.softplus(x))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SinusoidalPosEmb(nn.Module):
|
| 19 |
+
def __init__(self, dim):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.dim = dim
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
device = x.device
|
| 25 |
+
half_dim = self.dim // 2
|
| 26 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 27 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
| 28 |
+
emb = x[:, None] * emb[None, :]
|
| 29 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 30 |
+
return emb
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def Conv1d(*args, **kwargs):
|
| 34 |
+
layer = nn.Conv1d(*args, **kwargs)
|
| 35 |
+
nn.init.kaiming_normal_(layer.weight)
|
| 36 |
+
return layer
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ResidualBlock(nn.Module):
|
| 40 |
+
def __init__(self, encoder_hidden, residual_channels, dilation):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
|
| 43 |
+
self.diffusion_projection = Linear(residual_channels, residual_channels)
|
| 44 |
+
self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1)
|
| 45 |
+
self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
|
| 46 |
+
|
| 47 |
+
def forward(self, x, conditioner, diffusion_step):
|
| 48 |
+
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
| 49 |
+
conditioner = self.conditioner_projection(conditioner)
|
| 50 |
+
y = x + diffusion_step
|
| 51 |
+
|
| 52 |
+
y = self.dilated_conv(y) + conditioner
|
| 53 |
+
|
| 54 |
+
gate, filter = torch.chunk(y, 2, dim=1)
|
| 55 |
+
y = torch.sigmoid(gate) * torch.tanh(filter)
|
| 56 |
+
|
| 57 |
+
y = self.output_projection(y)
|
| 58 |
+
residual, skip = torch.chunk(y, 2, dim=1)
|
| 59 |
+
return (x + residual) / sqrt(2.0), skip
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class DiffNet(nn.Module):
|
| 63 |
+
def __init__(self, hparams):
|
| 64 |
+
super().__init__()
|
| 65 |
+
in_dims = hparams['audio_num_mel_bins']
|
| 66 |
+
self.encoder_hidden = hparams['hidden_size']
|
| 67 |
+
self.residual_layers = hparams['residual_layers']
|
| 68 |
+
self.residual_channels = hparams['residual_channels']
|
| 69 |
+
self.dilation_cycle_length = hparams['dilation_cycle_length']
|
| 70 |
+
|
| 71 |
+
self.input_projection = Conv1d(in_dims, self.residual_channels, 1)
|
| 72 |
+
self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels)
|
| 73 |
+
dim = self.residual_channels
|
| 74 |
+
self.mlp = nn.Sequential(
|
| 75 |
+
nn.Linear(dim, dim * 4),
|
| 76 |
+
Mish(),
|
| 77 |
+
nn.Linear(dim * 4, dim)
|
| 78 |
+
)
|
| 79 |
+
self.residual_layers = nn.ModuleList([
|
| 80 |
+
ResidualBlock(self.encoder_hidden, self.residual_channels, 2 ** (i % self.dilation_cycle_length))
|
| 81 |
+
for i in range(self.residual_layers)
|
| 82 |
+
])
|
| 83 |
+
self.skip_projection = Conv1d(self.residual_channels, self.residual_channels, 1)
|
| 84 |
+
self.output_projection = Conv1d(self.residual_channels, in_dims, 1)
|
| 85 |
+
nn.init.zeros_(self.output_projection.weight)
|
| 86 |
+
|
| 87 |
+
def forward(self, spec, diffusion_step, cond):
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
:param spec: [B, 1, M, T]
|
| 91 |
+
:param diffusion_step: [B, 1]
|
| 92 |
+
:param cond: [B, M, T]
|
| 93 |
+
:return:
|
| 94 |
+
"""
|
| 95 |
+
x = spec[:, 0]
|
| 96 |
+
x = self.input_projection(x) # x [B, residual_channel, T]
|
| 97 |
+
|
| 98 |
+
x = F.relu(x)
|
| 99 |
+
diffusion_step = self.diffusion_embedding(diffusion_step)
|
| 100 |
+
diffusion_step = self.mlp(diffusion_step)
|
| 101 |
+
skip = []
|
| 102 |
+
for layer_id, layer in enumerate(self.residual_layers):
|
| 103 |
+
x, skip_connection = layer(x, cond, diffusion_step)
|
| 104 |
+
skip.append(skip_connection)
|
| 105 |
+
|
| 106 |
+
x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
|
| 107 |
+
x = self.skip_projection(x)
|
| 108 |
+
x = F.relu(x)
|
| 109 |
+
x = self.output_projection(x) # [B, 80, T]
|
| 110 |
+
return x[:, None, :, :]
|
modules/tts/diffspeech/shallow_diffusion_tts.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
from functools import partial
|
| 4 |
+
from inspect import isfunction
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import nn
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from modules.tts.fs2_orig import FastSpeech2Orig
|
| 12 |
+
from modules.tts.diffspeech.net import DiffNet
|
| 13 |
+
from modules.tts.commons.align_ops import expand_states
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def exists(x):
|
| 17 |
+
return x is not None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def default(val, d):
|
| 21 |
+
if exists(val):
|
| 22 |
+
return val
|
| 23 |
+
return d() if isfunction(d) else d
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# gaussian diffusion trainer class
|
| 27 |
+
|
| 28 |
+
def extract(a, t, x_shape):
|
| 29 |
+
b, *_ = t.shape
|
| 30 |
+
out = a.gather(-1, t)
|
| 31 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def noise_like(shape, device, repeat=False):
|
| 35 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
| 36 |
+
noise = lambda: torch.randn(shape, device=device)
|
| 37 |
+
return repeat_noise() if repeat else noise()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def linear_beta_schedule(timesteps, max_beta=0.01):
|
| 41 |
+
"""
|
| 42 |
+
linear schedule
|
| 43 |
+
"""
|
| 44 |
+
betas = np.linspace(1e-4, max_beta, timesteps)
|
| 45 |
+
return betas
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def cosine_beta_schedule(timesteps, s=0.008):
|
| 49 |
+
"""
|
| 50 |
+
cosine schedule
|
| 51 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
| 52 |
+
"""
|
| 53 |
+
steps = timesteps + 1
|
| 54 |
+
x = np.linspace(0, steps, steps)
|
| 55 |
+
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
| 56 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 57 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 58 |
+
return np.clip(betas, a_min=0, a_max=0.999)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
beta_schedule = {
|
| 62 |
+
"cosine": cosine_beta_schedule,
|
| 63 |
+
"linear": linear_beta_schedule,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
DIFF_DECODERS = {
|
| 68 |
+
'wavenet': lambda hp: DiffNet(hp),
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class AuxModel(FastSpeech2Orig):
|
| 73 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None, spk_id=None,
|
| 74 |
+
f0=None, uv=None, energy=None, infer=False, **kwargs):
|
| 75 |
+
ret = {}
|
| 76 |
+
encoder_out = self.encoder(txt_tokens) # [B, T, C]
|
| 77 |
+
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
| 78 |
+
style_embed = self.forward_style_embed(spk_embed, spk_id)
|
| 79 |
+
|
| 80 |
+
# add dur
|
| 81 |
+
dur_inp = (encoder_out + style_embed) * src_nonpadding
|
| 82 |
+
mel2ph = self.forward_dur(dur_inp, mel2ph, txt_tokens, ret)
|
| 83 |
+
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
|
| 84 |
+
decoder_inp = decoder_inp_ = expand_states(encoder_out, mel2ph)
|
| 85 |
+
|
| 86 |
+
# add pitch and energy embed
|
| 87 |
+
if self.hparams['use_pitch_embed']:
|
| 88 |
+
pitch_inp = (decoder_inp_ + style_embed) * tgt_nonpadding
|
| 89 |
+
decoder_inp = decoder_inp + self.forward_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out)
|
| 90 |
+
|
| 91 |
+
# add pitch and energy embed
|
| 92 |
+
if self.hparams['use_energy_embed']:
|
| 93 |
+
energy_inp = (decoder_inp_ + style_embed) * tgt_nonpadding
|
| 94 |
+
decoder_inp = decoder_inp + self.forward_energy(energy_inp, energy, ret)
|
| 95 |
+
|
| 96 |
+
# decoder input
|
| 97 |
+
ret['decoder_inp'] = decoder_inp = (decoder_inp + style_embed) * tgt_nonpadding
|
| 98 |
+
if self.hparams['dec_inp_add_noise']:
|
| 99 |
+
B, T, _ = decoder_inp.shape
|
| 100 |
+
z = kwargs.get('adv_z', torch.randn([B, T, self.z_channels])).to(decoder_inp.device)
|
| 101 |
+
ret['adv_z'] = z
|
| 102 |
+
decoder_inp = torch.cat([decoder_inp, z], -1)
|
| 103 |
+
decoder_inp = self.dec_inp_noise_proj(decoder_inp) * tgt_nonpadding
|
| 104 |
+
if kwargs['skip_decoder']:
|
| 105 |
+
return ret
|
| 106 |
+
ret['mel_out'] = self.forward_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
| 107 |
+
return ret
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class GaussianDiffusion(nn.Module):
|
| 111 |
+
def __init__(self, dict_size, hparams, out_dims=None):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.hparams = hparams
|
| 114 |
+
out_dims = hparams['audio_num_mel_bins']
|
| 115 |
+
denoise_fn = DIFF_DECODERS[hparams['diff_decoder_type']](hparams)
|
| 116 |
+
timesteps = hparams['timesteps']
|
| 117 |
+
K_step = hparams['K_step']
|
| 118 |
+
loss_type = hparams['diff_loss_type']
|
| 119 |
+
spec_min = hparams['spec_min']
|
| 120 |
+
spec_max = hparams['spec_max']
|
| 121 |
+
|
| 122 |
+
self.denoise_fn = denoise_fn
|
| 123 |
+
self.fs2 = AuxModel(dict_size, hparams)
|
| 124 |
+
self.mel_bins = out_dims
|
| 125 |
+
|
| 126 |
+
if hparams['schedule_type'] == 'linear':
|
| 127 |
+
betas = linear_beta_schedule(timesteps, hparams['max_beta'])
|
| 128 |
+
else:
|
| 129 |
+
betas = cosine_beta_schedule(timesteps)
|
| 130 |
+
|
| 131 |
+
alphas = 1. - betas
|
| 132 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 133 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
| 134 |
+
|
| 135 |
+
timesteps, = betas.shape
|
| 136 |
+
self.num_timesteps = int(timesteps)
|
| 137 |
+
self.K_step = K_step
|
| 138 |
+
self.loss_type = loss_type
|
| 139 |
+
|
| 140 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
| 141 |
+
|
| 142 |
+
self.register_buffer('betas', to_torch(betas))
|
| 143 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 144 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
| 145 |
+
|
| 146 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 147 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
| 148 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
| 149 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
| 150 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
| 151 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
| 152 |
+
|
| 153 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 154 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
| 155 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 156 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
| 157 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 158 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
| 159 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
| 160 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
| 161 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
| 162 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
| 163 |
+
|
| 164 |
+
self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
|
| 165 |
+
self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
|
| 166 |
+
|
| 167 |
+
def q_mean_variance(self, x_start, t):
|
| 168 |
+
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 169 |
+
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
| 170 |
+
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
| 171 |
+
return mean, variance, log_variance
|
| 172 |
+
|
| 173 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 174 |
+
return (
|
| 175 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
| 176 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def q_posterior(self, x_start, x_t, t):
|
| 180 |
+
posterior_mean = (
|
| 181 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
| 182 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 183 |
+
)
|
| 184 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
| 185 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
| 186 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 187 |
+
|
| 188 |
+
def p_mean_variance(self, x, t, cond, clip_denoised: bool):
|
| 189 |
+
noise_pred = self.denoise_fn(x, t, cond=cond)
|
| 190 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
|
| 191 |
+
|
| 192 |
+
if clip_denoised:
|
| 193 |
+
x_recon.clamp_(-1., 1.)
|
| 194 |
+
|
| 195 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
| 196 |
+
return model_mean, posterior_variance, posterior_log_variance
|
| 197 |
+
|
| 198 |
+
@torch.no_grad()
|
| 199 |
+
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
| 200 |
+
b, *_, device = *x.shape, x.device
|
| 201 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
|
| 202 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
| 203 |
+
# no noise when t == 0
|
| 204 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
| 205 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
| 206 |
+
|
| 207 |
+
def q_sample(self, x_start, t, noise=None):
|
| 208 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 209 |
+
return (
|
| 210 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 211 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def p_losses(self, x_start, t, cond, noise=None, nonpadding=None):
|
| 215 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 216 |
+
|
| 217 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| 218 |
+
x_recon = self.denoise_fn(x_noisy, t, cond)
|
| 219 |
+
|
| 220 |
+
if self.loss_type == 'l1':
|
| 221 |
+
if nonpadding is not None:
|
| 222 |
+
loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean()
|
| 223 |
+
else:
|
| 224 |
+
# print('are you sure w/o nonpadding?')
|
| 225 |
+
loss = (noise - x_recon).abs().mean()
|
| 226 |
+
|
| 227 |
+
elif self.loss_type == 'l2':
|
| 228 |
+
loss = F.mse_loss(noise, x_recon)
|
| 229 |
+
else:
|
| 230 |
+
raise NotImplementedError()
|
| 231 |
+
|
| 232 |
+
return loss
|
| 233 |
+
|
| 234 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None, spk_id=None,
|
| 235 |
+
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
| 236 |
+
b, *_, device = *txt_tokens.shape, txt_tokens.device
|
| 237 |
+
ret = self.fs2(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
|
| 238 |
+
f0=f0, uv=uv, energy=energy, infer=infer, skip_decoder=(not infer), **kwargs)
|
| 239 |
+
# (txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
| 240 |
+
# skip_decoder=(not infer), infer=infer, **kwargs)
|
| 241 |
+
cond = ret['decoder_inp'].transpose(1, 2)
|
| 242 |
+
|
| 243 |
+
if not infer:
|
| 244 |
+
t = torch.randint(0, self.K_step, (b,), device=device).long()
|
| 245 |
+
x = ref_mels
|
| 246 |
+
x = self.norm_spec(x)
|
| 247 |
+
x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
|
| 248 |
+
ret['diff_loss'] = self.p_losses(x, t, cond)
|
| 249 |
+
# nonpadding = (mel2ph != 0).float()
|
| 250 |
+
# ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding)
|
| 251 |
+
ret['mel_out'] = None
|
| 252 |
+
else:
|
| 253 |
+
ret['fs2_mel'] = ret['mel_out']
|
| 254 |
+
fs2_mels = ret['mel_out']
|
| 255 |
+
t = self.K_step
|
| 256 |
+
fs2_mels = self.norm_spec(fs2_mels)
|
| 257 |
+
fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
|
| 258 |
+
|
| 259 |
+
x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
|
| 260 |
+
if self.hparams.get('gaussian_start') is not None and self.hparams['gaussian_start']:
|
| 261 |
+
print('===> gaussian start.')
|
| 262 |
+
shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
|
| 263 |
+
x = torch.randn(shape, device=device)
|
| 264 |
+
for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
|
| 265 |
+
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
| 266 |
+
x = x[:, 0].transpose(1, 2)
|
| 267 |
+
ret['mel_out'] = self.denorm_spec(x)
|
| 268 |
+
|
| 269 |
+
return ret
|
| 270 |
+
|
| 271 |
+
def norm_spec(self, x):
|
| 272 |
+
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
| 273 |
+
|
| 274 |
+
def denorm_spec(self, x):
|
| 275 |
+
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
| 276 |
+
|
| 277 |
+
def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
|
| 278 |
+
return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
|
| 279 |
+
|
| 280 |
+
def out2mel(self, x):
|
| 281 |
+
return x
|
tasks/tts/diffspeech.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from modules.tts.diffspeech.shallow_diffusion_tts import GaussianDiffusion
|
| 4 |
+
from tasks.tts.fs2_orig import FastSpeech2OrigTask
|
| 5 |
+
|
| 6 |
+
import utils
|
| 7 |
+
from utils.commons.hparams import hparams
|
| 8 |
+
from utils.commons.ckpt_utils import load_ckpt
|
| 9 |
+
from utils.audio.pitch.utils import denorm_f0
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DiffSpeechTask(FastSpeech2OrigTask):
|
| 13 |
+
def build_tts_model(self):
|
| 14 |
+
# get min and max
|
| 15 |
+
# import torch
|
| 16 |
+
# from tqdm import tqdm
|
| 17 |
+
# v_min = torch.ones([80]) * 100
|
| 18 |
+
# v_max = torch.ones([80]) * -100
|
| 19 |
+
# for i, ds in enumerate(tqdm(self.dataset_cls('train'))):
|
| 20 |
+
# v_max = torch.max(torch.max(ds['mel'].reshape(-1, 80), 0)[0], v_max)
|
| 21 |
+
# v_min = torch.min(torch.min(ds['mel'].reshape(-1, 80), 0)[0], v_min)
|
| 22 |
+
# if i % 100 == 0:
|
| 23 |
+
# print(i, v_min, v_max)
|
| 24 |
+
# print('final', v_min, v_max)
|
| 25 |
+
dict_size = len(self.token_encoder)
|
| 26 |
+
self.model = GaussianDiffusion(dict_size, hparams)
|
| 27 |
+
if hparams['fs2_ckpt'] != '':
|
| 28 |
+
load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True)
|
| 29 |
+
for k, v in self.model.fs2.named_parameters():
|
| 30 |
+
if 'predictor' not in k:
|
| 31 |
+
v.requires_grad = False
|
| 32 |
+
# or
|
| 33 |
+
# for k, v in self.model.fs2.named_parameters():
|
| 34 |
+
# v.requires_grad = False
|
| 35 |
+
|
| 36 |
+
def build_optimizer(self, model):
|
| 37 |
+
self.optimizer = optimizer = torch.optim.AdamW(
|
| 38 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 39 |
+
lr=hparams['lr'],
|
| 40 |
+
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
|
| 41 |
+
weight_decay=hparams['weight_decay'])
|
| 42 |
+
return optimizer
|
| 43 |
+
|
| 44 |
+
def build_scheduler(self, optimizer):
|
| 45 |
+
return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
|
| 46 |
+
|
| 47 |
+
def run_model(self, sample, infer=False, *args, **kwargs):
|
| 48 |
+
txt_tokens = sample['txt_tokens'] # [B, T_t]
|
| 49 |
+
spk_embed = sample.get('spk_embed')
|
| 50 |
+
spk_id = sample.get('spk_ids')
|
| 51 |
+
if not infer:
|
| 52 |
+
target = sample['mels'] # [B, T_s, 80]
|
| 53 |
+
mel2ph = sample['mel2ph'] # [B, T_s]
|
| 54 |
+
f0 = sample.get('f0')
|
| 55 |
+
uv = sample.get('uv')
|
| 56 |
+
output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
|
| 57 |
+
ref_mels=target, f0=f0, uv=uv, infer=False)
|
| 58 |
+
losses = {}
|
| 59 |
+
if 'diff_loss' in output:
|
| 60 |
+
losses['mel'] = output['diff_loss']
|
| 61 |
+
self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
|
| 62 |
+
if hparams['use_pitch_embed']:
|
| 63 |
+
self.add_pitch_loss(output, sample, losses)
|
| 64 |
+
return losses, output
|
| 65 |
+
else:
|
| 66 |
+
use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur'])
|
| 67 |
+
use_gt_f0 = kwargs.get('infer_use_gt_f0', hparams['use_gt_f0'])
|
| 68 |
+
mel2ph, uv, f0 = None, None, None
|
| 69 |
+
if use_gt_dur:
|
| 70 |
+
mel2ph = sample['mel2ph']
|
| 71 |
+
if use_gt_f0:
|
| 72 |
+
f0 = sample['f0']
|
| 73 |
+
uv = sample['uv']
|
| 74 |
+
output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
|
| 75 |
+
ref_mels=None, f0=f0, uv=uv, infer=True)
|
| 76 |
+
return output
|
| 77 |
+
|
| 78 |
+
def save_valid_result(self, sample, batch_idx, model_out):
|
| 79 |
+
sr = hparams['audio_sample_rate']
|
| 80 |
+
f0_gt = None
|
| 81 |
+
# mel_out = model_out['mel_out']
|
| 82 |
+
if sample.get('f0') is not None:
|
| 83 |
+
f0_gt = denorm_f0(sample['f0'][0].cpu(), sample['uv'][0].cpu())
|
| 84 |
+
# self.plot_mel(batch_idx, sample['mels'], mel_out, f0s=f0_gt)
|
| 85 |
+
if self.global_step > 0:
|
| 86 |
+
# wav_pred = self.vocoder.spec2wav(mel_out[0].cpu(), f0=f0_gt)
|
| 87 |
+
# self.logger.add_audio(f'wav_val_{batch_idx}', wav_pred, self.global_step, sr)
|
| 88 |
+
# with gt duration
|
| 89 |
+
model_out = self.run_model(sample, infer=True, infer_use_gt_dur=True)
|
| 90 |
+
dur_info = self.get_plot_dur_info(sample, model_out)
|
| 91 |
+
del dur_info['dur_pred']
|
| 92 |
+
wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu(), f0=f0_gt)
|
| 93 |
+
self.logger.add_audio(f'wav_gdur_{batch_idx}', wav_pred, self.global_step, sr)
|
| 94 |
+
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'][0], f'diffmel_gdur_{batch_idx}',
|
| 95 |
+
dur_info=dur_info, f0s=f0_gt)
|
| 96 |
+
self.plot_mel(batch_idx, sample['mels'], model_out['fs2_mel'][0], f'fs2mel_gdur_{batch_idx}',
|
| 97 |
+
dur_info=dur_info, f0s=f0_gt) # gt mel vs. fs2 mel
|
| 98 |
+
|
| 99 |
+
# with pred duration
|
| 100 |
+
if not hparams['use_gt_dur']:
|
| 101 |
+
model_out = self.run_model(sample, infer=True, infer_use_gt_dur=False)
|
| 102 |
+
dur_info = self.get_plot_dur_info(sample, model_out)
|
| 103 |
+
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'][0], f'mel_pdur_{batch_idx}',
|
| 104 |
+
dur_info=dur_info, f0s=f0_gt)
|
| 105 |
+
wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu(), f0=f0_gt)
|
| 106 |
+
self.logger.add_audio(f'wav_pdur_{batch_idx}', wav_pred, self.global_step, sr)
|
| 107 |
+
# gt wav
|
| 108 |
+
if self.global_step <= hparams['valid_infer_interval']:
|
| 109 |
+
mel_gt = sample['mels'][0].cpu()
|
| 110 |
+
wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0_gt)
|
| 111 |
+
self.logger.add_audio(f'wav_gt_{batch_idx}', wav_gt, self.global_step, sr)
|