File size: 2,202 Bytes
d04a061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python3

import os
import wandb
import lightning.pytorch as pl

from omegaconf import OmegaConf
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor

from src.utils.model_utils import _print
from src.guidance.solubility_module import SolubilityClassifier
from src.guidance.dataloader import MembraneDataModule, get_datasets


config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml")
wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f')

# data
datasets = get_datasets(config)
data_module = MembraneDataModule(
    config=config,
    train_dataset=datasets['train'],
    val_dataset=datasets['val'],
    test_dataset=datasets['test'],
)

# wandb logging
#wandb.init(project=config.wandb.project, name=config.wandb.name)
wandb_logger = WandbLogger(**config.wandb)

# lightning checkpoints
lr_monitor = LearningRateMonitor(logging_interval="step")
checkpoint_callback = ModelCheckpoint(
    monitor="val/loss",
    save_top_k=1,
    mode="min",
    dirpath=config.checkpointing.save_dir,
    filename="best_model",
)

# lightning trainer
trainer = pl.Trainer(
    max_steps=config.training.max_steps,
    accelerator="cuda",
    devices=1, #config.training.devices if config.training.mode=='train' else [0],
    #strategy=DDPStrategy(find_unused_parameters=True),
    callbacks=[checkpoint_callback, lr_monitor],
    logger=wandb_logger,
    log_every_n_steps=config.training.log_every_n_steps
)

# Folder to save checkpoints
ckpt_dir = config.checkpointing.save_dir
os.makedirs(ckpt_dir, exist_ok=True)

# instantiate model
model = SolubilityClassifier(config)

# train or evalute the model
if config.training.mode == "train":
    trainer.fit(model, datamodule=data_module)

elif config.training.mode == "test":
    ckpt_path = os.path.join(ckpt_dir, "best_model.ckpt")
    state_dict = model.get_state_dict(ckpt_path)
    model.load_state_dict(state_dict)
    trainer.test(model, datamodule=data_module, ckpt_path=ckpt_path)
else:
    raise ValueError(f"{config.training.mode} is invalid. Must be 'train' or 'test'")

wandb.finish()