Spaces:
Running
Running
| # import json | |
| # from tqdm import tqdm | |
| # import matplotlib.pyplot as plt | |
| # import numpy as np | |
| # f = open("/home/ubuntu/proteinedit-mm-clean/data/esm_subset/abstract.json", "r") | |
| # ann = json.load(f) | |
| # total = 0 | |
| # l_256 = 0 | |
| # l_384 = 0 | |
| # x = [] | |
| # for i in tqdm(range(0, len(ann))): | |
| # total += len(ann[i]["caption"].split()) | |
| # if (len(ann[i]["caption"].split()) <= 256): | |
| # l_256 += 1 | |
| # if (len(ann[i]["caption"].split()) <= 384): | |
| # l_384 += 1 | |
| # x.append(len(ann[i]["caption"].split())) | |
| # x = np.array(x) | |
| # print("avg: ", str(total / len(ann))) | |
| # print("below 256: ", str(l_256 / len(ann))) | |
| # print("below 384: ", str(l_384 / len(ann))) | |
| # plt.hist(x) | |
| # plt.savefig("test.png") | |
| from minigpt4.datasets.qa_dataset import QADataset | |
| datasets_raw = QADataset(pdb_root="/home/ubuntu/pt/", | |
| seq_root="/home/ubuntu/seq/", | |
| ann_paths="/home/ubuntu/proteinchat/data/esm_subset/qa_all.json", | |
| dataset_description="/home/ubuntu/dataset.json", | |
| chain="A") | |
| print(datasets_raw[0]["q_input"]) | |
| print(datasets_raw[0]["a_input"]) | |
| print(len(datasets_raw)) | |
| import esm | |
| import torch | |
| from esm.inverse_folding.util import load_coords | |
| device = 'cuda' | |
| # pdb_file = '/home/ubuntu/7md4.pdb' | |
| # pdb_file = "/home/ubuntu/8t3r.pdb" | |
| def encode(file): | |
| pdb_file = f'/home/ubuntu/test_pdb/{file}' | |
| coords, native_seq = load_coords(pdb_file, "A") | |
| print(native_seq) | |
| model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() | |
| model = model.eval().to(device) | |
| sampled_seq, encoder_out = model.sample(coords, temperature=1, | |
| device=torch.device(device)) | |
| sample_protein = encoder_out["encoder_out"][0].to(device) | |
| print(sample_protein.shape) | |
| # python -m pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+cu121.html | |
| # python -m pip install torch-sparse -f https://data.pyg.org/whl/torch-2.3.0+cu121.html | |
| # python -m pip install torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cu121.html | |
| # python -m pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html | |
| # python -m pip install torch-geometric | |
| # torch.Size([1, 32, 2560]) | |
| # /home/ubuntu/test_pdb | |
| # 1jj9.pdb 2cma.pdb 3lhj.pdb 5p11.pdb 6jzt.pdb | |
| encode('1jj9.pdb') | |
| encode('2cma.pdb') | |
| encode('3lhj.pdb') | |
| encode('5p11.pdb') | |
| encode('6jzt.pdb') |