Spaces:
Runtime error
Runtime error
| # PyTorch for deep learning operations | |
| import torch | |
| import torch.nn as nn | |
| # PyTorch data loading and utilities | |
| import torch.multiprocessing | |
| # COCO dataset tools | |
| from transformers import BertModel, BertTokenizer, AutoModel, AutoImageProcessor | |
| from configs import CFG | |
| from text_image import OneEncoder as TextImageEncoder | |
| class AlignmentLayer(nn.Module): | |
| def __init__(self, input_dim=768, projection_dim=CFG.projection_dim, dropout_rate=CFG.dropout_rate, *args, | |
| **kwargs): | |
| super(AlignmentLayer, self).__init__(*args, **kwargs) | |
| # Attributes | |
| self.input_dim = input_dim | |
| self.projection_dim = projection_dim | |
| self.dropout_rate = dropout_rate | |
| # Layers | |
| self.linear_layer1 = nn.Linear(self.input_dim, self.projection_dim) | |
| self.gelu = nn.GELU() | |
| self.linear_layer2 = nn.Linear(self.projection_dim, self.projection_dim) | |
| self.dropout = nn.Dropout(self.dropout_rate) | |
| self.normalization_layer = nn.LayerNorm(self.projection_dim) | |
| def forward(self, inputs): | |
| x = inputs | |
| x = self.linear_layer1(x) | |
| x = self.gelu(x) | |
| x = self.linear_layer2(x) | |
| x = self.dropout(x) | |
| x = self.normalization_layer(x) | |
| return x | |
| def __call__(self, inputs): | |
| return self.forward(inputs) | |
| class RadioEncoder(nn.Module): | |
| def __init__(self, model_name=CFG.radio_name, projection_dim=CFG.projection_dim, | |
| trainable=False, dropout_rate=CFG.dropout_rate, *args, **kwargs): | |
| super(RadioEncoder, self).__init__(*args, **kwargs) | |
| # Attributes | |
| self.model_name = model_name | |
| self.projection_dim = projection_dim | |
| self.dropout_rate = dropout_rate | |
| self.trainable = trainable | |
| # Models | |
| self.pretrained_encoder = AutoModel.from_pretrained(self.model_name) | |
| self.alignment_layer = AlignmentLayer( | |
| input_dim=self.pretrained_encoder.config.hidden_size, | |
| projection_dim=self.projection_dim, | |
| dropout_rate=self.dropout_rate) | |
| # Freeze Wav2VecModel | |
| for parameter in self.pretrained_encoder.parameters(): | |
| parameter.requires_grad = self.trainable | |
| def forward(self, inputs): | |
| x = self.pretrained_encoder(inputs).last_hidden_state | |
| x = self.alignment_layer(x) | |
| return x | |
| def __call__(self, inputs): | |
| return self.forward(inputs) | |
| class ModalityTokenEncoder(nn.Module): | |
| def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', *args, **kwargs): | |
| super(ModalityTokenEncoder, self).__init__(*args, **kwargs) | |
| # Attributes | |
| self.projection_dim = projection_dim | |
| self.device = device | |
| self.token_size = token_size | |
| # Models | |
| radio_variance = torch.rand(1) * 0.5 + 0.1 | |
| self.radio_token = nn.Parameter(torch.normal(mean=0, std=radio_variance.item(), | |
| size=(self.token_size, self.projection_dim)).to(self.device)) | |
| def forward(self): | |
| return self.radio_token | |
| def __call__(self): | |
| return self.forward() | |
| class OneEncoder(nn.Module): | |
| def __init__(self, device='cpu', modality_token_encoder=ModalityTokenEncoder(), | |
| checkpoint="bilalfaye/OneEncoder-text-image", | |
| radio_processor=AutoImageProcessor.from_pretrained("microsoft/rad-dino"), | |
| sample_rate=CFG.sample_rate, radio_encoder=RadioEncoder(), *args, **kwargs): | |
| super(OneEncoder, self).__init__(*args, **kwargs) | |
| self.device = device | |
| self.checkpoint = checkpoint | |
| self.modality_token_encoder = modality_token_encoder | |
| self.modality_token_encoder.device = self.device | |
| self.text_image_encoder = TextImageEncoder(device=self.device) | |
| self.text_image_encoder.from_pretrained(self.checkpoint) | |
| self.radio_processor = radio_processor | |
| self.sample_rate = sample_rate | |
| self.radio_encoder = radio_encoder | |
| self.temperature = nn.Parameter(torch.tensor(0.07).to(self.device)) | |
| # Freeze | |
| for parameter in self.text_image_encoder.parameters(): | |
| parameter.requires_grad = False | |
| def encode_radio(self, pil_radios=None, radios=None): | |
| """ | |
| :param pil_radios: list of pillow images | |
| :param radios: preprocessed image | |
| :return: tensor | |
| """ | |
| if pil_radios is not None: | |
| tensors = self.radio_processor(pil_radios, return_tensors="pt")["pixel_values"].to(self.device) | |
| else: | |
| tensors = radios.to(self.device) | |
| features = self.radio_encoder(tensors) | |
| radio_token = self.modality_token_encoder() | |
| outputs = self.text_image_encoder.universal_projection_encoder([features, radio_token]).last_hidden_state | |
| return outputs | |