Spaces:
Runtime error
Runtime error
| from types import SimpleNamespace | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset | |
| ModalityType = SimpleNamespace( | |
| AA="aa", | |
| DNA="dna", | |
| PDB="pdb", | |
| GO="go", | |
| MSA="msa", | |
| TEXT="text", | |
| ) | |
| class Normalize(nn.Module): | |
| def __init__(self, dim: int) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| return torch.nn.functional.normalize(x, dim=self.dim, p=2) | |
| class EmbeddingDataset(Dataset): | |
| """ | |
| The main class for turning any modality to a torch Dataset that can be passed to | |
| a torch dataloader. Any modality that doesn't fit into the __getitem__ | |
| method can subclass this and modify the __getitem__ method. | |
| """ | |
| def __init__(self, sequence_file_path, embeddings_file_path, modality): | |
| self.sequence = pd.read_csv(sequence_file_path) | |
| self.embedding = torch.load(embeddings_file_path) | |
| self.modality = modality | |
| def __len__(self): | |
| return len(self.sequence) | |
| def __getitem__(self, idx): | |
| sequence = self.sequence.iloc[idx, 0] | |
| embedding = self.embedding[idx] | |
| return {"aa": sequence, self.modality: embedding} | |
| class DualEmbeddingDataset(Dataset): | |
| """ | |
| The main class for turning any modality to a torch Dataset that can be passed to | |
| a torch dataloader. Any modality that doesn't fit into the __getitem__ | |
| method can subclass this and modify the __getitem__ method. | |
| """ | |
| def __init__(self, sequence_embeddings_file_path, embeddings_file_path, modality): | |
| self.sequence_embedding = torch.load(sequence_embeddings_file_path) | |
| self.embedding = torch.load(embeddings_file_path) | |
| self.modality = modality | |
| def __len__(self): | |
| return len(self.sequence_embedding) | |
| def __getitem__(self, idx): | |
| sequence_embedding = self.sequence_embedding[idx] | |
| embedding = self.embedding[idx] | |
| return {"aa": sequence_embedding, self.modality: embedding} | |
| class ProteinBindModel(nn.Module): | |
| def __init__( | |
| self, | |
| aa_embed_dim, | |
| dna_embed_dim, | |
| pdb_embed_dim, | |
| go_embed_dim, | |
| msa_embed_dim, | |
| text_embed_dim, | |
| in_embed_dim, | |
| out_embed_dim | |
| ): | |
| super().__init__() | |
| self.modality_trunks = self._create_modality_trunk( | |
| aa_embed_dim, | |
| dna_embed_dim, | |
| pdb_embed_dim, | |
| go_embed_dim, | |
| msa_embed_dim, | |
| text_embed_dim, | |
| out_embed_dim | |
| ) | |
| self.modality_heads = self._create_modality_head( | |
| in_embed_dim, | |
| out_embed_dim, | |
| ) | |
| self.modality_postprocessors = self._create_modality_postprocessors( | |
| out_embed_dim | |
| ) | |
| def _create_modality_trunk( | |
| self, | |
| aa_embed_dim, | |
| dna_embed_dim, | |
| pdb_embed_dim, | |
| go_embed_dim, | |
| msa_embed_dim, | |
| text_embed_dim, | |
| in_embed_dim | |
| ): | |
| """ | |
| The current layers are just a proof of concept | |
| and are subject to the opinion of others. | |
| :param aa_embed_dim: | |
| :param dna_embed_dim: | |
| :param pdb_embed_dim: | |
| :param go_embed_dim: | |
| :param msa_embed_dim: | |
| :param text_embed_dim: | |
| :param in_embed_dim: | |
| :return: | |
| """ | |
| modality_trunks = {} | |
| modality_trunks[ModalityType.AA] = nn.Sequential( | |
| nn.Linear(aa_embed_dim, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, in_embed_dim), | |
| ) | |
| modality_trunks[ModalityType.DNA] = nn.Sequential( | |
| nn.Linear(dna_embed_dim, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, in_embed_dim), | |
| ) | |
| modality_trunks[ModalityType.PDB] = nn.Sequential( | |
| nn.Linear(pdb_embed_dim, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, in_embed_dim), | |
| ) | |
| modality_trunks[ModalityType.GO] = nn.Sequential( | |
| nn.Linear(go_embed_dim, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, in_embed_dim), | |
| ) | |
| modality_trunks[ModalityType.MSA] = nn.Sequential( | |
| nn.Linear(msa_embed_dim, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, in_embed_dim), | |
| ) | |
| modality_trunks[ModalityType.TEXT] = nn.Sequential( | |
| nn.Linear(text_embed_dim, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, in_embed_dim), | |
| ) | |
| return nn.ModuleDict(modality_trunks) | |
| def _create_modality_head( | |
| self, | |
| in_embed_dim, | |
| out_embed_dim | |
| ): | |
| modality_heads = {} | |
| modality_heads[ModalityType.AA] = nn.Sequential( | |
| nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6), | |
| nn.Dropout(p=0.5), | |
| nn.Linear(in_embed_dim, out_embed_dim, bias=False), | |
| ) | |
| modality_heads[ModalityType.DNA] = nn.Sequential( | |
| nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6), | |
| nn.Dropout(p=0.5), | |
| nn.Linear(in_embed_dim, out_embed_dim, bias=False), | |
| ) | |
| modality_heads[ModalityType.PDB] = nn.Sequential( | |
| nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6), | |
| nn.Dropout(p=0.5), | |
| nn.Linear(in_embed_dim, out_embed_dim, bias=False), | |
| ) | |
| modality_heads[ModalityType.GO] = nn.Sequential( | |
| nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6), | |
| nn.Dropout(p=0.5), | |
| nn.Linear(in_embed_dim, out_embed_dim, bias=False), | |
| ) | |
| modality_heads[ModalityType.MSA] = nn.Sequential( | |
| nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6), | |
| nn.Dropout(p=0.5), | |
| nn.Linear(in_embed_dim, out_embed_dim, bias=False), | |
| ) | |
| modality_heads[ModalityType.TEXT] = nn.Sequential( | |
| nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6), | |
| nn.Dropout(p=0.5), | |
| nn.Linear(in_embed_dim, out_embed_dim, bias=False), | |
| ) | |
| return nn.ModuleDict(modality_heads) | |
| def _create_modality_postprocessors(self, out_embed_dim): | |
| modality_postprocessors = {} | |
| modality_postprocessors[ModalityType.AA] = Normalize(dim=-1) | |
| modality_postprocessors[ModalityType.DNA] = Normalize(dim=-1) | |
| modality_postprocessors[ModalityType.PDB] = Normalize(dim=-1) | |
| modality_postprocessors[ModalityType.TEXT] = Normalize(dim=-1) | |
| modality_postprocessors[ModalityType.GO] = Normalize(dim=-1) | |
| modality_postprocessors[ModalityType.MSA] = Normalize(dim=-1) | |
| return nn.ModuleDict(modality_postprocessors) | |
| def forward(self, inputs): | |
| """ | |
| input = {k_1: [v],k_n: [v]} | |
| for key in input | |
| get trunk for key | |
| forward pass of value in trunk | |
| get projection head of key | |
| forward pass of value in projection head | |
| append output in output dict | |
| return { k_1, [o], k_n: [o]} | |
| """ | |
| outputs = {} | |
| for modality_key, modality_value in inputs.items(): | |
| modality_value = self.modality_trunks[modality_key]( | |
| modality_value | |
| ) | |
| modality_value = self.modality_heads[modality_key]( | |
| modality_value | |
| ) | |
| modality_value = self.modality_postprocessors[modality_key]( | |
| modality_value | |
| ) | |
| outputs[modality_key] = modality_value | |
| return outputs | |
| def create_proteinbind(pretrained=False): | |
| """ | |
| The embedding dimensions here are dummy | |
| :param pretrained: | |
| :return: | |
| """ | |
| model = ProteinBindModel( | |
| aa_embed_dim=480, | |
| dna_embed_dim=1280, | |
| pdb_embed_dim=128, | |
| go_embed_dim=600, | |
| msa_embed_dim=768, | |
| text_embed_dim=768, | |
| in_embed_dim=1024, | |
| out_embed_dim=1024 | |
| ) | |
| if pretrained: | |
| # get path from config | |
| PATH = 'best_model.pth' | |
| model.load_state_dict(torch.load(PATH)) | |
| return model | |