Spaces:
Sleeping
Sleeping
| import copy | |
| import ml_collections as mlc | |
| from dockformerpp.utils.config_tools import set_inf, enforce_config_constraints | |
| def model_config( | |
| name, | |
| train=False, | |
| low_prec=False, | |
| long_sequence_inference=False | |
| ): | |
| c = copy.deepcopy(config) | |
| # TRAINING PRESETS | |
| if name == "initial_training": | |
| # AF2 Suppl. Table 4, "initial training" setting | |
| pass | |
| elif name == "finetune_affinity": | |
| c.loss.affinity2d.weight = 0.5 | |
| c.loss.binding_site.weight = 0.5 | |
| c.loss.positions_inter_distogram.weight = 0.5 # this is not essential given fape? | |
| else: | |
| raise ValueError("Invalid model name") | |
| c.globals.use_lma = False | |
| if long_sequence_inference: | |
| assert(not train) | |
| c.globals.use_lma = True | |
| if train: | |
| c.globals.blocks_per_ckpt = 1 | |
| c.globals.use_lma = False | |
| if low_prec: | |
| c.globals.eps = 1e-4 | |
| # If we want exact numerical parity with the original, inf can't be | |
| # a global constant | |
| set_inf(c, 1e4) | |
| enforce_config_constraints(c) | |
| return c | |
| c_z = mlc.FieldReference(128, field_type=int) | |
| c_m = mlc.FieldReference(256, field_type=int) | |
| c_t = mlc.FieldReference(64, field_type=int) | |
| c_e = mlc.FieldReference(64, field_type=int) | |
| c_s = mlc.FieldReference(384, field_type=int) | |
| blocks_per_ckpt = mlc.FieldReference(None, field_type=int) | |
| aux_distogram_bins = mlc.FieldReference(64, field_type=int) | |
| aux_affinity_bins = mlc.FieldReference(32, field_type=int) | |
| eps = mlc.FieldReference(1e-8, field_type=float) | |
| NUM_RES = "num residues placeholder" | |
| NUM_TOKEN = "num tokens placeholder" | |
| config = mlc.ConfigDict( | |
| { | |
| "data": { | |
| "common": { | |
| "feat": { | |
| "aatype": [NUM_TOKEN], | |
| "all_atom_mask": [NUM_TOKEN, None], | |
| "all_atom_positions": [NUM_TOKEN, None, None], | |
| "atom14_alt_gt_exists": [NUM_TOKEN, None], | |
| "atom14_alt_gt_positions": [NUM_TOKEN, None, None], | |
| "atom14_atom_exists": [NUM_TOKEN, None], | |
| "atom14_atom_is_ambiguous": [NUM_TOKEN, None], | |
| "atom14_gt_exists": [NUM_TOKEN, None], | |
| "atom14_gt_positions": [NUM_TOKEN, None, None], | |
| "atom37_atom_exists": [NUM_TOKEN, None], | |
| "backbone_rigid_mask": [NUM_TOKEN], | |
| "backbone_rigid_tensor": [NUM_TOKEN, None, None], | |
| "chi_angles_sin_cos": [NUM_TOKEN, None, None], | |
| "chi_mask": [NUM_TOKEN, None], | |
| "no_recycling_iters": [], | |
| "pseudo_beta": [NUM_TOKEN, None], | |
| "pseudo_beta_mask": [NUM_TOKEN], | |
| "residue_index": [NUM_TOKEN], | |
| "in_chain_residue_index": [NUM_TOKEN], | |
| "chain_index": [NUM_TOKEN], | |
| "residx_atom14_to_atom37": [NUM_TOKEN, None], | |
| "residx_atom37_to_atom14": [NUM_TOKEN, None], | |
| "resolution": [], | |
| "rigidgroups_alt_gt_frames": [NUM_TOKEN, None, None, None], | |
| "rigidgroups_group_exists": [NUM_TOKEN, None], | |
| "rigidgroups_group_is_ambiguous": [NUM_TOKEN, None], | |
| "rigidgroups_gt_exists": [NUM_TOKEN, None], | |
| "rigidgroups_gt_frames": [NUM_TOKEN, None, None, None], | |
| "seq_length": [], | |
| "token_mask": [NUM_TOKEN], | |
| "target_feat": [NUM_TOKEN, None], | |
| "use_clamped_fape": [], | |
| }, | |
| "max_recycling_iters": 1, | |
| "unsupervised_features": [ | |
| "aatype", | |
| "residue_index", | |
| "in_chain_residue_index", | |
| "chain_index", | |
| "seq_length", | |
| "no_recycling_iters", | |
| "all_atom_mask", | |
| "all_atom_positions", | |
| ], | |
| }, | |
| "supervised": { | |
| "clamp_prob": 0.9, | |
| "supervised_features": [ | |
| "resolution", | |
| "use_clamped_fape", | |
| ], | |
| }, | |
| "predict": { | |
| "fixed_size": True, | |
| "crop": False, | |
| "crop_size": None, | |
| "supervised": False, | |
| "uniform_recycling": False, | |
| }, | |
| "eval": { | |
| "fixed_size": True, | |
| "crop": False, | |
| "crop_size": None, | |
| "supervised": True, | |
| "uniform_recycling": False, | |
| }, | |
| "train": { | |
| "fixed_size": True, | |
| "crop": True, | |
| "crop_size": 355, | |
| "supervised": True, | |
| "clamp_prob": 0.9, | |
| "uniform_recycling": True, | |
| "distogram_mask_prob": 0.1, | |
| }, | |
| "data_module": { | |
| "data_loaders": { | |
| "batch_size": 1, | |
| # "batch_size": 2, | |
| "num_workers": 16, | |
| "pin_memory": True, | |
| "should_verify": False, | |
| }, | |
| }, | |
| }, | |
| # Recurring FieldReferences that can be changed globally here | |
| "globals": { | |
| "blocks_per_ckpt": blocks_per_ckpt, | |
| # Use Staats & Rabe's low-memory attention algorithm. | |
| "use_lma": False, | |
| "max_lr": 1e-3, | |
| "c_z": c_z, | |
| "c_m": c_m, | |
| "c_t": c_t, | |
| "c_e": c_e, | |
| "c_s": c_s, | |
| "eps": eps, | |
| }, | |
| "model": { | |
| "_mask_trans": False, | |
| "structure_input_embedder": { | |
| "protein_tf_dim": 20, | |
| "additional_tf_dim": 3, # number of classes (prot_r, prot_l, aff) | |
| "c_z": c_z, | |
| "c_m": c_m, | |
| "relpos_k": 32, | |
| "prot_min_bin": 3.25, | |
| "prot_max_bin": 20.75, | |
| "prot_no_bins": 15, | |
| "inf": 1e8, | |
| }, | |
| "recycling_embedder": { | |
| "c_z": c_z, | |
| "c_m": c_m, | |
| "min_bin": 3.25, | |
| "max_bin": 20.75, | |
| "no_bins": 15, | |
| "inf": 1e8, | |
| }, | |
| "evoformer_stack": { | |
| "c_m": c_m, | |
| "c_z": c_z, | |
| "c_hidden_single_att": 32, | |
| "c_hidden_mul": 128, | |
| "c_hidden_pair_att": 32, | |
| "c_s": c_s, | |
| "no_heads_single": 8, | |
| "no_heads_pair": 4, | |
| # "no_blocks": 48, | |
| "no_blocks": 2, | |
| "transition_n": 4, | |
| "single_dropout": 0.15, | |
| "pair_dropout": 0.25, | |
| "blocks_per_ckpt": blocks_per_ckpt, | |
| "clear_cache_between_blocks": False, | |
| "inf": 1e9, | |
| "eps": eps, # 1e-10, | |
| }, | |
| "structure_module": { | |
| "c_s": c_s, | |
| "c_z": c_z, | |
| "c_ipa": 16, | |
| "c_resnet": 128, | |
| "no_heads_ipa": 12, | |
| "no_qk_points": 4, | |
| "no_v_points": 8, | |
| "dropout_rate": 0.1, | |
| "no_blocks": 8, | |
| "no_transition_layers": 1, | |
| "no_resnet_blocks": 2, | |
| "no_angles": 7, | |
| "trans_scale_factor": 10, | |
| "epsilon": eps, # 1e-12, | |
| "inf": 1e5, | |
| }, | |
| "heads": { | |
| "lddt": { | |
| "no_bins": 50, | |
| "c_in": c_s, | |
| "c_hidden": 128, | |
| }, | |
| "distogram": { | |
| "c_z": c_z, | |
| "no_bins": aux_distogram_bins, | |
| }, | |
| "affinity_2d": { | |
| "c_z": c_z, | |
| "num_bins": aux_affinity_bins, | |
| }, | |
| "affinity_1d": { | |
| "c_s": c_s, | |
| "num_bins": aux_affinity_bins, | |
| }, | |
| "affinity_cls": { | |
| "c_s": c_s, | |
| "num_bins": aux_affinity_bins, | |
| }, | |
| "binding_site": { | |
| "c_s": c_s, | |
| "c_out": 1, | |
| }, | |
| "inter_contact": { | |
| "c_s": c_s, | |
| "c_z": c_z, | |
| "c_out": 1, | |
| }, | |
| }, | |
| # A negative value indicates that no early stopping will occur, i.e. | |
| # the model will always run `max_recycling_iters` number of recycling | |
| # iterations. A positive value will enable early stopping if the | |
| # difference in pairwise distances is less than the tolerance between | |
| # recycling steps. | |
| "recycle_early_stop_tolerance": -1. | |
| }, | |
| "relax": { | |
| "max_iterations": 0, # no max | |
| "tolerance": 2.39, | |
| "stiffness": 10.0, | |
| "max_outer_iterations": 20, | |
| "exclude_residues": [], | |
| }, | |
| "loss": { | |
| "distogram": { | |
| "min_bin": 2.3125, | |
| "max_bin": 21.6875, | |
| "no_bins": 64, | |
| "eps": eps, # 1e-6, | |
| "weight": 0.3, | |
| }, | |
| "positions_inter_distogram": { | |
| "max_dist": 20.0, | |
| "weight": 0.0, | |
| }, | |
| "positions_intra_distogram": { | |
| "max_dist": 10.0, | |
| "weight": 0.0, | |
| }, | |
| "binding_site": { | |
| "weight": 0.0, | |
| "pos_class_weight": 20.0, | |
| }, | |
| "inter_contact": { | |
| "weight": 0.0, | |
| "pos_class_weight": 200.0, | |
| }, | |
| "affinity2d": { | |
| "min_bin": 0, | |
| "max_bin": 15, | |
| "no_bins": aux_affinity_bins, | |
| "weight": 0.0, | |
| }, | |
| "affinity_cls": { | |
| "min_bin": 0, | |
| "max_bin": 15, | |
| "no_bins": aux_affinity_bins, | |
| "weight": 0.0, | |
| }, | |
| "fape_backbone": { | |
| "clamp_distance": 10.0, | |
| "loss_unit_distance": 10.0, | |
| "weight": 0.5, | |
| }, | |
| "fape_sidechain": { | |
| "clamp_distance": 10.0, | |
| "length_scale": 10.0, | |
| "weight": 0.5, | |
| }, | |
| "fape_interface": { | |
| "clamp_distance": 10.0, | |
| "length_scale": 10.0, | |
| "weight": 0.0, | |
| }, | |
| "plddt_loss": { | |
| "min_resolution": 0.1, | |
| "max_resolution": 3.0, | |
| "cutoff": 15.0, | |
| "no_bins": 50, | |
| "eps": eps, # 1e-10, | |
| "weight": 0.01, | |
| }, | |
| "supervised_chi": { | |
| "chi_weight": 0.5, | |
| "angle_norm_weight": 0.01, | |
| "eps": eps, # 1e-6, | |
| "weight": 1.0, | |
| }, | |
| "chain_center_of_mass": { | |
| "clamp_distance": -4.0, | |
| "weight": 0., | |
| "eps": eps, | |
| "enabled": False, | |
| }, | |
| "eps": eps, | |
| }, | |
| "ema": {"decay": 0.999}, | |
| } | |
| ) | |