PepTune / pareto_mcts.py
Sophia Tang
model upload
e54915d
import numpy as np
import torch
import torch.nn.functional as F
import numpy as np
import random as rd
from diffusion import Diffusion
from scoring.scoring_functions import ScoringFunctions
from utils.filter import PeptideAnalyzer
import noise_schedule
""""
Notes: store rolled out sequence?
path of node objects or strings?
should we only select valid expandable leaf nodes?
calculate similarity between sibling nodes?
should we evaluate generated sequences?
"""
class Node:
"""
Node class: partially unmasked SMILES string
- parentNode: Node object at previous time step
- childNodes: set of M Node objects generated from sampling M distinct unmasking schemes
- totalReward: vector of cumulative rewards for all K objectives
- visits: number of times the node has been visited by an interation
- path: array of partially unmasked SMILES strings leading to the node from the completely masked root node
- timestep: the time step where the sequence was sampled
- sampleProb: probability of sampling the sequence from the diffusion model
"""
def __init__(self, config, tokens=None, parentNode=None, childNodes=[], scoreVector=None, totalReward=None, timestep=None, sampleProb=None):
self.config = config
self.parentNode = parentNode
self.childNodes = childNodes
self.scoreVector = scoreVector
# initialize total rewards to the reward of the roll out unmasked sequence
if totalReward is not None:
self.totalReward = totalReward
else:
self.totalReward = np.zeros(self.config.mcts.num_objectives)
# set initial visits to 1
self.visits = 1
# array of all sequences in path from the root -> node
#self.path = path
# set timestep (value between 0 and num_steps)
self.timestep = timestep
# set the sampling probabiltiy equal to the probability from the reverse posterior
self.sampleProb = sampleProb
# dict with 'input_ids' as token array and 'attention_mask'
self.tokens = tokens
#self.sequence = sequence
def selectNode(self, num_func):
"""
Selects a node to move to among the children nodes
"""
# extract the status of the current node
nodeStatus = self.getExpandStatus()
# if the node is a legal non-leaf node
if (nodeStatus == 3):
# initialize array that will store select score vectors of each child node
paretoFront = {}
for childNode in self.childNodes:
childStatus = childNode.getExpandStatus()
# only append child if it is legal leaf node (expandable) or legal non-leaf node
if childStatus == 2 or childStatus == 3:
selectScore = childNode.calcSelectScore()
paretoFront = updateParetoFront(paretoFront, childNode, selectScore, num_func)
# randomly select a node on the Pareto front
#selected = rd.choice(paretoFront)
selected = rd.choice(list(paretoFront.keys()))
# return selected child node and status
return selected, selected.getExpandStatus()
# if node is not valid non-leaf node
return self, nodeStatus
def addChildNode(self, tokens, totalReward, prob=None):
""""
Adds a child node
"""
child = Node(config=self.config,
tokens=tokens,
parentNode=self,
childNodes=[],
totalReward=totalReward,
timestep=self.timestep+1,
sampleProb=prob)
self.childNodes.append(child)
return child
def updateNode(self, rewards):
"""
Updates the cumulative rewards vector with the reward vector at a descendent leaf node.
Increments the number of visits to the node.
"""
self.visits += 1
self.totalReward += rewards
def calcSelectScore(self):
"""
Calculates the select score for the node from the cumulative rewards vector and number of visits.
- c: determines the degree of exploration
- minSelectScore: determines the
"""
""""
if not self.parentNode:
return 0.0
"""
# K-dimensional vector of normalized rewards for each objective
normRewards = self.totalReward / self.visits
if self.sampleProb is not None:
print("Sample Prob")
print(self.sampleProb)
return normRewards + (self.config.mcts.sample_prob * self.sampleProb * np.sqrt(self.root.visits) / self.visits)
return normRewards
def getExpandStatus(self):
"""
Returns an integer indicating whether the node is a:
1. terminal node (sequence is fully unmasked)
2. legal leaf node (partially unmasked sequence that can be expanded)
3. legal non-leaf node (already expanded sequence with M child nodes)
"""
if self.timestep == self.config.sampling.steps:
return 1
elif (self.timestep < self.config.sampling.steps) and (len(self.childNodes) == 0):
return 2
return 3
"""END OF NODE CLASS"""
def updateParetoFront(paretoFront, node, scoreVector, num_func):
"""
Removes sequences that are dominated by scoreVector
adds the SMILES sequence if it is non-dominated and its scoreVector
"""
paretoSize = len(paretoFront)
if paretoSize == 0:
# if pareto front is empty, add sequence and scoreVector
paretoFront[node] = scoreVector
else:
# vector of boolean
# true: sequence is non-dominated by the pareto-optimal sequence
# false: sequence is completely dominated by the pareto-optimal sequence
nondominate = []
# sequences to be deleted
delete = []
for k, v in paretoFront.items():
nondominated = scoreVector >= np.asarray(v)
dominant = scoreVector > np.asarray(v)
if num_func <= len(nondominated):
attn_nondominated = nondominated[:num_func]
attn_dominant = dominant[:num_func]
# all scores are greater than or equal to v and at least one score is strictly greater than v
if attn_nondominated.all() and attn_dominant.any():
# add the dominated sequence to be deleted
delete.append(k)
# sequence is dominant
nondominate.append(True)
elif attn_nondominated.all():
# sequence is non-dominated
nondominate.append(True)
else:
# sequence is completely dominated
nondominate.append(False)
nondominate = np.asarray(nondominate)
# if sequence is either dominant or non-dominated by all sequences in pareto-front -> add to pareto front
if nondominate.all():
paretoFront[node] = scoreVector
# delete all dominated sequences
while (paretoSize > 0) and (len(delete) > 0):
#for k in delete:
del paretoFront[delete[0]]
del delete[0]
paretoSize -= 1
return paretoFront
"""BEGINNING OF MCTS CLASS"""
class MCTS:
def __init__(self, config, max_sequence_length=None, mdlm=None, score_func_names=[], prot_seqs=None, num_func = []):
self.config = config
self.noise = noise_schedule.get_noise(config)
self.time_conditioning = self.config.time_conditioning
# dictionary of k (SMILES string) and v (score vector) of Pareto-optimal sequences
self.peptideParetoFront = {}
self.num_steps = config.sampling.steps
self.num_sequences = config.sampling.num_sequences
# mdlm model
self.mdlm = mdlm
self.tokenizer = mdlm.tokenizer
self.device = mdlm.device
if max_sequence_length is None:
self.sequence_length = self.config.sampling.seq_length
else:
self.sequence_length = max_sequence_length
self.num_iter = config.mcts.num_iter
self.num_child = config.mcts.num_children
# score functions
self.score_functions = ScoringFunctions(score_func_names, prot_seqs)
self.num_func = num_func # K-dimensional vector with the iteration number to start conditioning on each of the objectives in increasng order
self.iter_num = 0
self.curr_num_func = 1
self.analyzer = PeptideAnalyzer()
# track fraction of valid peptides
self.valid_fraction_log = []
self.affinity1_log = []
self.affinity2_log = []
self.permeability_log = []
self.sol_log = []
self.hemo_log = []
self.nf_log = []
def reset(self):
self.iter_num = 0
self.valid_fraction_log = []
self.affinity1_log = []
self.affinity2_log = []
self.permeability_log = []
self.sol_log = []
self.hemo_log = []
self.nf_log = []
self.peptideParetoFront = {}
def forward(self, rootNode):
self.reset()
while (self.iter_num < self.num_iter):
self.iter_num += 1
# traverse the tree form the root node until a leaf node
leafNode, _ = self.select(rootNode)
#print(leafNode.tokens['input_ids'])
# expand leaf node into num_children partially unmasked sequences at the next timestep
self.expand(leafNode)
# return dictionary of pareto front peptides and their score vectors
return self.peptideParetoFront
# change to include more even if dominated? since there is error in the scores
def updateParetoFront(self, sequence, scoreVector, tokens):
"""
Removes sequences that are dominated by scoreVector
adds the SMILES sequence if it is non-dominated and its scoreVector
num_func: index of the last objective to consider when updating the pareto front from 0 to K
"""
paretoSize = len(self.peptideParetoFront)
self.curr_num_func = 1
for i in range(len(self.num_func)):
if self.iter_num >= self.num_func[i]:
self.curr_num_func = i+1
if paretoSize == 0:
# if pareto front is empty, add sequence and scoreVector
self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens}
# if pareto front is empty, set reward vector to 1s
rewardVector = np.ones(len(scoreVector))
else:
# vector of boolean
# true: sequence is non-dominated by the pareto-optimal sequence
# false: sequence is completely dominated by the pareto-optimal sequence
nondominate = []
# sequences to be deleted
delete = []
# initialize reward vector with zeros
rewardVector = np.zeros(len(scoreVector))
for k, v in self.peptideParetoFront.items():
# boolean vector
# true: if all metrics are equal or larger
# false: if the pareto front sequence dominates scoreVector
nondominated = scoreVector >= np.asarray(v['scores']) # [num_objectives]
dominant = scoreVector > np.asarray(v['scores'])
# add to reward vector
rewardVector += nondominated # [num_objectives]
if self.curr_num_func <= len(nondominated):
attn_nondominated = nondominated[:self.curr_num_func]
attn_dominant = dominant[:self.curr_num_func]
# only delete pareto-optimal sequence if
# all scores are greater than or equal to v and at least one score is strictly greater than v
if attn_nondominated.all() and attn_dominant.any():
# add the dominated sequence to be deleted
delete.append(k)
# sequence is dominant
nondominate.append(True)
elif attn_nondominated.all():
# sequence is non-dominated
nondominate.append(True)
else:
# sequence is completely dominated
nondominate.append(False)
assert len(nondominate) == paretoSize
nondominate = np.asarray(nondominate)
# if sequence is either dominant or non-dominated by all sequences in pareto-front -> add to pareto front
# or if the pareto front does not have enough sequences
if nondominate.all() or paretoSize < self.num_sequences:
self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens}
rewardVector = rewardVector / paretoSize
# delete all dominated sequences if pareto front is larger than num_sequences
while (paretoSize > self.num_sequences) and (len(delete) > 0):
#for k in delete:
del self.peptideParetoFront[delete[0]]
del delete[0]
paretoSize -= 1
return rewardVector
def isPathEnd(self, path, maxDepth):
"""
Checks if the node is completely unmasked (ie. end of path)
or if the path is at the max depth
"""
if (path[-1] != self.config.mcts.mask_token).all():
return True
elif len(path) >= maxDepth:
return True
return False
def select(self, currNode):
"""
Traverse the tree from the root node until reaching a legal leaf node
"""
while True:
currNode, nodeStatus = currNode.selectNode(self.curr_num_func)
if nodeStatus != 3:
return currNode, nodeStatus
def expand(self, parentNode, eps=1e-5, checkSimilarity = True):
"""
Sample unmasking steps from the pre-trained MDLM
adds num_children partially unmasked sequences to the children of the parentNode
"""
num_children = self.config.mcts.num_children
# initialize child rewards that will be added to total rewards
allChildReward = np.zeros_like(parentNode.totalReward) # (n_objectives)
# compute number of rollout steps
# if parentNode.timestep = self.num_steps then num_rollout_steps = 1
num_rollout_steps = self.num_steps - parentNode.timestep
# array of rollout timesteps from the timestep of parent node to 0
rollout_t = torch.linspace(1, eps, num_rollout_steps, device=self.device)
dt = (1 - eps) / self.num_steps
p_x0_cache = None
# initialize x and attn_mask
x = parentNode.tokens['input_ids'].to(self.device)
attn_mask = parentNode.tokens['attention_mask'].to(self.device)
t = rollout_t[0] * torch.ones(num_children, 1, device = self.device)
# generate (n_children, seq_length) array of sampled children nodes
print("token array:")
print(x)
p_x0_cache, x_children = self.mdlm.batch_cached_reverse_step(token_array=x,
t=t, dt=dt,
batch_size=num_children,
attn_mask=attn_mask)
x_rollout = x_children
for i in range(1, num_rollout_steps):
t = rollout_t[i] * torch.ones(num_children, 1, device = self.device)
p_x0_cache, x_next = self.mdlm.cached_reverse_step(x=x_rollout,
t=t, dt=dt, p_x0=p_x0_cache,
attn_mask=attn_mask)
if (not torch.allclose(x_next, x) or self.time_conditioning):
# Disable caching
p_x0_cache = None
x_rollout = x_next
if self.config.sampling.noise_removal:
t = rollout_t[-1] * torch.ones(x.shape[0], 1, device=self.device)
"""if self.sampler == 'analytic':
x = self.mdlm._denoiser_update(x, t)
else:"""
time_cond = self.noise(t)[0]
x_rollout = self.mdlm.forward(x_rollout, attn_mask, time_cond).argmax(dim=-1) # (n_children, seq_length)
childSequences = self.tokenizer.batch_decode(x_rollout)
validSequences = []
maskedTokens = []
unmaskedTokens = []
for i in range(num_children):
childSeq = childSequences[i]
#scoreVector = scoreVectors[i]
rewardVector = np.zeros(self.config.mcts.num_objectives)
# check if the peptide is valid
if self.analyzer.is_peptide(childSeq):
validSequences.append(childSeq)
maskedTokens.append(x_children[i])
unmaskedTokens.append(x_rollout[i])
else:
childTokens = {'input_ids': x_children[i], 'attention_mask': attn_mask}
parentNode.addChildNode(tokens=childTokens,
totalReward=rewardVector)
if (len(validSequences) != 0):
scoreVectors = self.score_functions(input_seqs=validSequences)
average_scores = scoreVectors.T
if self.config.mcts.single:
self.permeability_log.append(average_scores[0])
else:
self.affinity1_log.append(average_scores[0])
self.sol_log.append(average_scores[1])
self.hemo_log.append(average_scores[2])
self.nf_log.append(average_scores[3])
if self.config.mcts.perm:
self.permeability_log.append(average_scores[4])
elif self.config.mcts.dual:
self.affinity2_log.append(average_scores[4])
else:
self.affinity1_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
self.sol_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
self.hemo_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
self.nf_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
if self.config.mcts.perm:
self.permeability_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
elif self.config.mcts.dual:
self.affinity2_log.append(np.zeros((self.config.mcts.num_objectives, self.config.sampling.num_sequences)))
for i, validSeq in enumerate(validSequences):
#tokens = validTokens[i]
scoreVector = scoreVectors[i]
# update pareto front
rewardVector = self.updateParetoFront(validSeq, scoreVector, unmaskedTokens[i])
print(scoreVector)
print(rewardVector)
# add to all child reward vector for backprop
allChildReward += rewardVector
# create node for sequence and add to the children node of parent
childTokens = {'input_ids': maskedTokens[i], 'attention_mask': attn_mask}
parentNode.addChildNode(tokens=childTokens,
totalReward=rewardVector)
# compute fraction of invalid child sequences
invalid = (num_children - len(validSequences)) / num_children
valid_fraction = len(validSequences) / num_children
print(f"Valid fraction: {valid_fraction}")
self.valid_fraction_log.append(valid_fraction)
print(self.config.mcts.invalid_penalty)
# subtract score using fraction of invalid sequences from reward
allChildReward = allChildReward - (self.config.mcts.invalid_penalty * invalid)
# backpropogate all child rewards
self.backprop(parentNode, allChildReward)
def backprop(self, node, reward_vector):
# backpropogate rewards through the path leading to the leaf node from the root
while node:
node.updateNode(reward_vector)
node = node.parentNode
def getSequenceForObjective(self, objective_index, k):
"""
Returns the top-k sequences in the pareto front that has the best score for
a given objective and their score vectors for all objectives
"""
# dictionary of top-k peptides for the objective
topk = {}
peptides = []
objectiveScores = []
for k, v in self.peptideParetoFront.items():
# store peptides in list
peptides.append(k)
# store score for objective
objectiveScores.append(v['token_ids'][objective_index])
objectiveScores = torch.tensor(objectiveScores)
topKScores = torch.topk(objectiveScores, k)
for (_, index) in topKScores.items():
seq = peptides[index]
topk[seq] = self.peptideParetoFront.get(seq)
return topk