sgoel30 commited on
Commit
9c299b2
·
verified ·
1 Parent(s): 906ae08

Delete src/sampling/unconditional_generator.py

Browse files
src/sampling/unconditional_generator.py DELETED
@@ -1,114 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- import sys
4
- import os
5
-
6
- import random
7
- import torch
8
- import pandas as pd
9
- import numpy as np
10
-
11
- from tqdm import tqdm
12
- from collections import Counter
13
- from omegaconf import OmegaConf
14
- from datetime import datetime
15
- from transformers import AutoTokenizer, AutoModelForMaskedLM
16
-
17
- from MeMDLM_v2.src.lm.diffusion_module import MembraneFlow
18
- from src.sampling.unconditional_sampler import UnconditionalSampler
19
- from src.utils.generate_utils import mask_for_de_novo, calc_ppl
20
- from src.utils.model_utils import _print
21
-
22
-
23
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
- os.chdir('/home/a03-sgoel/MeMDLM_v2')
25
- config = OmegaConf.load("./src/configs/lm.yaml")
26
-
27
- date = datetime.now().strftime("%Y-%m-%d")
28
-
29
-
30
-
31
- def generate_sequence(prior: str, tokenizer, generator, device):
32
- input_ids = tokenizer(prior, return_tensors="pt").to(device)['input_ids']
33
- ids = generator.sample_unconditional(
34
- xt=input_ids,
35
- num_steps=config.sampling.n_steps,
36
- return_logits=False,
37
- banned_token_ids=None
38
- #banned_token_ids=[tokenizer.convert_tokens_to_ids("P"), tokenizer.convert_tokens_to_ids("C")]
39
- )
40
- generated_sequence = tokenizer.decode(ids[0].squeeze())[5:-5].replace(" ", "") # bos/eos tokens & spaces between residues
41
- return generated_sequence
42
-
43
-
44
- def main():
45
- csv_save_path = f'./results/denovo/unconditional/{config.wandb.name}/{date}_tau=3.0_test-set_distribution'
46
-
47
- try: os.makedirs(csv_save_path, exist_ok=False)
48
- except FileExistsError: pass
49
-
50
-
51
- tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow)
52
-
53
- flow = MembraneFlow(config).to(device)
54
- state_dict = flow.get_state_dict(f"./checkpoints/{config.wandb.name}/best_model.ckpt")
55
- flow.load_state_dict(state_dict)
56
- flow.eval()
57
-
58
- esm_pth = config.lm.pretrained_esm
59
- esm_model = AutoModelForMaskedLM.from_pretrained(esm_pth).to(device)
60
- esm_model.eval()
61
-
62
- generator = UnconditionalSampler(tokenizer, flow)
63
-
64
- # # Get 100 random sequence lengths to generate
65
- # seq_lengths = [random.randint(50, 250) for _ in range(5000)]
66
-
67
- # # Determine length from positive controls
68
- # df = pd.read_csv(f'./results/denovo/unconditional/{config.wandb.name}/perin_pos_ctrl/raw_seqs.csv')
69
- # seq_lengths = [len(seq) for seq in df['Sequence'].tolist() for _ in range(500)] # generate each length 100 times
70
- # _print(seq_lengths)
71
-
72
- # Determine lengths from test set distribution
73
- df = pd.read_csv("./data/test.csv")
74
- seq_lengths = [len(seq) for seq in df['Sequence'].tolist()]
75
- length_counts = Counter(seq_lengths) # {L1: freq, L2: freq, ...}
76
- total = sum(length_counts.values()) # total number of tokens
77
- lengths = np.array(list(length_counts.keys())) # Frequency of each length
78
- probs = np.array([length_counts[l] / total for l in lengths])
79
- seq_lengths = np.random.choice(lengths, size=len(seq_lengths), p=probs)
80
-
81
- generation_results = []
82
- for seq_len in tqdm(seq_lengths, desc=f"Generating sequences: "):
83
- seq_res = []
84
-
85
- masked_seq = mask_for_de_novo(seq_len) # Sequence of all <mask> tokens
86
- gen_seq = ""
87
- attempts = 0
88
-
89
- while len(gen_seq) != seq_len and attempts < 3:
90
- gen_seq = generate_sequence(masked_seq, tokenizer, generator, device)
91
- attempts += 1
92
-
93
- if len(gen_seq) != seq_len:
94
- esm_ppl, flow_ppl = None, None
95
- else:
96
- esm_ppl = calc_ppl(esm_model, tokenizer, gen_seq, [i for i in range(len(gen_seq))], model_type='esm')
97
- flow_ppl = calc_ppl(flow, tokenizer, gen_seq, [i for i in range(len(gen_seq))], model_type='flow')
98
-
99
- _print(f'gen seq: {gen_seq}')
100
- _print(f'esm ppl: {esm_ppl}')
101
- _print(f'flow ppl: {flow_ppl}')
102
-
103
- seq_res.append(gen_seq)
104
- seq_res.append(esm_ppl)
105
- seq_res.append(flow_ppl)
106
-
107
- generation_results.append(seq_res)
108
-
109
- df = pd.DataFrame(generation_results, columns=['Generated Sequence', 'ESM PPL', 'Flow PPL'])
110
- df.to_csv(csv_save_path + "/seqs_with_ppl.csv", index=False)
111
-
112
-
113
- if __name__ == "__main__":
114
- main()