Spaces:
Sleeping
Sleeping
| # Copyright 2021 AlQuraishi Laboratory | |
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import json | |
| import os | |
| import time | |
| from typing import List | |
| import numpy as np | |
| import torch | |
| import ml_collections as mlc | |
| from rdkit import Chem | |
| from dockformerpp.data import data_transforms | |
| from dockformerpp.data.data_transforms import get_restype_atom37_mask, get_restypes | |
| from dockformerpp.data.protein_features import make_protein_features | |
| from dockformerpp.data.utils import FeatureTensorDict, FeatureDict | |
| from dockformerpp.utils import protein | |
| def _np_filter_and_to_tensor_dict(np_example: FeatureDict, features_to_keep: List[str]) -> FeatureTensorDict: | |
| """Creates dict of tensors from a dict of NumPy arrays. | |
| Args: | |
| np_example: A dict of NumPy feature arrays. | |
| features: A list of strings of feature names to be returned in the dataset. | |
| Returns: | |
| A dictionary of features mapping feature names to features. Only the given | |
| features are returned, all other ones are filtered out. | |
| """ | |
| # torch generates warnings if feature is already a torch Tensor | |
| to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach() | |
| tensor_dict = { | |
| k: to_tensor(v) for k, v in np_example.items() if k in features_to_keep | |
| } | |
| return tensor_dict | |
| def _add_protein_probablistic_features(features: FeatureDict, cfg: mlc.ConfigDict, mode: str) -> FeatureDict: | |
| if mode == "train": | |
| p = torch.rand(1).item() | |
| use_clamped_fape_value = float(p < cfg.supervised.clamp_prob) | |
| features["use_clamped_fape"] = np.float32(use_clamped_fape_value) | |
| else: | |
| features["use_clamped_fape"] = np.float32(0.0) | |
| return features | |
| def compose(x, fs): | |
| for f in fs: | |
| x = f(x) | |
| return x | |
| def _apply_protein_transforms(tensors: FeatureTensorDict) -> FeatureTensorDict: | |
| transforms = [ | |
| data_transforms.cast_to_64bit_ints, | |
| data_transforms.squeeze_features, | |
| data_transforms.make_atom14_masks, | |
| data_transforms.make_atom14_positions, | |
| data_transforms.atom37_to_frames, | |
| data_transforms.atom37_to_torsion_angles(""), | |
| data_transforms.make_pseudo_beta(), | |
| data_transforms.get_backbone_frames, | |
| data_transforms.get_chi_angles, | |
| ] | |
| tensors = compose(transforms)(tensors) | |
| return tensors | |
| def _apply_protein_probablistic_transforms(tensors: FeatureTensorDict, cfg: mlc.ConfigDict, mode: str) \ | |
| -> FeatureTensorDict: | |
| transforms = [data_transforms.make_target_feat()] | |
| crop_feats = dict(cfg.common.feat) | |
| if cfg[mode].fixed_size: | |
| transforms.append(data_transforms.select_feat(list(crop_feats))) | |
| # TODO bshor: restore transforms for training on cropped proteins, need to handle pocket somehow | |
| # if so, look for random_crop_to_size and make_fixed_size in data_transforms.py | |
| compose(transforms)(tensors) | |
| return tensors | |
| class DataPipeline: | |
| """Assembles input features.""" | |
| def __init__(self, config: mlc.ConfigDict, mode: str): | |
| self.config = config | |
| self.mode = mode | |
| self.feature_names = config.common.unsupervised_features | |
| if config[mode].supervised: | |
| self.feature_names += config.supervised.supervised_features | |
| def process_pdb(self, pdb_path: str) -> FeatureTensorDict: | |
| """ | |
| Assembles features for a protein in a PDB file. | |
| """ | |
| with open(pdb_path, 'r') as f: | |
| pdb_str = f.read() | |
| protein_object = protein.from_pdb_string(pdb_str) | |
| description = os.path.splitext(os.path.basename(pdb_path))[0].upper() | |
| pdb_feats = make_protein_features(protein_object, description) | |
| pdb_feats = _add_protein_probablistic_features(pdb_feats, self.config, self.mode) | |
| tensor_feats = _np_filter_and_to_tensor_dict(pdb_feats, self.feature_names) | |
| tensor_feats = _apply_protein_transforms(tensor_feats) | |
| tensor_feats = _apply_protein_probablistic_transforms(tensor_feats, self.config, self.mode) | |
| return tensor_feats | |
| def _prepare_recycles(feat: torch.Tensor, num_recycles: int) -> torch.Tensor: | |
| return feat.unsqueeze(-1).repeat(*([1] * len(feat.shape)), num_recycles) | |
| def _fit_to_crop(target_tensor: torch.Tensor, crop_size: int, start_ind: int) -> torch.Tensor: | |
| if len(target_tensor.shape) == 1: | |
| ret = torch.zeros((crop_size, ), dtype=target_tensor.dtype) | |
| ret[start_ind:start_ind + target_tensor.shape[0]] = target_tensor | |
| return ret | |
| elif len(target_tensor.shape) == 2: | |
| ret = torch.zeros((crop_size, target_tensor.shape[-1]), dtype=target_tensor.dtype) | |
| ret[start_ind:start_ind + target_tensor.shape[0], :] = target_tensor | |
| return ret | |
| else: | |
| ret = torch.zeros((crop_size, *target_tensor.shape[1:]), dtype=target_tensor.dtype) | |
| ret[start_ind:start_ind + target_tensor.shape[0], ...] = target_tensor | |
| return ret | |
| def parse_input_json(input_path: str, mode: str, config: mlc.ConfigDict, data_pipeline: DataPipeline, | |
| data_dir: str, idx: int) -> FeatureTensorDict: | |
| start_load_time = time.time() | |
| input_data = json.load(open(input_path, "r")) | |
| if mode == "train" or mode == "eval": | |
| print("loading", input_data["pdb_id"], end=" ") | |
| num_recycles = config.common.max_recycling_iters + 1 | |
| input_protein_r_feats = data_pipeline.process_pdb(pdb_path=os.path.join(data_dir, input_data["input_r_structure"])) | |
| input_protein_l_feats = data_pipeline.process_pdb(pdb_path=os.path.join(data_dir, input_data["input_l_structure"])) | |
| n_res_r = input_protein_r_feats["protein_target_feat"].shape[0] | |
| n_res_l = input_protein_l_feats["protein_target_feat"].shape[0] | |
| n_res_total = n_res_r + n_res_l | |
| n_affinity = 1 | |
| # add 1 for affinity token | |
| crop_size = n_res_total + n_affinity | |
| if (mode == "train" or mode == "eval") and config.train.fixed_size: | |
| crop_size = config.train.crop_size | |
| assert crop_size >= n_res_total + n_affinity, f"crop_size: {crop_size}, n_res_r: {n_res_r}, n_res_l: {n_res_l}" | |
| token_mask = torch.zeros((crop_size,), dtype=torch.float32) | |
| token_mask[:n_res_total + n_affinity] = 1 | |
| protein_r_mask = torch.zeros((crop_size,), dtype=torch.float32) | |
| protein_r_mask[:n_res_r] = 1 | |
| protein_l_mask = torch.zeros((crop_size,), dtype=torch.float32) | |
| protein_l_mask[n_res_r:n_res_total] = 1 | |
| affinity_mask = torch.zeros((crop_size,), dtype=torch.float32) | |
| affinity_mask[n_res_total] = 1 | |
| structural_mask = torch.zeros((crop_size,), dtype=torch.float32) | |
| structural_mask[:n_res_total] = 1 | |
| inter_pair_mask = torch.zeros((crop_size, crop_size), dtype=torch.float32) | |
| inter_pair_mask[:n_res_r, n_res_r:n_res_total] = 1 | |
| inter_pair_mask[n_res_r:n_res_total, :n_res_r] = 1 | |
| tf_dim = input_protein_r_feats["protein_target_feat"].shape[-1] | |
| target_feat = torch.zeros((crop_size, tf_dim + 3), dtype=torch.float32) | |
| target_feat[:n_res_r, :tf_dim] = input_protein_r_feats["protein_target_feat"] | |
| target_feat[n_res_r:n_res_total, :tf_dim] = input_protein_l_feats["protein_target_feat"] | |
| target_feat[:n_res_r, tf_dim] = 1 # Set "is_protein_r" flag for protein rows | |
| target_feat[n_res_r:n_res_total, tf_dim + 1] = 1 # Set "is_protein_l" flag for ligand rows | |
| target_feat[n_res_total, tf_dim + 2] = 1 # Set "is_affinity" flag for affinity row | |
| input_positions = torch.zeros((crop_size, 3), dtype=torch.float32) | |
| input_positions[:n_res_r] = input_protein_r_feats["pseudo_beta"] | |
| input_positions[n_res_r:n_res_total] = input_protein_l_feats["pseudo_beta"] | |
| distogram_mask = torch.zeros(crop_size) | |
| if mode == "train": | |
| ones_indices = torch.randperm(n_res_total)[:int(n_res_total * config.train.distogram_mask_prob)] | |
| # print(ones_indices) | |
| distogram_mask[ones_indices] = 1 | |
| input_positions = input_positions * (1 - distogram_mask).unsqueeze(-1) | |
| elif mode == "predict": | |
| # ignore all positions where pseudo_beta is 0, 0, 0 | |
| distogram_mask = (input_positions == 0).all(dim=-1).float() | |
| # print("Ignoring residues", torch.nonzero(distogram_mask).flatten()) | |
| # Implement ligand as amino acid type 20 | |
| aatype = torch.cat([input_protein_r_feats["aatype"], input_protein_l_feats["aatype"]], dim=0) | |
| residue_index = torch.cat([input_protein_r_feats["residue_index"], input_protein_l_feats["residue_index"]], dim=0) | |
| residx_atom37_to_atom14 = torch.cat([input_protein_r_feats["residx_atom37_to_atom14"], | |
| input_protein_l_feats["residx_atom37_to_atom14"]], | |
| dim=0) | |
| atom37_atom_exists = torch.cat([input_protein_r_feats["atom37_atom_exists"], | |
| input_protein_l_feats["atom37_atom_exists"]], dim=0) | |
| feats = { | |
| "token_mask": token_mask, | |
| "protein_r_mask": protein_r_mask, | |
| "protein_l_mask": protein_l_mask, | |
| "affinity_mask": affinity_mask, | |
| "structural_mask": structural_mask, | |
| "inter_pair_mask": inter_pair_mask, | |
| "target_feat": target_feat, | |
| "input_positions": input_positions, | |
| "distogram_mask": distogram_mask, | |
| "residue_index": _fit_to_crop(residue_index, crop_size, 0), | |
| "aatype": _fit_to_crop(aatype, crop_size, 0), | |
| "residx_atom37_to_atom14": _fit_to_crop(residx_atom37_to_atom14, crop_size, 0), | |
| "atom37_atom_exists": _fit_to_crop(atom37_atom_exists, crop_size, 0), | |
| } | |
| if mode == "predict": | |
| feats.update({ | |
| "in_chain_residue_index_r": input_protein_r_feats["in_chain_residue_index"], | |
| "chain_index_r": input_protein_r_feats["chain_index"], | |
| "in_chain_residue_index_l": input_protein_l_feats["in_chain_residue_index"], | |
| "chain_index_l": input_protein_l_feats["chain_index"], | |
| }) | |
| if mode == 'train' or mode == 'eval': | |
| gt_protein_r_feats = data_pipeline.process_pdb(pdb_path=os.path.join(data_dir, input_data["gt_r_structure"])) | |
| gt_protein_l_feats = data_pipeline.process_pdb(pdb_path=os.path.join(data_dir, input_data["gt_l_structure"])) | |
| affinity_loss_factor = torch.tensor([1.0], dtype=torch.float32) | |
| if input_data.get("affinity") is None: | |
| eps = 1e-6 | |
| affinity_loss_factor = torch.tensor([eps], dtype=torch.float32) | |
| affinity = torch.tensor([0.0], dtype=torch.float32) | |
| else: | |
| affinity = torch.tensor([input_data["affinity"]], dtype=torch.float32) | |
| resolution = torch.tensor(input_data["resolution"], dtype=torch.float32) | |
| # prepare inter_contacts | |
| expanded_prot_r_pos = gt_protein_r_feats["pseudo_beta"].unsqueeze(1) # Shape: (n_res_r, 1, 3) | |
| expanded_prot_l_pos = gt_protein_l_feats["pseudo_beta"].unsqueeze(0) # Shape: (1, n_res_l, 3) | |
| distances = torch.sqrt(torch.sum((expanded_prot_r_pos - expanded_prot_l_pos) ** 2, dim=-1)) | |
| inter_contact = (distances < 8.0).float() | |
| binding_site_mask_r = inter_contact.any(dim=1).float() | |
| binding_site_mask_l = inter_contact.any(dim=0).float() | |
| print("attaching binding masks", binding_site_mask_r.shape, binding_site_mask_l.shape) | |
| binding_site_mask = torch.cat([binding_site_mask_r, binding_site_mask_l], dim=0) | |
| inter_contact_reshaped_to_crop = torch.zeros((crop_size, crop_size), dtype=torch.float32) | |
| inter_contact_reshaped_to_crop[:n_res_r, n_res_r:n_res_total] = inter_contact | |
| inter_contact_reshaped_to_crop[n_res_r:n_res_total, :n_res_r] = inter_contact.T | |
| # Use CA positions only | |
| atom37_gt_positions = torch.cat([gt_protein_r_feats["all_atom_positions"], | |
| gt_protein_l_feats["all_atom_positions"]], dim=0) | |
| atom37_atom_exists_in_res = torch.cat([gt_protein_r_feats["atom37_atom_exists"], | |
| gt_protein_l_feats["atom37_atom_exists"]], dim=0) | |
| atom37_atom_exists_in_gt = torch.cat([gt_protein_r_feats["all_atom_mask"], | |
| gt_protein_l_feats["all_atom_mask"]], dim=0) | |
| atom14_gt_positions = torch.cat([gt_protein_r_feats["atom14_gt_positions"], | |
| gt_protein_l_feats["atom14_gt_positions"]], dim=0) | |
| atom14_atom_exists_in_res = torch.cat([gt_protein_r_feats["atom14_atom_exists"], | |
| gt_protein_l_feats["atom14_atom_exists"]], dim=0) | |
| atom14_atom_exists_in_gt = torch.cat([gt_protein_r_feats["atom14_gt_exists"], | |
| gt_protein_l_feats["atom14_gt_exists"]], dim=0) | |
| gt_pseudo_beta_joined = torch.cat([gt_protein_r_feats["pseudo_beta"], gt_protein_l_feats["pseudo_beta"]], dim=0) | |
| gt_pseudo_beta_joined_mask = torch.cat([gt_protein_r_feats["pseudo_beta_mask"], | |
| gt_protein_l_feats["pseudo_beta_mask"]], dim=0) | |
| # IGNORES: residx_atom14_to_atom37, rigidgroups_group_exists, | |
| # rigidgroups_group_is_ambiguous, pseudo_beta_mask, backbone_rigid_mask, protein_target_feat | |
| gt_protein_feats = { | |
| "atom37_gt_positions": atom37_gt_positions, # torch.Size([n_struct, 37, 3]) | |
| "atom37_atom_exists_in_res": atom37_atom_exists_in_res, # torch.Size([n_struct, 37]) | |
| "atom37_atom_exists_in_gt": atom37_atom_exists_in_gt, # torch.Size([n_struct, 37]) | |
| "atom14_gt_positions": atom14_gt_positions, # torch.Size([n_struct, 14, 3]) | |
| "atom14_atom_exists_in_res": atom14_atom_exists_in_res, # torch.Size([n_struct, 14]) | |
| "atom14_atom_exists_in_gt": atom14_atom_exists_in_gt, # torch.Size([n_struct, 14]) | |
| "gt_pseudo_beta_joined": gt_pseudo_beta_joined, # torch.Size([n_struct, 3]) | |
| "gt_pseudo_beta_joined_mask": gt_pseudo_beta_joined_mask, # torch.Size([n_struct]) | |
| # These we don't need to add the ligand to, because padding is sufficient (everything should be 0) | |
| "atom14_alt_gt_positions": torch.cat([gt_protein_r_feats["atom14_alt_gt_positions"], | |
| gt_protein_l_feats["atom14_alt_gt_positions"]], dim=0), # torch.Size([n_res, 14, 3]) | |
| "atom14_alt_gt_exists": torch.cat([gt_protein_r_feats["atom14_alt_gt_exists"], | |
| gt_protein_l_feats["atom14_alt_gt_exists"]], dim=0), # torch.Size([n_res, 14]) | |
| "atom14_atom_is_ambiguous": torch.cat([gt_protein_r_feats["atom14_atom_is_ambiguous"], | |
| gt_protein_l_feats["atom14_atom_is_ambiguous"]], dim=0), # torch.Size([n_res, 14]) | |
| "rigidgroups_gt_frames": torch.cat([gt_protein_r_feats["rigidgroups_gt_frames"], | |
| gt_protein_l_feats["rigidgroups_gt_frames"]], dim=0), # torch.Size([n_res, 8, 4, 4]) | |
| "rigidgroups_gt_exists": torch.cat([gt_protein_r_feats["rigidgroups_gt_exists"], | |
| gt_protein_l_feats["rigidgroups_gt_exists"]], dim=0), # torch.Size([n_res, 8]) | |
| "rigidgroups_alt_gt_frames": torch.cat([gt_protein_r_feats["rigidgroups_alt_gt_frames"], | |
| gt_protein_l_feats["rigidgroups_alt_gt_frames"]], dim=0), # torch.Size([n_res, 8, 4, 4]) | |
| "backbone_rigid_tensor": torch.cat([gt_protein_r_feats["backbone_rigid_tensor"], | |
| gt_protein_l_feats["backbone_rigid_tensor"]], dim=0), # torch.Size([n_res, 4, 4]) | |
| "backbone_rigid_mask": torch.cat([gt_protein_r_feats["backbone_rigid_mask"], | |
| gt_protein_l_feats["backbone_rigid_mask"]], dim=0), # torch.Size([n_res]) | |
| "chi_angles_sin_cos": torch.cat([gt_protein_r_feats["chi_angles_sin_cos"], | |
| gt_protein_l_feats["chi_angles_sin_cos"]], dim=0), | |
| "chi_mask": torch.cat([gt_protein_r_feats["chi_mask"], gt_protein_l_feats["chi_mask"]], dim=0), | |
| } | |
| for k, v in gt_protein_feats.items(): | |
| gt_protein_feats[k] = _fit_to_crop(v, crop_size, 0) | |
| feats = { | |
| **feats, | |
| **gt_protein_feats, | |
| "resolution": resolution, | |
| "affinity": affinity, | |
| "affinity_loss_factor": affinity_loss_factor, | |
| "seq_length": torch.tensor(n_res_total), | |
| "binding_site_mask": _fit_to_crop(binding_site_mask, crop_size, 0), | |
| "gt_inter_contacts": inter_contact_reshaped_to_crop, | |
| } | |
| for k, v in feats.items(): | |
| # print(k, v.shape) | |
| feats[k] = _prepare_recycles(v, num_recycles) | |
| feats["batch_idx"] = torch.tensor( | |
| [idx for _ in range(crop_size)], dtype=torch.int64, device=feats["aatype"].device | |
| ) | |
| print("load time", round(time.time() - start_load_time, 4)) | |
| return feats | |