Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update text_image_audio.py
Browse files- text_image_audio.py +15 -4
    	
        text_image_audio.py
    CHANGED
    
    | @@ -83,18 +83,29 @@ class AudioEncoder(nn.Module): | |
| 83 | 
             
                    return self.forward(inputs)
         | 
| 84 |  | 
| 85 | 
             
            class ModalityTokenEncoder(nn.Module):
         | 
| 86 | 
            -
                def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', *args, **kwargs):
         | 
| 87 | 
             
                    super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
         | 
| 88 | 
             
                    # Attributes
         | 
| 89 | 
             
                    self.projection_dim = projection_dim
         | 
| 90 | 
             
                    self.device = device
         | 
| 91 | 
             
                    self.token_size = token_size
         | 
|  | |
| 92 | 
             
                    # Models
         | 
| 93 | 
             
                    audio_variance = torch.rand(1) * 0.5 + 0.1
         | 
| 94 | 
             
                    self.audio_token = nn.Parameter(torch.normal(mean=0, std=audio_variance.item(),
         | 
| 95 | 
            -
                                                                  size=(self.token_size, self. | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 96 | 
             
                def forward(self):
         | 
| 97 | 
            -
                    return self.audio_token
         | 
| 98 |  | 
| 99 | 
             
                def __call__(self):
         | 
| 100 | 
             
                    return self.forward()
         | 
| @@ -205,4 +216,4 @@ class OneEncoder(nn.Module, PyTorchModelHubMixin): | |
| 205 | 
             
                        #    fig.suptitle(display(Audio(query['input_values'], rate=self.sample_rate)))
         | 
| 206 | 
             
                        #plt.show()
         | 
| 207 | 
             
                    #return values, indices
         | 
| 208 | 
            -
                   
         | 
|  | |
| 83 | 
             
                    return self.forward(inputs)
         | 
| 84 |  | 
| 85 | 
             
            class ModalityTokenEncoder(nn.Module):
         | 
| 86 | 
            +
                def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', token_dim=CFG.token_dim, *args, **kwargs):
         | 
| 87 | 
             
                    super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
         | 
| 88 | 
             
                    # Attributes
         | 
| 89 | 
             
                    self.projection_dim = projection_dim
         | 
| 90 | 
             
                    self.device = device
         | 
| 91 | 
             
                    self.token_size = token_size
         | 
| 92 | 
            +
                    self.token_dim = token_dim
         | 
| 93 | 
             
                    # Models
         | 
| 94 | 
             
                    audio_variance = torch.rand(1) * 0.5 + 0.1
         | 
| 95 | 
             
                    self.audio_token = nn.Parameter(torch.normal(mean=0, std=audio_variance.item(),
         | 
| 96 | 
            +
                                                                  size=(self.token_size, self.token_dim)).to(self.device))
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    self.token_projection = nn.Sequential(
         | 
| 99 | 
            +
                        nn.Linear(self.token_dim, 64),
         | 
| 100 | 
            +
                        nn.ReLU(),
         | 
| 101 | 
            +
                        nn.Linear(64, 128),
         | 
| 102 | 
            +
                        nn.ReLU(),
         | 
| 103 | 
            +
                        nn.Linear(128, self.projection_dim),
         | 
| 104 | 
            +
                        nn.LayerNorm(self.projection_dim)
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
             | 
| 107 | 
             
                def forward(self):
         | 
| 108 | 
            +
                    return self.token_projection(self.audio_token)
         | 
| 109 |  | 
| 110 | 
             
                def __call__(self):
         | 
| 111 | 
             
                    return self.forward()
         | 
|  | |
| 216 | 
             
                        #    fig.suptitle(display(Audio(query['input_values'], rate=self.sample_rate)))
         | 
| 217 | 
             
                        #plt.show()
         | 
| 218 | 
             
                    #return values, indices
         | 
| 219 | 
            +
                   
         | 
