genau-demo / GenAU /train /1d_vae.py
Moayed's picture
add demo files
cef9e84
# Author: Moayed Haji Ali
# Email: mh155@rice.edu
# Date: 8 June 2024
# based on code from
# Author: Haohe Liu
# Email: haoheliu@gmail.com
# Date: 11 Feb 2023
import os
import argparse
import torch
import shutil
from pytorch_lightning.strategies.ddp import DDPStrategy
from torch.utils.data import DataLoader
from src.tools.logger import Logger
from pytorch_lightning import Trainer
from src.utilities.model.model_checkpoint import S3ModelCheckpoint
from src.utilities.data.videoaudio_dataset import VideoAudioDataset
from src.modules.latent_encoder.autoencoder_1d import AutoencoderKL1D
from src.tools.training_utils import get_restore_step
from src.tools.configuration import Configuration
def main(configs, config_yaml_path, exp_group_name, exp_name, debug=False):
if "precision" in configs['training'].keys():
torch.set_float32_matmul_precision(
configs['training']["precision"]
) # highest, high, medium
batch_size = configs["model"]["params"]["batchsize"]
max_epochs = configs['step']['max_epochs']
limit_val_batches = configs["step"].get("limit_val_batches", None)
limit_train_batches = configs["step"].get("limit_train_batches", None)
log_path = configs['logging']["log_directory"]
save_top_k = configs["logging"].get("save_top_k", -1)
save_checkpoint_every_n_steps = configs["logging"].get("save_checkpoint_every_n_steps", 5000)
if "dataloader_add_ons" in configs["data"].keys():
dataloader_add_ons = configs["data"]["dataloader_add_ons"]
else:
dataloader_add_ons = []
augment_p = configs['data']['augment_p'] if 'augment_p' in configs['data'] else 0.0
dataset = VideoAudioDataset(configs, split="train", add_ons=dataloader_add_ons, load_video=False, load_audio=True, sample_single_caption=True, augment_p=augment_p)
print("[INFO] Using batch size of:", batch_size)
loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=configs['data'].get('num_workers', 32),
pin_memory=True,
shuffle=True,
)
print(
"[INFO] The length of the dataset is %s, the length of the dataloader is %s, the batchsize is %s"
% (len(dataset), len(loader), batch_size)
)
val_dataset = VideoAudioDataset(configs, split="test", add_ons=dataloader_add_ons, load_video=False, load_audio=True, sample_single_caption=True)
val_loader = DataLoader(
val_dataset,
num_workers=configs['data'].get('num_workers', 32),
batch_size=batch_size,
)
devices = torch.cuda.device_count()
bs, base_lr = batch_size, configs["model"]["base_learning_rate"]
learning_rate = base_lr
model = AutoencoderKL1D(
ddconfig=configs["model"]["params"]["ddconfig"],
lossconfig=configs["model"]["params"]["lossconfig"],
embed_dim=configs["model"]["params"]["embed_dim"],
image_key=configs["model"]["params"]["image_key"],
base_learning_rate=learning_rate,
subband=configs["model"]["params"]["subband"],
sampling_rate=configs["preprocessing"]["audio"]["sampling_rate"],
)
try:
config_reload_from_ckpt = configs["reload_from_ckpt"]
except:
config_reload_from_ckpt = None
checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints")
checkpoint_callback = S3ModelCheckpoint(
bucket_name=configs['logging'].get('S3_BUCKET', None),
s3_folder=configs['logging'].get('S3_FOLDER', None),
dirpath=checkpoint_path,
monitor="global_step",
mode="max",
filename="checkpoint-fad-{val/frechet_inception_distance:.2f}-global_step={global_step:.0f}",
every_n_train_steps=save_checkpoint_every_n_steps,
save_top_k=save_top_k,
auto_insert_metric_name=False,
save_last=False,
)
wandb_path = os.path.join(log_path, exp_group_name, exp_name)
config_copy_dir = os.path.join(wandb_path, 'config')
os.makedirs(config_copy_dir, exist_ok=True)
shutil.copy(config_yaml_path, config_copy_dir)
model.set_log_dir(log_path, exp_group_name, exp_name)
os.makedirs(checkpoint_path, exist_ok=True)
if len(os.listdir(checkpoint_path)) > 0 and "resume_training" in configs['training'] and configs['training']["resume_training"]:
print("[INFO] Load checkpoint from path: %s" % checkpoint_path)
restore_step, n_step = get_restore_step(checkpoint_path)
print("[INFO] Resuming from step", n_step)
resume_from_checkpoint = os.path.join(checkpoint_path, restore_step)
print("[INFO] Resume from checkpoint", resume_from_checkpoint)
elif config_reload_from_ckpt is not None:
resume_from_checkpoint = config_reload_from_ckpt
print("[INFO] Reload ckpt specified in the config file %s" % resume_from_checkpoint)
else:
print("[INFO] Training from scratch")
resume_from_checkpoint = None
wandb_logger = Logger(
config=configs,
checkpoints_directory=wandb_path,
run_name="%s/%s" % (exp_group_name, exp_name),
offline=debug
).get_logger()
nodes_count = configs['training']["nodes_count"]
if nodes_count == -1:
if "WORLD_SIZE" in os.environ:
nodes_count = int(os.environ["WORLD_SIZE"]) // torch.cuda.device_count()
else:
nodes_count = 1
print("[INFO] Training on devices", devices)
trainer = Trainer(
accelerator="gpu",
devices=devices,
logger=wandb_logger,
num_sanity_val_steps=1,
num_nodes=nodes_count,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
max_epochs=max_epochs,
callbacks=[checkpoint_callback],
strategy=DDPStrategy(find_unused_parameters=True),
gradient_clip_val=configs["model"]["params"].get("clip_grad", None)
)
# TRAINING
trainer.fit(model, loader, val_loader, ckpt_path=resume_from_checkpoint)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--autoencoder_config",
type=str,
required=True,
help="path to autoencoder config .yam",
)
parser.add_argument(
"-d",
"--debug",
action="store_true",
default=False,
help="debug mode",
)
args = parser.parse_args()
config_yaml = args.autoencoder_config
exp_name = os.path.basename(config_yaml.split(".")[0])
exp_group_name = os.path.basename(os.path.dirname(config_yaml))
configuration = Configuration(config_yaml)
configs = configuration.get_config()
main(configs, config_yaml, exp_group_name, exp_name, args.debug)