Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	
		Antoni Bigata
		
	commited on
		
		
					Commit 
							
							·
						
						4fd1a69
	
1
								Parent(s):
							
							fc0dc6f
								
requirements
Browse files- WavLM.py +1 -1
 - sgm/models/diffusion.py +5 -5
 
    	
        WavLM.py
    CHANGED
    
    | 
         @@ -48,7 +48,7 @@ class WavLM_wrapper(nn.Module): 
     | 
|
| 48 | 
         
             
                        )
         
     | 
| 49 | 
         
             
                    if not os.path.exists(model_path):
         
     | 
| 50 | 
         
             
                        self.download_model(model_path, model_size)
         
     | 
| 51 | 
         
            -
                    checkpoint = torch.load(model_path)
         
     | 
| 52 | 
         
             
                    cfg = WavLMConfig(checkpoint["cfg"])
         
     | 
| 53 | 
         
             
                    self.cfg = cfg
         
     | 
| 54 | 
         
             
                    self.model = WavLM(cfg)
         
     | 
| 
         | 
|
| 48 | 
         
             
                        )
         
     | 
| 49 | 
         
             
                    if not os.path.exists(model_path):
         
     | 
| 50 | 
         
             
                        self.download_model(model_path, model_size)
         
     | 
| 51 | 
         
            +
                    checkpoint = torch.load(model_path, weights_only=False)
         
     | 
| 52 | 
         
             
                    cfg = WavLMConfig(checkpoint["cfg"])
         
     | 
| 53 | 
         
             
                    self.cfg = cfg
         
     | 
| 54 | 
         
             
                    self.model = WavLM(cfg)
         
     | 
    	
        sgm/models/diffusion.py
    CHANGED
    
    | 
         @@ -119,7 +119,7 @@ class DiffusionEngine(pl.LightningModule): 
     | 
|
| 119 | 
         
             
                            pattern_to_remove=pattern_to_remove,
         
     | 
| 120 | 
         
             
                        )
         
     | 
| 121 | 
         
             
                        if separate_unet_ckpt is not None:
         
     | 
| 122 | 
         
            -
                            sd = torch.load(separate_unet_ckpt)["state_dict"]
         
     | 
| 123 | 
         
             
                            if remove_keys_from_unet_weights is not None:
         
     | 
| 124 | 
         
             
                                for k in list(sd.keys()):
         
     | 
| 125 | 
         
             
                                    for remove_key in remove_keys_from_unet_weights:
         
     | 
| 
         @@ -190,7 +190,7 @@ class DiffusionEngine(pl.LightningModule): 
     | 
|
| 190 | 
         | 
| 191 | 
         
             
                def load_bad_model_weights(self, path: str) -> None:
         
     | 
| 192 | 
         
             
                    print(f"Restoring bad model from {path}")
         
     | 
| 193 | 
         
            -
                    state_dict = torch.load(path, map_location="cpu")
         
     | 
| 194 | 
         
             
                    new_dict = {}
         
     | 
| 195 | 
         
             
                    for k, v in state_dict["module"].items():
         
     | 
| 196 | 
         
             
                        if "learned_mask" in k:
         
     | 
| 
         @@ -221,13 +221,13 @@ class DiffusionEngine(pl.LightningModule): 
     | 
|
| 221 | 
         
             
                ) -> None:
         
     | 
| 222 | 
         
             
                    print(f"Restoring from {path}")
         
     | 
| 223 | 
         
             
                    if path.endswith("ckpt"):
         
     | 
| 224 | 
         
            -
                        sd = torch.load(path, map_location="cpu")["state_dict"]
         
     | 
| 225 | 
         
             
                    elif path.endswith("pt"):
         
     | 
| 226 | 
         
            -
                        sd = torch.load(path, map_location="cpu")["module"]
         
     | 
| 227 | 
         
             
                        # Remove leading _forward_module from keys
         
     | 
| 228 | 
         
             
                        sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
         
     | 
| 229 | 
         
             
                    elif path.endswith("bin"):
         
     | 
| 230 | 
         
            -
                        sd = torch.load(path, map_location="cpu")
         
     | 
| 231 | 
         
             
                        # Remove leading _forward_module from keys
         
     | 
| 232 | 
         
             
                        sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
         
     | 
| 233 | 
         
             
                    elif path.endswith("safetensors"):
         
     | 
| 
         | 
|
| 119 | 
         
             
                            pattern_to_remove=pattern_to_remove,
         
     | 
| 120 | 
         
             
                        )
         
     | 
| 121 | 
         
             
                        if separate_unet_ckpt is not None:
         
     | 
| 122 | 
         
            +
                            sd = torch.load(separate_unet_ckpt, weights_only=False)["state_dict"]
         
     | 
| 123 | 
         
             
                            if remove_keys_from_unet_weights is not None:
         
     | 
| 124 | 
         
             
                                for k in list(sd.keys()):
         
     | 
| 125 | 
         
             
                                    for remove_key in remove_keys_from_unet_weights:
         
     | 
| 
         | 
|
| 190 | 
         | 
| 191 | 
         
             
                def load_bad_model_weights(self, path: str) -> None:
         
     | 
| 192 | 
         
             
                    print(f"Restoring bad model from {path}")
         
     | 
| 193 | 
         
            +
                    state_dict = torch.load(path, map_location="cpu", weights_only=False)
         
     | 
| 194 | 
         
             
                    new_dict = {}
         
     | 
| 195 | 
         
             
                    for k, v in state_dict["module"].items():
         
     | 
| 196 | 
         
             
                        if "learned_mask" in k:
         
     | 
| 
         | 
|
| 221 | 
         
             
                ) -> None:
         
     | 
| 222 | 
         
             
                    print(f"Restoring from {path}")
         
     | 
| 223 | 
         
             
                    if path.endswith("ckpt"):
         
     | 
| 224 | 
         
            +
                        sd = torch.load(path, map_location="cpu", weights_only=False)["state_dict"]
         
     | 
| 225 | 
         
             
                    elif path.endswith("pt"):
         
     | 
| 226 | 
         
            +
                        sd = torch.load(path, map_location="cpu", weights_only=False)["module"]
         
     | 
| 227 | 
         
             
                        # Remove leading _forward_module from keys
         
     | 
| 228 | 
         
             
                        sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
         
     | 
| 229 | 
         
             
                    elif path.endswith("bin"):
         
     | 
| 230 | 
         
            +
                        sd = torch.load(path, map_location="cpu", weights_only=False)
         
     | 
| 231 | 
         
             
                        # Remove leading _forward_module from keys
         
     | 
| 232 | 
         
             
                        sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
         
     | 
| 233 | 
         
             
                    elif path.endswith("safetensors"):
         
     |