|
|
import os |
|
|
import math |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from src.utils.model_utils import _print |
|
|
from src.guidance.solubility_module import SolubilityClassifier |
|
|
from src.sampling.unconditional_sampler import UnconditionalSampler |
|
|
|
|
|
|
|
|
class GuidedSampler: |
|
|
def __init__(self, config, esm_model, tokenizer, diffusion, device): |
|
|
self.config = config |
|
|
self.device = device |
|
|
|
|
|
self.esm = esm_model |
|
|
self.memdlm = diffusion |
|
|
self.tokenizer = tokenizer |
|
|
self.uncond_generator = UnconditionalSampler(self.tokenizer, self.memdlm) |
|
|
|
|
|
ckpt_path = os.path.join(f"/home/a03-sgoel/MeMDLM_v2/checkpoints/{config.wandb.name}/best_model.ckpt") |
|
|
self.classifier_model = SolubilityClassifier(config) |
|
|
state_dict = self.classifier_model.get_state_dict(ckpt_path) |
|
|
self.classifier_model.load_state_dict(state_dict) |
|
|
self.classifier_model.eval().to(self.device) |
|
|
|
|
|
self.top_p = self.config.guidance.top_p |
|
|
self.alpha = self.config.guidance.alpha |
|
|
self.gamma = self.config.guidance.gamma |
|
|
self.saliency_eps = self.config.guidance.saliency_eps |
|
|
self.saliency_t = self.config.guidance.saliency_t |
|
|
self.sampling_t = self.config.guidance.sampling_t |
|
|
self.boltzmann_t = self.config.guidance.boltzmann_t |
|
|
|
|
|
|
|
|
def embed_sequence(self, input_ids, attention_masks): |
|
|
with torch.no_grad(): |
|
|
outs = self.esm( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_masks, |
|
|
output_hidden_states=True, |
|
|
output_attentions=True |
|
|
) |
|
|
embeds = outs.hidden_states[-1] |
|
|
attn_matrix = outs.attentions |
|
|
return embeds, attn_matrix |
|
|
|
|
|
|
|
|
def sample_from_categorical(self, logits, temperature, noise_scale=1.0): |
|
|
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8) |
|
|
logits = (logits / temperature) + (noise_scale * gumbel_noise) |
|
|
|
|
|
log_probs = F.log_softmax(logits, dim=-1) |
|
|
_, tokens = log_probs.max(dim=-1) |
|
|
|
|
|
return tokens, log_probs |
|
|
|
|
|
|
|
|
def denoise_sequence(self, input_ids, attn_masks): |
|
|
""" |
|
|
Compute the current and prior sequences' log prob distribution. |
|
|
""" |
|
|
has_masks = (input_ids == self.tokenizer.mask_token_id).any() |
|
|
|
|
|
|
|
|
if has_masks: |
|
|
xt_prior, logits_prior = self.uncond_generator.sample_unconditional( |
|
|
xt=input_ids, |
|
|
num_steps=self.config.guidance.n_steps, |
|
|
tau=self.sampling_t, |
|
|
return_logits=True |
|
|
) |
|
|
else: |
|
|
xt_prior = input_ids |
|
|
logits_prior = self.memdlm(input_ids=input_ids, attention_mask=attn_masks) |
|
|
|
|
|
|
|
|
_, logits = self.uncond_generator.sample_unconditional( |
|
|
xt=xt_prior, |
|
|
num_steps=1, |
|
|
tau=self.sampling_t, |
|
|
return_logits=True |
|
|
) |
|
|
|
|
|
|
|
|
x0, logp_lm = self.sample_from_categorical(logits, temperature=self.sampling_t) |
|
|
|
|
|
return x0.squeeze(), logp_lm.squeeze(), logits_prior |
|
|
|
|
|
|
|
|
def get_prior(self, logits_prior, solubility_logits): |
|
|
if self.config.guidance.prior == "boltzmann": |
|
|
hydrophilic = ["D","E","K","R","N","Q","H","S","T","Y"] |
|
|
hydrophobic = ["L","I","V","F","W","M","A","C","G","P"] |
|
|
amino_acids = hydrophilic + hydrophobic |
|
|
|
|
|
tokens = list(self.tokenizer.get_vocab().keys()) |
|
|
other = [tok for tok in tokens if tok not in amino_acids] |
|
|
|
|
|
hydrophilic_idxs = [self.tokenizer.convert_tokens_to_ids(aa) for aa in hydrophilic] |
|
|
hydrophobic_idxs = [self.tokenizer.convert_tokens_to_ids(aa) for aa in hydrophobic] |
|
|
other_idxs = [self.tokenizer.convert_tokens_to_ids(tok) for tok in other] |
|
|
|
|
|
bias = torch.zeros(len(tokens), device=self.device) |
|
|
bias[hydrophilic_idxs] = 1.0 |
|
|
bias[hydrophobic_idxs] = -1.0 |
|
|
bias[other_idxs] = 0.0 |
|
|
|
|
|
sol_scores = torch.sigmoid(solubility_logits) |
|
|
token_bias = sol_scores.unsqueeze(-1) * bias |
|
|
|
|
|
lm_probs = F.softmax(logits_prior / self.sampling_t, dim=-1) |
|
|
boltz_weight = torch.exp(token_bias / self.boltzmann_t) |
|
|
|
|
|
p_prior = lm_probs * boltz_weight |
|
|
p_prior = p_prior / p_prior.sum(dim=-1, keepdim=True) |
|
|
logp_prior = torch.log(p_prior) |
|
|
|
|
|
elif self.config.guidance.prior == "lm_probs": |
|
|
_, logp_prior = self.sample_from_categorical(logits_prior, temperature=self.sampling_t) |
|
|
|
|
|
return logp_prior.squeeze() |
|
|
|
|
|
|
|
|
def compute_saliency_map(self, embeds, solubility_logits): |
|
|
""" |
|
|
Compute a saliency map as in LaMBO-2 (https://arxiv.org/abs/2305.20009) Eq. 5 |
|
|
""" |
|
|
|
|
|
solubility_logits.sum().backward(retain_graph=True) |
|
|
grads = embeds.grad.abs().sum(dim=-1) |
|
|
saliency = grads.pow(1.0 / self.saliency_t).clamp(min=self.saliency_eps).to(self.device) |
|
|
saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-6) |
|
|
return saliency.squeeze() |
|
|
|
|
|
|
|
|
def determine_edit_positions(self, saliency_map, soluble_indices, solubility_logits): |
|
|
""" |
|
|
Fix the insoluble residues and additional TM residues to |
|
|
maintain membrane-like protein structure. |
|
|
""" |
|
|
seq_len = saliency_map.shape[0] |
|
|
|
|
|
|
|
|
edit_mask = torch.ones(seq_len, dtype=torch.bool, device=self.device) |
|
|
|
|
|
|
|
|
if len(soluble_indices) > 0: |
|
|
edit_mask[soluble_indices] = False |
|
|
elif soluble_indices is None or len(soluble_indices) == 0: |
|
|
solubility_preds = F.sigmoid(solubility_logits) |
|
|
edit_mask[solubility_preds > 0.5] = False |
|
|
|
|
|
|
|
|
num_conserved = max(1, int(0.1 * edit_mask.sum())) |
|
|
_, topk_idxs = torch.topk(saliency_map, num_conserved) |
|
|
edit_mask[topk_idxs] = False |
|
|
|
|
|
edit_idxs = edit_mask.nonzero(as_tuple=True)[0] |
|
|
return edit_idxs |
|
|
|
|
|
|
|
|
def create_neighborhood(self, edit_pos, attn_matrix, top_p): |
|
|
""" |
|
|
Select a dynamic "neighborhood" of tokens for edit position via top-p sampling. |
|
|
Attention scores find relevant tokens, avoding blind updates of the individual token |
|
|
""" |
|
|
|
|
|
row = attn_matrix[edit_pos].clone().squeeze() |
|
|
row = row.index_fill( |
|
|
dim=0, |
|
|
index=torch.tensor([0, edit_pos, row.size(0)-1], device=row.device), |
|
|
value=float('-inf') |
|
|
) |
|
|
|
|
|
|
|
|
temp = 1.0 / math.log(row.size(0)) |
|
|
attn_probs = F.softmax(row / temp, dim=0) |
|
|
sorted_probs, sorted_idxs = torch.sort(attn_probs, descending=True) |
|
|
cum_probs = sorted_probs.cumsum(dim=0) |
|
|
cutoff = (cum_probs <= top_p).nonzero(as_tuple=True)[0] |
|
|
|
|
|
|
|
|
final_idx = cutoff[-1].item() + 1 if cutoff.numel() > 0 else 1 |
|
|
neighborhood = sorted_idxs[:final_idx] |
|
|
return neighborhood |
|
|
|
|
|
|
|
|
def compute_saliency_weight(self, edit_pos, attn_mat, saliency_map, neighborhood): |
|
|
""" |
|
|
Blend the saliency of the neighborhood's tokens and the token at the edit position. |
|
|
""" |
|
|
neighborhood_attns = attn_mat[edit_pos, neighborhood] |
|
|
neighborhood_attns /= neighborhood_attns.sum() |
|
|
|
|
|
neighborhood_saliencies = saliency_map[neighborhood] |
|
|
|
|
|
neighborhood_weight = torch.sum(neighborhood_attns * neighborhood_saliencies) |
|
|
ctxt_aware_saliency = saliency_map[edit_pos] + (self.gamma * neighborhood_weight) |
|
|
|
|
|
return ctxt_aware_saliency |
|
|
|
|
|
|
|
|
def compute_guidance_dist(self, logp_lm, logp_prior, saliency_weight): |
|
|
""" |
|
|
Define a guidance distribution between a prior and the current LM probs. |
|
|
Compute the log probs of the "new" (optimized) token. |
|
|
""" |
|
|
w = torch.sigmoid(saliency_weight * self.alpha) |
|
|
p_lm = torch.exp(logp_lm) |
|
|
p_prior = torch.exp(logp_prior) |
|
|
mixed_probs = (1 - w) * p_lm + w * p_prior |
|
|
guidance_dist = torch.log(mixed_probs + 1e-12) |
|
|
return guidance_dist |
|
|
|
|
|
|
|
|
def check_scaffold(self, seq1, seq2, idxs): |
|
|
changed = (seq1[idxs] != seq2[idxs]) |
|
|
if changed.any(): |
|
|
_print('soluble residues changed') |
|
|
else: |
|
|
_print('no soluble residue changes') |
|
|
|
|
|
|
|
|
def optimize_sequence(self, input_ids, attn_masks, soluble_indices): |
|
|
_print(f'soluble idx: {soluble_indices}') |
|
|
|
|
|
|
|
|
x0, logp_lm, logits_prior = self.denoise_sequence(input_ids, attn_masks) |
|
|
_print(f'og tokens: {x0}') |
|
|
_print(f'og tokens: {x0.shape}') |
|
|
_print(f'og log probs: {logp_lm.shape}') |
|
|
|
|
|
|
|
|
embeds, attn_mats = self.embed_sequence(x0.unsqueeze(0), attn_masks) |
|
|
embeds = embeds.detach().clone().requires_grad_(True) |
|
|
attn_matrix = attn_mats[-1].mean(dim=1)[0].squeeze(0) |
|
|
|
|
|
|
|
|
batch = {"embeds": embeds, "attention_mask": attn_masks} |
|
|
solubility_logits = self.classifier_model(batch) |
|
|
|
|
|
|
|
|
saliency_map = self.compute_saliency_map(embeds, solubility_logits) |
|
|
_print(f'saliency map: {saliency_map}') |
|
|
edit_positions = self.determine_edit_positions(saliency_map, soluble_indices, solubility_logits) |
|
|
_print(f'edit positions: {edit_positions}') |
|
|
|
|
|
|
|
|
logp_prior = self.get_prior(logits_prior, solubility_logits) |
|
|
_print(f'prior log probs: {logp_prior.shape}') |
|
|
|
|
|
|
|
|
for edit_pos in edit_positions.tolist(): |
|
|
neighborhood = self.create_neighborhood( |
|
|
edit_pos, |
|
|
attn_matrix, |
|
|
self.top_p |
|
|
) |
|
|
_print(f'neighborhood: {neighborhood}') |
|
|
|
|
|
ctxt_aware_saliency = self.compute_saliency_weight( |
|
|
edit_pos, |
|
|
attn_matrix, |
|
|
saliency_map, |
|
|
neighborhood |
|
|
) |
|
|
_print(f'ctx aware saliency: {ctxt_aware_saliency}') |
|
|
|
|
|
logp_lm_prime = self.compute_guidance_dist( |
|
|
logp_lm[edit_pos], |
|
|
logp_prior[edit_pos], |
|
|
ctxt_aware_saliency |
|
|
) |
|
|
logp_lm[edit_pos] = logp_lm_prime |
|
|
|
|
|
tot = torch.exp(logp_lm_prime).sum() |
|
|
one = torch.tensor(1.0, dtype=tot.dtype, device=tot.device) |
|
|
assert torch.isclose(tot, one, atol=1e-4), f"Invalid prob distribution. Sum = {tot:5f}" |
|
|
|
|
|
|
|
|
x0_prime = torch.distributions.Categorical(logits=logp_lm).sample() |
|
|
|
|
|
|
|
|
self.check_scaffold(x0, x0_prime, soluble_indices) |
|
|
|
|
|
|
|
|
x0_prime[soluble_indices] = x0[soluble_indices] |
|
|
self.check_scaffold(x0, x0_prime, soluble_indices) |
|
|
|
|
|
return x0_prime |