Spaces:
Build error
Build error
ddd
commited on
Commit
·
40e984c
1
Parent(s):
c4e83e4
pndm codes
Browse files- .gitattributes +1 -0
- docs/README-SVS-opencpop-cascade.md +3 -3
- docs/README-SVS-opencpop-e2e.md +2 -1
- docs/README-SVS-popcs.md +1 -1
- docs/README-SVS.md +41 -9
- docs/README-TTS.md +7 -1
- inference/svs/base_svs_infer.py +1 -1
- inference/svs/ds_cascade.py +2 -0
- inference/svs/ds_e2e.py +2 -2
- inference/svs/gradio/infer.py +1 -1
- modules/diffsinger_midi/fs2.py +110 -0
- modules/hifigan/hifigan.py +365 -365
- modules/hifigan/mel_utils.py +80 -80
- modules/parallel_wavegan/models/parallel_wavegan.py +434 -434
- usr/configs/midi/cascade/opencs/ds60_rel.yaml +2 -1
- usr/diff/shallow_diffusion_tts.py +324 -273
- utils/hparams.py +36 -44
.gitattributes
CHANGED
|
@@ -30,3 +30,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 30 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 32 |
model_ckpt_steps* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 30 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 32 |
model_ckpt_steps* filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
checkpoints/0831_opencpop_ds1000 filter=lfs diff=lfs merge=lfs -text
|
docs/README-SVS-opencpop-cascade.md
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
[](https://github.com/MoonInTheRiver/DiffSinger)
|
| 4 |
[](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
| 5 |
|
| 6 |
-
## DiffSinger (MIDI version
|
| 7 |
### 0. Data Acquirement
|
| 8 |
For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
|
| 9 |
|
|
@@ -67,7 +67,7 @@ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/ope
|
|
| 67 |
|
| 68 |
Remember to adjust the "fs2_ckpt" parameter in `usr/configs/midi/cascade/opencs/ds60_rel.yaml` to fit your path.
|
| 69 |
|
| 70 |
-
### 3. Inference
|
| 71 |
```sh
|
| 72 |
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME --reset --infer
|
| 73 |
```
|
|
@@ -82,7 +82,7 @@ Remember to put the pre-trained models in `checkpoints` directory.
|
|
| 82 |
|
| 83 |
### 4. Inference from raw inputs
|
| 84 |
```sh
|
| 85 |
-
python inference/svs/
|
| 86 |
```
|
| 87 |
Raw inputs:
|
| 88 |
```
|
|
|
|
| 3 |
[](https://github.com/MoonInTheRiver/DiffSinger)
|
| 4 |
[](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
| 5 |
|
| 6 |
+
## DiffSinger (MIDI SVS | A version)
|
| 7 |
### 0. Data Acquirement
|
| 8 |
For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
|
| 9 |
|
|
|
|
| 67 |
|
| 68 |
Remember to adjust the "fs2_ckpt" parameter in `usr/configs/midi/cascade/opencs/ds60_rel.yaml` to fit your path.
|
| 69 |
|
| 70 |
+
### 3. Inference from packed test set
|
| 71 |
```sh
|
| 72 |
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME --reset --infer
|
| 73 |
```
|
|
|
|
| 82 |
|
| 83 |
### 4. Inference from raw inputs
|
| 84 |
```sh
|
| 85 |
+
python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME
|
| 86 |
```
|
| 87 |
Raw inputs:
|
| 88 |
```
|
docs/README-SVS-opencpop-e2e.md
CHANGED
|
@@ -2,13 +2,14 @@
|
|
| 2 |
[](https://arxiv.org/abs/2105.02446)
|
| 3 |
[](https://github.com/MoonInTheRiver/DiffSinger)
|
| 4 |
[](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
|
|
|
| 5 |
|
| 6 |
Substantial update: We 1) **abandon** the explicit prediction of the F0 curve; 2) increase the receptive field of the denoiser; 3) make the linguistic encoder more robust.
|
| 7 |
**By doing so, 1) the synthesized recordings are more natural in terms of pitch; 2) the pipeline is simpler.**
|
| 8 |
|
| 9 |
简而言之,把F0曲线的动态性交给生成式模型去捕捉,而不再是以前那样用MSE约束对数域F0。
|
| 10 |
|
| 11 |
-
## DiffSinger (MIDI version
|
| 12 |
### 0. Data Acquirement
|
| 13 |
For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
|
| 14 |
|
|
|
|
| 2 |
[](https://arxiv.org/abs/2105.02446)
|
| 3 |
[](https://github.com/MoonInTheRiver/DiffSinger)
|
| 4 |
[](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
| 5 |
+
| [Interactive🤗 SVS](https://huggingface.co/spaces/Silentlin/DiffSinger)
|
| 6 |
|
| 7 |
Substantial update: We 1) **abandon** the explicit prediction of the F0 curve; 2) increase the receptive field of the denoiser; 3) make the linguistic encoder more robust.
|
| 8 |
**By doing so, 1) the synthesized recordings are more natural in terms of pitch; 2) the pipeline is simpler.**
|
| 9 |
|
| 10 |
简而言之,把F0曲线的动态性交给生成式模型去捕捉,而不再是以前那样用MSE约束对数域F0。
|
| 11 |
|
| 12 |
+
## DiffSinger (MIDI SVS | B version)
|
| 13 |
### 0. Data Acquirement
|
| 14 |
For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
|
| 15 |
|
docs/README-SVS-popcs.md
CHANGED
|
@@ -54,7 +54,7 @@ Remember to put the pre-trained models in `checkpoints` directory.
|
|
| 54 |
*Note that:*
|
| 55 |
|
| 56 |
- *the original PWG version vocoder in the paper we used has been put into commercial use, so we provide this HifiGAN version vocoder as a substitute.*
|
| 57 |
-
- *we assume the ground-truth F0 to be given as the pitch information following [1][2][3]. If you want to conduct experiments on MIDI data, you need an external F0 predictor (like [MIDI-
|
| 58 |
|
| 59 |
[1] Adversarially trained multi-singer sequence-to-sequence singing synthesizer. Interspeech 2020.
|
| 60 |
|
|
|
|
| 54 |
*Note that:*
|
| 55 |
|
| 56 |
- *the original PWG version vocoder in the paper we used has been put into commercial use, so we provide this HifiGAN version vocoder as a substitute.*
|
| 57 |
+
- *we assume the ground-truth F0 to be given as the pitch information following [1][2][3]. If you want to conduct experiments on MIDI data, you need an external F0 predictor (like [MIDI-A-version](README-SVS-opencpop-cascade.md)) or a joint prediction with spectrograms(like [MIDI-B-version](README-SVS-opencpop-e2e.md)).*
|
| 58 |
|
| 59 |
[1] Adversarially trained multi-singer sequence-to-sequence singing synthesizer. Interspeech 2020.
|
| 60 |
|
docs/README-SVS.md
CHANGED
|
@@ -1,7 +1,13 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
### PART1. [Run DiffSinger on PopCS](README-SVS-popcs.md)
|
| 4 |
-
In
|
| 5 |
|
| 6 |
Thus, the pipeline of this part can be summarized as:
|
| 7 |
|
|
@@ -18,13 +24,16 @@ Thus, the pipeline of this part can be summarized as:
|
|
| 18 |
|
| 19 |
[3] DeepSinger : Singing Voice Synthesis with Data Mined From the Web. KDD 2020.
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
### PART2. [Run DiffSinger on Opencpop](README-SVS-opencpop-cascade.md)
|
| 22 |
-
Thanks [Opencpop team](https://wenet.org.cn/opencpop/) for releasing their SVS dataset with MIDI label, **Jan.20, 2022
|
| 23 |
|
| 24 |
Since there are elaborately annotated MIDI labels, we are able to supplement the pipeline in PART 1 by adding a naive melody frontend.
|
| 25 |
|
| 26 |
-
#### 2.
|
| 27 |
-
Thus, the pipeline of [
|
| 28 |
|
| 29 |
```
|
| 30 |
[lyrics] + [MIDI] -> [linguistic representation (with MIDI information)] + [predicted F0] + [predicted phoneme duration] (Melody frontend)
|
|
@@ -32,13 +41,36 @@ Thus, the pipeline of [this part](README-SVS-opencpop-cascade.md) can be summari
|
|
| 32 |
[mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
|
| 33 |
```
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
Thus, the pipeline of [
|
| 39 |
```
|
| 40 |
[lyrics] + [MIDI] -> [linguistic representation] + [predicted phoneme duration] (Melody frontend)
|
| 41 |
[linguistic representation (with MIDI information)] + [predicted phoneme duration] -> [mel-spectrogram] (Acoustic model)
|
| 42 |
[mel-spectrogram] -> [predicted F0] (Pitch extractor)
|
| 43 |
[mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
|
| 44 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
|
| 2 |
+
[](https://arxiv.org/abs/2105.02446)
|
| 3 |
+
[](https://github.com/MoonInTheRiver/DiffSinger)
|
| 4 |
+
[](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
| 5 |
+
| [Interactive🤗 SVS](https://huggingface.co/spaces/Silentlin/DiffSinger)
|
| 6 |
+
|
| 7 |
+
## DiffSinger (SVS)
|
| 8 |
|
| 9 |
### PART1. [Run DiffSinger on PopCS](README-SVS-popcs.md)
|
| 10 |
+
In PART1, we only focus on spectrum modeling (acoustic model) and assume the ground-truth (GT) F0 to be given as the pitch information following these papers [1][2][3]. If you want to conduct experiments with F0 prediction, please move to PART2.
|
| 11 |
|
| 12 |
Thus, the pipeline of this part can be summarized as:
|
| 13 |
|
|
|
|
| 24 |
|
| 25 |
[3] DeepSinger : Singing Voice Synthesis with Data Mined From the Web. KDD 2020.
|
| 26 |
|
| 27 |
+
Click here for detailed instructions: [link](README-SVS-popcs.md).
|
| 28 |
+
|
| 29 |
+
|
| 30 |
### PART2. [Run DiffSinger on Opencpop](README-SVS-opencpop-cascade.md)
|
| 31 |
+
Thanks [Opencpop team](https://wenet.org.cn/opencpop/) for releasing their SVS dataset with MIDI label, **Jan.20, 2022** (after we published our paper).
|
| 32 |
|
| 33 |
Since there are elaborately annotated MIDI labels, we are able to supplement the pipeline in PART 1 by adding a naive melody frontend.
|
| 34 |
|
| 35 |
+
#### 2.A
|
| 36 |
+
Thus, the pipeline of [2.A](README-SVS-opencpop-cascade.md) can be summarized as:
|
| 37 |
|
| 38 |
```
|
| 39 |
[lyrics] + [MIDI] -> [linguistic representation (with MIDI information)] + [predicted F0] + [predicted phoneme duration] (Melody frontend)
|
|
|
|
| 41 |
[mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
|
| 42 |
```
|
| 43 |
|
| 44 |
+
Click here for detailed instructions: [link](README-SVS-opencpop-cascade.md).
|
| 45 |
+
|
| 46 |
+
#### 2.B
|
| 47 |
+
In 2.1, we find that if we predict F0 explicitly in the melody frontend, there will be many bad cases of uv/v prediction. Then, we abandon the explicit prediction of the F0 curve in the melody frontend and make a joint prediction with spectrograms.
|
| 48 |
|
| 49 |
+
Thus, the pipeline of [2.B](README-SVS-opencpop-e2e.md) can be summarized as:
|
| 50 |
```
|
| 51 |
[lyrics] + [MIDI] -> [linguistic representation] + [predicted phoneme duration] (Melody frontend)
|
| 52 |
[linguistic representation (with MIDI information)] + [predicted phoneme duration] -> [mel-spectrogram] (Acoustic model)
|
| 53 |
[mel-spectrogram] -> [predicted F0] (Pitch extractor)
|
| 54 |
[mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
Click here for detailed instructions: [link](README-SVS-opencpop-e2e.md).
|
| 58 |
+
|
| 59 |
+
### FAQ
|
| 60 |
+
Q1: Why do I need F0 in Vocoders?
|
| 61 |
+
|
| 62 |
+
A1: See vocoder parts in HiFiSinger, DiffSinger or SingGAN. This is a common practice now.
|
| 63 |
+
|
| 64 |
+
Q2: Why not run MIDI version SVS on PopCS dataset? or Why not release MIDI labels for PopCS dataset?
|
| 65 |
+
|
| 66 |
+
A2: Our laboratory has no funds to label PopCS dataset. But there are funds for labeling other singing dataset, which is coming soon.
|
| 67 |
+
|
| 68 |
+
Q3: Why " 'HifiGAN' object has no attribute 'model' "?
|
| 69 |
+
|
| 70 |
+
A3: Please put the pretrained vocoders in your `checkpoints` dictionary.
|
| 71 |
+
|
| 72 |
+
Q4: How to check whether I use GT information or predicted information during inference from packed test set?
|
| 73 |
+
|
| 74 |
+
A4: Please see codes [here](https://github.com/MoonInTheRiver/DiffSinger/blob/55e2f46068af6e69940a9f8f02d306c24a940cab/tasks/tts/fs2.py#L343).
|
| 75 |
+
|
| 76 |
+
...
|
docs/README-TTS.md
CHANGED
|
@@ -1,4 +1,10 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
### 1. Preparation
|
| 3 |
|
| 4 |
#### Data Preparation
|
|
|
|
| 1 |
+
# DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
|
| 2 |
+
[](https://arxiv.org/abs/2105.02446)
|
| 3 |
+
[](https://github.com/MoonInTheRiver/DiffSinger)
|
| 4 |
+
[](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
| 5 |
+
| [Interactive🤗 TTS](https://huggingface.co/spaces/NATSpeech/DiffSpeech)
|
| 6 |
+
|
| 7 |
+
## DiffSpeech (TTS)
|
| 8 |
### 1. Preparation
|
| 9 |
|
| 10 |
#### Data Preparation
|
inference/svs/base_svs_infer.py
CHANGED
|
@@ -142,7 +142,7 @@ class BaseSVSInfer:
|
|
| 142 |
ph_seq = inp['ph_seq']
|
| 143 |
note_lst = inp['note_seq'].split()
|
| 144 |
midi_dur_lst = inp['note_dur_seq'].split()
|
| 145 |
-
is_slur = inp['is_slur_seq'].split()
|
| 146 |
print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
|
| 147 |
if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
|
| 148 |
print('Pass word-notes check.')
|
|
|
|
| 142 |
ph_seq = inp['ph_seq']
|
| 143 |
note_lst = inp['note_seq'].split()
|
| 144 |
midi_dur_lst = inp['note_dur_seq'].split()
|
| 145 |
+
is_slur = [float(x) for x in inp['is_slur_seq'].split()]
|
| 146 |
print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
|
| 147 |
if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
|
| 148 |
print('Pass word-notes check.')
|
inference/svs/ds_cascade.py
CHANGED
|
@@ -52,3 +52,5 @@ if __name__ == '__main__':
|
|
| 52 |
'input_type': 'phoneme'
|
| 53 |
} # input like Opencpop dataset.
|
| 54 |
DiffSingerCascadeInfer.example_run(inp)
|
|
|
|
|
|
|
|
|
| 52 |
'input_type': 'phoneme'
|
| 53 |
} # input like Opencpop dataset.
|
| 54 |
DiffSingerCascadeInfer.example_run(inp)
|
| 55 |
+
|
| 56 |
+
# # CUDA_VISIBLE_DEVICES=1 python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
|
inference/svs/ds_e2e.py
CHANGED
|
@@ -53,7 +53,7 @@ if __name__ == '__main__':
|
|
| 53 |
'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
|
| 54 |
'input_type': 'word'
|
| 55 |
} # user input: Chinese characters
|
| 56 |
-
|
| 57 |
'text': '小酒窝长睫毛AP是你最美的记号',
|
| 58 |
'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
|
| 59 |
'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
|
|
@@ -64,4 +64,4 @@ if __name__ == '__main__':
|
|
| 64 |
DiffSingerE2EInfer.example_run(inp)
|
| 65 |
|
| 66 |
|
| 67 |
-
# python inference/svs/ds_e2e.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
|
|
|
|
| 53 |
'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
|
| 54 |
'input_type': 'word'
|
| 55 |
} # user input: Chinese characters
|
| 56 |
+
inp = {
|
| 57 |
'text': '小酒窝长睫毛AP是你最美的记号',
|
| 58 |
'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
|
| 59 |
'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
|
|
|
|
| 64 |
DiffSingerE2EInfer.example_run(inp)
|
| 65 |
|
| 66 |
|
| 67 |
+
# CUDA_VISIBLE_DEVICES=3 python inference/svs/ds_e2e.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
|
inference/svs/gradio/infer.py
CHANGED
|
@@ -88,4 +88,4 @@ if __name__ == '__main__':
|
|
| 88 |
|
| 89 |
# python inference/svs/gradio/infer.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
|
| 90 |
# python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
|
| 91 |
-
# CUDA_VISIBLE_DEVICES=3 python inference/svs/gradio/infer.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
|
|
|
|
| 88 |
|
| 89 |
# python inference/svs/gradio/infer.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
|
| 90 |
# python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
|
| 91 |
+
# CUDA_VISIBLE_DEVICES=3 python inference/svs/gradio/infer.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
|
modules/diffsinger_midi/fs2.py
CHANGED
|
@@ -116,3 +116,113 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
| 116 |
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
| 117 |
|
| 118 |
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
| 117 |
|
| 118 |
return ret
|
| 119 |
+
|
| 120 |
+
def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
|
| 121 |
+
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
|
| 122 |
+
pitch_padding = mel2ph == 0
|
| 123 |
+
if hparams['pitch_ar']:
|
| 124 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None)
|
| 125 |
+
if f0 is None:
|
| 126 |
+
f0 = pitch_pred[:, :, 0]
|
| 127 |
+
else:
|
| 128 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
|
| 129 |
+
if f0 is None:
|
| 130 |
+
f0 = pitch_pred[:, :, 0]
|
| 131 |
+
if hparams['use_uv'] and uv is None:
|
| 132 |
+
uv = pitch_pred[:, :, 1] > 0
|
| 133 |
+
|
| 134 |
+
# here f0_denorm for pitch prediction
|
| 135 |
+
ret['f0_denorm'] = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
|
| 136 |
+
|
| 137 |
+
# here f0_denorm for mel prediction
|
| 138 |
+
if self.training:
|
| 139 |
+
mask = torch.full(uv.shape, hparams.get('mask_uv_prob', 0.)).to(f0.device)
|
| 140 |
+
masked_uv = torch.bernoulli(mask).bool().to(f0.device) # prob 的概率吐出一个随机uv.
|
| 141 |
+
uv_masked = uv.bool() | masked_uv
|
| 142 |
+
# print((uv.float()-uv_masked.float()).mean(dim=1))
|
| 143 |
+
f0_denorm = denorm_f0(f0, uv_masked, hparams, pitch_padding=pitch_padding)
|
| 144 |
+
else:
|
| 145 |
+
f0_denorm = ret['f0_denorm']
|
| 146 |
+
|
| 147 |
+
if pitch_padding is not None:
|
| 148 |
+
f0[pitch_padding] = 0
|
| 149 |
+
|
| 150 |
+
pitch = f0_to_coarse(f0_denorm) # start from 0
|
| 151 |
+
pitch_embed = self.pitch_embed(pitch)
|
| 152 |
+
return pitch_embed
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class FastSpeech2MIDIMasked(FastSpeech2MIDI):
|
| 156 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
| 157 |
+
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
|
| 158 |
+
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
|
| 159 |
+
ret = {}
|
| 160 |
+
|
| 161 |
+
midi_dur_embedding, slur_embedding = 0, 0
|
| 162 |
+
if kwargs.get('midi_dur') is not None:
|
| 163 |
+
midi_dur_embedding = self.midi_dur_layer(kwargs['midi_dur'][:, :, None]) # [B, T, 1] -> [B, T, H]
|
| 164 |
+
if kwargs.get('is_slur') is not None:
|
| 165 |
+
slur_embedding = self.is_slur_embed(kwargs['is_slur'])
|
| 166 |
+
encoder_out = self.encoder(txt_tokens, 0, midi_dur_embedding, slur_embedding) # [B, T, C]
|
| 167 |
+
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
| 168 |
+
|
| 169 |
+
# add ref style embed
|
| 170 |
+
# Not implemented
|
| 171 |
+
# variance encoder
|
| 172 |
+
var_embed = 0
|
| 173 |
+
|
| 174 |
+
# encoder_out_dur denotes encoder outputs for duration predictor
|
| 175 |
+
# in speech adaptation, duration predictor use old speaker embedding
|
| 176 |
+
if hparams['use_spk_embed']:
|
| 177 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
|
| 178 |
+
elif hparams['use_spk_id']:
|
| 179 |
+
spk_embed_id = spk_embed
|
| 180 |
+
if spk_embed_dur_id is None:
|
| 181 |
+
spk_embed_dur_id = spk_embed_id
|
| 182 |
+
if spk_embed_f0_id is None:
|
| 183 |
+
spk_embed_f0_id = spk_embed_id
|
| 184 |
+
spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
|
| 185 |
+
spk_embed_dur = spk_embed_f0 = spk_embed
|
| 186 |
+
if hparams['use_split_spk_id']:
|
| 187 |
+
spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
|
| 188 |
+
spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
|
| 189 |
+
else:
|
| 190 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = 0
|
| 191 |
+
|
| 192 |
+
# add dur
|
| 193 |
+
dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
|
| 194 |
+
|
| 195 |
+
mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
|
| 196 |
+
|
| 197 |
+
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
|
| 198 |
+
|
| 199 |
+
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
|
| 200 |
+
decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
|
| 201 |
+
|
| 202 |
+
# expanded midi
|
| 203 |
+
midi_embedding = self.midi_embed(kwargs['pitch_midi'])
|
| 204 |
+
midi_embedding = F.pad(midi_embedding, [0, 0, 1, 0])
|
| 205 |
+
midi_embedding = torch.gather(midi_embedding, 1, mel2ph_)
|
| 206 |
+
print(midi_embedding.shape, decoder_inp.shape)
|
| 207 |
+
midi_mask = torch.full(midi_embedding.shape, hparams.get('mask_uv_prob', 0.)).to(midi_embedding.device)
|
| 208 |
+
midi_mask = 1 - torch.bernoulli(midi_mask).bool().to(midi_embedding.device) # prob 的概率吐出一个随机uv.
|
| 209 |
+
|
| 210 |
+
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
|
| 211 |
+
|
| 212 |
+
decoder_inp += midi_embedding
|
| 213 |
+
decoder_inp_origin = decoder_inp
|
| 214 |
+
# add pitch and energy embed
|
| 215 |
+
pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
|
| 216 |
+
if hparams['use_pitch_embed']:
|
| 217 |
+
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
|
| 218 |
+
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
|
| 219 |
+
if hparams['use_energy_embed']:
|
| 220 |
+
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
|
| 221 |
+
|
| 222 |
+
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
|
| 223 |
+
|
| 224 |
+
if skip_decoder:
|
| 225 |
+
return ret
|
| 226 |
+
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
| 227 |
+
|
| 228 |
+
return ret
|
modules/hifigan/hifigan.py
CHANGED
|
@@ -1,365 +1,365 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn.functional as F
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
| 5 |
-
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 6 |
-
|
| 7 |
-
from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
|
| 8 |
-
from modules.parallel_wavegan.models.source import SourceModuleHnNSF
|
| 9 |
-
import numpy as np
|
| 10 |
-
|
| 11 |
-
LRELU_SLOPE = 0.1
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def init_weights(m, mean=0.0, std=0.01):
|
| 15 |
-
classname = m.__class__.__name__
|
| 16 |
-
if classname.find("Conv") != -1:
|
| 17 |
-
m.weight.data.normal_(mean, std)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def apply_weight_norm(m):
|
| 21 |
-
classname = m.__class__.__name__
|
| 22 |
-
if classname.find("Conv") != -1:
|
| 23 |
-
weight_norm(m)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def get_padding(kernel_size, dilation=1):
|
| 27 |
-
return int((kernel_size * dilation - dilation) / 2)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class ResBlock1(torch.nn.Module):
|
| 31 |
-
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 32 |
-
super(ResBlock1, self).__init__()
|
| 33 |
-
self.h = h
|
| 34 |
-
self.convs1 = nn.ModuleList([
|
| 35 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
| 36 |
-
padding=get_padding(kernel_size, dilation[0]))),
|
| 37 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
| 38 |
-
padding=get_padding(kernel_size, dilation[1]))),
|
| 39 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
| 40 |
-
padding=get_padding(kernel_size, dilation[2])))
|
| 41 |
-
])
|
| 42 |
-
self.convs1.apply(init_weights)
|
| 43 |
-
|
| 44 |
-
self.convs2 = nn.ModuleList([
|
| 45 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 46 |
-
padding=get_padding(kernel_size, 1))),
|
| 47 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 48 |
-
padding=get_padding(kernel_size, 1))),
|
| 49 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 50 |
-
padding=get_padding(kernel_size, 1)))
|
| 51 |
-
])
|
| 52 |
-
self.convs2.apply(init_weights)
|
| 53 |
-
|
| 54 |
-
def forward(self, x):
|
| 55 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
| 56 |
-
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 57 |
-
xt = c1(xt)
|
| 58 |
-
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
| 59 |
-
xt = c2(xt)
|
| 60 |
-
x = xt + x
|
| 61 |
-
return x
|
| 62 |
-
|
| 63 |
-
def remove_weight_norm(self):
|
| 64 |
-
for l in self.convs1:
|
| 65 |
-
remove_weight_norm(l)
|
| 66 |
-
for l in self.convs2:
|
| 67 |
-
remove_weight_norm(l)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
class ResBlock2(torch.nn.Module):
|
| 71 |
-
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
| 72 |
-
super(ResBlock2, self).__init__()
|
| 73 |
-
self.h = h
|
| 74 |
-
self.convs = nn.ModuleList([
|
| 75 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
| 76 |
-
padding=get_padding(kernel_size, dilation[0]))),
|
| 77 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
| 78 |
-
padding=get_padding(kernel_size, dilation[1])))
|
| 79 |
-
])
|
| 80 |
-
self.convs.apply(init_weights)
|
| 81 |
-
|
| 82 |
-
def forward(self, x):
|
| 83 |
-
for c in self.convs:
|
| 84 |
-
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 85 |
-
xt = c(xt)
|
| 86 |
-
x = xt + x
|
| 87 |
-
return x
|
| 88 |
-
|
| 89 |
-
def remove_weight_norm(self):
|
| 90 |
-
for l in self.convs:
|
| 91 |
-
remove_weight_norm(l)
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
class Conv1d1x1(Conv1d):
|
| 95 |
-
"""1x1 Conv1d with customized initialization."""
|
| 96 |
-
|
| 97 |
-
def __init__(self, in_channels, out_channels, bias):
|
| 98 |
-
"""Initialize 1x1 Conv1d module."""
|
| 99 |
-
super(Conv1d1x1, self).__init__(in_channels, out_channels,
|
| 100 |
-
kernel_size=1, padding=0,
|
| 101 |
-
dilation=1, bias=bias)
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
class HifiGanGenerator(torch.nn.Module):
|
| 105 |
-
def __init__(self, h, c_out=1):
|
| 106 |
-
super(HifiGanGenerator, self).__init__()
|
| 107 |
-
self.h = h
|
| 108 |
-
self.num_kernels = len(h['resblock_kernel_sizes'])
|
| 109 |
-
self.num_upsamples = len(h['upsample_rates'])
|
| 110 |
-
|
| 111 |
-
if h['use_pitch_embed']:
|
| 112 |
-
self.harmonic_num = 8
|
| 113 |
-
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
|
| 114 |
-
self.m_source = SourceModuleHnNSF(
|
| 115 |
-
sampling_rate=h['audio_sample_rate'],
|
| 116 |
-
harmonic_num=self.harmonic_num)
|
| 117 |
-
self.noise_convs = nn.ModuleList()
|
| 118 |
-
self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
|
| 119 |
-
resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
|
| 120 |
-
|
| 121 |
-
self.ups = nn.ModuleList()
|
| 122 |
-
for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
|
| 123 |
-
c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
|
| 124 |
-
self.ups.append(weight_norm(
|
| 125 |
-
ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
|
| 126 |
-
if h['use_pitch_embed']:
|
| 127 |
-
if i + 1 < len(h['upsample_rates']):
|
| 128 |
-
stride_f0 = np.prod(h['upsample_rates'][i + 1:])
|
| 129 |
-
self.noise_convs.append(Conv1d(
|
| 130 |
-
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
|
| 131 |
-
else:
|
| 132 |
-
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
| 133 |
-
|
| 134 |
-
self.resblocks = nn.ModuleList()
|
| 135 |
-
for i in range(len(self.ups)):
|
| 136 |
-
ch = h['upsample_initial_channel'] // (2 ** (i + 1))
|
| 137 |
-
for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
|
| 138 |
-
self.resblocks.append(resblock(h, ch, k, d))
|
| 139 |
-
|
| 140 |
-
self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
|
| 141 |
-
self.ups.apply(init_weights)
|
| 142 |
-
self.conv_post.apply(init_weights)
|
| 143 |
-
|
| 144 |
-
def forward(self, x, f0=None):
|
| 145 |
-
if f0 is not None:
|
| 146 |
-
# harmonic-source signal, noise-source signal, uv flag
|
| 147 |
-
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
|
| 148 |
-
har_source, noi_source, uv = self.m_source(f0)
|
| 149 |
-
har_source = har_source.transpose(1, 2)
|
| 150 |
-
|
| 151 |
-
x = self.conv_pre(x)
|
| 152 |
-
for i in range(self.num_upsamples):
|
| 153 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 154 |
-
x = self.ups[i](x)
|
| 155 |
-
if f0 is not None:
|
| 156 |
-
x_source = self.noise_convs[i](har_source)
|
| 157 |
-
x = x + x_source
|
| 158 |
-
xs = None
|
| 159 |
-
for j in range(self.num_kernels):
|
| 160 |
-
if xs is None:
|
| 161 |
-
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 162 |
-
else:
|
| 163 |
-
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 164 |
-
x = xs / self.num_kernels
|
| 165 |
-
x = F.leaky_relu(x)
|
| 166 |
-
x = self.conv_post(x)
|
| 167 |
-
x = torch.tanh(x)
|
| 168 |
-
|
| 169 |
-
return x
|
| 170 |
-
|
| 171 |
-
def remove_weight_norm(self):
|
| 172 |
-
print('Removing weight norm...')
|
| 173 |
-
for l in self.ups:
|
| 174 |
-
remove_weight_norm(l)
|
| 175 |
-
for l in self.resblocks:
|
| 176 |
-
l.remove_weight_norm()
|
| 177 |
-
remove_weight_norm(self.conv_pre)
|
| 178 |
-
remove_weight_norm(self.conv_post)
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
class DiscriminatorP(torch.nn.Module):
|
| 182 |
-
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
|
| 183 |
-
super(DiscriminatorP, self).__init__()
|
| 184 |
-
self.use_cond = use_cond
|
| 185 |
-
if use_cond:
|
| 186 |
-
from utils.hparams import hparams
|
| 187 |
-
t = hparams['hop_size']
|
| 188 |
-
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
|
| 189 |
-
c_in = 2
|
| 190 |
-
|
| 191 |
-
self.period = period
|
| 192 |
-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 193 |
-
self.convs = nn.ModuleList([
|
| 194 |
-
norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 195 |
-
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 196 |
-
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 197 |
-
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 198 |
-
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
| 199 |
-
])
|
| 200 |
-
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 201 |
-
|
| 202 |
-
def forward(self, x, mel):
|
| 203 |
-
fmap = []
|
| 204 |
-
if self.use_cond:
|
| 205 |
-
x_mel = self.cond_net(mel)
|
| 206 |
-
x = torch.cat([x_mel, x], 1)
|
| 207 |
-
# 1d to 2d
|
| 208 |
-
b, c, t = x.shape
|
| 209 |
-
if t % self.period != 0: # pad first
|
| 210 |
-
n_pad = self.period - (t % self.period)
|
| 211 |
-
x = F.pad(x, (0, n_pad), "reflect")
|
| 212 |
-
t = t + n_pad
|
| 213 |
-
x = x.view(b, c, t // self.period, self.period)
|
| 214 |
-
|
| 215 |
-
for l in self.convs:
|
| 216 |
-
x = l(x)
|
| 217 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 218 |
-
fmap.append(x)
|
| 219 |
-
x = self.conv_post(x)
|
| 220 |
-
fmap.append(x)
|
| 221 |
-
x = torch.flatten(x, 1, -1)
|
| 222 |
-
|
| 223 |
-
return x, fmap
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 227 |
-
def __init__(self, use_cond=False, c_in=1):
|
| 228 |
-
super(MultiPeriodDiscriminator, self).__init__()
|
| 229 |
-
self.discriminators = nn.ModuleList([
|
| 230 |
-
DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
|
| 231 |
-
DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
|
| 232 |
-
DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
|
| 233 |
-
DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
|
| 234 |
-
DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
|
| 235 |
-
])
|
| 236 |
-
|
| 237 |
-
def forward(self, y, y_hat, mel=None):
|
| 238 |
-
y_d_rs = []
|
| 239 |
-
y_d_gs = []
|
| 240 |
-
fmap_rs = []
|
| 241 |
-
fmap_gs = []
|
| 242 |
-
for i, d in enumerate(self.discriminators):
|
| 243 |
-
y_d_r, fmap_r = d(y, mel)
|
| 244 |
-
y_d_g, fmap_g = d(y_hat, mel)
|
| 245 |
-
y_d_rs.append(y_d_r)
|
| 246 |
-
fmap_rs.append(fmap_r)
|
| 247 |
-
y_d_gs.append(y_d_g)
|
| 248 |
-
fmap_gs.append(fmap_g)
|
| 249 |
-
|
| 250 |
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
class DiscriminatorS(torch.nn.Module):
|
| 254 |
-
def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
|
| 255 |
-
super(DiscriminatorS, self).__init__()
|
| 256 |
-
self.use_cond = use_cond
|
| 257 |
-
if use_cond:
|
| 258 |
-
t = np.prod(upsample_rates)
|
| 259 |
-
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
|
| 260 |
-
c_in = 2
|
| 261 |
-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 262 |
-
self.convs = nn.ModuleList([
|
| 263 |
-
norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
|
| 264 |
-
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
| 265 |
-
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
| 266 |
-
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
| 267 |
-
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
| 268 |
-
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
| 269 |
-
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 270 |
-
])
|
| 271 |
-
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 272 |
-
|
| 273 |
-
def forward(self, x, mel):
|
| 274 |
-
if self.use_cond:
|
| 275 |
-
x_mel = self.cond_net(mel)
|
| 276 |
-
x = torch.cat([x_mel, x], 1)
|
| 277 |
-
fmap = []
|
| 278 |
-
for l in self.convs:
|
| 279 |
-
x = l(x)
|
| 280 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 281 |
-
fmap.append(x)
|
| 282 |
-
x = self.conv_post(x)
|
| 283 |
-
fmap.append(x)
|
| 284 |
-
x = torch.flatten(x, 1, -1)
|
| 285 |
-
|
| 286 |
-
return x, fmap
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
class MultiScaleDiscriminator(torch.nn.Module):
|
| 290 |
-
def __init__(self, use_cond=False, c_in=1):
|
| 291 |
-
super(MultiScaleDiscriminator, self).__init__()
|
| 292 |
-
from utils.hparams import hparams
|
| 293 |
-
self.discriminators = nn.ModuleList([
|
| 294 |
-
DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
|
| 295 |
-
upsample_rates=[4, 4, hparams['hop_size'] // 16],
|
| 296 |
-
c_in=c_in),
|
| 297 |
-
DiscriminatorS(use_cond=use_cond,
|
| 298 |
-
upsample_rates=[4, 4, hparams['hop_size'] // 32],
|
| 299 |
-
c_in=c_in),
|
| 300 |
-
DiscriminatorS(use_cond=use_cond,
|
| 301 |
-
upsample_rates=[4, 4, hparams['hop_size'] // 64],
|
| 302 |
-
c_in=c_in),
|
| 303 |
-
])
|
| 304 |
-
self.meanpools = nn.ModuleList([
|
| 305 |
-
AvgPool1d(4, 2, padding=1),
|
| 306 |
-
AvgPool1d(4, 2, padding=1)
|
| 307 |
-
])
|
| 308 |
-
|
| 309 |
-
def forward(self, y, y_hat, mel=None):
|
| 310 |
-
y_d_rs = []
|
| 311 |
-
y_d_gs = []
|
| 312 |
-
fmap_rs = []
|
| 313 |
-
fmap_gs = []
|
| 314 |
-
for i, d in enumerate(self.discriminators):
|
| 315 |
-
if i != 0:
|
| 316 |
-
y = self.meanpools[i - 1](y)
|
| 317 |
-
y_hat = self.meanpools[i - 1](y_hat)
|
| 318 |
-
y_d_r, fmap_r = d(y, mel)
|
| 319 |
-
y_d_g, fmap_g = d(y_hat, mel)
|
| 320 |
-
y_d_rs.append(y_d_r)
|
| 321 |
-
fmap_rs.append(fmap_r)
|
| 322 |
-
y_d_gs.append(y_d_g)
|
| 323 |
-
fmap_gs.append(fmap_g)
|
| 324 |
-
|
| 325 |
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
def feature_loss(fmap_r, fmap_g):
|
| 329 |
-
loss = 0
|
| 330 |
-
for dr, dg in zip(fmap_r, fmap_g):
|
| 331 |
-
for rl, gl in zip(dr, dg):
|
| 332 |
-
loss += torch.mean(torch.abs(rl - gl))
|
| 333 |
-
|
| 334 |
-
return loss * 2
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
| 338 |
-
r_losses = 0
|
| 339 |
-
g_losses = 0
|
| 340 |
-
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 341 |
-
r_loss = torch.mean((1 - dr) ** 2)
|
| 342 |
-
g_loss = torch.mean(dg ** 2)
|
| 343 |
-
r_losses += r_loss
|
| 344 |
-
g_losses += g_loss
|
| 345 |
-
r_losses = r_losses / len(disc_real_outputs)
|
| 346 |
-
g_losses = g_losses / len(disc_real_outputs)
|
| 347 |
-
return r_losses, g_losses
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
def cond_discriminator_loss(outputs):
|
| 351 |
-
loss = 0
|
| 352 |
-
for dg in outputs:
|
| 353 |
-
g_loss = torch.mean(dg ** 2)
|
| 354 |
-
loss += g_loss
|
| 355 |
-
loss = loss / len(outputs)
|
| 356 |
-
return loss
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
def generator_loss(disc_outputs):
|
| 360 |
-
loss = 0
|
| 361 |
-
for dg in disc_outputs:
|
| 362 |
-
l = torch.mean((1 - dg) ** 2)
|
| 363 |
-
loss += l
|
| 364 |
-
loss = loss / len(disc_outputs)
|
| 365 |
-
return loss
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
| 5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 6 |
+
|
| 7 |
+
from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
|
| 8 |
+
from modules.parallel_wavegan.models.source import SourceModuleHnNSF
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
LRELU_SLOPE = 0.1
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 15 |
+
classname = m.__class__.__name__
|
| 16 |
+
if classname.find("Conv") != -1:
|
| 17 |
+
m.weight.data.normal_(mean, std)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def apply_weight_norm(m):
|
| 21 |
+
classname = m.__class__.__name__
|
| 22 |
+
if classname.find("Conv") != -1:
|
| 23 |
+
weight_norm(m)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_padding(kernel_size, dilation=1):
|
| 27 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ResBlock1(torch.nn.Module):
|
| 31 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 32 |
+
super(ResBlock1, self).__init__()
|
| 33 |
+
self.h = h
|
| 34 |
+
self.convs1 = nn.ModuleList([
|
| 35 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
| 36 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
| 37 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
| 38 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
| 39 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
| 40 |
+
padding=get_padding(kernel_size, dilation[2])))
|
| 41 |
+
])
|
| 42 |
+
self.convs1.apply(init_weights)
|
| 43 |
+
|
| 44 |
+
self.convs2 = nn.ModuleList([
|
| 45 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 46 |
+
padding=get_padding(kernel_size, 1))),
|
| 47 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 48 |
+
padding=get_padding(kernel_size, 1))),
|
| 49 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 50 |
+
padding=get_padding(kernel_size, 1)))
|
| 51 |
+
])
|
| 52 |
+
self.convs2.apply(init_weights)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
| 56 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 57 |
+
xt = c1(xt)
|
| 58 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
| 59 |
+
xt = c2(xt)
|
| 60 |
+
x = xt + x
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
def remove_weight_norm(self):
|
| 64 |
+
for l in self.convs1:
|
| 65 |
+
remove_weight_norm(l)
|
| 66 |
+
for l in self.convs2:
|
| 67 |
+
remove_weight_norm(l)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ResBlock2(torch.nn.Module):
|
| 71 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
| 72 |
+
super(ResBlock2, self).__init__()
|
| 73 |
+
self.h = h
|
| 74 |
+
self.convs = nn.ModuleList([
|
| 75 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
| 76 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
| 77 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
| 78 |
+
padding=get_padding(kernel_size, dilation[1])))
|
| 79 |
+
])
|
| 80 |
+
self.convs.apply(init_weights)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
for c in self.convs:
|
| 84 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 85 |
+
xt = c(xt)
|
| 86 |
+
x = xt + x
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
def remove_weight_norm(self):
|
| 90 |
+
for l in self.convs:
|
| 91 |
+
remove_weight_norm(l)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Conv1d1x1(Conv1d):
|
| 95 |
+
"""1x1 Conv1d with customized initialization."""
|
| 96 |
+
|
| 97 |
+
def __init__(self, in_channels, out_channels, bias):
|
| 98 |
+
"""Initialize 1x1 Conv1d module."""
|
| 99 |
+
super(Conv1d1x1, self).__init__(in_channels, out_channels,
|
| 100 |
+
kernel_size=1, padding=0,
|
| 101 |
+
dilation=1, bias=bias)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class HifiGanGenerator(torch.nn.Module):
|
| 105 |
+
def __init__(self, h, c_out=1):
|
| 106 |
+
super(HifiGanGenerator, self).__init__()
|
| 107 |
+
self.h = h
|
| 108 |
+
self.num_kernels = len(h['resblock_kernel_sizes'])
|
| 109 |
+
self.num_upsamples = len(h['upsample_rates'])
|
| 110 |
+
|
| 111 |
+
if h['use_pitch_embed']:
|
| 112 |
+
self.harmonic_num = 8
|
| 113 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
|
| 114 |
+
self.m_source = SourceModuleHnNSF(
|
| 115 |
+
sampling_rate=h['audio_sample_rate'],
|
| 116 |
+
harmonic_num=self.harmonic_num)
|
| 117 |
+
self.noise_convs = nn.ModuleList()
|
| 118 |
+
self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
|
| 119 |
+
resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
|
| 120 |
+
|
| 121 |
+
self.ups = nn.ModuleList()
|
| 122 |
+
for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
|
| 123 |
+
c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
|
| 124 |
+
self.ups.append(weight_norm(
|
| 125 |
+
ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
|
| 126 |
+
if h['use_pitch_embed']:
|
| 127 |
+
if i + 1 < len(h['upsample_rates']):
|
| 128 |
+
stride_f0 = np.prod(h['upsample_rates'][i + 1:])
|
| 129 |
+
self.noise_convs.append(Conv1d(
|
| 130 |
+
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
|
| 131 |
+
else:
|
| 132 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
| 133 |
+
|
| 134 |
+
self.resblocks = nn.ModuleList()
|
| 135 |
+
for i in range(len(self.ups)):
|
| 136 |
+
ch = h['upsample_initial_channel'] // (2 ** (i + 1))
|
| 137 |
+
for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
|
| 138 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
| 139 |
+
|
| 140 |
+
self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
|
| 141 |
+
self.ups.apply(init_weights)
|
| 142 |
+
self.conv_post.apply(init_weights)
|
| 143 |
+
|
| 144 |
+
def forward(self, x, f0=None):
|
| 145 |
+
if f0 is not None:
|
| 146 |
+
# harmonic-source signal, noise-source signal, uv flag
|
| 147 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
|
| 148 |
+
har_source, noi_source, uv = self.m_source(f0)
|
| 149 |
+
har_source = har_source.transpose(1, 2)
|
| 150 |
+
|
| 151 |
+
x = self.conv_pre(x)
|
| 152 |
+
for i in range(self.num_upsamples):
|
| 153 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 154 |
+
x = self.ups[i](x)
|
| 155 |
+
if f0 is not None:
|
| 156 |
+
x_source = self.noise_convs[i](har_source)
|
| 157 |
+
x = x + x_source
|
| 158 |
+
xs = None
|
| 159 |
+
for j in range(self.num_kernels):
|
| 160 |
+
if xs is None:
|
| 161 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 162 |
+
else:
|
| 163 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 164 |
+
x = xs / self.num_kernels
|
| 165 |
+
x = F.leaky_relu(x)
|
| 166 |
+
x = self.conv_post(x)
|
| 167 |
+
x = torch.tanh(x)
|
| 168 |
+
|
| 169 |
+
return x
|
| 170 |
+
|
| 171 |
+
def remove_weight_norm(self):
|
| 172 |
+
print('Removing weight norm...')
|
| 173 |
+
for l in self.ups:
|
| 174 |
+
remove_weight_norm(l)
|
| 175 |
+
for l in self.resblocks:
|
| 176 |
+
l.remove_weight_norm()
|
| 177 |
+
remove_weight_norm(self.conv_pre)
|
| 178 |
+
remove_weight_norm(self.conv_post)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class DiscriminatorP(torch.nn.Module):
|
| 182 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
|
| 183 |
+
super(DiscriminatorP, self).__init__()
|
| 184 |
+
self.use_cond = use_cond
|
| 185 |
+
if use_cond:
|
| 186 |
+
from utils.hparams import hparams
|
| 187 |
+
t = hparams['hop_size']
|
| 188 |
+
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
|
| 189 |
+
c_in = 2
|
| 190 |
+
|
| 191 |
+
self.period = period
|
| 192 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 193 |
+
self.convs = nn.ModuleList([
|
| 194 |
+
norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 195 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 196 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 197 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 198 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
| 199 |
+
])
|
| 200 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 201 |
+
|
| 202 |
+
def forward(self, x, mel):
|
| 203 |
+
fmap = []
|
| 204 |
+
if self.use_cond:
|
| 205 |
+
x_mel = self.cond_net(mel)
|
| 206 |
+
x = torch.cat([x_mel, x], 1)
|
| 207 |
+
# 1d to 2d
|
| 208 |
+
b, c, t = x.shape
|
| 209 |
+
if t % self.period != 0: # pad first
|
| 210 |
+
n_pad = self.period - (t % self.period)
|
| 211 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 212 |
+
t = t + n_pad
|
| 213 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 214 |
+
|
| 215 |
+
for l in self.convs:
|
| 216 |
+
x = l(x)
|
| 217 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 218 |
+
fmap.append(x)
|
| 219 |
+
x = self.conv_post(x)
|
| 220 |
+
fmap.append(x)
|
| 221 |
+
x = torch.flatten(x, 1, -1)
|
| 222 |
+
|
| 223 |
+
return x, fmap
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 227 |
+
def __init__(self, use_cond=False, c_in=1):
|
| 228 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 229 |
+
self.discriminators = nn.ModuleList([
|
| 230 |
+
DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
|
| 231 |
+
DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
|
| 232 |
+
DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
|
| 233 |
+
DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
|
| 234 |
+
DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
|
| 235 |
+
])
|
| 236 |
+
|
| 237 |
+
def forward(self, y, y_hat, mel=None):
|
| 238 |
+
y_d_rs = []
|
| 239 |
+
y_d_gs = []
|
| 240 |
+
fmap_rs = []
|
| 241 |
+
fmap_gs = []
|
| 242 |
+
for i, d in enumerate(self.discriminators):
|
| 243 |
+
y_d_r, fmap_r = d(y, mel)
|
| 244 |
+
y_d_g, fmap_g = d(y_hat, mel)
|
| 245 |
+
y_d_rs.append(y_d_r)
|
| 246 |
+
fmap_rs.append(fmap_r)
|
| 247 |
+
y_d_gs.append(y_d_g)
|
| 248 |
+
fmap_gs.append(fmap_g)
|
| 249 |
+
|
| 250 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class DiscriminatorS(torch.nn.Module):
|
| 254 |
+
def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
|
| 255 |
+
super(DiscriminatorS, self).__init__()
|
| 256 |
+
self.use_cond = use_cond
|
| 257 |
+
if use_cond:
|
| 258 |
+
t = np.prod(upsample_rates)
|
| 259 |
+
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
|
| 260 |
+
c_in = 2
|
| 261 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 262 |
+
self.convs = nn.ModuleList([
|
| 263 |
+
norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
|
| 264 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
| 265 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
| 266 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
| 267 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
| 268 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
| 269 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 270 |
+
])
|
| 271 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 272 |
+
|
| 273 |
+
def forward(self, x, mel):
|
| 274 |
+
if self.use_cond:
|
| 275 |
+
x_mel = self.cond_net(mel)
|
| 276 |
+
x = torch.cat([x_mel, x], 1)
|
| 277 |
+
fmap = []
|
| 278 |
+
for l in self.convs:
|
| 279 |
+
x = l(x)
|
| 280 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 281 |
+
fmap.append(x)
|
| 282 |
+
x = self.conv_post(x)
|
| 283 |
+
fmap.append(x)
|
| 284 |
+
x = torch.flatten(x, 1, -1)
|
| 285 |
+
|
| 286 |
+
return x, fmap
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
| 290 |
+
def __init__(self, use_cond=False, c_in=1):
|
| 291 |
+
super(MultiScaleDiscriminator, self).__init__()
|
| 292 |
+
from utils.hparams import hparams
|
| 293 |
+
self.discriminators = nn.ModuleList([
|
| 294 |
+
DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
|
| 295 |
+
upsample_rates=[4, 4, hparams['hop_size'] // 16],
|
| 296 |
+
c_in=c_in),
|
| 297 |
+
DiscriminatorS(use_cond=use_cond,
|
| 298 |
+
upsample_rates=[4, 4, hparams['hop_size'] // 32],
|
| 299 |
+
c_in=c_in),
|
| 300 |
+
DiscriminatorS(use_cond=use_cond,
|
| 301 |
+
upsample_rates=[4, 4, hparams['hop_size'] // 64],
|
| 302 |
+
c_in=c_in),
|
| 303 |
+
])
|
| 304 |
+
self.meanpools = nn.ModuleList([
|
| 305 |
+
AvgPool1d(4, 2, padding=1),
|
| 306 |
+
AvgPool1d(4, 2, padding=1)
|
| 307 |
+
])
|
| 308 |
+
|
| 309 |
+
def forward(self, y, y_hat, mel=None):
|
| 310 |
+
y_d_rs = []
|
| 311 |
+
y_d_gs = []
|
| 312 |
+
fmap_rs = []
|
| 313 |
+
fmap_gs = []
|
| 314 |
+
for i, d in enumerate(self.discriminators):
|
| 315 |
+
if i != 0:
|
| 316 |
+
y = self.meanpools[i - 1](y)
|
| 317 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
| 318 |
+
y_d_r, fmap_r = d(y, mel)
|
| 319 |
+
y_d_g, fmap_g = d(y_hat, mel)
|
| 320 |
+
y_d_rs.append(y_d_r)
|
| 321 |
+
fmap_rs.append(fmap_r)
|
| 322 |
+
y_d_gs.append(y_d_g)
|
| 323 |
+
fmap_gs.append(fmap_g)
|
| 324 |
+
|
| 325 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def feature_loss(fmap_r, fmap_g):
|
| 329 |
+
loss = 0
|
| 330 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 331 |
+
for rl, gl in zip(dr, dg):
|
| 332 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 333 |
+
|
| 334 |
+
return loss * 2
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
| 338 |
+
r_losses = 0
|
| 339 |
+
g_losses = 0
|
| 340 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 341 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
| 342 |
+
g_loss = torch.mean(dg ** 2)
|
| 343 |
+
r_losses += r_loss
|
| 344 |
+
g_losses += g_loss
|
| 345 |
+
r_losses = r_losses / len(disc_real_outputs)
|
| 346 |
+
g_losses = g_losses / len(disc_real_outputs)
|
| 347 |
+
return r_losses, g_losses
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def cond_discriminator_loss(outputs):
|
| 351 |
+
loss = 0
|
| 352 |
+
for dg in outputs:
|
| 353 |
+
g_loss = torch.mean(dg ** 2)
|
| 354 |
+
loss += g_loss
|
| 355 |
+
loss = loss / len(outputs)
|
| 356 |
+
return loss
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def generator_loss(disc_outputs):
|
| 360 |
+
loss = 0
|
| 361 |
+
for dg in disc_outputs:
|
| 362 |
+
l = torch.mean((1 - dg) ** 2)
|
| 363 |
+
loss += l
|
| 364 |
+
loss = loss / len(disc_outputs)
|
| 365 |
+
return loss
|
modules/hifigan/mel_utils.py
CHANGED
|
@@ -1,80 +1,80 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import torch
|
| 3 |
-
import torch.utils.data
|
| 4 |
-
from librosa.filters import mel as librosa_mel_fn
|
| 5 |
-
from scipy.io.wavfile import read
|
| 6 |
-
|
| 7 |
-
MAX_WAV_VALUE = 32768.0
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def load_wav(full_path):
|
| 11 |
-
sampling_rate, data = read(full_path)
|
| 12 |
-
return data, sampling_rate
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 16 |
-
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def dynamic_range_decompression(x, C=1):
|
| 20 |
-
return np.exp(x) / C
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 24 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def dynamic_range_decompression_torch(x, C=1):
|
| 28 |
-
return torch.exp(x) / C
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def spectral_normalize_torch(magnitudes):
|
| 32 |
-
output = dynamic_range_compression_torch(magnitudes)
|
| 33 |
-
return output
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def spectral_de_normalize_torch(magnitudes):
|
| 37 |
-
output = dynamic_range_decompression_torch(magnitudes)
|
| 38 |
-
return output
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
mel_basis = {}
|
| 42 |
-
hann_window = {}
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def mel_spectrogram(y, hparams, center=False, complex=False):
|
| 46 |
-
# hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
|
| 47 |
-
# win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
|
| 48 |
-
# fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
| 49 |
-
# fmax: 10000 # To be increased/reduced depending on data.
|
| 50 |
-
# fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
|
| 51 |
-
# n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
|
| 52 |
-
n_fft = hparams['fft_size']
|
| 53 |
-
num_mels = hparams['audio_num_mel_bins']
|
| 54 |
-
sampling_rate = hparams['audio_sample_rate']
|
| 55 |
-
hop_size = hparams['hop_size']
|
| 56 |
-
win_size = hparams['win_size']
|
| 57 |
-
fmin = hparams['fmin']
|
| 58 |
-
fmax = hparams['fmax']
|
| 59 |
-
y = y.clamp(min=-1., max=1.)
|
| 60 |
-
global mel_basis, hann_window
|
| 61 |
-
if fmax not in mel_basis:
|
| 62 |
-
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
| 63 |
-
mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
| 64 |
-
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
| 65 |
-
|
| 66 |
-
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
| 67 |
-
mode='reflect')
|
| 68 |
-
y = y.squeeze(1)
|
| 69 |
-
|
| 70 |
-
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
| 71 |
-
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
| 72 |
-
|
| 73 |
-
if not complex:
|
| 74 |
-
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 75 |
-
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
|
| 76 |
-
spec = spectral_normalize_torch(spec)
|
| 77 |
-
else:
|
| 78 |
-
B, C, T, _ = spec.shape
|
| 79 |
-
spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
|
| 80 |
-
return spec
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.utils.data
|
| 4 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 5 |
+
from scipy.io.wavfile import read
|
| 6 |
+
|
| 7 |
+
MAX_WAV_VALUE = 32768.0
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_wav(full_path):
|
| 11 |
+
sampling_rate, data = read(full_path)
|
| 12 |
+
return data, sampling_rate
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 16 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def dynamic_range_decompression(x, C=1):
|
| 20 |
+
return np.exp(x) / C
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 24 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 28 |
+
return torch.exp(x) / C
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def spectral_normalize_torch(magnitudes):
|
| 32 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 33 |
+
return output
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def spectral_de_normalize_torch(magnitudes):
|
| 37 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
| 38 |
+
return output
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
mel_basis = {}
|
| 42 |
+
hann_window = {}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def mel_spectrogram(y, hparams, center=False, complex=False):
|
| 46 |
+
# hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
|
| 47 |
+
# win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
|
| 48 |
+
# fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
| 49 |
+
# fmax: 10000 # To be increased/reduced depending on data.
|
| 50 |
+
# fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
|
| 51 |
+
# n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
|
| 52 |
+
n_fft = hparams['fft_size']
|
| 53 |
+
num_mels = hparams['audio_num_mel_bins']
|
| 54 |
+
sampling_rate = hparams['audio_sample_rate']
|
| 55 |
+
hop_size = hparams['hop_size']
|
| 56 |
+
win_size = hparams['win_size']
|
| 57 |
+
fmin = hparams['fmin']
|
| 58 |
+
fmax = hparams['fmax']
|
| 59 |
+
y = y.clamp(min=-1., max=1.)
|
| 60 |
+
global mel_basis, hann_window
|
| 61 |
+
if fmax not in mel_basis:
|
| 62 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
| 63 |
+
mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
| 64 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
| 65 |
+
|
| 66 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
| 67 |
+
mode='reflect')
|
| 68 |
+
y = y.squeeze(1)
|
| 69 |
+
|
| 70 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
| 71 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
| 72 |
+
|
| 73 |
+
if not complex:
|
| 74 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 75 |
+
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
|
| 76 |
+
spec = spectral_normalize_torch(spec)
|
| 77 |
+
else:
|
| 78 |
+
B, C, T, _ = spec.shape
|
| 79 |
+
spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
|
| 80 |
+
return spec
|
modules/parallel_wavegan/models/parallel_wavegan.py
CHANGED
|
@@ -1,434 +1,434 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
# Copyright 2019 Tomoki Hayashi
|
| 4 |
-
# MIT License (https://opensource.org/licenses/MIT)
|
| 5 |
-
|
| 6 |
-
"""Parallel WaveGAN Modules."""
|
| 7 |
-
|
| 8 |
-
import logging
|
| 9 |
-
import math
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
from torch import nn
|
| 13 |
-
|
| 14 |
-
from modules.parallel_wavegan.layers import Conv1d
|
| 15 |
-
from modules.parallel_wavegan.layers import Conv1d1x1
|
| 16 |
-
from modules.parallel_wavegan.layers import ResidualBlock
|
| 17 |
-
from modules.parallel_wavegan.layers import upsample
|
| 18 |
-
from modules.parallel_wavegan import models
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class ParallelWaveGANGenerator(torch.nn.Module):
|
| 22 |
-
"""Parallel WaveGAN Generator module."""
|
| 23 |
-
|
| 24 |
-
def __init__(self,
|
| 25 |
-
in_channels=1,
|
| 26 |
-
out_channels=1,
|
| 27 |
-
kernel_size=3,
|
| 28 |
-
layers=30,
|
| 29 |
-
stacks=3,
|
| 30 |
-
residual_channels=64,
|
| 31 |
-
gate_channels=128,
|
| 32 |
-
skip_channels=64,
|
| 33 |
-
aux_channels=80,
|
| 34 |
-
aux_context_window=2,
|
| 35 |
-
dropout=0.0,
|
| 36 |
-
bias=True,
|
| 37 |
-
use_weight_norm=True,
|
| 38 |
-
use_causal_conv=False,
|
| 39 |
-
upsample_conditional_features=True,
|
| 40 |
-
upsample_net="ConvInUpsampleNetwork",
|
| 41 |
-
upsample_params={"upsample_scales": [4, 4, 4, 4]},
|
| 42 |
-
use_pitch_embed=False,
|
| 43 |
-
):
|
| 44 |
-
"""Initialize Parallel WaveGAN Generator module.
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
in_channels (int): Number of input channels.
|
| 48 |
-
out_channels (int): Number of output channels.
|
| 49 |
-
kernel_size (int): Kernel size of dilated convolution.
|
| 50 |
-
layers (int): Number of residual block layers.
|
| 51 |
-
stacks (int): Number of stacks i.e., dilation cycles.
|
| 52 |
-
residual_channels (int): Number of channels in residual conv.
|
| 53 |
-
gate_channels (int): Number of channels in gated conv.
|
| 54 |
-
skip_channels (int): Number of channels in skip conv.
|
| 55 |
-
aux_channels (int): Number of channels for auxiliary feature conv.
|
| 56 |
-
aux_context_window (int): Context window size for auxiliary feature.
|
| 57 |
-
dropout (float): Dropout rate. 0.0 means no dropout applied.
|
| 58 |
-
bias (bool): Whether to use bias parameter in conv layer.
|
| 59 |
-
use_weight_norm (bool): Whether to use weight norm.
|
| 60 |
-
If set to true, it will be applied to all of the conv layers.
|
| 61 |
-
use_causal_conv (bool): Whether to use causal structure.
|
| 62 |
-
upsample_conditional_features (bool): Whether to use upsampling network.
|
| 63 |
-
upsample_net (str): Upsampling network architecture.
|
| 64 |
-
upsample_params (dict): Upsampling network parameters.
|
| 65 |
-
|
| 66 |
-
"""
|
| 67 |
-
super(ParallelWaveGANGenerator, self).__init__()
|
| 68 |
-
self.in_channels = in_channels
|
| 69 |
-
self.out_channels = out_channels
|
| 70 |
-
self.aux_channels = aux_channels
|
| 71 |
-
self.layers = layers
|
| 72 |
-
self.stacks = stacks
|
| 73 |
-
self.kernel_size = kernel_size
|
| 74 |
-
|
| 75 |
-
# check the number of layers and stacks
|
| 76 |
-
assert layers % stacks == 0
|
| 77 |
-
layers_per_stack = layers // stacks
|
| 78 |
-
|
| 79 |
-
# define first convolution
|
| 80 |
-
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
|
| 81 |
-
|
| 82 |
-
# define conv + upsampling network
|
| 83 |
-
if upsample_conditional_features:
|
| 84 |
-
upsample_params.update({
|
| 85 |
-
"use_causal_conv": use_causal_conv,
|
| 86 |
-
})
|
| 87 |
-
if upsample_net == "MelGANGenerator":
|
| 88 |
-
assert aux_context_window == 0
|
| 89 |
-
upsample_params.update({
|
| 90 |
-
"use_weight_norm": False, # not to apply twice
|
| 91 |
-
"use_final_nonlinear_activation": False,
|
| 92 |
-
})
|
| 93 |
-
self.upsample_net = getattr(models, upsample_net)(**upsample_params)
|
| 94 |
-
else:
|
| 95 |
-
if upsample_net == "ConvInUpsampleNetwork":
|
| 96 |
-
upsample_params.update({
|
| 97 |
-
"aux_channels": aux_channels,
|
| 98 |
-
"aux_context_window": aux_context_window,
|
| 99 |
-
})
|
| 100 |
-
self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
|
| 101 |
-
else:
|
| 102 |
-
self.upsample_net = None
|
| 103 |
-
|
| 104 |
-
# define residual blocks
|
| 105 |
-
self.conv_layers = torch.nn.ModuleList()
|
| 106 |
-
for layer in range(layers):
|
| 107 |
-
dilation = 2 ** (layer % layers_per_stack)
|
| 108 |
-
conv = ResidualBlock(
|
| 109 |
-
kernel_size=kernel_size,
|
| 110 |
-
residual_channels=residual_channels,
|
| 111 |
-
gate_channels=gate_channels,
|
| 112 |
-
skip_channels=skip_channels,
|
| 113 |
-
aux_channels=aux_channels,
|
| 114 |
-
dilation=dilation,
|
| 115 |
-
dropout=dropout,
|
| 116 |
-
bias=bias,
|
| 117 |
-
use_causal_conv=use_causal_conv,
|
| 118 |
-
)
|
| 119 |
-
self.conv_layers += [conv]
|
| 120 |
-
|
| 121 |
-
# define output layers
|
| 122 |
-
self.last_conv_layers = torch.nn.ModuleList([
|
| 123 |
-
torch.nn.ReLU(inplace=True),
|
| 124 |
-
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
| 125 |
-
torch.nn.ReLU(inplace=True),
|
| 126 |
-
Conv1d1x1(skip_channels, out_channels, bias=True),
|
| 127 |
-
])
|
| 128 |
-
|
| 129 |
-
self.use_pitch_embed = use_pitch_embed
|
| 130 |
-
if use_pitch_embed:
|
| 131 |
-
self.pitch_embed = nn.Embedding(300, aux_channels, 0)
|
| 132 |
-
self.c_proj = nn.Linear(2 * aux_channels, aux_channels)
|
| 133 |
-
|
| 134 |
-
# apply weight norm
|
| 135 |
-
if use_weight_norm:
|
| 136 |
-
self.apply_weight_norm()
|
| 137 |
-
|
| 138 |
-
def forward(self, x, c=None, pitch=None, **kwargs):
|
| 139 |
-
"""Calculate forward propagation.
|
| 140 |
-
|
| 141 |
-
Args:
|
| 142 |
-
x (Tensor): Input noise signal (B, C_in, T).
|
| 143 |
-
c (Tensor): Local conditioning auxiliary features (B, C ,T').
|
| 144 |
-
pitch (Tensor): Local conditioning pitch (B, T').
|
| 145 |
-
|
| 146 |
-
Returns:
|
| 147 |
-
Tensor: Output tensor (B, C_out, T)
|
| 148 |
-
|
| 149 |
-
"""
|
| 150 |
-
# perform upsampling
|
| 151 |
-
if c is not None and self.upsample_net is not None:
|
| 152 |
-
if self.use_pitch_embed:
|
| 153 |
-
p = self.pitch_embed(pitch)
|
| 154 |
-
c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2)
|
| 155 |
-
c = self.upsample_net(c)
|
| 156 |
-
assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1))
|
| 157 |
-
|
| 158 |
-
# encode to hidden representation
|
| 159 |
-
x = self.first_conv(x)
|
| 160 |
-
skips = 0
|
| 161 |
-
for f in self.conv_layers:
|
| 162 |
-
x, h = f(x, c)
|
| 163 |
-
skips += h
|
| 164 |
-
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
| 165 |
-
|
| 166 |
-
# apply final layers
|
| 167 |
-
x = skips
|
| 168 |
-
for f in self.last_conv_layers:
|
| 169 |
-
x = f(x)
|
| 170 |
-
|
| 171 |
-
return x
|
| 172 |
-
|
| 173 |
-
def remove_weight_norm(self):
|
| 174 |
-
"""Remove weight normalization module from all of the layers."""
|
| 175 |
-
def _remove_weight_norm(m):
|
| 176 |
-
try:
|
| 177 |
-
logging.debug(f"Weight norm is removed from {m}.")
|
| 178 |
-
torch.nn.utils.remove_weight_norm(m)
|
| 179 |
-
except ValueError: # this module didn't have weight norm
|
| 180 |
-
return
|
| 181 |
-
|
| 182 |
-
self.apply(_remove_weight_norm)
|
| 183 |
-
|
| 184 |
-
def apply_weight_norm(self):
|
| 185 |
-
"""Apply weight normalization module from all of the layers."""
|
| 186 |
-
def _apply_weight_norm(m):
|
| 187 |
-
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
| 188 |
-
torch.nn.utils.weight_norm(m)
|
| 189 |
-
logging.debug(f"Weight norm is applied to {m}.")
|
| 190 |
-
|
| 191 |
-
self.apply(_apply_weight_norm)
|
| 192 |
-
|
| 193 |
-
@staticmethod
|
| 194 |
-
def _get_receptive_field_size(layers, stacks, kernel_size,
|
| 195 |
-
dilation=lambda x: 2 ** x):
|
| 196 |
-
assert layers % stacks == 0
|
| 197 |
-
layers_per_cycle = layers // stacks
|
| 198 |
-
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
| 199 |
-
return (kernel_size - 1) * sum(dilations) + 1
|
| 200 |
-
|
| 201 |
-
@property
|
| 202 |
-
def receptive_field_size(self):
|
| 203 |
-
"""Return receptive field size."""
|
| 204 |
-
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
class ParallelWaveGANDiscriminator(torch.nn.Module):
|
| 208 |
-
"""Parallel WaveGAN Discriminator module."""
|
| 209 |
-
|
| 210 |
-
def __init__(self,
|
| 211 |
-
in_channels=1,
|
| 212 |
-
out_channels=1,
|
| 213 |
-
kernel_size=3,
|
| 214 |
-
layers=10,
|
| 215 |
-
conv_channels=64,
|
| 216 |
-
dilation_factor=1,
|
| 217 |
-
nonlinear_activation="LeakyReLU",
|
| 218 |
-
nonlinear_activation_params={"negative_slope": 0.2},
|
| 219 |
-
bias=True,
|
| 220 |
-
use_weight_norm=True,
|
| 221 |
-
):
|
| 222 |
-
"""Initialize Parallel WaveGAN Discriminator module.
|
| 223 |
-
|
| 224 |
-
Args:
|
| 225 |
-
in_channels (int): Number of input channels.
|
| 226 |
-
out_channels (int): Number of output channels.
|
| 227 |
-
kernel_size (int): Number of output channels.
|
| 228 |
-
layers (int): Number of conv layers.
|
| 229 |
-
conv_channels (int): Number of chnn layers.
|
| 230 |
-
dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,
|
| 231 |
-
the dilation will be 2, 4, 8, ..., and so on.
|
| 232 |
-
nonlinear_activation (str): Nonlinear function after each conv.
|
| 233 |
-
nonlinear_activation_params (dict): Nonlinear function parameters
|
| 234 |
-
bias (bool): Whether to use bias parameter in conv.
|
| 235 |
-
use_weight_norm (bool) Whether to use weight norm.
|
| 236 |
-
If set to true, it will be applied to all of the conv layers.
|
| 237 |
-
|
| 238 |
-
"""
|
| 239 |
-
super(ParallelWaveGANDiscriminator, self).__init__()
|
| 240 |
-
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
| 241 |
-
assert dilation_factor > 0, "Dilation factor must be > 0."
|
| 242 |
-
self.conv_layers = torch.nn.ModuleList()
|
| 243 |
-
conv_in_channels = in_channels
|
| 244 |
-
for i in range(layers - 1):
|
| 245 |
-
if i == 0:
|
| 246 |
-
dilation = 1
|
| 247 |
-
else:
|
| 248 |
-
dilation = i if dilation_factor == 1 else dilation_factor ** i
|
| 249 |
-
conv_in_channels = conv_channels
|
| 250 |
-
padding = (kernel_size - 1) // 2 * dilation
|
| 251 |
-
conv_layer = [
|
| 252 |
-
Conv1d(conv_in_channels, conv_channels,
|
| 253 |
-
kernel_size=kernel_size, padding=padding,
|
| 254 |
-
dilation=dilation, bias=bias),
|
| 255 |
-
getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
|
| 256 |
-
]
|
| 257 |
-
self.conv_layers += conv_layer
|
| 258 |
-
padding = (kernel_size - 1) // 2
|
| 259 |
-
last_conv_layer = Conv1d(
|
| 260 |
-
conv_in_channels, out_channels,
|
| 261 |
-
kernel_size=kernel_size, padding=padding, bias=bias)
|
| 262 |
-
self.conv_layers += [last_conv_layer]
|
| 263 |
-
|
| 264 |
-
# apply weight norm
|
| 265 |
-
if use_weight_norm:
|
| 266 |
-
self.apply_weight_norm()
|
| 267 |
-
|
| 268 |
-
def forward(self, x):
|
| 269 |
-
"""Calculate forward propagation.
|
| 270 |
-
|
| 271 |
-
Args:
|
| 272 |
-
x (Tensor): Input noise signal (B, 1, T).
|
| 273 |
-
|
| 274 |
-
Returns:
|
| 275 |
-
Tensor: Output tensor (B, 1, T)
|
| 276 |
-
|
| 277 |
-
"""
|
| 278 |
-
for f in self.conv_layers:
|
| 279 |
-
x = f(x)
|
| 280 |
-
return x
|
| 281 |
-
|
| 282 |
-
def apply_weight_norm(self):
|
| 283 |
-
"""Apply weight normalization module from all of the layers."""
|
| 284 |
-
def _apply_weight_norm(m):
|
| 285 |
-
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
| 286 |
-
torch.nn.utils.weight_norm(m)
|
| 287 |
-
logging.debug(f"Weight norm is applied to {m}.")
|
| 288 |
-
|
| 289 |
-
self.apply(_apply_weight_norm)
|
| 290 |
-
|
| 291 |
-
def remove_weight_norm(self):
|
| 292 |
-
"""Remove weight normalization module from all of the layers."""
|
| 293 |
-
def _remove_weight_norm(m):
|
| 294 |
-
try:
|
| 295 |
-
logging.debug(f"Weight norm is removed from {m}.")
|
| 296 |
-
torch.nn.utils.remove_weight_norm(m)
|
| 297 |
-
except ValueError: # this module didn't have weight norm
|
| 298 |
-
return
|
| 299 |
-
|
| 300 |
-
self.apply(_remove_weight_norm)
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
class ResidualParallelWaveGANDiscriminator(torch.nn.Module):
|
| 304 |
-
"""Parallel WaveGAN Discriminator module."""
|
| 305 |
-
|
| 306 |
-
def __init__(self,
|
| 307 |
-
in_channels=1,
|
| 308 |
-
out_channels=1,
|
| 309 |
-
kernel_size=3,
|
| 310 |
-
layers=30,
|
| 311 |
-
stacks=3,
|
| 312 |
-
residual_channels=64,
|
| 313 |
-
gate_channels=128,
|
| 314 |
-
skip_channels=64,
|
| 315 |
-
dropout=0.0,
|
| 316 |
-
bias=True,
|
| 317 |
-
use_weight_norm=True,
|
| 318 |
-
use_causal_conv=False,
|
| 319 |
-
nonlinear_activation="LeakyReLU",
|
| 320 |
-
nonlinear_activation_params={"negative_slope": 0.2},
|
| 321 |
-
):
|
| 322 |
-
"""Initialize Parallel WaveGAN Discriminator module.
|
| 323 |
-
|
| 324 |
-
Args:
|
| 325 |
-
in_channels (int): Number of input channels.
|
| 326 |
-
out_channels (int): Number of output channels.
|
| 327 |
-
kernel_size (int): Kernel size of dilated convolution.
|
| 328 |
-
layers (int): Number of residual block layers.
|
| 329 |
-
stacks (int): Number of stacks i.e., dilation cycles.
|
| 330 |
-
residual_channels (int): Number of channels in residual conv.
|
| 331 |
-
gate_channels (int): Number of channels in gated conv.
|
| 332 |
-
skip_channels (int): Number of channels in skip conv.
|
| 333 |
-
dropout (float): Dropout rate. 0.0 means no dropout applied.
|
| 334 |
-
bias (bool): Whether to use bias parameter in conv.
|
| 335 |
-
use_weight_norm (bool): Whether to use weight norm.
|
| 336 |
-
If set to true, it will be applied to all of the conv layers.
|
| 337 |
-
use_causal_conv (bool): Whether to use causal structure.
|
| 338 |
-
nonlinear_activation_params (dict): Nonlinear function parameters
|
| 339 |
-
|
| 340 |
-
"""
|
| 341 |
-
super(ResidualParallelWaveGANDiscriminator, self).__init__()
|
| 342 |
-
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
| 343 |
-
|
| 344 |
-
self.in_channels = in_channels
|
| 345 |
-
self.out_channels = out_channels
|
| 346 |
-
self.layers = layers
|
| 347 |
-
self.stacks = stacks
|
| 348 |
-
self.kernel_size = kernel_size
|
| 349 |
-
|
| 350 |
-
# check the number of layers and stacks
|
| 351 |
-
assert layers % stacks == 0
|
| 352 |
-
layers_per_stack = layers // stacks
|
| 353 |
-
|
| 354 |
-
# define first convolution
|
| 355 |
-
self.first_conv = torch.nn.Sequential(
|
| 356 |
-
Conv1d1x1(in_channels, residual_channels, bias=True),
|
| 357 |
-
getattr(torch.nn, nonlinear_activation)(
|
| 358 |
-
inplace=True, **nonlinear_activation_params),
|
| 359 |
-
)
|
| 360 |
-
|
| 361 |
-
# define residual blocks
|
| 362 |
-
self.conv_layers = torch.nn.ModuleList()
|
| 363 |
-
for layer in range(layers):
|
| 364 |
-
dilation = 2 ** (layer % layers_per_stack)
|
| 365 |
-
conv = ResidualBlock(
|
| 366 |
-
kernel_size=kernel_size,
|
| 367 |
-
residual_channels=residual_channels,
|
| 368 |
-
gate_channels=gate_channels,
|
| 369 |
-
skip_channels=skip_channels,
|
| 370 |
-
aux_channels=-1,
|
| 371 |
-
dilation=dilation,
|
| 372 |
-
dropout=dropout,
|
| 373 |
-
bias=bias,
|
| 374 |
-
use_causal_conv=use_causal_conv,
|
| 375 |
-
)
|
| 376 |
-
self.conv_layers += [conv]
|
| 377 |
-
|
| 378 |
-
# define output layers
|
| 379 |
-
self.last_conv_layers = torch.nn.ModuleList([
|
| 380 |
-
getattr(torch.nn, nonlinear_activation)(
|
| 381 |
-
inplace=True, **nonlinear_activation_params),
|
| 382 |
-
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
| 383 |
-
getattr(torch.nn, nonlinear_activation)(
|
| 384 |
-
inplace=True, **nonlinear_activation_params),
|
| 385 |
-
Conv1d1x1(skip_channels, out_channels, bias=True),
|
| 386 |
-
])
|
| 387 |
-
|
| 388 |
-
# apply weight norm
|
| 389 |
-
if use_weight_norm:
|
| 390 |
-
self.apply_weight_norm()
|
| 391 |
-
|
| 392 |
-
def forward(self, x):
|
| 393 |
-
"""Calculate forward propagation.
|
| 394 |
-
|
| 395 |
-
Args:
|
| 396 |
-
x (Tensor): Input noise signal (B, 1, T).
|
| 397 |
-
|
| 398 |
-
Returns:
|
| 399 |
-
Tensor: Output tensor (B, 1, T)
|
| 400 |
-
|
| 401 |
-
"""
|
| 402 |
-
x = self.first_conv(x)
|
| 403 |
-
|
| 404 |
-
skips = 0
|
| 405 |
-
for f in self.conv_layers:
|
| 406 |
-
x, h = f(x, None)
|
| 407 |
-
skips += h
|
| 408 |
-
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
| 409 |
-
|
| 410 |
-
# apply final layers
|
| 411 |
-
x = skips
|
| 412 |
-
for f in self.last_conv_layers:
|
| 413 |
-
x = f(x)
|
| 414 |
-
return x
|
| 415 |
-
|
| 416 |
-
def apply_weight_norm(self):
|
| 417 |
-
"""Apply weight normalization module from all of the layers."""
|
| 418 |
-
def _apply_weight_norm(m):
|
| 419 |
-
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
| 420 |
-
torch.nn.utils.weight_norm(m)
|
| 421 |
-
logging.debug(f"Weight norm is applied to {m}.")
|
| 422 |
-
|
| 423 |
-
self.apply(_apply_weight_norm)
|
| 424 |
-
|
| 425 |
-
def remove_weight_norm(self):
|
| 426 |
-
"""Remove weight normalization module from all of the layers."""
|
| 427 |
-
def _remove_weight_norm(m):
|
| 428 |
-
try:
|
| 429 |
-
logging.debug(f"Weight norm is removed from {m}.")
|
| 430 |
-
torch.nn.utils.remove_weight_norm(m)
|
| 431 |
-
except ValueError: # this module didn't have weight norm
|
| 432 |
-
return
|
| 433 |
-
|
| 434 |
-
self.apply(_remove_weight_norm)
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
# Copyright 2019 Tomoki Hayashi
|
| 4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 5 |
+
|
| 6 |
+
"""Parallel WaveGAN Modules."""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from modules.parallel_wavegan.layers import Conv1d
|
| 15 |
+
from modules.parallel_wavegan.layers import Conv1d1x1
|
| 16 |
+
from modules.parallel_wavegan.layers import ResidualBlock
|
| 17 |
+
from modules.parallel_wavegan.layers import upsample
|
| 18 |
+
from modules.parallel_wavegan import models
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ParallelWaveGANGenerator(torch.nn.Module):
|
| 22 |
+
"""Parallel WaveGAN Generator module."""
|
| 23 |
+
|
| 24 |
+
def __init__(self,
|
| 25 |
+
in_channels=1,
|
| 26 |
+
out_channels=1,
|
| 27 |
+
kernel_size=3,
|
| 28 |
+
layers=30,
|
| 29 |
+
stacks=3,
|
| 30 |
+
residual_channels=64,
|
| 31 |
+
gate_channels=128,
|
| 32 |
+
skip_channels=64,
|
| 33 |
+
aux_channels=80,
|
| 34 |
+
aux_context_window=2,
|
| 35 |
+
dropout=0.0,
|
| 36 |
+
bias=True,
|
| 37 |
+
use_weight_norm=True,
|
| 38 |
+
use_causal_conv=False,
|
| 39 |
+
upsample_conditional_features=True,
|
| 40 |
+
upsample_net="ConvInUpsampleNetwork",
|
| 41 |
+
upsample_params={"upsample_scales": [4, 4, 4, 4]},
|
| 42 |
+
use_pitch_embed=False,
|
| 43 |
+
):
|
| 44 |
+
"""Initialize Parallel WaveGAN Generator module.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
in_channels (int): Number of input channels.
|
| 48 |
+
out_channels (int): Number of output channels.
|
| 49 |
+
kernel_size (int): Kernel size of dilated convolution.
|
| 50 |
+
layers (int): Number of residual block layers.
|
| 51 |
+
stacks (int): Number of stacks i.e., dilation cycles.
|
| 52 |
+
residual_channels (int): Number of channels in residual conv.
|
| 53 |
+
gate_channels (int): Number of channels in gated conv.
|
| 54 |
+
skip_channels (int): Number of channels in skip conv.
|
| 55 |
+
aux_channels (int): Number of channels for auxiliary feature conv.
|
| 56 |
+
aux_context_window (int): Context window size for auxiliary feature.
|
| 57 |
+
dropout (float): Dropout rate. 0.0 means no dropout applied.
|
| 58 |
+
bias (bool): Whether to use bias parameter in conv layer.
|
| 59 |
+
use_weight_norm (bool): Whether to use weight norm.
|
| 60 |
+
If set to true, it will be applied to all of the conv layers.
|
| 61 |
+
use_causal_conv (bool): Whether to use causal structure.
|
| 62 |
+
upsample_conditional_features (bool): Whether to use upsampling network.
|
| 63 |
+
upsample_net (str): Upsampling network architecture.
|
| 64 |
+
upsample_params (dict): Upsampling network parameters.
|
| 65 |
+
|
| 66 |
+
"""
|
| 67 |
+
super(ParallelWaveGANGenerator, self).__init__()
|
| 68 |
+
self.in_channels = in_channels
|
| 69 |
+
self.out_channels = out_channels
|
| 70 |
+
self.aux_channels = aux_channels
|
| 71 |
+
self.layers = layers
|
| 72 |
+
self.stacks = stacks
|
| 73 |
+
self.kernel_size = kernel_size
|
| 74 |
+
|
| 75 |
+
# check the number of layers and stacks
|
| 76 |
+
assert layers % stacks == 0
|
| 77 |
+
layers_per_stack = layers // stacks
|
| 78 |
+
|
| 79 |
+
# define first convolution
|
| 80 |
+
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
|
| 81 |
+
|
| 82 |
+
# define conv + upsampling network
|
| 83 |
+
if upsample_conditional_features:
|
| 84 |
+
upsample_params.update({
|
| 85 |
+
"use_causal_conv": use_causal_conv,
|
| 86 |
+
})
|
| 87 |
+
if upsample_net == "MelGANGenerator":
|
| 88 |
+
assert aux_context_window == 0
|
| 89 |
+
upsample_params.update({
|
| 90 |
+
"use_weight_norm": False, # not to apply twice
|
| 91 |
+
"use_final_nonlinear_activation": False,
|
| 92 |
+
})
|
| 93 |
+
self.upsample_net = getattr(models, upsample_net)(**upsample_params)
|
| 94 |
+
else:
|
| 95 |
+
if upsample_net == "ConvInUpsampleNetwork":
|
| 96 |
+
upsample_params.update({
|
| 97 |
+
"aux_channels": aux_channels,
|
| 98 |
+
"aux_context_window": aux_context_window,
|
| 99 |
+
})
|
| 100 |
+
self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
|
| 101 |
+
else:
|
| 102 |
+
self.upsample_net = None
|
| 103 |
+
|
| 104 |
+
# define residual blocks
|
| 105 |
+
self.conv_layers = torch.nn.ModuleList()
|
| 106 |
+
for layer in range(layers):
|
| 107 |
+
dilation = 2 ** (layer % layers_per_stack)
|
| 108 |
+
conv = ResidualBlock(
|
| 109 |
+
kernel_size=kernel_size,
|
| 110 |
+
residual_channels=residual_channels,
|
| 111 |
+
gate_channels=gate_channels,
|
| 112 |
+
skip_channels=skip_channels,
|
| 113 |
+
aux_channels=aux_channels,
|
| 114 |
+
dilation=dilation,
|
| 115 |
+
dropout=dropout,
|
| 116 |
+
bias=bias,
|
| 117 |
+
use_causal_conv=use_causal_conv,
|
| 118 |
+
)
|
| 119 |
+
self.conv_layers += [conv]
|
| 120 |
+
|
| 121 |
+
# define output layers
|
| 122 |
+
self.last_conv_layers = torch.nn.ModuleList([
|
| 123 |
+
torch.nn.ReLU(inplace=True),
|
| 124 |
+
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
| 125 |
+
torch.nn.ReLU(inplace=True),
|
| 126 |
+
Conv1d1x1(skip_channels, out_channels, bias=True),
|
| 127 |
+
])
|
| 128 |
+
|
| 129 |
+
self.use_pitch_embed = use_pitch_embed
|
| 130 |
+
if use_pitch_embed:
|
| 131 |
+
self.pitch_embed = nn.Embedding(300, aux_channels, 0)
|
| 132 |
+
self.c_proj = nn.Linear(2 * aux_channels, aux_channels)
|
| 133 |
+
|
| 134 |
+
# apply weight norm
|
| 135 |
+
if use_weight_norm:
|
| 136 |
+
self.apply_weight_norm()
|
| 137 |
+
|
| 138 |
+
def forward(self, x, c=None, pitch=None, **kwargs):
|
| 139 |
+
"""Calculate forward propagation.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
x (Tensor): Input noise signal (B, C_in, T).
|
| 143 |
+
c (Tensor): Local conditioning auxiliary features (B, C ,T').
|
| 144 |
+
pitch (Tensor): Local conditioning pitch (B, T').
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Tensor: Output tensor (B, C_out, T)
|
| 148 |
+
|
| 149 |
+
"""
|
| 150 |
+
# perform upsampling
|
| 151 |
+
if c is not None and self.upsample_net is not None:
|
| 152 |
+
if self.use_pitch_embed:
|
| 153 |
+
p = self.pitch_embed(pitch)
|
| 154 |
+
c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2)
|
| 155 |
+
c = self.upsample_net(c)
|
| 156 |
+
assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1))
|
| 157 |
+
|
| 158 |
+
# encode to hidden representation
|
| 159 |
+
x = self.first_conv(x)
|
| 160 |
+
skips = 0
|
| 161 |
+
for f in self.conv_layers:
|
| 162 |
+
x, h = f(x, c)
|
| 163 |
+
skips += h
|
| 164 |
+
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
| 165 |
+
|
| 166 |
+
# apply final layers
|
| 167 |
+
x = skips
|
| 168 |
+
for f in self.last_conv_layers:
|
| 169 |
+
x = f(x)
|
| 170 |
+
|
| 171 |
+
return x
|
| 172 |
+
|
| 173 |
+
def remove_weight_norm(self):
|
| 174 |
+
"""Remove weight normalization module from all of the layers."""
|
| 175 |
+
def _remove_weight_norm(m):
|
| 176 |
+
try:
|
| 177 |
+
logging.debug(f"Weight norm is removed from {m}.")
|
| 178 |
+
torch.nn.utils.remove_weight_norm(m)
|
| 179 |
+
except ValueError: # this module didn't have weight norm
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
self.apply(_remove_weight_norm)
|
| 183 |
+
|
| 184 |
+
def apply_weight_norm(self):
|
| 185 |
+
"""Apply weight normalization module from all of the layers."""
|
| 186 |
+
def _apply_weight_norm(m):
|
| 187 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
| 188 |
+
torch.nn.utils.weight_norm(m)
|
| 189 |
+
logging.debug(f"Weight norm is applied to {m}.")
|
| 190 |
+
|
| 191 |
+
self.apply(_apply_weight_norm)
|
| 192 |
+
|
| 193 |
+
@staticmethod
|
| 194 |
+
def _get_receptive_field_size(layers, stacks, kernel_size,
|
| 195 |
+
dilation=lambda x: 2 ** x):
|
| 196 |
+
assert layers % stacks == 0
|
| 197 |
+
layers_per_cycle = layers // stacks
|
| 198 |
+
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
| 199 |
+
return (kernel_size - 1) * sum(dilations) + 1
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def receptive_field_size(self):
|
| 203 |
+
"""Return receptive field size."""
|
| 204 |
+
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class ParallelWaveGANDiscriminator(torch.nn.Module):
|
| 208 |
+
"""Parallel WaveGAN Discriminator module."""
|
| 209 |
+
|
| 210 |
+
def __init__(self,
|
| 211 |
+
in_channels=1,
|
| 212 |
+
out_channels=1,
|
| 213 |
+
kernel_size=3,
|
| 214 |
+
layers=10,
|
| 215 |
+
conv_channels=64,
|
| 216 |
+
dilation_factor=1,
|
| 217 |
+
nonlinear_activation="LeakyReLU",
|
| 218 |
+
nonlinear_activation_params={"negative_slope": 0.2},
|
| 219 |
+
bias=True,
|
| 220 |
+
use_weight_norm=True,
|
| 221 |
+
):
|
| 222 |
+
"""Initialize Parallel WaveGAN Discriminator module.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
in_channels (int): Number of input channels.
|
| 226 |
+
out_channels (int): Number of output channels.
|
| 227 |
+
kernel_size (int): Number of output channels.
|
| 228 |
+
layers (int): Number of conv layers.
|
| 229 |
+
conv_channels (int): Number of chnn layers.
|
| 230 |
+
dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,
|
| 231 |
+
the dilation will be 2, 4, 8, ..., and so on.
|
| 232 |
+
nonlinear_activation (str): Nonlinear function after each conv.
|
| 233 |
+
nonlinear_activation_params (dict): Nonlinear function parameters
|
| 234 |
+
bias (bool): Whether to use bias parameter in conv.
|
| 235 |
+
use_weight_norm (bool) Whether to use weight norm.
|
| 236 |
+
If set to true, it will be applied to all of the conv layers.
|
| 237 |
+
|
| 238 |
+
"""
|
| 239 |
+
super(ParallelWaveGANDiscriminator, self).__init__()
|
| 240 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
| 241 |
+
assert dilation_factor > 0, "Dilation factor must be > 0."
|
| 242 |
+
self.conv_layers = torch.nn.ModuleList()
|
| 243 |
+
conv_in_channels = in_channels
|
| 244 |
+
for i in range(layers - 1):
|
| 245 |
+
if i == 0:
|
| 246 |
+
dilation = 1
|
| 247 |
+
else:
|
| 248 |
+
dilation = i if dilation_factor == 1 else dilation_factor ** i
|
| 249 |
+
conv_in_channels = conv_channels
|
| 250 |
+
padding = (kernel_size - 1) // 2 * dilation
|
| 251 |
+
conv_layer = [
|
| 252 |
+
Conv1d(conv_in_channels, conv_channels,
|
| 253 |
+
kernel_size=kernel_size, padding=padding,
|
| 254 |
+
dilation=dilation, bias=bias),
|
| 255 |
+
getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
|
| 256 |
+
]
|
| 257 |
+
self.conv_layers += conv_layer
|
| 258 |
+
padding = (kernel_size - 1) // 2
|
| 259 |
+
last_conv_layer = Conv1d(
|
| 260 |
+
conv_in_channels, out_channels,
|
| 261 |
+
kernel_size=kernel_size, padding=padding, bias=bias)
|
| 262 |
+
self.conv_layers += [last_conv_layer]
|
| 263 |
+
|
| 264 |
+
# apply weight norm
|
| 265 |
+
if use_weight_norm:
|
| 266 |
+
self.apply_weight_norm()
|
| 267 |
+
|
| 268 |
+
def forward(self, x):
|
| 269 |
+
"""Calculate forward propagation.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
x (Tensor): Input noise signal (B, 1, T).
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Tensor: Output tensor (B, 1, T)
|
| 276 |
+
|
| 277 |
+
"""
|
| 278 |
+
for f in self.conv_layers:
|
| 279 |
+
x = f(x)
|
| 280 |
+
return x
|
| 281 |
+
|
| 282 |
+
def apply_weight_norm(self):
|
| 283 |
+
"""Apply weight normalization module from all of the layers."""
|
| 284 |
+
def _apply_weight_norm(m):
|
| 285 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
| 286 |
+
torch.nn.utils.weight_norm(m)
|
| 287 |
+
logging.debug(f"Weight norm is applied to {m}.")
|
| 288 |
+
|
| 289 |
+
self.apply(_apply_weight_norm)
|
| 290 |
+
|
| 291 |
+
def remove_weight_norm(self):
|
| 292 |
+
"""Remove weight normalization module from all of the layers."""
|
| 293 |
+
def _remove_weight_norm(m):
|
| 294 |
+
try:
|
| 295 |
+
logging.debug(f"Weight norm is removed from {m}.")
|
| 296 |
+
torch.nn.utils.remove_weight_norm(m)
|
| 297 |
+
except ValueError: # this module didn't have weight norm
|
| 298 |
+
return
|
| 299 |
+
|
| 300 |
+
self.apply(_remove_weight_norm)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class ResidualParallelWaveGANDiscriminator(torch.nn.Module):
|
| 304 |
+
"""Parallel WaveGAN Discriminator module."""
|
| 305 |
+
|
| 306 |
+
def __init__(self,
|
| 307 |
+
in_channels=1,
|
| 308 |
+
out_channels=1,
|
| 309 |
+
kernel_size=3,
|
| 310 |
+
layers=30,
|
| 311 |
+
stacks=3,
|
| 312 |
+
residual_channels=64,
|
| 313 |
+
gate_channels=128,
|
| 314 |
+
skip_channels=64,
|
| 315 |
+
dropout=0.0,
|
| 316 |
+
bias=True,
|
| 317 |
+
use_weight_norm=True,
|
| 318 |
+
use_causal_conv=False,
|
| 319 |
+
nonlinear_activation="LeakyReLU",
|
| 320 |
+
nonlinear_activation_params={"negative_slope": 0.2},
|
| 321 |
+
):
|
| 322 |
+
"""Initialize Parallel WaveGAN Discriminator module.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
in_channels (int): Number of input channels.
|
| 326 |
+
out_channels (int): Number of output channels.
|
| 327 |
+
kernel_size (int): Kernel size of dilated convolution.
|
| 328 |
+
layers (int): Number of residual block layers.
|
| 329 |
+
stacks (int): Number of stacks i.e., dilation cycles.
|
| 330 |
+
residual_channels (int): Number of channels in residual conv.
|
| 331 |
+
gate_channels (int): Number of channels in gated conv.
|
| 332 |
+
skip_channels (int): Number of channels in skip conv.
|
| 333 |
+
dropout (float): Dropout rate. 0.0 means no dropout applied.
|
| 334 |
+
bias (bool): Whether to use bias parameter in conv.
|
| 335 |
+
use_weight_norm (bool): Whether to use weight norm.
|
| 336 |
+
If set to true, it will be applied to all of the conv layers.
|
| 337 |
+
use_causal_conv (bool): Whether to use causal structure.
|
| 338 |
+
nonlinear_activation_params (dict): Nonlinear function parameters
|
| 339 |
+
|
| 340 |
+
"""
|
| 341 |
+
super(ResidualParallelWaveGANDiscriminator, self).__init__()
|
| 342 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
| 343 |
+
|
| 344 |
+
self.in_channels = in_channels
|
| 345 |
+
self.out_channels = out_channels
|
| 346 |
+
self.layers = layers
|
| 347 |
+
self.stacks = stacks
|
| 348 |
+
self.kernel_size = kernel_size
|
| 349 |
+
|
| 350 |
+
# check the number of layers and stacks
|
| 351 |
+
assert layers % stacks == 0
|
| 352 |
+
layers_per_stack = layers // stacks
|
| 353 |
+
|
| 354 |
+
# define first convolution
|
| 355 |
+
self.first_conv = torch.nn.Sequential(
|
| 356 |
+
Conv1d1x1(in_channels, residual_channels, bias=True),
|
| 357 |
+
getattr(torch.nn, nonlinear_activation)(
|
| 358 |
+
inplace=True, **nonlinear_activation_params),
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# define residual blocks
|
| 362 |
+
self.conv_layers = torch.nn.ModuleList()
|
| 363 |
+
for layer in range(layers):
|
| 364 |
+
dilation = 2 ** (layer % layers_per_stack)
|
| 365 |
+
conv = ResidualBlock(
|
| 366 |
+
kernel_size=kernel_size,
|
| 367 |
+
residual_channels=residual_channels,
|
| 368 |
+
gate_channels=gate_channels,
|
| 369 |
+
skip_channels=skip_channels,
|
| 370 |
+
aux_channels=-1,
|
| 371 |
+
dilation=dilation,
|
| 372 |
+
dropout=dropout,
|
| 373 |
+
bias=bias,
|
| 374 |
+
use_causal_conv=use_causal_conv,
|
| 375 |
+
)
|
| 376 |
+
self.conv_layers += [conv]
|
| 377 |
+
|
| 378 |
+
# define output layers
|
| 379 |
+
self.last_conv_layers = torch.nn.ModuleList([
|
| 380 |
+
getattr(torch.nn, nonlinear_activation)(
|
| 381 |
+
inplace=True, **nonlinear_activation_params),
|
| 382 |
+
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
| 383 |
+
getattr(torch.nn, nonlinear_activation)(
|
| 384 |
+
inplace=True, **nonlinear_activation_params),
|
| 385 |
+
Conv1d1x1(skip_channels, out_channels, bias=True),
|
| 386 |
+
])
|
| 387 |
+
|
| 388 |
+
# apply weight norm
|
| 389 |
+
if use_weight_norm:
|
| 390 |
+
self.apply_weight_norm()
|
| 391 |
+
|
| 392 |
+
def forward(self, x):
|
| 393 |
+
"""Calculate forward propagation.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
x (Tensor): Input noise signal (B, 1, T).
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
Tensor: Output tensor (B, 1, T)
|
| 400 |
+
|
| 401 |
+
"""
|
| 402 |
+
x = self.first_conv(x)
|
| 403 |
+
|
| 404 |
+
skips = 0
|
| 405 |
+
for f in self.conv_layers:
|
| 406 |
+
x, h = f(x, None)
|
| 407 |
+
skips += h
|
| 408 |
+
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
| 409 |
+
|
| 410 |
+
# apply final layers
|
| 411 |
+
x = skips
|
| 412 |
+
for f in self.last_conv_layers:
|
| 413 |
+
x = f(x)
|
| 414 |
+
return x
|
| 415 |
+
|
| 416 |
+
def apply_weight_norm(self):
|
| 417 |
+
"""Apply weight normalization module from all of the layers."""
|
| 418 |
+
def _apply_weight_norm(m):
|
| 419 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
| 420 |
+
torch.nn.utils.weight_norm(m)
|
| 421 |
+
logging.debug(f"Weight norm is applied to {m}.")
|
| 422 |
+
|
| 423 |
+
self.apply(_apply_weight_norm)
|
| 424 |
+
|
| 425 |
+
def remove_weight_norm(self):
|
| 426 |
+
"""Remove weight normalization module from all of the layers."""
|
| 427 |
+
def _remove_weight_norm(m):
|
| 428 |
+
try:
|
| 429 |
+
logging.debug(f"Weight norm is removed from {m}.")
|
| 430 |
+
torch.nn.utils.remove_weight_norm(m)
|
| 431 |
+
except ValueError: # this module didn't have weight norm
|
| 432 |
+
return
|
| 433 |
+
|
| 434 |
+
self.apply(_remove_weight_norm)
|
usr/configs/midi/cascade/opencs/ds60_rel.yaml
CHANGED
|
@@ -24,10 +24,11 @@ fs2_ckpt: 'checkpoints/0302_opencpop_fs_midi/model_ckpt_steps_160000.ckpt' #
|
|
| 24 |
task_cls: usr.diffsinger_task.DiffSingerMIDITask
|
| 25 |
|
| 26 |
K_step: 60
|
| 27 |
-
max_tokens:
|
| 28 |
predictor_layers: 5
|
| 29 |
dilation_cycle_length: 4 # *
|
| 30 |
rel_pos: true
|
| 31 |
dur_predictor_layers: 5 # *
|
| 32 |
max_updates: 160000
|
| 33 |
gaussian_start: false
|
|
|
|
|
|
| 24 |
task_cls: usr.diffsinger_task.DiffSingerMIDITask
|
| 25 |
|
| 26 |
K_step: 60
|
| 27 |
+
max_tokens: 36000
|
| 28 |
predictor_layers: 5
|
| 29 |
dilation_cycle_length: 4 # *
|
| 30 |
rel_pos: true
|
| 31 |
dur_predictor_layers: 5 # *
|
| 32 |
max_updates: 160000
|
| 33 |
gaussian_start: false
|
| 34 |
+
mask_uv_prob: 0.15
|
usr/diff/shallow_diffusion_tts.py
CHANGED
|
@@ -1,273 +1,324 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import random
|
| 3 |
-
from
|
| 4 |
-
from
|
| 5 |
-
from
|
| 6 |
-
|
| 7 |
-
import
|
| 8 |
-
import torch
|
| 9 |
-
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
from modules.
|
| 15 |
-
from
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
alphas_cumprod =
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
"
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
self.
|
| 96 |
-
self.
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
self.register_buffer('
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
self.register_buffer('
|
| 109 |
-
self.register_buffer('
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
self.register_buffer('
|
| 118 |
-
|
| 119 |
-
self.register_buffer('
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
self.register_buffer('
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
return
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
x = self.
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
from collections import deque
|
| 4 |
+
from functools import partial
|
| 5 |
+
from inspect import isfunction
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch import nn
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
|
| 14 |
+
from modules.fastspeech.fs2 import FastSpeech2
|
| 15 |
+
from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
|
| 16 |
+
from utils.hparams import hparams
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def exists(x):
|
| 21 |
+
return x is not None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def default(val, d):
|
| 25 |
+
if exists(val):
|
| 26 |
+
return val
|
| 27 |
+
return d() if isfunction(d) else d
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# gaussian diffusion trainer class
|
| 31 |
+
|
| 32 |
+
def extract(a, t, x_shape):
|
| 33 |
+
b, *_ = t.shape
|
| 34 |
+
out = a.gather(-1, t)
|
| 35 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def noise_like(shape, device, repeat=False):
|
| 39 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
| 40 |
+
noise = lambda: torch.randn(shape, device=device)
|
| 41 |
+
return repeat_noise() if repeat else noise()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def linear_beta_schedule(timesteps, max_beta=hparams.get('max_beta', 0.01)):
|
| 45 |
+
"""
|
| 46 |
+
linear schedule
|
| 47 |
+
"""
|
| 48 |
+
betas = np.linspace(1e-4, max_beta, timesteps)
|
| 49 |
+
return betas
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def cosine_beta_schedule(timesteps, s=0.008):
|
| 53 |
+
"""
|
| 54 |
+
cosine schedule
|
| 55 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
| 56 |
+
"""
|
| 57 |
+
steps = timesteps + 1
|
| 58 |
+
x = np.linspace(0, steps, steps)
|
| 59 |
+
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
| 60 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 61 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 62 |
+
return np.clip(betas, a_min=0, a_max=0.999)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
beta_schedule = {
|
| 66 |
+
"cosine": cosine_beta_schedule,
|
| 67 |
+
"linear": linear_beta_schedule,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class GaussianDiffusion(nn.Module):
|
| 72 |
+
def __init__(self, phone_encoder, out_dims, denoise_fn,
|
| 73 |
+
timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None, spec_max=None):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.denoise_fn = denoise_fn
|
| 76 |
+
if hparams.get('use_midi') is not None and hparams['use_midi']:
|
| 77 |
+
self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims)
|
| 78 |
+
else:
|
| 79 |
+
self.fs2 = FastSpeech2(phone_encoder, out_dims)
|
| 80 |
+
self.mel_bins = out_dims
|
| 81 |
+
|
| 82 |
+
if exists(betas):
|
| 83 |
+
betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
|
| 84 |
+
else:
|
| 85 |
+
if 'schedule_type' in hparams.keys():
|
| 86 |
+
betas = beta_schedule[hparams['schedule_type']](timesteps)
|
| 87 |
+
else:
|
| 88 |
+
betas = cosine_beta_schedule(timesteps)
|
| 89 |
+
|
| 90 |
+
alphas = 1. - betas
|
| 91 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 92 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
| 93 |
+
|
| 94 |
+
timesteps, = betas.shape
|
| 95 |
+
self.num_timesteps = int(timesteps)
|
| 96 |
+
self.K_step = K_step
|
| 97 |
+
self.loss_type = loss_type
|
| 98 |
+
|
| 99 |
+
self.noise_list = deque(maxlen=4)
|
| 100 |
+
|
| 101 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
| 102 |
+
|
| 103 |
+
self.register_buffer('betas', to_torch(betas))
|
| 104 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 105 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
| 106 |
+
|
| 107 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 108 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
| 109 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
| 110 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
| 111 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
| 112 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
| 113 |
+
|
| 114 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 115 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
| 116 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 117 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
| 118 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 119 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
| 120 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
| 121 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
| 122 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
| 123 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
| 124 |
+
|
| 125 |
+
self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
|
| 126 |
+
self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
|
| 127 |
+
|
| 128 |
+
def q_mean_variance(self, x_start, t):
|
| 129 |
+
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 130 |
+
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
| 131 |
+
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
| 132 |
+
return mean, variance, log_variance
|
| 133 |
+
|
| 134 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 135 |
+
return (
|
| 136 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
| 137 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def q_posterior(self, x_start, x_t, t):
|
| 141 |
+
posterior_mean = (
|
| 142 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
| 143 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 144 |
+
)
|
| 145 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
| 146 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
| 147 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 148 |
+
|
| 149 |
+
def p_mean_variance(self, x, t, cond, clip_denoised: bool):
|
| 150 |
+
noise_pred = self.denoise_fn(x, t, cond=cond)
|
| 151 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
|
| 152 |
+
|
| 153 |
+
if clip_denoised:
|
| 154 |
+
x_recon.clamp_(-1., 1.)
|
| 155 |
+
|
| 156 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
| 157 |
+
return model_mean, posterior_variance, posterior_log_variance
|
| 158 |
+
|
| 159 |
+
@torch.no_grad()
|
| 160 |
+
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
| 161 |
+
b, *_, device = *x.shape, x.device
|
| 162 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
|
| 163 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
| 164 |
+
# no noise when t == 0
|
| 165 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
| 166 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
| 167 |
+
|
| 168 |
+
@torch.no_grad()
|
| 169 |
+
def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
|
| 170 |
+
"""
|
| 171 |
+
Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def get_x_pred(x, noise_t, t):
|
| 175 |
+
a_t = extract(self.alphas_cumprod, t, x.shape)
|
| 176 |
+
if t[0] < interval:
|
| 177 |
+
a_prev = torch.ones_like(a_t)
|
| 178 |
+
else:
|
| 179 |
+
a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
|
| 180 |
+
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
| 181 |
+
|
| 182 |
+
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
| 183 |
+
x_pred = x + x_delta
|
| 184 |
+
|
| 185 |
+
return x_pred
|
| 186 |
+
|
| 187 |
+
noise_list = self.noise_list
|
| 188 |
+
noise_pred = self.denoise_fn(x, t, cond=cond)
|
| 189 |
+
|
| 190 |
+
if len(noise_list) == 0:
|
| 191 |
+
x_pred = get_x_pred(x, noise_pred, t)
|
| 192 |
+
noise_pred_prev = self.denoise_fn(x_pred, max(t-interval, 0), cond=cond)
|
| 193 |
+
noise_pred_prime = (noise_pred + noise_pred_prev) / 2
|
| 194 |
+
elif len(noise_list) == 1:
|
| 195 |
+
noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
|
| 196 |
+
elif len(noise_list) == 2:
|
| 197 |
+
noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
|
| 198 |
+
elif len(noise_list) >= 3:
|
| 199 |
+
noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
|
| 200 |
+
|
| 201 |
+
x_prev = get_x_pred(x, noise_pred_prime, t)
|
| 202 |
+
noise_list.append(noise_pred)
|
| 203 |
+
|
| 204 |
+
return x_prev
|
| 205 |
+
|
| 206 |
+
def q_sample(self, x_start, t, noise=None):
|
| 207 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 208 |
+
return (
|
| 209 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 210 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def p_losses(self, x_start, t, cond, noise=None, nonpadding=None):
|
| 214 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 215 |
+
|
| 216 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| 217 |
+
x_recon = self.denoise_fn(x_noisy, t, cond)
|
| 218 |
+
|
| 219 |
+
if self.loss_type == 'l1':
|
| 220 |
+
if nonpadding is not None:
|
| 221 |
+
loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean()
|
| 222 |
+
else:
|
| 223 |
+
# print('are you sure w/o nonpadding?')
|
| 224 |
+
loss = (noise - x_recon).abs().mean()
|
| 225 |
+
|
| 226 |
+
elif self.loss_type == 'l2':
|
| 227 |
+
loss = F.mse_loss(noise, x_recon)
|
| 228 |
+
else:
|
| 229 |
+
raise NotImplementedError()
|
| 230 |
+
|
| 231 |
+
return loss
|
| 232 |
+
|
| 233 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
| 234 |
+
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
| 235 |
+
b, *_, device = *txt_tokens.shape, txt_tokens.device
|
| 236 |
+
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
| 237 |
+
skip_decoder=(not infer), infer=infer, **kwargs)
|
| 238 |
+
cond = ret['decoder_inp'].transpose(1, 2)
|
| 239 |
+
|
| 240 |
+
if not infer:
|
| 241 |
+
t = torch.randint(0, self.K_step, (b,), device=device).long()
|
| 242 |
+
x = ref_mels
|
| 243 |
+
x = self.norm_spec(x)
|
| 244 |
+
x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
|
| 245 |
+
ret['diff_loss'] = self.p_losses(x, t, cond)
|
| 246 |
+
# nonpadding = (mel2ph != 0).float()
|
| 247 |
+
# ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding)
|
| 248 |
+
else:
|
| 249 |
+
ret['fs2_mel'] = ret['mel_out']
|
| 250 |
+
fs2_mels = ret['mel_out']
|
| 251 |
+
t = self.K_step
|
| 252 |
+
fs2_mels = self.norm_spec(fs2_mels)
|
| 253 |
+
fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
|
| 254 |
+
|
| 255 |
+
x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
|
| 256 |
+
if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
|
| 257 |
+
print('===> gaussion start.')
|
| 258 |
+
shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
|
| 259 |
+
x = torch.randn(shape, device=device)
|
| 260 |
+
|
| 261 |
+
if hparams.get('pndm_speedup'):
|
| 262 |
+
print('===> pndm speed:', hparams['pndm_speedup'])
|
| 263 |
+
self.noise_list = deque(maxlen=4)
|
| 264 |
+
iteration_interval = hparams['pndm_speedup']
|
| 265 |
+
for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',
|
| 266 |
+
total=t // iteration_interval):
|
| 267 |
+
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), iteration_interval,
|
| 268 |
+
cond)
|
| 269 |
+
else:
|
| 270 |
+
for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
|
| 271 |
+
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
| 272 |
+
x = x[:, 0].transpose(1, 2)
|
| 273 |
+
if mel2ph is not None: # for singing
|
| 274 |
+
ret['mel_out'] = self.denorm_spec(x) * ((mel2ph > 0).float()[:, :, None])
|
| 275 |
+
else:
|
| 276 |
+
ret['mel_out'] = self.denorm_spec(x)
|
| 277 |
+
return ret
|
| 278 |
+
|
| 279 |
+
def norm_spec(self, x):
|
| 280 |
+
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
| 281 |
+
|
| 282 |
+
def denorm_spec(self, x):
|
| 283 |
+
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
| 284 |
+
|
| 285 |
+
def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
|
| 286 |
+
return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
|
| 287 |
+
|
| 288 |
+
def out2mel(self, x):
|
| 289 |
+
return x
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class OfflineGaussianDiffusion(GaussianDiffusion):
|
| 293 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
| 294 |
+
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
| 295 |
+
b, *_, device = *txt_tokens.shape, txt_tokens.device
|
| 296 |
+
|
| 297 |
+
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
| 298 |
+
skip_decoder=True, infer=True, **kwargs)
|
| 299 |
+
cond = ret['decoder_inp'].transpose(1, 2)
|
| 300 |
+
fs2_mels = ref_mels[1]
|
| 301 |
+
ref_mels = ref_mels[0]
|
| 302 |
+
|
| 303 |
+
if not infer:
|
| 304 |
+
t = torch.randint(0, self.K_step, (b,), device=device).long()
|
| 305 |
+
x = ref_mels
|
| 306 |
+
x = self.norm_spec(x)
|
| 307 |
+
x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
|
| 308 |
+
ret['diff_loss'] = self.p_losses(x, t, cond)
|
| 309 |
+
else:
|
| 310 |
+
t = self.K_step
|
| 311 |
+
fs2_mels = self.norm_spec(fs2_mels)
|
| 312 |
+
fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
|
| 313 |
+
|
| 314 |
+
x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
|
| 315 |
+
|
| 316 |
+
if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
|
| 317 |
+
print('===> gaussion start.')
|
| 318 |
+
shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
|
| 319 |
+
x = torch.randn(shape, device=device)
|
| 320 |
+
for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
|
| 321 |
+
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
| 322 |
+
x = x[:, 0].transpose(1, 2)
|
| 323 |
+
ret['mel_out'] = self.denorm_spec(x)
|
| 324 |
+
return ret
|
utils/hparams.py
CHANGED
|
@@ -21,35 +21,30 @@ def override_config(old_config: dict, new_config: dict):
|
|
| 21 |
|
| 22 |
|
| 23 |
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
|
| 24 |
-
if config == ''
|
| 25 |
-
parser = argparse.ArgumentParser(description='')
|
| 26 |
parser.add_argument('--config', type=str, default='',
|
| 27 |
help='location of the data corpus')
|
| 28 |
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
|
| 29 |
-
parser.add_argument('
|
| 30 |
help='location of the data corpus')
|
| 31 |
parser.add_argument('--infer', action='store_true', help='infer')
|
| 32 |
parser.add_argument('--validate', action='store_true', help='validate')
|
| 33 |
parser.add_argument('--reset', action='store_true', help='reset hparams')
|
| 34 |
-
parser.add_argument('--remove', action='store_true', help='remove old ckpt')
|
| 35 |
parser.add_argument('--debug', action='store_true', help='debug')
|
| 36 |
args, unknown = parser.parse_known_args()
|
| 37 |
-
print("| Unknow hparams: ", unknown)
|
| 38 |
else:
|
| 39 |
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
|
| 40 |
-
infer=False, validate=False, reset=False, debug=False
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
|
| 46 |
config_chains = []
|
| 47 |
loaded_config = set()
|
| 48 |
|
| 49 |
-
def load_config(config_fn):
|
| 50 |
-
# deep first inheritance and avoid the second visit of one node
|
| 51 |
-
if not os.path.exists(config_fn):
|
| 52 |
-
return {}
|
| 53 |
with open(config_fn) as f:
|
| 54 |
hparams_ = yaml.safe_load(f)
|
| 55 |
loaded_config.add(config_fn)
|
|
@@ -58,10 +53,10 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
|
|
| 58 |
if not isinstance(hparams_['base_config'], list):
|
| 59 |
hparams_['base_config'] = [hparams_['base_config']]
|
| 60 |
for c in hparams_['base_config']:
|
| 61 |
-
if c.startswith('.'):
|
| 62 |
-
c = f'{os.path.dirname(config_fn)}/{c}'
|
| 63 |
-
c = os.path.normpath(c)
|
| 64 |
if c not in loaded_config:
|
|
|
|
|
|
|
|
|
|
| 65 |
override_config(ret_hparams, load_config(c))
|
| 66 |
override_config(ret_hparams, hparams_)
|
| 67 |
else:
|
|
@@ -69,43 +64,36 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
|
|
| 69 |
config_chains.append(config_fn)
|
| 70 |
return ret_hparams
|
| 71 |
|
|
|
|
|
|
|
| 72 |
saved_hparams = {}
|
| 73 |
-
args_work_dir
|
| 74 |
-
if args.exp_name != '':
|
| 75 |
-
args_work_dir = f'checkpoints/{args.exp_name}'
|
| 76 |
ckpt_config_path = f'{args_work_dir}/config.yaml'
|
| 77 |
if os.path.exists(ckpt_config_path):
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
hparams_ = {}
|
| 83 |
-
|
| 84 |
-
|
|
|
|
| 85 |
if not args.reset:
|
| 86 |
hparams_.update(saved_hparams)
|
| 87 |
hparams_['work_dir'] = args_work_dir
|
| 88 |
|
| 89 |
-
# Support config overriding in command line. Support list type config overriding.
|
| 90 |
-
# Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
|
| 91 |
if args.hparams != "":
|
| 92 |
for new_hparam in args.hparams.split(","):
|
| 93 |
k, v = new_hparam.split("=")
|
| 94 |
-
v
|
| 95 |
-
|
| 96 |
-
for k_ in k.split(".")[:-1]:
|
| 97 |
-
config_node = config_node[k_]
|
| 98 |
-
k = k.split(".")[-1]
|
| 99 |
-
if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
|
| 100 |
-
if type(config_node[k]) == list:
|
| 101 |
-
v = v.replace(" ", ",")
|
| 102 |
-
config_node[k] = eval(v)
|
| 103 |
else:
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
answer = input("REMOVE old checkpoint? Y/N [Default: N]: ")
|
| 107 |
-
if answer.lower() == "y":
|
| 108 |
-
remove_file(args_work_dir)
|
| 109 |
if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
|
| 110 |
os.makedirs(hparams_['work_dir'], exist_ok=True)
|
| 111 |
with open(ckpt_config_path, 'w') as f:
|
|
@@ -114,11 +102,11 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
|
|
| 114 |
hparams_['infer'] = args.infer
|
| 115 |
hparams_['debug'] = args.debug
|
| 116 |
hparams_['validate'] = args.validate
|
| 117 |
-
hparams_['exp_name'] = args.exp_name
|
| 118 |
global global_print_hparams
|
| 119 |
if global_hparams:
|
| 120 |
hparams.clear()
|
| 121 |
hparams.update(hparams_)
|
|
|
|
| 122 |
if print_hparams and global_print_hparams and global_hparams:
|
| 123 |
print('| Hparams chains: ', config_chains)
|
| 124 |
print('| Hparams: ')
|
|
@@ -126,5 +114,9 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
|
|
| 126 |
print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
|
| 127 |
print("")
|
| 128 |
global_print_hparams = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
return hparams_
|
| 130 |
-
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
|
| 24 |
+
if config == '':
|
| 25 |
+
parser = argparse.ArgumentParser(description='neural music')
|
| 26 |
parser.add_argument('--config', type=str, default='',
|
| 27 |
help='location of the data corpus')
|
| 28 |
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
|
| 29 |
+
parser.add_argument('--hparams', type=str, default='',
|
| 30 |
help='location of the data corpus')
|
| 31 |
parser.add_argument('--infer', action='store_true', help='infer')
|
| 32 |
parser.add_argument('--validate', action='store_true', help='validate')
|
| 33 |
parser.add_argument('--reset', action='store_true', help='reset hparams')
|
|
|
|
| 34 |
parser.add_argument('--debug', action='store_true', help='debug')
|
| 35 |
args, unknown = parser.parse_known_args()
|
|
|
|
| 36 |
else:
|
| 37 |
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
|
| 38 |
+
infer=False, validate=False, reset=False, debug=False)
|
| 39 |
+
args_work_dir = ''
|
| 40 |
+
if args.exp_name != '':
|
| 41 |
+
args.work_dir = args.exp_name
|
| 42 |
+
args_work_dir = f'checkpoints/{args.work_dir}'
|
| 43 |
|
| 44 |
config_chains = []
|
| 45 |
loaded_config = set()
|
| 46 |
|
| 47 |
+
def load_config(config_fn): # deep first
|
|
|
|
|
|
|
|
|
|
| 48 |
with open(config_fn) as f:
|
| 49 |
hparams_ = yaml.safe_load(f)
|
| 50 |
loaded_config.add(config_fn)
|
|
|
|
| 53 |
if not isinstance(hparams_['base_config'], list):
|
| 54 |
hparams_['base_config'] = [hparams_['base_config']]
|
| 55 |
for c in hparams_['base_config']:
|
|
|
|
|
|
|
|
|
|
| 56 |
if c not in loaded_config:
|
| 57 |
+
if c.startswith('.'):
|
| 58 |
+
c = f'{os.path.dirname(config_fn)}/{c}'
|
| 59 |
+
c = os.path.normpath(c)
|
| 60 |
override_config(ret_hparams, load_config(c))
|
| 61 |
override_config(ret_hparams, hparams_)
|
| 62 |
else:
|
|
|
|
| 64 |
config_chains.append(config_fn)
|
| 65 |
return ret_hparams
|
| 66 |
|
| 67 |
+
global hparams
|
| 68 |
+
assert args.config != '' or args_work_dir != ''
|
| 69 |
saved_hparams = {}
|
| 70 |
+
if args_work_dir != 'checkpoints/':
|
|
|
|
|
|
|
| 71 |
ckpt_config_path = f'{args_work_dir}/config.yaml'
|
| 72 |
if os.path.exists(ckpt_config_path):
|
| 73 |
+
try:
|
| 74 |
+
with open(ckpt_config_path) as f:
|
| 75 |
+
saved_hparams.update(yaml.safe_load(f))
|
| 76 |
+
except:
|
| 77 |
+
pass
|
| 78 |
+
if args.config == '':
|
| 79 |
+
args.config = ckpt_config_path
|
| 80 |
+
|
| 81 |
hparams_ = {}
|
| 82 |
+
|
| 83 |
+
hparams_.update(load_config(args.config))
|
| 84 |
+
|
| 85 |
if not args.reset:
|
| 86 |
hparams_.update(saved_hparams)
|
| 87 |
hparams_['work_dir'] = args_work_dir
|
| 88 |
|
|
|
|
|
|
|
| 89 |
if args.hparams != "":
|
| 90 |
for new_hparam in args.hparams.split(","):
|
| 91 |
k, v = new_hparam.split("=")
|
| 92 |
+
if v in ['True', 'False'] or type(hparams_[k]) == bool:
|
| 93 |
+
hparams_[k] = eval(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
else:
|
| 95 |
+
hparams_[k] = type(hparams_[k])(v)
|
| 96 |
+
|
|
|
|
|
|
|
|
|
|
| 97 |
if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
|
| 98 |
os.makedirs(hparams_['work_dir'], exist_ok=True)
|
| 99 |
with open(ckpt_config_path, 'w') as f:
|
|
|
|
| 102 |
hparams_['infer'] = args.infer
|
| 103 |
hparams_['debug'] = args.debug
|
| 104 |
hparams_['validate'] = args.validate
|
|
|
|
| 105 |
global global_print_hparams
|
| 106 |
if global_hparams:
|
| 107 |
hparams.clear()
|
| 108 |
hparams.update(hparams_)
|
| 109 |
+
|
| 110 |
if print_hparams and global_print_hparams and global_hparams:
|
| 111 |
print('| Hparams chains: ', config_chains)
|
| 112 |
print('| Hparams: ')
|
|
|
|
| 114 |
print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
|
| 115 |
print("")
|
| 116 |
global_print_hparams = False
|
| 117 |
+
# print(hparams_.keys())
|
| 118 |
+
if hparams.get('exp_name') is None:
|
| 119 |
+
hparams['exp_name'] = args.exp_name
|
| 120 |
+
if hparams_.get('exp_name') is None:
|
| 121 |
+
hparams_['exp_name'] = args.exp_name
|
| 122 |
return hparams_
|
|
|