Spaces:
Paused
Paused
| #ref: https://huggingface.co/blog/AmelieSchreiber/esmbind | |
| import gradio as gr | |
| import os | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| #import wandb | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import pickle | |
| import xml.etree.ElementTree as ET | |
| from datetime import datetime | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.utils.class_weight import compute_class_weight | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| precision_recall_fscore_support, | |
| roc_auc_score, | |
| matthews_corrcoef | |
| ) | |
| from transformers import ( | |
| AutoModelForTokenClassification, | |
| AutoTokenizer, | |
| DataCollatorForTokenClassification, | |
| TrainingArguments, | |
| Trainer | |
| ) | |
| from peft import PeftModel | |
| from datasets import Dataset | |
| from accelerate import Accelerator | |
| # Imports specific to the custom peft lora model | |
| from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType | |
| from plot_pdb import plot_struc | |
| def suggest(option): | |
| if option == "Plastic degradation protein": | |
| suggestion = "MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ" | |
| elif option == "Default protein": | |
| #suggestion = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE" | |
| suggestion = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" | |
| elif option == "Antifreeze protein": | |
| suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH" | |
| elif option == "AI Generated protein": | |
| suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS" | |
| elif option == "7-bladed propeller fold": | |
| suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK" | |
| else: | |
| suggestion = "" | |
| return suggestion | |
| # Helper Functions and Data Preparation | |
| def truncate_labels(labels, max_length): | |
| """Truncate labels to the specified max_length.""" | |
| return [label[:max_length] for label in labels] | |
| def compute_metrics(p): | |
| """Compute metrics for evaluation.""" | |
| predictions, labels = p | |
| predictions = np.argmax(predictions, axis=2) | |
| # Remove padding (-100 labels) | |
| predictions = predictions[labels != -100].flatten() | |
| labels = labels[labels != -100].flatten() | |
| # Compute accuracy | |
| accuracy = accuracy_score(labels, predictions) | |
| # Compute precision, recall, F1 score, and AUC | |
| precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary') | |
| auc = roc_auc_score(labels, predictions) | |
| # Compute MCC | |
| mcc = matthews_corrcoef(labels, predictions) | |
| return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc} | |
| def compute_loss(model, inputs): | |
| """Custom compute_loss function.""" | |
| logits = model(**inputs).logits | |
| labels = inputs["labels"] | |
| loss_fct = nn.CrossEntropyLoss(weight=class_weights) | |
| active_loss = inputs["attention_mask"].view(-1) == 1 | |
| active_logits = logits.view(-1, model.config.num_labels) | |
| active_labels = torch.where( | |
| active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) | |
| ) | |
| loss = loss_fct(active_logits, active_labels) | |
| return loss | |
| # Define Custom Trainer Class | |
| # Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer. | |
| class WeightedTrainer(Trainer): | |
| def compute_loss(self, model, inputs, return_outputs=False): | |
| outputs = model(**inputs) | |
| loss = compute_loss(model, inputs) | |
| return (loss, outputs) if return_outputs else loss | |
| # Predict binding site with finetuned PEFT model | |
| def predict_bind(base_model_path,PEFT_model_path,input_seq): | |
| # Load the model | |
| base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) | |
| loaded_model = PeftModel.from_pretrained(base_model, PEFT_model_path) | |
| # Ensure the model is in evaluation mode | |
| loaded_model.eval() | |
| # Tokenization | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_path) | |
| # Tokenize the sequence | |
| inputs = tokenizer(input_seq, return_tensors="pt", truncation=True, max_length=1024, padding='max_length') | |
| # Run the model | |
| with torch.no_grad(): | |
| logits = loaded_model(**inputs).logits | |
| # Get predictions | |
| tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens | |
| predictions = torch.argmax(logits, dim=2) | |
| binding_site=[] | |
| pos = 0 | |
| # Print the predicted labels for each token | |
| for token, prediction in zip(tokens, predictions[0].numpy()): | |
| if token not in ['<pad>', '<cls>', '<eos>']: | |
| pos += 1 | |
| print((pos, token, id2label[prediction])) | |
| if prediction == 1: | |
| print((pos, token, id2label[prediction])) | |
| binding_site.append([pos, token, id2label[prediction]]) | |
| return binding_site | |
| # fine-tuning function | |
| def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset): | |
| # Set the LoRA config | |
| config = { | |
| "lora_alpha": 1, #try 0.5, 1, 2, ..., 16 | |
| "lora_dropout": 0.2, | |
| "lr": 5.701568055793089e-04, | |
| "lr_scheduler_type": "cosine", | |
| "max_grad_norm": 0.5, | |
| "num_train_epochs": 1, #3, jw 20240628 | |
| "per_device_train_batch_size": 12, | |
| "r": 2, | |
| "weight_decay": 0.2, | |
| # Add other hyperparameters as needed | |
| } | |
| base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id) | |
| # Tokenization | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D") | |
| train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) | |
| test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) | |
| train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels) | |
| test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels) | |
| # Convert the model into a PeftModel | |
| peft_config = LoraConfig( | |
| task_type=TaskType.TOKEN_CLS, | |
| inference_mode=False, | |
| r=config["r"], | |
| lora_alpha=config["lora_alpha"], | |
| target_modules=["query", "key", "value"], # also try "dense_h_to_4h" and "dense_4h_to_h" | |
| lora_dropout=config["lora_dropout"], | |
| bias="none" # or "all" or "lora_only" | |
| ) | |
| base_model = get_peft_model(base_model, peft_config) | |
| # Use the accelerator | |
| base_model = accelerator.prepare(base_model) | |
| train_dataset = accelerator.prepare(train_dataset) | |
| test_dataset = accelerator.prepare(test_dataset) | |
| model_name_base = base_model_path.split("/")[1] | |
| timestamp = datetime.now().strftime('%Y-%m-%d_%H') | |
| save_path = f"{model_name_base}-lora-binding-sites_{timestamp}" | |
| # Training setup | |
| training_args = TrainingArguments( | |
| output_dir=save_path, #f"{model_name_base}-lora-binding-sites_{timestamp}", | |
| learning_rate=config["lr"], | |
| lr_scheduler_type=config["lr_scheduler_type"], | |
| gradient_accumulation_steps=1, | |
| max_grad_norm=config["max_grad_norm"], | |
| per_device_train_batch_size=config["per_device_train_batch_size"], | |
| per_device_eval_batch_size=config["per_device_train_batch_size"], | |
| num_train_epochs=config["num_train_epochs"], | |
| weight_decay=config["weight_decay"], | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="f1", | |
| greater_is_better=True, | |
| push_to_hub=True, #jw 20240701 False, | |
| logging_dir=None, | |
| logging_first_step=False, | |
| logging_steps=200, | |
| save_total_limit=7, | |
| no_cuda=False, | |
| seed=8893, | |
| fp16=True, | |
| #report_to='wandb' | |
| report_to=None, | |
| hub_token = HF_TOKEN, #jw 20240701 | |
| ) | |
| # Initialize Trainer | |
| trainer = WeightedTrainer( | |
| model=base_model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=test_dataset, | |
| tokenizer=tokenizer, | |
| data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer), | |
| compute_metrics=compute_metrics, | |
| ) | |
| # Train and Save Model | |
| trainer.train() | |
| return save_path | |
| # Constants & Globals | |
| HF_TOKEN = os.environ.get("HF_token") | |
| print("HF_TOKEN:",HF_TOKEN) | |
| MODEL_OPTIONS = [ | |
| "facebook/esm2_t6_8M_UR50D", | |
| "facebook/esm2_t12_35M_UR50D", | |
| "facebook/esm2_t33_650M_UR50D", | |
| ] # models users can choose from | |
| PEFT_MODEL_OPTIONS = [ | |
| "wangjin2000/esm2_t6_8M-lora-binding-sites_2024-07-02_09-26-54", | |
| "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3", | |
| ] # finetuned models | |
| # Load the data from pickle files (replace with your local paths) | |
| with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f: | |
| train_sequences = pickle.load(f) | |
| with open("./datasets/test_sequences_chunked_by_family.pkl", "rb") as f: | |
| test_sequences = pickle.load(f) | |
| with open("./datasets/train_labels_chunked_by_family.pkl", "rb") as f: | |
| train_labels = pickle.load(f) | |
| with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f: | |
| test_labels = pickle.load(f) | |
| max_sequence_length = 1000 | |
| # Directly truncate the entire list of labels | |
| train_labels = truncate_labels(train_labels, max_sequence_length) | |
| test_labels = truncate_labels(test_labels, max_sequence_length) | |
| # Compute Class Weights | |
| classes = [0, 1] | |
| flat_train_labels = [label for sublist in train_labels for label in sublist] | |
| class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels) | |
| accelerator = Accelerator() | |
| class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device) | |
| # Define labels and model | |
| id2label = {0: "No binding site", 1: "Binding site"} | |
| label2id = {v: k for k, v in id2label.items()} | |
| ''' | |
| # debug result | |
| dubug_result = saved_path #predictions #class_weights | |
| ''' | |
| demo = gr.Blocks(title="DEMO FOR ESM2Bind") | |
| with demo: | |
| gr.Markdown("# DEMO FOR ESM2Bind") | |
| #gr.Textbox(dubug_result) | |
| with gr.Column(): | |
| gr.Markdown("## Select a base model and a corresponding PEFT finetune model") | |
| with gr.Row(): | |
| with gr.Column(scale=5, variant="compact"): | |
| base_model_name = gr.Dropdown( | |
| choices=MODEL_OPTIONS, | |
| value=MODEL_OPTIONS[0], | |
| label="Base Model Name", | |
| interactive = True, | |
| ) | |
| PEFT_model_name = gr.Dropdown( | |
| choices=PEFT_MODEL_OPTIONS, | |
| value=PEFT_MODEL_OPTIONS[0], | |
| label="PEFT Model Name", | |
| interactive = True, | |
| ) | |
| with gr.Column(scale=5, variant="compact"): | |
| name = gr.Dropdown( | |
| label="Choose a Sample Protein", | |
| value="Default protein", | |
| choices=["Default protein", "Antifreeze protein", "Plastic degradation protein", "AI Generated protein", "7-bladed propeller fold", "custom"] | |
| ) | |
| gr.Markdown( | |
| "## Predict binding site and Plot structure for selected protein sequence:" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(variant="compact", scale = 8): | |
| input_seq = gr.Textbox( | |
| lines=1, | |
| max_lines=12, | |
| label="Protein sequency to be predicted:", | |
| value="MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT", | |
| placeholder="Paste your protein sequence here...", | |
| interactive = True, | |
| ) | |
| text_pos = gr.Textbox( | |
| lines=1, | |
| max_lines=12, | |
| label="Sequency Position:", | |
| placeholder= | |
| "012345678911234567892123456789312345678941234567895123456789612345678971234567898123456789912345678901234567891123456789", | |
| interactive=False, | |
| ) | |
| with gr.Column(variant="compact", scale = 2): | |
| predict_btn = gr.Button( | |
| value="Predict binding site", | |
| interactive=True, | |
| variant="primary", | |
| ) | |
| plot_struc_btn = gr.Button(value = "Plot ESMFold Predicted Structure ", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(variant="compact", scale = 5): | |
| output_text = gr.Textbox( | |
| lines=1, | |
| max_lines=12, | |
| label="Output", | |
| placeholder="Output", | |
| ) | |
| with gr.Column(variant="compact", scale = 5): | |
| finetune_button = gr.Button( | |
| value="Finetune Pre-trained Model", | |
| interactive=True, | |
| variant="primary", | |
| ) | |
| with gr.Row(): | |
| output_viewer = gr.HTML() | |
| output_file = gr.File( | |
| label="Download as Text File", | |
| file_count="single", | |
| type="filepath", | |
| interactive=False, | |
| ) | |
| # select protein sample | |
| name.change(fn=suggest, inputs=name, outputs=input_seq) | |
| # "Predict binding site" actions | |
| predict_btn.click( | |
| fn = predict_bind, | |
| inputs=[base_model_name,PEFT_model_name,input_seq], | |
| outputs = [output_text], | |
| ) | |
| # "Finetune Pre-trained Model" actions | |
| finetune_button.click( | |
| fn = train_function_no_sweeps, | |
| inputs=[base_model_name], | |
| outputs = [output_text], | |
| ) | |
| # plot protein structure | |
| plot_struc_btn.click(fn=plot_struc, inputs=input_seq, outputs=[output_file, output_viewer]) | |
| demo.launch() |