Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """Trains Karras et al. (2022) diffusion models.""" | |
| import argparse | |
| from copy import deepcopy | |
| from functools import partial | |
| import math | |
| import json | |
| from pathlib import Path | |
| import accelerate | |
| import torch | |
| from torch import nn, optim | |
| from torch import multiprocessing as mp | |
| from torch.utils import data | |
| from torchvision import datasets, transforms, utils | |
| from tqdm.auto import trange, tqdm | |
| import k_diffusion as K | |
| def main(): | |
| p = argparse.ArgumentParser(description=__doc__, | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
| p.add_argument('--batch-size', type=int, default=64, | |
| help='the batch size') | |
| p.add_argument('--config', type=str, required=True, | |
| help='the configuration file') | |
| p.add_argument('--demo-every', type=int, default=500, | |
| help='save a demo grid every this many steps') | |
| p.add_argument('--evaluate-every', type=int, default=10000, | |
| help='save a demo grid every this many steps') | |
| p.add_argument('--evaluate-n', type=int, default=2000, | |
| help='the number of samples to draw to evaluate') | |
| p.add_argument('--gns', action='store_true', | |
| help='measure the gradient noise scale (DDP only)') | |
| p.add_argument('--grad-accum-steps', type=int, default=1, | |
| help='the number of gradient accumulation steps') | |
| p.add_argument('--grow', type=str, | |
| help='the checkpoint to grow from') | |
| p.add_argument('--grow-config', type=str, | |
| help='the configuration file of the model to grow from') | |
| p.add_argument('--lr', type=float, | |
| help='the learning rate') | |
| p.add_argument('--name', type=str, default='model', | |
| help='the name of the run') | |
| p.add_argument('--num-workers', type=int, default=8, | |
| help='the number of data loader workers') | |
| p.add_argument('--resume', type=str, | |
| help='the checkpoint to resume from') | |
| p.add_argument('--sample-n', type=int, default=64, | |
| help='the number of images to sample for demo grids') | |
| p.add_argument('--save-every', type=int, default=10000, | |
| help='save every this many steps') | |
| p.add_argument('--seed', type=int, | |
| help='the random seed') | |
| p.add_argument('--start-method', type=str, default='spawn', | |
| choices=['fork', 'forkserver', 'spawn'], | |
| help='the multiprocessing start method') | |
| p.add_argument('--wandb-entity', type=str, | |
| help='the wandb entity name') | |
| p.add_argument('--wandb-group', type=str, | |
| help='the wandb group name') | |
| p.add_argument('--wandb-project', type=str, | |
| help='the wandb project name (specify this to enable wandb)') | |
| p.add_argument('--wandb-save-model', action='store_true', | |
| help='save model to wandb') | |
| args = p.parse_args() | |
| mp.set_start_method(args.start_method) | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| config = K.config.load_config(open(args.config)) | |
| model_config = config['model'] | |
| dataset_config = config['dataset'] | |
| opt_config = config['optimizer'] | |
| sched_config = config['lr_sched'] | |
| ema_sched_config = config['ema_sched'] | |
| # TODO: allow non-square input sizes | |
| assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] | |
| size = model_config['input_size'] | |
| ddp_kwargs = accelerate.DistributedDataParallelKwargs(find_unused_parameters=model_config['skip_stages'] > 0) | |
| accelerator = accelerate.Accelerator(kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps=args.grad_accum_steps) | |
| device = accelerator.device | |
| print(f'Process {accelerator.process_index} using device: {device}', flush=True) | |
| if args.seed is not None: | |
| seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes], generator=torch.Generator().manual_seed(args.seed)) | |
| torch.manual_seed(seeds[accelerator.process_index]) | |
| inner_model = K.config.make_model(config) | |
| if accelerator.is_main_process: | |
| print('Parameters:', K.utils.n_params(inner_model)) | |
| # If logging to wandb, initialize the run | |
| use_wandb = accelerator.is_main_process and args.wandb_project | |
| if use_wandb: | |
| import wandb | |
| log_config = vars(args) | |
| log_config['config'] = config | |
| log_config['parameters'] = K.utils.n_params(inner_model) | |
| wandb.init(project=args.wandb_project, entity=args.wandb_entity, group=args.wandb_group, config=log_config, save_code=True) | |
| if opt_config['type'] == 'adamw': | |
| opt = optim.AdamW(inner_model.parameters(), | |
| lr=opt_config['lr'] if args.lr is None else args.lr, | |
| betas=tuple(opt_config['betas']), | |
| eps=opt_config['eps'], | |
| weight_decay=opt_config['weight_decay']) | |
| elif opt_config['type'] == 'sgd': | |
| opt = optim.SGD(inner_model.parameters(), | |
| lr=opt_config['lr'] if args.lr is None else args.lr, | |
| momentum=opt_config.get('momentum', 0.), | |
| nesterov=opt_config.get('nesterov', False), | |
| weight_decay=opt_config.get('weight_decay', 0.)) | |
| else: | |
| raise ValueError('Invalid optimizer type') | |
| if sched_config['type'] == 'inverse': | |
| sched = K.utils.InverseLR(opt, | |
| inv_gamma=sched_config['inv_gamma'], | |
| power=sched_config['power'], | |
| warmup=sched_config['warmup']) | |
| elif sched_config['type'] == 'exponential': | |
| sched = K.utils.ExponentialLR(opt, | |
| num_steps=sched_config['num_steps'], | |
| decay=sched_config['decay'], | |
| warmup=sched_config['warmup']) | |
| else: | |
| raise ValueError('Invalid schedule type') | |
| assert ema_sched_config['type'] == 'inverse' | |
| ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], | |
| max_value=ema_sched_config['max_value']) | |
| tf = transforms.Compose([ | |
| transforms.Resize(size[0], interpolation=transforms.InterpolationMode.LANCZOS), | |
| transforms.CenterCrop(size[0]), | |
| K.augmentation.KarrasAugmentationPipeline(model_config['augment_prob']), | |
| ]) | |
| if dataset_config['type'] == 'imagefolder': | |
| train_set = K.utils.FolderOfImages(dataset_config['location'], transform=tf) | |
| elif dataset_config['type'] == 'cifar10': | |
| train_set = datasets.CIFAR10(dataset_config['location'], train=True, download=True, transform=tf) | |
| elif dataset_config['type'] == 'mnist': | |
| train_set = datasets.MNIST(dataset_config['location'], train=True, download=True, transform=tf) | |
| elif dataset_config['type'] == 'huggingface': | |
| from datasets import load_dataset | |
| train_set = load_dataset(dataset_config['location']) | |
| train_set.set_transform(partial(K.utils.hf_datasets_augs_helper, transform=tf, image_key=dataset_config['image_key'])) | |
| train_set = train_set['train'] | |
| else: | |
| raise ValueError('Invalid dataset type') | |
| if accelerator.is_main_process: | |
| try: | |
| print('Number of items in dataset:', len(train_set)) | |
| except TypeError: | |
| pass | |
| image_key = dataset_config.get('image_key', 0) | |
| train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, drop_last=True, | |
| num_workers=args.num_workers, persistent_workers=True) | |
| if args.grow: | |
| if not args.grow_config: | |
| raise ValueError('--grow requires --grow-config') | |
| ckpt = torch.load(args.grow, map_location='cpu') | |
| old_config = K.config.load_config(open(args.grow_config)) | |
| old_inner_model = K.config.make_model(old_config) | |
| old_inner_model.load_state_dict(ckpt['model_ema']) | |
| if old_config['model']['skip_stages'] != model_config['skip_stages']: | |
| old_inner_model.set_skip_stages(model_config['skip_stages']) | |
| if old_config['model']['patch_size'] != model_config['patch_size']: | |
| old_inner_model.set_patch_size(model_config['patch_size']) | |
| inner_model.load_state_dict(old_inner_model.state_dict()) | |
| del ckpt, old_inner_model | |
| inner_model, opt, train_dl = accelerator.prepare(inner_model, opt, train_dl) | |
| if use_wandb: | |
| wandb.watch(inner_model) | |
| if args.gns: | |
| gns_stats_hook = K.gns.DDPGradientStatsHook(inner_model) | |
| gns_stats = K.gns.GradientNoiseScale() | |
| else: | |
| gns_stats = None | |
| sigma_min = model_config['sigma_min'] | |
| sigma_max = model_config['sigma_max'] | |
| sample_density = K.config.make_sample_density(model_config) | |
| model = K.config.make_denoiser_wrapper(config)(inner_model) | |
| model_ema = deepcopy(model) | |
| state_path = Path(f'{args.name}_state.json') | |
| if state_path.exists() or args.resume: | |
| if args.resume: | |
| ckpt_path = args.resume | |
| if not args.resume: | |
| state = json.load(open(state_path)) | |
| ckpt_path = state['latest_checkpoint'] | |
| if accelerator.is_main_process: | |
| print(f'Resuming from {ckpt_path}...') | |
| ckpt = torch.load(ckpt_path, map_location='cpu') | |
| accelerator.unwrap_model(model.inner_model).load_state_dict(ckpt['model']) | |
| accelerator.unwrap_model(model_ema.inner_model).load_state_dict(ckpt['model_ema']) | |
| opt.load_state_dict(ckpt['opt']) | |
| sched.load_state_dict(ckpt['sched']) | |
| ema_sched.load_state_dict(ckpt['ema_sched']) | |
| epoch = ckpt['epoch'] + 1 | |
| step = ckpt['step'] + 1 | |
| if args.gns and ckpt.get('gns_stats', None) is not None: | |
| gns_stats.load_state_dict(ckpt['gns_stats']) | |
| del ckpt | |
| else: | |
| epoch = 0 | |
| step = 0 | |
| evaluate_enabled = args.evaluate_every > 0 and args.evaluate_n > 0 | |
| if evaluate_enabled: | |
| extractor = K.evaluation.InceptionV3FeatureExtractor(device=device) | |
| train_iter = iter(train_dl) | |
| if accelerator.is_main_process: | |
| print('Computing features for reals...') | |
| reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, args.evaluate_n, args.batch_size) | |
| if accelerator.is_main_process: | |
| metrics_log = K.utils.CSVLogger(f'{args.name}_metrics.csv', ['step', 'fid', 'kid']) | |
| del train_iter | |
| def demo(): | |
| if accelerator.is_main_process: | |
| tqdm.write('Sampling...') | |
| filename = f'{args.name}_demo_{step:08}.png' | |
| n_per_proc = math.ceil(args.sample_n / accelerator.num_processes) | |
| x = torch.randn([n_per_proc, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max | |
| sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) | |
| x_0 = K.sampling.sample_dpmpp_2m(model_ema, x, sigmas, disable=not accelerator.is_main_process) | |
| x_0 = accelerator.gather(x_0)[:args.sample_n] | |
| if accelerator.is_main_process: | |
| grid = utils.make_grid(x_0, nrow=math.ceil(args.sample_n ** 0.5), padding=0) | |
| K.utils.to_pil_image(grid).save(filename) | |
| if use_wandb: | |
| wandb.log({'demo_grid': wandb.Image(filename)}, step=step) | |
| def evaluate(): | |
| if not evaluate_enabled: | |
| return | |
| if accelerator.is_main_process: | |
| tqdm.write('Evaluating...') | |
| sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) | |
| def sample_fn(n): | |
| x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max | |
| x_0 = K.sampling.sample_dpmpp_2m(model_ema, x, sigmas, disable=True) | |
| return x_0 | |
| fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, args.evaluate_n, args.batch_size) | |
| if accelerator.is_main_process: | |
| fid = K.evaluation.fid(fakes_features, reals_features) | |
| kid = K.evaluation.kid(fakes_features, reals_features) | |
| print(f'FID: {fid.item():g}, KID: {kid.item():g}') | |
| if accelerator.is_main_process: | |
| metrics_log.write(step, fid.item(), kid.item()) | |
| if use_wandb: | |
| wandb.log({'FID': fid.item(), 'KID': kid.item()}, step=step) | |
| def save(): | |
| accelerator.wait_for_everyone() | |
| filename = f'{args.name}_{step:08}.pth' | |
| if accelerator.is_main_process: | |
| tqdm.write(f'Saving to {filename}...') | |
| obj = { | |
| 'model': accelerator.unwrap_model(model.inner_model).state_dict(), | |
| 'model_ema': accelerator.unwrap_model(model_ema.inner_model).state_dict(), | |
| 'opt': opt.state_dict(), | |
| 'sched': sched.state_dict(), | |
| 'ema_sched': ema_sched.state_dict(), | |
| 'epoch': epoch, | |
| 'step': step, | |
| 'gns_stats': gns_stats.state_dict() if gns_stats is not None else None, | |
| } | |
| accelerator.save(obj, filename) | |
| if accelerator.is_main_process: | |
| state_obj = {'latest_checkpoint': filename} | |
| json.dump(state_obj, open(state_path, 'w')) | |
| if args.wandb_save_model and use_wandb: | |
| wandb.save(filename) | |
| try: | |
| while True: | |
| for batch in tqdm(train_dl, disable=not accelerator.is_main_process): | |
| with accelerator.accumulate(model): | |
| reals, _, aug_cond = batch[image_key] | |
| noise = torch.randn_like(reals) | |
| sigma = sample_density([reals.shape[0]], device=device) | |
| losses = model.loss(reals, noise, sigma, aug_cond=aug_cond) | |
| losses_all = accelerator.gather(losses) | |
| loss = losses_all.mean() | |
| accelerator.backward(losses.mean()) | |
| if args.gns: | |
| sq_norm_small_batch, sq_norm_large_batch = gns_stats_hook.get_stats() | |
| gns_stats.update(sq_norm_small_batch, sq_norm_large_batch, reals.shape[0], reals.shape[0] * accelerator.num_processes) | |
| opt.step() | |
| sched.step() | |
| opt.zero_grad() | |
| if accelerator.sync_gradients: | |
| ema_decay = ema_sched.get_value() | |
| K.utils.ema_update(model, model_ema, ema_decay) | |
| ema_sched.step() | |
| if accelerator.is_main_process: | |
| if step % 25 == 0: | |
| if args.gns: | |
| tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}, gns: {gns_stats.get_gns():g}') | |
| else: | |
| tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}') | |
| if use_wandb: | |
| log_dict = { | |
| 'epoch': epoch, | |
| 'loss': loss.item(), | |
| 'lr': sched.get_last_lr()[0], | |
| 'ema_decay': ema_decay, | |
| } | |
| if args.gns: | |
| log_dict['gradient_noise_scale'] = gns_stats.get_gns() | |
| wandb.log(log_dict, step=step) | |
| if step % args.demo_every == 0: | |
| demo() | |
| if evaluate_enabled and step > 0 and step % args.evaluate_every == 0: | |
| evaluate() | |
| if step > 0 and step % args.save_every == 0: | |
| save() | |
| step += 1 | |
| epoch += 1 | |
| except KeyboardInterrupt: | |
| pass | |
| if __name__ == '__main__': | |
| main() | |