|
|
|
|
|
import os |
|
|
|
|
|
import uuid |
|
|
|
|
|
import wandb |
|
|
import fsspec |
|
|
import hydra |
|
|
import lightning as L |
|
|
from lightning.pytorch import Trainer |
|
|
from lightning.pytorch.callbacks import ModelCheckpoint, GradientAccumulationScheduler |
|
|
import omegaconf |
|
|
import rich.syntax |
|
|
import rich.tree |
|
|
import torch |
|
|
import sys |
|
|
import torch.distributed as dist |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
from . import dataset as dataloader |
|
|
from . import dataloading_for_dynamic_batching as dynamic_dataloader |
|
|
from .diffusion import Diffusion |
|
|
from .utils import utils |
|
|
from .new_tokenizer.ape_tokenizer import APETokenizer |
|
|
from .tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
from .helm_tokenizer.helm_tokenizer import HelmTokenizer |
|
|
|
|
|
from lightning.pytorch.strategies import DDPStrategy |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
|
|
|
omegaconf.OmegaConf.register_new_resolver('cwd', os.getcwd) |
|
|
omegaconf.OmegaConf.register_new_resolver('device_count', torch.cuda.device_count) |
|
|
omegaconf.OmegaConf.register_new_resolver('eval', eval) |
|
|
omegaconf.OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y) |
|
|
omegaconf.OmegaConf.register_new_resolver("env_or", lambda k, d: os.getenv(k, d)) |
|
|
|
|
|
def _load_from_checkpoint(config, tokenizer): |
|
|
"""Create Diffusion model; load weights if checkpoint_path is set.""" |
|
|
if "hf" in str(config.get("backbone", "")): |
|
|
return Diffusion(config, tokenizer=tokenizer).to("cuda") |
|
|
|
|
|
ckpt_path = config.eval.checkpoint_path |
|
|
model = Diffusion.load_from_checkpoint( |
|
|
ckpt_path, |
|
|
tokenizer=tokenizer, |
|
|
config=config, |
|
|
map_location="cuda" if torch.cuda.is_available() else "cpu", |
|
|
) |
|
|
return model |
|
|
|
|
|
@L.pytorch.utilities.rank_zero_only |
|
|
def print_config( |
|
|
config: omegaconf.DictConfig, |
|
|
resolve: bool = True, |
|
|
save_cfg: bool = True) -> None: |
|
|
""" |
|
|
Prints content of DictConfig using Rich library and its tree structure. |
|
|
|
|
|
Args: |
|
|
config (DictConfig): Configuration composed by Hydra. |
|
|
resolve (bool): Whether to resolve reference fields of DictConfig. |
|
|
save_cfg (bool): Whether to save the configuration tree to a file. |
|
|
""" |
|
|
|
|
|
style = 'dim' |
|
|
tree = rich.tree.Tree('CONFIG', style=style, guide_style=style) |
|
|
|
|
|
fields = config.keys() |
|
|
for field in fields: |
|
|
branch = tree.add(field, style=style, guide_style=style) |
|
|
|
|
|
config_section = config.get(field) |
|
|
branch_content = str(config_section) |
|
|
if isinstance(config_section, omegaconf.DictConfig): |
|
|
branch_content = omegaconf.OmegaConf.to_yaml( |
|
|
config_section, resolve=resolve) |
|
|
|
|
|
branch.add(rich.syntax.Syntax(branch_content, 'yaml')) |
|
|
rich.print(tree) |
|
|
if save_cfg: |
|
|
with fsspec.open( |
|
|
'{}/config_tree.txt'.format( |
|
|
config.checkpointing.save_dir), 'w') as fp: |
|
|
rich.print(tree, file=fp) |
|
|
|
|
|
|
|
|
@L.pytorch.utilities.rank_zero_only |
|
|
def print_batch(train_ds, valid_ds, tokenizer, k=64): |
|
|
|
|
|
|
|
|
|
|
|
for dl_type, dl in [ |
|
|
('train', train_ds)]: |
|
|
print(f'Printing {dl_type} dataloader batch.') |
|
|
batch = next(iter(dl)) |
|
|
print('Batch input_ids.shape', batch['input_ids'].shape) |
|
|
first = batch['input_ids'][0, :k] |
|
|
last = batch['input_ids'][0, -k:] |
|
|
print(f'First {k} tokens:', tokenizer.decode(first)) |
|
|
print('ids:', first) |
|
|
print(f'Last {k} tokens:', tokenizer.decode(last)) |
|
|
print('ids:', last) |
|
|
|
|
|
|
|
|
def generate_samples(config, logger, tokenizer): |
|
|
logger.info('Generating samples.') |
|
|
model = _load_from_checkpoint(config=config, tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for _ in range(config.sampling.num_sample_batches): |
|
|
samples = model.restore_model_and_sample(num_steps=config.sampling.steps) |
|
|
peptide_sequences = model.tokenizer.batch_decode(samples) |
|
|
model.compute_generative_perplexity(peptide_sequences) |
|
|
|
|
|
print('Peptide samples:', peptide_sequences) |
|
|
|
|
|
print('Generative perplexity:', model.compute_masked_perplexity()) |
|
|
|
|
|
return peptide_sequences |
|
|
|
|
|
|
|
|
def ppl_eval(config, logger, tokenizer, data_module): |
|
|
logger.info('Starting Zero Shot Eval.') |
|
|
|
|
|
model = _load_from_checkpoint(config=config, tokenizer=tokenizer) |
|
|
|
|
|
wandb_logger = None |
|
|
if config.get('wandb', None) is not None: |
|
|
wandb_logger = L.pytorch.loggers.WandbLogger( |
|
|
config=omegaconf.OmegaConf.to_object(config), |
|
|
** config.wandb) |
|
|
|
|
|
callbacks = [] |
|
|
|
|
|
if 'callbacks' in config: |
|
|
for _, callback in config.callbacks.items(): |
|
|
callbacks.append(hydra.utils.instantiate(callback)) |
|
|
|
|
|
trainer = hydra.utils.instantiate( |
|
|
config.trainer, |
|
|
default_root_dir=os.getcwd(), |
|
|
callbacks=callbacks, |
|
|
strategy=DDPStrategy(find_unused_parameters = True), |
|
|
logger=wandb_logger) |
|
|
|
|
|
|
|
|
trainer.test(model, data_module) |
|
|
|
|
|
|
|
|
def _train(config, logger, tokenizer, data_module): |
|
|
logger.info('Starting Training.') |
|
|
wandb_logger = None |
|
|
|
|
|
if config.get('wandb', None) is not None: |
|
|
unique_id = str(uuid.uuid4()) |
|
|
|
|
|
config.wandb.id = f"{config.wandb.id}_{unique_id}" |
|
|
|
|
|
wandb_logger = L.pytorch.loggers.WandbLogger( |
|
|
config=omegaconf.OmegaConf.to_object(config), |
|
|
** config.wandb) |
|
|
|
|
|
if (config.checkpointing.resume_from_ckpt |
|
|
and config.checkpointing.resume_ckpt_path is not None |
|
|
and utils.fsspec_exists( |
|
|
config.checkpointing.resume_ckpt_path)): |
|
|
ckpt_path = config.checkpointing.resume_ckpt_path |
|
|
else: |
|
|
ckpt_path = None |
|
|
|
|
|
|
|
|
callbacks = [] |
|
|
if 'callbacks' in config: |
|
|
for callback_name, callback_config in config.callbacks.items(): |
|
|
if callback_name == 'model_checkpoint': |
|
|
model_checkpoint_config = {k: v for k, v in callback_config.items() if k != '_target_'} |
|
|
callbacks.append(ModelCheckpoint(**model_checkpoint_config)) |
|
|
else: |
|
|
callbacks.append(hydra.utils.instantiate(callback_config)) |
|
|
|
|
|
if config.training.accumulator: |
|
|
accumulator = GradientAccumulationScheduler(scheduling = {1: 5, 2: 4, 3: 3, 4: 1}) |
|
|
callbacks.append(accumulator) |
|
|
|
|
|
trainer = hydra.utils.instantiate( |
|
|
config.trainer, |
|
|
default_root_dir=os.getcwd(), |
|
|
callbacks=callbacks, |
|
|
accelerator='cuda', |
|
|
strategy=DDPStrategy(find_unused_parameters = True), |
|
|
devices=[2,3,4,5,6,7], |
|
|
logger=wandb_logger) |
|
|
|
|
|
model = Diffusion(config, tokenizer=tokenizer) |
|
|
|
|
|
if config.backbone == "finetune_roformer" and config.eval.checkpoint_path: |
|
|
checkpoint = torch.load(config.eval.checkpoint_path, map_location="cpu") |
|
|
state = checkpoint.get("state_dict", checkpoint) |
|
|
model.load_state_dict(state, strict=False) |
|
|
|
|
|
trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path) |
|
|
|
|
|
|
|
|
@hydra.main(version_base=None, config_path='configs', config_name='config') |
|
|
def main(config): |
|
|
""" |
|
|
Main entry point for training |
|
|
""" |
|
|
L.seed_everything(config.seed) |
|
|
|
|
|
|
|
|
|
|
|
logger = utils.get_logger(__name__) |
|
|
|
|
|
tok_dir = config.paths.tokenizers |
|
|
if config.vocab == 'new_smiles': |
|
|
tokenizer = APETokenizer() |
|
|
tokenizer.load_vocabulary(f'{tok_dir}/peptide_smiles_600_vocab.json') |
|
|
elif config.vocab == 'old_smiles': |
|
|
tokenizer = SMILES_SPE_Tokenizer(f'{tok_dir}/new_vocab.txt', |
|
|
f'{tok_dir}/new_splits.txt') |
|
|
elif config.vocab == 'selfies': |
|
|
tokenizer = APETokenizer() |
|
|
tokenizer.load_vocabulary(f'{tok_dir}/peptide_selfies_600_vocab.json') |
|
|
elif config.vocab == 'helm': |
|
|
tokenizer = HelmTokenizer(f'{tok_dir}/monomer_vocab.txt') |
|
|
|
|
|
if config.backbone == 'finetune_roformer': |
|
|
train_dataset = load_dataset('csv', data_files=config.data.train) |
|
|
val_dataset = load_dataset('csv', data_files=config.data.valid) |
|
|
|
|
|
train_dataset = train_dataset['train'] |
|
|
val_dataset = val_dataset['train'] |
|
|
data_module = dataloader.CustomDataModule(train_dataset, val_dataset, None, tokenizer, batch_size=config.loader.global_batch_size) |
|
|
else: |
|
|
data_module = dynamic_dataloader.CustomDataModule(f'{config.paths.data}/smiles/11M_smiles_old_tokenizer_no_limit', tokenizer) |
|
|
|
|
|
if config.mode == 'sample_eval': |
|
|
generate_samples(config, logger, tokenizer) |
|
|
elif config.mode == 'ppl_eval': |
|
|
ppl_eval(config, logger, tokenizer, data_module) |
|
|
else: |
|
|
_train(config, logger, tokenizer, data_module) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|