Spaces:
Runtime error
Runtime error
| # Author: Moayed Haji Ali | |
| # Email: mh155@rice.edu | |
| # Date: 8 June 2024 | |
| # based on code from | |
| # Author: Haohe Liu | |
| # Email: haoheliu@gmail.com | |
| # Date: 11 Feb 2023 | |
| import os | |
| import argparse | |
| import torch | |
| import shutil | |
| from pytorch_lightning.strategies.ddp import DDPStrategy | |
| from torch.utils.data import DataLoader | |
| from src.tools.logger import Logger | |
| from pytorch_lightning import Trainer | |
| from src.utilities.model.model_checkpoint import S3ModelCheckpoint | |
| from src.utilities.data.videoaudio_dataset import VideoAudioDataset | |
| from src.modules.latent_encoder.autoencoder_1d import AutoencoderKL1D | |
| from src.tools.training_utils import get_restore_step | |
| from src.tools.configuration import Configuration | |
| def main(configs, config_yaml_path, exp_group_name, exp_name, debug=False): | |
| if "precision" in configs['training'].keys(): | |
| torch.set_float32_matmul_precision( | |
| configs['training']["precision"] | |
| ) # highest, high, medium | |
| batch_size = configs["model"]["params"]["batchsize"] | |
| max_epochs = configs['step']['max_epochs'] | |
| limit_val_batches = configs["step"].get("limit_val_batches", None) | |
| limit_train_batches = configs["step"].get("limit_train_batches", None) | |
| log_path = configs['logging']["log_directory"] | |
| save_top_k = configs["logging"].get("save_top_k", -1) | |
| save_checkpoint_every_n_steps = configs["logging"].get("save_checkpoint_every_n_steps", 5000) | |
| if "dataloader_add_ons" in configs["data"].keys(): | |
| dataloader_add_ons = configs["data"]["dataloader_add_ons"] | |
| else: | |
| dataloader_add_ons = [] | |
| augment_p = configs['data']['augment_p'] if 'augment_p' in configs['data'] else 0.0 | |
| dataset = VideoAudioDataset(configs, split="train", add_ons=dataloader_add_ons, load_video=False, load_audio=True, sample_single_caption=True, augment_p=augment_p) | |
| print("[INFO] Using batch size of:", batch_size) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| num_workers=configs['data'].get('num_workers', 32), | |
| pin_memory=True, | |
| shuffle=True, | |
| ) | |
| print( | |
| "[INFO] The length of the dataset is %s, the length of the dataloader is %s, the batchsize is %s" | |
| % (len(dataset), len(loader), batch_size) | |
| ) | |
| val_dataset = VideoAudioDataset(configs, split="test", add_ons=dataloader_add_ons, load_video=False, load_audio=True, sample_single_caption=True) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| num_workers=configs['data'].get('num_workers', 32), | |
| batch_size=batch_size, | |
| ) | |
| devices = torch.cuda.device_count() | |
| bs, base_lr = batch_size, configs["model"]["base_learning_rate"] | |
| learning_rate = base_lr | |
| model = AutoencoderKL1D( | |
| ddconfig=configs["model"]["params"]["ddconfig"], | |
| lossconfig=configs["model"]["params"]["lossconfig"], | |
| embed_dim=configs["model"]["params"]["embed_dim"], | |
| image_key=configs["model"]["params"]["image_key"], | |
| base_learning_rate=learning_rate, | |
| subband=configs["model"]["params"]["subband"], | |
| sampling_rate=configs["preprocessing"]["audio"]["sampling_rate"], | |
| ) | |
| try: | |
| config_reload_from_ckpt = configs["reload_from_ckpt"] | |
| except: | |
| config_reload_from_ckpt = None | |
| checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") | |
| checkpoint_callback = S3ModelCheckpoint( | |
| bucket_name=configs['logging'].get('S3_BUCKET', None), | |
| s3_folder=configs['logging'].get('S3_FOLDER', None), | |
| dirpath=checkpoint_path, | |
| monitor="global_step", | |
| mode="max", | |
| filename="checkpoint-fad-{val/frechet_inception_distance:.2f}-global_step={global_step:.0f}", | |
| every_n_train_steps=save_checkpoint_every_n_steps, | |
| save_top_k=save_top_k, | |
| auto_insert_metric_name=False, | |
| save_last=False, | |
| ) | |
| wandb_path = os.path.join(log_path, exp_group_name, exp_name) | |
| config_copy_dir = os.path.join(wandb_path, 'config') | |
| os.makedirs(config_copy_dir, exist_ok=True) | |
| shutil.copy(config_yaml_path, config_copy_dir) | |
| model.set_log_dir(log_path, exp_group_name, exp_name) | |
| os.makedirs(checkpoint_path, exist_ok=True) | |
| if len(os.listdir(checkpoint_path)) > 0 and "resume_training" in configs['training'] and configs['training']["resume_training"]: | |
| print("[INFO] Load checkpoint from path: %s" % checkpoint_path) | |
| restore_step, n_step = get_restore_step(checkpoint_path) | |
| print("[INFO] Resuming from step", n_step) | |
| resume_from_checkpoint = os.path.join(checkpoint_path, restore_step) | |
| print("[INFO] Resume from checkpoint", resume_from_checkpoint) | |
| elif config_reload_from_ckpt is not None: | |
| resume_from_checkpoint = config_reload_from_ckpt | |
| print("[INFO] Reload ckpt specified in the config file %s" % resume_from_checkpoint) | |
| else: | |
| print("[INFO] Training from scratch") | |
| resume_from_checkpoint = None | |
| wandb_logger = Logger( | |
| config=configs, | |
| checkpoints_directory=wandb_path, | |
| run_name="%s/%s" % (exp_group_name, exp_name), | |
| offline=debug | |
| ).get_logger() | |
| nodes_count = configs['training']["nodes_count"] | |
| if nodes_count == -1: | |
| if "WORLD_SIZE" in os.environ: | |
| nodes_count = int(os.environ["WORLD_SIZE"]) // torch.cuda.device_count() | |
| else: | |
| nodes_count = 1 | |
| print("[INFO] Training on devices", devices) | |
| trainer = Trainer( | |
| accelerator="gpu", | |
| devices=devices, | |
| logger=wandb_logger, | |
| num_sanity_val_steps=1, | |
| num_nodes=nodes_count, | |
| limit_train_batches=limit_train_batches, | |
| limit_val_batches=limit_val_batches, | |
| max_epochs=max_epochs, | |
| callbacks=[checkpoint_callback], | |
| strategy=DDPStrategy(find_unused_parameters=True), | |
| gradient_clip_val=configs["model"]["params"].get("clip_grad", None) | |
| ) | |
| # TRAINING | |
| trainer.fit(model, loader, val_loader, ckpt_path=resume_from_checkpoint) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-c", | |
| "--autoencoder_config", | |
| type=str, | |
| required=True, | |
| help="path to autoencoder config .yam", | |
| ) | |
| parser.add_argument( | |
| "-d", | |
| "--debug", | |
| action="store_true", | |
| default=False, | |
| help="debug mode", | |
| ) | |
| args = parser.parse_args() | |
| config_yaml = args.autoencoder_config | |
| exp_name = os.path.basename(config_yaml.split(".")[0]) | |
| exp_group_name = os.path.basename(os.path.dirname(config_yaml)) | |
| configuration = Configuration(config_yaml) | |
| configs = configuration.get_config() | |
| main(configs, config_yaml, exp_group_name, exp_name, args.debug) | |