Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		liuhuadai
		
	commited on
		
		
					Commit 
							
							·
						
						052cf68
	
1
								Parent(s):
							
							70bc476
								
support cot
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- {think_sound → ThinkSound}/__init__.py +0 -0
 - {think_sound/configs/model_configs/autoencoders → ThinkSound/configs/model_configs}/stable_audio_2_0_vae.json +0 -0
 - think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json → ThinkSound/configs/model_configs/thinksound.json +1 -1
 - ThinkSound/configs/multimodal_dataset_demo.json +53 -0
 - {data_utils → ThinkSound/data}/__init__.py +0 -0
 - {think_sound → ThinkSound}/data/datamodule.py +4 -2
 - {think_sound → ThinkSound}/data/dataset.py +6 -8
 - {think_sound → ThinkSound}/data/utils.py +0 -0
 - {think_sound/data → ThinkSound/inference}/__init__.py +0 -0
 - {think_sound → ThinkSound}/inference/generation.py +0 -0
 - {think_sound → ThinkSound}/inference/sampling.py +0 -0
 - {think_sound → ThinkSound}/inference/utils.py +0 -0
 - {think_sound → ThinkSound}/models/__init__.py +0 -0
 - {think_sound → ThinkSound}/models/autoencoders.py +0 -0
 - {think_sound → ThinkSound}/models/blocks.py +92 -1
 - {think_sound → ThinkSound}/models/bottleneck.py +0 -0
 - {think_sound → ThinkSound}/models/codebook_patterns.py +0 -0
 - {think_sound → ThinkSound}/models/conditioners.py +0 -1
 - {think_sound → ThinkSound}/models/diffusion.py +1 -3
 - {think_sound → ThinkSound}/models/dit.py +0 -0
 - {think_sound/models/mmmodules/model → ThinkSound/models}/embeddings.py +36 -0
 - {think_sound → ThinkSound}/models/factory.py +0 -0
 - {think_sound → ThinkSound}/models/local_attention.py +0 -0
 - {think_sound → ThinkSound}/models/mmdit.py +56 -9
 - {think_sound → ThinkSound}/models/pretrained.py +0 -0
 - {think_sound → ThinkSound}/models/pretransforms.py +0 -0
 - {think_sound → ThinkSound}/models/transformer.py +0 -0
 - {think_sound/models/mmmodules/model → ThinkSound/models}/transformer_layers.py +2 -2
 - {think_sound → ThinkSound}/models/utils.py +0 -0
 - {think_sound → ThinkSound}/training/__init__.py +0 -0
 - {think_sound → ThinkSound}/training/autoencoders.py +0 -1
 - {think_sound → ThinkSound}/training/diffusion.py +1 -948
 - {think_sound → ThinkSound}/training/factory.py +0 -0
 - {think_sound → ThinkSound}/training/losses/__init__.py +0 -0
 - {think_sound → ThinkSound}/training/losses/auraloss.py +0 -0
 - {think_sound → ThinkSound}/training/losses/losses.py +0 -0
 - {think_sound → ThinkSound}/training/utils.py +0 -0
 - app.py +50 -59
 - cot_vgg_demo_caption.txt +1 -0
 - data_utils/__pycache__/__init__.cpython-310.pyc +0 -0
 - data_utils/__pycache__/utils.cpython-310.pyc +0 -0
 - data_utils/__pycache__/utils.cpython-39.pyc +0 -0
 - data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc +0 -0
 - data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc +0 -0
 - data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc +0 -0
 - data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc +0 -0
 - data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc +0 -0
 - data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc +0 -0
 - data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc +0 -0
 - data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc +0 -0
 
    	
        {think_sound → ThinkSound}/__init__.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound/configs/model_configs/autoencoders → ThinkSound/configs/model_configs}/stable_audio_2_0_vae.json
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json → ThinkSound/configs/model_configs/thinksound.json
    RENAMED
    
    | 
         @@ -85,7 +85,7 @@ 
     | 
|
| 85 | 
         
             
                            "clip_dim":1024,
         
     | 
| 86 | 
         
             
                            "sync_dim":768,
         
     | 
| 87 | 
         
             
                            "text_dim":2048,
         
     | 
| 88 | 
         
            -
                            "hidden_dim":1024 
     | 
| 89 | 
         
             
                            "depth":21,
         
     | 
| 90 | 
         
             
                            "fused_depth":14,
         
     | 
| 91 | 
         
             
                            "num_heads":16,
         
     | 
| 
         | 
|
| 85 | 
         
             
                            "clip_dim":1024,
         
     | 
| 86 | 
         
             
                            "sync_dim":768,
         
     | 
| 87 | 
         
             
                            "text_dim":2048,
         
     | 
| 88 | 
         
            +
                            "hidden_dim":1024,
         
     | 
| 89 | 
         
             
                            "depth":21,
         
     | 
| 90 | 
         
             
                            "fused_depth":14,
         
     | 
| 91 | 
         
             
                            "num_heads":16,
         
     | 
    	
        ThinkSound/configs/multimodal_dataset_demo.json
    ADDED
    
    | 
         @@ -0,0 +1,53 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
                "dataset_type": "multimodal_dir",
         
     | 
| 3 | 
         
            +
                "video_datasets": [
         
     | 
| 4 | 
         
            +
                    {
         
     | 
| 5 | 
         
            +
                        "id": "vggsound",
         
     | 
| 6 | 
         
            +
                        "path": "dataset/vggsound/video_latents_t5_clip_npz/train",
         
     | 
| 7 | 
         
            +
                        "split_path": "dataset/vggsound/split_txt/train_cot.txt"
         
     | 
| 8 | 
         
            +
                    }
         
     | 
| 9 | 
         
            +
                ],
         
     | 
| 10 | 
         
            +
                "audio_datasets": [
         
     | 
| 11 | 
         
            +
                    {
         
     | 
| 12 | 
         
            +
                        "id": "audiostock",
         
     | 
| 13 | 
         
            +
                        "path": "dataset/Laion-Audio-630k/audiostock_latents_npz",
         
     | 
| 14 | 
         
            +
                        "split_path": "dataset/Laion-Audio-630k/split_txt/cot_audiostock_1.txt"
         
     | 
| 15 | 
         
            +
                    },
         
     | 
| 16 | 
         
            +
                    {
         
     | 
| 17 | 
         
            +
                        "id": "freesound_no_overlap",
         
     | 
| 18 | 
         
            +
                        "path": "dataset/Laion-Audio-630k/freesound_no_overlap_latents_npz",
         
     | 
| 19 | 
         
            +
                        "split_path": "dataset/Laion-Audio-630k/split_txt/cot_freesound.txt"
         
     | 
| 20 | 
         
            +
                    },
         
     | 
| 21 | 
         
            +
                    {
         
     | 
| 22 | 
         
            +
                        "id": "audioset_sl",
         
     | 
| 23 | 
         
            +
                        "path": "dataset/wavcaps/audioset_sl_latents_npz",
         
     | 
| 24 | 
         
            +
                        "split_path": "dataset/wavcaps/split_txt/cot_audio_sl_1.txt"
         
     | 
| 25 | 
         
            +
                    },
         
     | 
| 26 | 
         
            +
                    {
         
     | 
| 27 | 
         
            +
                        "id": "audiocaps",
         
     | 
| 28 | 
         
            +
                        "path": "dataset/1_audiocaps/audiocaps_latents_npz",
         
     | 
| 29 | 
         
            +
                        "split_path": "dataset/1_audiocaps/split_txt/train_cot.txt"
         
     | 
| 30 | 
         
            +
                    },
         
     | 
| 31 | 
         
            +
                    {
         
     | 
| 32 | 
         
            +
                        "id": "bbc",
         
     | 
| 33 | 
         
            +
                        "path": "dataset/Laion-Audio-630k/bbc_latents_npz",
         
     | 
| 34 | 
         
            +
                        "split_path": "dataset/Laion-Audio-630k/split_txt/cot_bbc_1.txt"
         
     | 
| 35 | 
         
            +
                    }
         
     | 
| 36 | 
         
            +
                ],
         
     | 
| 37 | 
         
            +
                "val_datasets": [
         
     | 
| 38 | 
         
            +
                    {
         
     | 
| 39 | 
         
            +
                        "id": "vggsound",
         
     | 
| 40 | 
         
            +
                        "path": "dataset/vggsound/video_latents_t5_clip_npz/test",
         
     | 
| 41 | 
         
            +
                        "split_path": "dataset/vggsound/split_txt/test_cot.txt"
         
     | 
| 42 | 
         
            +
                    }
         
     | 
| 43 | 
         
            +
                ],
         
     | 
| 44 | 
         
            +
                "test_datasets": [
         
     | 
| 45 | 
         
            +
                    {
         
     | 
| 46 | 
         
            +
                        "id": "vggsound",
         
     | 
| 47 | 
         
            +
                        "path": "cot_coarse",
         
     | 
| 48 | 
         
            +
                        "split_path": "cot_vgg_demo_caption.txt"
         
     | 
| 49 | 
         
            +
                    }
         
     | 
| 50 | 
         
            +
                ],
         
     | 
| 51 | 
         
            +
                "random_crop": true,
         
     | 
| 52 | 
         
            +
                "input_type": "prompt"
         
     | 
| 53 | 
         
            +
            }
         
     | 
    	
        {data_utils → ThinkSound/data}/__init__.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/data/datamodule.py
    RENAMED
    
    | 
         @@ -33,13 +33,14 @@ def get_configs(audio_configs): 
     | 
|
| 33 | 
         
             
                return configs
         
     | 
| 34 | 
         | 
| 35 | 
         
             
            class DataModule(L.LightningDataModule):
         
     | 
| 36 | 
         
            -
                def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5):
         
     | 
| 37 | 
         
             
                    super().__init__()
         
     | 
| 38 | 
         
             
                    dataset_type = dataset_config.get("dataset_type", None)
         
     | 
| 39 | 
         
             
                    self.batch_size = batch_size
         
     | 
| 40 | 
         
             
                    self.num_workers = num_workers
         
     | 
| 41 | 
         
             
                    self.test_batch_size = test_batch_size
         
     | 
| 42 | 
         
             
                    self.repeat_num = repeat_num
         
     | 
| 
         | 
|
| 43 | 
         
             
                    assert dataset_type is not None, "Dataset type must be specified in dataset config"
         
     | 
| 44 | 
         | 
| 45 | 
         
             
                    if audio_channels == 1:
         
     | 
| 
         @@ -140,7 +141,8 @@ class DataModule(L.LightningDataModule): 
     | 
|
| 140 | 
         
             
                            random_crop=random_crop,
         
     | 
| 141 | 
         
             
                            input_type=self.input_type,
         
     | 
| 142 | 
         
             
                            fps=self.input_type,
         
     | 
| 143 | 
         
            -
                            force_channels=self.force_channels
         
     | 
| 
         | 
|
| 144 | 
         
             
                        )
         
     | 
| 145 | 
         | 
| 146 | 
         
             
                    if stage == 'fit':
         
     | 
| 
         | 
|
| 33 | 
         
             
                return configs
         
     | 
| 34 | 
         | 
| 35 | 
         
             
            class DataModule(L.LightningDataModule):
         
     | 
| 36 | 
         
            +
                def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5,latent_length=194):
         
     | 
| 37 | 
         
             
                    super().__init__()
         
     | 
| 38 | 
         
             
                    dataset_type = dataset_config.get("dataset_type", None)
         
     | 
| 39 | 
         
             
                    self.batch_size = batch_size
         
     | 
| 40 | 
         
             
                    self.num_workers = num_workers
         
     | 
| 41 | 
         
             
                    self.test_batch_size = test_batch_size
         
     | 
| 42 | 
         
             
                    self.repeat_num = repeat_num
         
     | 
| 43 | 
         
            +
                    self.latent_length = latent_length
         
     | 
| 44 | 
         
             
                    assert dataset_type is not None, "Dataset type must be specified in dataset config"
         
     | 
| 45 | 
         | 
| 46 | 
         
             
                    if audio_channels == 1:
         
     | 
| 
         | 
|
| 141 | 
         
             
                            random_crop=random_crop,
         
     | 
| 142 | 
         
             
                            input_type=self.input_type,
         
     | 
| 143 | 
         
             
                            fps=self.input_type,
         
     | 
| 144 | 
         
            +
                            force_channels=self.force_channels,
         
     | 
| 145 | 
         
            +
                            latent_length=self.latent_length
         
     | 
| 146 | 
         
             
                        )
         
     | 
| 147 | 
         | 
| 148 | 
         
             
                    if stage == 'fit':
         
     | 
    	
        {think_sound → ThinkSound}/data/dataset.py
    RENAMED
    
    | 
         @@ -342,8 +342,7 @@ class LatentDataset(torch.utils.data.Dataset): 
     | 
|
| 342 | 
         
             
                    info = {}
         
     | 
| 343 | 
         
             
                    audio, video = self.load_file(audio_filename, info)
         
     | 
| 344 | 
         
             
                    info["path"] = audio_filename
         
     | 
| 345 | 
         
            -
             
     | 
| 346 | 
         
            -
                    assert video.shape == (72,1024), f'{video.shape} input error, id: {id}'
         
     | 
| 347 | 
         
             
                    info['id'] = Path(audio_filename).stem
         
     | 
| 348 | 
         
             
                    for root_path in self.root_paths:
         
     | 
| 349 | 
         
             
                        if root_path in audio_filename:
         
     | 
| 
         @@ -434,8 +433,7 @@ class AudioDataset(torch.utils.data.Dataset): 
     | 
|
| 434 | 
         
             
                    info = {}
         
     | 
| 435 | 
         
             
                    audio, video = self.load_file(audio_filename, info)
         
     | 
| 436 | 
         
             
                    info["path"] = audio_filename
         
     | 
| 437 | 
         
            -
             
     | 
| 438 | 
         
            -
                    assert video.shape == (72,1024), f'{video.shape} input error, id: {id}'
         
     | 
| 439 | 
         
             
                    info['id'] = Path(audio_filename).stem
         
     | 
| 440 | 
         
             
                    for root_path in self.root_paths:
         
     | 
| 441 | 
         
             
                        if root_path in audio_filename:
         
     | 
| 
         @@ -454,8 +452,9 @@ class VideoDataset(torch.utils.data.Dataset): 
     | 
|
| 454 | 
         
             
                    input_type="prompt",
         
     | 
| 455 | 
         
             
                    fps=4,
         
     | 
| 456 | 
         
             
                    force_channels="stereo",
         
     | 
| 
         | 
|
| 457 | 
         
             
                ):
         
     | 
| 458 | 
         
            -
             
     | 
| 459 | 
         
             
                    super().__init__()
         
     | 
| 460 | 
         
             
                    self.filenames = []
         
     | 
| 461 | 
         
             
                    print(f'configs: {configs[0]}')
         
     | 
| 
         @@ -523,7 +522,7 @@ class VideoDataset(torch.utils.data.Dataset): 
     | 
|
| 523 | 
         
             
                    if 'latent' in data.keys():
         
     | 
| 524 | 
         
             
                        audio = data['latent']
         
     | 
| 525 | 
         
             
                    else:
         
     | 
| 526 | 
         
            -
                        audio = torch.zeros(64, 
     | 
| 527 | 
         
             
                    info['video_exist'] = self.video_exist
         
     | 
| 528 | 
         
             
                    # except:
         
     | 
| 529 | 
         
             
                    #     print(f'error load file: {filename}')
         
     | 
| 
         @@ -540,8 +539,7 @@ class VideoDataset(torch.utils.data.Dataset): 
     | 
|
| 540 | 
         
             
                    info = {}
         
     | 
| 541 | 
         
             
                    audio, video = self.load_file(audio_filename, info)
         
     | 
| 542 | 
         
             
                    info["path"] = audio_filename
         
     | 
| 543 | 
         
            -
             
     | 
| 544 | 
         
            -
                    assert video.shape == (72,1024), f'{video.shape} input error, id: {id}'
         
     | 
| 545 | 
         
             
                    info['id'] = Path(audio_filename).stem
         
     | 
| 546 | 
         
             
                    for root_path in self.root_paths:
         
     | 
| 547 | 
         
             
                        if root_path in audio_filename:
         
     | 
| 
         | 
|
| 342 | 
         
             
                    info = {}
         
     | 
| 343 | 
         
             
                    audio, video = self.load_file(audio_filename, info)
         
     | 
| 344 | 
         
             
                    info["path"] = audio_filename
         
     | 
| 345 | 
         
            +
             
     | 
| 
         | 
|
| 346 | 
         
             
                    info['id'] = Path(audio_filename).stem
         
     | 
| 347 | 
         
             
                    for root_path in self.root_paths:
         
     | 
| 348 | 
         
             
                        if root_path in audio_filename:
         
     | 
| 
         | 
|
| 433 | 
         
             
                    info = {}
         
     | 
| 434 | 
         
             
                    audio, video = self.load_file(audio_filename, info)
         
     | 
| 435 | 
         
             
                    info["path"] = audio_filename
         
     | 
| 436 | 
         
            +
             
     | 
| 
         | 
|
| 437 | 
         
             
                    info['id'] = Path(audio_filename).stem
         
     | 
| 438 | 
         
             
                    for root_path in self.root_paths:
         
     | 
| 439 | 
         
             
                        if root_path in audio_filename:
         
     | 
| 
         | 
|
| 452 | 
         
             
                    input_type="prompt",
         
     | 
| 453 | 
         
             
                    fps=4,
         
     | 
| 454 | 
         
             
                    force_channels="stereo",
         
     | 
| 455 | 
         
            +
                    latent_length=194,  # default latent length for video dataset
         
     | 
| 456 | 
         
             
                ):
         
     | 
| 457 | 
         
            +
                    self.latent_length = latent_length
         
     | 
| 458 | 
         
             
                    super().__init__()
         
     | 
| 459 | 
         
             
                    self.filenames = []
         
     | 
| 460 | 
         
             
                    print(f'configs: {configs[0]}')
         
     | 
| 
         | 
|
| 522 | 
         
             
                    if 'latent' in data.keys():
         
     | 
| 523 | 
         
             
                        audio = data['latent']
         
     | 
| 524 | 
         
             
                    else:
         
     | 
| 525 | 
         
            +
                        audio = torch.zeros(64,self.latent_length)
         
     | 
| 526 | 
         
             
                    info['video_exist'] = self.video_exist
         
     | 
| 527 | 
         
             
                    # except:
         
     | 
| 528 | 
         
             
                    #     print(f'error load file: {filename}')
         
     | 
| 
         | 
|
| 539 | 
         
             
                    info = {}
         
     | 
| 540 | 
         
             
                    audio, video = self.load_file(audio_filename, info)
         
     | 
| 541 | 
         
             
                    info["path"] = audio_filename
         
     | 
| 542 | 
         
            +
             
     | 
| 
         | 
|
| 543 | 
         
             
                    info['id'] = Path(audio_filename).stem
         
     | 
| 544 | 
         
             
                    for root_path in self.root_paths:
         
     | 
| 545 | 
         
             
                        if root_path in audio_filename:
         
     | 
    	
        {think_sound → ThinkSound}/data/utils.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound/data → ThinkSound/inference}/__init__.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/inference/generation.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/inference/sampling.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/inference/utils.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/models/__init__.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/models/autoencoders.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/models/blocks.py
    RENAMED
    
    | 
         @@ -336,4 +336,95 @@ class SnakeBeta(nn.Module): 
     | 
|
| 336 | 
         
             
                        beta = torch.exp(beta)
         
     | 
| 337 | 
         
             
                    x = snake_beta(x, alpha, beta)
         
     | 
| 338 | 
         | 
| 339 | 
         
            -
                    return x
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 336 | 
         
             
                        beta = torch.exp(beta)
         
     | 
| 337 | 
         
             
                    x = snake_beta(x, alpha, beta)
         
     | 
| 338 | 
         | 
| 339 | 
         
            +
                    return x
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
            class ChannelLastConv1d(nn.Conv1d):
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 344 | 
         
            +
                    x = x.permute(0, 2, 1)
         
     | 
| 345 | 
         
            +
                    x = super().forward(x)
         
     | 
| 346 | 
         
            +
                    x = x.permute(0, 2, 1)
         
     | 
| 347 | 
         
            +
                    return x
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
            # https://github.com/Stability-AI/sd3-ref
         
     | 
| 351 | 
         
            +
            class MLP(nn.Module):
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                def __init__(
         
     | 
| 354 | 
         
            +
                    self,
         
     | 
| 355 | 
         
            +
                    dim: int,
         
     | 
| 356 | 
         
            +
                    hidden_dim: int,
         
     | 
| 357 | 
         
            +
                    multiple_of: int = 256,
         
     | 
| 358 | 
         
            +
                ):
         
     | 
| 359 | 
         
            +
                    """
         
     | 
| 360 | 
         
            +
                    Initialize the FeedForward module.
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
                    Args:
         
     | 
| 363 | 
         
            +
                        dim (int): Input dimension.
         
     | 
| 364 | 
         
            +
                        hidden_dim (int): Hidden dimension of the feedforward layer.
         
     | 
| 365 | 
         
            +
                        multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                    Attributes:
         
     | 
| 368 | 
         
            +
                        w1 (ColumnParallelLinear): Linear transformation for the first layer.
         
     | 
| 369 | 
         
            +
                        w2 (RowParallelLinear): Linear transformation for the second layer.
         
     | 
| 370 | 
         
            +
                        w3 (ColumnParallelLinear): Linear transformation for the third layer.
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                    """
         
     | 
| 373 | 
         
            +
                    super().__init__()
         
     | 
| 374 | 
         
            +
                    hidden_dim = int(2 * hidden_dim / 3)
         
     | 
| 375 | 
         
            +
                    hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                    self.w1 = nn.Linear(dim, hidden_dim, bias=False)
         
     | 
| 378 | 
         
            +
                    self.w2 = nn.Linear(hidden_dim, dim, bias=False)
         
     | 
| 379 | 
         
            +
                    self.w3 = nn.Linear(dim, hidden_dim, bias=False)
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                def forward(self, x):
         
     | 
| 382 | 
         
            +
                    return self.w2(F.silu(self.w1(x)) * self.w3(x))
         
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
            class ConvMLP(nn.Module):
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
                def __init__(
         
     | 
| 388 | 
         
            +
                    self,
         
     | 
| 389 | 
         
            +
                    dim: int,
         
     | 
| 390 | 
         
            +
                    hidden_dim: int,
         
     | 
| 391 | 
         
            +
                    multiple_of: int = 256,
         
     | 
| 392 | 
         
            +
                    kernel_size: int = 3,
         
     | 
| 393 | 
         
            +
                    padding: int = 1,
         
     | 
| 394 | 
         
            +
                ):
         
     | 
| 395 | 
         
            +
                    """
         
     | 
| 396 | 
         
            +
                    Initialize the FeedForward module.
         
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
                    Args:
         
     | 
| 399 | 
         
            +
                        dim (int): Input dimension.
         
     | 
| 400 | 
         
            +
                        hidden_dim (int): Hidden dimension of the feedforward layer.
         
     | 
| 401 | 
         
            +
                        multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
         
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
                    Attributes:
         
     | 
| 404 | 
         
            +
                        w1 (ColumnParallelLinear): Linear transformation for the first layer.
         
     | 
| 405 | 
         
            +
                        w2 (RowParallelLinear): Linear transformation for the second layer.
         
     | 
| 406 | 
         
            +
                        w3 (ColumnParallelLinear): Linear transformation for the third layer.
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                    """
         
     | 
| 409 | 
         
            +
                    super().__init__()
         
     | 
| 410 | 
         
            +
                    hidden_dim = int(2 * hidden_dim / 3)
         
     | 
| 411 | 
         
            +
                    hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                    self.w1 = ChannelLastConv1d(dim,
         
     | 
| 414 | 
         
            +
                                                hidden_dim,
         
     | 
| 415 | 
         
            +
                                                bias=False,
         
     | 
| 416 | 
         
            +
                                                kernel_size=kernel_size,
         
     | 
| 417 | 
         
            +
                                                padding=padding)
         
     | 
| 418 | 
         
            +
                    self.w2 = ChannelLastConv1d(hidden_dim,
         
     | 
| 419 | 
         
            +
                                                dim,
         
     | 
| 420 | 
         
            +
                                                bias=False,
         
     | 
| 421 | 
         
            +
                                                kernel_size=kernel_size,
         
     | 
| 422 | 
         
            +
                                                padding=padding)
         
     | 
| 423 | 
         
            +
                    self.w3 = ChannelLastConv1d(dim,
         
     | 
| 424 | 
         
            +
                                                hidden_dim,
         
     | 
| 425 | 
         
            +
                                                bias=False,
         
     | 
| 426 | 
         
            +
                                                kernel_size=kernel_size,
         
     | 
| 427 | 
         
            +
                                                padding=padding)
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                def forward(self, x):
         
     | 
| 430 | 
         
            +
                    return self.w2(F.silu(self.w1(x)) * self.w3(x))
         
     | 
    	
        {think_sound → ThinkSound}/models/bottleneck.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/models/codebook_patterns.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/models/conditioners.py
    RENAMED
    
    | 
         @@ -7,7 +7,6 @@ import typing as tp 
     | 
|
| 7 | 
         
             
            import gc
         
     | 
| 8 | 
         
             
            from typing import Literal, Optional
         
     | 
| 9 | 
         
             
            import os 
         
     | 
| 10 | 
         
            -
            from .adp import NumberEmbedder
         
     | 
| 11 | 
         
             
            from ..inference.utils import set_audio_channels
         
     | 
| 12 | 
         
             
            from .factory import create_pretransform_from_config
         
     | 
| 13 | 
         
             
            from .pretransforms import Pretransform
         
     | 
| 
         | 
|
| 7 | 
         
             
            import gc
         
     | 
| 8 | 
         
             
            from typing import Literal, Optional
         
     | 
| 9 | 
         
             
            import os 
         
     | 
| 
         | 
|
| 10 | 
         
             
            from ..inference.utils import set_audio_channels
         
     | 
| 11 | 
         
             
            from .factory import create_pretransform_from_config
         
     | 
| 12 | 
         
             
            from .pretransforms import Pretransform
         
     | 
    	
        {think_sound → ThinkSound}/models/diffusion.py
    RENAMED
    
    | 
         @@ -7,14 +7,12 @@ import typing as tp 
     | 
|
| 7 | 
         | 
| 8 | 
         
             
            from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
         
     | 
| 9 | 
         
             
            from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
         
     | 
| 10 | 
         
            -
            from .dit import DiffusionTransformer
         
     | 
| 11 | 
         
             
            from .mmdit import MMAudio
         
     | 
| 12 | 
         
             
            from .factory import create_pretransform_from_config
         
     | 
| 13 | 
         
             
            from .pretransforms import Pretransform
         
     | 
| 14 | 
         
             
            from ..inference.generation import generate_diffusion_cond
         
     | 
| 15 | 
         | 
| 16 | 
         
            -
            from .adp import UNetCFG1d, UNet1d
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
             
            from time import time
         
     | 
| 19 | 
         | 
| 20 | 
         
             
            class Profiler:
         
     | 
| 
         | 
|
| 7 | 
         | 
| 8 | 
         
             
            from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
         
     | 
| 9 | 
         
             
            from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
         
     | 
| 10 | 
         
            +
            # from .dit import DiffusionTransformer
         
     | 
| 11 | 
         
             
            from .mmdit import MMAudio
         
     | 
| 12 | 
         
             
            from .factory import create_pretransform_from_config
         
     | 
| 13 | 
         
             
            from .pretransforms import Pretransform
         
     | 
| 14 | 
         
             
            from ..inference.generation import generate_diffusion_cond
         
     | 
| 15 | 
         | 
| 
         | 
|
| 
         | 
|
| 16 | 
         
             
            from time import time
         
     | 
| 17 | 
         | 
| 18 | 
         
             
            class Profiler:
         
     | 
    	
        {think_sound → ThinkSound}/models/dit.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound/models/mmmodules/model → ThinkSound/models}/embeddings.py
    RENAMED
    
    | 
         @@ -3,6 +3,42 @@ import torch.nn as nn 
     | 
|
| 3 | 
         | 
| 4 | 
         
             
            # https://github.com/facebookresearch/DiT
         
     | 
| 5 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 6 | 
         | 
| 7 | 
         
             
            class TimestepEmbedder(nn.Module):
         
     | 
| 8 | 
         
             
                """
         
     | 
| 
         | 
|
| 3 | 
         | 
| 4 | 
         
             
            # https://github.com/facebookresearch/DiT
         
     | 
| 5 | 
         | 
| 6 | 
         
            +
            from typing import Union
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            from einops import rearrange
         
     | 
| 10 | 
         
            +
            from torch import Tensor
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
         
     | 
| 13 | 
         
            +
            # Ref: https://github.com/lucidrains/rotary-embedding-torch
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            def compute_rope_rotations(length: int,
         
     | 
| 17 | 
         
            +
                                       dim: int,
         
     | 
| 18 | 
         
            +
                                       theta: int,
         
     | 
| 19 | 
         
            +
                                       *,
         
     | 
| 20 | 
         
            +
                                       freq_scaling: float = 1.0,
         
     | 
| 21 | 
         
            +
                                       device: Union[torch.device, str] = 'cpu') -> Tensor:
         
     | 
| 22 | 
         
            +
                assert dim % 2 == 0
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                with torch.amp.autocast(device_type='cuda', enabled=False):
         
     | 
| 25 | 
         
            +
                    pos = torch.arange(length, dtype=torch.float32, device=device)
         
     | 
| 26 | 
         
            +
                    freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
         
     | 
| 27 | 
         
            +
                    freqs *= freq_scaling
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    rot = torch.einsum('..., f -> ... f', pos, freqs)
         
     | 
| 30 | 
         
            +
                    rot = torch.stack([torch.cos(rot), -torch.sin(rot), torch.sin(rot), torch.cos(rot)], dim=-1)
         
     | 
| 31 | 
         
            +
                    rot = rearrange(rot, 'n d (i j) -> 1 n d i j', i=2, j=2)
         
     | 
| 32 | 
         
            +
                    return rot
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]:
         
     | 
| 36 | 
         
            +
                with torch.amp.autocast(device_type='cuda', enabled=False):
         
     | 
| 37 | 
         
            +
                    _x = x.float()
         
     | 
| 38 | 
         
            +
                    _x = _x.view(*_x.shape[:-1], -1, 1, 2)
         
     | 
| 39 | 
         
            +
                    x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
         
     | 
| 40 | 
         
            +
                    return x_out.reshape(*x.shape).to(dtype=x.dtype)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         | 
| 43 | 
         
             
            class TimestepEmbedder(nn.Module):
         
     | 
| 44 | 
         
             
                """
         
     | 
    	
        {think_sound → ThinkSound}/models/factory.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/models/local_attention.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/models/mmdit.py
    RENAMED
    
    | 
         @@ -6,10 +6,10 @@ import torch 
     | 
|
| 6 | 
         
             
            import torch.nn as nn
         
     | 
| 7 | 
         
             
            import torch.nn.functional as F
         
     | 
| 8 | 
         
             
            import sys
         
     | 
| 9 | 
         
            -
            from . 
     | 
| 10 | 
         
            -
            from . 
     | 
| 11 | 
         
            -
            from . 
     | 
| 12 | 
         
            -
            from . 
     | 
| 13 | 
         
             
            from .utils import resample
         
     | 
| 14 | 
         | 
| 15 | 
         
             
            log = logging.getLogger()
         
     | 
| 
         @@ -24,7 +24,6 @@ class PreprocessedConditions: 
     | 
|
| 24 | 
         
             
                text_f_c: torch.Tensor
         
     | 
| 25 | 
         | 
| 26 | 
         | 
| 27 | 
         
            -
            # Partially from https://github.com/facebookresearch/DiT
         
     | 
| 28 | 
         
             
            class MMAudio(nn.Module):
         
     | 
| 29 | 
         | 
| 30 | 
         
             
                def __init__(self,
         
     | 
| 
         @@ -94,7 +93,6 @@ class MMAudio(nn.Module): 
     | 
|
| 94 | 
         
             
                            nn.Linear(hidden_dim * 4, hidden_dim, bias=False),
         
     | 
| 95 | 
         
             
                            nn.Sigmoid()
         
     | 
| 96 | 
         
             
                        )
         
     | 
| 97 | 
         
            -
                        # 初始化最后一层权重为零,促进初始均匀融合
         
     | 
| 98 | 
         
             
                        nn.init.zeros_(self.gated_mlp_v[3].weight)
         
     | 
| 99 | 
         
             
                        nn.init.zeros_(self.gated_mlp_t[3].weight)
         
     | 
| 100 | 
         
             
                    if v2:
         
     | 
| 
         @@ -441,9 +439,9 @@ class MMAudio(nn.Module): 
     | 
|
| 441 | 
         
             
                        # clip_f = torch.cat([clip_f,empty_clip_f], dim=0)
         
     | 
| 442 | 
         
             
                        # sync_f = torch.cat([sync_f,empty_sync_f], dim=0)
         
     | 
| 443 | 
         
             
                        # text_f = torch.cat([text_f,empty_text_f], dim=0)
         
     | 
| 444 | 
         
            -
                        clip_f =  
     | 
| 445 | 
         
            -
                        sync_f =  
     | 
| 446 | 
         
            -
                        text_f =  
     | 
| 447 | 
         
             
                        if t5_features is not None:
         
     | 
| 448 | 
         
             
                            empty_t5_features = torch.zeros_like(t5_features, device=latent.device)
         
     | 
| 449 | 
         
             
                            # t5_features = torch.cat([t5_features,empty_t5_features], dim=0)
         
     | 
| 
         @@ -529,3 +527,52 @@ class MMAudio(nn.Module): 
     | 
|
| 529 | 
         
             
                def sync_seq_len(self) -> int:
         
     | 
| 530 | 
         
             
                    return self._sync_seq_len
         
     | 
| 531 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 6 | 
         
             
            import torch.nn as nn
         
     | 
| 7 | 
         
             
            import torch.nn.functional as F
         
     | 
| 8 | 
         
             
            import sys
         
     | 
| 9 | 
         
            +
            from .embeddings import compute_rope_rotations
         
     | 
| 10 | 
         
            +
            from .embeddings import TimestepEmbedder
         
     | 
| 11 | 
         
            +
            from .blocks import MLP, ChannelLastConv1d, ConvMLP
         
     | 
| 12 | 
         
            +
            from .transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock)
         
     | 
| 13 | 
         
             
            from .utils import resample
         
     | 
| 14 | 
         | 
| 15 | 
         
             
            log = logging.getLogger()
         
     | 
| 
         | 
|
| 24 | 
         
             
                text_f_c: torch.Tensor
         
     | 
| 25 | 
         | 
| 26 | 
         | 
| 
         | 
|
| 27 | 
         
             
            class MMAudio(nn.Module):
         
     | 
| 28 | 
         | 
| 29 | 
         
             
                def __init__(self,
         
     | 
| 
         | 
|
| 93 | 
         
             
                            nn.Linear(hidden_dim * 4, hidden_dim, bias=False),
         
     | 
| 94 | 
         
             
                            nn.Sigmoid()
         
     | 
| 95 | 
         
             
                        )
         
     | 
| 
         | 
|
| 96 | 
         
             
                        nn.init.zeros_(self.gated_mlp_v[3].weight)
         
     | 
| 97 | 
         
             
                        nn.init.zeros_(self.gated_mlp_t[3].weight)
         
     | 
| 98 | 
         
             
                    if v2:
         
     | 
| 
         | 
|
| 439 | 
         
             
                        # clip_f = torch.cat([clip_f,empty_clip_f], dim=0)
         
     | 
| 440 | 
         
             
                        # sync_f = torch.cat([sync_f,empty_sync_f], dim=0)
         
     | 
| 441 | 
         
             
                        # text_f = torch.cat([text_f,empty_text_f], dim=0)
         
     | 
| 442 | 
         
            +
                        clip_f = safe_cat(clip_f,self.get_empty_clip_sequence(bsz), dim=0, match_dim=1)
         
     | 
| 443 | 
         
            +
                        sync_f = safe_cat(sync_f,self.get_empty_sync_sequence(bsz), dim=0, match_dim=1)
         
     | 
| 444 | 
         
            +
                        text_f = safe_cat(text_f,self.get_empty_string_sequence(bsz), dim=0, match_dim=1)
         
     | 
| 445 | 
         
             
                        if t5_features is not None:
         
     | 
| 446 | 
         
             
                            empty_t5_features = torch.zeros_like(t5_features, device=latent.device)
         
     | 
| 447 | 
         
             
                            # t5_features = torch.cat([t5_features,empty_t5_features], dim=0)
         
     | 
| 
         | 
|
| 527 | 
         
             
                def sync_seq_len(self) -> int:
         
     | 
| 528 | 
         
             
                    return self._sync_seq_len
         
     | 
| 529 | 
         | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
             
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
             
     | 
| 536 | 
         
            +
             
     | 
| 537 | 
         
            +
             
     | 
| 538 | 
         
            +
             
     | 
| 539 | 
         
            +
             
     | 
| 540 | 
         
            +
             
     | 
| 541 | 
         
            +
             
     | 
| 542 | 
         
            +
             
     | 
| 543 | 
         
            +
             
     | 
| 544 | 
         
            +
             
     | 
| 545 | 
         
            +
             
     | 
| 546 | 
         
            +
            def truncate_to_target(tensor, target_size, dim=1):
         
     | 
| 547 | 
         
            +
                current_size = tensor.size(dim)
         
     | 
| 548 | 
         
            +
                if current_size > target_size:
         
     | 
| 549 | 
         
            +
                    slices = [slice(None)] * tensor.dim()
         
     | 
| 550 | 
         
            +
                    slices[dim] = slice(0, target_size)
         
     | 
| 551 | 
         
            +
                    return tensor[slices]
         
     | 
| 552 | 
         
            +
                return tensor
         
     | 
| 553 | 
         
            +
             
     | 
| 554 | 
         
            +
            def pad_to_target(tensor, target_size, dim=1, pad_value=0):
         
     | 
| 555 | 
         
            +
                current_size = tensor.size(dim)
         
     | 
| 556 | 
         
            +
                if current_size < target_size:
         
     | 
| 557 | 
         
            +
                    pad_size = target_size - current_size
         
     | 
| 558 | 
         
            +
                    
         
     | 
| 559 | 
         
            +
                    pad_config = [0, 0] * tensor.dim()
         
     | 
| 560 | 
         
            +
                    pad_index = 2 * (tensor.dim() - dim - 1) + 1
         
     | 
| 561 | 
         
            +
                    pad_config[pad_index] = pad_size
         
     | 
| 562 | 
         
            +
                    
         
     | 
| 563 | 
         
            +
                    return torch.nn.functional.pad(tensor, pad_config, value=pad_value)
         
     | 
| 564 | 
         
            +
                return tensor
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
             
     | 
| 567 | 
         
            +
            def safe_cat(tensor1, tensor2, dim=0, match_dim=1):
         
     | 
| 568 | 
         
            +
             
     | 
| 569 | 
         
            +
                target_size = tensor2.size(match_dim)
         
     | 
| 570 | 
         
            +
             
     | 
| 571 | 
         
            +
                if tensor1.size(match_dim) > target_size:
         
     | 
| 572 | 
         
            +
                    tensor1 = truncate_to_target(tensor1, target_size, match_dim)
         
     | 
| 573 | 
         
            +
                    
         
     | 
| 574 | 
         
            +
                else:
         
     | 
| 575 | 
         
            +
                    tensor1 = pad_to_target(tensor1, target_size, match_dim)
         
     | 
| 576 | 
         
            +
                
         
     | 
| 577 | 
         
            +
                return torch.cat([tensor1, tensor2], dim=dim)
         
     | 
| 578 | 
         
            +
             
     | 
    	
        {think_sound → ThinkSound}/models/pretrained.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/models/pretransforms.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/models/transformer.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound/models/mmmodules/model → ThinkSound/models}/transformer_layers.py
    RENAMED
    
    | 
         @@ -6,8 +6,8 @@ import torch.nn.functional as F 
     | 
|
| 6 | 
         
             
            from einops import rearrange
         
     | 
| 7 | 
         
             
            from einops.layers.torch import Rearrange
         
     | 
| 8 | 
         | 
| 9 | 
         
            -
            from  
     | 
| 10 | 
         
            -
            from  
     | 
| 11 | 
         
             
            try:
         
     | 
| 12 | 
         
             
                from flash_attn import flash_attn_func, flash_attn_kvpacked_func
         
     | 
| 13 | 
         
             
                print('flash_attn installed, using Flash Attention')
         
     | 
| 
         | 
|
| 6 | 
         
             
            from einops import rearrange
         
     | 
| 7 | 
         
             
            from einops.layers.torch import Rearrange
         
     | 
| 8 | 
         | 
| 9 | 
         
            +
            from .embeddings import apply_rope
         
     | 
| 10 | 
         
            +
            from .blocks import MLP, ChannelLastConv1d, ConvMLP
         
     | 
| 11 | 
         
             
            try:
         
     | 
| 12 | 
         
             
                from flash_attn import flash_attn_func, flash_attn_kvpacked_func
         
     | 
| 13 | 
         
             
                print('flash_attn installed, using Flash Attention')
         
     | 
    	
        {think_sound → ThinkSound}/models/utils.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/training/__init__.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/training/autoencoders.py
    RENAMED
    
    | 
         @@ -9,7 +9,6 @@ from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss, 
     | 
|
| 9 | 
         
             
            import lightning as L
         
     | 
| 10 | 
         
             
            from lightning.pytorch.callbacks import Callback
         
     | 
| 11 | 
         
             
            from ..models.autoencoders import AudioAutoencoder
         
     | 
| 12 | 
         
            -
            from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss
         
     | 
| 13 | 
         
             
            from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck
         
     | 
| 14 | 
         
             
            from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss
         
     | 
| 15 | 
         
             
            from .utils import create_optimizer_from_config, create_scheduler_from_config
         
     | 
| 
         | 
|
| 9 | 
         
             
            import lightning as L
         
     | 
| 10 | 
         
             
            from lightning.pytorch.callbacks import Callback
         
     | 
| 11 | 
         
             
            from ..models.autoencoders import AudioAutoencoder
         
     | 
| 
         | 
|
| 12 | 
         
             
            from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck
         
     | 
| 13 | 
         
             
            from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss
         
     | 
| 14 | 
         
             
            from .utils import create_optimizer_from_config, create_scheduler_from_config
         
     | 
    	
        {think_sound → ThinkSound}/training/diffusion.py
    RENAMED
    
    | 
         @@ -20,7 +20,6 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_only 
     | 
|
| 20 | 
         
             
            from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
         
     | 
| 21 | 
         
             
            from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper
         
     | 
| 22 | 
         
             
            from ..models.autoencoders import DiffusionAutoencoder
         
     | 
| 23 | 
         
            -
            from ..models.diffusion_prior import PriorType
         
     | 
| 24 | 
         
             
            from .autoencoders import create_loss_modules_from_bottleneck
         
     | 
| 25 | 
         
             
            from .losses import AuralossLoss, MSELoss, MultiLoss
         
     | 
| 26 | 
         
             
            from .utils import create_optimizer_from_config, create_scheduler_from_config, mask_from_frac_lengths, generate_mask, generate_channel_mask
         
     | 
| 
         @@ -846,10 +845,9 @@ class DiffusionCondTrainingWrapper(L.LightningModule): 
     | 
|
| 846 | 
         | 
| 847 | 
         
             
                def predict_step(self, batch, batch_idx):
         
     | 
| 848 | 
         
             
                    reals, metadata = batch
         
     | 
| 849 | 
         
            -
                    # import ipdb
         
     | 
| 850 | 
         
            -
                    # ipdb.set_trace()
         
     | 
| 851 | 
         
             
                    ids = [item['id'] for item in metadata]
         
     | 
| 852 | 
         
             
                    batch_size, length = reals.shape[0], reals.shape[2]
         
     | 
| 
         | 
|
| 853 | 
         
             
                    with torch.amp.autocast('cuda'):
         
     | 
| 854 | 
         
             
                        conditioning = self.diffusion.conditioner(metadata, self.device)
         
     | 
| 855 | 
         | 
| 
         @@ -878,7 +876,6 @@ class DiffusionCondTrainingWrapper(L.LightningModule): 
     | 
|
| 878 | 
         
             
                            end_time = time.time()
         
     | 
| 879 | 
         
             
                            execution_time = end_time - start_time
         
     | 
| 880 | 
         
             
                            print(f"执行时间: {execution_time:.2f} 秒")
         
     | 
| 881 | 
         
            -
                            breakpoint()
         
     | 
| 882 | 
         
             
                        if self.diffusion.pretransform is not None:
         
     | 
| 883 | 
         
             
                            fakes = self.diffusion.pretransform.decode(fakes)
         
     | 
| 884 | 
         | 
| 
         @@ -1077,947 +1074,3 @@ class DiffusionCondDemoCallback(Callback): 
     | 
|
| 1077 | 
         
             
                        gc.collect()
         
     | 
| 1078 | 
         
             
                        torch.cuda.empty_cache()
         
     | 
| 1079 | 
         
             
                        module.train()
         
     | 
| 1080 | 
         
            -
             
     | 
| 1081 | 
         
            -
            class DiffusionCondInpaintTrainingWrapper(L.LightningModule):
         
     | 
| 1082 | 
         
            -
                '''
         
     | 
| 1083 | 
         
            -
                Wrapper for training a conditional audio diffusion model.
         
     | 
| 1084 | 
         
            -
                '''
         
     | 
| 1085 | 
         
            -
                def __init__(
         
     | 
| 1086 | 
         
            -
                        self,
         
     | 
| 1087 | 
         
            -
                        model: ConditionedDiffusionModelWrapper,
         
     | 
| 1088 | 
         
            -
                        lr: float = 1e-4,
         
     | 
| 1089 | 
         
            -
                        max_mask_segments = 10,
         
     | 
| 1090 | 
         
            -
                        log_loss_info: bool = False,
         
     | 
| 1091 | 
         
            -
                        optimizer_configs: dict = None,
         
     | 
| 1092 | 
         
            -
                        use_ema: bool = True,
         
     | 
| 1093 | 
         
            -
                        pre_encoded: bool = False,
         
     | 
| 1094 | 
         
            -
                        cfg_dropout_prob = 0.1,
         
     | 
| 1095 | 
         
            -
                        timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
         
     | 
| 1096 | 
         
            -
                ):
         
     | 
| 1097 | 
         
            -
                    super().__init__()
         
     | 
| 1098 | 
         
            -
             
     | 
| 1099 | 
         
            -
                    self.diffusion = model
         
     | 
| 1100 | 
         
            -
                    
         
     | 
| 1101 | 
         
            -
                    self.use_ema = use_ema
         
     | 
| 1102 | 
         
            -
             
     | 
| 1103 | 
         
            -
                    if self.use_ema:
         
     | 
| 1104 | 
         
            -
                        self.diffusion_ema = EMA(
         
     | 
| 1105 | 
         
            -
                            self.diffusion.model,
         
     | 
| 1106 | 
         
            -
                            beta=0.9999,
         
     | 
| 1107 | 
         
            -
                            power=3/4,
         
     | 
| 1108 | 
         
            -
                            update_every=1,
         
     | 
| 1109 | 
         
            -
                            update_after_step=1,
         
     | 
| 1110 | 
         
            -
                            include_online_model=False
         
     | 
| 1111 | 
         
            -
                        )
         
     | 
| 1112 | 
         
            -
                    else:
         
     | 
| 1113 | 
         
            -
                        self.diffusion_ema = None
         
     | 
| 1114 | 
         
            -
             
     | 
| 1115 | 
         
            -
                    self.cfg_dropout_prob = cfg_dropout_prob
         
     | 
| 1116 | 
         
            -
             
     | 
| 1117 | 
         
            -
                    self.lr = lr
         
     | 
| 1118 | 
         
            -
                    self.max_mask_segments = max_mask_segments
         
     | 
| 1119 | 
         
            -
             
     | 
| 1120 | 
         
            -
                    self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
         
     | 
| 1121 | 
         
            -
                    
         
     | 
| 1122 | 
         
            -
                    self.timestep_sampler = timestep_sampler
         
     | 
| 1123 | 
         
            -
             
     | 
| 1124 | 
         
            -
                    self.diffusion_objective = model.diffusion_objective
         
     | 
| 1125 | 
         
            -
             
     | 
| 1126 | 
         
            -
                    self.loss_modules = [
         
     | 
| 1127 | 
         
            -
                        MSELoss("output", 
         
     | 
| 1128 | 
         
            -
                               "targets", 
         
     | 
| 1129 | 
         
            -
                               weight=1.0, 
         
     | 
| 1130 | 
         
            -
                               name="mse_loss"
         
     | 
| 1131 | 
         
            -
                        )
         
     | 
| 1132 | 
         
            -
                    ]
         
     | 
| 1133 | 
         
            -
             
     | 
| 1134 | 
         
            -
                    self.losses = MultiLoss(self.loss_modules)
         
     | 
| 1135 | 
         
            -
             
     | 
| 1136 | 
         
            -
                    self.log_loss_info = log_loss_info
         
     | 
| 1137 | 
         
            -
             
     | 
| 1138 | 
         
            -
                    assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
         
     | 
| 1139 | 
         
            -
             
     | 
| 1140 | 
         
            -
                    if optimizer_configs is None:
         
     | 
| 1141 | 
         
            -
                        optimizer_configs = {
         
     | 
| 1142 | 
         
            -
                            "diffusion": {
         
     | 
| 1143 | 
         
            -
                                "optimizer": {
         
     | 
| 1144 | 
         
            -
                                    "type": "Adam",
         
     | 
| 1145 | 
         
            -
                                    "config": {
         
     | 
| 1146 | 
         
            -
                                        "lr": lr
         
     | 
| 1147 | 
         
            -
                                    }
         
     | 
| 1148 | 
         
            -
                                }
         
     | 
| 1149 | 
         
            -
                            }
         
     | 
| 1150 | 
         
            -
                        }
         
     | 
| 1151 | 
         
            -
                    else:
         
     | 
| 1152 | 
         
            -
                        if lr is not None:
         
     | 
| 1153 | 
         
            -
                            print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
         
     | 
| 1154 | 
         
            -
             
     | 
| 1155 | 
         
            -
                    self.optimizer_configs = optimizer_configs
         
     | 
| 1156 | 
         
            -
             
     | 
| 1157 | 
         
            -
                    self.pre_encoded = pre_encoded
         
     | 
| 1158 | 
         
            -
             
     | 
| 1159 | 
         
            -
                def configure_optimizers(self):
         
     | 
| 1160 | 
         
            -
                    diffusion_opt_config = self.optimizer_configs['diffusion']
         
     | 
| 1161 | 
         
            -
                    opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters())
         
     | 
| 1162 | 
         
            -
             
     | 
| 1163 | 
         
            -
                    if "scheduler" in diffusion_opt_config:
         
     | 
| 1164 | 
         
            -
                        sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff)
         
     | 
| 1165 | 
         
            -
                        sched_diff_config = {
         
     | 
| 1166 | 
         
            -
                            "scheduler": sched_diff,
         
     | 
| 1167 | 
         
            -
                            "interval": "step"
         
     | 
| 1168 | 
         
            -
                        }
         
     | 
| 1169 | 
         
            -
                        return [opt_diff], [sched_diff_config]
         
     | 
| 1170 | 
         
            -
             
     | 
| 1171 | 
         
            -
                    return [opt_diff]
         
     | 
| 1172 | 
         
            -
             
     | 
| 1173 | 
         
            -
                def random_mask(self, sequence, max_mask_length):
         
     | 
| 1174 | 
         
            -
                    b, _, sequence_length = sequence.size()
         
     | 
| 1175 | 
         
            -
             
     | 
| 1176 | 
         
            -
                    # Create a mask tensor for each batch element
         
     | 
| 1177 | 
         
            -
                    masks = []
         
     | 
| 1178 | 
         
            -
             
     | 
| 1179 | 
         
            -
                    for i in range(b):
         
     | 
| 1180 | 
         
            -
                        mask_type = random.randint(0, 2)
         
     | 
| 1181 | 
         
            -
             
     | 
| 1182 | 
         
            -
                        if mask_type == 0:  # Random mask with multiple segments
         
     | 
| 1183 | 
         
            -
                            num_segments = random.randint(1, self.max_mask_segments)
         
     | 
| 1184 | 
         
            -
                            max_segment_length = max_mask_length // num_segments
         
     | 
| 1185 | 
         
            -
             
     | 
| 1186 | 
         
            -
                            segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments)
         
     | 
| 1187 | 
         
            -
                           
         
     | 
| 1188 | 
         
            -
                            mask = torch.ones((1, 1, sequence_length))
         
     | 
| 1189 | 
         
            -
                            for length in segment_lengths:
         
     | 
| 1190 | 
         
            -
                                mask_start = random.randint(0, sequence_length - length)
         
     | 
| 1191 | 
         
            -
                                mask[:, :, mask_start:mask_start + length] = 0
         
     | 
| 1192 | 
         
            -
             
     | 
| 1193 | 
         
            -
                        elif mask_type == 1:  # Full mask
         
     | 
| 1194 | 
         
            -
                            mask = torch.zeros((1, 1, sequence_length))
         
     | 
| 1195 | 
         
            -
             
     | 
| 1196 | 
         
            -
                        elif mask_type == 2:  # Causal mask
         
     | 
| 1197 | 
         
            -
                            mask = torch.ones((1, 1, sequence_length))
         
     | 
| 1198 | 
         
            -
                            mask_length = random.randint(1, max_mask_length)
         
     | 
| 1199 | 
         
            -
                            mask[:, :, -mask_length:] = 0
         
     | 
| 1200 | 
         
            -
             
     | 
| 1201 | 
         
            -
                        mask = mask.to(sequence.device)
         
     | 
| 1202 | 
         
            -
                        masks.append(mask)
         
     | 
| 1203 | 
         
            -
             
     | 
| 1204 | 
         
            -
                    # Concatenate the mask tensors into a single tensor
         
     | 
| 1205 | 
         
            -
                    mask = torch.cat(masks, dim=0).to(sequence.device)
         
     | 
| 1206 | 
         
            -
             
     | 
| 1207 | 
         
            -
                    # Apply the mask to the sequence tensor for each batch element
         
     | 
| 1208 | 
         
            -
                    masked_sequence = sequence * mask
         
     | 
| 1209 | 
         
            -
             
     | 
| 1210 | 
         
            -
                    return masked_sequence, mask
         
     | 
| 1211 | 
         
            -
             
     | 
| 1212 | 
         
            -
                def training_step(self, batch, batch_idx):
         
     | 
| 1213 | 
         
            -
                    reals, metadata = batch
         
     | 
| 1214 | 
         
            -
             
     | 
| 1215 | 
         
            -
                    p = Profiler()
         
     | 
| 1216 | 
         
            -
             
     | 
| 1217 | 
         
            -
                    if reals.ndim == 4 and reals.shape[0] == 1:
         
     | 
| 1218 | 
         
            -
                        reals = reals[0]
         
     | 
| 1219 | 
         
            -
             
     | 
| 1220 | 
         
            -
                    loss_info = {}
         
     | 
| 1221 | 
         
            -
             
     | 
| 1222 | 
         
            -
                    diffusion_input = reals
         
     | 
| 1223 | 
         
            -
             
     | 
| 1224 | 
         
            -
                    if not self.pre_encoded:
         
     | 
| 1225 | 
         
            -
                        loss_info["audio_reals"] = diffusion_input
         
     | 
| 1226 | 
         
            -
             
     | 
| 1227 | 
         
            -
                    p.tick("setup")
         
     | 
| 1228 | 
         
            -
             
     | 
| 1229 | 
         
            -
                    with torch.amp.autocast('cuda'):
         
     | 
| 1230 | 
         
            -
                        conditioning = self.diffusion.conditioner(metadata, self.device)
         
     | 
| 1231 | 
         
            -
             
     | 
| 1232 | 
         
            -
                    p.tick("conditioning")
         
     | 
| 1233 | 
         
            -
             
     | 
| 1234 | 
         
            -
                    if self.diffusion.pretransform is not None:
         
     | 
| 1235 | 
         
            -
                        self.diffusion.pretransform.to(self.device)
         
     | 
| 1236 | 
         
            -
             
     | 
| 1237 | 
         
            -
                        if not self.pre_encoded:
         
     | 
| 1238 | 
         
            -
                            with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
         
     | 
| 1239 | 
         
            -
                                diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
         
     | 
| 1240 | 
         
            -
                                p.tick("pretransform")
         
     | 
| 1241 | 
         
            -
             
     | 
| 1242 | 
         
            -
                                # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input
         
     | 
| 1243 | 
         
            -
                                # if use_padding_mask:
         
     | 
| 1244 | 
         
            -
                                #     padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool()
         
     | 
| 1245 | 
         
            -
                        else:            
         
     | 
| 1246 | 
         
            -
                            # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
         
     | 
| 1247 | 
         
            -
                            if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
         
     | 
| 1248 | 
         
            -
                                diffusion_input = diffusion_input / self.diffusion.pretransform.scale
         
     | 
| 1249 | 
         
            -
             
     | 
| 1250 | 
         
            -
                    # Max mask size is the full sequence length
         
     | 
| 1251 | 
         
            -
                    max_mask_length = diffusion_input.shape[2]
         
     | 
| 1252 | 
         
            -
             
     | 
| 1253 | 
         
            -
                    # Create a mask of random length for a random slice of the input
         
     | 
| 1254 | 
         
            -
                    masked_input, mask = self.random_mask(diffusion_input, max_mask_length)
         
     | 
| 1255 | 
         
            -
             
     | 
| 1256 | 
         
            -
                    # conditioning['inpaint_mask'] = [mask]
         
     | 
| 1257 | 
         
            -
                    conditioning['inpaint_masked_input'] = [masked_input]
         
     | 
| 1258 | 
         
            -
             
     | 
| 1259 | 
         
            -
                    if self.timestep_sampler == "uniform":
         
     | 
| 1260 | 
         
            -
                        # Draw uniformly distributed continuous timesteps
         
     | 
| 1261 | 
         
            -
                        t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
         
     | 
| 1262 | 
         
            -
                    elif self.timestep_sampler == "logit_normal":
         
     | 
| 1263 | 
         
            -
                        t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
         
     | 
| 1264 | 
         
            -
                        
         
     | 
| 1265 | 
         
            -
                    # Calculate the noise schedule parameters for those timesteps
         
     | 
| 1266 | 
         
            -
                    if self.diffusion_objective == "v":
         
     | 
| 1267 | 
         
            -
                        alphas, sigmas = get_alphas_sigmas(t)
         
     | 
| 1268 | 
         
            -
                    elif self.diffusion_objective == "rectified_flow":
         
     | 
| 1269 | 
         
            -
                        alphas, sigmas = 1-t, t
         
     | 
| 1270 | 
         
            -
             
     | 
| 1271 | 
         
            -
                    # Combine the ground truth data and the noise
         
     | 
| 1272 | 
         
            -
                    alphas = alphas[:, None, None]
         
     | 
| 1273 | 
         
            -
                    sigmas = sigmas[:, None, None]
         
     | 
| 1274 | 
         
            -
                    noise = torch.randn_like(diffusion_input)
         
     | 
| 1275 | 
         
            -
                    noised_inputs = diffusion_input * alphas + noise * sigmas
         
     | 
| 1276 | 
         
            -
             
     | 
| 1277 | 
         
            -
                    if self.diffusion_objective == "v":
         
     | 
| 1278 | 
         
            -
                        targets = noise * alphas - diffusion_input * sigmas
         
     | 
| 1279 | 
         
            -
                    elif self.diffusion_objective == "rectified_flow":
         
     | 
| 1280 | 
         
            -
                        targets = noise - diffusion_input
         
     | 
| 1281 | 
         
            -
             
     | 
| 1282 | 
         
            -
                    p.tick("noise")
         
     | 
| 1283 | 
         
            -
             
     | 
| 1284 | 
         
            -
                    extra_args = {}
         
     | 
| 1285 | 
         
            -
             
     | 
| 1286 | 
         
            -
                    with torch.amp.autocast('cuda'):
         
     | 
| 1287 | 
         
            -
                        p.tick("amp")
         
     | 
| 1288 | 
         
            -
                        output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
         
     | 
| 1289 | 
         
            -
                        p.tick("diffusion")
         
     | 
| 1290 | 
         
            -
             
     | 
| 1291 | 
         
            -
                        loss_info.update({
         
     | 
| 1292 | 
         
            -
                            "output": output,
         
     | 
| 1293 | 
         
            -
                            "targets": targets,
         
     | 
| 1294 | 
         
            -
                        })
         
     | 
| 1295 | 
         
            -
             
     | 
| 1296 | 
         
            -
                        loss, losses = self.losses(loss_info)
         
     | 
| 1297 | 
         
            -
             
     | 
| 1298 | 
         
            -
                        if self.log_loss_info:
         
     | 
| 1299 | 
         
            -
                            # Loss debugging logs
         
     | 
| 1300 | 
         
            -
                            num_loss_buckets = 10
         
     | 
| 1301 | 
         
            -
                            bucket_size = 1 / num_loss_buckets
         
     | 
| 1302 | 
         
            -
                            loss_all = F.mse_loss(output, targets, reduction="none")
         
     | 
| 1303 | 
         
            -
             
     | 
| 1304 | 
         
            -
                            sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
         
     | 
| 1305 | 
         
            -
             
     | 
| 1306 | 
         
            -
                            # gather loss_all across all GPUs
         
     | 
| 1307 | 
         
            -
                            loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
         
     | 
| 1308 | 
         
            -
             
     | 
| 1309 | 
         
            -
                            # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
         
     | 
| 1310 | 
         
            -
                            loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
         
     | 
| 1311 | 
         
            -
             
     | 
| 1312 | 
         
            -
                            # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
         
     | 
| 1313 | 
         
            -
                            debug_log_dict = {
         
     | 
| 1314 | 
         
            -
                                f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
         
     | 
| 1315 | 
         
            -
                            }
         
     | 
| 1316 | 
         
            -
             
     | 
| 1317 | 
         
            -
                            self.log_dict(debug_log_dict)
         
     | 
| 1318 | 
         
            -
             
     | 
| 1319 | 
         
            -
                    log_dict = {
         
     | 
| 1320 | 
         
            -
                        'train/loss': loss.detach(),
         
     | 
| 1321 | 
         
            -
                        'train/std_data': diffusion_input.std(),
         
     | 
| 1322 | 
         
            -
                        'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
         
     | 
| 1323 | 
         
            -
                    }
         
     | 
| 1324 | 
         
            -
             
     | 
| 1325 | 
         
            -
                    for loss_name, loss_value in losses.items():
         
     | 
| 1326 | 
         
            -
                        log_dict[f"train/{loss_name}"] = loss_value.detach()
         
     | 
| 1327 | 
         
            -
             
     | 
| 1328 | 
         
            -
                    self.log_dict(log_dict, prog_bar=True, on_step=True)
         
     | 
| 1329 | 
         
            -
                    p.tick("log")
         
     | 
| 1330 | 
         
            -
                    #print(f"Profiler: {p}")
         
     | 
| 1331 | 
         
            -
                    return loss
         
     | 
| 1332 | 
         
            -
                
         
     | 
| 1333 | 
         
            -
                def on_before_zero_grad(self, *args, **kwargs):
         
     | 
| 1334 | 
         
            -
                    if self.diffusion_ema is not None:
         
     | 
| 1335 | 
         
            -
                        self.diffusion_ema.update()
         
     | 
| 1336 | 
         
            -
             
     | 
| 1337 | 
         
            -
                def export_model(self, path, use_safetensors=False):
         
     | 
| 1338 | 
         
            -
                    if self.diffusion_ema is not None:
         
     | 
| 1339 | 
         
            -
                        self.diffusion.model = self.diffusion_ema.ema_model
         
     | 
| 1340 | 
         
            -
                    
         
     | 
| 1341 | 
         
            -
                    if use_safetensors:
         
     | 
| 1342 | 
         
            -
                        save_file(self.diffusion.state_dict(), path)
         
     | 
| 1343 | 
         
            -
                    else:
         
     | 
| 1344 | 
         
            -
                        torch.save({"state_dict": self.diffusion.state_dict()}, path)
         
     | 
| 1345 | 
         
            -
             
     | 
| 1346 | 
         
            -
            class DiffusionCondInpaintDemoCallback(Callback):
         
     | 
| 1347 | 
         
            -
                def __init__(
         
     | 
| 1348 | 
         
            -
                    self, 
         
     | 
| 1349 | 
         
            -
                    demo_dl, 
         
     | 
| 1350 | 
         
            -
                    demo_every=2000,
         
     | 
| 1351 | 
         
            -
                    demo_steps=250,
         
     | 
| 1352 | 
         
            -
                    sample_size=65536,
         
     | 
| 1353 | 
         
            -
                    sample_rate=48000,
         
     | 
| 1354 | 
         
            -
                    demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7]
         
     | 
| 1355 | 
         
            -
                ):
         
     | 
| 1356 | 
         
            -
                    super().__init__()
         
     | 
| 1357 | 
         
            -
                    self.demo_every = demo_every
         
     | 
| 1358 | 
         
            -
                    self.demo_steps = demo_steps
         
     | 
| 1359 | 
         
            -
                    self.demo_samples = sample_size
         
     | 
| 1360 | 
         
            -
                    self.demo_dl = iter(demo_dl)
         
     | 
| 1361 | 
         
            -
                    self.sample_rate = sample_rate
         
     | 
| 1362 | 
         
            -
                    self.demo_cfg_scales = demo_cfg_scales
         
     | 
| 1363 | 
         
            -
                    self.last_demo_step = -1
         
     | 
| 1364 | 
         
            -
             
     | 
| 1365 | 
         
            -
                @rank_zero_only
         
     | 
| 1366 | 
         
            -
                @torch.no_grad()
         
     | 
| 1367 | 
         
            -
                def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): 
         
     | 
| 1368 | 
         
            -
                    if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
         
     | 
| 1369 | 
         
            -
                        return
         
     | 
| 1370 | 
         
            -
                    
         
     | 
| 1371 | 
         
            -
                    self.last_demo_step = trainer.global_step
         
     | 
| 1372 | 
         
            -
             
     | 
| 1373 | 
         
            -
                    try:
         
     | 
| 1374 | 
         
            -
                        log_dict = {}
         
     | 
| 1375 | 
         
            -
             
     | 
| 1376 | 
         
            -
                        demo_reals, metadata = next(self.demo_dl)
         
     | 
| 1377 | 
         
            -
             
     | 
| 1378 | 
         
            -
                        # Remove extra dimension added by WebDataset
         
     | 
| 1379 | 
         
            -
                        if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
         
     | 
| 1380 | 
         
            -
                            demo_reals = demo_reals[0]
         
     | 
| 1381 | 
         
            -
             
     | 
| 1382 | 
         
            -
                        demo_reals = demo_reals.to(module.device)
         
     | 
| 1383 | 
         
            -
             
     | 
| 1384 | 
         
            -
                        if not module.pre_encoded:
         
     | 
| 1385 | 
         
            -
                            # Log the real audio
         
     | 
| 1386 | 
         
            -
                            log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu()))
         
     | 
| 1387 | 
         
            -
                            # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals")
         
     | 
| 1388 | 
         
            -
             
     | 
| 1389 | 
         
            -
                            if module.diffusion.pretransform is not None:
         
     | 
| 1390 | 
         
            -
                                module.diffusion.pretransform.to(module.device)
         
     | 
| 1391 | 
         
            -
                                with torch.amp.autocast('cuda'):
         
     | 
| 1392 | 
         
            -
                                    demo_reals = module.diffusion.pretransform.encode(demo_reals)
         
     | 
| 1393 | 
         
            -
             
     | 
| 1394 | 
         
            -
                        demo_samples = demo_reals.shape[2]
         
     | 
| 1395 | 
         
            -
             
     | 
| 1396 | 
         
            -
                        # Get conditioning
         
     | 
| 1397 | 
         
            -
                        conditioning = module.diffusion.conditioner(metadata, module.device)
         
     | 
| 1398 | 
         
            -
             
     | 
| 1399 | 
         
            -
                        masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2])
         
     | 
| 1400 | 
         
            -
             
     | 
| 1401 | 
         
            -
                        conditioning['inpaint_mask'] = [mask]
         
     | 
| 1402 | 
         
            -
                        conditioning['inpaint_masked_input'] = [masked_input]
         
     | 
| 1403 | 
         
            -
             
     | 
| 1404 | 
         
            -
                        if module.diffusion.pretransform is not None:
         
     | 
| 1405 | 
         
            -
                            log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu()))
         
     | 
| 1406 | 
         
            -
                        else:
         
     | 
| 1407 | 
         
            -
                            log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu()))
         
     | 
| 1408 | 
         
            -
             
     | 
| 1409 | 
         
            -
                        cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
         
     | 
| 1410 | 
         
            -
             
     | 
| 1411 | 
         
            -
                        noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device)
         
     | 
| 1412 | 
         
            -
             
     | 
| 1413 | 
         
            -
                        trainer.logger.experiment.log(log_dict)
         
     | 
| 1414 | 
         
            -
             
     | 
| 1415 | 
         
            -
                        for cfg_scale in self.demo_cfg_scales:
         
     | 
| 1416 | 
         
            -
                            model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model
         
     | 
| 1417 | 
         
            -
                            print(f"Generating demo for cfg scale {cfg_scale}")
         
     | 
| 1418 | 
         
            -
             
     | 
| 1419 | 
         
            -
                            if module.diffusion_objective == "v":
         
     | 
| 1420 | 
         
            -
                                fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
         
     | 
| 1421 | 
         
            -
                            elif module.diffusion_objective == "rectified_flow":
         
     | 
| 1422 | 
         
            -
                                fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
         
     | 
| 1423 | 
         
            -
             
     | 
| 1424 | 
         
            -
                            if module.diffusion.pretransform is not None:
         
     | 
| 1425 | 
         
            -
                                with torch.amp.autocast('cuda'):
         
     | 
| 1426 | 
         
            -
                                    fakes = module.diffusion.pretransform.decode(fakes)
         
     | 
| 1427 | 
         
            -
             
     | 
| 1428 | 
         
            -
                            # Put the demos together
         
     | 
| 1429 | 
         
            -
                            fakes = rearrange(fakes, 'b d n -> d (b n)')
         
     | 
| 1430 | 
         
            -
             
     | 
| 1431 | 
         
            -
                            log_dict = {}
         
     | 
| 1432 | 
         
            -
                            
         
     | 
| 1433 | 
         
            -
                            filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
         
     | 
| 1434 | 
         
            -
                            fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
         
     | 
| 1435 | 
         
            -
                            torchaudio.save(filename, fakes, self.sample_rate)
         
     | 
| 1436 | 
         
            -
             
     | 
| 1437 | 
         
            -
                            log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
         
     | 
| 1438 | 
         
            -
                                                                sample_rate=self.sample_rate,
         
     | 
| 1439 | 
         
            -
                                                                caption=f'Reconstructed')
         
     | 
| 1440 | 
         
            -
                        
         
     | 
| 1441 | 
         
            -
                            log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
         
     | 
| 1442 | 
         
            -
             
     | 
| 1443 | 
         
            -
                            trainer.logger.experiment.log(log_dict)
         
     | 
| 1444 | 
         
            -
                    except Exception as e:
         
     | 
| 1445 | 
         
            -
                        print(f'{type(e).__name__}: {e}')
         
     | 
| 1446 | 
         
            -
                        raise e
         
     | 
| 1447 | 
         
            -
             
     | 
| 1448 | 
         
            -
            class DiffusionAutoencoderTrainingWrapper(L.LightningModule):
         
     | 
| 1449 | 
         
            -
                '''
         
     | 
| 1450 | 
         
            -
                Wrapper for training a diffusion autoencoder
         
     | 
| 1451 | 
         
            -
                '''
         
     | 
| 1452 | 
         
            -
                def __init__(
         
     | 
| 1453 | 
         
            -
                        self,
         
     | 
| 1454 | 
         
            -
                        model: DiffusionAutoencoder,
         
     | 
| 1455 | 
         
            -
                        lr: float = 1e-4,
         
     | 
| 1456 | 
         
            -
                        ema_copy = None,
         
     | 
| 1457 | 
         
            -
                        use_reconstruction_loss: bool = False
         
     | 
| 1458 | 
         
            -
                ):
         
     | 
| 1459 | 
         
            -
                    super().__init__()
         
     | 
| 1460 | 
         
            -
             
     | 
| 1461 | 
         
            -
                    self.diffae = model
         
     | 
| 1462 | 
         
            -
                    
         
     | 
| 1463 | 
         
            -
                    self.diffae_ema = EMA(
         
     | 
| 1464 | 
         
            -
                        self.diffae,
         
     | 
| 1465 | 
         
            -
                        ema_model=ema_copy,
         
     | 
| 1466 | 
         
            -
                        beta=0.9999,
         
     | 
| 1467 | 
         
            -
                        power=3/4,
         
     | 
| 1468 | 
         
            -
                        update_every=1,
         
     | 
| 1469 | 
         
            -
                        update_after_step=1,
         
     | 
| 1470 | 
         
            -
                        include_online_model=False
         
     | 
| 1471 | 
         
            -
                    )
         
     | 
| 1472 | 
         
            -
             
     | 
| 1473 | 
         
            -
                    self.lr = lr
         
     | 
| 1474 | 
         
            -
             
     | 
| 1475 | 
         
            -
                    self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
         
     | 
| 1476 | 
         
            -
             
     | 
| 1477 | 
         
            -
                    loss_modules = [
         
     | 
| 1478 | 
         
            -
                        MSELoss("v",
         
     | 
| 1479 | 
         
            -
                                "targets",
         
     | 
| 1480 | 
         
            -
                                weight=1.0,
         
     | 
| 1481 | 
         
            -
                                name="mse_loss"
         
     | 
| 1482 | 
         
            -
                        )
         
     | 
| 1483 | 
         
            -
                    ]
         
     | 
| 1484 | 
         
            -
             
     | 
| 1485 | 
         
            -
                    if model.bottleneck is not None:
         
     | 
| 1486 | 
         
            -
                        # TODO: Use loss config for configurable bottleneck weights and reconstruction losses
         
     | 
| 1487 | 
         
            -
                        loss_modules += create_loss_modules_from_bottleneck(model.bottleneck, {})
         
     | 
| 1488 | 
         
            -
             
     | 
| 1489 | 
         
            -
                    self.use_reconstruction_loss = use_reconstruction_loss
         
     | 
| 1490 | 
         
            -
             
     | 
| 1491 | 
         
            -
                    if use_reconstruction_loss:
         
     | 
| 1492 | 
         
            -
                        scales = [2048, 1024, 512, 256, 128, 64, 32]
         
     | 
| 1493 | 
         
            -
                        hop_sizes = []
         
     | 
| 1494 | 
         
            -
                        win_lengths = []
         
     | 
| 1495 | 
         
            -
                        overlap = 0.75
         
     | 
| 1496 | 
         
            -
                        for s in scales:
         
     | 
| 1497 | 
         
            -
                            hop_sizes.append(int(s * (1 - overlap)))
         
     | 
| 1498 | 
         
            -
                            win_lengths.append(s)
         
     | 
| 1499 | 
         
            -
             
     | 
| 1500 | 
         
            -
                        sample_rate = model.sample_rate
         
     | 
| 1501 | 
         
            -
             
     | 
| 1502 | 
         
            -
                        stft_loss_args = {
         
     | 
| 1503 | 
         
            -
                            "fft_sizes": scales,
         
     | 
| 1504 | 
         
            -
                            "hop_sizes": hop_sizes,
         
     | 
| 1505 | 
         
            -
                            "win_lengths": win_lengths,
         
     | 
| 1506 | 
         
            -
                            "perceptual_weighting": True
         
     | 
| 1507 | 
         
            -
                        }
         
     | 
| 1508 | 
         
            -
             
     | 
| 1509 | 
         
            -
                        out_channels = model.out_channels
         
     | 
| 1510 | 
         
            -
             
     | 
| 1511 | 
         
            -
                        if model.pretransform is not None:
         
     | 
| 1512 | 
         
            -
                            out_channels = model.pretransform.io_channels
         
     | 
| 1513 | 
         
            -
             
     | 
| 1514 | 
         
            -
                        if out_channels == 2:
         
     | 
| 1515 | 
         
            -
                            self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
         
     | 
| 1516 | 
         
            -
                        else:
         
     | 
| 1517 | 
         
            -
                            self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
         
     | 
| 1518 | 
         
            -
             
     | 
| 1519 | 
         
            -
                        loss_modules.append(
         
     | 
| 1520 | 
         
            -
                            AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss
         
     | 
| 1521 | 
         
            -
                        )
         
     | 
| 1522 | 
         
            -
             
     | 
| 1523 | 
         
            -
                    self.losses = MultiLoss(loss_modules)
         
     | 
| 1524 | 
         
            -
             
     | 
| 1525 | 
         
            -
                def configure_optimizers(self):
         
     | 
| 1526 | 
         
            -
                    return optim.Adam([*self.diffae.parameters()], lr=self.lr)
         
     | 
| 1527 | 
         
            -
             
     | 
| 1528 | 
         
            -
                def training_step(self, batch, batch_idx):
         
     | 
| 1529 | 
         
            -
                    reals = batch[0]
         
     | 
| 1530 | 
         
            -
             
     | 
| 1531 | 
         
            -
                    if reals.ndim == 4 and reals.shape[0] == 1:
         
     | 
| 1532 | 
         
            -
                        reals = reals[0]
         
     | 
| 1533 | 
         
            -
             
     | 
| 1534 | 
         
            -
                    loss_info = {}
         
     | 
| 1535 | 
         
            -
             
     | 
| 1536 | 
         
            -
                    loss_info["audio_reals"] = reals
         
     | 
| 1537 | 
         
            -
                    
         
     | 
| 1538 | 
         
            -
                    if self.diffae.pretransform is not None:
         
     | 
| 1539 | 
         
            -
                        with torch.no_grad():
         
     | 
| 1540 | 
         
            -
                            reals = self.diffae.pretransform.encode(reals)
         
     | 
| 1541 | 
         
            -
             
     | 
| 1542 | 
         
            -
                    loss_info["reals"] = reals
         
     | 
| 1543 | 
         
            -
             
     | 
| 1544 | 
         
            -
                    #Encode reals, skipping the pretransform since it was already applied
         
     | 
| 1545 | 
         
            -
                    latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True)
         
     | 
| 1546 | 
         
            -
             
     | 
| 1547 | 
         
            -
                    loss_info["latents"] = latents
         
     | 
| 1548 | 
         
            -
                    loss_info.update(encoder_info)
         
     | 
| 1549 | 
         
            -
             
     | 
| 1550 | 
         
            -
                    if self.diffae.decoder is not None:
         
     | 
| 1551 | 
         
            -
                        latents = self.diffae.decoder(latents)
         
     | 
| 1552 | 
         
            -
                    
         
     | 
| 1553 | 
         
            -
                    # Upsample latents to match diffusion length
         
     | 
| 1554 | 
         
            -
                    if latents.shape[2] != reals.shape[2]:
         
     | 
| 1555 | 
         
            -
                        latents = F.interpolate(latents, size=reals.shape[2], mode='nearest')
         
     | 
| 1556 | 
         
            -
             
     | 
| 1557 | 
         
            -
                    loss_info["latents_upsampled"] = latents
         
     | 
| 1558 | 
         
            -
             
     | 
| 1559 | 
         
            -
                    # Draw uniformly distributed continuous timesteps
         
     | 
| 1560 | 
         
            -
                    t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
         
     | 
| 1561 | 
         
            -
             
     | 
| 1562 | 
         
            -
                    # Calculate the noise schedule parameters for those timesteps
         
     | 
| 1563 | 
         
            -
                    alphas, sigmas = get_alphas_sigmas(t)
         
     | 
| 1564 | 
         
            -
             
     | 
| 1565 | 
         
            -
                    # Combine the ground truth data and the noise
         
     | 
| 1566 | 
         
            -
                    alphas = alphas[:, None, None]
         
     | 
| 1567 | 
         
            -
                    sigmas = sigmas[:, None, None]
         
     | 
| 1568 | 
         
            -
                    noise = torch.randn_like(reals)
         
     | 
| 1569 | 
         
            -
                    noised_reals = reals * alphas + noise * sigmas
         
     | 
| 1570 | 
         
            -
                    targets = noise * alphas - reals * sigmas
         
     | 
| 1571 | 
         
            -
             
     | 
| 1572 | 
         
            -
                    with torch.amp.autocast('cuda'):
         
     | 
| 1573 | 
         
            -
                        v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents)
         
     | 
| 1574 | 
         
            -
                        
         
     | 
| 1575 | 
         
            -
                        loss_info.update({
         
     | 
| 1576 | 
         
            -
                            "v": v,
         
     | 
| 1577 | 
         
            -
                            "targets": targets
         
     | 
| 1578 | 
         
            -
                        })
         
     | 
| 1579 | 
         
            -
             
     | 
| 1580 | 
         
            -
                        if self.use_reconstruction_loss:
         
     | 
| 1581 | 
         
            -
                            pred = noised_reals * alphas - v * sigmas
         
     | 
| 1582 | 
         
            -
             
     | 
| 1583 | 
         
            -
                            loss_info["pred"] = pred
         
     | 
| 1584 | 
         
            -
             
     | 
| 1585 | 
         
            -
                            if self.diffae.pretransform is not None:
         
     | 
| 1586 | 
         
            -
                                pred = self.diffae.pretransform.decode(pred)
         
     | 
| 1587 | 
         
            -
                                loss_info["audio_pred"] = pred
         
     | 
| 1588 | 
         
            -
             
     | 
| 1589 | 
         
            -
                        loss, losses = self.losses(loss_info)
         
     | 
| 1590 | 
         
            -
             
     | 
| 1591 | 
         
            -
                    log_dict = {
         
     | 
| 1592 | 
         
            -
                        'train/loss': loss.detach(),
         
     | 
| 1593 | 
         
            -
                        'train/std_data': reals.std(),
         
     | 
| 1594 | 
         
            -
                        'train/latent_std': latents.std(),
         
     | 
| 1595 | 
         
            -
                    }
         
     | 
| 1596 | 
         
            -
             
     | 
| 1597 | 
         
            -
                    for loss_name, loss_value in losses.items():
         
     | 
| 1598 | 
         
            -
                        log_dict[f"train/{loss_name}"] = loss_value.detach()
         
     | 
| 1599 | 
         
            -
             
     | 
| 1600 | 
         
            -
                    self.log_dict(log_dict, prog_bar=True, on_step=True)
         
     | 
| 1601 | 
         
            -
                    return loss
         
     | 
| 1602 | 
         
            -
                
         
     | 
| 1603 | 
         
            -
                def on_before_zero_grad(self, *args, **kwargs):
         
     | 
| 1604 | 
         
            -
                    self.diffae_ema.update()
         
     | 
| 1605 | 
         
            -
             
     | 
| 1606 | 
         
            -
                def export_model(self, path, use_safetensors=False):
         
     | 
| 1607 | 
         
            -
             
     | 
| 1608 | 
         
            -
                    model = self.diffae_ema.ema_model
         
     | 
| 1609 | 
         
            -
                    
         
     | 
| 1610 | 
         
            -
                    if use_safetensors:
         
     | 
| 1611 | 
         
            -
                        save_file(model.state_dict(), path)
         
     | 
| 1612 | 
         
            -
                    else:
         
     | 
| 1613 | 
         
            -
                        torch.save({"state_dict": model.state_dict()}, path)
         
     | 
| 1614 | 
         
            -
             
     | 
| 1615 | 
         
            -
            class DiffusionAutoencoderDemoCallback(Callback):
         
     | 
| 1616 | 
         
            -
                def __init__(
         
     | 
| 1617 | 
         
            -
                    self, 
         
     | 
| 1618 | 
         
            -
                    demo_dl, 
         
     | 
| 1619 | 
         
            -
                    demo_every=2000,
         
     | 
| 1620 | 
         
            -
                    demo_steps=250,
         
     | 
| 1621 | 
         
            -
                    sample_size=65536,
         
     | 
| 1622 | 
         
            -
                    sample_rate=48000
         
     | 
| 1623 | 
         
            -
                ):
         
     | 
| 1624 | 
         
            -
                    super().__init__()
         
     | 
| 1625 | 
         
            -
                    self.demo_every = demo_every
         
     | 
| 1626 | 
         
            -
                    self.demo_steps = demo_steps
         
     | 
| 1627 | 
         
            -
                    self.demo_samples = sample_size
         
     | 
| 1628 | 
         
            -
                    self.demo_dl = iter(demo_dl)
         
     | 
| 1629 | 
         
            -
                    self.sample_rate = sample_rate
         
     | 
| 1630 | 
         
            -
                    self.last_demo_step = -1
         
     | 
| 1631 | 
         
            -
             
     | 
| 1632 | 
         
            -
                @rank_zero_only
         
     | 
| 1633 | 
         
            -
                @torch.no_grad()
         
     | 
| 1634 | 
         
            -
                def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): 
         
     | 
| 1635 | 
         
            -
                    if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
         
     | 
| 1636 | 
         
            -
                        return
         
     | 
| 1637 | 
         
            -
                    
         
     | 
| 1638 | 
         
            -
                    self.last_demo_step = trainer.global_step
         
     | 
| 1639 | 
         
            -
             
     | 
| 1640 | 
         
            -
                    demo_reals, _ = next(self.demo_dl)
         
     | 
| 1641 | 
         
            -
             
     | 
| 1642 | 
         
            -
                    # Remove extra dimension added by WebDataset
         
     | 
| 1643 | 
         
            -
                    if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
         
     | 
| 1644 | 
         
            -
                        demo_reals = demo_reals[0]
         
     | 
| 1645 | 
         
            -
             
     | 
| 1646 | 
         
            -
                    encoder_input = demo_reals
         
     | 
| 1647 | 
         
            -
                    
         
     | 
| 1648 | 
         
            -
                    encoder_input = encoder_input.to(module.device)
         
     | 
| 1649 | 
         
            -
             
     | 
| 1650 | 
         
            -
                    demo_reals = demo_reals.to(module.device)
         
     | 
| 1651 | 
         
            -
             
     | 
| 1652 | 
         
            -
                    with torch.no_grad() and torch.amp.autocast('cuda'):
         
     | 
| 1653 | 
         
            -
                        latents = module.diffae_ema.ema_model.encode(encoder_input).float()
         
     | 
| 1654 | 
         
            -
                        fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps)
         
     | 
| 1655 | 
         
            -
             
     | 
| 1656 | 
         
            -
                    #Interleave reals and fakes
         
     | 
| 1657 | 
         
            -
                    reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
         
     | 
| 1658 | 
         
            -
             
     | 
| 1659 | 
         
            -
                    # Put the demos together
         
     | 
| 1660 | 
         
            -
                    reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
         
     | 
| 1661 | 
         
            -
             
     | 
| 1662 | 
         
            -
                    log_dict = {}
         
     | 
| 1663 | 
         
            -
                    
         
     | 
| 1664 | 
         
            -
                    filename = f'recon_{trainer.global_step:08}.wav'
         
     | 
| 1665 | 
         
            -
                    reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu()
         
     | 
| 1666 | 
         
            -
                    torchaudio.save(filename, reals_fakes, self.sample_rate)
         
     | 
| 1667 | 
         
            -
             
     | 
| 1668 | 
         
            -
                    log_dict[f'recon'] = wandb.Audio(filename,
         
     | 
| 1669 | 
         
            -
                                                        sample_rate=self.sample_rate,
         
     | 
| 1670 | 
         
            -
                                                        caption=f'Reconstructed')
         
     | 
| 1671 | 
         
            -
             
     | 
| 1672 | 
         
            -
                    log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents)
         
     | 
| 1673 | 
         
            -
                    log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents))
         
     | 
| 1674 | 
         
            -
             
     | 
| 1675 | 
         
            -
                    log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
         
     | 
| 1676 | 
         
            -
             
     | 
| 1677 | 
         
            -
                    if module.diffae_ema.ema_model.pretransform is not None:
         
     | 
| 1678 | 
         
            -
                        with torch.no_grad() and torch.amp.autocast('cuda'):
         
     | 
| 1679 | 
         
            -
                            initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input)
         
     | 
| 1680 | 
         
            -
                            first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents)
         
     | 
| 1681 | 
         
            -
                            first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)')
         
     | 
| 1682 | 
         
            -
                            first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu()
         
     | 
| 1683 | 
         
            -
                            first_stage_filename = f'first_stage_{trainer.global_step:08}.wav'
         
     | 
| 1684 | 
         
            -
                            torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate)
         
     | 
| 1685 | 
         
            -
             
     | 
| 1686 | 
         
            -
                            log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents))
         
     | 
| 1687 | 
         
            -
             
     | 
| 1688 | 
         
            -
                            log_dict[f'first_stage'] = wandb.Audio(first_stage_filename,
         
     | 
| 1689 | 
         
            -
                                                        sample_rate=self.sample_rate,
         
     | 
| 1690 | 
         
            -
                                                        caption=f'First Stage Reconstructed')
         
     | 
| 1691 | 
         
            -
                            
         
     | 
| 1692 | 
         
            -
                            log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes))
         
     | 
| 1693 | 
         
            -
                            
         
     | 
| 1694 | 
         
            -
             
     | 
| 1695 | 
         
            -
                    trainer.logger.experiment.log(log_dict)
         
     | 
| 1696 | 
         
            -
             
     | 
| 1697 | 
         
            -
            def create_source_mixture(reals, num_sources=2):
         
     | 
| 1698 | 
         
            -
                # Create a fake mixture source by mixing elements from the training batch together with random offsets
         
     | 
| 1699 | 
         
            -
                source = torch.zeros_like(reals)
         
     | 
| 1700 | 
         
            -
                for i in range(reals.shape[0]):
         
     | 
| 1701 | 
         
            -
                    sources_added = 0
         
     | 
| 1702 | 
         
            -
                    
         
     | 
| 1703 | 
         
            -
                    js = list(range(reals.shape[0]))
         
     | 
| 1704 | 
         
            -
                    random.shuffle(js)
         
     | 
| 1705 | 
         
            -
                    for j in js:
         
     | 
| 1706 | 
         
            -
                        if i == j or (i != j and sources_added < num_sources):
         
     | 
| 1707 | 
         
            -
                            # Randomly offset the mixed element between 0 and the length of the source
         
     | 
| 1708 | 
         
            -
                            seq_len = reals.shape[2]
         
     | 
| 1709 | 
         
            -
                            offset = random.randint(0, seq_len-1)
         
     | 
| 1710 | 
         
            -
                            source[i, :, offset:] += reals[j, :, :-offset]
         
     | 
| 1711 | 
         
            -
                            if i == j:
         
     | 
| 1712 | 
         
            -
                                # If this is the real one, shift the reals as well to ensure alignment
         
     | 
| 1713 | 
         
            -
                                new_reals = torch.zeros_like(reals[i])
         
     | 
| 1714 | 
         
            -
                                new_reals[:, offset:] = reals[i, :, :-offset]
         
     | 
| 1715 | 
         
            -
                                reals[i] = new_reals
         
     | 
| 1716 | 
         
            -
                            sources_added += 1
         
     | 
| 1717 | 
         
            -
             
     | 
| 1718 | 
         
            -
                return source
         
     | 
| 1719 | 
         
            -
             
     | 
| 1720 | 
         
            -
            class DiffusionPriorTrainingWrapper(L.LightningModule):
         
     | 
| 1721 | 
         
            -
                '''
         
     | 
| 1722 | 
         
            -
                Wrapper for training a diffusion prior for inverse problems
         
     | 
| 1723 | 
         
            -
                Prior types:
         
     | 
| 1724 | 
         
            -
                    mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version
         
     | 
| 1725 | 
         
            -
                '''
         
     | 
| 1726 | 
         
            -
                def __init__(
         
     | 
| 1727 | 
         
            -
                        self,
         
     | 
| 1728 | 
         
            -
                        model: ConditionedDiffusionModelWrapper,
         
     | 
| 1729 | 
         
            -
                        lr: float = 1e-4,
         
     | 
| 1730 | 
         
            -
                        ema_copy = None,
         
     | 
| 1731 | 
         
            -
                        prior_type: PriorType = PriorType.MonoToStereo,
         
     | 
| 1732 | 
         
            -
                        use_reconstruction_loss: bool = False,
         
     | 
| 1733 | 
         
            -
                        log_loss_info: bool = False,
         
     | 
| 1734 | 
         
            -
                ):
         
     | 
| 1735 | 
         
            -
                    super().__init__()
         
     | 
| 1736 | 
         
            -
             
     | 
| 1737 | 
         
            -
                    self.diffusion = model
         
     | 
| 1738 | 
         
            -
                    
         
     | 
| 1739 | 
         
            -
                    self.diffusion_ema = EMA(
         
     | 
| 1740 | 
         
            -
                        self.diffusion,
         
     | 
| 1741 | 
         
            -
                        ema_model=ema_copy,
         
     | 
| 1742 | 
         
            -
                        beta=0.9999,
         
     | 
| 1743 | 
         
            -
                        power=3/4,
         
     | 
| 1744 | 
         
            -
                        update_every=1,
         
     | 
| 1745 | 
         
            -
                        update_after_step=1,
         
     | 
| 1746 | 
         
            -
                        include_online_model=False
         
     | 
| 1747 | 
         
            -
                    )
         
     | 
| 1748 | 
         
            -
             
     | 
| 1749 | 
         
            -
                    self.lr = lr
         
     | 
| 1750 | 
         
            -
             
     | 
| 1751 | 
         
            -
                    self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
         
     | 
| 1752 | 
         
            -
             
     | 
| 1753 | 
         
            -
                    self.log_loss_info = log_loss_info
         
     | 
| 1754 | 
         
            -
             
     | 
| 1755 | 
         
            -
                    loss_modules = [
         
     | 
| 1756 | 
         
            -
                        MSELoss("v",
         
     | 
| 1757 | 
         
            -
                                "targets",
         
     | 
| 1758 | 
         
            -
                                weight=1.0,
         
     | 
| 1759 | 
         
            -
                                name="mse_loss"
         
     | 
| 1760 | 
         
            -
                        )
         
     | 
| 1761 | 
         
            -
                    ]
         
     | 
| 1762 | 
         
            -
             
     | 
| 1763 | 
         
            -
                    self.use_reconstruction_loss = use_reconstruction_loss
         
     | 
| 1764 | 
         
            -
             
     | 
| 1765 | 
         
            -
                    if use_reconstruction_loss:
         
     | 
| 1766 | 
         
            -
                        scales = [2048, 1024, 512, 256, 128, 64, 32]
         
     | 
| 1767 | 
         
            -
                        hop_sizes = []
         
     | 
| 1768 | 
         
            -
                        win_lengths = []
         
     | 
| 1769 | 
         
            -
                        overlap = 0.75
         
     | 
| 1770 | 
         
            -
                        for s in scales:
         
     | 
| 1771 | 
         
            -
                            hop_sizes.append(int(s * (1 - overlap)))
         
     | 
| 1772 | 
         
            -
                            win_lengths.append(s)
         
     | 
| 1773 | 
         
            -
             
     | 
| 1774 | 
         
            -
                        sample_rate = model.sample_rate
         
     | 
| 1775 | 
         
            -
             
     | 
| 1776 | 
         
            -
                        stft_loss_args = {
         
     | 
| 1777 | 
         
            -
                            "fft_sizes": scales,
         
     | 
| 1778 | 
         
            -
                            "hop_sizes": hop_sizes,
         
     | 
| 1779 | 
         
            -
                            "win_lengths": win_lengths,
         
     | 
| 1780 | 
         
            -
                            "perceptual_weighting": True
         
     | 
| 1781 | 
         
            -
                        }
         
     | 
| 1782 | 
         
            -
             
     | 
| 1783 | 
         
            -
                        out_channels = model.io_channels
         
     | 
| 1784 | 
         
            -
             
     | 
| 1785 | 
         
            -
             
     | 
| 1786 | 
         
            -
                        if model.pretransform is not None:
         
     | 
| 1787 | 
         
            -
                            out_channels = model.pretransform.io_channels
         
     | 
| 1788 | 
         
            -
                        self.audio_out_channels = out_channels
         
     | 
| 1789 | 
         
            -
             
     | 
| 1790 | 
         
            -
                        if self.audio_out_channels == 2:
         
     | 
| 1791 | 
         
            -
                            self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
         
     | 
| 1792 | 
         
            -
                            self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
         
     | 
| 1793 | 
         
            -
             
     | 
| 1794 | 
         
            -
                            # Add left and right channel reconstruction losses in addition to the sum and difference
         
     | 
| 1795 | 
         
            -
                            loss_modules += [
         
     | 
| 1796 | 
         
            -
                                AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05),
         
     | 
| 1797 | 
         
            -
                                AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05),
         
     | 
| 1798 | 
         
            -
                            ]
         
     | 
| 1799 | 
         
            -
             
     | 
| 1800 | 
         
            -
                        else:
         
     | 
| 1801 | 
         
            -
                            self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
         
     | 
| 1802 | 
         
            -
             
     | 
| 1803 | 
         
            -
                        loss_modules.append(
         
     | 
| 1804 | 
         
            -
                            AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss
         
     | 
| 1805 | 
         
            -
                        )
         
     | 
| 1806 | 
         
            -
             
     | 
| 1807 | 
         
            -
                    self.losses = MultiLoss(loss_modules)
         
     | 
| 1808 | 
         
            -
             
     | 
| 1809 | 
         
            -
                    self.prior_type = prior_type
         
     | 
| 1810 | 
         
            -
             
     | 
| 1811 | 
         
            -
                def configure_optimizers(self):
         
     | 
| 1812 | 
         
            -
                    return optim.Adam([*self.diffusion.parameters()], lr=self.lr)
         
     | 
| 1813 | 
         
            -
             
     | 
| 1814 | 
         
            -
                def training_step(self, batch, batch_idx):
         
     | 
| 1815 | 
         
            -
                    reals, metadata = batch
         
     | 
| 1816 | 
         
            -
             
     | 
| 1817 | 
         
            -
                    if reals.ndim == 4 and reals.shape[0] == 1:
         
     | 
| 1818 | 
         
            -
                        reals = reals[0]
         
     | 
| 1819 | 
         
            -
             
     | 
| 1820 | 
         
            -
                    loss_info = {}
         
     | 
| 1821 | 
         
            -
             
     | 
| 1822 | 
         
            -
                    loss_info["audio_reals"] = reals
         
     | 
| 1823 | 
         
            -
             
     | 
| 1824 | 
         
            -
                    if self.prior_type == PriorType.MonoToStereo:
         
     | 
| 1825 | 
         
            -
                        source = reals.mean(dim=1, keepdim=True).repeat(1, reals.shape[1], 1).to(self.device)
         
     | 
| 1826 | 
         
            -
                        loss_info["audio_reals_mono"] = source
         
     | 
| 1827 | 
         
            -
                    else:
         
     | 
| 1828 | 
         
            -
                        raise ValueError(f"Unknown prior type {self.prior_type}")
         
     | 
| 1829 | 
         
            -
                    
         
     | 
| 1830 | 
         
            -
                    if self.diffusion.pretransform is not None:
         
     | 
| 1831 | 
         
            -
                        with torch.no_grad():
         
     | 
| 1832 | 
         
            -
                            reals = self.diffusion.pretransform.encode(reals)
         
     | 
| 1833 | 
         
            -
             
     | 
| 1834 | 
         
            -
                            if self.prior_type in [PriorType.MonoToStereo]:
         
     | 
| 1835 | 
         
            -
                                source = self.diffusion.pretransform.encode(source)
         
     | 
| 1836 | 
         
            -
             
     | 
| 1837 | 
         
            -
                    if self.diffusion.conditioner is not None:
         
     | 
| 1838 | 
         
            -
                        with torch.amp.autocast('cuda'):
         
     | 
| 1839 | 
         
            -
                            conditioning = self.diffusion.conditioner(metadata, self.device)
         
     | 
| 1840 | 
         
            -
                    else:
         
     | 
| 1841 | 
         
            -
                        conditioning = {}
         
     | 
| 1842 | 
         
            -
             
     | 
| 1843 | 
         
            -
                    loss_info["reals"] = reals
         
     | 
| 1844 | 
         
            -
             
     | 
| 1845 | 
         
            -
                    # Draw uniformly distributed continuous timesteps
         
     | 
| 1846 | 
         
            -
                    t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
         
     | 
| 1847 | 
         
            -
             
     | 
| 1848 | 
         
            -
                    # Calculate the noise schedule parameters for those timesteps
         
     | 
| 1849 | 
         
            -
                    alphas, sigmas = get_alphas_sigmas(t)
         
     | 
| 1850 | 
         
            -
             
     | 
| 1851 | 
         
            -
                    # Combine the ground truth data and the noise
         
     | 
| 1852 | 
         
            -
                    alphas = alphas[:, None, None]
         
     | 
| 1853 | 
         
            -
                    sigmas = sigmas[:, None, None]
         
     | 
| 1854 | 
         
            -
                    noise = torch.randn_like(reals)
         
     | 
| 1855 | 
         
            -
                    noised_reals = reals * alphas + noise * sigmas
         
     | 
| 1856 | 
         
            -
                    targets = noise * alphas - reals * sigmas
         
     | 
| 1857 | 
         
            -
             
     | 
| 1858 | 
         
            -
                    with torch.amp.autocast('cuda'):
         
     | 
| 1859 | 
         
            -
                        
         
     | 
| 1860 | 
         
            -
                        conditioning['source'] = [source]
         
     | 
| 1861 | 
         
            -
             
     | 
| 1862 | 
         
            -
                        v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1)
         
     | 
| 1863 | 
         
            -
                        
         
     | 
| 1864 | 
         
            -
                        loss_info.update({
         
     | 
| 1865 | 
         
            -
                            "v": v,
         
     | 
| 1866 | 
         
            -
                            "targets": targets
         
     | 
| 1867 | 
         
            -
                        })
         
     | 
| 1868 | 
         
            -
             
     | 
| 1869 | 
         
            -
                        if self.use_reconstruction_loss:
         
     | 
| 1870 | 
         
            -
                            pred = noised_reals * alphas - v * sigmas
         
     | 
| 1871 | 
         
            -
             
     | 
| 1872 | 
         
            -
                            loss_info["pred"] = pred
         
     | 
| 1873 | 
         
            -
             
     | 
| 1874 | 
         
            -
                            if self.diffusion.pretransform is not None:
         
     | 
| 1875 | 
         
            -
                                pred = self.diffusion.pretransform.decode(pred)
         
     | 
| 1876 | 
         
            -
                                loss_info["audio_pred"] = pred
         
     | 
| 1877 | 
         
            -
             
     | 
| 1878 | 
         
            -
                            if self.audio_out_channels == 2:
         
     | 
| 1879 | 
         
            -
                                loss_info["pred_left"] = pred[:, 0:1, :]
         
     | 
| 1880 | 
         
            -
                                loss_info["pred_right"] = pred[:, 1:2, :]
         
     | 
| 1881 | 
         
            -
                                loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :]
         
     | 
| 1882 | 
         
            -
                                loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :]
         
     | 
| 1883 | 
         
            -
             
     | 
| 1884 | 
         
            -
                        loss, losses = self.losses(loss_info)
         
     | 
| 1885 | 
         
            -
             
     | 
| 1886 | 
         
            -
                        if self.log_loss_info:
         
     | 
| 1887 | 
         
            -
                            # Loss debugging logs
         
     | 
| 1888 | 
         
            -
                            num_loss_buckets = 10
         
     | 
| 1889 | 
         
            -
                            bucket_size = 1 / num_loss_buckets
         
     | 
| 1890 | 
         
            -
                            loss_all = F.mse_loss(v, targets, reduction="none")
         
     | 
| 1891 | 
         
            -
             
     | 
| 1892 | 
         
            -
                            sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
         
     | 
| 1893 | 
         
            -
             
     | 
| 1894 | 
         
            -
                            # gather loss_all across all GPUs
         
     | 
| 1895 | 
         
            -
                            loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
         
     | 
| 1896 | 
         
            -
             
     | 
| 1897 | 
         
            -
                            # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
         
     | 
| 1898 | 
         
            -
                            loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
         
     | 
| 1899 | 
         
            -
             
     | 
| 1900 | 
         
            -
                            # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
         
     | 
| 1901 | 
         
            -
                            debug_log_dict = {
         
     | 
| 1902 | 
         
            -
                                f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
         
     | 
| 1903 | 
         
            -
                            }
         
     | 
| 1904 | 
         
            -
             
     | 
| 1905 | 
         
            -
                            self.log_dict(debug_log_dict)
         
     | 
| 1906 | 
         
            -
             
     | 
| 1907 | 
         
            -
                    log_dict = {
         
     | 
| 1908 | 
         
            -
                        'train/loss': loss.detach(),
         
     | 
| 1909 | 
         
            -
                        'train/std_data': reals.std()
         
     | 
| 1910 | 
         
            -
                    }
         
     | 
| 1911 | 
         
            -
             
     | 
| 1912 | 
         
            -
                    for loss_name, loss_value in losses.items():
         
     | 
| 1913 | 
         
            -
                        log_dict[f"train/{loss_name}"] = loss_value.detach()
         
     | 
| 1914 | 
         
            -
             
     | 
| 1915 | 
         
            -
                    self.log_dict(log_dict, prog_bar=True, on_step=True)
         
     | 
| 1916 | 
         
            -
                    return loss
         
     | 
| 1917 | 
         
            -
                
         
     | 
| 1918 | 
         
            -
                def on_before_zero_grad(self, *args, **kwargs):
         
     | 
| 1919 | 
         
            -
                    self.diffusion_ema.update()
         
     | 
| 1920 | 
         
            -
             
     | 
| 1921 | 
         
            -
                def export_model(self, path, use_safetensors=False):
         
     | 
| 1922 | 
         
            -
             
     | 
| 1923 | 
         
            -
                    #model = self.diffusion_ema.ema_model
         
     | 
| 1924 | 
         
            -
                    model = self.diffusion
         
     | 
| 1925 | 
         
            -
                    
         
     | 
| 1926 | 
         
            -
                    if use_safetensors:
         
     | 
| 1927 | 
         
            -
                        save_file(model.state_dict(), path)
         
     | 
| 1928 | 
         
            -
                    else:
         
     | 
| 1929 | 
         
            -
                        torch.save({"state_dict": model.state_dict()}, path)
         
     | 
| 1930 | 
         
            -
             
     | 
| 1931 | 
         
            -
            class DiffusionPriorDemoCallback(Callback):
         
     | 
| 1932 | 
         
            -
                def __init__(
         
     | 
| 1933 | 
         
            -
                    self, 
         
     | 
| 1934 | 
         
            -
                    demo_dl, 
         
     | 
| 1935 | 
         
            -
                    demo_every=2000,
         
     | 
| 1936 | 
         
            -
                    demo_steps=250,
         
     | 
| 1937 | 
         
            -
                    sample_size=65536,
         
     | 
| 1938 | 
         
            -
                    sample_rate=48000
         
     | 
| 1939 | 
         
            -
                ):
         
     | 
| 1940 | 
         
            -
                    super().__init__()
         
     | 
| 1941 | 
         
            -
             
     | 
| 1942 | 
         
            -
                    self.demo_every = demo_every
         
     | 
| 1943 | 
         
            -
                    self.demo_steps = demo_steps
         
     | 
| 1944 | 
         
            -
                    self.demo_samples = sample_size
         
     | 
| 1945 | 
         
            -
                    self.demo_dl = iter(demo_dl)
         
     | 
| 1946 | 
         
            -
                    self.sample_rate = sample_rate
         
     | 
| 1947 | 
         
            -
                    self.last_demo_step = -1
         
     | 
| 1948 | 
         
            -
             
     | 
| 1949 | 
         
            -
                @rank_zero_only
         
     | 
| 1950 | 
         
            -
                @torch.no_grad()
         
     | 
| 1951 | 
         
            -
                def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): 
         
     | 
| 1952 | 
         
            -
                    if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
         
     | 
| 1953 | 
         
            -
                        return
         
     | 
| 1954 | 
         
            -
                    
         
     | 
| 1955 | 
         
            -
                    self.last_demo_step = trainer.global_step
         
     | 
| 1956 | 
         
            -
             
     | 
| 1957 | 
         
            -
                    demo_reals, metadata = next(self.demo_dl)
         
     | 
| 1958 | 
         
            -
                    # import ipdb
         
     | 
| 1959 | 
         
            -
                    # ipdb.set_trace()
         
     | 
| 1960 | 
         
            -
                    # Remove extra dimension added by WebDataset
         
     | 
| 1961 | 
         
            -
                    if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
         
     | 
| 1962 | 
         
            -
                        demo_reals = demo_reals[0]
         
     | 
| 1963 | 
         
            -
             
     | 
| 1964 | 
         
            -
                    demo_reals = demo_reals.to(module.device)
         
     | 
| 1965 | 
         
            -
             
     | 
| 1966 | 
         
            -
                    encoder_input = demo_reals
         
     | 
| 1967 | 
         
            -
             
     | 
| 1968 | 
         
            -
                    if module.diffusion.conditioner is not None:
         
     | 
| 1969 | 
         
            -
                        with torch.amp.autocast('cuda'):
         
     | 
| 1970 | 
         
            -
                            conditioning_tensors = module.diffusion.conditioner(metadata, module.device)
         
     | 
| 1971 | 
         
            -
             
     | 
| 1972 | 
         
            -
                    else:
         
     | 
| 1973 | 
         
            -
                        conditioning_tensors = {}
         
     | 
| 1974 | 
         
            -
             
     | 
| 1975 | 
         
            -
                           
         
     | 
| 1976 | 
         
            -
                    with torch.no_grad() and torch.amp.autocast('cuda'):
         
     | 
| 1977 | 
         
            -
                        if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1:
         
     | 
| 1978 | 
         
            -
                            source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device)
         
     | 
| 1979 | 
         
            -
             
     | 
| 1980 | 
         
            -
                        if module.diffusion.pretransform is not None:
         
     | 
| 1981 | 
         
            -
                            encoder_input = module.diffusion.pretransform.encode(encoder_input)
         
     | 
| 1982 | 
         
            -
                            source_input = module.diffusion.pretransform.encode(source)
         
     | 
| 1983 | 
         
            -
                        else:
         
     | 
| 1984 | 
         
            -
                            source_input = source
         
     | 
| 1985 | 
         
            -
             
     | 
| 1986 | 
         
            -
                        conditioning_tensors['source'] = [source_input]
         
     | 
| 1987 | 
         
            -
             
     | 
| 1988 | 
         
            -
                        fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors)
         
     | 
| 1989 | 
         
            -
             
     | 
| 1990 | 
         
            -
                        if module.diffusion.pretransform is not None:
         
     | 
| 1991 | 
         
            -
                            fakes = module.diffusion.pretransform.decode(fakes)
         
     | 
| 1992 | 
         
            -
             
     | 
| 1993 | 
         
            -
                    #Interleave reals and fakes
         
     | 
| 1994 | 
         
            -
                    reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
         
     | 
| 1995 | 
         
            -
             
     | 
| 1996 | 
         
            -
                    # Put the demos together
         
     | 
| 1997 | 
         
            -
                    reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
         
     | 
| 1998 | 
         
            -
             
     | 
| 1999 | 
         
            -
                    log_dict = {}
         
     | 
| 2000 | 
         
            -
                    
         
     | 
| 2001 | 
         
            -
                    filename = f'recon_mono_{trainer.global_step:08}.wav'
         
     | 
| 2002 | 
         
            -
                    reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu()
         
     | 
| 2003 | 
         
            -
                    torchaudio.save(filename, reals_fakes, self.sample_rate)
         
     | 
| 2004 | 
         
            -
             
     | 
| 2005 | 
         
            -
                    log_dict[f'recon'] = wandb.Audio(filename,
         
     | 
| 2006 | 
         
            -
                                                        sample_rate=self.sample_rate,
         
     | 
| 2007 | 
         
            -
                                                        caption=f'Reconstructed')
         
     | 
| 2008 | 
         
            -
             
     | 
| 2009 | 
         
            -
                    log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))   
         
     | 
| 2010 | 
         
            -
             
     | 
| 2011 | 
         
            -
                    #Log the source
         
     | 
| 2012 | 
         
            -
                    filename = f'source_{trainer.global_step:08}.wav'
         
     | 
| 2013 | 
         
            -
                    source = rearrange(source, 'b d n -> d (b n)')
         
     | 
| 2014 | 
         
            -
                    source = source.to(torch.float32).mul(32767).to(torch.int16).cpu()
         
     | 
| 2015 | 
         
            -
                    torchaudio.save(filename, source, self.sample_rate)
         
     | 
| 2016 | 
         
            -
             
     | 
| 2017 | 
         
            -
                    log_dict[f'source'] = wandb.Audio(filename,
         
     | 
| 2018 | 
         
            -
                                                        sample_rate=self.sample_rate,
         
     | 
| 2019 | 
         
            -
                                                        caption=f'Source')
         
     | 
| 2020 | 
         
            -
             
     | 
| 2021 | 
         
            -
                    log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source))
         
     | 
| 2022 | 
         
            -
             
     | 
| 2023 | 
         
            -
                    trainer.logger.experiment.log(log_dict)
         
     | 
| 
         | 
|
| 20 | 
         
             
            from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
         
     | 
| 21 | 
         
             
            from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper
         
     | 
| 22 | 
         
             
            from ..models.autoencoders import DiffusionAutoencoder
         
     | 
| 
         | 
|
| 23 | 
         
             
            from .autoencoders import create_loss_modules_from_bottleneck
         
     | 
| 24 | 
         
             
            from .losses import AuralossLoss, MSELoss, MultiLoss
         
     | 
| 25 | 
         
             
            from .utils import create_optimizer_from_config, create_scheduler_from_config, mask_from_frac_lengths, generate_mask, generate_channel_mask
         
     | 
| 
         | 
|
| 845 | 
         | 
| 846 | 
         
             
                def predict_step(self, batch, batch_idx):
         
     | 
| 847 | 
         
             
                    reals, metadata = batch
         
     | 
| 
         | 
|
| 
         | 
|
| 848 | 
         
             
                    ids = [item['id'] for item in metadata]
         
     | 
| 849 | 
         
             
                    batch_size, length = reals.shape[0], reals.shape[2]
         
     | 
| 850 | 
         
            +
                    print(f"Predicting {batch_size} samples with length {length} for ids: {ids}")
         
     | 
| 851 | 
         
             
                    with torch.amp.autocast('cuda'):
         
     | 
| 852 | 
         
             
                        conditioning = self.diffusion.conditioner(metadata, self.device)
         
     | 
| 853 | 
         | 
| 
         | 
|
| 876 | 
         
             
                            end_time = time.time()
         
     | 
| 877 | 
         
             
                            execution_time = end_time - start_time
         
     | 
| 878 | 
         
             
                            print(f"执行时间: {execution_time:.2f} 秒")
         
     | 
| 
         | 
|
| 879 | 
         
             
                        if self.diffusion.pretransform is not None:
         
     | 
| 880 | 
         
             
                            fakes = self.diffusion.pretransform.decode(fakes)
         
     | 
| 881 | 
         | 
| 
         | 
|
| 1074 | 
         
             
                        gc.collect()
         
     | 
| 1075 | 
         
             
                        torch.cuda.empty_cache()
         
     | 
| 1076 | 
         
             
                        module.train()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        {think_sound → ThinkSound}/training/factory.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/training/losses/__init__.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/training/losses/auraloss.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/training/losses/losses.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        {think_sound → ThinkSound}/training/utils.py
    RENAMED
    
    | 
         
            File without changes
         
     | 
    	
        app.py
    CHANGED
    
    | 
         @@ -14,13 +14,12 @@ from lightning.pytorch.tuner import Tuner 
     | 
|
| 14 | 
         
             
            from lightning.pytorch import seed_everything
         
     | 
| 15 | 
         
             
            import random
         
     | 
| 16 | 
         
             
            from datetime import datetime
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
            from  
     | 
| 19 | 
         
            -
            from  
     | 
| 20 | 
         
            -
            from  
     | 
| 21 | 
         
            -
            from  
     | 
| 22 | 
         
            -
            from  
     | 
| 23 | 
         
            -
            from think_sound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
         
     | 
| 24 | 
         
             
            from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
         
     | 
| 25 | 
         
             
            from torch.utils.data import Dataset
         
     | 
| 26 | 
         
             
            from typing import Optional, Union
         
     | 
| 
         @@ -34,7 +33,7 @@ import tempfile 
     | 
|
| 34 | 
         
             
            import subprocess
         
     | 
| 35 | 
         
             
            from huggingface_hub import hf_hub_download
         
     | 
| 36 | 
         
             
            from moviepy.editor import VideoFileClip
         
     | 
| 37 | 
         
            -
            os.system("conda install -c conda-forge 'ffmpeg<7'")
         
     | 
| 38 | 
         | 
| 39 | 
         
             
            _CLIP_SIZE = 224
         
     | 
| 40 | 
         
             
            _CLIP_FPS = 8.0
         
     | 
| 
         @@ -101,7 +100,7 @@ class VGGSound(Dataset): 
     | 
|
| 101 | 
         | 
| 102 | 
         
             
                    self.resampler = {}
         
     | 
| 103 | 
         | 
| 104 | 
         
            -
                def sample(self, video_path,label):
         
     | 
| 105 | 
         
             
                    video_id = video_path
         
     | 
| 106 | 
         | 
| 107 | 
         
             
                    reader = StreamingMediaDecoder(video_path)
         
     | 
| 
         @@ -156,7 +155,7 @@ class VGGSound(Dataset): 
     | 
|
| 156 | 
         
             
                        # padding using the last frame, but no more than 2
         
     | 
| 157 | 
         
             
                        current_length = sync_chunk.shape[0]
         
     | 
| 158 | 
         
             
                        last_frame = sync_chunk[-1]
         
     | 
| 159 | 
         
            -
             
     | 
| 160 | 
         
             
                        padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
         
     | 
| 161 | 
         
             
                        assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
         
     | 
| 162 | 
         
             
                        sync_chunk = torch.cat((sync_chunk, padding), dim=0)
         
     | 
| 
         @@ -170,6 +169,7 @@ class VGGSound(Dataset): 
     | 
|
| 170 | 
         
             
                    data = {
         
     | 
| 171 | 
         
             
                        'id': video_id,
         
     | 
| 172 | 
         
             
                        'caption': label,
         
     | 
| 
         | 
|
| 173 | 
         
             
                        # 'audio': audio_chunk,
         
     | 
| 174 | 
         
             
                        'clip_video': clip_chunk,
         
     | 
| 175 | 
         
             
                        'sync_video': sync_chunk,
         
     | 
| 
         @@ -187,17 +187,16 @@ else: 
     | 
|
| 187 | 
         | 
| 188 | 
         
             
            print(f"load in device {device}")
         
     | 
| 189 | 
         | 
| 190 | 
         
            -
            vae_ckpt = hf_hub_download(repo_id=" 
     | 
| 191 | 
         
            -
            synchformer_ckpt = hf_hub_download(repo_id=" 
     | 
| 
         | 
|
| 192 | 
         
             
            feature_extractor = FeaturesUtils(
         
     | 
| 193 | 
         
            -
                vae_ckpt= 
     | 
| 194 | 
         
            -
                vae_config=' 
     | 
| 195 | 
         
             
                enable_conditions=True,
         
     | 
| 196 | 
         
             
                synchformer_ckpt=synchformer_ckpt
         
     | 
| 197 | 
         
             
            ).eval().to(extra_device)
         
     | 
| 198 | 
         | 
| 199 | 
         
            -
             
     | 
| 200 | 
         
            -
             
     | 
| 201 | 
         
             
            args = get_all_args()
         
     | 
| 202 | 
         | 
| 203 | 
         
             
            seed = 10086
         
     | 
| 
         @@ -206,7 +205,7 @@ seed_everything(seed, workers=True) 
     | 
|
| 206 | 
         | 
| 207 | 
         | 
| 208 | 
         
             
            #Get JSON config from args.model_config
         
     | 
| 209 | 
         
            -
            with open(" 
     | 
| 210 | 
         
             
                model_config = json.load(f)
         
     | 
| 211 | 
         | 
| 212 | 
         
             
            model = create_model_from_config(model_config)
         
     | 
| 
         @@ -229,7 +228,7 @@ model.pretransform.load_state_dict(load_vae_state) 
     | 
|
| 229 | 
         
             
            # Remove weight_norm from the pretransform if specified
         
     | 
| 230 | 
         
             
            if args.remove_pretransform_weight_norm == "post_load":
         
     | 
| 231 | 
         
             
                remove_weight_norm_from_model(model.pretransform)
         
     | 
| 232 | 
         
            -
            ckpt_path = hf_hub_download(repo_id=" 
     | 
| 233 | 
         
             
            training_wrapper = create_training_wrapper_from_config(model_config, model)
         
     | 
| 234 | 
         
             
            # 加载模型权重时根据设备选择map_location
         
     | 
| 235 | 
         
             
            training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
         
     | 
| 
         @@ -243,16 +242,17 @@ def get_video_duration(video_path): 
     | 
|
| 243 | 
         
             
            @spaces.GPU(duration=60)
         
     | 
| 244 | 
         
             
            @torch.inference_mode()
         
     | 
| 245 | 
         
             
            @torch.no_grad()
         
     | 
| 246 | 
         
            -
            def get_audio(video_path, caption):
         
     | 
| 247 | 
         
            -
                # 允许caption为空
         
     | 
| 248 | 
         
             
                if caption is None:
         
     | 
| 249 | 
         
             
                    caption = ''
         
     | 
| 
         | 
|
| 
         | 
|
| 250 | 
         
             
                timer = Timer(duration="00:15:00:00")
         
     | 
| 251 | 
         
             
                #get video duration
         
     | 
| 252 | 
         
             
                duration_sec = get_video_duration(video_path)
         
     | 
| 253 | 
         
             
                print(duration_sec)
         
     | 
| 254 | 
         
             
                preprocesser = VGGSound(duration_sec=duration_sec)
         
     | 
| 255 | 
         
            -
                data = preprocesser.sample(video_path, caption)
         
     | 
| 256 | 
         | 
| 257 | 
         | 
| 258 | 
         | 
| 
         @@ -261,7 +261,7 @@ def get_audio(video_path, caption): 
     | 
|
| 261 | 
         
             
                preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0)
         
     | 
| 262 | 
         
             
                preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0)
         
     | 
| 263 | 
         | 
| 264 | 
         
            -
                t5_features = feature_extractor.encode_t5_text(data[' 
     | 
| 265 | 
         
             
                preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0)
         
     | 
| 266 | 
         | 
| 267 | 
         
             
                clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device))
         
     | 
| 
         @@ -305,56 +305,47 @@ def get_audio(video_path, caption): 
     | 
|
| 305 | 
         
             
                        fakes = training_wrapper.diffusion.pretransform.decode(fakes)
         
     | 
| 306 | 
         | 
| 307 | 
         
             
                audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
         
     | 
| 308 | 
         
            -
                # 保存临时音频文件
         
     | 
| 309 | 
         
             
                with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
         
     | 
| 310 | 
         
             
                    torchaudio.save(tmp_audio.name, audios[0], 44100)
         
     | 
| 311 | 
         
             
                    audio_path = tmp_audio.name
         
     | 
| 
         | 
|
| 312 | 
         
             
                return audio_path
         
     | 
| 313 | 
         | 
| 314 | 
         
            -
            def synthesize_video_with_audio(video_file, caption):
         
     | 
| 315 | 
         
            -
                 
     | 
| 316 | 
         
            -
                if caption is None:
         
     | 
| 317 | 
         
            -
                    caption = ''
         
     | 
| 318 | 
         
            -
                audio_path = get_audio(video_file, caption)
         
     | 
| 319 | 
         
             
                with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
         
     | 
| 320 | 
         
             
                    output_video_path = tmp_video.name
         
     | 
| 321 | 
         
            -
             
     | 
| 322 | 
         
             
                cmd = [
         
     | 
| 323 | 
         
             
                    'ffmpeg', '-y', '-i', video_file, '-i', audio_path,
         
     | 
| 324 | 
         
             
                    '-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0',
         
     | 
| 325 | 
         
             
                    '-shortest', output_video_path
         
     | 
| 326 | 
         
             
                ]
         
     | 
| 327 | 
         
             
                subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
         
     | 
| 
         | 
|
| 328 | 
         
             
                return output_video_path
         
     | 
| 329 | 
         | 
| 330 | 
         
            -
             
     | 
| 331 | 
         
            -
             
     | 
| 332 | 
         
            -
                 
     | 
| 333 | 
         
            -
                    "" 
     | 
| 334 | 
         
            -
             
     | 
| 335 | 
         
            -
             
     | 
| 336 | 
         
            -
             
     | 
| 337 | 
         
            -
             
     | 
| 338 | 
         
            -
             
     | 
| 339 | 
         
            -
             
     | 
| 340 | 
         
            -
                 
     | 
| 341 | 
         
            -
                with  
     | 
| 342 | 
         
            -
             
     | 
| 343 | 
         
            -
                     
     | 
| 344 | 
         
            -
             
     | 
| 345 | 
         
            -
             
     | 
| 346 | 
         
            -
             
     | 
| 347 | 
         
            -
             
     | 
| 348 | 
         
            -
                 
     | 
| 349 | 
         
            -
             
     | 
| 350 | 
         
            -
             
     | 
| 351 | 
         
            -
             
     | 
| 352 | 
         
            -
             
     | 
| 353 | 
         
            -
                        ["./examples/4_mute.mp4", "train passing by", "./examples/4.mp4"],
         
     | 
| 354 | 
         
            -
                        ["./examples/5_mute.mp4", "Lighting Firecrackers", "./examples/5.mp4"]
         
     | 
| 355 | 
         
            -
                    ],
         
     | 
| 356 | 
         
            -
                    inputs=[video_input, caption_input,output_video],
         
     | 
| 357 | 
         
            -
                )
         
     | 
| 358 | 
         
            -
                
         
     | 
| 359 | 
         
            -
            demo.launch(share=True)
         
     | 
| 360 | 
         | 
| 
         | 
|
| 14 | 
         
             
            from lightning.pytorch import seed_everything
         
     | 
| 15 | 
         
             
            import random
         
     | 
| 16 | 
         
             
            from datetime import datetime
         
     | 
| 17 | 
         
            +
            from ThinkSound.data.datamodule import DataModule
         
     | 
| 18 | 
         
            +
            from ThinkSound.models import create_model_from_config
         
     | 
| 19 | 
         
            +
            from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
         
     | 
| 20 | 
         
            +
            from ThinkSound.training import create_training_wrapper_from_config, create_demo_callback_from_config
         
     | 
| 21 | 
         
            +
            from ThinkSound.training.utils import copy_state_dict
         
     | 
| 22 | 
         
            +
            from ThinkSound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
         
     | 
| 
         | 
|
| 23 | 
         
             
            from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
         
     | 
| 24 | 
         
             
            from torch.utils.data import Dataset
         
     | 
| 25 | 
         
             
            from typing import Optional, Union
         
     | 
| 
         | 
|
| 33 | 
         
             
            import subprocess
         
     | 
| 34 | 
         
             
            from huggingface_hub import hf_hub_download
         
     | 
| 35 | 
         
             
            from moviepy.editor import VideoFileClip
         
     | 
| 36 | 
         
            +
            # os.system("conda install -c conda-forge 'ffmpeg<7'")
         
     | 
| 37 | 
         | 
| 38 | 
         
             
            _CLIP_SIZE = 224
         
     | 
| 39 | 
         
             
            _CLIP_FPS = 8.0
         
     | 
| 
         | 
|
| 100 | 
         | 
| 101 | 
         
             
                    self.resampler = {}
         
     | 
| 102 | 
         | 
| 103 | 
         
            +
                def sample(self, video_path,label,cot):
         
     | 
| 104 | 
         
             
                    video_id = video_path
         
     | 
| 105 | 
         | 
| 106 | 
         
             
                    reader = StreamingMediaDecoder(video_path)
         
     | 
| 
         | 
|
| 155 | 
         
             
                        # padding using the last frame, but no more than 2
         
     | 
| 156 | 
         
             
                        current_length = sync_chunk.shape[0]
         
     | 
| 157 | 
         
             
                        last_frame = sync_chunk[-1]
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
             
                        padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
         
     | 
| 160 | 
         
             
                        assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
         
     | 
| 161 | 
         
             
                        sync_chunk = torch.cat((sync_chunk, padding), dim=0)
         
     | 
| 
         | 
|
| 169 | 
         
             
                    data = {
         
     | 
| 170 | 
         
             
                        'id': video_id,
         
     | 
| 171 | 
         
             
                        'caption': label,
         
     | 
| 172 | 
         
            +
                        'caption_cot': cot,
         
     | 
| 173 | 
         
             
                        # 'audio': audio_chunk,
         
     | 
| 174 | 
         
             
                        'clip_video': clip_chunk,
         
     | 
| 175 | 
         
             
                        'sync_video': sync_chunk,
         
     | 
| 
         | 
|
| 187 | 
         | 
| 188 | 
         
             
            print(f"load in device {device}")
         
     | 
| 189 | 
         | 
| 190 | 
         
            +
            vae_ckpt = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="vae.ckpt",repo_type="model")
         
     | 
| 191 | 
         
            +
            synchformer_ckpt = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="synchformer_state_dict.pth",repo_type="model")
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
             
            feature_extractor = FeaturesUtils(
         
     | 
| 194 | 
         
            +
                vae_ckpt=None,
         
     | 
| 195 | 
         
            +
                vae_config='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json',
         
     | 
| 196 | 
         
             
                enable_conditions=True,
         
     | 
| 197 | 
         
             
                synchformer_ckpt=synchformer_ckpt
         
     | 
| 198 | 
         
             
            ).eval().to(extra_device)
         
     | 
| 199 | 
         | 
| 
         | 
|
| 
         | 
|
| 200 | 
         
             
            args = get_all_args()
         
     | 
| 201 | 
         | 
| 202 | 
         
             
            seed = 10086
         
     | 
| 
         | 
|
| 205 | 
         | 
| 206 | 
         | 
| 207 | 
         
             
            #Get JSON config from args.model_config
         
     | 
| 208 | 
         
            +
            with open("ThinkSound/configs/model_configs/thinksound.json") as f:
         
     | 
| 209 | 
         
             
                model_config = json.load(f)
         
     | 
| 210 | 
         | 
| 211 | 
         
             
            model = create_model_from_config(model_config)
         
     | 
| 
         | 
|
| 228 | 
         
             
            # Remove weight_norm from the pretransform if specified
         
     | 
| 229 | 
         
             
            if args.remove_pretransform_weight_norm == "post_load":
         
     | 
| 230 | 
         
             
                remove_weight_norm_from_model(model.pretransform)
         
     | 
| 231 | 
         
            +
            ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound.ckpt",repo_type="model")
         
     | 
| 232 | 
         
             
            training_wrapper = create_training_wrapper_from_config(model_config, model)
         
     | 
| 233 | 
         
             
            # 加载模型权重时根据设备选择map_location
         
     | 
| 234 | 
         
             
            training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
         
     | 
| 
         | 
|
| 242 | 
         
             
            @spaces.GPU(duration=60)
         
     | 
| 243 | 
         
             
            @torch.inference_mode()
         
     | 
| 244 | 
         
             
            @torch.no_grad()
         
     | 
| 245 | 
         
            +
            def get_audio(video_path, caption, cot):
         
     | 
| 
         | 
|
| 246 | 
         
             
                if caption is None:
         
     | 
| 247 | 
         
             
                    caption = ''
         
     | 
| 248 | 
         
            +
                if cot is None:
         
     | 
| 249 | 
         
            +
                    cot = caption
         
     | 
| 250 | 
         
             
                timer = Timer(duration="00:15:00:00")
         
     | 
| 251 | 
         
             
                #get video duration
         
     | 
| 252 | 
         
             
                duration_sec = get_video_duration(video_path)
         
     | 
| 253 | 
         
             
                print(duration_sec)
         
     | 
| 254 | 
         
             
                preprocesser = VGGSound(duration_sec=duration_sec)
         
     | 
| 255 | 
         
            +
                data = preprocesser.sample(video_path, caption, cot)
         
     | 
| 256 | 
         | 
| 257 | 
         | 
| 258 | 
         | 
| 
         | 
|
| 261 | 
         
             
                preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0)
         
     | 
| 262 | 
         
             
                preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0)
         
     | 
| 263 | 
         | 
| 264 | 
         
            +
                t5_features = feature_extractor.encode_t5_text(data['caption_cot'])
         
     | 
| 265 | 
         
             
                preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0)
         
     | 
| 266 | 
         | 
| 267 | 
         
             
                clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device))
         
     | 
| 
         | 
|
| 305 | 
         
             
                        fakes = training_wrapper.diffusion.pretransform.decode(fakes)
         
     | 
| 306 | 
         | 
| 307 | 
         
             
                audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
         
     | 
| 
         | 
|
| 308 | 
         
             
                with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
         
     | 
| 309 | 
         
             
                    torchaudio.save(tmp_audio.name, audios[0], 44100)
         
     | 
| 310 | 
         
             
                    audio_path = tmp_audio.name
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
             
                return audio_path
         
     | 
| 313 | 
         | 
| 314 | 
         
            +
            def synthesize_video_with_audio(video_file, caption, cot):
         
     | 
| 315 | 
         
            +
                audio_path = get_audio(video_file, caption, cot)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 316 | 
         
             
                with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
         
     | 
| 317 | 
         
             
                    output_video_path = tmp_video.name
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
             
                cmd = [
         
     | 
| 320 | 
         
             
                    'ffmpeg', '-y', '-i', video_file, '-i', audio_path,
         
     | 
| 321 | 
         
             
                    '-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0',
         
     | 
| 322 | 
         
             
                    '-shortest', output_video_path
         
     | 
| 323 | 
         
             
                ]
         
     | 
| 324 | 
         
             
                subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
             
                return output_video_path
         
     | 
| 327 | 
         | 
| 328 | 
         
            +
            demo = gr.Interface(
         
     | 
| 329 | 
         
            +
                fn=synthesize_video_with_audio,
         
     | 
| 330 | 
         
            +
                inputs=[
         
     | 
| 331 | 
         
            +
                    gr.Video(label="Upload Video"),
         
     | 
| 332 | 
         
            +
                    gr.Textbox(label="Caption (optional)", placeholder="can be empty",),
         
     | 
| 333 | 
         
            +
                    gr.Textbox(label="CoT Description (optional)", lines=6, placeholder="can be empty",),
         
     | 
| 334 | 
         
            +
                ],
         
     | 
| 335 | 
         
            +
                outputs=[
         
     | 
| 336 | 
         
            +
                    gr.Video(label="Result"),
         
     | 
| 337 | 
         
            +
                ],
         
     | 
| 338 | 
         
            +
                title="ThinkSound Demo",
         
     | 
| 339 | 
         
            +
                description="Upload a video, caption, or CoT to generate audio. For an enhanced experience, we automatically merge the generated audio with your original silent video. (Note: Flexible audio generation lengths are supported.:)",
         
     | 
| 340 | 
         
            +
                examples=[
         
     | 
| 341 | 
         
            +
                    ["examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier", "Begin by creating a soft, steady background of light pacifier suckling. Add subtle, breathy rhythms to mimic a newborn's gentle mouth movements. Keep the sound smooth, natural, and soothing."],
         
     | 
| 342 | 
         
            +
                    ["examples/2_mute.mp4", "Printer Printing", "Generate a continuous printer printing sound with periodic beeps and paper movement, plus a cat pawing at the machine. Add subtle ambient room noise for authenticity, keeping the focus on printing, beeps, and the cat's interaction."],
         
     | 
| 343 | 
         
            +
                    ["examples/4_mute.mp4", "Plastic Debris Handling", "Begin with the sound of hands scooping up loose plastic debris, followed by the subtle cascading noise as the pieces fall and scatter back down. Include soft crinkling and rustling to emphasize the texture of the plastic. Add ambient factory background noise with distant machinery to create an industrial atmosphere."],
         
     | 
| 344 | 
         
            +
                    ["examples/5_mute.mp4", "Lighting Firecrackers", "Generate the sound of firecrackers lighting and exploding repeatedly on the ground, followed by fireworks bursting in the sky. Incorporate occasional subtle echoes to mimic an outdoor night ambiance, with no human voices present."]
         
     | 
| 345 | 
         
            +
                ],
         
     | 
| 346 | 
         
            +
                cache_examples=True
         
     | 
| 347 | 
         
            +
            )
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 350 | 
         
            +
                demo.launch(share=True)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 351 | 
         | 
    	
        cot_vgg_demo_caption.txt
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            demo.npz
         
     | 
    	
        data_utils/__pycache__/__init__.cpython-310.pyc
    DELETED
    
    | 
         Binary file (149 Bytes) 
     | 
| 
         | 
    	
        data_utils/__pycache__/utils.cpython-310.pyc
    DELETED
    
    | 
         Binary file (4.56 kB) 
     | 
| 
         | 
    	
        data_utils/__pycache__/utils.cpython-39.pyc
    DELETED
    
    | 
         Binary file (4.56 kB) 
     | 
| 
         | 
    	
        data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc
    DELETED
    
    | 
         Binary file (243 Bytes) 
     | 
| 
         | 
    	
        data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc
    DELETED
    
    | 
         Binary file (241 Bytes) 
     | 
| 
         | 
    	
        data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc
    DELETED
    
    | 
         Binary file (12.7 kB) 
     | 
| 
         | 
    	
        data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc
    DELETED
    
    | 
         Binary file (12.7 kB) 
     | 
| 
         | 
    	
        data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc
    DELETED
    
    | 
         Binary file (1.91 kB) 
     | 
| 
         | 
    	
        data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc
    DELETED
    
    | 
         Binary file (1.9 kB) 
     | 
| 
         | 
    	
        data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc
    DELETED
    
    | 
         Binary file (3.97 kB) 
     | 
| 
         | 
    	
        data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc
    DELETED
    
    | 
         Binary file (3.78 kB) 
     | 
| 
         |