PepTune / main.py
Yinuo Zhang
upload data
d65f3a2
#!/usr/bin/env
import os
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096'
import uuid
import wandb
import fsspec
import hydra
import lightning as L
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, GradientAccumulationScheduler
import omegaconf
import rich.syntax
import rich.tree
import torch
import sys
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from . import dataset as dataloader
from . import dataloading_for_dynamic_batching as dynamic_dataloader
from .diffusion import Diffusion
from .utils import utils
from .new_tokenizer.ape_tokenizer import APETokenizer
from .tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
from .helm_tokenizer.helm_tokenizer import HelmTokenizer
from lightning.pytorch.strategies import DDPStrategy
from datasets import load_dataset
omegaconf.OmegaConf.register_new_resolver('cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver('device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver('eval', eval)
omegaconf.OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)
omegaconf.OmegaConf.register_new_resolver("env_or", lambda k, d: os.getenv(k, d))
def _load_from_checkpoint(config, tokenizer):
"""Create Diffusion model; load weights if checkpoint_path is set."""
if "hf" in str(config.get("backbone", "")):
return Diffusion(config, tokenizer=tokenizer).to("cuda")
ckpt_path = config.eval.checkpoint_path
model = Diffusion.load_from_checkpoint(
ckpt_path,
tokenizer=tokenizer,
config=config,
map_location="cuda" if torch.cuda.is_available() else "cpu",
)
return model
@L.pytorch.utilities.rank_zero_only
def print_config(
config: omegaconf.DictConfig,
resolve: bool = True,
save_cfg: bool = True) -> None:
"""
Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
resolve (bool): Whether to resolve reference fields of DictConfig.
save_cfg (bool): Whether to save the configuration tree to a file.
"""
style = 'dim'
tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)
fields = config.keys()
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, omegaconf.DictConfig):
branch_content = omegaconf.OmegaConf.to_yaml(
config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
rich.print(tree)
if save_cfg:
with fsspec.open(
'{}/config_tree.txt'.format(
config.checkpointing.save_dir), 'w') as fp:
rich.print(tree, file=fp)
@L.pytorch.utilities.rank_zero_only
def print_batch(train_ds, valid_ds, tokenizer, k=64):
#for dl_type, dl in [
#('train', train_ds), ('valid', valid_ds)]:
for dl_type, dl in [
('train', train_ds)]:
print(f'Printing {dl_type} dataloader batch.')
batch = next(iter(dl))
print('Batch input_ids.shape', batch['input_ids'].shape)
first = batch['input_ids'][0, :k]
last = batch['input_ids'][0, -k:]
print(f'First {k} tokens:', tokenizer.decode(first))
print('ids:', first)
print(f'Last {k} tokens:', tokenizer.decode(last))
print('ids:', last)
def generate_samples(config, logger, tokenizer):
logger.info('Generating samples.')
model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
# model.gen_ppl_metric.reset()
#stride_length = config.sampling.stride_length
#num_strides = config.sampling.num_strides
for _ in range(config.sampling.num_sample_batches):
samples = model.restore_model_and_sample(num_steps=config.sampling.steps)
peptide_sequences = model.tokenizer.batch_decode(samples)
model.compute_generative_perplexity(peptide_sequences)
print('Peptide samples:', peptide_sequences)
print('Generative perplexity:', model.compute_masked_perplexity())
return peptide_sequences
def ppl_eval(config, logger, tokenizer, data_module):
logger.info('Starting Zero Shot Eval.')
model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
wandb_logger = None
if config.get('wandb', None) is not None:
wandb_logger = L.pytorch.loggers.WandbLogger(
config=omegaconf.OmegaConf.to_object(config),
** config.wandb)
callbacks = []
if 'callbacks' in config:
for _, callback in config.callbacks.items():
callbacks.append(hydra.utils.instantiate(callback))
trainer = hydra.utils.instantiate(
config.trainer,
default_root_dir=os.getcwd(),
callbacks=callbacks,
strategy=DDPStrategy(find_unused_parameters = True),
logger=wandb_logger)
#_, valid_ds = dataloader.get_dataloaders(config, tokenizer, skiptrain=True, valid_seed=config.seed)
trainer.test(model, data_module)
def _train(config, logger, tokenizer, data_module):
logger.info('Starting Training.')
wandb_logger = None
if config.get('wandb', None) is not None:
unique_id = str(uuid.uuid4())
config.wandb.id = f"{config.wandb.id}_{unique_id}"
wandb_logger = L.pytorch.loggers.WandbLogger(
config=omegaconf.OmegaConf.to_object(config),
** config.wandb)
if (config.checkpointing.resume_from_ckpt
and config.checkpointing.resume_ckpt_path is not None
and utils.fsspec_exists(
config.checkpointing.resume_ckpt_path)):
ckpt_path = config.checkpointing.resume_ckpt_path
else:
ckpt_path = None
# Lightning callbacks
callbacks = []
if 'callbacks' in config:
for callback_name, callback_config in config.callbacks.items():
if callback_name == 'model_checkpoint':
model_checkpoint_config = {k: v for k, v in callback_config.items() if k != '_target_'}
callbacks.append(ModelCheckpoint(**model_checkpoint_config))
else:
callbacks.append(hydra.utils.instantiate(callback_config))
if config.training.accumulator:
accumulator = GradientAccumulationScheduler(scheduling = {1: 5, 2: 4, 3: 3, 4: 1})
callbacks.append(accumulator)
trainer = hydra.utils.instantiate(
config.trainer,
default_root_dir=os.getcwd(),
callbacks=callbacks,
accelerator='cuda',
strategy=DDPStrategy(find_unused_parameters = True),
devices=[2,3,4,5,6,7],
logger=wandb_logger)
model = Diffusion(config, tokenizer=tokenizer)
if config.backbone == "finetune_roformer" and config.eval.checkpoint_path:
checkpoint = torch.load(config.eval.checkpoint_path, map_location="cpu")
state = checkpoint.get("state_dict", checkpoint)
model.load_state_dict(state, strict=False)
trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)
@hydra.main(version_base=None, config_path='configs', config_name='config')
def main(config):
"""
Main entry point for training
"""
L.seed_everything(config.seed)
# print_config(config, resolve=True, save_cfg=True)
logger = utils.get_logger(__name__)
# load PeptideCLM tokenizer
tok_dir = config.paths.tokenizers
if config.vocab == 'new_smiles':
tokenizer = APETokenizer()
tokenizer.load_vocabulary(f'{tok_dir}/peptide_smiles_600_vocab.json')
elif config.vocab == 'old_smiles':
tokenizer = SMILES_SPE_Tokenizer(f'{tok_dir}/new_vocab.txt',
f'{tok_dir}/new_splits.txt')
elif config.vocab == 'selfies':
tokenizer = APETokenizer()
tokenizer.load_vocabulary(f'{tok_dir}/peptide_selfies_600_vocab.json')
elif config.vocab == 'helm':
tokenizer = HelmTokenizer(f'{tok_dir}/monomer_vocab.txt')
if config.backbone == 'finetune_roformer':
train_dataset = load_dataset('csv', data_files=config.data.train)
val_dataset = load_dataset('csv', data_files=config.data.valid)
train_dataset = train_dataset['train']#.select(lst)
val_dataset = val_dataset['train']#.select(lst)
data_module = dataloader.CustomDataModule(train_dataset, val_dataset, None, tokenizer, batch_size=config.loader.global_batch_size)
else:
data_module = dynamic_dataloader.CustomDataModule(f'{config.paths.data}/smiles/11M_smiles_old_tokenizer_no_limit', tokenizer)
if config.mode == 'sample_eval':
generate_samples(config, logger, tokenizer)
elif config.mode == 'ppl_eval':
ppl_eval(config, logger, tokenizer, data_module)
else:
_train(config, logger, tokenizer, data_module)
if __name__ == '__main__':
main()