Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Stereo demo update (#60)
Browse files- updated demo (a16e65ea0e58cfa993cbd11f06f2ded8eafb559c)
Co-authored-by: Alexandre Défossez <adefossez@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes.  
							See raw diff
- .github/actions/audiocraft_build/action.yml +2 -0
- .github/workflows/audiocraft_docs.yml +3 -3
- .github/workflows/audiocraft_tests.yml +6 -1
- .gitignore +8 -1
- CHANGELOG.md +31 -1
- CONTRIBUTING.md +2 -2
- LICENSE_weights +399 -157
- MANIFEST.in +7 -0
- Makefile +23 -4
- README.md +43 -83
- assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 +0 -0
- assets/sirens_and_a_humming_engine_approach_and_pass.mp3 +0 -0
- audiocraft/__init__.py +17 -1
- audiocraft/adversarial/__init__.py +22 -0
- audiocraft/adversarial/discriminators/__init__.py +10 -0
- audiocraft/adversarial/discriminators/base.py +34 -0
- audiocraft/adversarial/discriminators/mpd.py +106 -0
- audiocraft/adversarial/discriminators/msd.py +126 -0
- audiocraft/adversarial/discriminators/msstftd.py +134 -0
- audiocraft/adversarial/losses.py +228 -0
- audiocraft/data/__init__.py +3 -1
- audiocraft/data/audio.py +37 -21
- audiocraft/data/audio_dataset.py +93 -31
- audiocraft/data/audio_utils.py +12 -10
- audiocraft/data/info_audio_dataset.py +110 -0
- audiocraft/data/music_dataset.py +270 -0
- audiocraft/data/sound_dataset.py +330 -0
- audiocraft/data/zip.py +8 -6
- audiocraft/environment.py +176 -0
- audiocraft/grids/__init__.py +6 -0
- audiocraft/grids/_base_explorers.py +80 -0
- audiocraft/grids/audiogen/__init__.py +6 -0
- audiocraft/grids/audiogen/audiogen_base_16khz.py +23 -0
- audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py +68 -0
- audiocraft/grids/compression/__init__.py +6 -0
- audiocraft/grids/compression/_explorers.py +55 -0
- audiocraft/grids/compression/debug.py +31 -0
- audiocraft/grids/compression/encodec_audiogen_16khz.py +29 -0
- audiocraft/grids/compression/encodec_base_24khz.py +28 -0
- audiocraft/grids/compression/encodec_musicgen_32khz.py +34 -0
- audiocraft/grids/diffusion/4_bands_base_32khz.py +27 -0
- audiocraft/grids/diffusion/__init__.py +6 -0
- audiocraft/grids/diffusion/_explorers.py +66 -0
- audiocraft/grids/musicgen/__init__.py +6 -0
- audiocraft/grids/musicgen/_explorers.py +93 -0
- audiocraft/grids/musicgen/musicgen_base_32khz.py +43 -0
- audiocraft/grids/musicgen/musicgen_base_cached_32khz.py +67 -0
- audiocraft/grids/musicgen/musicgen_clapemb_32khz.py +32 -0
- audiocraft/grids/musicgen/musicgen_melody_32khz.py +65 -0
- audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py +99 -0
    	
        .github/actions/audiocraft_build/action.yml
    CHANGED
    
    | @@ -21,6 +21,8 @@ runs: | |
| 21 | 
             
                  python3 -m venv env
         | 
| 22 | 
             
                  .  env/bin/activate
         | 
| 23 | 
             
                  python -m pip install --upgrade pip
         | 
|  | |
|  | |
| 24 | 
             
                  pip install -e '.[dev]'
         | 
| 25 | 
             
              - name: System Dependencies
         | 
| 26 | 
             
                shell: bash
         | 
|  | |
| 21 | 
             
                  python3 -m venv env
         | 
| 22 | 
             
                  .  env/bin/activate
         | 
| 23 | 
             
                  python -m pip install --upgrade pip
         | 
| 24 | 
            +
                  pip install torch torchvision torchaudio
         | 
| 25 | 
            +
                  pip install xformers
         | 
| 26 | 
             
                  pip install -e '.[dev]'
         | 
| 27 | 
             
              - name: System Dependencies
         | 
| 28 | 
             
                shell: bash
         | 
    	
        .github/workflows/audiocraft_docs.yml
    CHANGED
    
    | @@ -23,9 +23,9 @@ jobs: | |
| 23 | 
             
                  - name: Make docs
         | 
| 24 | 
             
                    run: |
         | 
| 25 | 
             
                      . env/bin/activate
         | 
| 26 | 
            -
                      make  | 
| 27 | 
            -
                      git add -f  | 
| 28 | 
            -
                      git commit -m  | 
| 29 |  | 
| 30 | 
             
                  - name: Push branch
         | 
| 31 | 
             
                    run: |
         | 
|  | |
| 23 | 
             
                  - name: Make docs
         | 
| 24 | 
             
                    run: |
         | 
| 25 | 
             
                      . env/bin/activate
         | 
| 26 | 
            +
                      make api_docs
         | 
| 27 | 
            +
                      git add -f api_docs
         | 
| 28 | 
            +
                      git commit -m api_docs
         | 
| 29 |  | 
| 30 | 
             
                  - name: Push branch
         | 
| 31 | 
             
                    run: |
         | 
    	
        .github/workflows/audiocraft_tests.yml
    CHANGED
    
    | @@ -12,6 +12,11 @@ jobs: | |
| 12 | 
             
                steps:
         | 
| 13 | 
             
                  - uses: actions/checkout@v2
         | 
| 14 | 
             
                  - uses: ./.github/actions/audiocraft_build
         | 
| 15 | 
            -
                  -  | 
|  | |
| 16 | 
             
                      . env/bin/activate
         | 
| 17 | 
             
                      make tests
         | 
|  | |
|  | |
|  | |
|  | 
|  | |
| 12 | 
             
                steps:
         | 
| 13 | 
             
                  - uses: actions/checkout@v2
         | 
| 14 | 
             
                  - uses: ./.github/actions/audiocraft_build
         | 
| 15 | 
            +
                  - name: Run unit tests
         | 
| 16 | 
            +
                    run: |
         | 
| 17 | 
             
                      . env/bin/activate
         | 
| 18 | 
             
                      make tests
         | 
| 19 | 
            +
                  - name: Run integration tests
         | 
| 20 | 
            +
                    run: |
         | 
| 21 | 
            +
                      . env/bin/activate
         | 
| 22 | 
            +
                      make tests_integ
         | 
    	
        .gitignore
    CHANGED
    
    | @@ -35,7 +35,7 @@ wheels/ | |
| 35 | 
             
            .coverage
         | 
| 36 |  | 
| 37 | 
             
            # docs
         | 
| 38 | 
            -
            / | 
| 39 |  | 
| 40 | 
             
            # dotenv
         | 
| 41 | 
             
            .env
         | 
| @@ -46,6 +46,13 @@ wheels/ | |
| 46 | 
             
            venv/
         | 
| 47 | 
             
            ENV/
         | 
| 48 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 49 | 
             
            # personal notebooks & scripts
         | 
| 50 | 
             
            */local_scripts
         | 
| 51 | 
             
            */notes
         | 
|  | |
| 35 | 
             
            .coverage
         | 
| 36 |  | 
| 37 | 
             
            # docs
         | 
| 38 | 
            +
            /api_docs
         | 
| 39 |  | 
| 40 | 
             
            # dotenv
         | 
| 41 | 
             
            .env
         | 
|  | |
| 46 | 
             
            venv/
         | 
| 47 | 
             
            ENV/
         | 
| 48 |  | 
| 49 | 
            +
            # egs with manifest files
         | 
| 50 | 
            +
            egs/*
         | 
| 51 | 
            +
            !egs/example
         | 
| 52 | 
            +
            # local datasets
         | 
| 53 | 
            +
            dataset/*
         | 
| 54 | 
            +
            !dataset/example
         | 
| 55 | 
            +
             | 
| 56 | 
             
            # personal notebooks & scripts
         | 
| 57 | 
             
            */local_scripts
         | 
| 58 | 
             
            */notes
         | 
    	
        CHANGELOG.md
    CHANGED
    
    | @@ -4,7 +4,37 @@ All notable changes to this project will be documented in this file. | |
| 4 |  | 
| 5 | 
             
            The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
         | 
| 6 |  | 
| 7 | 
            -
            ## [ | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 8 |  | 
| 9 | 
             
            Improved demo, fixed top p (thanks @jnordberg).
         | 
| 10 |  | 
|  | |
| 4 |  | 
| 5 | 
             
            The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
         | 
| 6 |  | 
| 7 | 
            +
            ## [1.2.0a] - TBD
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            Adding stereo models.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            ## [1.1.0] - 2023-11-06
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            Fixed DAC support with non default number of codebooks.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            Fixed bug when `two_step_cfg` was overriden when calling `generate()`.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            Fixed samples being always prompted with audio, rather than having both prompted and unprompted.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            **Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release.
         | 
| 23 | 
            +
            	The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners.
         | 
| 24 | 
            +
            	We removed it, so you might need to retrain models.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            **Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before).
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            **Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one
         | 
| 29 | 
            +
            	retrained a model with this pattern, so hopefully this won't impact you!
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            ## [1.0.0] - 2023-09-07
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
         | 
| 35 | 
            +
            Added pretrained model for AudioGen and MultiBandDiffusion.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            ## [0.0.2] - 2023-08-01
         | 
| 38 |  | 
| 39 | 
             
            Improved demo, fixed top p (thanks @jnordberg).
         | 
| 40 |  | 
    	
        CONTRIBUTING.md
    CHANGED
    
    | @@ -1,11 +1,11 @@ | |
| 1 | 
            -
            # Contributing to  | 
| 2 |  | 
| 3 | 
             
            We want to make contributing to this project as easy and transparent as
         | 
| 4 | 
             
            possible.
         | 
| 5 |  | 
| 6 | 
             
            ## Pull Requests
         | 
| 7 |  | 
| 8 | 
            -
             | 
| 9 | 
             
            Therefore, we do not plan on accepting many pull requests for new features.
         | 
| 10 | 
             
            We certainly welcome them for bug fixes.
         | 
| 11 |  | 
|  | |
| 1 | 
            +
            # Contributing to AudioCraft
         | 
| 2 |  | 
| 3 | 
             
            We want to make contributing to this project as easy and transparent as
         | 
| 4 | 
             
            possible.
         | 
| 5 |  | 
| 6 | 
             
            ## Pull Requests
         | 
| 7 |  | 
| 8 | 
            +
            AudioCraft is the implementation of a research paper.
         | 
| 9 | 
             
            Therefore, we do not plan on accepting many pull requests for new features.
         | 
| 10 | 
             
            We certainly welcome them for bug fixes.
         | 
| 11 |  | 
    	
        LICENSE_weights
    CHANGED
    
    | @@ -1,157 +1,399 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
            -
             | 
| 66 | 
            -
             | 
| 67 | 
            -
             | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
             | 
| 90 | 
            -
             | 
| 91 | 
            -
             | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
            -
             | 
| 95 | 
            -
             | 
| 96 | 
            -
             
         | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
             | 
| 117 | 
            -
             | 
| 118 | 
            -
             | 
| 119 | 
            -
             | 
| 120 | 
            -
             | 
| 121 | 
            -
             | 
| 122 | 
            -
             | 
| 123 | 
            -
             | 
| 124 | 
            -
             | 
| 125 | 
            -
             | 
| 126 | 
            -
             | 
| 127 | 
            -
             | 
| 128 | 
            -
             | 
| 129 | 
            -
             | 
| 130 | 
            -
             | 
| 131 | 
            -
             | 
| 132 | 
            -
             | 
| 133 | 
            -
             | 
| 134 | 
            -
             | 
| 135 | 
            -
             | 
| 136 | 
            -
             | 
| 137 | 
            -
             | 
| 138 | 
            -
             | 
| 139 | 
            -
             | 
| 140 | 
            -
             | 
| 141 | 
            -
             | 
| 142 | 
            -
             | 
| 143 | 
            -
             | 
| 144 | 
            -
             | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
             | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 154 | 
            -
             | 
| 155 | 
            -
             | 
| 156 | 
            -
             | 
| 157 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Attribution-NonCommercial 4.0 International
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            =======================================================================
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Creative Commons Corporation ("Creative Commons") is not a law firm and
         | 
| 6 | 
            +
            does not provide legal services or legal advice. Distribution of
         | 
| 7 | 
            +
            Creative Commons public licenses does not create a lawyer-client or
         | 
| 8 | 
            +
            other relationship. Creative Commons makes its licenses and related
         | 
| 9 | 
            +
            information available on an "as-is" basis. Creative Commons gives no
         | 
| 10 | 
            +
            warranties regarding its licenses, any material licensed under their
         | 
| 11 | 
            +
            terms and conditions, or any related information. Creative Commons
         | 
| 12 | 
            +
            disclaims all liability for damages resulting from their use to the
         | 
| 13 | 
            +
            fullest extent possible.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            Using Creative Commons Public Licenses
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            Creative Commons public licenses provide a standard set of terms and
         | 
| 18 | 
            +
            conditions that creators and other rights holders may use to share
         | 
| 19 | 
            +
            original works of authorship and other material subject to copyright
         | 
| 20 | 
            +
            and certain other rights specified in the public license below. The
         | 
| 21 | 
            +
            following considerations are for informational purposes only, are not
         | 
| 22 | 
            +
            exhaustive, and do not form part of our licenses.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                 Considerations for licensors: Our public licenses are
         | 
| 25 | 
            +
                 intended for use by those authorized to give the public
         | 
| 26 | 
            +
                 permission to use material in ways otherwise restricted by
         | 
| 27 | 
            +
                 copyright and certain other rights. Our licenses are
         | 
| 28 | 
            +
                 irrevocable. Licensors should read and understand the terms
         | 
| 29 | 
            +
                 and conditions of the license they choose before applying it.
         | 
| 30 | 
            +
                 Licensors should also secure all rights necessary before
         | 
| 31 | 
            +
                 applying our licenses so that the public can reuse the
         | 
| 32 | 
            +
                 material as expected. Licensors should clearly mark any
         | 
| 33 | 
            +
                 material not subject to the license. This includes other CC-
         | 
| 34 | 
            +
                 licensed material, or material used under an exception or
         | 
| 35 | 
            +
                 limitation to copyright. More considerations for licensors:
         | 
| 36 | 
            +
            	wiki.creativecommons.org/Considerations_for_licensors
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                 Considerations for the public: By using one of our public
         | 
| 39 | 
            +
                 licenses, a licensor grants the public permission to use the
         | 
| 40 | 
            +
                 licensed material under specified terms and conditions. If
         | 
| 41 | 
            +
                 the licensor's permission is not necessary for any reason--for
         | 
| 42 | 
            +
                 example, because of any applicable exception or limitation to
         | 
| 43 | 
            +
                 copyright--then that use is not regulated by the license. Our
         | 
| 44 | 
            +
                 licenses grant only permissions under copyright and certain
         | 
| 45 | 
            +
                 other rights that a licensor has authority to grant. Use of
         | 
| 46 | 
            +
                 the licensed material may still be restricted for other
         | 
| 47 | 
            +
                 reasons, including because others have copyright or other
         | 
| 48 | 
            +
                 rights in the material. A licensor may make special requests,
         | 
| 49 | 
            +
                 such as asking that all changes be marked or described.
         | 
| 50 | 
            +
                 Although not required by our licenses, you are encouraged to
         | 
| 51 | 
            +
                 respect those requests where reasonable. More_considerations
         | 
| 52 | 
            +
                 for the public: 
         | 
| 53 | 
            +
            	wiki.creativecommons.org/Considerations_for_licensees
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            =======================================================================
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            Creative Commons Attribution-NonCommercial 4.0 International Public
         | 
| 58 | 
            +
            License
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            By exercising the Licensed Rights (defined below), You accept and agree
         | 
| 61 | 
            +
            to be bound by the terms and conditions of this Creative Commons
         | 
| 62 | 
            +
            Attribution-NonCommercial 4.0 International Public License ("Public
         | 
| 63 | 
            +
            License"). To the extent this Public License may be interpreted as a
         | 
| 64 | 
            +
            contract, You are granted the Licensed Rights in consideration of Your
         | 
| 65 | 
            +
            acceptance of these terms and conditions, and the Licensor grants You
         | 
| 66 | 
            +
            such rights in consideration of benefits the Licensor receives from
         | 
| 67 | 
            +
            making the Licensed Material available under these terms and
         | 
| 68 | 
            +
            conditions.
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            Section 1 -- Definitions.
         | 
| 71 | 
            +
             | 
| 72 | 
            +
              a. Adapted Material means material subject to Copyright and Similar
         | 
| 73 | 
            +
                 Rights that is derived from or based upon the Licensed Material
         | 
| 74 | 
            +
                 and in which the Licensed Material is translated, altered,
         | 
| 75 | 
            +
                 arranged, transformed, or otherwise modified in a manner requiring
         | 
| 76 | 
            +
                 permission under the Copyright and Similar Rights held by the
         | 
| 77 | 
            +
                 Licensor. For purposes of this Public License, where the Licensed
         | 
| 78 | 
            +
                 Material is a musical work, performance, or sound recording,
         | 
| 79 | 
            +
                 Adapted Material is always produced where the Licensed Material is
         | 
| 80 | 
            +
                 synched in timed relation with a moving image.
         | 
| 81 | 
            +
             | 
| 82 | 
            +
              b. Adapter's License means the license You apply to Your Copyright
         | 
| 83 | 
            +
                 and Similar Rights in Your contributions to Adapted Material in
         | 
| 84 | 
            +
                 accordance with the terms and conditions of this Public License.
         | 
| 85 | 
            +
             | 
| 86 | 
            +
              c. Copyright and Similar Rights means copyright and/or similar rights
         | 
| 87 | 
            +
                 closely related to copyright including, without limitation,
         | 
| 88 | 
            +
                 performance, broadcast, sound recording, and Sui Generis Database
         | 
| 89 | 
            +
                 Rights, without regard to how the rights are labeled or
         | 
| 90 | 
            +
                 categorized. For purposes of this Public License, the rights
         | 
| 91 | 
            +
                 specified in Section 2(b)(1)-(2) are not Copyright and Similar
         | 
| 92 | 
            +
                 Rights.
         | 
| 93 | 
            +
              d. Effective Technological Measures means those measures that, in the
         | 
| 94 | 
            +
                 absence of proper authority, may not be circumvented under laws
         | 
| 95 | 
            +
                 fulfilling obligations under Article 11 of the WIPO Copyright
         | 
| 96 | 
            +
                 Treaty adopted on December 20, 1996, and/or similar international
         | 
| 97 | 
            +
                 agreements.
         | 
| 98 | 
            +
             | 
| 99 | 
            +
              e. Exceptions and Limitations means fair use, fair dealing, and/or
         | 
| 100 | 
            +
                 any other exception or limitation to Copyright and Similar Rights
         | 
| 101 | 
            +
                 that applies to Your use of the Licensed Material.
         | 
| 102 | 
            +
             | 
| 103 | 
            +
              f. Licensed Material means the artistic or literary work, database,
         | 
| 104 | 
            +
                 or other material to which the Licensor applied this Public
         | 
| 105 | 
            +
                 License.
         | 
| 106 | 
            +
             | 
| 107 | 
            +
              g. Licensed Rights means the rights granted to You subject to the
         | 
| 108 | 
            +
                 terms and conditions of this Public License, which are limited to
         | 
| 109 | 
            +
                 all Copyright and Similar Rights that apply to Your use of the
         | 
| 110 | 
            +
                 Licensed Material and that the Licensor has authority to license.
         | 
| 111 | 
            +
             | 
| 112 | 
            +
              h. Licensor means the individual(s) or entity(ies) granting rights
         | 
| 113 | 
            +
                 under this Public License.
         | 
| 114 | 
            +
             | 
| 115 | 
            +
              i. NonCommercial means not primarily intended for or directed towards
         | 
| 116 | 
            +
                 commercial advantage or monetary compensation. For purposes of
         | 
| 117 | 
            +
                 this Public License, the exchange of the Licensed Material for
         | 
| 118 | 
            +
                 other material subject to Copyright and Similar Rights by digital
         | 
| 119 | 
            +
                 file-sharing or similar means is NonCommercial provided there is
         | 
| 120 | 
            +
                 no payment of monetary compensation in connection with the
         | 
| 121 | 
            +
                 exchange.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
              j. Share means to provide material to the public by any means or
         | 
| 124 | 
            +
                 process that requires permission under the Licensed Rights, such
         | 
| 125 | 
            +
                 as reproduction, public display, public performance, distribution,
         | 
| 126 | 
            +
                 dissemination, communication, or importation, and to make material
         | 
| 127 | 
            +
                 available to the public including in ways that members of the
         | 
| 128 | 
            +
                 public may access the material from a place and at a time
         | 
| 129 | 
            +
                 individually chosen by them.
         | 
| 130 | 
            +
             | 
| 131 | 
            +
              k. Sui Generis Database Rights means rights other than copyright
         | 
| 132 | 
            +
                 resulting from Directive 96/9/EC of the European Parliament and of
         | 
| 133 | 
            +
                 the Council of 11 March 1996 on the legal protection of databases,
         | 
| 134 | 
            +
                 as amended and/or succeeded, as well as other essentially
         | 
| 135 | 
            +
                 equivalent rights anywhere in the world.
         | 
| 136 | 
            +
             | 
| 137 | 
            +
              l. You means the individual or entity exercising the Licensed Rights
         | 
| 138 | 
            +
                 under this Public License. Your has a corresponding meaning.
         | 
| 139 | 
            +
             | 
| 140 | 
            +
            Section 2 -- Scope.
         | 
| 141 | 
            +
             | 
| 142 | 
            +
              a. License grant.
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                   1. Subject to the terms and conditions of this Public License,
         | 
| 145 | 
            +
                      the Licensor hereby grants You a worldwide, royalty-free,
         | 
| 146 | 
            +
                      non-sublicensable, non-exclusive, irrevocable license to
         | 
| 147 | 
            +
                      exercise the Licensed Rights in the Licensed Material to:
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        a. reproduce and Share the Licensed Material, in whole or
         | 
| 150 | 
            +
                           in part, for NonCommercial purposes only; and
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                        b. produce, reproduce, and Share Adapted Material for
         | 
| 153 | 
            +
                           NonCommercial purposes only.
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                   2. Exceptions and Limitations. For the avoidance of doubt, where
         | 
| 156 | 
            +
                      Exceptions and Limitations apply to Your use, this Public
         | 
| 157 | 
            +
                      License does not apply, and You do not need to comply with
         | 
| 158 | 
            +
                      its terms and conditions.
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                   3. Term. The term of this Public License is specified in Section
         | 
| 161 | 
            +
                      6(a).
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                   4. Media and formats; technical modifications allowed. The
         | 
| 164 | 
            +
                      Licensor authorizes You to exercise the Licensed Rights in
         | 
| 165 | 
            +
                      all media and formats whether now known or hereafter created,
         | 
| 166 | 
            +
                      and to make technical modifications necessary to do so. The
         | 
| 167 | 
            +
                      Licensor waives and/or agrees not to assert any right or
         | 
| 168 | 
            +
                      authority to forbid You from making technical modifications
         | 
| 169 | 
            +
                      necessary to exercise the Licensed Rights, including
         | 
| 170 | 
            +
                      technical modifications necessary to circumvent Effective
         | 
| 171 | 
            +
                      Technological Measures. For purposes of this Public License,
         | 
| 172 | 
            +
                      simply making modifications authorized by this Section 2(a)
         | 
| 173 | 
            +
                      (4) never produces Adapted Material.
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                   5. Downstream recipients.
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                        a. Offer from the Licensor -- Licensed Material. Every
         | 
| 178 | 
            +
                           recipient of the Licensed Material automatically
         | 
| 179 | 
            +
                           receives an offer from the Licensor to exercise the
         | 
| 180 | 
            +
                           Licensed Rights under the terms and conditions of this
         | 
| 181 | 
            +
                           Public License.
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                        b. No downstream restrictions. You may not offer or impose
         | 
| 184 | 
            +
                           any additional or different terms or conditions on, or
         | 
| 185 | 
            +
                           apply any Effective Technological Measures to, the
         | 
| 186 | 
            +
                           Licensed Material if doing so restricts exercise of the
         | 
| 187 | 
            +
                           Licensed Rights by any recipient of the Licensed
         | 
| 188 | 
            +
                           Material.
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                   6. No endorsement. Nothing in this Public License constitutes or
         | 
| 191 | 
            +
                      may be construed as permission to assert or imply that You
         | 
| 192 | 
            +
                      are, or that Your use of the Licensed Material is, connected
         | 
| 193 | 
            +
                      with, or sponsored, endorsed, or granted official status by,
         | 
| 194 | 
            +
                      the Licensor or others designated to receive attribution as
         | 
| 195 | 
            +
                      provided in Section 3(a)(1)(A)(i).
         | 
| 196 | 
            +
             | 
| 197 | 
            +
              b. Other rights.
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                   1. Moral rights, such as the right of integrity, are not
         | 
| 200 | 
            +
                      licensed under this Public License, nor are publicity,
         | 
| 201 | 
            +
                      privacy, and/or other similar personality rights; however, to
         | 
| 202 | 
            +
                      the extent possible, the Licensor waives and/or agrees not to
         | 
| 203 | 
            +
                      assert any such rights held by the Licensor to the limited
         | 
| 204 | 
            +
                      extent necessary to allow You to exercise the Licensed
         | 
| 205 | 
            +
                      Rights, but not otherwise.
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                   2. Patent and trademark rights are not licensed under this
         | 
| 208 | 
            +
                      Public License.
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                   3. To the extent possible, the Licensor waives any right to
         | 
| 211 | 
            +
                      collect royalties from You for the exercise of the Licensed
         | 
| 212 | 
            +
                      Rights, whether directly or through a collecting society
         | 
| 213 | 
            +
                      under any voluntary or waivable statutory or compulsory
         | 
| 214 | 
            +
                      licensing scheme. In all other cases the Licensor expressly
         | 
| 215 | 
            +
                      reserves any right to collect such royalties, including when
         | 
| 216 | 
            +
                      the Licensed Material is used other than for NonCommercial
         | 
| 217 | 
            +
                      purposes.
         | 
| 218 | 
            +
             | 
| 219 | 
            +
            Section 3 -- License Conditions.
         | 
| 220 | 
            +
             | 
| 221 | 
            +
            Your exercise of the Licensed Rights is expressly made subject to the
         | 
| 222 | 
            +
            following conditions.
         | 
| 223 | 
            +
             | 
| 224 | 
            +
              a. Attribution.
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                   1. If You Share the Licensed Material (including in modified
         | 
| 227 | 
            +
                      form), You must:
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                        a. retain the following if it is supplied by the Licensor
         | 
| 230 | 
            +
                           with the Licensed Material:
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                             i. identification of the creator(s) of the Licensed
         | 
| 233 | 
            +
                                Material and any others designated to receive
         | 
| 234 | 
            +
                                attribution, in any reasonable manner requested by
         | 
| 235 | 
            +
                                the Licensor (including by pseudonym if
         | 
| 236 | 
            +
                                designated);
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                            ii. a copyright notice;
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                           iii. a notice that refers to this Public License;
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                            iv. a notice that refers to the disclaimer of
         | 
| 243 | 
            +
                                warranties;
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                             v. a URI or hyperlink to the Licensed Material to the
         | 
| 246 | 
            +
                                extent reasonably practicable;
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                        b. indicate if You modified the Licensed Material and
         | 
| 249 | 
            +
                           retain an indication of any previous modifications; and
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                        c. indicate the Licensed Material is licensed under this
         | 
| 252 | 
            +
                           Public License, and include the text of, or the URI or
         | 
| 253 | 
            +
                           hyperlink to, this Public License.
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                   2. You may satisfy the conditions in Section 3(a)(1) in any
         | 
| 256 | 
            +
                      reasonable manner based on the medium, means, and context in
         | 
| 257 | 
            +
                      which You Share the Licensed Material. For example, it may be
         | 
| 258 | 
            +
                      reasonable to satisfy the conditions by providing a URI or
         | 
| 259 | 
            +
                      hyperlink to a resource that includes the required
         | 
| 260 | 
            +
                      information.
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                   3. If requested by the Licensor, You must remove any of the
         | 
| 263 | 
            +
                      information required by Section 3(a)(1)(A) to the extent
         | 
| 264 | 
            +
                      reasonably practicable.
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                   4. If You Share Adapted Material You produce, the Adapter's
         | 
| 267 | 
            +
                      License You apply must not prevent recipients of the Adapted
         | 
| 268 | 
            +
                      Material from complying with this Public License.
         | 
| 269 | 
            +
             | 
| 270 | 
            +
            Section 4 -- Sui Generis Database Rights.
         | 
| 271 | 
            +
             | 
| 272 | 
            +
            Where the Licensed Rights include Sui Generis Database Rights that
         | 
| 273 | 
            +
            apply to Your use of the Licensed Material:
         | 
| 274 | 
            +
             | 
| 275 | 
            +
              a. for the avoidance of doubt, Section 2(a)(1) grants You the right
         | 
| 276 | 
            +
                 to extract, reuse, reproduce, and Share all or a substantial
         | 
| 277 | 
            +
                 portion of the contents of the database for NonCommercial purposes
         | 
| 278 | 
            +
                 only;
         | 
| 279 | 
            +
             | 
| 280 | 
            +
              b. if You include all or a substantial portion of the database
         | 
| 281 | 
            +
                 contents in a database in which You have Sui Generis Database
         | 
| 282 | 
            +
                 Rights, then the database in which You have Sui Generis Database
         | 
| 283 | 
            +
                 Rights (but not its individual contents) is Adapted Material; and
         | 
| 284 | 
            +
             | 
| 285 | 
            +
              c. You must comply with the conditions in Section 3(a) if You Share
         | 
| 286 | 
            +
                 all or a substantial portion of the contents of the database.
         | 
| 287 | 
            +
             | 
| 288 | 
            +
            For the avoidance of doubt, this Section 4 supplements and does not
         | 
| 289 | 
            +
            replace Your obligations under this Public License where the Licensed
         | 
| 290 | 
            +
            Rights include other Copyright and Similar Rights.
         | 
| 291 | 
            +
             | 
| 292 | 
            +
            Section 5 -- Disclaimer of Warranties and Limitation of Liability.
         | 
| 293 | 
            +
             | 
| 294 | 
            +
              a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
         | 
| 295 | 
            +
                 EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
         | 
| 296 | 
            +
                 AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
         | 
| 297 | 
            +
                 ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
         | 
| 298 | 
            +
                 IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
         | 
| 299 | 
            +
                 WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
         | 
| 300 | 
            +
                 PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
         | 
| 301 | 
            +
                 ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
         | 
| 302 | 
            +
                 KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
         | 
| 303 | 
            +
                 ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
         | 
| 304 | 
            +
             | 
| 305 | 
            +
              b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
         | 
| 306 | 
            +
                 TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
         | 
| 307 | 
            +
                 NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
         | 
| 308 | 
            +
                 INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
         | 
| 309 | 
            +
                 COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
         | 
| 310 | 
            +
                 USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
         | 
| 311 | 
            +
                 ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
         | 
| 312 | 
            +
                 DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
         | 
| 313 | 
            +
                 IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
         | 
| 314 | 
            +
             | 
| 315 | 
            +
              c. The disclaimer of warranties and limitation of liability provided
         | 
| 316 | 
            +
                 above shall be interpreted in a manner that, to the extent
         | 
| 317 | 
            +
                 possible, most closely approximates an absolute disclaimer and
         | 
| 318 | 
            +
                 waiver of all liability.
         | 
| 319 | 
            +
             | 
| 320 | 
            +
            Section 6 -- Term and Termination.
         | 
| 321 | 
            +
             | 
| 322 | 
            +
              a. This Public License applies for the term of the Copyright and
         | 
| 323 | 
            +
                 Similar Rights licensed here. However, if You fail to comply with
         | 
| 324 | 
            +
                 this Public License, then Your rights under this Public License
         | 
| 325 | 
            +
                 terminate automatically.
         | 
| 326 | 
            +
             | 
| 327 | 
            +
              b. Where Your right to use the Licensed Material has terminated under
         | 
| 328 | 
            +
                 Section 6(a), it reinstates:
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                   1. automatically as of the date the violation is cured, provided
         | 
| 331 | 
            +
                      it is cured within 30 days of Your discovery of the
         | 
| 332 | 
            +
                      violation; or
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                   2. upon express reinstatement by the Licensor.
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                 For the avoidance of doubt, this Section 6(b) does not affect any
         | 
| 337 | 
            +
                 right the Licensor may have to seek remedies for Your violations
         | 
| 338 | 
            +
                 of this Public License.
         | 
| 339 | 
            +
             | 
| 340 | 
            +
              c. For the avoidance of doubt, the Licensor may also offer the
         | 
| 341 | 
            +
                 Licensed Material under separate terms or conditions or stop
         | 
| 342 | 
            +
                 distributing the Licensed Material at any time; however, doing so
         | 
| 343 | 
            +
                 will not terminate this Public License.
         | 
| 344 | 
            +
             | 
| 345 | 
            +
              d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
         | 
| 346 | 
            +
                 License.
         | 
| 347 | 
            +
             | 
| 348 | 
            +
            Section 7 -- Other Terms and Conditions.
         | 
| 349 | 
            +
             | 
| 350 | 
            +
              a. The Licensor shall not be bound by any additional or different
         | 
| 351 | 
            +
                 terms or conditions communicated by You unless expressly agreed.
         | 
| 352 | 
            +
             | 
| 353 | 
            +
              b. Any arrangements, understandings, or agreements regarding the
         | 
| 354 | 
            +
                 Licensed Material not stated herein are separate from and
         | 
| 355 | 
            +
                 independent of the terms and conditions of this Public License.
         | 
| 356 | 
            +
             | 
| 357 | 
            +
            Section 8 -- Interpretation.
         | 
| 358 | 
            +
             | 
| 359 | 
            +
              a. For the avoidance of doubt, this Public License does not, and
         | 
| 360 | 
            +
                 shall not be interpreted to, reduce, limit, restrict, or impose
         | 
| 361 | 
            +
                 conditions on any use of the Licensed Material that could lawfully
         | 
| 362 | 
            +
                 be made without permission under this Public License.
         | 
| 363 | 
            +
             | 
| 364 | 
            +
              b. To the extent possible, if any provision of this Public License is
         | 
| 365 | 
            +
                 deemed unenforceable, it shall be automatically reformed to the
         | 
| 366 | 
            +
                 minimum extent necessary to make it enforceable. If the provision
         | 
| 367 | 
            +
                 cannot be reformed, it shall be severed from this Public License
         | 
| 368 | 
            +
                 without affecting the enforceability of the remaining terms and
         | 
| 369 | 
            +
                 conditions.
         | 
| 370 | 
            +
             | 
| 371 | 
            +
              c. No term or condition of this Public License will be waived and no
         | 
| 372 | 
            +
                 failure to comply consented to unless expressly agreed to by the
         | 
| 373 | 
            +
                 Licensor.
         | 
| 374 | 
            +
             | 
| 375 | 
            +
              d. Nothing in this Public License constitutes or may be interpreted
         | 
| 376 | 
            +
                 as a limitation upon, or waiver of, any privileges and immunities
         | 
| 377 | 
            +
                 that apply to the Licensor or You, including from the legal
         | 
| 378 | 
            +
                 processes of any jurisdiction or authority.
         | 
| 379 | 
            +
             | 
| 380 | 
            +
            =======================================================================
         | 
| 381 | 
            +
             | 
| 382 | 
            +
            Creative Commons is not a party to its public
         | 
| 383 | 
            +
            licenses. Notwithstanding, Creative Commons may elect to apply one of
         | 
| 384 | 
            +
            its public licenses to material it publishes and in those instances
         | 
| 385 | 
            +
            will be considered the “Licensor.” The text of the Creative Commons
         | 
| 386 | 
            +
            public licenses is dedicated to the public domain under the CC0 Public
         | 
| 387 | 
            +
            Domain Dedication. Except for the limited purpose of indicating that
         | 
| 388 | 
            +
            material is shared under a Creative Commons public license or as
         | 
| 389 | 
            +
            otherwise permitted by the Creative Commons policies published at
         | 
| 390 | 
            +
            creativecommons.org/policies, Creative Commons does not authorize the
         | 
| 391 | 
            +
            use of the trademark "Creative Commons" or any other trademark or logo
         | 
| 392 | 
            +
            of Creative Commons without its prior written consent including,
         | 
| 393 | 
            +
            without limitation, in connection with any unauthorized modifications
         | 
| 394 | 
            +
            to any of its public licenses or any other arrangements,
         | 
| 395 | 
            +
            understandings, or agreements concerning use of licensed material. For
         | 
| 396 | 
            +
            the avoidance of doubt, this paragraph does not form part of the
         | 
| 397 | 
            +
            public licenses.
         | 
| 398 | 
            +
             | 
| 399 | 
            +
            Creative Commons may be contacted at creativecommons.org.
         | 
    	
        MANIFEST.in
    CHANGED
    
    | @@ -6,3 +6,10 @@ include *.ini | |
| 6 | 
             
            include requirements.txt
         | 
| 7 | 
             
            include audiocraft/py.typed
         | 
| 8 | 
             
            include assets/*.mp3
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 6 | 
             
            include requirements.txt
         | 
| 7 | 
             
            include audiocraft/py.typed
         | 
| 8 | 
             
            include assets/*.mp3
         | 
| 9 | 
            +
            include datasets/*.mp3
         | 
| 10 | 
            +
            recursive-include config *.yaml
         | 
| 11 | 
            +
            recursive-include demos *.py
         | 
| 12 | 
            +
            recursive-include demos *.ipynb
         | 
| 13 | 
            +
            recursive-include scripts *.py
         | 
| 14 | 
            +
            recursive-include model_cards *.md
         | 
| 15 | 
            +
            recursive-include docs *.md
         | 
    	
        Makefile
    CHANGED
    
    | @@ -1,3 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            default: linter tests
         | 
| 2 |  | 
| 3 | 
             
            install:
         | 
| @@ -10,12 +22,19 @@ linter: | |
| 10 |  | 
| 11 | 
             
            tests:
         | 
| 12 | 
             
            	coverage run -m pytest tests
         | 
| 13 | 
            -
            	coverage report | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 14 |  | 
| 15 | 
            -
             | 
| 16 | 
            -
            	pdoc3 --html -o  | 
| 17 |  | 
| 18 | 
             
            dist:
         | 
| 19 | 
             
            	python setup.py sdist
         | 
| 20 |  | 
| 21 | 
            -
            .PHONY: linter tests  | 
|  | |
| 1 | 
            +
            INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \
         | 
| 2 | 
            +
            	dataset.train.num_samples=10 dataset.valid.num_samples=10 \
         | 
| 3 | 
            +
            	dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \
         | 
| 4 | 
            +
            	logging.level=DEBUG
         | 
| 5 | 
            +
            INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true   # SIG is 5091833e
         | 
| 6 | 
            +
            INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
         | 
| 7 | 
            +
            	transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false  # Using compression model from 5091833e
         | 
| 8 | 
            +
            INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
         | 
| 9 | 
            +
            	transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false  # Using compression model from 5091833e
         | 
| 10 | 
            +
            INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example  \
         | 
| 11 | 
            +
            	checkpoint.save_last=false  # Using compression model from 616d7b3c
         | 
| 12 | 
            +
             | 
| 13 | 
             
            default: linter tests
         | 
| 14 |  | 
| 15 | 
             
            install:
         | 
|  | |
| 22 |  | 
| 23 | 
             
            tests:
         | 
| 24 | 
             
            	coverage run -m pytest tests
         | 
| 25 | 
            +
            	coverage report
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            tests_integ:
         | 
| 28 | 
            +
            	$(INTEG_COMPRESSION)
         | 
| 29 | 
            +
            	$(INTEG_MBD)
         | 
| 30 | 
            +
            	$(INTEG_MUSICGEN)
         | 
| 31 | 
            +
            	$(INTEG_AUDIOGEN)
         | 
| 32 | 
            +
             | 
| 33 |  | 
| 34 | 
            +
            api_docs:
         | 
| 35 | 
            +
            	pdoc3 --html -o api_docs -f audiocraft
         | 
| 36 |  | 
| 37 | 
             
            dist:
         | 
| 38 | 
             
            	python setup.py sdist
         | 
| 39 |  | 
| 40 | 
            +
            .PHONY: linter tests api_docs dist
         | 
    	
        README.md
    CHANGED
    
    | @@ -5,7 +5,7 @@ tags: | |
| 5 | 
             
              - "music generation"
         | 
| 6 | 
             
              - "language models"
         | 
| 7 | 
             
              - "LLMs"
         | 
| 8 | 
            -
            app_file: " | 
| 9 | 
             
            emoji: 🎵
         | 
| 10 | 
             
            colorFrom: gray
         | 
| 11 | 
             
            colorTo: blue
         | 
| @@ -14,33 +14,17 @@ sdk_version: 3.34.0 | |
| 14 | 
             
            pinned: true
         | 
| 15 | 
             
            license: "cc-by-nc-4.0"
         | 
| 16 | 
             
            ---
         | 
| 17 | 
            -
            #  | 
| 18 | 
             
            
         | 
| 19 | 
             
            
         | 
| 20 | 
             
            
         | 
| 21 |  | 
| 22 | 
            -
             | 
|  | |
| 23 |  | 
| 24 | 
            -
            ## MusicGen
         | 
| 25 | 
            -
             | 
| 26 | 
            -
            Audiocraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv]. MusicGen is a single stage auto-regressive
         | 
| 27 | 
            -
            Transformer model trained over a 32kHz <a href="https://github.com/facebookresearch/encodec">EnCodec tokenizer</a> with 4 codebooks sampled at 50 Hz. Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't require a self-supervised semantic representation, and it generates
         | 
| 28 | 
            -
            all 4 codebooks in one pass. By introducing a small delay between the codebooks, we show we can predict
         | 
| 29 | 
            -
            them in parallel, thus having only 50 auto-regressive steps per second of audio.
         | 
| 30 | 
            -
            Check out our [sample page][musicgen_samples] or test the available demo!
         | 
| 31 | 
            -
             | 
| 32 | 
            -
            <a target="_blank" href="https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing">
         | 
| 33 | 
            -
              <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
         | 
| 34 | 
            -
            </a>
         | 
| 35 | 
            -
            <a target="_blank" href="https://huggingface.co/spaces/facebook/MusicGen">
         | 
| 36 | 
            -
              <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg" alt="Open in HugginFace"/>
         | 
| 37 | 
            -
            </a>
         | 
| 38 | 
            -
            <br>
         | 
| 39 | 
            -
             | 
| 40 | 
            -
            We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data.
         | 
| 41 |  | 
| 42 | 
             
            ## Installation
         | 
| 43 | 
            -
             | 
| 44 |  | 
| 45 | 
             
            ```shell
         | 
| 46 | 
             
            # Best to make sure you have torch installed first, in particular before installing xformers.
         | 
| @@ -49,92 +33,68 @@ pip install 'torch>=2.0' | |
| 49 | 
             
            # Then proceed to one of the following
         | 
| 50 | 
             
            pip install -U audiocraft  # stable release
         | 
| 51 | 
             
            pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft  # bleeding edge
         | 
| 52 | 
            -
            pip install -e .  # or if you cloned the repo locally
         | 
| 53 | 
             
            ```
         | 
| 54 |  | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
            4. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally (if you have a GPU).
         | 
| 61 | 
            -
            5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab) which is regularly
         | 
| 62 | 
            -
              updated with contributions from @camenduru and the community.
         | 
| 63 | 
            -
             | 
| 64 | 
            -
            ## API
         | 
| 65 | 
            -
             | 
| 66 | 
            -
            We provide a simple API and 4 pre-trained models. The pre trained models are:
         | 
| 67 | 
            -
            - `small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small)
         | 
| 68 | 
            -
            - `medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium)
         | 
| 69 | 
            -
            - `melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody)
         | 
| 70 | 
            -
            - `large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large)
         | 
| 71 | 
            -
             | 
| 72 | 
            -
            We observe the best trade-off between quality and compute with the `medium` or `melody` model.
         | 
| 73 | 
            -
            In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
         | 
| 74 | 
            -
            GPUs will be able to generate short sequences, or longer sequences with the `small` model.
         | 
| 75 | 
            -
             | 
| 76 | 
            -
            **Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`.
         | 
| 77 | 
            -
            You can install it with:
         | 
| 78 | 
            -
            ```
         | 
| 79 | 
            -
            apt-get install ffmpeg
         | 
| 80 | 
             
            ```
         | 
| 81 |  | 
| 82 | 
            -
             | 
| 83 |  | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
|  | |
| 88 |  | 
| 89 | 
            -
             | 
| 90 | 
            -
            model.set_generation_params(duration=8)  # generate 8 seconds.
         | 
| 91 | 
            -
            wav = model.generate_unconditional(4)    # generates 4 unconditional audio samples
         | 
| 92 | 
            -
            descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
         | 
| 93 | 
            -
            wav = model.generate(descriptions)  # generates 3 samples.
         | 
| 94 |  | 
| 95 | 
            -
             | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 |  | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
                audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
         | 
| 102 | 
            -
            ```
         | 
| 103 |  | 
| 104 |  | 
| 105 | 
            -
            ##  | 
| 106 |  | 
| 107 | 
            -
             | 
| 108 |  | 
| 109 | 
            -
            ## FAQ
         | 
| 110 |  | 
| 111 | 
            -
             | 
| 112 |  | 
| 113 | 
            -
             | 
| 114 |  | 
|  | |
| 115 |  | 
| 116 | 
            -
            ####  | 
| 117 |  | 
| 118 | 
            -
             | 
|  | |
|  | |
| 119 |  | 
| 120 | 
            -
            #### I need help for running the demo on Colab
         | 
| 121 |  | 
| 122 | 
            -
             | 
|  | |
|  | |
| 123 |  | 
| 124 |  | 
| 125 | 
             
            ## Citation
         | 
|  | |
|  | |
| 126 | 
             
            ```
         | 
| 127 | 
             
            @article{copet2023simple,
         | 
| 128 | 
            -
             | 
| 129 | 
            -
             | 
| 130 | 
            -
             | 
| 131 | 
            -
             | 
| 132 | 
             
            }
         | 
| 133 | 
             
            ```
         | 
| 134 |  | 
| 135 | 
            -
             | 
| 136 | 
            -
             | 
| 137 | 
            -
            * The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
         | 
| 138 | 
            -
             | 
| 139 | 
            -
            [arxiv]: https://arxiv.org/abs/2306.05284
         | 
| 140 | 
            -
            [musicgen_samples]: https://ai.honu.io/papers/musicgen/
         | 
|  | |
| 5 | 
             
              - "music generation"
         | 
| 6 | 
             
              - "language models"
         | 
| 7 | 
             
              - "LLMs"
         | 
| 8 | 
            +
            app_file: "demos/musicgen_app.py"
         | 
| 9 | 
             
            emoji: 🎵
         | 
| 10 | 
             
            colorFrom: gray
         | 
| 11 | 
             
            colorTo: blue
         | 
|  | |
| 14 | 
             
            pinned: true
         | 
| 15 | 
             
            license: "cc-by-nc-4.0"
         | 
| 16 | 
             
            ---
         | 
| 17 | 
            +
            # AudioCraft
         | 
| 18 | 
             
            
         | 
| 19 | 
             
            
         | 
| 20 | 
             
            
         | 
| 21 |  | 
| 22 | 
            +
            AudioCraft is a PyTorch library for deep learning research on audio generation. AudioCraft contains inference and training code
         | 
| 23 | 
            +
            for two state-of-the-art AI generative models producing high-quality audio: AudioGen and MusicGen.
         | 
| 24 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 25 |  | 
| 26 | 
             
            ## Installation
         | 
| 27 | 
            +
            AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can run the following:
         | 
| 28 |  | 
| 29 | 
             
            ```shell
         | 
| 30 | 
             
            # Best to make sure you have torch installed first, in particular before installing xformers.
         | 
|  | |
| 33 | 
             
            # Then proceed to one of the following
         | 
| 34 | 
             
            pip install -U audiocraft  # stable release
         | 
| 35 | 
             
            pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft  # bleeding edge
         | 
| 36 | 
            +
            pip install -e .  # or if you cloned the repo locally (mandatory if you want to train).
         | 
| 37 | 
             
            ```
         | 
| 38 |  | 
| 39 | 
            +
            We also recommend having `ffmpeg` installed, either through your system or Anaconda:
         | 
| 40 | 
            +
            ```bash
         | 
| 41 | 
            +
            sudo apt-get install ffmpeg
         | 
| 42 | 
            +
            # Or if you are using Anaconda or Miniconda
         | 
| 43 | 
            +
            conda install "ffmpeg<5" -c conda-forge
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 44 | 
             
            ```
         | 
| 45 |  | 
| 46 | 
            +
            ## Models
         | 
| 47 |  | 
| 48 | 
            +
            At the moment, AudioCraft contains the training code and inference code for:
         | 
| 49 | 
            +
            * [MusicGen](./docs/MUSICGEN.md): A state-of-the-art controllable text-to-music model.
         | 
| 50 | 
            +
            * [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model.
         | 
| 51 | 
            +
            * [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec.
         | 
| 52 | 
            +
            * [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion.
         | 
| 53 |  | 
| 54 | 
            +
            ## Training code
         | 
|  | |
|  | |
|  | |
|  | |
| 55 |  | 
| 56 | 
            +
            AudioCraft contains PyTorch components for deep learning research in audio and training pipelines for the developed models.
         | 
| 57 | 
            +
            For a general introduction of AudioCraft design principles and instructions to develop your own training pipeline, refer to
         | 
| 58 | 
            +
            the [AudioCraft training documentation](./docs/TRAINING.md).
         | 
| 59 |  | 
| 60 | 
            +
            For reproducing existing work and using the developed training pipelines, refer to the instructions for each specific model
         | 
| 61 | 
            +
            that provides pointers to configuration, example grids and model/task-specific information and FAQ.
         | 
|  | |
|  | |
| 62 |  | 
| 63 |  | 
| 64 | 
            +
            ## API documentation
         | 
| 65 |  | 
| 66 | 
            +
            We provide some [API documentation](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html) for AudioCraft.
         | 
| 67 |  | 
|  | |
| 68 |  | 
| 69 | 
            +
            ## FAQ
         | 
| 70 |  | 
| 71 | 
            +
            #### Is the training code available?
         | 
| 72 |  | 
| 73 | 
            +
            Yes! We provide the training code for [EnCodec](./docs/ENCODEC.md), [MusicGen](./docs/MUSICGEN.md) and [Multi Band Diffusion](./docs/MBD.md).
         | 
| 74 |  | 
| 75 | 
            +
            #### Where are the models stored?
         | 
| 76 |  | 
| 77 | 
            +
            Hugging Face stored the model in a specific location, which can be overriden by setting the `AUDIOCRAFT_CACHE_DIR` environment variable for the AudioCraft models.
         | 
| 78 | 
            +
            In order to change the cache location of the other Hugging Face models, please check out the [Hugging Face Transformers documentation for the cache setup](https://huggingface.co/docs/transformers/installation#cache-setup).
         | 
| 79 | 
            +
            Finally, if you use a model that relies on Demucs (e.g. `musicgen-melody`) and want to change the download location for Demucs, refer to the [Torch Hub documentation](https://pytorch.org/docs/stable/hub.html#where-are-my-downloaded-models-saved).
         | 
| 80 |  | 
|  | |
| 81 |  | 
| 82 | 
            +
            ## License
         | 
| 83 | 
            +
            * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
         | 
| 84 | 
            +
            * The models weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
         | 
| 85 |  | 
| 86 |  | 
| 87 | 
             
            ## Citation
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            For the general framework of AudioCraft, please cite the following.
         | 
| 90 | 
             
            ```
         | 
| 91 | 
             
            @article{copet2023simple,
         | 
| 92 | 
            +
                title={Simple and Controllable Music Generation},
         | 
| 93 | 
            +
                author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
         | 
| 94 | 
            +
                year={2023},
         | 
| 95 | 
            +
                journal={arXiv preprint arXiv:2306.05284},
         | 
| 96 | 
             
            }
         | 
| 97 | 
             
            ```
         | 
| 98 |  | 
| 99 | 
            +
            When referring to a specific model, please cite as mentioned in the model specific README, e.g
         | 
| 100 | 
            +
            [./docs/MUSICGEN.md](./docs/MUSICGEN.md), [./docs/AUDIOGEN.md](./docs/AUDIOGEN.md), etc.
         | 
|  | |
|  | |
|  | |
|  | 
    	
        assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3
    ADDED
    
    | Binary file (15.2 kB). View file | 
|  | 
    	
        assets/sirens_and_a_humming_engine_approach_and_pass.mp3
    ADDED
    
    | Binary file (15.2 kB). View file | 
|  | 
    	
        audiocraft/__init__.py
    CHANGED
    
    | @@ -3,8 +3,24 @@ | |
| 3 | 
             
            #
         | 
| 4 | 
             
            # This source code is licensed under the license found in the
         | 
| 5 | 
             
            # LICENSE file in the root directory of this source tree.
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 6 |  | 
| 7 | 
             
            # flake8: noqa
         | 
| 8 | 
             
            from . import data, modules, models
         | 
| 9 |  | 
| 10 | 
            -
            __version__ = ' | 
|  | |
| 3 | 
             
            #
         | 
| 4 | 
             
            # This source code is licensed under the license found in the
         | 
| 5 | 
             
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            AudioCraft is a general framework for training audio generative models.
         | 
| 8 | 
            +
            At the moment we provide the training code for:
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            - [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
         | 
| 11 | 
            +
                text-to-music and melody+text autoregressive generative model.
         | 
| 12 | 
            +
                For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
         | 
| 13 | 
            +
                `audiocraft.models.musicgen.MusicGen`.
         | 
| 14 | 
            +
            - [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
         | 
| 15 | 
            +
                text-to-general-audio generative model.
         | 
| 16 | 
            +
            - [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
         | 
| 17 | 
            +
                neural audio codec which provides an excellent tokenizer for autoregressive language models.
         | 
| 18 | 
            +
                See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
         | 
| 19 | 
            +
            - [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
         | 
| 20 | 
            +
                improves the perceived quality and reduces the artifacts coming from adversarial decoders.
         | 
| 21 | 
            +
            """
         | 
| 22 |  | 
| 23 | 
             
            # flake8: noqa
         | 
| 24 | 
             
            from . import data, modules, models
         | 
| 25 |  | 
| 26 | 
            +
            __version__ = '1.1.0'
         | 
    	
        audiocraft/adversarial/__init__.py
    ADDED
    
    | @@ -0,0 +1,22 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Adversarial losses and discriminator architectures."""
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # flake8: noqa
         | 
| 9 | 
            +
            from .discriminators import (
         | 
| 10 | 
            +
                MultiPeriodDiscriminator,
         | 
| 11 | 
            +
                MultiScaleDiscriminator,
         | 
| 12 | 
            +
                MultiScaleSTFTDiscriminator
         | 
| 13 | 
            +
            )
         | 
| 14 | 
            +
            from .losses import (
         | 
| 15 | 
            +
                AdversarialLoss,
         | 
| 16 | 
            +
                AdvLossType,
         | 
| 17 | 
            +
                get_adv_criterion,
         | 
| 18 | 
            +
                get_fake_criterion,
         | 
| 19 | 
            +
                get_real_criterion,
         | 
| 20 | 
            +
                FeatLossType,
         | 
| 21 | 
            +
                FeatureMatchingLoss
         | 
| 22 | 
            +
            )
         | 
    	
        audiocraft/adversarial/discriminators/__init__.py
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # flake8: noqa
         | 
| 8 | 
            +
            from .mpd import MultiPeriodDiscriminator
         | 
| 9 | 
            +
            from .msd import MultiScaleDiscriminator
         | 
| 10 | 
            +
            from .msstftd import MultiScaleSTFTDiscriminator
         | 
    	
        audiocraft/adversarial/discriminators/base.py
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from abc import ABC, abstractmethod
         | 
| 8 | 
            +
            import typing as tp
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            import torch.nn as nn
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            FeatureMapType = tp.List[torch.Tensor]
         | 
| 15 | 
            +
            LogitsType = torch.Tensor
         | 
| 16 | 
            +
            MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class MultiDiscriminator(ABC, nn.Module):
         | 
| 20 | 
            +
                """Base implementation for discriminators composed of sub-discriminators acting at different scales.
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                def __init__(self):
         | 
| 23 | 
            +
                    super().__init__()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                @abstractmethod
         | 
| 26 | 
            +
                def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
         | 
| 27 | 
            +
                    ...
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                @property
         | 
| 30 | 
            +
                @abstractmethod
         | 
| 31 | 
            +
                def num_discriminators(self) -> int:
         | 
| 32 | 
            +
                    """Number of discriminators.
         | 
| 33 | 
            +
                    """
         | 
| 34 | 
            +
                    ...
         | 
    	
        audiocraft/adversarial/discriminators/mpd.py
    ADDED
    
    | @@ -0,0 +1,106 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import typing as tp
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
            import torch.nn.functional as F
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from ...modules import NormConv2d
         | 
| 14 | 
            +
            from .base import MultiDiscriminator, MultiDiscriminatorOutputType
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def get_padding(kernel_size: int, dilation: int = 1) -> int:
         | 
| 18 | 
            +
                return int((kernel_size * dilation - dilation) / 2)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class PeriodDiscriminator(nn.Module):
         | 
| 22 | 
            +
                """Period sub-discriminator.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                Args:
         | 
| 25 | 
            +
                    period (int): Period between samples of audio.
         | 
| 26 | 
            +
                    in_channels (int): Number of input channels.
         | 
| 27 | 
            +
                    out_channels (int): Number of output channels.
         | 
| 28 | 
            +
                    n_layers (int): Number of convolutional layers.
         | 
| 29 | 
            +
                    kernel_sizes (list of int): Kernel sizes for convolutions.
         | 
| 30 | 
            +
                    stride (int): Stride for convolutions.
         | 
| 31 | 
            +
                    filters (int): Initial number of filters in convolutions.
         | 
| 32 | 
            +
                    filters_scale (int): Multiplier of number of filters as we increase depth.
         | 
| 33 | 
            +
                    max_filters (int): Maximum number of filters.
         | 
| 34 | 
            +
                    norm (str): Normalization method.
         | 
| 35 | 
            +
                    activation (str): Activation function.
         | 
| 36 | 
            +
                    activation_params (dict): Parameters to provide to the activation function.
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
         | 
| 39 | 
            +
                             n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
         | 
| 40 | 
            +
                             filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
         | 
| 41 | 
            +
                             norm: str = 'weight_norm', activation: str = 'LeakyReLU',
         | 
| 42 | 
            +
                             activation_params: dict = {'negative_slope': 0.2}):
         | 
| 43 | 
            +
                    super().__init__()
         | 
| 44 | 
            +
                    self.period = period
         | 
| 45 | 
            +
                    self.n_layers = n_layers
         | 
| 46 | 
            +
                    self.activation = getattr(torch.nn, activation)(**activation_params)
         | 
| 47 | 
            +
                    self.convs = nn.ModuleList()
         | 
| 48 | 
            +
                    in_chs = in_channels
         | 
| 49 | 
            +
                    for i in range(self.n_layers):
         | 
| 50 | 
            +
                        out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
         | 
| 51 | 
            +
                        eff_stride = 1 if i == self.n_layers - 1 else stride
         | 
| 52 | 
            +
                        self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
         | 
| 53 | 
            +
                                                     padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
         | 
| 54 | 
            +
                        in_chs = out_chs
         | 
| 55 | 
            +
                    self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
         | 
| 56 | 
            +
                                                padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def forward(self, x: torch.Tensor):
         | 
| 59 | 
            +
                    fmap = []
         | 
| 60 | 
            +
                    # 1d to 2d
         | 
| 61 | 
            +
                    b, c, t = x.shape
         | 
| 62 | 
            +
                    if t % self.period != 0:  # pad first
         | 
| 63 | 
            +
                        n_pad = self.period - (t % self.period)
         | 
| 64 | 
            +
                        x = F.pad(x, (0, n_pad), 'reflect')
         | 
| 65 | 
            +
                        t = t + n_pad
         | 
| 66 | 
            +
                    x = x.view(b, c, t // self.period, self.period)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    for conv in self.convs:
         | 
| 69 | 
            +
                        x = conv(x)
         | 
| 70 | 
            +
                        x = self.activation(x)
         | 
| 71 | 
            +
                        fmap.append(x)
         | 
| 72 | 
            +
                    x = self.conv_post(x)
         | 
| 73 | 
            +
                    fmap.append(x)
         | 
| 74 | 
            +
                    # x = torch.flatten(x, 1, -1)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    return x, fmap
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            class MultiPeriodDiscriminator(MultiDiscriminator):
         | 
| 80 | 
            +
                """Multi-Period (MPD) Discriminator.
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                Args:
         | 
| 83 | 
            +
                    in_channels (int): Number of input channels.
         | 
| 84 | 
            +
                    out_channels (int): Number of output channels.
         | 
| 85 | 
            +
                    periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
         | 
| 86 | 
            +
                    **kwargs: Additional args for `PeriodDiscriminator`
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                def __init__(self, in_channels: int = 1, out_channels: int = 1,
         | 
| 89 | 
            +
                             periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
         | 
| 90 | 
            +
                    super().__init__()
         | 
| 91 | 
            +
                    self.discriminators = nn.ModuleList([
         | 
| 92 | 
            +
                        PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
         | 
| 93 | 
            +
                    ])
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                @property
         | 
| 96 | 
            +
                def num_discriminators(self):
         | 
| 97 | 
            +
                    return len(self.discriminators)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
         | 
| 100 | 
            +
                    logits = []
         | 
| 101 | 
            +
                    fmaps = []
         | 
| 102 | 
            +
                    for disc in self.discriminators:
         | 
| 103 | 
            +
                        logit, fmap = disc(x)
         | 
| 104 | 
            +
                        logits.append(logit)
         | 
| 105 | 
            +
                        fmaps.append(fmap)
         | 
| 106 | 
            +
                    return logits, fmaps
         | 
    	
        audiocraft/adversarial/discriminators/msd.py
    ADDED
    
    | @@ -0,0 +1,126 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import typing as tp
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            import torch.nn as nn
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from ...modules import NormConv1d
         | 
| 14 | 
            +
            from .base import MultiDiscriminator, MultiDiscriminatorOutputType
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class ScaleDiscriminator(nn.Module):
         | 
| 18 | 
            +
                """Waveform sub-discriminator.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                Args:
         | 
| 21 | 
            +
                    in_channels (int): Number of input channels.
         | 
| 22 | 
            +
                    out_channels (int): Number of output channels.
         | 
| 23 | 
            +
                    kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
         | 
| 24 | 
            +
                    filters (int): Number of initial filters for convolutions.
         | 
| 25 | 
            +
                    max_filters (int): Maximum number of filters.
         | 
| 26 | 
            +
                    downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
         | 
| 27 | 
            +
                    inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
         | 
| 28 | 
            +
                    groups (Sequence[int] or None): Groups for inner convolutions.
         | 
| 29 | 
            +
                    strides (Sequence[int] or None): Strides for inner convolutions.
         | 
| 30 | 
            +
                    paddings (Sequence[int] or None): Paddings for inner convolutions.
         | 
| 31 | 
            +
                    norm (str): Normalization method.
         | 
| 32 | 
            +
                    activation (str): Activation function.
         | 
| 33 | 
            +
                    activation_params (dict): Parameters to provide to the activation function.
         | 
| 34 | 
            +
                    pad (str): Padding for initial convolution.
         | 
| 35 | 
            +
                    pad_params (dict): Parameters to provide to the padding module.
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
         | 
| 38 | 
            +
                             filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
         | 
| 39 | 
            +
                             inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
         | 
| 40 | 
            +
                             strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
         | 
| 41 | 
            +
                             norm: str = 'weight_norm', activation: str = 'LeakyReLU',
         | 
| 42 | 
            +
                             activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
         | 
| 43 | 
            +
                             pad_params: dict = {}):
         | 
| 44 | 
            +
                    super().__init__()
         | 
| 45 | 
            +
                    assert len(kernel_sizes) == 2
         | 
| 46 | 
            +
                    assert kernel_sizes[0] % 2 == 1
         | 
| 47 | 
            +
                    assert kernel_sizes[1] % 2 == 1
         | 
| 48 | 
            +
                    assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
         | 
| 49 | 
            +
                    assert (groups is None or len(groups) == len(downsample_scales))
         | 
| 50 | 
            +
                    assert (strides is None or len(strides) == len(downsample_scales))
         | 
| 51 | 
            +
                    assert (paddings is None or len(paddings) == len(downsample_scales))
         | 
| 52 | 
            +
                    self.activation = getattr(torch.nn, activation)(**activation_params)
         | 
| 53 | 
            +
                    self.convs = nn.ModuleList()
         | 
| 54 | 
            +
                    self.convs.append(
         | 
| 55 | 
            +
                        nn.Sequential(
         | 
| 56 | 
            +
                            getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
         | 
| 57 | 
            +
                            NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
         | 
| 58 | 
            +
                        )
         | 
| 59 | 
            +
                    )
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    in_chs = filters
         | 
| 62 | 
            +
                    for i, downsample_scale in enumerate(downsample_scales):
         | 
| 63 | 
            +
                        out_chs = min(in_chs * downsample_scale, max_filters)
         | 
| 64 | 
            +
                        default_kernel_size = downsample_scale * 10 + 1
         | 
| 65 | 
            +
                        default_stride = downsample_scale
         | 
| 66 | 
            +
                        default_padding = (default_kernel_size - 1) // 2
         | 
| 67 | 
            +
                        default_groups = in_chs // 4
         | 
| 68 | 
            +
                        self.convs.append(
         | 
| 69 | 
            +
                            NormConv1d(in_chs, out_chs,
         | 
| 70 | 
            +
                                       kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
         | 
| 71 | 
            +
                                       stride=strides[i] if strides else default_stride,
         | 
| 72 | 
            +
                                       groups=groups[i] if groups else default_groups,
         | 
| 73 | 
            +
                                       padding=paddings[i] if paddings else default_padding,
         | 
| 74 | 
            +
                                       norm=norm))
         | 
| 75 | 
            +
                        in_chs = out_chs
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    out_chs = min(in_chs * 2, max_filters)
         | 
| 78 | 
            +
                    self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
         | 
| 79 | 
            +
                                                 padding=(kernel_sizes[0] - 1) // 2, norm=norm))
         | 
| 80 | 
            +
                    self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
         | 
| 81 | 
            +
                                                padding=(kernel_sizes[1] - 1) // 2, norm=norm)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                def forward(self, x: torch.Tensor):
         | 
| 84 | 
            +
                    fmap = []
         | 
| 85 | 
            +
                    for layer in self.convs:
         | 
| 86 | 
            +
                        x = layer(x)
         | 
| 87 | 
            +
                        x = self.activation(x)
         | 
| 88 | 
            +
                        fmap.append(x)
         | 
| 89 | 
            +
                    x = self.conv_post(x)
         | 
| 90 | 
            +
                    fmap.append(x)
         | 
| 91 | 
            +
                    # x = torch.flatten(x, 1, -1)
         | 
| 92 | 
            +
                    return x, fmap
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            class MultiScaleDiscriminator(MultiDiscriminator):
         | 
| 96 | 
            +
                """Multi-Scale (MSD) Discriminator,
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                Args:
         | 
| 99 | 
            +
                    in_channels (int): Number of input channels.
         | 
| 100 | 
            +
                    out_channels (int): Number of output channels.
         | 
| 101 | 
            +
                    downsample_factor (int): Downsampling factor between the different scales.
         | 
| 102 | 
            +
                    scale_norms (Sequence[str]): Normalization for each sub-discriminator.
         | 
| 103 | 
            +
                    **kwargs: Additional args for ScaleDiscriminator.
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
         | 
| 106 | 
            +
                             scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
         | 
| 107 | 
            +
                    super().__init__()
         | 
| 108 | 
            +
                    self.discriminators = nn.ModuleList([
         | 
| 109 | 
            +
                        ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
         | 
| 110 | 
            +
                    ])
         | 
| 111 | 
            +
                    self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                @property
         | 
| 114 | 
            +
                def num_discriminators(self):
         | 
| 115 | 
            +
                    return len(self.discriminators)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
         | 
| 118 | 
            +
                    logits = []
         | 
| 119 | 
            +
                    fmaps = []
         | 
| 120 | 
            +
                    for i, disc in enumerate(self.discriminators):
         | 
| 121 | 
            +
                        if i != 0:
         | 
| 122 | 
            +
                            self.downsample(x)
         | 
| 123 | 
            +
                        logit, fmap = disc(x)
         | 
| 124 | 
            +
                        logits.append(logit)
         | 
| 125 | 
            +
                        fmaps.append(fmap)
         | 
| 126 | 
            +
                    return logits, fmaps
         | 
    	
        audiocraft/adversarial/discriminators/msstftd.py
    ADDED
    
    | @@ -0,0 +1,134 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import typing as tp
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torchaudio
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from torch import nn
         | 
| 12 | 
            +
            from einops import rearrange
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from ...modules import NormConv2d
         | 
| 15 | 
            +
            from .base import MultiDiscriminator, MultiDiscriminatorOutputType
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
         | 
| 19 | 
            +
                return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class DiscriminatorSTFT(nn.Module):
         | 
| 23 | 
            +
                """STFT sub-discriminator.
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                Args:
         | 
| 26 | 
            +
                    filters (int): Number of filters in convolutions.
         | 
| 27 | 
            +
                    in_channels (int): Number of input channels.
         | 
| 28 | 
            +
                    out_channels (int): Number of output channels.
         | 
| 29 | 
            +
                    n_fft (int): Size of FFT for each scale.
         | 
| 30 | 
            +
                    hop_length (int): Length of hop between STFT windows for each scale.
         | 
| 31 | 
            +
                    kernel_size (tuple of int): Inner Conv2d kernel sizes.
         | 
| 32 | 
            +
                    stride (tuple of int): Inner Conv2d strides.
         | 
| 33 | 
            +
                    dilations (list of int): Inner Conv2d dilation on the time dimension.
         | 
| 34 | 
            +
                    win_length (int): Window size for each scale.
         | 
| 35 | 
            +
                    normalized (bool): Whether to normalize by magnitude after stft.
         | 
| 36 | 
            +
                    norm (str): Normalization method.
         | 
| 37 | 
            +
                    activation (str): Activation function.
         | 
| 38 | 
            +
                    activation_params (dict): Parameters to provide to the activation function.
         | 
| 39 | 
            +
                    growth (int): Growth factor for the filters.
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
         | 
| 42 | 
            +
                             n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
         | 
| 43 | 
            +
                             filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
         | 
| 44 | 
            +
                             stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
         | 
| 45 | 
            +
                             activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
         | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
                    assert len(kernel_size) == 2
         | 
| 48 | 
            +
                    assert len(stride) == 2
         | 
| 49 | 
            +
                    self.filters = filters
         | 
| 50 | 
            +
                    self.in_channels = in_channels
         | 
| 51 | 
            +
                    self.out_channels = out_channels
         | 
| 52 | 
            +
                    self.n_fft = n_fft
         | 
| 53 | 
            +
                    self.hop_length = hop_length
         | 
| 54 | 
            +
                    self.win_length = win_length
         | 
| 55 | 
            +
                    self.normalized = normalized
         | 
| 56 | 
            +
                    self.activation = getattr(torch.nn, activation)(**activation_params)
         | 
| 57 | 
            +
                    self.spec_transform = torchaudio.transforms.Spectrogram(
         | 
| 58 | 
            +
                        n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
         | 
| 59 | 
            +
                        normalized=self.normalized, center=False, pad_mode=None, power=None)
         | 
| 60 | 
            +
                    spec_channels = 2 * self.in_channels
         | 
| 61 | 
            +
                    self.convs = nn.ModuleList()
         | 
| 62 | 
            +
                    self.convs.append(
         | 
| 63 | 
            +
                        NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
         | 
| 64 | 
            +
                    )
         | 
| 65 | 
            +
                    in_chs = min(filters_scale * self.filters, max_filters)
         | 
| 66 | 
            +
                    for i, dilation in enumerate(dilations):
         | 
| 67 | 
            +
                        out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
         | 
| 68 | 
            +
                        self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
         | 
| 69 | 
            +
                                                     dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
         | 
| 70 | 
            +
                                                     norm=norm))
         | 
| 71 | 
            +
                        in_chs = out_chs
         | 
| 72 | 
            +
                    out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
         | 
| 73 | 
            +
                    self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
         | 
| 74 | 
            +
                                                 padding=get_2d_padding((kernel_size[0], kernel_size[0])),
         | 
| 75 | 
            +
                                                 norm=norm))
         | 
| 76 | 
            +
                    self.conv_post = NormConv2d(out_chs, self.out_channels,
         | 
| 77 | 
            +
                                                kernel_size=(kernel_size[0], kernel_size[0]),
         | 
| 78 | 
            +
                                                padding=get_2d_padding((kernel_size[0], kernel_size[0])),
         | 
| 79 | 
            +
                                                norm=norm)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def forward(self, x: torch.Tensor):
         | 
| 82 | 
            +
                    fmap = []
         | 
| 83 | 
            +
                    z = self.spec_transform(x)  # [B, 2, Freq, Frames, 2]
         | 
| 84 | 
            +
                    z = torch.cat([z.real, z.imag], dim=1)
         | 
| 85 | 
            +
                    z = rearrange(z, 'b c w t -> b c t w')
         | 
| 86 | 
            +
                    for i, layer in enumerate(self.convs):
         | 
| 87 | 
            +
                        z = layer(z)
         | 
| 88 | 
            +
                        z = self.activation(z)
         | 
| 89 | 
            +
                        fmap.append(z)
         | 
| 90 | 
            +
                    z = self.conv_post(z)
         | 
| 91 | 
            +
                    return z, fmap
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            class MultiScaleSTFTDiscriminator(MultiDiscriminator):
         | 
| 95 | 
            +
                """Multi-Scale STFT (MS-STFT) discriminator.
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                Args:
         | 
| 98 | 
            +
                    filters (int): Number of filters in convolutions.
         | 
| 99 | 
            +
                    in_channels (int): Number of input channels.
         | 
| 100 | 
            +
                    out_channels (int): Number of output channels.
         | 
| 101 | 
            +
                    sep_channels (bool): Separate channels to distinct samples for stereo support.
         | 
| 102 | 
            +
                    n_ffts (Sequence[int]): Size of FFT for each scale.
         | 
| 103 | 
            +
                    hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
         | 
| 104 | 
            +
                    win_lengths (Sequence[int]): Window size for each scale.
         | 
| 105 | 
            +
                    **kwargs: Additional args for STFTDiscriminator.
         | 
| 106 | 
            +
                """
         | 
| 107 | 
            +
                def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
         | 
| 108 | 
            +
                             n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
         | 
| 109 | 
            +
                             win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
         | 
| 110 | 
            +
                    super().__init__()
         | 
| 111 | 
            +
                    assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
         | 
| 112 | 
            +
                    self.sep_channels = sep_channels
         | 
| 113 | 
            +
                    self.discriminators = nn.ModuleList([
         | 
| 114 | 
            +
                        DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
         | 
| 115 | 
            +
                                          n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
         | 
| 116 | 
            +
                        for i in range(len(n_ffts))
         | 
| 117 | 
            +
                    ])
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                @property
         | 
| 120 | 
            +
                def num_discriminators(self):
         | 
| 121 | 
            +
                    return len(self.discriminators)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 124 | 
            +
                    B, C, T = x.shape
         | 
| 125 | 
            +
                    return x.view(-1, 1, T)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
         | 
| 128 | 
            +
                    logits = []
         | 
| 129 | 
            +
                    fmaps = []
         | 
| 130 | 
            +
                    for disc in self.discriminators:
         | 
| 131 | 
            +
                        logit, fmap = disc(x)
         | 
| 132 | 
            +
                        logits.append(logit)
         | 
| 133 | 
            +
                        fmaps.append(fmap)
         | 
| 134 | 
            +
                    return logits, fmaps
         | 
    	
        audiocraft/adversarial/losses.py
    ADDED
    
    | @@ -0,0 +1,228 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            Utility module to handle adversarial losses without requiring to mess up the main training loop.
         | 
| 9 | 
            +
            """
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import typing as tp
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import flashy
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import torch.nn as nn
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
         | 
| 23 | 
            +
            FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class AdversarialLoss(nn.Module):
         | 
| 27 | 
            +
                """Adversary training wrapper.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Args:
         | 
| 30 | 
            +
                    adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
         | 
| 31 | 
            +
                        We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
         | 
| 32 | 
            +
                        where the first item is a list of logits and the second item is a list of feature maps.
         | 
| 33 | 
            +
                    optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
         | 
| 34 | 
            +
                    loss (AdvLossType): Loss function for generator training.
         | 
| 35 | 
            +
                    loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
         | 
| 36 | 
            +
                    loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
         | 
| 37 | 
            +
                    loss_feat (FeatLossType): Feature matching loss function for generator training.
         | 
| 38 | 
            +
                    normalize (bool): Whether to normalize by number of sub-discriminators.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                Example of usage:
         | 
| 41 | 
            +
                    adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
         | 
| 42 | 
            +
                    for real in loader:
         | 
| 43 | 
            +
                        noise = torch.randn(...)
         | 
| 44 | 
            +
                        fake = model(noise)
         | 
| 45 | 
            +
                        adv_loss.train_adv(fake, real)
         | 
| 46 | 
            +
                        loss, _ = adv_loss(fake, real)
         | 
| 47 | 
            +
                        loss.backward()
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                def __init__(self,
         | 
| 50 | 
            +
                             adversary: nn.Module,
         | 
| 51 | 
            +
                             optimizer: torch.optim.Optimizer,
         | 
| 52 | 
            +
                             loss: AdvLossType,
         | 
| 53 | 
            +
                             loss_real: AdvLossType,
         | 
| 54 | 
            +
                             loss_fake: AdvLossType,
         | 
| 55 | 
            +
                             loss_feat: tp.Optional[FeatLossType] = None,
         | 
| 56 | 
            +
                             normalize: bool = True):
         | 
| 57 | 
            +
                    super().__init__()
         | 
| 58 | 
            +
                    self.adversary: nn.Module = adversary
         | 
| 59 | 
            +
                    flashy.distrib.broadcast_model(self.adversary)
         | 
| 60 | 
            +
                    self.optimizer = optimizer
         | 
| 61 | 
            +
                    self.loss = loss
         | 
| 62 | 
            +
                    self.loss_real = loss_real
         | 
| 63 | 
            +
                    self.loss_fake = loss_fake
         | 
| 64 | 
            +
                    self.loss_feat = loss_feat
         | 
| 65 | 
            +
                    self.normalize = normalize
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def _save_to_state_dict(self, destination, prefix, keep_vars):
         | 
| 68 | 
            +
                    # Add the optimizer state dict inside our own.
         | 
| 69 | 
            +
                    super()._save_to_state_dict(destination, prefix, keep_vars)
         | 
| 70 | 
            +
                    destination[prefix + 'optimizer'] = self.optimizer.state_dict()
         | 
| 71 | 
            +
                    return destination
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
         | 
| 74 | 
            +
                    # Load optimizer state.
         | 
| 75 | 
            +
                    self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
         | 
| 76 | 
            +
                    super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def get_adversary_pred(self, x):
         | 
| 79 | 
            +
                    """Run adversary model, validating expected output format."""
         | 
| 80 | 
            +
                    logits, fmaps = self.adversary(x)
         | 
| 81 | 
            +
                    assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
         | 
| 82 | 
            +
                        f'Expecting a list of tensors as logits but {type(logits)} found.'
         | 
| 83 | 
            +
                    assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
         | 
| 84 | 
            +
                    for fmap in fmaps:
         | 
| 85 | 
            +
                        assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
         | 
| 86 | 
            +
                            f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
         | 
| 87 | 
            +
                    return logits, fmaps
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
         | 
| 90 | 
            +
                    """Train the adversary with the given fake and real example.
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
         | 
| 93 | 
            +
                    The first item being the logits and second item being a list of feature maps for each sub-discriminator.
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
         | 
| 96 | 
            +
                    and call the optimizer.
         | 
| 97 | 
            +
                    """
         | 
| 98 | 
            +
                    loss = torch.tensor(0., device=fake.device)
         | 
| 99 | 
            +
                    all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
         | 
| 100 | 
            +
                    all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
         | 
| 101 | 
            +
                    n_sub_adversaries = len(all_logits_fake_is_fake)
         | 
| 102 | 
            +
                    for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
         | 
| 103 | 
            +
                        loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    if self.normalize:
         | 
| 106 | 
            +
                        loss /= n_sub_adversaries
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    self.optimizer.zero_grad()
         | 
| 109 | 
            +
                    with flashy.distrib.eager_sync_model(self.adversary):
         | 
| 110 | 
            +
                        loss.backward()
         | 
| 111 | 
            +
                    self.optimizer.step()
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    return loss
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
         | 
| 116 | 
            +
                    """Return the loss for the generator, i.e. trying to fool the adversary,
         | 
| 117 | 
            +
                    and feature matching loss if provided.
         | 
| 118 | 
            +
                    """
         | 
| 119 | 
            +
                    adv = torch.tensor(0., device=fake.device)
         | 
| 120 | 
            +
                    feat = torch.tensor(0., device=fake.device)
         | 
| 121 | 
            +
                    with flashy.utils.readonly(self.adversary):
         | 
| 122 | 
            +
                        all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
         | 
| 123 | 
            +
                        all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
         | 
| 124 | 
            +
                        n_sub_adversaries = len(all_logits_fake_is_fake)
         | 
| 125 | 
            +
                        for logit_fake_is_fake in all_logits_fake_is_fake:
         | 
| 126 | 
            +
                            adv += self.loss(logit_fake_is_fake)
         | 
| 127 | 
            +
                        if self.loss_feat:
         | 
| 128 | 
            +
                            for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
         | 
| 129 | 
            +
                                feat += self.loss_feat(fmap_fake, fmap_real)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    if self.normalize:
         | 
| 132 | 
            +
                        adv /= n_sub_adversaries
         | 
| 133 | 
            +
                        feat /= n_sub_adversaries
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    return adv, feat
         | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            def get_adv_criterion(loss_type: str) -> tp.Callable:
         | 
| 139 | 
            +
                assert loss_type in ADVERSARIAL_LOSSES
         | 
| 140 | 
            +
                if loss_type == 'mse':
         | 
| 141 | 
            +
                    return mse_loss
         | 
| 142 | 
            +
                elif loss_type == 'hinge':
         | 
| 143 | 
            +
                    return hinge_loss
         | 
| 144 | 
            +
                elif loss_type == 'hinge2':
         | 
| 145 | 
            +
                    return hinge2_loss
         | 
| 146 | 
            +
                raise ValueError('Unsupported loss')
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            def get_fake_criterion(loss_type: str) -> tp.Callable:
         | 
| 150 | 
            +
                assert loss_type in ADVERSARIAL_LOSSES
         | 
| 151 | 
            +
                if loss_type == 'mse':
         | 
| 152 | 
            +
                    return mse_fake_loss
         | 
| 153 | 
            +
                elif loss_type in ['hinge', 'hinge2']:
         | 
| 154 | 
            +
                    return hinge_fake_loss
         | 
| 155 | 
            +
                raise ValueError('Unsupported loss')
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            def get_real_criterion(loss_type: str) -> tp.Callable:
         | 
| 159 | 
            +
                assert loss_type in ADVERSARIAL_LOSSES
         | 
| 160 | 
            +
                if loss_type == 'mse':
         | 
| 161 | 
            +
                    return mse_real_loss
         | 
| 162 | 
            +
                elif loss_type in ['hinge', 'hinge2']:
         | 
| 163 | 
            +
                    return hinge_real_loss
         | 
| 164 | 
            +
                raise ValueError('Unsupported loss')
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 168 | 
            +
                return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 172 | 
            +
                return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
         | 
| 173 | 
            +
             | 
| 174 | 
            +
             | 
| 175 | 
            +
            def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 176 | 
            +
                return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 180 | 
            +
                return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            def mse_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 184 | 
            +
                if x.numel() == 0:
         | 
| 185 | 
            +
                    return torch.tensor([0.0], device=x.device)
         | 
| 186 | 
            +
                return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
         | 
| 187 | 
            +
             | 
| 188 | 
            +
             | 
| 189 | 
            +
            def hinge_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 190 | 
            +
                if x.numel() == 0:
         | 
| 191 | 
            +
                    return torch.tensor([0.0], device=x.device)
         | 
| 192 | 
            +
                return -x.mean()
         | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| 195 | 
            +
            def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
         | 
| 196 | 
            +
                if x.numel() == 0:
         | 
| 197 | 
            +
                    return torch.tensor([0.0])
         | 
| 198 | 
            +
                return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
         | 
| 199 | 
            +
             | 
| 200 | 
            +
             | 
| 201 | 
            +
            class FeatureMatchingLoss(nn.Module):
         | 
| 202 | 
            +
                """Feature matching loss for adversarial training.
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                Args:
         | 
| 205 | 
            +
                    loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
         | 
| 206 | 
            +
                    normalize (bool): Whether to normalize the loss.
         | 
| 207 | 
            +
                        by number of feature maps.
         | 
| 208 | 
            +
                """
         | 
| 209 | 
            +
                def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
         | 
| 210 | 
            +
                    super().__init__()
         | 
| 211 | 
            +
                    self.loss = loss
         | 
| 212 | 
            +
                    self.normalize = normalize
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
         | 
| 215 | 
            +
                    assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
         | 
| 216 | 
            +
                    feat_loss = torch.tensor(0., device=fmap_fake[0].device)
         | 
| 217 | 
            +
                    feat_scale = torch.tensor(0., device=fmap_fake[0].device)
         | 
| 218 | 
            +
                    n_fmaps = 0
         | 
| 219 | 
            +
                    for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
         | 
| 220 | 
            +
                        assert feat_fake.shape == feat_real.shape
         | 
| 221 | 
            +
                        n_fmaps += 1
         | 
| 222 | 
            +
                        feat_loss += self.loss(feat_fake, feat_real)
         | 
| 223 | 
            +
                        feat_scale += torch.mean(torch.abs(feat_real))
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    if self.normalize:
         | 
| 226 | 
            +
                        feat_loss /= n_fmaps
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    return feat_loss
         | 
    	
        audiocraft/data/__init__.py
    CHANGED
    
    | @@ -3,6 +3,8 @@ | |
| 3 | 
             
            #
         | 
| 4 | 
             
            # This source code is licensed under the license found in the
         | 
| 5 | 
             
            # LICENSE file in the root directory of this source tree.
         | 
|  | |
|  | |
| 6 |  | 
| 7 | 
             
            # flake8: noqa
         | 
| 8 | 
            -
            from . import audio, audio_dataset
         | 
|  | |
| 3 | 
             
            #
         | 
| 4 | 
             
            # This source code is licensed under the license found in the
         | 
| 5 | 
             
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Audio loading and writing support. Datasets for raw audio
         | 
| 7 | 
            +
            or also including some metadata."""
         | 
| 8 |  | 
| 9 | 
             
            # flake8: noqa
         | 
| 10 | 
            +
            from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset
         | 
    	
        audiocraft/data/audio.py
    CHANGED
    
    | @@ -18,11 +18,11 @@ import numpy as np | |
| 18 | 
             
            import soundfile
         | 
| 19 | 
             
            import torch
         | 
| 20 | 
             
            from torch.nn import functional as F
         | 
| 21 | 
            -
            import torchaudio as ta
         | 
| 22 |  | 
| 23 | 
             
            import av
         | 
|  | |
| 24 |  | 
| 25 | 
            -
            from .audio_utils import f32_pcm,  | 
| 26 |  | 
| 27 |  | 
| 28 | 
             
            _av_initialized = False
         | 
| @@ -78,7 +78,7 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa | |
| 78 | 
             
                    seek_time (float): Time at which to start reading in the file.
         | 
| 79 | 
             
                    duration (float): Duration to read from the file. If set to -1, the whole file is read.
         | 
| 80 | 
             
                Returns:
         | 
| 81 | 
            -
                     | 
| 82 | 
             
                """
         | 
| 83 | 
             
                _init_av()
         | 
| 84 | 
             
                with av.open(str(filepath)) as af:
         | 
| @@ -123,7 +123,7 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., | |
| 123 | 
             
                    duration (float): Duration to read from the file. If set to -1, the whole file is read.
         | 
| 124 | 
             
                    pad (bool): Pad output audio if not reaching expected duration.
         | 
| 125 | 
             
                Returns:
         | 
| 126 | 
            -
                     | 
| 127 | 
             
                """
         | 
| 128 | 
             
                fp = Path(filepath)
         | 
| 129 | 
             
                if fp.suffix in ['.flac', '.ogg']:  # TODO: check if we can safely use av_read for .ogg
         | 
| @@ -136,12 +136,6 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., | |
| 136 | 
             
                    wav = torch.from_numpy(wav).t().contiguous()
         | 
| 137 | 
             
                    if len(wav.shape) == 1:
         | 
| 138 | 
             
                        wav = torch.unsqueeze(wav, 0)
         | 
| 139 | 
            -
                elif (
         | 
| 140 | 
            -
                    fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
         | 
| 141 | 
            -
                    and duration <= 0 and seek_time == 0
         | 
| 142 | 
            -
                ):
         | 
| 143 | 
            -
                    # Torchaudio is faster if we load an entire file at once.
         | 
| 144 | 
            -
                    wav, sr = ta.load(fp)
         | 
| 145 | 
             
                else:
         | 
| 146 | 
             
                    wav, sr = _av_read(filepath, seek_time, duration)
         | 
| 147 | 
             
                if pad and duration > 0:
         | 
| @@ -150,10 +144,22 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., | |
| 150 | 
             
                return wav, sr
         | 
| 151 |  | 
| 152 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 153 | 
             
            def audio_write(stem_name: tp.Union[str, Path],
         | 
| 154 | 
             
                            wav: torch.Tensor, sample_rate: int,
         | 
| 155 | 
            -
                            format: str = 'wav', mp3_rate: int = 320,  | 
| 156 | 
            -
                            strategy: str = 'peak', peak_clip_headroom_db: float = 1,
         | 
| 157 | 
             
                            rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
         | 
| 158 | 
             
                            loudness_compressor: bool = False,
         | 
| 159 | 
             
                            log_clipping: bool = True, make_parent_dir: bool = True,
         | 
| @@ -162,8 +168,11 @@ def audio_write(stem_name: tp.Union[str, Path], | |
| 162 |  | 
| 163 | 
             
                Args:
         | 
| 164 | 
             
                    stem_name (str or Path): Filename without extension which will be added automatically.
         | 
| 165 | 
            -
                     | 
|  | |
|  | |
| 166 | 
             
                    mp3_rate (int): kbps when using mp3s.
         | 
|  | |
| 167 | 
             
                    normalize (bool): if `True` (default), normalizes according to the prescribed
         | 
| 168 | 
             
                        strategy (see after). If `False`, the strategy is only used in case clipping
         | 
| 169 | 
             
                        would happen.
         | 
| @@ -175,7 +184,7 @@ def audio_write(stem_name: tp.Union[str, Path], | |
| 175 | 
             
                        than the `peak_clip` one to avoid further clipping.
         | 
| 176 | 
             
                    loudness_headroom_db (float): Target loudness for loudness normalization.
         | 
| 177 | 
             
                    loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
         | 
| 178 | 
            -
                     when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
         | 
| 179 | 
             
                        occurs despite strategy (only for 'rms').
         | 
| 180 | 
             
                    make_parent_dir (bool): Make parent directory if it doesn't exist.
         | 
| 181 | 
             
                Returns:
         | 
| @@ -188,16 +197,23 @@ def audio_write(stem_name: tp.Union[str, Path], | |
| 188 | 
             
                    raise ValueError("Input wav should be at most 2 dimension.")
         | 
| 189 | 
             
                assert wav.isfinite().all()
         | 
| 190 | 
             
                wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
         | 
| 191 | 
            -
                                      rms_headroom_db, loudness_headroom_db,  | 
| 192 | 
            -
                                      sample_rate=sample_rate, | 
| 193 | 
            -
             | 
| 194 | 
             
                if format == 'mp3':
         | 
| 195 | 
             
                    suffix = '.mp3'
         | 
| 196 | 
            -
                     | 
| 197 | 
             
                elif format == 'wav':
         | 
| 198 | 
            -
                    wav = i16_pcm(wav)
         | 
| 199 | 
             
                    suffix = '.wav'
         | 
| 200 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 201 | 
             
                else:
         | 
| 202 | 
             
                    raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
         | 
| 203 | 
             
                if not add_suffix:
         | 
| @@ -206,7 +222,7 @@ def audio_write(stem_name: tp.Union[str, Path], | |
| 206 | 
             
                if make_parent_dir:
         | 
| 207 | 
             
                    path.parent.mkdir(exist_ok=True, parents=True)
         | 
| 208 | 
             
                try:
         | 
| 209 | 
            -
                     | 
| 210 | 
             
                except Exception:
         | 
| 211 | 
             
                    if path.exists():
         | 
| 212 | 
             
                        # we do not want to leave half written files around.
         | 
|  | |
| 18 | 
             
            import soundfile
         | 
| 19 | 
             
            import torch
         | 
| 20 | 
             
            from torch.nn import functional as F
         | 
|  | |
| 21 |  | 
| 22 | 
             
            import av
         | 
| 23 | 
            +
            import subprocess as sp
         | 
| 24 |  | 
| 25 | 
            +
            from .audio_utils import f32_pcm, normalize_audio
         | 
| 26 |  | 
| 27 |  | 
| 28 | 
             
            _av_initialized = False
         | 
|  | |
| 78 | 
             
                    seek_time (float): Time at which to start reading in the file.
         | 
| 79 | 
             
                    duration (float): Duration to read from the file. If set to -1, the whole file is read.
         | 
| 80 | 
             
                Returns:
         | 
| 81 | 
            +
                    tuple of torch.Tensor, int: Tuple containing audio data and sample rate
         | 
| 82 | 
             
                """
         | 
| 83 | 
             
                _init_av()
         | 
| 84 | 
             
                with av.open(str(filepath)) as af:
         | 
|  | |
| 123 | 
             
                    duration (float): Duration to read from the file. If set to -1, the whole file is read.
         | 
| 124 | 
             
                    pad (bool): Pad output audio if not reaching expected duration.
         | 
| 125 | 
             
                Returns:
         | 
| 126 | 
            +
                    tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
         | 
| 127 | 
             
                """
         | 
| 128 | 
             
                fp = Path(filepath)
         | 
| 129 | 
             
                if fp.suffix in ['.flac', '.ogg']:  # TODO: check if we can safely use av_read for .ogg
         | 
|  | |
| 136 | 
             
                    wav = torch.from_numpy(wav).t().contiguous()
         | 
| 137 | 
             
                    if len(wav.shape) == 1:
         | 
| 138 | 
             
                        wav = torch.unsqueeze(wav, 0)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 139 | 
             
                else:
         | 
| 140 | 
             
                    wav, sr = _av_read(filepath, seek_time, duration)
         | 
| 141 | 
             
                if pad and duration > 0:
         | 
|  | |
| 144 | 
             
                return wav, sr
         | 
| 145 |  | 
| 146 |  | 
| 147 | 
            +
            def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
         | 
| 148 | 
            +
                # ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
         | 
| 149 | 
            +
                assert wav.dim() == 2, wav.shape
         | 
| 150 | 
            +
                command = [
         | 
| 151 | 
            +
                    'ffmpeg',
         | 
| 152 | 
            +
                    '-loglevel', 'error',
         | 
| 153 | 
            +
                    '-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
         | 
| 154 | 
            +
                    '-i', '-'] + flags + [str(out_path)]
         | 
| 155 | 
            +
                input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
         | 
| 156 | 
            +
                sp.run(command, input=input_, check=True)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
             
            def audio_write(stem_name: tp.Union[str, Path],
         | 
| 160 | 
             
                            wav: torch.Tensor, sample_rate: int,
         | 
| 161 | 
            +
                            format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
         | 
| 162 | 
            +
                            normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
         | 
| 163 | 
             
                            rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
         | 
| 164 | 
             
                            loudness_compressor: bool = False,
         | 
| 165 | 
             
                            log_clipping: bool = True, make_parent_dir: bool = True,
         | 
|  | |
| 168 |  | 
| 169 | 
             
                Args:
         | 
| 170 | 
             
                    stem_name (str or Path): Filename without extension which will be added automatically.
         | 
| 171 | 
            +
                    wav (torch.Tensor): Audio data to save.
         | 
| 172 | 
            +
                    sample_rate (int): Sample rate of audio data.
         | 
| 173 | 
            +
                    format (str): Either "wav", "mp3", "ogg", or "flac".
         | 
| 174 | 
             
                    mp3_rate (int): kbps when using mp3s.
         | 
| 175 | 
            +
                    ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
         | 
| 176 | 
             
                    normalize (bool): if `True` (default), normalizes according to the prescribed
         | 
| 177 | 
             
                        strategy (see after). If `False`, the strategy is only used in case clipping
         | 
| 178 | 
             
                        would happen.
         | 
|  | |
| 184 | 
             
                        than the `peak_clip` one to avoid further clipping.
         | 
| 185 | 
             
                    loudness_headroom_db (float): Target loudness for loudness normalization.
         | 
| 186 | 
             
                    loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
         | 
| 187 | 
            +
                     when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
         | 
| 188 | 
             
                        occurs despite strategy (only for 'rms').
         | 
| 189 | 
             
                    make_parent_dir (bool): Make parent directory if it doesn't exist.
         | 
| 190 | 
             
                Returns:
         | 
|  | |
| 197 | 
             
                    raise ValueError("Input wav should be at most 2 dimension.")
         | 
| 198 | 
             
                assert wav.isfinite().all()
         | 
| 199 | 
             
                wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
         | 
| 200 | 
            +
                                      rms_headroom_db, loudness_headroom_db, loudness_compressor,
         | 
| 201 | 
            +
                                      log_clipping=log_clipping, sample_rate=sample_rate,
         | 
| 202 | 
            +
                                      stem_name=str(stem_name))
         | 
| 203 | 
             
                if format == 'mp3':
         | 
| 204 | 
             
                    suffix = '.mp3'
         | 
| 205 | 
            +
                    flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
         | 
| 206 | 
             
                elif format == 'wav':
         | 
|  | |
| 207 | 
             
                    suffix = '.wav'
         | 
| 208 | 
            +
                    flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
         | 
| 209 | 
            +
                elif format == 'ogg':
         | 
| 210 | 
            +
                    suffix = '.ogg'
         | 
| 211 | 
            +
                    flags = ['-f', 'ogg', '-c:a', 'libvorbis']
         | 
| 212 | 
            +
                    if ogg_rate is not None:
         | 
| 213 | 
            +
                        flags += ['-b:a', f'{ogg_rate}k']
         | 
| 214 | 
            +
                elif format == 'flac':
         | 
| 215 | 
            +
                    suffix = '.flac'
         | 
| 216 | 
            +
                    flags = ['-f', 'flac']
         | 
| 217 | 
             
                else:
         | 
| 218 | 
             
                    raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
         | 
| 219 | 
             
                if not add_suffix:
         | 
|  | |
| 222 | 
             
                if make_parent_dir:
         | 
| 223 | 
             
                    path.parent.mkdir(exist_ok=True, parents=True)
         | 
| 224 | 
             
                try:
         | 
| 225 | 
            +
                    _piping_to_ffmpeg(path, wav, sample_rate, flags)
         | 
| 226 | 
             
                except Exception:
         | 
| 227 | 
             
                    if path.exists():
         | 
| 228 | 
             
                        # we do not want to leave half written files around.
         | 
    	
        audiocraft/data/audio_dataset.py
    CHANGED
    
    | @@ -3,12 +3,16 @@ | |
| 3 | 
             
            #
         | 
| 4 | 
             
            # This source code is licensed under the license found in the
         | 
| 5 | 
             
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            -
             | 
|  | |
|  | |
|  | |
| 7 | 
             
            import argparse
         | 
| 8 | 
             
            import copy
         | 
| 9 | 
             
            from concurrent.futures import ThreadPoolExecutor, Future
         | 
| 10 | 
             
            from dataclasses import dataclass, fields
         | 
| 11 | 
             
            from contextlib import ExitStack
         | 
|  | |
| 12 | 
             
            import gzip
         | 
| 13 | 
             
            import json
         | 
| 14 | 
             
            import logging
         | 
| @@ -81,9 +85,12 @@ class AudioMeta(BaseInfo): | |
| 81 | 
             
            class SegmentInfo(BaseInfo):
         | 
| 82 | 
             
                meta: AudioMeta
         | 
| 83 | 
             
                seek_time: float
         | 
| 84 | 
            -
                 | 
|  | |
|  | |
| 85 | 
             
                total_frames: int  # total number of frames, padding included
         | 
| 86 | 
            -
                sample_rate: int | 
|  | |
| 87 |  | 
| 88 |  | 
| 89 | 
             
            DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
         | 
| @@ -114,8 +121,8 @@ def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta: | |
| 114 |  | 
| 115 | 
             
                Args:
         | 
| 116 | 
             
                    m (AudioMeta): Audio meta to resolve.
         | 
| 117 | 
            -
                    fast (bool): If True, uses a really fast check for determining if a file | 
| 118 | 
            -
                        Only valid on Linux/Mac.
         | 
| 119 | 
             
                Returns:
         | 
| 120 | 
             
                    AudioMeta: Audio meta with resolved path.
         | 
| 121 | 
             
                """
         | 
| @@ -151,7 +158,7 @@ def find_audio_files(path: tp.Union[Path, str], | |
| 151 | 
             
                    progress (bool): Whether to log progress on audio files collection.
         | 
| 152 | 
             
                    workers (int): number of parallel workers, if 0, use only the current thread.
         | 
| 153 | 
             
                Returns:
         | 
| 154 | 
            -
                     | 
| 155 | 
             
                """
         | 
| 156 | 
             
                audio_files = []
         | 
| 157 | 
             
                futures: tp.List[Future] = []
         | 
| @@ -203,7 +210,7 @@ def load_audio_meta(path: tp.Union[str, Path], | |
| 203 | 
             
                    resolve (bool): Whether to resolve the path from AudioMeta (default=True).
         | 
| 204 | 
             
                    fast (bool): activates some tricks to make things faster.
         | 
| 205 | 
             
                Returns:
         | 
| 206 | 
            -
                     | 
| 207 | 
             
                """
         | 
| 208 | 
             
                open_fn = gzip.open if str(path).lower().endswith('.gz') else open
         | 
| 209 | 
             
                with open_fn(path, 'rb') as fp:  # type: ignore
         | 
| @@ -250,9 +257,14 @@ class AudioDataset: | |
| 250 | 
             
                allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
         | 
| 251 | 
             
                original audio meta.
         | 
| 252 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 253 | 
             
                Args:
         | 
| 254 | 
            -
                    meta ( | 
| 255 | 
            -
                    segment_duration (float): Optional segment duration of audio to load.
         | 
| 256 | 
             
                        If not specified, the dataset will load the full audio segment from the file.
         | 
| 257 | 
             
                    shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
         | 
| 258 | 
             
                    sample_rate (int): Target sample rate of the loaded audio samples.
         | 
| @@ -266,10 +278,19 @@ class AudioDataset: | |
| 266 | 
             
                        is shorter than the desired segment.
         | 
| 267 | 
             
                    max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
         | 
| 268 | 
             
                    return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
         | 
| 269 | 
            -
                    min_audio_duration ( | 
| 270 | 
             
                        audio shorter than this will be filtered out.
         | 
| 271 | 
            -
                    max_audio_duration ( | 
| 272 | 
             
                        audio longer than this will be filtered out.
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 273 | 
             
                """
         | 
| 274 | 
             
                def __init__(self,
         | 
| 275 | 
             
                             meta: tp.List[AudioMeta],
         | 
| @@ -285,16 +306,14 @@ class AudioDataset: | |
| 285 | 
             
                             max_read_retry: int = 10,
         | 
| 286 | 
             
                             return_info: bool = False,
         | 
| 287 | 
             
                             min_audio_duration: tp.Optional[float] = None,
         | 
| 288 | 
            -
                             max_audio_duration: tp.Optional[float] = None
         | 
|  | |
|  | |
|  | |
| 289 | 
             
                             ):
         | 
| 290 | 
            -
                    assert len(meta) > 0,  | 
| 291 | 
             
                    assert segment_duration is None or segment_duration > 0
         | 
| 292 | 
             
                    assert segment_duration is None or min_segment_ratio >= 0
         | 
| 293 | 
            -
                    logging.debug(f'sample_on_duration: {sample_on_duration}')
         | 
| 294 | 
            -
                    logging.debug(f'sample_on_weight: {sample_on_weight}')
         | 
| 295 | 
            -
                    logging.debug(f'pad: {pad}')
         | 
| 296 | 
            -
                    logging.debug(f'min_segment_ratio: {min_segment_ratio}')
         | 
| 297 | 
            -
             | 
| 298 | 
             
                    self.segment_duration = segment_duration
         | 
| 299 | 
             
                    self.min_segment_ratio = min_segment_ratio
         | 
| 300 | 
             
                    self.max_audio_duration = max_audio_duration
         | 
| @@ -317,13 +336,25 @@ class AudioDataset: | |
| 317 | 
             
                    self.sampling_probabilities = self._get_sampling_probabilities()
         | 
| 318 | 
             
                    self.max_read_retry = max_read_retry
         | 
| 319 | 
             
                    self.return_info = return_info
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 320 |  | 
| 321 | 
             
                def __len__(self):
         | 
| 322 | 
             
                    return self.num_samples
         | 
| 323 |  | 
| 324 | 
             
                def _get_sampling_probabilities(self, normalized: bool = True):
         | 
| 325 | 
            -
                    """Return the sampling probabilities for each file inside `self.meta`.
         | 
| 326 | 
            -
                    """
         | 
| 327 | 
             
                    scores: tp.List[float] = []
         | 
| 328 | 
             
                    for file_meta in self.meta:
         | 
| 329 | 
             
                        score = 1.
         | 
| @@ -337,12 +368,32 @@ class AudioDataset: | |
| 337 | 
             
                        probabilities /= probabilities.sum()
         | 
| 338 | 
             
                    return probabilities
         | 
| 339 |  | 
| 340 | 
            -
                 | 
| 341 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 342 | 
             
                    This is only called if `segment_duration` is not None.
         | 
| 343 |  | 
| 344 | 
             
                    You must use the provided random number generator `rng` for reproducibility.
         | 
|  | |
| 345 | 
             
                    """
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 346 | 
             
                    if not self.sample_on_weight and not self.sample_on_duration:
         | 
| 347 | 
             
                        file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
         | 
| 348 | 
             
                    else:
         | 
| @@ -350,6 +401,15 @@ class AudioDataset: | |
| 350 |  | 
| 351 | 
             
                    return self.meta[file_index]
         | 
| 352 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 353 | 
             
                def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
         | 
| 354 | 
             
                    if self.segment_duration is None:
         | 
| 355 | 
             
                        file_meta = self.meta[index]
         | 
| @@ -357,18 +417,22 @@ class AudioDataset: | |
| 357 | 
             
                        out = convert_audio(out, sr, self.sample_rate, self.channels)
         | 
| 358 | 
             
                        n_frames = out.shape[-1]
         | 
| 359 | 
             
                        segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
         | 
| 360 | 
            -
                                                   sample_rate=self.sample_rate)
         | 
| 361 | 
             
                    else:
         | 
| 362 | 
             
                        rng = torch.Generator()
         | 
| 363 | 
             
                        if self.shuffle:
         | 
| 364 | 
            -
                            # We use index, plus extra randomness
         | 
| 365 | 
            -
                             | 
|  | |
|  | |
|  | |
|  | |
| 366 | 
             
                        else:
         | 
| 367 | 
             
                            # We only use index
         | 
| 368 | 
             
                            rng.manual_seed(index)
         | 
| 369 |  | 
| 370 | 
             
                        for retry in range(self.max_read_retry):
         | 
| 371 | 
            -
                            file_meta = self.sample_file(rng)
         | 
| 372 | 
             
                            # We add some variance in the file position even if audio file is smaller than segment
         | 
| 373 | 
             
                            # without ending up with empty segments
         | 
| 374 | 
             
                            max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
         | 
| @@ -381,7 +445,7 @@ class AudioDataset: | |
| 381 | 
             
                                if self.pad:
         | 
| 382 | 
             
                                    out = F.pad(out, (0, target_frames - n_frames))
         | 
| 383 | 
             
                                segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
         | 
| 384 | 
            -
                                                           sample_rate=self.sample_rate)
         | 
| 385 | 
             
                            except Exception as exc:
         | 
| 386 | 
             
                                logger.warning("Error opening file %s: %r", file_meta.path, exc)
         | 
| 387 | 
             
                                if retry == self.max_read_retry - 1:
         | 
| @@ -423,7 +487,7 @@ class AudioDataset: | |
| 423 | 
             
                        if to_pad:
         | 
| 424 | 
             
                            # Each wav could be of a different duration as they are not segmented.
         | 
| 425 | 
             
                            for i in range(len(samples)):
         | 
| 426 | 
            -
                                # Determines the total  | 
| 427 | 
             
                                segment_infos[i].total_frames = max_len
         | 
| 428 | 
             
                                wavs[i] = _pad_wav(wavs[i])
         | 
| 429 |  | 
| @@ -436,9 +500,7 @@ class AudioDataset: | |
| 436 | 
             
                        return torch.stack(samples)
         | 
| 437 |  | 
| 438 | 
             
                def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
         | 
| 439 | 
            -
                    """Filters out audio files with  | 
| 440 | 
            -
                    Removes from meta files that have durations that will not allow to samples examples from them.
         | 
| 441 | 
            -
                    """
         | 
| 442 | 
             
                    orig_len = len(meta)
         | 
| 443 |  | 
| 444 | 
             
                    # Filter data that is too short.
         | 
|  | |
| 3 | 
             
            #
         | 
| 4 | 
             
            # This source code is licensed under the license found in the
         | 
| 5 | 
             
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """AudioDataset support. In order to handle a larger number of files
         | 
| 7 | 
            +
            without having to scan again the folders, we precompute some metadata
         | 
| 8 | 
            +
            (filename, sample rate, duration), and use that to efficiently sample audio segments.
         | 
| 9 | 
            +
            """
         | 
| 10 | 
             
            import argparse
         | 
| 11 | 
             
            import copy
         | 
| 12 | 
             
            from concurrent.futures import ThreadPoolExecutor, Future
         | 
| 13 | 
             
            from dataclasses import dataclass, fields
         | 
| 14 | 
             
            from contextlib import ExitStack
         | 
| 15 | 
            +
            from functools import lru_cache
         | 
| 16 | 
             
            import gzip
         | 
| 17 | 
             
            import json
         | 
| 18 | 
             
            import logging
         | 
|  | |
| 85 | 
             
            class SegmentInfo(BaseInfo):
         | 
| 86 | 
             
                meta: AudioMeta
         | 
| 87 | 
             
                seek_time: float
         | 
| 88 | 
            +
                # The following values are given once the audio is processed, e.g.
         | 
| 89 | 
            +
                # at the target sample rate and target number of channels.
         | 
| 90 | 
            +
                n_frames: int      # actual number of frames without padding
         | 
| 91 | 
             
                total_frames: int  # total number of frames, padding included
         | 
| 92 | 
            +
                sample_rate: int   # actual sample rate
         | 
| 93 | 
            +
                channels: int      # number of audio channels.
         | 
| 94 |  | 
| 95 |  | 
| 96 | 
             
            DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
         | 
|  | |
| 121 |  | 
| 122 | 
             
                Args:
         | 
| 123 | 
             
                    m (AudioMeta): Audio meta to resolve.
         | 
| 124 | 
            +
                    fast (bool): If True, uses a really fast check for determining if a file
         | 
| 125 | 
            +
                        is already absolute or not. Only valid on Linux/Mac.
         | 
| 126 | 
             
                Returns:
         | 
| 127 | 
             
                    AudioMeta: Audio meta with resolved path.
         | 
| 128 | 
             
                """
         | 
|  | |
| 158 | 
             
                    progress (bool): Whether to log progress on audio files collection.
         | 
| 159 | 
             
                    workers (int): number of parallel workers, if 0, use only the current thread.
         | 
| 160 | 
             
                Returns:
         | 
| 161 | 
            +
                    list of AudioMeta: List of audio file path and its metadata.
         | 
| 162 | 
             
                """
         | 
| 163 | 
             
                audio_files = []
         | 
| 164 | 
             
                futures: tp.List[Future] = []
         | 
|  | |
| 210 | 
             
                    resolve (bool): Whether to resolve the path from AudioMeta (default=True).
         | 
| 211 | 
             
                    fast (bool): activates some tricks to make things faster.
         | 
| 212 | 
             
                Returns:
         | 
| 213 | 
            +
                    list of AudioMeta: List of audio file path and its total duration.
         | 
| 214 | 
             
                """
         | 
| 215 | 
             
                open_fn = gzip.open if str(path).lower().endswith('.gz') else open
         | 
| 216 | 
             
                with open_fn(path, 'rb') as fp:  # type: ignore
         | 
|  | |
| 257 | 
             
                allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
         | 
| 258 | 
             
                original audio meta.
         | 
| 259 |  | 
| 260 | 
            +
                Note that you can call `start_epoch(epoch)` in order to get
         | 
| 261 | 
            +
                a deterministic "randomization" for `shuffle=True`.
         | 
| 262 | 
            +
                For a given epoch and dataset index, this will always return the same extract.
         | 
| 263 | 
            +
                You can get back some diversity by setting the `shuffle_seed` param.
         | 
| 264 | 
            +
             | 
| 265 | 
             
                Args:
         | 
| 266 | 
            +
                    meta (list of AudioMeta): List of audio files metadata.
         | 
| 267 | 
            +
                    segment_duration (float, optional): Optional segment duration of audio to load.
         | 
| 268 | 
             
                        If not specified, the dataset will load the full audio segment from the file.
         | 
| 269 | 
             
                    shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
         | 
| 270 | 
             
                    sample_rate (int): Target sample rate of the loaded audio samples.
         | 
|  | |
| 278 | 
             
                        is shorter than the desired segment.
         | 
| 279 | 
             
                    max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
         | 
| 280 | 
             
                    return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
         | 
| 281 | 
            +
                    min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
         | 
| 282 | 
             
                        audio shorter than this will be filtered out.
         | 
| 283 | 
            +
                    max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
         | 
| 284 | 
             
                        audio longer than this will be filtered out.
         | 
| 285 | 
            +
                    shuffle_seed (int): can be used to further randomize
         | 
| 286 | 
            +
                    load_wav (bool): if False, skip loading the wav but returns a tensor of 0
         | 
| 287 | 
            +
                        with the expected segment_duration (which must be provided if load_wav is False).
         | 
| 288 | 
            +
                    permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
         | 
| 289 | 
            +
                        are False. Will ensure a permutation on files when going through the dataset.
         | 
| 290 | 
            +
                        In that case the epoch number must be provided in order for the model
         | 
| 291 | 
            +
                        to continue the permutation across epochs. In that case, it is assumed
         | 
| 292 | 
            +
                        that `num_samples = total_batch_size * num_updates_per_epoch`, with
         | 
| 293 | 
            +
                        `total_batch_size` the overall batch size accounting for all gpus.
         | 
| 294 | 
             
                """
         | 
| 295 | 
             
                def __init__(self,
         | 
| 296 | 
             
                             meta: tp.List[AudioMeta],
         | 
|  | |
| 306 | 
             
                             max_read_retry: int = 10,
         | 
| 307 | 
             
                             return_info: bool = False,
         | 
| 308 | 
             
                             min_audio_duration: tp.Optional[float] = None,
         | 
| 309 | 
            +
                             max_audio_duration: tp.Optional[float] = None,
         | 
| 310 | 
            +
                             shuffle_seed: int = 0,
         | 
| 311 | 
            +
                             load_wav: bool = True,
         | 
| 312 | 
            +
                             permutation_on_files: bool = False,
         | 
| 313 | 
             
                             ):
         | 
| 314 | 
            +
                    assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
         | 
| 315 | 
             
                    assert segment_duration is None or segment_duration > 0
         | 
| 316 | 
             
                    assert segment_duration is None or min_segment_ratio >= 0
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 317 | 
             
                    self.segment_duration = segment_duration
         | 
| 318 | 
             
                    self.min_segment_ratio = min_segment_ratio
         | 
| 319 | 
             
                    self.max_audio_duration = max_audio_duration
         | 
|  | |
| 336 | 
             
                    self.sampling_probabilities = self._get_sampling_probabilities()
         | 
| 337 | 
             
                    self.max_read_retry = max_read_retry
         | 
| 338 | 
             
                    self.return_info = return_info
         | 
| 339 | 
            +
                    self.shuffle_seed = shuffle_seed
         | 
| 340 | 
            +
                    self.current_epoch: tp.Optional[int] = None
         | 
| 341 | 
            +
                    self.load_wav = load_wav
         | 
| 342 | 
            +
                    if not load_wav:
         | 
| 343 | 
            +
                        assert segment_duration is not None
         | 
| 344 | 
            +
                    self.permutation_on_files = permutation_on_files
         | 
| 345 | 
            +
                    if permutation_on_files:
         | 
| 346 | 
            +
                        assert not self.sample_on_duration
         | 
| 347 | 
            +
                        assert not self.sample_on_weight
         | 
| 348 | 
            +
                        assert self.shuffle
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                def start_epoch(self, epoch: int):
         | 
| 351 | 
            +
                    self.current_epoch = epoch
         | 
| 352 |  | 
| 353 | 
             
                def __len__(self):
         | 
| 354 | 
             
                    return self.num_samples
         | 
| 355 |  | 
| 356 | 
             
                def _get_sampling_probabilities(self, normalized: bool = True):
         | 
| 357 | 
            +
                    """Return the sampling probabilities for each file inside `self.meta`."""
         | 
|  | |
| 358 | 
             
                    scores: tp.List[float] = []
         | 
| 359 | 
             
                    for file_meta in self.meta:
         | 
| 360 | 
             
                        score = 1.
         | 
|  | |
| 368 | 
             
                        probabilities /= probabilities.sum()
         | 
| 369 | 
             
                    return probabilities
         | 
| 370 |  | 
| 371 | 
            +
                @staticmethod
         | 
| 372 | 
            +
                @lru_cache(16)
         | 
| 373 | 
            +
                def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
         | 
| 374 | 
            +
                    # Used to keep the most recent files permutation in memory implicitely.
         | 
| 375 | 
            +
                    # will work unless someone is using a lot of Datasets in parallel.
         | 
| 376 | 
            +
                    rng = torch.Generator()
         | 
| 377 | 
            +
                    rng.manual_seed(base_seed + permutation_index)
         | 
| 378 | 
            +
                    return torch.randperm(num_files, generator=rng)
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
         | 
| 381 | 
            +
                    """Sample a given file from `self.meta`. Can be overridden in subclasses.
         | 
| 382 | 
             
                    This is only called if `segment_duration` is not None.
         | 
| 383 |  | 
| 384 | 
             
                    You must use the provided random number generator `rng` for reproducibility.
         | 
| 385 | 
            +
                    You can further make use of the index accessed.
         | 
| 386 | 
             
                    """
         | 
| 387 | 
            +
                    if self.permutation_on_files:
         | 
| 388 | 
            +
                        assert self.current_epoch is not None
         | 
| 389 | 
            +
                        total_index = self.current_epoch * len(self) + index
         | 
| 390 | 
            +
                        permutation_index = total_index // len(self.meta)
         | 
| 391 | 
            +
                        relative_index = total_index % len(self.meta)
         | 
| 392 | 
            +
                        permutation = AudioDataset._get_file_permutation(
         | 
| 393 | 
            +
                            len(self.meta), permutation_index, self.shuffle_seed)
         | 
| 394 | 
            +
                        file_index = permutation[relative_index]
         | 
| 395 | 
            +
                        return self.meta[file_index]
         | 
| 396 | 
            +
             | 
| 397 | 
             
                    if not self.sample_on_weight and not self.sample_on_duration:
         | 
| 398 | 
             
                        file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
         | 
| 399 | 
             
                    else:
         | 
|  | |
| 401 |  | 
| 402 | 
             
                    return self.meta[file_index]
         | 
| 403 |  | 
| 404 | 
            +
                def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
         | 
| 405 | 
            +
                    # Override this method in subclass if needed.
         | 
| 406 | 
            +
                    if self.load_wav:
         | 
| 407 | 
            +
                        return audio_read(path, seek_time, duration, pad=False)
         | 
| 408 | 
            +
                    else:
         | 
| 409 | 
            +
                        assert self.segment_duration is not None
         | 
| 410 | 
            +
                        n_frames = int(self.sample_rate * self.segment_duration)
         | 
| 411 | 
            +
                        return torch.zeros(self.channels, n_frames), self.sample_rate
         | 
| 412 | 
            +
             | 
| 413 | 
             
                def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
         | 
| 414 | 
             
                    if self.segment_duration is None:
         | 
| 415 | 
             
                        file_meta = self.meta[index]
         | 
|  | |
| 417 | 
             
                        out = convert_audio(out, sr, self.sample_rate, self.channels)
         | 
| 418 | 
             
                        n_frames = out.shape[-1]
         | 
| 419 | 
             
                        segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
         | 
| 420 | 
            +
                                                   sample_rate=self.sample_rate, channels=out.shape[0])
         | 
| 421 | 
             
                    else:
         | 
| 422 | 
             
                        rng = torch.Generator()
         | 
| 423 | 
             
                        if self.shuffle:
         | 
| 424 | 
            +
                            # We use index, plus extra randomness, either totally random if we don't know the epoch.
         | 
| 425 | 
            +
                            # otherwise we make use of the epoch number and optional shuffle_seed.
         | 
| 426 | 
            +
                            if self.current_epoch is None:
         | 
| 427 | 
            +
                                rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
         | 
| 428 | 
            +
                            else:
         | 
| 429 | 
            +
                                rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
         | 
| 430 | 
             
                        else:
         | 
| 431 | 
             
                            # We only use index
         | 
| 432 | 
             
                            rng.manual_seed(index)
         | 
| 433 |  | 
| 434 | 
             
                        for retry in range(self.max_read_retry):
         | 
| 435 | 
            +
                            file_meta = self.sample_file(index, rng)
         | 
| 436 | 
             
                            # We add some variance in the file position even if audio file is smaller than segment
         | 
| 437 | 
             
                            # without ending up with empty segments
         | 
| 438 | 
             
                            max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
         | 
|  | |
| 445 | 
             
                                if self.pad:
         | 
| 446 | 
             
                                    out = F.pad(out, (0, target_frames - n_frames))
         | 
| 447 | 
             
                                segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
         | 
| 448 | 
            +
                                                           sample_rate=self.sample_rate, channels=out.shape[0])
         | 
| 449 | 
             
                            except Exception as exc:
         | 
| 450 | 
             
                                logger.warning("Error opening file %s: %r", file_meta.path, exc)
         | 
| 451 | 
             
                                if retry == self.max_read_retry - 1:
         | 
|  | |
| 487 | 
             
                        if to_pad:
         | 
| 488 | 
             
                            # Each wav could be of a different duration as they are not segmented.
         | 
| 489 | 
             
                            for i in range(len(samples)):
         | 
| 490 | 
            +
                                # Determines the total length of the signal with padding, so we update here as we pad.
         | 
| 491 | 
             
                                segment_infos[i].total_frames = max_len
         | 
| 492 | 
             
                                wavs[i] = _pad_wav(wavs[i])
         | 
| 493 |  | 
|  | |
| 500 | 
             
                        return torch.stack(samples)
         | 
| 501 |  | 
| 502 | 
             
                def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
         | 
| 503 | 
            +
                    """Filters out audio files with audio durations that will not allow to sample examples from them."""
         | 
|  | |
|  | |
| 504 | 
             
                    orig_len = len(meta)
         | 
| 505 |  | 
| 506 | 
             
                    # Filter data that is too short.
         | 
    	
        audiocraft/data/audio_utils.py
    CHANGED
    
    | @@ -3,7 +3,8 @@ | |
| 3 | 
             
            #
         | 
| 4 | 
             
            # This source code is licensed under the license found in the
         | 
| 5 | 
             
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            -
             | 
|  | |
| 7 | 
             
            import sys
         | 
| 8 | 
             
            import typing as tp
         | 
| 9 |  | 
| @@ -47,8 +48,7 @@ def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor | |
| 47 |  | 
| 48 | 
             
            def convert_audio(wav: torch.Tensor, from_rate: float,
         | 
| 49 | 
             
                              to_rate: float, to_channels: int) -> torch.Tensor:
         | 
| 50 | 
            -
                """Convert audio to new sample rate and number of audio channels.
         | 
| 51 | 
            -
                """
         | 
| 52 | 
             
                wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
         | 
| 53 | 
             
                wav = convert_audio_channels(wav, to_channels)
         | 
| 54 | 
             
                return wav
         | 
| @@ -66,7 +66,7 @@ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db | |
| 66 | 
             
                    loudness_compressor (bool): Uses tanh for soft clipping.
         | 
| 67 | 
             
                    energy_floor (float): anything below that RMS level will not be rescaled.
         | 
| 68 | 
             
                Returns:
         | 
| 69 | 
            -
                     | 
| 70 | 
             
                """
         | 
| 71 | 
             
                energy = wav.pow(2).mean().sqrt().item()
         | 
| 72 | 
             
                if energy < energy_floor:
         | 
| @@ -117,7 +117,7 @@ def normalize_audio(wav: torch.Tensor, normalize: bool = True, | |
| 117 | 
             
                    log_clipping (bool): If True, basic logging on stderr when clipping still
         | 
| 118 | 
             
                        occurs despite strategy (only for 'rms').
         | 
| 119 | 
             
                    sample_rate (int): Sample rate for the audio data (required for loudness).
         | 
| 120 | 
            -
                    stem_name ( | 
| 121 | 
             
                Returns:
         | 
| 122 | 
             
                    torch.Tensor: Normalized audio.
         | 
| 123 | 
             
                """
         | 
| @@ -150,17 +150,19 @@ def f32_pcm(wav: torch.Tensor) -> torch.Tensor: | |
| 150 | 
             
                """
         | 
| 151 | 
             
                if wav.dtype.is_floating_point:
         | 
| 152 | 
             
                    return wav
         | 
| 153 | 
            -
                 | 
| 154 | 
            -
                    assert wav.dtype == torch.int16
         | 
| 155 | 
             
                    return wav.float() / 2**15
         | 
|  | |
|  | |
|  | |
| 156 |  | 
| 157 |  | 
| 158 | 
             
            def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
         | 
| 159 | 
             
                """Convert audio to int 16 bits PCM format.
         | 
| 160 |  | 
| 161 | 
            -
                ..Warning:: There exist many formula for doing this  | 
| 162 | 
            -
                due to the  | 
| 163 | 
            -
                or  | 
| 164 | 
             
                it is possible that `i16_pcm(f32_pcm)) != Identity`.
         | 
| 165 | 
             
                """
         | 
| 166 | 
             
                if wav.dtype.is_floating_point:
         | 
|  | |
| 3 | 
             
            #
         | 
| 4 | 
             
            # This source code is licensed under the license found in the
         | 
| 5 | 
             
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Various utilities for audio convertion (pcm format, sample rate and channels),
         | 
| 7 | 
            +
            and volume normalization."""
         | 
| 8 | 
             
            import sys
         | 
| 9 | 
             
            import typing as tp
         | 
| 10 |  | 
|  | |
| 48 |  | 
| 49 | 
             
            def convert_audio(wav: torch.Tensor, from_rate: float,
         | 
| 50 | 
             
                              to_rate: float, to_channels: int) -> torch.Tensor:
         | 
| 51 | 
            +
                """Convert audio to new sample rate and number of audio channels."""
         | 
|  | |
| 52 | 
             
                wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
         | 
| 53 | 
             
                wav = convert_audio_channels(wav, to_channels)
         | 
| 54 | 
             
                return wav
         | 
|  | |
| 66 | 
             
                    loudness_compressor (bool): Uses tanh for soft clipping.
         | 
| 67 | 
             
                    energy_floor (float): anything below that RMS level will not be rescaled.
         | 
| 68 | 
             
                Returns:
         | 
| 69 | 
            +
                    torch.Tensor: Loudness normalized output data.
         | 
| 70 | 
             
                """
         | 
| 71 | 
             
                energy = wav.pow(2).mean().sqrt().item()
         | 
| 72 | 
             
                if energy < energy_floor:
         | 
|  | |
| 117 | 
             
                    log_clipping (bool): If True, basic logging on stderr when clipping still
         | 
| 118 | 
             
                        occurs despite strategy (only for 'rms').
         | 
| 119 | 
             
                    sample_rate (int): Sample rate for the audio data (required for loudness).
         | 
| 120 | 
            +
                    stem_name (str, optional): Stem name for clipping logging.
         | 
| 121 | 
             
                Returns:
         | 
| 122 | 
             
                    torch.Tensor: Normalized audio.
         | 
| 123 | 
             
                """
         | 
|  | |
| 150 | 
             
                """
         | 
| 151 | 
             
                if wav.dtype.is_floating_point:
         | 
| 152 | 
             
                    return wav
         | 
| 153 | 
            +
                elif wav.dtype == torch.int16:
         | 
|  | |
| 154 | 
             
                    return wav.float() / 2**15
         | 
| 155 | 
            +
                elif wav.dtype == torch.int32:
         | 
| 156 | 
            +
                    return wav.float() / 2**31
         | 
| 157 | 
            +
                raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
         | 
| 158 |  | 
| 159 |  | 
| 160 | 
             
            def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
         | 
| 161 | 
             
                """Convert audio to int 16 bits PCM format.
         | 
| 162 |  | 
| 163 | 
            +
                ..Warning:: There exist many formula for doing this conversion. None are perfect
         | 
| 164 | 
            +
                due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
         | 
| 165 | 
            +
                or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
         | 
| 166 | 
             
                it is possible that `i16_pcm(f32_pcm)) != Identity`.
         | 
| 167 | 
             
                """
         | 
| 168 | 
             
                if wav.dtype.is_floating_point:
         | 
    	
        audiocraft/data/info_audio_dataset.py
    ADDED
    
    | @@ -0,0 +1,110 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Base classes for the datasets that also provide non-audio metadata,
         | 
| 7 | 
            +
            e.g. description, text transcription etc.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
            from dataclasses import dataclass
         | 
| 10 | 
            +
            import logging
         | 
| 11 | 
            +
            import math
         | 
| 12 | 
            +
            import re
         | 
| 13 | 
            +
            import typing as tp
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from .audio_dataset import AudioDataset, AudioMeta
         | 
| 18 | 
            +
            from ..environment import AudioCraftEnvironment
         | 
| 19 | 
            +
            from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
         | 
| 26 | 
            +
                """Monkey-patch meta to match cluster specificities."""
         | 
| 27 | 
            +
                meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
         | 
| 28 | 
            +
                if meta.info_path is not None:
         | 
| 29 | 
            +
                    meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
         | 
| 30 | 
            +
                return meta
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
         | 
| 34 | 
            +
                """Monkey-patch all meta to match cluster specificities."""
         | 
| 35 | 
            +
                return [_clusterify_meta(m) for m in meta]
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            @dataclass
         | 
| 39 | 
            +
            class AudioInfo(SegmentWithAttributes):
         | 
| 40 | 
            +
                """Dummy SegmentInfo with empty attributes.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                The InfoAudioDataset is expected to return metadata that inherits
         | 
| 43 | 
            +
                from SegmentWithAttributes class and can return conditioning attributes.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                This basically guarantees all datasets will be compatible with current
         | 
| 46 | 
            +
                solver that contain conditioners requiring this.
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                audio_tokens: tp.Optional[torch.Tensor] = None  # populated when using cached batch for training a LM.
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def to_condition_attributes(self) -> ConditioningAttributes:
         | 
| 51 | 
            +
                    return ConditioningAttributes()
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            class InfoAudioDataset(AudioDataset):
         | 
| 55 | 
            +
                """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
                def __init__(self, meta: tp.List[AudioMeta], **kwargs):
         | 
| 60 | 
            +
                    super().__init__(clusterify_all_meta(meta), **kwargs)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
         | 
| 63 | 
            +
                    if not self.return_info:
         | 
| 64 | 
            +
                        wav = super().__getitem__(index)
         | 
| 65 | 
            +
                        assert isinstance(wav, torch.Tensor)
         | 
| 66 | 
            +
                        return wav
         | 
| 67 | 
            +
                    wav, meta = super().__getitem__(index)
         | 
| 68 | 
            +
                    return wav, AudioInfo(**meta.to_dict())
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
         | 
| 72 | 
            +
                """Preprocess a single keyword or possible a list of keywords."""
         | 
| 73 | 
            +
                if isinstance(value, list):
         | 
| 74 | 
            +
                    return get_keyword_list(value)
         | 
| 75 | 
            +
                else:
         | 
| 76 | 
            +
                    return get_keyword(value)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
         | 
| 80 | 
            +
                """Preprocess a single keyword."""
         | 
| 81 | 
            +
                if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
         | 
| 82 | 
            +
                    return None
         | 
| 83 | 
            +
                else:
         | 
| 84 | 
            +
                    return value.strip()
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
         | 
| 88 | 
            +
                """Preprocess a single keyword."""
         | 
| 89 | 
            +
                if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
         | 
| 90 | 
            +
                    return None
         | 
| 91 | 
            +
                else:
         | 
| 92 | 
            +
                    return value.strip().lower()
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
         | 
| 96 | 
            +
                """Preprocess a list of keywords."""
         | 
| 97 | 
            +
                if isinstance(values, str):
         | 
| 98 | 
            +
                    values = [v.strip() for v in re.split(r'[,\s]', values)]
         | 
| 99 | 
            +
                elif isinstance(values, float) and math.isnan(values):
         | 
| 100 | 
            +
                    values = []
         | 
| 101 | 
            +
                if not isinstance(values, list):
         | 
| 102 | 
            +
                    logger.debug(f"Unexpected keyword list {values}")
         | 
| 103 | 
            +
                    values = [str(values)]
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                kws = [get_keyword(v) for v in values]
         | 
| 106 | 
            +
                kw_list = [k for k in kws if k is not None]
         | 
| 107 | 
            +
                if len(kw_list) == 0:
         | 
| 108 | 
            +
                    return None
         | 
| 109 | 
            +
                else:
         | 
| 110 | 
            +
                    return kw_list
         | 
    	
        audiocraft/data/music_dataset.py
    ADDED
    
    | @@ -0,0 +1,270 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Dataset of music tracks with rich metadata.
         | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            from dataclasses import dataclass, field, fields, replace
         | 
| 9 | 
            +
            import gzip
         | 
| 10 | 
            +
            import json
         | 
| 11 | 
            +
            import logging
         | 
| 12 | 
            +
            from pathlib import Path
         | 
| 13 | 
            +
            import random
         | 
| 14 | 
            +
            import typing as tp
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .info_audio_dataset import (
         | 
| 19 | 
            +
                InfoAudioDataset,
         | 
| 20 | 
            +
                AudioInfo,
         | 
| 21 | 
            +
                get_keyword_list,
         | 
| 22 | 
            +
                get_keyword,
         | 
| 23 | 
            +
                get_string
         | 
| 24 | 
            +
            )
         | 
| 25 | 
            +
            from ..modules.conditioners import (
         | 
| 26 | 
            +
                ConditioningAttributes,
         | 
| 27 | 
            +
                JointEmbedCondition,
         | 
| 28 | 
            +
                WavCondition,
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
            from ..utils.utils import warn_once
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            @dataclass
         | 
| 37 | 
            +
            class MusicInfo(AudioInfo):
         | 
| 38 | 
            +
                """Segment info augmented with music metadata.
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                # music-specific metadata
         | 
| 41 | 
            +
                title: tp.Optional[str] = None
         | 
| 42 | 
            +
                artist: tp.Optional[str] = None  # anonymized artist id, used to ensure no overlap between splits
         | 
| 43 | 
            +
                key: tp.Optional[str] = None
         | 
| 44 | 
            +
                bpm: tp.Optional[float] = None
         | 
| 45 | 
            +
                genre: tp.Optional[str] = None
         | 
| 46 | 
            +
                moods: tp.Optional[list] = None
         | 
| 47 | 
            +
                keywords: tp.Optional[list] = None
         | 
| 48 | 
            +
                description: tp.Optional[str] = None
         | 
| 49 | 
            +
                name: tp.Optional[str] = None
         | 
| 50 | 
            +
                instrument: tp.Optional[str] = None
         | 
| 51 | 
            +
                # original wav accompanying the metadata
         | 
| 52 | 
            +
                self_wav: tp.Optional[WavCondition] = None
         | 
| 53 | 
            +
                # dict mapping attributes names to tuple of wav, text and metadata
         | 
| 54 | 
            +
                joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                @property
         | 
| 57 | 
            +
                def has_music_meta(self) -> bool:
         | 
| 58 | 
            +
                    return self.name is not None
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def to_condition_attributes(self) -> ConditioningAttributes:
         | 
| 61 | 
            +
                    out = ConditioningAttributes()
         | 
| 62 | 
            +
                    for _field in fields(self):
         | 
| 63 | 
            +
                        key, value = _field.name, getattr(self, _field.name)
         | 
| 64 | 
            +
                        if key == 'self_wav':
         | 
| 65 | 
            +
                            out.wav[key] = value
         | 
| 66 | 
            +
                        elif key == 'joint_embed':
         | 
| 67 | 
            +
                            for embed_attribute, embed_cond in value.items():
         | 
| 68 | 
            +
                                out.joint_embed[embed_attribute] = embed_cond
         | 
| 69 | 
            +
                        else:
         | 
| 70 | 
            +
                            if isinstance(value, list):
         | 
| 71 | 
            +
                                value = ' '.join(value)
         | 
| 72 | 
            +
                            out.text[key] = value
         | 
| 73 | 
            +
                    return out
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                @staticmethod
         | 
| 76 | 
            +
                def attribute_getter(attribute):
         | 
| 77 | 
            +
                    if attribute == 'bpm':
         | 
| 78 | 
            +
                        preprocess_func = get_bpm
         | 
| 79 | 
            +
                    elif attribute == 'key':
         | 
| 80 | 
            +
                        preprocess_func = get_musical_key
         | 
| 81 | 
            +
                    elif attribute in ['moods', 'keywords']:
         | 
| 82 | 
            +
                        preprocess_func = get_keyword_list
         | 
| 83 | 
            +
                    elif attribute in ['genre', 'name', 'instrument']:
         | 
| 84 | 
            +
                        preprocess_func = get_keyword
         | 
| 85 | 
            +
                    elif attribute in ['title', 'artist', 'description']:
         | 
| 86 | 
            +
                        preprocess_func = get_string
         | 
| 87 | 
            +
                    else:
         | 
| 88 | 
            +
                        preprocess_func = None
         | 
| 89 | 
            +
                    return preprocess_func
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                @classmethod
         | 
| 92 | 
            +
                def from_dict(cls, dictionary: dict, fields_required: bool = False):
         | 
| 93 | 
            +
                    _dictionary: tp.Dict[str, tp.Any] = {}
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    # allow a subset of attributes to not be loaded from the dictionary
         | 
| 96 | 
            +
                    # these attributes may be populated later
         | 
| 97 | 
            +
                    post_init_attributes = ['self_wav', 'joint_embed']
         | 
| 98 | 
            +
                    optional_fields = ['keywords']
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    for _field in fields(cls):
         | 
| 101 | 
            +
                        if _field.name in post_init_attributes:
         | 
| 102 | 
            +
                            continue
         | 
| 103 | 
            +
                        elif _field.name not in dictionary:
         | 
| 104 | 
            +
                            if fields_required and _field.name not in optional_fields:
         | 
| 105 | 
            +
                                raise KeyError(f"Unexpected missing key: {_field.name}")
         | 
| 106 | 
            +
                        else:
         | 
| 107 | 
            +
                            preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
         | 
| 108 | 
            +
                            value = dictionary[_field.name]
         | 
| 109 | 
            +
                            if preprocess_func:
         | 
| 110 | 
            +
                                value = preprocess_func(value)
         | 
| 111 | 
            +
                            _dictionary[_field.name] = value
         | 
| 112 | 
            +
                    return cls(**_dictionary)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
         | 
| 116 | 
            +
                                               drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
         | 
| 117 | 
            +
                """Augment MusicInfo description with additional metadata fields and potential dropout.
         | 
| 118 | 
            +
                Additional textual attributes are added given probability 'merge_text_conditions_p' and
         | 
| 119 | 
            +
                the original textual description is dropped from the augmented description given probability drop_desc_p.
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                Args:
         | 
| 122 | 
            +
                    music_info (MusicInfo): The music metadata to augment.
         | 
| 123 | 
            +
                    merge_text_p (float): Probability of merging additional metadata to the description.
         | 
| 124 | 
            +
                        If provided value is 0, then no merging is performed.
         | 
| 125 | 
            +
                    drop_desc_p (float): Probability of dropping the original description on text merge.
         | 
| 126 | 
            +
                        if provided value is 0, then no drop out is performed.
         | 
| 127 | 
            +
                    drop_other_p (float): Probability of dropping the other fields used for text augmentation.
         | 
| 128 | 
            +
                Returns:
         | 
| 129 | 
            +
                    MusicInfo: The MusicInfo with augmented textual description.
         | 
| 130 | 
            +
                """
         | 
| 131 | 
            +
                def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
         | 
| 132 | 
            +
                    valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
         | 
| 133 | 
            +
                    valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
         | 
| 134 | 
            +
                    keep_field = random.uniform(0, 1) < drop_other_p
         | 
| 135 | 
            +
                    return valid_field_name and valid_field_value and keep_field
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def process_value(v: tp.Any) -> str:
         | 
| 138 | 
            +
                    if isinstance(v, (int, float, str)):
         | 
| 139 | 
            +
                        return str(v)
         | 
| 140 | 
            +
                    if isinstance(v, list):
         | 
| 141 | 
            +
                        return ", ".join(v)
         | 
| 142 | 
            +
                    else:
         | 
| 143 | 
            +
                        raise ValueError(f"Unknown type for text value! ({type(v), v})")
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                description = music_info.description
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                metadata_text = ""
         | 
| 148 | 
            +
                if random.uniform(0, 1) < merge_text_p:
         | 
| 149 | 
            +
                    meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
         | 
| 150 | 
            +
                                  for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
         | 
| 151 | 
            +
                    random.shuffle(meta_pairs)
         | 
| 152 | 
            +
                    metadata_text = ". ".join(meta_pairs)
         | 
| 153 | 
            +
                    description = description if not random.uniform(0, 1) < drop_desc_p else None
         | 
| 154 | 
            +
                    logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                if description is None:
         | 
| 157 | 
            +
                    description = metadata_text if len(metadata_text) > 1 else None
         | 
| 158 | 
            +
                else:
         | 
| 159 | 
            +
                    description = ". ".join([description.rstrip('.'), metadata_text])
         | 
| 160 | 
            +
                description = description.strip() if description else None
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                music_info = replace(music_info)
         | 
| 163 | 
            +
                music_info.description = description
         | 
| 164 | 
            +
                return music_info
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            class Paraphraser:
         | 
| 168 | 
            +
                def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
         | 
| 169 | 
            +
                    self.paraphrase_p = paraphrase_p
         | 
| 170 | 
            +
                    open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
         | 
| 171 | 
            +
                    with open_fn(paraphrase_source, 'rb') as f:  # type: ignore
         | 
| 172 | 
            +
                        self.paraphrase_source = json.loads(f.read())
         | 
| 173 | 
            +
                    logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def sample_paraphrase(self, audio_path: str, description: str):
         | 
| 176 | 
            +
                    if random.random() >= self.paraphrase_p:
         | 
| 177 | 
            +
                        return description
         | 
| 178 | 
            +
                    info_path = Path(audio_path).with_suffix('.json')
         | 
| 179 | 
            +
                    if info_path not in self.paraphrase_source:
         | 
| 180 | 
            +
                        warn_once(logger, f"{info_path} not in paraphrase source!")
         | 
| 181 | 
            +
                        return description
         | 
| 182 | 
            +
                    new_desc = random.choice(self.paraphrase_source[info_path])
         | 
| 183 | 
            +
                    logger.debug(f"{description} -> {new_desc}")
         | 
| 184 | 
            +
                    return new_desc
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            class MusicDataset(InfoAudioDataset):
         | 
| 188 | 
            +
                """Music dataset is an AudioDataset with music-related metadata.
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                Args:
         | 
| 191 | 
            +
                    info_fields_required (bool): Whether to enforce having required fields.
         | 
| 192 | 
            +
                    merge_text_p (float): Probability of merging additional metadata to the description.
         | 
| 193 | 
            +
                    drop_desc_p (float): Probability of dropping the original description on text merge.
         | 
| 194 | 
            +
                    drop_other_p (float): Probability of dropping the other fields used for text augmentation.
         | 
| 195 | 
            +
                    joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
         | 
| 196 | 
            +
                    paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
         | 
| 197 | 
            +
                        paraphrases for the description. The json should be a dict with keys are the
         | 
| 198 | 
            +
                        original info path (e.g. track_path.json) and each value is a list of possible
         | 
| 199 | 
            +
                        paraphrased.
         | 
| 200 | 
            +
                    paraphrase_p (float): probability of taking a paraphrase.
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
         | 
| 203 | 
            +
                """
         | 
| 204 | 
            +
                def __init__(self, *args, info_fields_required: bool = True,
         | 
| 205 | 
            +
                             merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
         | 
| 206 | 
            +
                             joint_embed_attributes: tp.List[str] = [],
         | 
| 207 | 
            +
                             paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
         | 
| 208 | 
            +
                             **kwargs):
         | 
| 209 | 
            +
                    kwargs['return_info'] = True  # We require the info for each song of the dataset.
         | 
| 210 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 211 | 
            +
                    self.info_fields_required = info_fields_required
         | 
| 212 | 
            +
                    self.merge_text_p = merge_text_p
         | 
| 213 | 
            +
                    self.drop_desc_p = drop_desc_p
         | 
| 214 | 
            +
                    self.drop_other_p = drop_other_p
         | 
| 215 | 
            +
                    self.joint_embed_attributes = joint_embed_attributes
         | 
| 216 | 
            +
                    self.paraphraser = None
         | 
| 217 | 
            +
                    if paraphrase_source is not None:
         | 
| 218 | 
            +
                        self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                def __getitem__(self, index):
         | 
| 221 | 
            +
                    wav, info = super().__getitem__(index)
         | 
| 222 | 
            +
                    info_data = info.to_dict()
         | 
| 223 | 
            +
                    music_info_path = Path(info.meta.path).with_suffix('.json')
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    if Path(music_info_path).exists():
         | 
| 226 | 
            +
                        with open(music_info_path, 'r') as json_file:
         | 
| 227 | 
            +
                            music_data = json.load(json_file)
         | 
| 228 | 
            +
                            music_data.update(info_data)
         | 
| 229 | 
            +
                            music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
         | 
| 230 | 
            +
                        if self.paraphraser is not None:
         | 
| 231 | 
            +
                            music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
         | 
| 232 | 
            +
                        if self.merge_text_p:
         | 
| 233 | 
            +
                            music_info = augment_music_info_description(
         | 
| 234 | 
            +
                                music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
         | 
| 235 | 
            +
                    else:
         | 
| 236 | 
            +
                        music_info = MusicInfo.from_dict(info_data, fields_required=False)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    music_info.self_wav = WavCondition(
         | 
| 239 | 
            +
                        wav=wav[None], length=torch.tensor([info.n_frames]),
         | 
| 240 | 
            +
                        sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    for att in self.joint_embed_attributes:
         | 
| 243 | 
            +
                        att_value = getattr(music_info, att)
         | 
| 244 | 
            +
                        joint_embed_cond = JointEmbedCondition(
         | 
| 245 | 
            +
                            wav[None], [att_value], torch.tensor([info.n_frames]),
         | 
| 246 | 
            +
                            sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
         | 
| 247 | 
            +
                        music_info.joint_embed[att] = joint_embed_cond
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    return wav, music_info
         | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
            +
            def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
         | 
| 253 | 
            +
                """Preprocess key keywords, discarding them if there are multiple key defined."""
         | 
| 254 | 
            +
                if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
         | 
| 255 | 
            +
                    return None
         | 
| 256 | 
            +
                elif ',' in value:
         | 
| 257 | 
            +
                    # For now, we discard when multiple keys are defined separated with comas
         | 
| 258 | 
            +
                    return None
         | 
| 259 | 
            +
                else:
         | 
| 260 | 
            +
                    return value.strip().lower()
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
         | 
| 264 | 
            +
                """Preprocess to a float."""
         | 
| 265 | 
            +
                if value is None:
         | 
| 266 | 
            +
                    return None
         | 
| 267 | 
            +
                try:
         | 
| 268 | 
            +
                    return float(value)
         | 
| 269 | 
            +
                except ValueError:
         | 
| 270 | 
            +
                    return None
         | 
    	
        audiocraft/data/sound_dataset.py
    ADDED
    
    | @@ -0,0 +1,330 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Dataset of audio with a simple description.
         | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from dataclasses import dataclass, fields, replace
         | 
| 10 | 
            +
            import json
         | 
| 11 | 
            +
            from pathlib import Path
         | 
| 12 | 
            +
            import random
         | 
| 13 | 
            +
            import typing as tp
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import numpy as np
         | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .info_audio_dataset import (
         | 
| 19 | 
            +
                InfoAudioDataset,
         | 
| 20 | 
            +
                get_keyword_or_keyword_list
         | 
| 21 | 
            +
            )
         | 
| 22 | 
            +
            from ..modules.conditioners import (
         | 
| 23 | 
            +
                ConditioningAttributes,
         | 
| 24 | 
            +
                SegmentWithAttributes,
         | 
| 25 | 
            +
                WavCondition,
         | 
| 26 | 
            +
            )
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            EPS = torch.finfo(torch.float32).eps
         | 
| 30 | 
            +
            TARGET_LEVEL_LOWER = -35
         | 
| 31 | 
            +
            TARGET_LEVEL_UPPER = -15
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            @dataclass
         | 
| 35 | 
            +
            class SoundInfo(SegmentWithAttributes):
         | 
| 36 | 
            +
                """Segment info augmented with Sound metadata.
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                description: tp.Optional[str] = None
         | 
| 39 | 
            +
                self_wav: tp.Optional[torch.Tensor] = None
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                @property
         | 
| 42 | 
            +
                def has_sound_meta(self) -> bool:
         | 
| 43 | 
            +
                    return self.description is not None
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def to_condition_attributes(self) -> ConditioningAttributes:
         | 
| 46 | 
            +
                    out = ConditioningAttributes()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    for _field in fields(self):
         | 
| 49 | 
            +
                        key, value = _field.name, getattr(self, _field.name)
         | 
| 50 | 
            +
                        if key == 'self_wav':
         | 
| 51 | 
            +
                            out.wav[key] = value
         | 
| 52 | 
            +
                        else:
         | 
| 53 | 
            +
                            out.text[key] = value
         | 
| 54 | 
            +
                    return out
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                @staticmethod
         | 
| 57 | 
            +
                def attribute_getter(attribute):
         | 
| 58 | 
            +
                    if attribute == 'description':
         | 
| 59 | 
            +
                        preprocess_func = get_keyword_or_keyword_list
         | 
| 60 | 
            +
                    else:
         | 
| 61 | 
            +
                        preprocess_func = None
         | 
| 62 | 
            +
                    return preprocess_func
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                @classmethod
         | 
| 65 | 
            +
                def from_dict(cls, dictionary: dict, fields_required: bool = False):
         | 
| 66 | 
            +
                    _dictionary: tp.Dict[str, tp.Any] = {}
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # allow a subset of attributes to not be loaded from the dictionary
         | 
| 69 | 
            +
                    # these attributes may be populated later
         | 
| 70 | 
            +
                    post_init_attributes = ['self_wav']
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    for _field in fields(cls):
         | 
| 73 | 
            +
                        if _field.name in post_init_attributes:
         | 
| 74 | 
            +
                            continue
         | 
| 75 | 
            +
                        elif _field.name not in dictionary:
         | 
| 76 | 
            +
                            if fields_required:
         | 
| 77 | 
            +
                                raise KeyError(f"Unexpected missing key: {_field.name}")
         | 
| 78 | 
            +
                        else:
         | 
| 79 | 
            +
                            preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
         | 
| 80 | 
            +
                            value = dictionary[_field.name]
         | 
| 81 | 
            +
                            if preprocess_func:
         | 
| 82 | 
            +
                                value = preprocess_func(value)
         | 
| 83 | 
            +
                            _dictionary[_field.name] = value
         | 
| 84 | 
            +
                    return cls(**_dictionary)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            class SoundDataset(InfoAudioDataset):
         | 
| 88 | 
            +
                """Sound audio dataset: Audio dataset with environmental sound-specific metadata.
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                Args:
         | 
| 91 | 
            +
                    info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
         | 
| 92 | 
            +
                    external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
         | 
| 93 | 
            +
                        The metadata files contained in this folder are expected to match the stem of the audio file with
         | 
| 94 | 
            +
                        a json extension.
         | 
| 95 | 
            +
                    aug_p (float): Probability of performing audio mixing augmentation on the batch.
         | 
| 96 | 
            +
                    mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
         | 
| 97 | 
            +
                    mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
         | 
| 98 | 
            +
                    mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
         | 
| 99 | 
            +
                    mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
         | 
| 100 | 
            +
                    kwargs: Additional arguments for AudioDataset.
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
         | 
| 103 | 
            +
                """
         | 
| 104 | 
            +
                def __init__(
         | 
| 105 | 
            +
                    self,
         | 
| 106 | 
            +
                    *args,
         | 
| 107 | 
            +
                    info_fields_required: bool = True,
         | 
| 108 | 
            +
                    external_metadata_source: tp.Optional[str] = None,
         | 
| 109 | 
            +
                    aug_p: float = 0.,
         | 
| 110 | 
            +
                    mix_p: float = 0.,
         | 
| 111 | 
            +
                    mix_snr_low: int = -5,
         | 
| 112 | 
            +
                    mix_snr_high: int = 5,
         | 
| 113 | 
            +
                    mix_min_overlap: float = 0.5,
         | 
| 114 | 
            +
                    **kwargs
         | 
| 115 | 
            +
                ):
         | 
| 116 | 
            +
                    kwargs['return_info'] = True  # We require the info for each song of the dataset.
         | 
| 117 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 118 | 
            +
                    self.info_fields_required = info_fields_required
         | 
| 119 | 
            +
                    self.external_metadata_source = external_metadata_source
         | 
| 120 | 
            +
                    self.aug_p = aug_p
         | 
| 121 | 
            +
                    self.mix_p = mix_p
         | 
| 122 | 
            +
                    if self.aug_p > 0:
         | 
| 123 | 
            +
                        assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
         | 
| 124 | 
            +
                        assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
         | 
| 125 | 
            +
                    self.mix_snr_low = mix_snr_low
         | 
| 126 | 
            +
                    self.mix_snr_high = mix_snr_high
         | 
| 127 | 
            +
                    self.mix_min_overlap = mix_min_overlap
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
         | 
| 130 | 
            +
                    """Get path of JSON with metadata (description, etc.).
         | 
| 131 | 
            +
                    If there exists a JSON with the same name as 'path.name', then it will be used.
         | 
| 132 | 
            +
                    Else, such JSON will be searched for in an external json source folder if it exists.
         | 
| 133 | 
            +
                    """
         | 
| 134 | 
            +
                    info_path = Path(path).with_suffix('.json')
         | 
| 135 | 
            +
                    if Path(info_path).exists():
         | 
| 136 | 
            +
                        return info_path
         | 
| 137 | 
            +
                    elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
         | 
| 138 | 
            +
                        return Path(self.external_metadata_source) / info_path.name
         | 
| 139 | 
            +
                    else:
         | 
| 140 | 
            +
                        raise Exception(f"Unable to find a metadata JSON for path: {path}")
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def __getitem__(self, index):
         | 
| 143 | 
            +
                    wav, info = super().__getitem__(index)
         | 
| 144 | 
            +
                    info_data = info.to_dict()
         | 
| 145 | 
            +
                    info_path = self._get_info_path(info.meta.path)
         | 
| 146 | 
            +
                    if Path(info_path).exists():
         | 
| 147 | 
            +
                        with open(info_path, 'r') as json_file:
         | 
| 148 | 
            +
                            sound_data = json.load(json_file)
         | 
| 149 | 
            +
                            sound_data.update(info_data)
         | 
| 150 | 
            +
                            sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
         | 
| 151 | 
            +
                            # if there are multiple descriptions, sample one randomly
         | 
| 152 | 
            +
                            if isinstance(sound_info.description, list):
         | 
| 153 | 
            +
                                sound_info.description = random.choice(sound_info.description)
         | 
| 154 | 
            +
                    else:
         | 
| 155 | 
            +
                        sound_info = SoundInfo.from_dict(info_data, fields_required=False)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    sound_info.self_wav = WavCondition(
         | 
| 158 | 
            +
                        wav=wav[None], length=torch.tensor([info.n_frames]),
         | 
| 159 | 
            +
                        sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    return wav, sound_info
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def collater(self, samples):
         | 
| 164 | 
            +
                    # when training, audio mixing is performed in the collate function
         | 
| 165 | 
            +
                    wav, sound_info = super().collater(samples)  # SoundDataset always returns infos
         | 
| 166 | 
            +
                    if self.aug_p > 0:
         | 
| 167 | 
            +
                        wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
         | 
| 168 | 
            +
                                                      snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
         | 
| 169 | 
            +
                                                      min_overlap=self.mix_min_overlap)
         | 
| 170 | 
            +
                    return wav, sound_info
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            def rms_f(x: torch.Tensor) -> torch.Tensor:
         | 
| 174 | 
            +
                return (x ** 2).mean(1).pow(0.5)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
             | 
| 177 | 
            +
            def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
         | 
| 178 | 
            +
                """Normalize the signal to the target level."""
         | 
| 179 | 
            +
                rms = rms_f(audio)
         | 
| 180 | 
            +
                scalar = 10 ** (target_level / 20) / (rms + EPS)
         | 
| 181 | 
            +
                audio = audio * scalar.unsqueeze(1)
         | 
| 182 | 
            +
                return audio
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
         | 
| 186 | 
            +
                return (abs(audio) > clipping_threshold).any(1)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
             | 
| 189 | 
            +
            def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
         | 
| 190 | 
            +
                start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
         | 
| 191 | 
            +
                remainder = src.shape[1] - start
         | 
| 192 | 
            +
                if dst.shape[1] > remainder:
         | 
| 193 | 
            +
                    src[:, start:] = src[:, start:] + dst[:, :remainder]
         | 
| 194 | 
            +
                else:
         | 
| 195 | 
            +
                    src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
         | 
| 196 | 
            +
                return src
         | 
| 197 | 
            +
             | 
| 198 | 
            +
             | 
| 199 | 
            +
            def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
         | 
| 200 | 
            +
                          target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
         | 
| 201 | 
            +
                """Function to mix clean speech and noise at various SNR levels.
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                Args:
         | 
| 204 | 
            +
                    clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
         | 
| 205 | 
            +
                    noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
         | 
| 206 | 
            +
                    snr (int): SNR level when mixing.
         | 
| 207 | 
            +
                    min_overlap (float): Minimum overlap between the two mixed sources.
         | 
| 208 | 
            +
                    target_level (int): Gain level in dB.
         | 
| 209 | 
            +
                    clipping_threshold (float): Threshold for clipping the audio.
         | 
| 210 | 
            +
                Returns:
         | 
| 211 | 
            +
                    torch.Tensor: The mixed audio, of shape [B, T].
         | 
| 212 | 
            +
                """
         | 
| 213 | 
            +
                if clean.shape[1] > noise.shape[1]:
         | 
| 214 | 
            +
                    noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
         | 
| 215 | 
            +
                else:
         | 
| 216 | 
            +
                    noise = noise[:, :clean.shape[1]]
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                # normalizing to -25 dB FS
         | 
| 219 | 
            +
                clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
         | 
| 220 | 
            +
                clean = normalize(clean, target_level)
         | 
| 221 | 
            +
                rmsclean = rms_f(clean)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
         | 
| 224 | 
            +
                noise = normalize(noise, target_level)
         | 
| 225 | 
            +
                rmsnoise = rms_f(noise)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                # set the noise level for a given SNR
         | 
| 228 | 
            +
                noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
         | 
| 229 | 
            +
                noisenewlevel = noise * noisescalar
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                # mix noise and clean speech
         | 
| 232 | 
            +
                noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
         | 
| 235 | 
            +
                # there is a chance of clipping that might happen with very less probability, which is not a major issue.
         | 
| 236 | 
            +
                noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
         | 
| 237 | 
            +
                rmsnoisy = rms_f(noisyspeech)
         | 
| 238 | 
            +
                scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
         | 
| 239 | 
            +
                noisyspeech = noisyspeech * scalarnoisy
         | 
| 240 | 
            +
                clean = clean * scalarnoisy
         | 
| 241 | 
            +
                noisenewlevel = noisenewlevel * scalarnoisy
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
         | 
| 244 | 
            +
                clipped = is_clipped(noisyspeech)
         | 
| 245 | 
            +
                if clipped.any():
         | 
| 246 | 
            +
                    noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
         | 
| 247 | 
            +
                    noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                return noisyspeech
         | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
            +
            def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
         | 
| 253 | 
            +
                if snr_low == snr_high:
         | 
| 254 | 
            +
                    snr = snr_low
         | 
| 255 | 
            +
                else:
         | 
| 256 | 
            +
                    snr = np.random.randint(snr_low, snr_high)
         | 
| 257 | 
            +
                mix = snr_mixer(src, dst, snr, min_overlap)
         | 
| 258 | 
            +
                return mix
         | 
| 259 | 
            +
             | 
| 260 | 
            +
             | 
| 261 | 
            +
            def mix_text(src_text: str, dst_text: str):
         | 
| 262 | 
            +
                """Mix text from different sources by concatenating them."""
         | 
| 263 | 
            +
                if src_text == dst_text:
         | 
| 264 | 
            +
                    return src_text
         | 
| 265 | 
            +
                return src_text + " " + dst_text
         | 
| 266 | 
            +
             | 
| 267 | 
            +
             | 
| 268 | 
            +
            def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
         | 
| 269 | 
            +
                            snr_low: int, snr_high: int, min_overlap: float):
         | 
| 270 | 
            +
                """Mix samples within a batch, summing the waveforms and concatenating the text infos.
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                Args:
         | 
| 273 | 
            +
                    wavs (torch.Tensor): Audio tensors of shape [B, C, T].
         | 
| 274 | 
            +
                    infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
         | 
| 275 | 
            +
                    aug_p (float): Augmentation probability.
         | 
| 276 | 
            +
                    mix_p (float): Proportion of items in the batch to mix (and merge) together.
         | 
| 277 | 
            +
                    snr_low (int): Lowerbound for sampling SNR.
         | 
| 278 | 
            +
                    snr_high (int): Upperbound for sampling SNR.
         | 
| 279 | 
            +
                    min_overlap (float): Minimum overlap between mixed samples.
         | 
| 280 | 
            +
                Returns:
         | 
| 281 | 
            +
                    tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
         | 
| 282 | 
            +
                        and mixed SoundInfo for the given batch.
         | 
| 283 | 
            +
                """
         | 
| 284 | 
            +
                # no mixing to perform within the batch
         | 
| 285 | 
            +
                if mix_p == 0:
         | 
| 286 | 
            +
                    return wavs, infos
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                if random.uniform(0, 1) < aug_p:
         | 
| 289 | 
            +
                    # perform all augmentations on waveforms as [B, T]
         | 
| 290 | 
            +
                    # randomly picking pairs of audio to mix
         | 
| 291 | 
            +
                    assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
         | 
| 292 | 
            +
                    wavs = wavs.mean(dim=1, keepdim=False)
         | 
| 293 | 
            +
                    B, T = wavs.shape
         | 
| 294 | 
            +
                    k = int(mix_p * B)
         | 
| 295 | 
            +
                    mixed_sources_idx = torch.randperm(B)[:k]
         | 
| 296 | 
            +
                    mixed_targets_idx = torch.randperm(B)[:k]
         | 
| 297 | 
            +
                    aug_wavs = snr_mix(
         | 
| 298 | 
            +
                        wavs[mixed_sources_idx],
         | 
| 299 | 
            +
                        wavs[mixed_targets_idx],
         | 
| 300 | 
            +
                        snr_low,
         | 
| 301 | 
            +
                        snr_high,
         | 
| 302 | 
            +
                        min_overlap,
         | 
| 303 | 
            +
                    )
         | 
| 304 | 
            +
                    # mixing textual descriptions in metadata
         | 
| 305 | 
            +
                    descriptions = [info.description for info in infos]
         | 
| 306 | 
            +
                    aug_infos = []
         | 
| 307 | 
            +
                    for i, j in zip(mixed_sources_idx, mixed_targets_idx):
         | 
| 308 | 
            +
                        text = mix_text(descriptions[i], descriptions[j])
         | 
| 309 | 
            +
                        m = replace(infos[i])
         | 
| 310 | 
            +
                        m.description = text
         | 
| 311 | 
            +
                        aug_infos.append(m)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    # back to [B, C, T]
         | 
| 314 | 
            +
                    aug_wavs = aug_wavs.unsqueeze(1)
         | 
| 315 | 
            +
                    assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
         | 
| 316 | 
            +
                    assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
         | 
| 317 | 
            +
                    assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    return aug_wavs, aug_infos  # [B, C, T]
         | 
| 320 | 
            +
                else:
         | 
| 321 | 
            +
                    # randomly pick samples in the batch to match
         | 
| 322 | 
            +
                    # the batch size when performing audio mixing
         | 
| 323 | 
            +
                    B, C, T = wavs.shape
         | 
| 324 | 
            +
                    k = int(mix_p * B)
         | 
| 325 | 
            +
                    wav_idx = torch.randperm(B)[:k]
         | 
| 326 | 
            +
                    wavs = wavs[wav_idx]
         | 
| 327 | 
            +
                    infos = [infos[i] for i in wav_idx]
         | 
| 328 | 
            +
                    assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    return wavs, infos  # [B, C, T]
         | 
    	
        audiocraft/data/zip.py
    CHANGED
    
    | @@ -3,6 +3,8 @@ | |
| 3 | 
             
            #
         | 
| 4 | 
             
            # This source code is licensed under the license found in the
         | 
| 5 | 
             
            # LICENSE file in the root directory of this source tree.
         | 
|  | |
|  | |
| 6 |  | 
| 7 | 
             
            import typing
         | 
| 8 | 
             
            import zipfile
         | 
| @@ -18,13 +20,13 @@ MODE = Literal['r', 'w', 'x', 'a'] | |
| 18 |  | 
| 19 | 
             
            @dataclass(order=True)
         | 
| 20 | 
             
            class PathInZip:
         | 
| 21 | 
            -
                """ | 
| 22 |  | 
| 23 | 
             
                Args:
         | 
| 24 | 
            -
                    path: The convention is <path_to_zip>:<relative_path_inside_zip | 
| 25 | 
             
                        Let's assume there is a zip file /some/location/foo.zip
         | 
| 26 | 
             
                        and inside of it is a json file located at /data/file1.json,
         | 
| 27 | 
            -
                        Then we expect path = "/some/location/foo.zip:/data/file1.json"
         | 
| 28 | 
             
                """
         | 
| 29 |  | 
| 30 | 
             
                INFO_PATH_SEP = ':'
         | 
| @@ -55,7 +57,7 @@ def set_zip_cache_size(max_size: int): | |
| 55 | 
             
                """Sets the maximal LRU caching for zip file opening.
         | 
| 56 |  | 
| 57 | 
             
                Args:
         | 
| 58 | 
            -
                    max_size: the maximal LRU cache.
         | 
| 59 | 
             
                """
         | 
| 60 | 
             
                global _cached_open_zip
         | 
| 61 | 
             
                _cached_open_zip = lru_cache(max_size)(_open_zip)
         | 
| @@ -65,8 +67,8 @@ def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO: | |
| 65 | 
             
                """Opens a file stored inside a zip and returns a file-like object.
         | 
| 66 |  | 
| 67 | 
             
                Args:
         | 
| 68 | 
            -
                    path_in_zip: A PathInZip object representing the file to return a file-like object of.
         | 
| 69 | 
            -
                    mode: The mode in which to open the file with.
         | 
| 70 | 
             
                Returns:
         | 
| 71 | 
             
                    A file-like object for PathInZip.
         | 
| 72 | 
             
                """
         | 
|  | |
| 3 | 
             
            #
         | 
| 4 | 
             
            # This source code is licensed under the license found in the
         | 
| 5 | 
             
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Utility for reading some info from inside a zip file.
         | 
| 7 | 
            +
            """
         | 
| 8 |  | 
| 9 | 
             
            import typing
         | 
| 10 | 
             
            import zipfile
         | 
|  | |
| 20 |  | 
| 21 | 
             
            @dataclass(order=True)
         | 
| 22 | 
             
            class PathInZip:
         | 
| 23 | 
            +
                """Hold a path of file within a zip file.
         | 
| 24 |  | 
| 25 | 
             
                Args:
         | 
| 26 | 
            +
                    path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
         | 
| 27 | 
             
                        Let's assume there is a zip file /some/location/foo.zip
         | 
| 28 | 
             
                        and inside of it is a json file located at /data/file1.json,
         | 
| 29 | 
            +
                        Then we expect path = "/some/location/foo.zip:/data/file1.json".
         | 
| 30 | 
             
                """
         | 
| 31 |  | 
| 32 | 
             
                INFO_PATH_SEP = ':'
         | 
|  | |
| 57 | 
             
                """Sets the maximal LRU caching for zip file opening.
         | 
| 58 |  | 
| 59 | 
             
                Args:
         | 
| 60 | 
            +
                    max_size (int): the maximal LRU cache.
         | 
| 61 | 
             
                """
         | 
| 62 | 
             
                global _cached_open_zip
         | 
| 63 | 
             
                _cached_open_zip = lru_cache(max_size)(_open_zip)
         | 
|  | |
| 67 | 
             
                """Opens a file stored inside a zip and returns a file-like object.
         | 
| 68 |  | 
| 69 | 
             
                Args:
         | 
| 70 | 
            +
                    path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
         | 
| 71 | 
            +
                    mode (str): The mode in which to open the file with.
         | 
| 72 | 
             
                Returns:
         | 
| 73 | 
             
                    A file-like object for PathInZip.
         | 
| 74 | 
             
                """
         | 
    	
        audiocraft/environment.py
    ADDED
    
    | @@ -0,0 +1,176 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            Provides cluster and tools configuration across clusters (slurm, dora, utilities).
         | 
| 9 | 
            +
            """
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import logging
         | 
| 12 | 
            +
            import os
         | 
| 13 | 
            +
            from pathlib import Path
         | 
| 14 | 
            +
            import re
         | 
| 15 | 
            +
            import typing as tp
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import omegaconf
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from .utils.cluster import _guess_cluster_type
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class AudioCraftEnvironment:
         | 
| 26 | 
            +
                """Environment configuration for teams and clusters.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
         | 
| 29 | 
            +
                or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
         | 
| 30 | 
            +
                provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
         | 
| 31 | 
            +
                allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
         | 
| 32 | 
            +
                map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
         | 
| 35 | 
            +
                Use the following environment variables to specify the cluster, team or configuration:
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
         | 
| 38 | 
            +
                        cannot be inferred automatically.
         | 
| 39 | 
            +
                    AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
         | 
| 40 | 
            +
                        If not set, configuration is read from config/teams.yaml.
         | 
| 41 | 
            +
                    AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
         | 
| 42 | 
            +
                        Cluster configuration are shared across teams to match compute allocation,
         | 
| 43 | 
            +
                        specify your cluster configuration in the configuration file under a key mapping
         | 
| 44 | 
            +
                        your team name.
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
                _instance = None
         | 
| 47 | 
            +
                DEFAULT_TEAM = "default"
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def __init__(self) -> None:
         | 
| 50 | 
            +
                    """Loads configuration."""
         | 
| 51 | 
            +
                    self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
         | 
| 52 | 
            +
                    cluster_type = _guess_cluster_type()
         | 
| 53 | 
            +
                    cluster = os.getenv(
         | 
| 54 | 
            +
                        "AUDIOCRAFT_CLUSTER", cluster_type.value
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
                    logger.info("Detecting cluster type %s", cluster_type)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.cluster: str = cluster
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    config_path = os.getenv(
         | 
| 61 | 
            +
                        "AUDIOCRAFT_CONFIG",
         | 
| 62 | 
            +
                        Path(__file__)
         | 
| 63 | 
            +
                        .parent.parent.joinpath("config/teams", self.team)
         | 
| 64 | 
            +
                        .with_suffix(".yaml"),
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    self.config = omegaconf.OmegaConf.load(config_path)
         | 
| 67 | 
            +
                    self._dataset_mappers = []
         | 
| 68 | 
            +
                    cluster_config = self._get_cluster_config()
         | 
| 69 | 
            +
                    if "dataset_mappers" in cluster_config:
         | 
| 70 | 
            +
                        for pattern, repl in cluster_config["dataset_mappers"].items():
         | 
| 71 | 
            +
                            regex = re.compile(pattern)
         | 
| 72 | 
            +
                            self._dataset_mappers.append((regex, repl))
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def _get_cluster_config(self) -> omegaconf.DictConfig:
         | 
| 75 | 
            +
                    assert isinstance(self.config, omegaconf.DictConfig)
         | 
| 76 | 
            +
                    return self.config[self.cluster]
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                @classmethod
         | 
| 79 | 
            +
                def instance(cls):
         | 
| 80 | 
            +
                    if cls._instance is None:
         | 
| 81 | 
            +
                        cls._instance = cls()
         | 
| 82 | 
            +
                    return cls._instance
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                @classmethod
         | 
| 85 | 
            +
                def reset(cls):
         | 
| 86 | 
            +
                    """Clears the environment and forces a reload on next invocation."""
         | 
| 87 | 
            +
                    cls._instance = None
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                @classmethod
         | 
| 90 | 
            +
                def get_team(cls) -> str:
         | 
| 91 | 
            +
                    """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
         | 
| 92 | 
            +
                    If not defined, defaults to "labs".
         | 
| 93 | 
            +
                    """
         | 
| 94 | 
            +
                    return cls.instance().team
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                @classmethod
         | 
| 97 | 
            +
                def get_cluster(cls) -> str:
         | 
| 98 | 
            +
                    """Gets the detected cluster.
         | 
| 99 | 
            +
                    This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
         | 
| 100 | 
            +
                    """
         | 
| 101 | 
            +
                    return cls.instance().cluster
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                @classmethod
         | 
| 104 | 
            +
                def get_dora_dir(cls) -> Path:
         | 
| 105 | 
            +
                    """Gets the path to the dora directory for the current team and cluster.
         | 
| 106 | 
            +
                    Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
         | 
| 107 | 
            +
                    """
         | 
| 108 | 
            +
                    cluster_config = cls.instance()._get_cluster_config()
         | 
| 109 | 
            +
                    dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
         | 
| 110 | 
            +
                    logger.warning(f"Dora directory: {dora_dir}")
         | 
| 111 | 
            +
                    return Path(dora_dir)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                @classmethod
         | 
| 114 | 
            +
                def get_reference_dir(cls) -> Path:
         | 
| 115 | 
            +
                    """Gets the path to the reference directory for the current team and cluster.
         | 
| 116 | 
            +
                    Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
         | 
| 117 | 
            +
                    """
         | 
| 118 | 
            +
                    cluster_config = cls.instance()._get_cluster_config()
         | 
| 119 | 
            +
                    return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                @classmethod
         | 
| 122 | 
            +
                def get_slurm_exclude(cls) -> tp.Optional[str]:
         | 
| 123 | 
            +
                    """Get the list of nodes to exclude for that cluster."""
         | 
| 124 | 
            +
                    cluster_config = cls.instance()._get_cluster_config()
         | 
| 125 | 
            +
                    return cluster_config.get("slurm_exclude")
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                @classmethod
         | 
| 128 | 
            +
                def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
         | 
| 129 | 
            +
                    """Gets the requested partitions for the current team and cluster as a comma-separated string.
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    Args:
         | 
| 132 | 
            +
                        partition_types (list[str], optional): partition types to retrieve. Values must be
         | 
| 133 | 
            +
                            from ['global', 'team']. If not provided, the global partition is returned.
         | 
| 134 | 
            +
                    """
         | 
| 135 | 
            +
                    if not partition_types:
         | 
| 136 | 
            +
                        partition_types = ["global"]
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    cluster_config = cls.instance()._get_cluster_config()
         | 
| 139 | 
            +
                    partitions = [
         | 
| 140 | 
            +
                        cluster_config["partitions"][partition_type]
         | 
| 141 | 
            +
                        for partition_type in partition_types
         | 
| 142 | 
            +
                    ]
         | 
| 143 | 
            +
                    return ",".join(partitions)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                @classmethod
         | 
| 146 | 
            +
                def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
         | 
| 147 | 
            +
                    """Converts reference placeholder in path with configured reference dir to resolve paths.
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    Args:
         | 
| 150 | 
            +
                        path (str or Path): Path to resolve.
         | 
| 151 | 
            +
                    Returns:
         | 
| 152 | 
            +
                        Path: Resolved path.
         | 
| 153 | 
            +
                    """
         | 
| 154 | 
            +
                    path = str(path)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    if path.startswith("//reference"):
         | 
| 157 | 
            +
                        reference_dir = cls.get_reference_dir()
         | 
| 158 | 
            +
                        logger.warn(f"Reference directory: {reference_dir}")
         | 
| 159 | 
            +
                        assert (
         | 
| 160 | 
            +
                            reference_dir.exists() and reference_dir.is_dir()
         | 
| 161 | 
            +
                        ), f"Reference directory does not exist: {reference_dir}."
         | 
| 162 | 
            +
                        path = re.sub("^//reference", str(reference_dir), path)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    return Path(path)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                @classmethod
         | 
| 167 | 
            +
                def apply_dataset_mappers(cls, path: str) -> str:
         | 
| 168 | 
            +
                    """Applies dataset mapping regex rules as defined in the configuration.
         | 
| 169 | 
            +
                    If no rules are defined, the path is returned as-is.
         | 
| 170 | 
            +
                    """
         | 
| 171 | 
            +
                    instance = cls.instance()
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    for pattern, repl in instance._dataset_mappers:
         | 
| 174 | 
            +
                        path = pattern.sub(repl, path)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    return path
         | 
    	
        audiocraft/grids/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Dora Grids."""
         | 
    	
        audiocraft/grids/_base_explorers.py
    ADDED
    
    | @@ -0,0 +1,80 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from abc import ABC, abstractmethod
         | 
| 8 | 
            +
            import time
         | 
| 9 | 
            +
            import typing as tp
         | 
| 10 | 
            +
            from dora import Explorer
         | 
| 11 | 
            +
            import treetable as tt
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def get_sheep_ping(sheep) -> tp.Optional[str]:
         | 
| 15 | 
            +
                """Return the amount of time since the Sheep made some update
         | 
| 16 | 
            +
                to its log. Returns a str using the relevant time unit."""
         | 
| 17 | 
            +
                ping = None
         | 
| 18 | 
            +
                if sheep.log is not None and sheep.log.exists():
         | 
| 19 | 
            +
                    delta = time.time() - sheep.log.stat().st_mtime
         | 
| 20 | 
            +
                    if delta > 3600 * 24:
         | 
| 21 | 
            +
                        ping = f'{delta / (3600 * 24):.1f}d'
         | 
| 22 | 
            +
                    elif delta > 3600:
         | 
| 23 | 
            +
                        ping = f'{delta / (3600):.1f}h'
         | 
| 24 | 
            +
                    elif delta > 60:
         | 
| 25 | 
            +
                        ping = f'{delta / 60:.1f}m'
         | 
| 26 | 
            +
                    else:
         | 
| 27 | 
            +
                        ping = f'{delta:.1f}s'
         | 
| 28 | 
            +
                return ping
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class BaseExplorer(ABC, Explorer):
         | 
| 32 | 
            +
                """Base explorer for AudioCraft grids.
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                All task specific solvers are expected to implement the `get_grid_metrics`
         | 
| 35 | 
            +
                method to specify logic about metrics to display for a given task.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                If additional stages are used, the child explorer must define how to handle
         | 
| 38 | 
            +
                these new stages in the `process_history` and `process_sheep` methods.
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                def stages(self):
         | 
| 41 | 
            +
                    return ["train", "valid", "evaluate"]
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def get_grid_meta(self):
         | 
| 44 | 
            +
                    """Returns the list of Meta information to display for each XP/job.
         | 
| 45 | 
            +
                    """
         | 
| 46 | 
            +
                    return [
         | 
| 47 | 
            +
                        tt.leaf("index", align=">"),
         | 
| 48 | 
            +
                        tt.leaf("name", wrap=140),
         | 
| 49 | 
            +
                        tt.leaf("state"),
         | 
| 50 | 
            +
                        tt.leaf("sig", align=">"),
         | 
| 51 | 
            +
                        tt.leaf("sid", align="<"),
         | 
| 52 | 
            +
                    ]
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                @abstractmethod
         | 
| 55 | 
            +
                def get_grid_metrics(self):
         | 
| 56 | 
            +
                    """Return the metrics that should be displayed in the tracking table.
         | 
| 57 | 
            +
                    """
         | 
| 58 | 
            +
                    ...
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def process_sheep(self, sheep, history):
         | 
| 61 | 
            +
                    train = {
         | 
| 62 | 
            +
                        "epoch": len(history),
         | 
| 63 | 
            +
                    }
         | 
| 64 | 
            +
                    parts = {"train": train}
         | 
| 65 | 
            +
                    for metrics in history:
         | 
| 66 | 
            +
                        for key, sub in metrics.items():
         | 
| 67 | 
            +
                            part = parts.get(key, {})
         | 
| 68 | 
            +
                            if 'duration' in sub:
         | 
| 69 | 
            +
                                # Convert to minutes for readability.
         | 
| 70 | 
            +
                                sub['duration'] = sub['duration'] / 60.
         | 
| 71 | 
            +
                            part.update(sub)
         | 
| 72 | 
            +
                            parts[key] = part
         | 
| 73 | 
            +
                    ping = get_sheep_ping(sheep)
         | 
| 74 | 
            +
                    if ping is not None:
         | 
| 75 | 
            +
                        for name in self.stages():
         | 
| 76 | 
            +
                            if name not in parts:
         | 
| 77 | 
            +
                                parts[name] = {}
         | 
| 78 | 
            +
                            # Add the ping to each part for convenience.
         | 
| 79 | 
            +
                            parts[name]['ping'] = ping
         | 
| 80 | 
            +
                    return parts
         | 
    	
        audiocraft/grids/audiogen/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """AudioGen grids."""
         | 
    	
        audiocraft/grids/audiogen/audiogen_base_16khz.py
    ADDED
    
    | @@ -0,0 +1,23 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from ..musicgen._explorers import LMExplorer
         | 
| 8 | 
            +
            from ...environment import AudioCraftEnvironment
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            @LMExplorer
         | 
| 12 | 
            +
            def explorer(launcher):
         | 
| 13 | 
            +
                partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
         | 
| 14 | 
            +
                launcher.slurm_(gpus=64, partition=partitions)
         | 
| 15 | 
            +
                launcher.bind_(solver='audiogen/audiogen_base_16khz')
         | 
| 16 | 
            +
                # replace this by the desired environmental sound dataset
         | 
| 17 | 
            +
                launcher.bind_(dset='internal/sounds_16khz')
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                fsdp = {'autocast': False, 'fsdp.use': True}
         | 
| 20 | 
            +
                medium = {'model/lm/model_scale': 'medium'}
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                launcher.bind_(fsdp)
         | 
| 23 | 
            +
                launcher(medium)
         | 
    	
        audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py
    ADDED
    
    | @@ -0,0 +1,68 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            Evaluation with objective metrics for the pretrained AudioGen models.
         | 
| 9 | 
            +
            This grid takes signature from the training grid and runs evaluation-only stage.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            When running the grid for the first time, please use:
         | 
| 12 | 
            +
            REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
         | 
| 13 | 
            +
            and re-use the REGEN=1 option when the grid is changed to force regenerating it.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            Note that you need the proper metrics external libraries setup to use all
         | 
| 16 | 
            +
            the objective metrics activated in this grid. Refer to the README for more information.
         | 
| 17 | 
            +
            """
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import os
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from ..musicgen._explorers import GenerationEvalExplorer
         | 
| 22 | 
            +
            from ...environment import AudioCraftEnvironment
         | 
| 23 | 
            +
            from ... import train
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def eval(launcher, batch_size: int = 32):
         | 
| 27 | 
            +
                opts = {
         | 
| 28 | 
            +
                    'dset': 'audio/audiocaps_16khz',
         | 
| 29 | 
            +
                    'solver/audiogen/evaluation': 'objective_eval',
         | 
| 30 | 
            +
                    'execute_only': 'evaluate',
         | 
| 31 | 
            +
                    '+dataset.evaluate.batch_size': batch_size,
         | 
| 32 | 
            +
                    '+metrics.fad.tf.batch_size': 32,
         | 
| 33 | 
            +
                }
         | 
| 34 | 
            +
                # binary for FAD computation: replace this path with your own path
         | 
| 35 | 
            +
                metrics_opts = {
         | 
| 36 | 
            +
                    'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
         | 
| 37 | 
            +
                }
         | 
| 38 | 
            +
                opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
         | 
| 39 | 
            +
                opt2 = {'transformer_lm.two_step_cfg': True}
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                sub = launcher.bind(opts)
         | 
| 42 | 
            +
                sub.bind_(metrics_opts)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # base objective metrics
         | 
| 45 | 
            +
                sub(opt1, opt2)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            @GenerationEvalExplorer
         | 
| 49 | 
            +
            def explorer(launcher):
         | 
| 50 | 
            +
                partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
         | 
| 51 | 
            +
                launcher.slurm_(gpus=4, partition=partitions)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                if 'REGEN' not in os.environ:
         | 
| 54 | 
            +
                    folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
         | 
| 55 | 
            +
                    with launcher.job_array():
         | 
| 56 | 
            +
                        for sig in folder.iterdir():
         | 
| 57 | 
            +
                            if not sig.is_symlink():
         | 
| 58 | 
            +
                                continue
         | 
| 59 | 
            +
                            xp = train.main.get_xp_from_sig(sig.name)
         | 
| 60 | 
            +
                            launcher(xp.argv)
         | 
| 61 | 
            +
                    return
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
         | 
| 64 | 
            +
                audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
         | 
| 67 | 
            +
                audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
         | 
| 68 | 
            +
                eval(audiogen_base_medium, batch_size=128)
         | 
    	
        audiocraft/grids/compression/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """EnCodec grids."""
         | 
    	
        audiocraft/grids/compression/_explorers.py
    ADDED
    
    | @@ -0,0 +1,55 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import treetable as tt
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from .._base_explorers import BaseExplorer
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class CompressionExplorer(BaseExplorer):
         | 
| 13 | 
            +
                eval_metrics = ["sisnr", "visqol"]
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def stages(self):
         | 
| 16 | 
            +
                    return ["train", "valid", "evaluate"]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def get_grid_meta(self):
         | 
| 19 | 
            +
                    """Returns the list of Meta information to display for each XP/job.
         | 
| 20 | 
            +
                    """
         | 
| 21 | 
            +
                    return [
         | 
| 22 | 
            +
                        tt.leaf("index", align=">"),
         | 
| 23 | 
            +
                        tt.leaf("name", wrap=140),
         | 
| 24 | 
            +
                        tt.leaf("state"),
         | 
| 25 | 
            +
                        tt.leaf("sig", align=">"),
         | 
| 26 | 
            +
                    ]
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def get_grid_metrics(self):
         | 
| 29 | 
            +
                    """Return the metrics that should be displayed in the tracking table.
         | 
| 30 | 
            +
                    """
         | 
| 31 | 
            +
                    return [
         | 
| 32 | 
            +
                        tt.group(
         | 
| 33 | 
            +
                            "train",
         | 
| 34 | 
            +
                            [
         | 
| 35 | 
            +
                                tt.leaf("epoch"),
         | 
| 36 | 
            +
                                tt.leaf("bandwidth", ".2f"),
         | 
| 37 | 
            +
                                tt.leaf("adv", ".4f"),
         | 
| 38 | 
            +
                                tt.leaf("d_loss", ".4f"),
         | 
| 39 | 
            +
                            ],
         | 
| 40 | 
            +
                            align=">",
         | 
| 41 | 
            +
                        ),
         | 
| 42 | 
            +
                        tt.group(
         | 
| 43 | 
            +
                            "valid",
         | 
| 44 | 
            +
                            [
         | 
| 45 | 
            +
                                tt.leaf("bandwidth", ".2f"),
         | 
| 46 | 
            +
                                tt.leaf("adv", ".4f"),
         | 
| 47 | 
            +
                                tt.leaf("msspec", ".4f"),
         | 
| 48 | 
            +
                                tt.leaf("sisnr", ".2f"),
         | 
| 49 | 
            +
                            ],
         | 
| 50 | 
            +
                            align=">",
         | 
| 51 | 
            +
                        ),
         | 
| 52 | 
            +
                        tt.group(
         | 
| 53 | 
            +
                            "evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">"
         | 
| 54 | 
            +
                        ),
         | 
| 55 | 
            +
                    ]
         | 
    	
        audiocraft/grids/compression/debug.py
    ADDED
    
    | @@ -0,0 +1,31 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            Grid search file, simply list all the exp you want in `explorer`.
         | 
| 9 | 
            +
            Any new exp added there will be scheduled.
         | 
| 10 | 
            +
            You can cancel and experiment by commenting its line.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            This grid is a minimal example for debugging compression task
         | 
| 13 | 
            +
            and how to override parameters directly in a grid.
         | 
| 14 | 
            +
            Learn more about dora grids: https://github.com/facebookresearch/dora
         | 
| 15 | 
            +
            """
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from ._explorers import CompressionExplorer
         | 
| 18 | 
            +
            from ...environment import AudioCraftEnvironment
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            @CompressionExplorer
         | 
| 22 | 
            +
            def explorer(launcher):
         | 
| 23 | 
            +
                partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
         | 
| 24 | 
            +
                launcher.slurm_(gpus=2, partition=partitions)
         | 
| 25 | 
            +
                launcher.bind_(solver='compression/debug')
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                with launcher.job_array():
         | 
| 28 | 
            +
                    # base debug task using config from solver=compression/debug
         | 
| 29 | 
            +
                    launcher()
         | 
| 30 | 
            +
                    # we can override parameters in the grid to launch additional xps
         | 
| 31 | 
            +
                    launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
         | 
    	
        audiocraft/grids/compression/encodec_audiogen_16khz.py
    ADDED
    
    | @@ -0,0 +1,29 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            Grid search file, simply list all the exp you want in `explorer`.
         | 
| 9 | 
            +
            Any new exp added there will be scheduled.
         | 
| 10 | 
            +
            You can cancel and experiment by commenting its line.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
         | 
| 13 | 
            +
            """
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from ._explorers import CompressionExplorer
         | 
| 16 | 
            +
            from ...environment import AudioCraftEnvironment
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            @CompressionExplorer
         | 
| 20 | 
            +
            def explorer(launcher):
         | 
| 21 | 
            +
                partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
         | 
| 22 | 
            +
                launcher.slurm_(gpus=8, partition=partitions)
         | 
| 23 | 
            +
                # use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
         | 
| 24 | 
            +
                # AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
         | 
| 25 | 
            +
                launcher.bind_(solver='compression/encodec_audiogen_16khz')
         | 
| 26 | 
            +
                # replace this by the desired sound dataset
         | 
| 27 | 
            +
                launcher.bind_(dset='internal/sounds_16khz')
         | 
| 28 | 
            +
                # launch xp
         | 
| 29 | 
            +
                launcher()
         | 
    	
        audiocraft/grids/compression/encodec_base_24khz.py
    ADDED
    
    | @@ -0,0 +1,28 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            Grid search file, simply list all the exp you want in `explorer`.
         | 
| 9 | 
            +
            Any new exp added there will be scheduled.
         | 
| 10 | 
            +
            You can cancel and experiment by commenting its line.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            This grid shows how to train a base causal EnCodec model at 24 kHz.
         | 
| 13 | 
            +
            """
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from ._explorers import CompressionExplorer
         | 
| 16 | 
            +
            from ...environment import AudioCraftEnvironment
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            @CompressionExplorer
         | 
| 20 | 
            +
            def explorer(launcher):
         | 
| 21 | 
            +
                partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
         | 
| 22 | 
            +
                launcher.slurm_(gpus=8, partition=partitions)
         | 
| 23 | 
            +
                # base causal EnCodec trained on monophonic audio sampled at 24 kHz
         | 
| 24 | 
            +
                launcher.bind_(solver='compression/encodec_base_24khz')
         | 
| 25 | 
            +
                # replace this by the desired dataset
         | 
| 26 | 
            +
                launcher.bind_(dset='audio/example')
         | 
| 27 | 
            +
                # launch xp
         | 
| 28 | 
            +
                launcher()
         | 
    	
        audiocraft/grids/compression/encodec_musicgen_32khz.py
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            Grid search file, simply list all the exp you want in `explorer`.
         | 
| 9 | 
            +
            Any new exp added there will be scheduled.
         | 
| 10 | 
            +
            You can cancel and experiment by commenting its line.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            This grid shows how to train a MusicGen EnCodec model at 32 kHz.
         | 
| 13 | 
            +
            """
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from ._explorers import CompressionExplorer
         | 
| 16 | 
            +
            from ...environment import AudioCraftEnvironment
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            @CompressionExplorer
         | 
| 20 | 
            +
            def explorer(launcher):
         | 
| 21 | 
            +
                partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
         | 
| 22 | 
            +
                launcher.slurm_(gpus=8, partition=partitions)
         | 
| 23 | 
            +
                # use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
         | 
| 24 | 
            +
                # MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
         | 
| 25 | 
            +
                launcher.bind_(solver='compression/encodec_musicgen_32khz')
         | 
| 26 | 
            +
                # replace this by the desired music dataset
         | 
| 27 | 
            +
                launcher.bind_(dset='internal/music_400k_32khz')
         | 
| 28 | 
            +
                # launch xp
         | 
| 29 | 
            +
                launcher()
         | 
| 30 | 
            +
                launcher({
         | 
| 31 | 
            +
                    'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
         | 
| 32 | 
            +
                    'label': 'visqol',
         | 
| 33 | 
            +
                    'evaluate.metrics.visqol': True
         | 
| 34 | 
            +
                })
         | 
    	
        audiocraft/grids/diffusion/4_bands_base_32khz.py
    ADDED
    
    | @@ -0,0 +1,27 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            Training of the 4 diffusion models described in
         | 
| 9 | 
            +
            "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
         | 
| 10 | 
            +
            (paper link).
         | 
| 11 | 
            +
            """
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from ._explorers import DiffusionExplorer
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            @DiffusionExplorer
         | 
| 17 | 
            +
            def explorer(launcher):
         | 
| 18 | 
            +
                launcher.slurm_(gpus=4, partition='learnfair')
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                launcher.bind_({'solver': 'diffusion/default',
         | 
| 21 | 
            +
                                'dset': 'internal/music_10k_32khz'})
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                with launcher.job_array():
         | 
| 24 | 
            +
                    launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
         | 
| 25 | 
            +
                    launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
         | 
| 26 | 
            +
                    launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
         | 
| 27 | 
            +
                    launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
         | 
    	
        audiocraft/grids/diffusion/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Diffusion grids."""
         | 
    	
        audiocraft/grids/diffusion/_explorers.py
    ADDED
    
    | @@ -0,0 +1,66 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import treetable as tt
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from .._base_explorers import BaseExplorer
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class DiffusionExplorer(BaseExplorer):
         | 
| 13 | 
            +
                eval_metrics = ["sisnr", "visqol"]
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def stages(self):
         | 
| 16 | 
            +
                    return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def get_grid_meta(self):
         | 
| 19 | 
            +
                    """Returns the list of Meta information to display for each XP/job.
         | 
| 20 | 
            +
                    """
         | 
| 21 | 
            +
                    return [
         | 
| 22 | 
            +
                        tt.leaf("index", align=">"),
         | 
| 23 | 
            +
                        tt.leaf("name", wrap=140),
         | 
| 24 | 
            +
                        tt.leaf("state"),
         | 
| 25 | 
            +
                        tt.leaf("sig", align=">"),
         | 
| 26 | 
            +
                    ]
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def get_grid_metrics(self):
         | 
| 29 | 
            +
                    """Return the metrics that should be displayed in the tracking table.
         | 
| 30 | 
            +
                    """
         | 
| 31 | 
            +
                    return [
         | 
| 32 | 
            +
                        tt.group(
         | 
| 33 | 
            +
                            "train",
         | 
| 34 | 
            +
                            [
         | 
| 35 | 
            +
                                tt.leaf("epoch"),
         | 
| 36 | 
            +
                                tt.leaf("loss", ".3%"),
         | 
| 37 | 
            +
                            ],
         | 
| 38 | 
            +
                            align=">",
         | 
| 39 | 
            +
                        ),
         | 
| 40 | 
            +
                        tt.group(
         | 
| 41 | 
            +
                            "valid",
         | 
| 42 | 
            +
                            [
         | 
| 43 | 
            +
                                tt.leaf("loss", ".3%"),
         | 
| 44 | 
            +
                                # tt.leaf("loss_0", ".3%"),
         | 
| 45 | 
            +
                            ],
         | 
| 46 | 
            +
                            align=">",
         | 
| 47 | 
            +
                        ),
         | 
| 48 | 
            +
                        tt.group(
         | 
| 49 | 
            +
                            "valid_ema",
         | 
| 50 | 
            +
                            [
         | 
| 51 | 
            +
                                tt.leaf("loss", ".3%"),
         | 
| 52 | 
            +
                                # tt.leaf("loss_0", ".3%"),
         | 
| 53 | 
            +
                            ],
         | 
| 54 | 
            +
                            align=">",
         | 
| 55 | 
            +
                        ),
         | 
| 56 | 
            +
                        tt.group(
         | 
| 57 | 
            +
                            "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
         | 
| 58 | 
            +
                                         tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
         | 
| 59 | 
            +
                                         tt.leaf("rvm_3", ".4f"), ], align=">"
         | 
| 60 | 
            +
                        ),
         | 
| 61 | 
            +
                        tt.group(
         | 
| 62 | 
            +
                            "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
         | 
| 63 | 
            +
                                             tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
         | 
| 64 | 
            +
                                             tt.leaf("rvm_3", ".4f")], align=">"
         | 
| 65 | 
            +
                        ),
         | 
| 66 | 
            +
                    ]
         | 
    	
        audiocraft/grids/musicgen/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """MusicGen grids."""
         | 
    	
        audiocraft/grids/musicgen/_explorers.py
    ADDED
    
    | @@ -0,0 +1,93 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import typing as tp
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import treetable as tt
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from .._base_explorers import BaseExplorer
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class LMExplorer(BaseExplorer):
         | 
| 15 | 
            +
                eval_metrics: tp.List[str] = []
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def stages(self) -> tp.List[str]:
         | 
| 18 | 
            +
                    return ['train', 'valid']
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def get_grid_metrics(self):
         | 
| 21 | 
            +
                    """Return the metrics that should be displayed in the tracking table."""
         | 
| 22 | 
            +
                    return [
         | 
| 23 | 
            +
                        tt.group(
         | 
| 24 | 
            +
                            'train',
         | 
| 25 | 
            +
                            [
         | 
| 26 | 
            +
                                tt.leaf('epoch'),
         | 
| 27 | 
            +
                                tt.leaf('duration', '.1f'),  # duration in minutes
         | 
| 28 | 
            +
                                tt.leaf('ping'),
         | 
| 29 | 
            +
                                tt.leaf('ce', '.4f'),  # cross entropy
         | 
| 30 | 
            +
                                tt.leaf("ppl", '.3f'),  # perplexity
         | 
| 31 | 
            +
                            ],
         | 
| 32 | 
            +
                            align='>',
         | 
| 33 | 
            +
                        ),
         | 
| 34 | 
            +
                        tt.group(
         | 
| 35 | 
            +
                            'valid',
         | 
| 36 | 
            +
                            [
         | 
| 37 | 
            +
                                tt.leaf('ce', '.4f'),
         | 
| 38 | 
            +
                                tt.leaf('ppl', '.3f'),
         | 
| 39 | 
            +
                                tt.leaf('best_ppl', '.3f'),
         | 
| 40 | 
            +
                            ],
         | 
| 41 | 
            +
                            align='>',
         | 
| 42 | 
            +
                        ),
         | 
| 43 | 
            +
                    ]
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def process_sheep(self, sheep, history):
         | 
| 46 | 
            +
                    parts = super().process_sheep(sheep, history)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    track_by = {'ppl': 'lower'}  # values should be in ['lower', 'higher']
         | 
| 49 | 
            +
                    best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    def comparator(mode, a, b):
         | 
| 52 | 
            +
                        return a < b if mode == 'lower' else a > b
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    for metrics in history:
         | 
| 55 | 
            +
                        for key, sub in metrics.items():
         | 
| 56 | 
            +
                            for metric in track_by:
         | 
| 57 | 
            +
                                # for the validation set, keep track of best metrics (ppl in this example)
         | 
| 58 | 
            +
                                # this is so we can conveniently compare metrics between runs in the grid
         | 
| 59 | 
            +
                                if key == 'valid' and metric in sub and comparator(
         | 
| 60 | 
            +
                                    track_by[metric], sub[metric], best_metrics[metric]
         | 
| 61 | 
            +
                                ):
         | 
| 62 | 
            +
                                    best_metrics[metric] = sub[metric]
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    if 'valid' in parts:
         | 
| 65 | 
            +
                        parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
         | 
| 66 | 
            +
                    return parts
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            class GenerationEvalExplorer(BaseExplorer):
         | 
| 70 | 
            +
                eval_metrics: tp.List[str] = []
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def stages(self) -> tp.List[str]:
         | 
| 73 | 
            +
                    return ['evaluate']
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def get_grid_metrics(self):
         | 
| 76 | 
            +
                    """Return the metrics that should be displayed in the tracking table."""
         | 
| 77 | 
            +
                    return [
         | 
| 78 | 
            +
                        tt.group(
         | 
| 79 | 
            +
                            'evaluate',
         | 
| 80 | 
            +
                            [
         | 
| 81 | 
            +
                                tt.leaf('epoch', '.3f'),
         | 
| 82 | 
            +
                                tt.leaf('duration', '.1f'),
         | 
| 83 | 
            +
                                tt.leaf('ping'),
         | 
| 84 | 
            +
                                tt.leaf('ce', '.4f'),
         | 
| 85 | 
            +
                                tt.leaf('ppl', '.3f'),
         | 
| 86 | 
            +
                                tt.leaf('fad', '.3f'),
         | 
| 87 | 
            +
                                tt.leaf('kld', '.3f'),
         | 
| 88 | 
            +
                                tt.leaf('text_consistency', '.3f'),
         | 
| 89 | 
            +
                                tt.leaf('chroma_cosine', '.3f'),
         | 
| 90 | 
            +
                            ],
         | 
| 91 | 
            +
                            align='>',
         | 
| 92 | 
            +
                        ),
         | 
| 93 | 
            +
                    ]
         | 
    	
        audiocraft/grids/musicgen/musicgen_base_32khz.py
    ADDED
    
    | @@ -0,0 +1,43 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from ._explorers import LMExplorer
         | 
| 8 | 
            +
            from ...environment import AudioCraftEnvironment
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            @LMExplorer
         | 
| 12 | 
            +
            def explorer(launcher):
         | 
| 13 | 
            +
                partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
         | 
| 14 | 
            +
                launcher.slurm_(gpus=32, partition=partitions)
         | 
| 15 | 
            +
                launcher.bind_(solver='musicgen/musicgen_base_32khz')
         | 
| 16 | 
            +
                # replace this by the desired music dataset
         | 
| 17 | 
            +
                launcher.bind_(dset='internal/music_400k_32khz')
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                fsdp = {'autocast': False, 'fsdp.use': True}
         | 
| 20 | 
            +
                medium = {'model/lm/model_scale': 'medium'}
         | 
| 21 | 
            +
                large = {'model/lm/model_scale': 'large'}
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
         | 
| 24 | 
            +
                wd_low = {'conditioners.description.t5.word_dropout': 0.2}
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                launcher.bind_(fsdp)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                launcher.slurm_(gpus=32).bind_(label='32gpus')
         | 
| 31 | 
            +
                with launcher.job_array():
         | 
| 32 | 
            +
                    sub = launcher.bind()
         | 
| 33 | 
            +
                    sub()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                launcher.slurm_(gpus=64).bind_(label='64gpus')
         | 
| 36 | 
            +
                with launcher.job_array():
         | 
| 37 | 
            +
                    sub = launcher.bind()
         | 
| 38 | 
            +
                    sub(medium, adam)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                launcher.slurm_(gpus=96).bind_(label='96gpus')
         | 
| 41 | 
            +
                with launcher.job_array():
         | 
| 42 | 
            +
                    sub = launcher.bind()
         | 
| 43 | 
            +
                    sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
         | 
    	
        audiocraft/grids/musicgen/musicgen_base_cached_32khz.py
    ADDED
    
    | @@ -0,0 +1,67 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from ._explorers import LMExplorer
         | 
| 8 | 
            +
            from ...environment import AudioCraftEnvironment
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            @LMExplorer
         | 
| 12 | 
            +
            def explorer(launcher):
         | 
| 13 | 
            +
                partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
         | 
| 14 | 
            +
                launcher.slurm_(gpus=32, partition=partitions)
         | 
| 15 | 
            +
                launcher.bind_(solver='musicgen/musicgen_base_32khz')
         | 
| 16 | 
            +
                # replace this by the desired music dataset
         | 
| 17 | 
            +
                launcher.bind_(dset='internal/music_400k_32khz')
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                fsdp = {'autocast': False, 'fsdp.use': True}
         | 
| 20 | 
            +
                medium = {'model/lm/model_scale': 'medium'}
         | 
| 21 | 
            +
                large = {'model/lm/model_scale': 'large'}
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
         | 
| 24 | 
            +
                wd_low = {'conditioners.description.t5.word_dropout': 0.2}
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                # BEGINNING OF CACHE WRITING JOBS.
         | 
| 29 | 
            +
                cache_write = {
         | 
| 30 | 
            +
                    'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
         | 
| 31 | 
            +
                    'cache.write': True,
         | 
| 32 | 
            +
                    'generate.every': 500,
         | 
| 33 | 
            +
                    'evaluate.every': 500,
         | 
| 34 | 
            +
                    'logging.log_updates': 50,
         | 
| 35 | 
            +
                }
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'})
         | 
| 38 | 
            +
                cache_sub.bind_({'deadlock.use': True})
         | 
| 39 | 
            +
                cache_sub.slurm_(gpus=8)
         | 
| 40 | 
            +
                with launcher.job_array():
         | 
| 41 | 
            +
                    num_shards = 10  # total number of jobs running in parallel.
         | 
| 42 | 
            +
                    for shard in range(0, num_shards):
         | 
| 43 | 
            +
                        launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard})
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                # REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE,
         | 
| 46 | 
            +
                # OR SUFFICIENTLY AHEAD.
         | 
| 47 | 
            +
                return
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                cache = {
         | 
| 50 | 
            +
                    'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
         | 
| 51 | 
            +
                }
         | 
| 52 | 
            +
                launcher.bind_(fsdp, cache)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                launcher.slurm_(gpus=32).bind_(label='32gpus')
         | 
| 55 | 
            +
                with launcher.job_array():
         | 
| 56 | 
            +
                    sub = launcher.bind()
         | 
| 57 | 
            +
                    sub()
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                launcher.slurm_(gpus=64).bind_(label='64gpus')
         | 
| 60 | 
            +
                with launcher.job_array():
         | 
| 61 | 
            +
                    sub = launcher.bind()
         | 
| 62 | 
            +
                    sub(medium, adam)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                launcher.slurm_(gpus=96).bind_(label='96gpus')
         | 
| 65 | 
            +
                with launcher.job_array():
         | 
| 66 | 
            +
                    sub = launcher.bind()
         | 
| 67 | 
            +
                    sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
         | 
    	
        audiocraft/grids/musicgen/musicgen_clapemb_32khz.py
    ADDED
    
    | @@ -0,0 +1,32 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from ._explorers import LMExplorer
         | 
| 8 | 
            +
            from ...environment import AudioCraftEnvironment
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            @LMExplorer
         | 
| 12 | 
            +
            def explorer(launcher):
         | 
| 13 | 
            +
                partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
         | 
| 14 | 
            +
                launcher.slurm_(gpus=32, partition=partitions)
         | 
| 15 | 
            +
                launcher.bind_(solver='musicgen/musicgen_base_32khz')
         | 
| 16 | 
            +
                # replace this by the desired music dataset
         | 
| 17 | 
            +
                launcher.bind_(dset='internal/music_400k_32khz')
         | 
| 18 | 
            +
                launcher.bind_(conditioner='clapemb2music')
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                fsdp = {'autocast': False, 'fsdp.use': True}
         | 
| 21 | 
            +
                cache_path = {'conditioners.description.clap.cache_path':
         | 
| 22 | 
            +
                              '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'}
         | 
| 23 | 
            +
                text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5}
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                launcher.bind_(fsdp)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                launcher.slurm_(gpus=32).bind_(label='32gpus')
         | 
| 28 | 
            +
                with launcher.job_array():
         | 
| 29 | 
            +
                    launcher()
         | 
| 30 | 
            +
                    launcher(text_wav_training_opt)
         | 
| 31 | 
            +
                    launcher(cache_path)
         | 
| 32 | 
            +
                    launcher(cache_path, text_wav_training_opt)
         | 
    	
        audiocraft/grids/musicgen/musicgen_melody_32khz.py
    ADDED
    
    | @@ -0,0 +1,65 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from ._explorers import LMExplorer
         | 
| 8 | 
            +
            from ...environment import AudioCraftEnvironment
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            @LMExplorer
         | 
| 12 | 
            +
            def explorer(launcher):
         | 
| 13 | 
            +
                partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
         | 
| 14 | 
            +
                launcher.slurm_(gpus=32, partition=partitions)
         | 
| 15 | 
            +
                launcher.bind_(solver='musicgen/musicgen_melody_32khz')
         | 
| 16 | 
            +
                # replace this by the desired music dataset
         | 
| 17 | 
            +
                launcher.bind_(dset='internal/music_400k_32khz')
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                fsdp = {'autocast': False, 'fsdp.use': True}
         | 
| 20 | 
            +
                medium = {'model/lm/model_scale': 'medium'}
         | 
| 21 | 
            +
                large = {'model/lm/model_scale': 'large'}
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
         | 
| 24 | 
            +
                wd_low = {'conditioners.description.t5.word_dropout': 0.2}
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                cache_path = {'conditioners.self_wav.chroma_stem.cache_path':
         | 
| 29 | 
            +
                              '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'}
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                # CACHE GENERATION JOBS
         | 
| 32 | 
            +
                n_cache_gen_jobs = 4
         | 
| 33 | 
            +
                gen_sub = launcher.slurm(gpus=1)
         | 
| 34 | 
            +
                gen_sub.bind_(
         | 
| 35 | 
            +
                    cache_path, {
         | 
| 36 | 
            +
                        # the cache is always computed over the whole file, so duration doesn't matter here.
         | 
| 37 | 
            +
                        'dataset.segment_duration': 2.,
         | 
| 38 | 
            +
                        'dataset.batch_size': 8,
         | 
| 39 | 
            +
                        'dataset.train.permutation_on_files': True,  # try to not repeat files.
         | 
| 40 | 
            +
                        'optim.epochs': 10,
         | 
| 41 | 
            +
                        'model/lm/model_scale': 'xsmall',
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    })
         | 
| 44 | 
            +
                with gen_sub.job_array():
         | 
| 45 | 
            +
                    for gen_job in range(n_cache_gen_jobs):
         | 
| 46 | 
            +
                        gen_sub({'dataset.train.shuffle_seed': gen_job})
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                # ACTUAL TRAINING JOBS.
         | 
| 49 | 
            +
                launcher.bind_(fsdp)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                launcher.slurm_(gpus=32).bind_(label='32gpus')
         | 
| 52 | 
            +
                with launcher.job_array():
         | 
| 53 | 
            +
                    sub = launcher.bind()
         | 
| 54 | 
            +
                    sub()
         | 
| 55 | 
            +
                    sub(cache_path)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                launcher.slurm_(gpus=64).bind_(label='64gpus')
         | 
| 58 | 
            +
                with launcher.job_array():
         | 
| 59 | 
            +
                    sub = launcher.bind()
         | 
| 60 | 
            +
                    sub(medium, adam)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                launcher.slurm_(gpus=96).bind_(label='96gpus')
         | 
| 63 | 
            +
                with launcher.job_array():
         | 
| 64 | 
            +
                    sub = launcher.bind()
         | 
| 65 | 
            +
                    sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
         | 
    	
        audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py
    ADDED
    
    | @@ -0,0 +1,99 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            Evaluation with objective metrics for the pretrained MusicGen models.
         | 
| 9 | 
            +
            This grid takes signature from the training grid and runs evaluation-only stage.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            When running the grid for the first time, please use:
         | 
| 12 | 
            +
            REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval
         | 
| 13 | 
            +
            and re-use the REGEN=1 option when the grid is changed to force regenerating it.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            Note that you need the proper metrics external libraries setup to use all
         | 
| 16 | 
            +
            the objective metrics activated in this grid. Refer to the README for more information.
         | 
| 17 | 
            +
            """
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import os
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from ._explorers import GenerationEvalExplorer
         | 
| 22 | 
            +
            from ...environment import AudioCraftEnvironment
         | 
| 23 | 
            +
            from ... import train
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def eval(launcher, batch_size: int = 32, eval_melody: bool = False):
         | 
| 27 | 
            +
                opts = {
         | 
| 28 | 
            +
                    'dset': 'audio/musiccaps_32khz',
         | 
| 29 | 
            +
                    'solver/musicgen/evaluation': 'objective_eval',
         | 
| 30 | 
            +
                    'execute_only': 'evaluate',
         | 
| 31 | 
            +
                    '+dataset.evaluate.batch_size': batch_size,
         | 
| 32 | 
            +
                    '+metrics.fad.tf.batch_size': 16,
         | 
| 33 | 
            +
                }
         | 
| 34 | 
            +
                # chroma-specific evaluation
         | 
| 35 | 
            +
                chroma_opts = {
         | 
| 36 | 
            +
                    'dset': 'internal/music_400k_32khz',
         | 
| 37 | 
            +
                    'dataset.evaluate.segment_duration': 30,
         | 
| 38 | 
            +
                    'dataset.evaluate.num_samples': 1000,
         | 
| 39 | 
            +
                    'evaluate.metrics.chroma_cosine': True,
         | 
| 40 | 
            +
                    'evaluate.metrics.fad': False,
         | 
| 41 | 
            +
                    'evaluate.metrics.kld': False,
         | 
| 42 | 
            +
                    'evaluate.metrics.text_consistency': False,
         | 
| 43 | 
            +
                }
         | 
| 44 | 
            +
                # binary for FAD computation: replace this path with your own path
         | 
| 45 | 
            +
                metrics_opts = {
         | 
| 46 | 
            +
                    'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
         | 
| 47 | 
            +
                }
         | 
| 48 | 
            +
                opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
         | 
| 49 | 
            +
                opt2 = {'transformer_lm.two_step_cfg': True}
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                sub = launcher.bind(opts)
         | 
| 52 | 
            +
                sub.bind_(metrics_opts)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                # base objective metrics
         | 
| 55 | 
            +
                sub(opt1, opt2)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                if eval_melody:
         | 
| 58 | 
            +
                    # chroma-specific metrics
         | 
| 59 | 
            +
                    sub(opt1, opt2, chroma_opts)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            @GenerationEvalExplorer
         | 
| 63 | 
            +
            def explorer(launcher):
         | 
| 64 | 
            +
                partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
         | 
| 65 | 
            +
                launcher.slurm_(gpus=4, partition=partitions)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                if 'REGEN' not in os.environ:
         | 
| 68 | 
            +
                    folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
         | 
| 69 | 
            +
                    with launcher.job_array():
         | 
| 70 | 
            +
                        for sig in folder.iterdir():
         | 
| 71 | 
            +
                            if not sig.is_symlink():
         | 
| 72 | 
            +
                                continue
         | 
| 73 | 
            +
                            xp = train.main.get_xp_from_sig(sig.name)
         | 
| 74 | 
            +
                            launcher(xp.argv)
         | 
| 75 | 
            +
                    return
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                with launcher.job_array():
         | 
| 78 | 
            +
                    musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz")
         | 
| 79 | 
            +
                    musicgen_base.bind_({'autocast': False, 'fsdp.use': True})
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    # base musicgen models
         | 
| 82 | 
            +
                    musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'})
         | 
| 83 | 
            +
                    eval(musicgen_base_small, batch_size=128)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'})
         | 
| 86 | 
            +
                    musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'})
         | 
| 87 | 
            +
                    eval(musicgen_base_medium, batch_size=128)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'})
         | 
| 90 | 
            +
                    musicgen_base_large.bind_({'model/lm/model_scale': 'large'})
         | 
| 91 | 
            +
                    eval(musicgen_base_large, batch_size=128)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    # melody musicgen model
         | 
| 94 | 
            +
                    musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz")
         | 
| 95 | 
            +
                    musicgen_melody.bind_({'autocast': False, 'fsdp.use': True})
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'})
         | 
| 98 | 
            +
                    musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'})
         | 
| 99 | 
            +
                    eval(musicgen_melody_medium, batch_size=128, eval_melody=True)
         | 

 
		