Spaces:
Runtime error
Runtime error
| from pytorch_lightning.loggers import WandbLogger | |
| import diffusion | |
| import torch | |
| import wandb | |
| import pytorch_lightning as pl | |
| import argparse | |
| import os | |
| torch.multiprocessing.set_sharing_strategy('file_system') | |
| def main(): | |
| # PARSERs | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| '--dataset', '-d', type=str, default='mnist', | |
| help='choose dataset' | |
| ) | |
| parser.add_argument( | |
| '--data_dir', '-dd', type=str, default='./data/', | |
| help='model name' | |
| ) | |
| parser.add_argument( | |
| '--mode', type=str, default='ddim', | |
| help='sampling mode' | |
| ) | |
| parser.add_argument( | |
| '--max_epochs', '-me', type=int, default=200, | |
| help='max epoch' | |
| ) | |
| parser.add_argument( | |
| '--batch_size', '-bs', type=int, default=32, | |
| help='batch size' | |
| ) | |
| parser.add_argument( | |
| '--train_ratio', '-tr', type=float, default=0.99, | |
| help='batch size' | |
| ) | |
| parser.add_argument( | |
| '--timesteps', '-ts', type=int, default=1000, | |
| help='max timesteps diffusion' | |
| ) | |
| parser.add_argument( | |
| '--max_batch_size', '-mbs', type=int, default=32, | |
| help='max batch size' | |
| ) | |
| parser.add_argument( | |
| '--lr', '-l', type=float, default=1e-4, | |
| help='learning rate' | |
| ) | |
| parser.add_argument( | |
| '--num_workers', '-nw', type=int, default=4, | |
| help='number of workers' | |
| ) | |
| parser.add_argument( | |
| '--seed', '-s', type=int, default=42, | |
| help='seed' | |
| ) | |
| parser.add_argument( | |
| '--name', '-n', type=str, default=None, | |
| help='name of the experiment' | |
| ) | |
| parser.add_argument( | |
| '--pbar', action='store_true', | |
| help='progress bar' | |
| ) | |
| parser.add_argument( | |
| '--precision', '-p', type=str, default='32', | |
| help='numerical precision' | |
| ) | |
| parser.add_argument( | |
| '--sample_per_epochs', '-spe', type=int, default=25, | |
| help='sample every n epochs' | |
| ) | |
| parser.add_argument( | |
| '--n_samples', '-ns', type=int, default=4, | |
| help='number of workers' | |
| ) | |
| parser.add_argument( | |
| '--monitor', '-m', type=str, default='val_loss', | |
| help='callbacks monitor' | |
| ) | |
| parser.add_argument( | |
| '--wandb', '-wk', type=str, default=None, | |
| help='wandb API key' | |
| ) | |
| args = parser.parse_args() | |
| # SEED | |
| pl.seed_everything(args.seed, workers=True) | |
| # WANDB (OPTIONAL) | |
| if args.wandb is not None: | |
| wandb.login(key=args.wandb) # API KEY | |
| name = args.name or f"diffusion-{args.max_epochs}-{args.batch_size}-{args.lr}" | |
| logger = WandbLogger( | |
| project="diffusion-model", | |
| name=name, | |
| log_model=False | |
| ) | |
| else: | |
| logger = None | |
| # DATAMODULE | |
| if args.dataset == "mnist": | |
| DATAMODULE = diffusion.MNISTDataModule | |
| img_dim = 32 | |
| num_classes = 10 | |
| elif args.dataset == "cifar10": | |
| DATAMODULE = diffusion.CIFAR10DataModule | |
| img_dim = 32 | |
| num_classes = 10 | |
| elif args.dataset == "celeba": | |
| DATAMODULE = diffusion.CelebADataModule | |
| img_dim = 64 | |
| num_classes = None | |
| datamodule = DATAMODULE( | |
| data_dir=args.data_dir, | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers, | |
| seed=args.seed, | |
| train_ratio=args.train_ratio, | |
| img_dim=img_dim | |
| ) | |
| # MODEL | |
| in_channels = 1 if args.dataset == "mnist" else 3 | |
| model = diffusion.DiffusionModel( | |
| lr=args.lr, | |
| in_channels=in_channels, | |
| sample_per_epochs=args.sample_per_epochs, | |
| max_timesteps=args.timesteps, | |
| dim=img_dim, | |
| num_classes=num_classes, | |
| n_samples=args.n_samples, | |
| mode=args.mode | |
| ) | |
| # CALLBACK | |
| root_path = os.path.join(os.getcwd(), "checkpoints") | |
| callback = diffusion.ModelCallback( | |
| root_path=root_path, | |
| ckpt_monitor=args.monitor | |
| ) | |
| # STRATEGY | |
| strategy = 'ddp_find_unused_parameters_true' if torch.cuda.is_available() else 'auto' | |
| # TRAINER | |
| trainer = pl.Trainer( | |
| default_root_dir=root_path, | |
| logger=logger, | |
| callbacks=callback.get_callback(), | |
| gradient_clip_val=0.5, | |
| max_epochs=args.max_epochs, | |
| enable_progress_bar=args.pbar, | |
| deterministic=False, | |
| precision=args.precision, | |
| strategy=strategy, | |
| accumulate_grad_batches=max(int(args.max_batch_size / args.batch_size), 1) | |
| ) | |
| # FIT MODEL | |
| trainer.fit(model=model, datamodule=datamodule) | |
| if __name__ == '__main__': | |
| main() | |