|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import random as rd |
|
|
|
|
|
from diffusion import Diffusion |
|
|
from scoring.scoring_functions import ScoringFunctions |
|
|
from utils.filter import PeptideAnalyzer |
|
|
import noise_schedule |
|
|
|
|
|
"""" |
|
|
Notes: store rolled out sequence? |
|
|
path of node objects or strings? |
|
|
should we only select valid expandable leaf nodes? |
|
|
calculate similarity between sibling nodes? |
|
|
should we evaluate generated sequences? |
|
|
""" |
|
|
class Node: |
|
|
""" |
|
|
Node class: partially unmasked SMILES string |
|
|
- parentNode: Node object at previous time step |
|
|
- childNodes: set of M Node objects generated from sampling M distinct unmasking schemes |
|
|
- totalReward: vector of cumulative rewards for all K objectives |
|
|
- visits: number of times the node has been visited by an interation |
|
|
- path: array of partially unmasked SMILES strings leading to the node from the completely masked root node |
|
|
- timestep: the time step where the sequence was sampled |
|
|
- sampleProb: probability of sampling the sequence from the diffusion model |
|
|
""" |
|
|
def __init__(self, config, tokens=None, parentNode=None, childNodes=[], scoreVector=None, totalReward=None, timestep=None, sampleProb=None): |
|
|
self.config = config |
|
|
self.parentNode = parentNode |
|
|
self.childNodes = childNodes |
|
|
self.scoreVector = scoreVector |
|
|
|
|
|
|
|
|
if totalReward is not None: |
|
|
self.totalReward = totalReward |
|
|
else: |
|
|
self.totalReward = np.zeros(self.config.mcts.num_objectives) |
|
|
|
|
|
|
|
|
self.visits = 1 |
|
|
|
|
|
|
|
|
|
|
|
self.timestep = timestep |
|
|
|
|
|
self.sampleProb = sampleProb |
|
|
|
|
|
|
|
|
self.tokens = tokens |
|
|
|
|
|
|
|
|
|
|
|
def selectNode(self, num_func): |
|
|
""" |
|
|
Selects a node to move to among the children nodes |
|
|
""" |
|
|
|
|
|
nodeStatus = self.getExpandStatus() |
|
|
|
|
|
|
|
|
if (nodeStatus == 3): |
|
|
|
|
|
paretoFront = {} |
|
|
for childNode in self.childNodes: |
|
|
childStatus = childNode.getExpandStatus() |
|
|
|
|
|
if childStatus == 2 or childStatus == 3: |
|
|
selectScore = childNode.calcSelectScore() |
|
|
paretoFront = updateParetoFront(paretoFront, childNode, selectScore, num_func) |
|
|
|
|
|
|
|
|
|
|
|
selected = rd.choice(list(paretoFront.keys())) |
|
|
|
|
|
return selected, selected.getExpandStatus() |
|
|
|
|
|
|
|
|
return self, nodeStatus |
|
|
|
|
|
def addChildNode(self, tokens, totalReward, prob=None): |
|
|
"""" |
|
|
Adds a child node |
|
|
""" |
|
|
child = Node(config=self.config, |
|
|
tokens=tokens, |
|
|
parentNode=self, |
|
|
childNodes=[], |
|
|
totalReward=totalReward, |
|
|
timestep=self.timestep+1, |
|
|
sampleProb=prob) |
|
|
|
|
|
self.childNodes.append(child) |
|
|
return child |
|
|
|
|
|
def updateNode(self, rewards): |
|
|
""" |
|
|
Updates the cumulative rewards vector with the reward vector at a descendent leaf node. |
|
|
Increments the number of visits to the node. |
|
|
""" |
|
|
self.visits += 1 |
|
|
self.totalReward += rewards |
|
|
|
|
|
def calcSelectScore(self): |
|
|
""" |
|
|
Calculates the select score for the node from the cumulative rewards vector and number of visits. |
|
|
- c: determines the degree of exploration |
|
|
- minSelectScore: determines the |
|
|
""" |
|
|
"""" |
|
|
if not self.parentNode: |
|
|
return 0.0 |
|
|
""" |
|
|
|
|
|
normRewards = self.totalReward / self.visits |
|
|
if self.sampleProb is not None: |
|
|
print("Sample Prob") |
|
|
print(self.sampleProb) |
|
|
return normRewards + (self.config.mcts.sample_prob * self.sampleProb * np.sqrt(self.root.visits) / self.visits) |
|
|
return normRewards |
|
|
|
|
|
def getExpandStatus(self): |
|
|
""" |
|
|
Returns an integer indicating whether the node is a: |
|
|
1. terminal node (sequence is fully unmasked) |
|
|
2. legal leaf node (partially unmasked sequence that can be expanded) |
|
|
3. legal non-leaf node (already expanded sequence with M child nodes) |
|
|
""" |
|
|
if self.timestep == self.config.sampling.steps: |
|
|
return 1 |
|
|
elif (self.timestep < self.config.sampling.steps) and (len(self.childNodes) == 0): |
|
|
return 2 |
|
|
return 3 |
|
|
|
|
|
"""END OF NODE CLASS""" |
|
|
|
|
|
def updateParetoFront(paretoFront, node, scoreVector, num_func): |
|
|
""" |
|
|
Removes sequences that are dominated by scoreVector |
|
|
adds the SMILES sequence if it is non-dominated and its scoreVector |
|
|
""" |
|
|
paretoSize = len(paretoFront) |
|
|
if paretoSize == 0: |
|
|
|
|
|
paretoFront[node] = scoreVector |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
nondominate = [] |
|
|
|
|
|
delete = [] |
|
|
for k, v in paretoFront.items(): |
|
|
nondominated = scoreVector >= np.asarray(v) |
|
|
dominant = scoreVector > np.asarray(v) |
|
|
|
|
|
if num_func <= len(nondominated): |
|
|
attn_nondominated = nondominated[:num_func] |
|
|
attn_dominant = dominant[:num_func] |
|
|
|
|
|
|
|
|
if attn_nondominated.all() and attn_dominant.any(): |
|
|
|
|
|
delete.append(k) |
|
|
|
|
|
nondominate.append(True) |
|
|
elif attn_nondominated.all(): |
|
|
|
|
|
nondominate.append(True) |
|
|
else: |
|
|
|
|
|
nondominate.append(False) |
|
|
|
|
|
nondominate = np.asarray(nondominate) |
|
|
|
|
|
if nondominate.all(): |
|
|
paretoFront[node] = scoreVector |
|
|
|
|
|
|
|
|
while (paretoSize > 0) and (len(delete) > 0): |
|
|
|
|
|
del paretoFront[delete[0]] |
|
|
del delete[0] |
|
|
paretoSize -= 1 |
|
|
return paretoFront |
|
|
|
|
|
"""BEGINNING OF MCTS CLASS""" |
|
|
|
|
|
class MCTS: |
|
|
def __init__(self, config, max_sequence_length=None, mdlm=None, score_func_names=[], prot_seqs=None, num_func = []): |
|
|
self.config = config |
|
|
self.noise = noise_schedule.get_noise(config) |
|
|
self.time_conditioning = self.config.time_conditioning |
|
|
|
|
|
self.peptideParetoFront = {} |
|
|
self.num_steps = config.sampling.steps |
|
|
self.num_sequences = config.sampling.num_sequences |
|
|
|
|
|
|
|
|
self.mdlm = mdlm |
|
|
self.tokenizer = mdlm.tokenizer |
|
|
self.device = mdlm.device |
|
|
|
|
|
if max_sequence_length is None: |
|
|
self.sequence_length = self.config.sampling.seq_length |
|
|
else: |
|
|
self.sequence_length = max_sequence_length |
|
|
|
|
|
self.num_iter = config.mcts.num_iter |
|
|
|
|
|
self.num_child = config.mcts.num_children |
|
|
|
|
|
|
|
|
self.score_functions = ScoringFunctions(score_func_names, prot_seqs) |
|
|
self.num_func = num_func |
|
|
self.iter_num = 0 |
|
|
self.curr_num_func = 1 |
|
|
self.analyzer = PeptideAnalyzer() |
|
|
|
|
|
|
|
|
self.valid_fraction_log = [] |
|
|
self.affinity1_log = [] |
|
|
self.affinity2_log = [] |
|
|
self.permeability_log = [] |
|
|
self.sol_log = [] |
|
|
self.hemo_log = [] |
|
|
self.nf_log = [] |
|
|
|
|
|
def reset(self): |
|
|
self.iter_num = 0 |
|
|
self.valid_fraction_log = [] |
|
|
self.affinity1_log = [] |
|
|
self.affinity2_log = [] |
|
|
self.permeability_log = [] |
|
|
self.sol_log = [] |
|
|
self.hemo_log = [] |
|
|
self.nf_log = [] |
|
|
self.peptideParetoFront = {} |
|
|
|
|
|
def forward(self, rootNode): |
|
|
self.reset() |
|
|
|
|
|
while (self.iter_num < self.num_iter): |
|
|
self.iter_num += 1 |
|
|
|
|
|
|
|
|
leafNode, _ = self.select(rootNode) |
|
|
|
|
|
|
|
|
|
|
|
self.expand(leafNode) |
|
|
|
|
|
|
|
|
return self.peptideParetoFront |
|
|
|
|
|
|
|
|
def updateParetoFront(self, sequence, scoreVector, tokens): |
|
|
""" |
|
|
Removes sequences that are dominated by scoreVector |
|
|
adds the SMILES sequence if it is non-dominated and its scoreVector |
|
|
|
|
|
num_func: index of the last objective to consider when updating the pareto front from 0 to K |
|
|
""" |
|
|
paretoSize = len(self.peptideParetoFront) |
|
|
|
|
|
self.curr_num_func = 1 |
|
|
|
|
|
for i in range(len(self.num_func)): |
|
|
if self.iter_num >= self.num_func[i]: |
|
|
self.curr_num_func = i+1 |
|
|
|
|
|
if paretoSize == 0: |
|
|
|
|
|
self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens} |
|
|
|
|
|
rewardVector = np.ones(len(scoreVector)) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
nondominate = [] |
|
|
|
|
|
delete = [] |
|
|
|
|
|
rewardVector = np.zeros(len(scoreVector)) |
|
|
for k, v in self.peptideParetoFront.items(): |
|
|
|
|
|
|
|
|
|
|
|
nondominated = scoreVector >= np.asarray(v['scores']) |
|
|
dominant = scoreVector > np.asarray(v['scores']) |
|
|
|
|
|
rewardVector += nondominated |
|
|
|
|
|
if self.curr_num_func <= len(nondominated): |
|
|
attn_nondominated = nondominated[:self.curr_num_func] |
|
|
attn_dominant = dominant[:self.curr_num_func] |
|
|
|
|
|
|
|
|
|
|
|
if attn_nondominated.all() and attn_dominant.any(): |
|
|
|
|
|
delete.append(k) |
|
|
|
|
|
nondominate.append(True) |
|
|
elif attn_nondominated.all(): |
|
|
|
|
|
nondominate.append(True) |
|
|
else: |
|
|
|
|
|
nondominate.append(False) |
|
|
|
|
|
assert len(nondominate) == paretoSize |
|
|
nondominate = np.asarray(nondominate) |
|
|
|
|
|
|
|
|
if nondominate.all() or paretoSize < self.num_sequences: |
|
|
self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens} |
|
|
|
|
|
rewardVector = rewardVector / paretoSize |
|
|
|
|
|
|
|
|
while (paretoSize > self.num_sequences) and (len(delete) > 0): |
|
|
|
|
|
del self.peptideParetoFront[delete[0]] |
|
|
del delete[0] |
|
|
paretoSize -= 1 |
|
|
|
|
|
return rewardVector |
|
|
|
|
|
def isPathEnd(self, path, maxDepth): |
|
|
""" |
|
|
Checks if the node is completely unmasked (ie. end of path) |
|
|
or if the path is at the max depth |
|
|
""" |
|
|
if (path[-1] != self.config.mcts.mask_token).all(): |
|
|
return True |
|
|
elif len(path) >= maxDepth: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def select(self, currNode): |
|
|
""" |
|
|
Traverse the tree from the root node until reaching a legal leaf node |
|
|
""" |
|
|
while True: |
|
|
currNode, nodeStatus = currNode.selectNode(self.curr_num_func) |
|
|
if nodeStatus != 3: |
|
|
return currNode, nodeStatus |
|
|
|
|
|
def expand(self, parentNode, eps=1e-5, checkSimilarity = True): |
|
|
""" |
|
|
Sample unmasking steps from the pre-trained MDLM |
|
|
adds num_children partially unmasked sequences to the children of the parentNode |
|
|
""" |
|
|
|
|
|
num_children = self.config.mcts.num_children |
|
|
|
|
|
allChildReward = np.zeros_like(parentNode.totalReward) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_rollout_steps = self.num_steps - parentNode.timestep |
|
|
|
|
|
rollout_t = torch.linspace(1, eps, num_rollout_steps, device=self.device) |
|
|
dt = (1 - eps) / self.num_steps |
|
|
p_x0_cache = None |
|
|
|
|
|
|
|
|
x = parentNode.tokens['input_ids'].to(self.device) |
|
|
attn_mask = parentNode.tokens['attention_mask'].to(self.device) |
|
|
|
|
|
t = rollout_t[0] * torch.ones(num_children, 1, device = self.device) |
|
|
|
|
|
print("token array:") |
|
|
print(x) |
|
|
p_x0_cache, x_children = self.mdlm.batch_cached_reverse_step(token_array=x, |
|
|
t=t, dt=dt, |
|
|
batch_size=num_children, |
|
|
attn_mask=attn_mask) |
|
|
x_rollout = x_children |
|
|
|
|
|
for i in range(1, num_rollout_steps): |
|
|
t = rollout_t[i] * torch.ones(num_children, 1, device = self.device) |
|
|
|
|
|
p_x0_cache, x_next = self.mdlm.cached_reverse_step(x=x_rollout, |
|
|
t=t, dt=dt, p_x0=p_x0_cache, |
|
|
attn_mask=attn_mask) |
|
|
|
|
|
if (not torch.allclose(x_next, x) or self.time_conditioning): |
|
|
|
|
|
p_x0_cache = None |
|
|
|
|
|
x_rollout = x_next |
|
|
|
|
|
if self.config.sampling.noise_removal: |
|
|
t = rollout_t[-1] * torch.ones(x.shape[0], 1, device=self.device) |
|
|
"""if self.sampler == 'analytic': |
|
|
x = self.mdlm._denoiser_update(x, t) |
|
|
else:""" |
|
|
time_cond = self.noise(t)[0] |
|
|
x_rollout = self.mdlm.forward(x_rollout, attn_mask, time_cond).argmax(dim=-1) |
|
|
|
|
|
childSequences = self.tokenizer.batch_decode(x_rollout) |
|
|
|
|
|
validSequences = [] |
|
|
maskedTokens = [] |
|
|
unmaskedTokens = [] |
|
|
for i in range(num_children): |
|
|
childSeq = childSequences[i] |
|
|
|
|
|
rewardVector = np.zeros(self.config.mcts.num_objectives) |
|
|
|
|
|
|
|
|
if self.analyzer.is_peptide(childSeq): |
|
|
validSequences.append(childSeq) |
|
|
maskedTokens.append(x_children[i]) |
|
|
unmaskedTokens.append(x_rollout[i]) |
|
|
else: |
|
|
childTokens = {'input_ids': x_children[i], 'attention_mask': attn_mask} |
|
|
parentNode.addChildNode(tokens=childTokens, |
|
|
totalReward=rewardVector) |
|
|
|
|
|
if (len(validSequences) != 0): |
|
|
scoreVectors = self.score_functions(input_seqs=validSequences) |
|
|
average_scores = scoreVectors.T |
|
|
if self.config.mcts.single: |
|
|
self.permeability_log.append(average_scores[0]) |
|
|
else: |
|
|
self.affinity1_log.append(average_scores[0]) |
|
|
self.sol_log.append(average_scores[1]) |
|
|
self.hemo_log.append(average_scores[2]) |
|
|
self.nf_log.append(average_scores[3]) |
|
|
if self.config.mcts.perm: |
|
|
self.permeability_log.append(average_scores[4]) |
|
|
elif self.config.mcts.dual: |
|
|
self.affinity2_log.append(average_scores[4]) |
|
|
else: |
|
|
self.affinity1_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
|
|
self.sol_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
|
|
self.hemo_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
|
|
self.nf_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
|
|
|
|
|
if self.config.mcts.perm: |
|
|
self.permeability_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
|
|
elif self.config.mcts.dual: |
|
|
self.affinity2_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences))) |
|
|
|
|
|
for i, validSeq in enumerate(validSequences): |
|
|
|
|
|
scoreVector = scoreVectors[i] |
|
|
|
|
|
|
|
|
rewardVector = self.updateParetoFront(validSeq, scoreVector, unmaskedTokens[i]) |
|
|
print(scoreVector) |
|
|
print(rewardVector) |
|
|
|
|
|
|
|
|
allChildReward += rewardVector |
|
|
|
|
|
|
|
|
childTokens = {'input_ids': maskedTokens[i], 'attention_mask': attn_mask} |
|
|
parentNode.addChildNode(tokens=childTokens, |
|
|
totalReward=rewardVector) |
|
|
|
|
|
|
|
|
invalid = (num_children - len(validSequences)) / num_children |
|
|
|
|
|
valid_fraction = len(validSequences) / num_children |
|
|
print(f"Valid fraction: {valid_fraction}") |
|
|
self.valid_fraction_log.append(valid_fraction) |
|
|
|
|
|
print(self.config.mcts.invalid_penalty) |
|
|
|
|
|
allChildReward = allChildReward - (self.config.mcts.invalid_penalty * invalid) |
|
|
|
|
|
self.backprop(parentNode, allChildReward) |
|
|
|
|
|
|
|
|
def backprop(self, node, reward_vector): |
|
|
|
|
|
while node: |
|
|
node.updateNode(reward_vector) |
|
|
node = node.parentNode |
|
|
|
|
|
|
|
|
def getSequenceForObjective(self, objective_index, k): |
|
|
""" |
|
|
Returns the top-k sequences in the pareto front that has the best score for |
|
|
a given objective and their score vectors for all objectives |
|
|
""" |
|
|
|
|
|
|
|
|
topk = {} |
|
|
|
|
|
peptides = [] |
|
|
objectiveScores = [] |
|
|
for k, v in self.peptideParetoFront.items(): |
|
|
|
|
|
peptides.append(k) |
|
|
|
|
|
objectiveScores.append(v['token_ids'][objective_index]) |
|
|
|
|
|
objectiveScores = torch.tensor(objectiveScores) |
|
|
topKScores = torch.topk(objectiveScores, k) |
|
|
for (_, index) in topKScores.items(): |
|
|
seq = peptides[index] |
|
|
|
|
|
topk[seq] = self.peptideParetoFront.get(seq) |
|
|
|
|
|
return topk |
|
|
|
|
|
|