| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import argparse | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import random | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						import horovod.torch as hvd | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from proard.classification.elastic_nn.modules.dynamic_op import ( | 
					
					
						
						| 
							 | 
						    DynamicSeparableConv2d, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						from proard.classification.elastic_nn.networks import DYNMobileNetV3,DYNProxylessNASNets,DYNResNets,DYNProxylessNASNets_Cifar,DYNMobileNetV3_Cifar,DYNResNets_Cifar | 
					
					
						
						| 
							 | 
						from proard.classification.run_manager import DistributedClassificationRunConfig | 
					
					
						
						| 
							 | 
						from proard.classification.run_manager.distributed_run_manager import ( | 
					
					
						
						| 
							 | 
						    DistributedRunManager | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						from proard.utils import download_url, MyRandomResizedCrop | 
					
					
						
						| 
							 | 
						from proard.classification.elastic_nn.training.progressive_shrinking import load_models | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						parser = argparse.ArgumentParser() | 
					
					
						
						| 
							 | 
						parser.add_argument( | 
					
					
						
						| 
							 | 
						    "--task", | 
					
					
						
						| 
							 | 
						    type=str, | 
					
					
						
						| 
							 | 
						    default="expand", | 
					
					
						
						| 
							 | 
						    choices=[ | 
					
					
						
						| 
							 | 
						        "kernel",  | 
					
					
						
						| 
							 | 
						        "depth", | 
					
					
						
						| 
							 | 
						        "expand", | 
					
					
						
						| 
							 | 
						        "width",  | 
					
					
						
						| 
							 | 
						    ], | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						parser.add_argument("--phase", type=int, default=2, choices=[1, 2]) | 
					
					
						
						| 
							 | 
						parser.add_argument("--resume", action="store_true") | 
					
					
						
						| 
							 | 
						parser.add_argument("--model_name", type=str, default="MBV2", choices=["ResNet50", "MBV3", "ProxylessNASNet","MBV2"]) | 
					
					
						
						| 
							 | 
						parser.add_argument("--dataset", type=str, default="cifar100", choices=["cifar10", "cifar100", "imagenet"]) | 
					
					
						
						| 
							 | 
						parser.add_argument("--robust_mode", type=bool, default=True) | 
					
					
						
						| 
							 | 
						parser.add_argument("--epsilon", type=float, default=0.031) | 
					
					
						
						| 
							 | 
						parser.add_argument("--num_steps", type=int, default=10) | 
					
					
						
						| 
							 | 
						parser.add_argument("--step_size", type=float, default=0.0078)  | 
					
					
						
						| 
							 | 
						parser.add_argument("--clip_min", type=int, default=0) | 
					
					
						
						| 
							 | 
						parser.add_argument("--clip_max", type=int, default=1) | 
					
					
						
						| 
							 | 
						parser.add_argument("--const_init", type=bool, default=False) | 
					
					
						
						| 
							 | 
						parser.add_argument("--beta", type=float, default=6.0) | 
					
					
						
						| 
							 | 
						parser.add_argument("--distance", type=str, default="l_inf",choices=["l_inf","l2"]) | 
					
					
						
						| 
							 | 
						parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"]) | 
					
					
						
						| 
							 | 
						parser.add_argument("--test_criterion", type=str, default="ce",choices=["ce"]) | 
					
					
						
						| 
							 | 
						parser.add_argument("--kd_criterion", type=str, default="rslad",choices=["ard","rslad","adaad"]) | 
					
					
						
						| 
							 | 
						parser.add_argument("--attack_type", type=str, default="linf-pgd",choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce']) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						args = parser.parse_args() | 
					
					
						
						| 
							 | 
						if args.model_name == "ResNet50": | 
					
					
						
						| 
							 | 
						    args.ks_list = "3" | 
					
					
						
						| 
							 | 
						    if args.task == "width": | 
					
					
						
						| 
							 | 
						        if args.robust_mode: | 
					
					
						
						| 
							 | 
						            args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2width" | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            args.path = "exp/"+ args.dataset + '/' +args.model_name +'/' + args.train_criterion +"/normal2width"     | 
					
					
						
						| 
							 | 
						        args.dynamic_batch_size = 1 | 
					
					
						
						| 
							 | 
						        args.n_epochs = 120 | 
					
					
						
						| 
							 | 
						        args.base_lr = 3e-2 | 
					
					
						
						| 
							 | 
						        args.warmup_epochs = 5 | 
					
					
						
						| 
							 | 
						        args.warmup_lr = -1 | 
					
					
						
						| 
							 | 
						        args.width_mult_list = "0.65,0.8,1.0" | 
					
					
						
						| 
							 | 
						        args.expand_list = "0.35" | 
					
					
						
						| 
							 | 
						        args.depth_list = "2" | 
					
					
						
						| 
							 | 
						    elif args.task == "depth": | 
					
					
						
						| 
							 | 
						        if args.robust_mode: | 
					
					
						
						| 
							 | 
						            args.path = "exp/robust/"+ args.dataset + '/'  + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase%d" % args.phase | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            args.path = "exp/"+ args.dataset + '/'  + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase%d" % args.phase      | 
					
					
						
						| 
							 | 
						        args.dynamic_batch_size = 2 | 
					
					
						
						| 
							 | 
						        if args.phase == 1: | 
					
					
						
						| 
							 | 
						            args.n_epochs = 25 | 
					
					
						
						| 
							 | 
						            args.base_lr = 2.5e-3 | 
					
					
						
						| 
							 | 
						            args.warmup_epochs = 0 | 
					
					
						
						| 
							 | 
						            args.warmup_lr = -1 | 
					
					
						
						| 
							 | 
						            args.width_mult_list = "0.65,0.8,1.0" | 
					
					
						
						| 
							 | 
						            args.expand_list ="0.35" | 
					
					
						
						| 
							 | 
						            args.depth_list = "1,2" | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            args.n_epochs = 120 | 
					
					
						
						| 
							 | 
						            args.base_lr = 7.5e-3 | 
					
					
						
						| 
							 | 
						            args.warmup_epochs = 5 | 
					
					
						
						| 
							 | 
						            args.warmup_lr = -1 | 
					
					
						
						| 
							 | 
						            args.width_mult_list = "0.65,0.8,1.0" | 
					
					
						
						| 
							 | 
						            args.expand_list = "0.35" | 
					
					
						
						| 
							 | 
						            args.depth_list = "0,1,2" | 
					
					
						
						| 
							 | 
						    elif args.task == "expand": | 
					
					
						
						| 
							 | 
						        if args.robust_mode :  | 
					
					
						
						| 
							 | 
						            args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase%d" % args.phase  | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            args.path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase%d" % args.phase     | 
					
					
						
						| 
							 | 
						        args.dynamic_batch_size = 4 | 
					
					
						
						| 
							 | 
						        if args.phase == 1: | 
					
					
						
						| 
							 | 
						            args.n_epochs = 25 | 
					
					
						
						| 
							 | 
						            args.base_lr = 2.5e-3 | 
					
					
						
						| 
							 | 
						            args.warmup_epochs = 0 | 
					
					
						
						| 
							 | 
						            args.warmup_lr = -1 | 
					
					
						
						| 
							 | 
						            args.width_mult_list = "0.65,0.8,1.0" | 
					
					
						
						| 
							 | 
						            args.expand_list = "0.25,0.35" | 
					
					
						
						| 
							 | 
						            args.depth_list = "0,1,2" | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            args.n_epochs = 120 | 
					
					
						
						| 
							 | 
						            args.base_lr = 7.5e-3 | 
					
					
						
						| 
							 | 
						            args.warmup_epochs = 5 | 
					
					
						
						| 
							 | 
						            args.warmup_lr = -1 | 
					
					
						
						| 
							 | 
						            args.width_mult_list = "0.65,0.8,1.0" | 
					
					
						
						| 
							 | 
						            args.expand_list = "0.2,0.25,0.35" | 
					
					
						
						| 
							 | 
						            args.depth_list = "0,1,2" | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise NotImplementedError | 
					
					
						
						| 
							 | 
						else: | 
					
					
						
						| 
							 | 
						    args.width_mult_list = "1.0"     | 
					
					
						
						| 
							 | 
						    if args.task == "kernel": | 
					
					
						
						| 
							 | 
						        if args.robust_mode: | 
					
					
						
						| 
							 | 
						            args.path = "exp/robust/"+ args.dataset + '/' +  args.model_name +'/' + args.train_criterion +"/normal2kernel" | 
					
					
						
						| 
							 | 
						        else:     | 
					
					
						
						| 
							 | 
						            args.path = "exp/"+ args.dataset + '/' +  args.model_name +'/' + args.train_criterion +"/normal2kernel" | 
					
					
						
						| 
							 | 
						        args.dynamic_batch_size = 1 | 
					
					
						
						| 
							 | 
						        args.n_epochs = 120 | 
					
					
						
						| 
							 | 
						        args.base_lr = 3e-2 | 
					
					
						
						| 
							 | 
						        args.warmup_epochs = 5 | 
					
					
						
						| 
							 | 
						        args.warmup_lr = -1 | 
					
					
						
						| 
							 | 
						        args.ks_list = "3,5,7" | 
					
					
						
						| 
							 | 
						        args.expand_list = "6" | 
					
					
						
						| 
							 | 
						        args.depth_list = "4" | 
					
					
						
						| 
							 | 
						    elif args.task == "depth": | 
					
					
						
						| 
							 | 
						        if args.robust_mode :  | 
					
					
						
						| 
							 | 
						            args.path = "exp/robust/"+args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase%d" % args.phase | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            args.path = "exp/"+args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase%d" % args.phase     | 
					
					
						
						| 
							 | 
						        args.dynamic_batch_size = 2 | 
					
					
						
						| 
							 | 
						        if args.phase == 1: | 
					
					
						
						| 
							 | 
						            args.n_epochs = 25 | 
					
					
						
						| 
							 | 
						            args.base_lr = 2.5e-3 | 
					
					
						
						| 
							 | 
						            args.warmup_epochs = 0 | 
					
					
						
						| 
							 | 
						            args.warmup_lr = -1 | 
					
					
						
						| 
							 | 
						            args.ks_list = "3,5,7" | 
					
					
						
						| 
							 | 
						            args.expand_list = "6" | 
					
					
						
						| 
							 | 
						            args.depth_list = "3,4" | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            args.n_epochs = 120 | 
					
					
						
						| 
							 | 
						            args.base_lr = 7.5e-3 | 
					
					
						
						| 
							 | 
						            args.warmup_epochs = 5 | 
					
					
						
						| 
							 | 
						            args.warmup_lr = -1 | 
					
					
						
						| 
							 | 
						            args.ks_list = "3,5,7" | 
					
					
						
						| 
							 | 
						            args.expand_list = "6" | 
					
					
						
						| 
							 | 
						            args.depth_list = "2,3,4" | 
					
					
						
						| 
							 | 
						    elif args.task == "expand": | 
					
					
						
						| 
							 | 
						        if args.robust_mode: | 
					
					
						
						| 
							 | 
						            args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase%d" % args.phase | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            args.path = "exp/"+ args.dataset + '/' + args.model_name +  '/' + args.train_criterion + "/kernel_depth2kernel_depth_width/phase%d" % args.phase     | 
					
					
						
						| 
							 | 
						        args.dynamic_batch_size = 4 | 
					
					
						
						| 
							 | 
						        if args.phase == 1: | 
					
					
						
						| 
							 | 
						            args.n_epochs = 25 | 
					
					
						
						| 
							 | 
						            args.base_lr = 2.5e-3 | 
					
					
						
						| 
							 | 
						            args.warmup_epochs = 0 | 
					
					
						
						| 
							 | 
						            args.warmup_lr = -1 | 
					
					
						
						| 
							 | 
						            args.ks_list = "3,5,7" | 
					
					
						
						| 
							 | 
						            args.expand_list = "4,6" | 
					
					
						
						| 
							 | 
						            args.depth_list = "2,3,4" | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            args.n_epochs = 120 | 
					
					
						
						| 
							 | 
						            args.base_lr = 7.5e-3 | 
					
					
						
						| 
							 | 
						            args.warmup_epochs = 5 | 
					
					
						
						| 
							 | 
						            args.warmup_lr = -1 | 
					
					
						
						| 
							 | 
						            args.ks_list = "3,5,7" | 
					
					
						
						| 
							 | 
						            args.expand_list = "3,4,6" | 
					
					
						
						| 
							 | 
						            args.depth_list = "2,3,4" | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise NotImplementedError | 
					
					
						
						| 
							 | 
						args.manual_seed = 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						args.lr_schedule_type = "cosine" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						args.base_batch_size = 64 | 
					
					
						
						| 
							 | 
						args.valid_size = 64 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						args.opt_type = "sgd" | 
					
					
						
						| 
							 | 
						args.momentum = 0.9 | 
					
					
						
						| 
							 | 
						args.no_nesterov = False | 
					
					
						
						| 
							 | 
						args.weight_decay = 3e-5 | 
					
					
						
						| 
							 | 
						args.label_smoothing = 0.1 | 
					
					
						
						| 
							 | 
						args.no_decay_keys = "bn#bias" | 
					
					
						
						| 
							 | 
						args.fp16_allreduce = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						args.model_init = "he_fout" | 
					
					
						
						| 
							 | 
						args.validation_frequency = 1 | 
					
					
						
						| 
							 | 
						args.print_frequency = 10 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						args.n_worker = 8 | 
					
					
						
						| 
							 | 
						args.resize_scale = 0.08 | 
					
					
						
						| 
							 | 
						args.distort_color = "tf" | 
					
					
						
						| 
							 | 
						if args.dataset == "imagenet":   | 
					
					
						
						| 
							 | 
						    args.image_size = "128,160,192,224" | 
					
					
						
						| 
							 | 
						else: | 
					
					
						
						| 
							 | 
						    args.image_size = "32"     | 
					
					
						
						| 
							 | 
						args.continuous_size = True | 
					
					
						
						| 
							 | 
						args.not_sync_distributed_image_size = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						args.bn_momentum = 0.1 | 
					
					
						
						| 
							 | 
						args.bn_eps = 1e-5 | 
					
					
						
						| 
							 | 
						args.dropout = 0.1 | 
					
					
						
						| 
							 | 
						args.base_stage_width = "google" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						args.dy_conv_scaling_mode = 1 | 
					
					
						
						| 
							 | 
						args.independent_distributed_sampling = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						args.kd_ratio = 1.0 | 
					
					
						
						| 
							 | 
						args.kd_type = "ce" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    os.makedirs(args.path, exist_ok=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    hvd.init() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    torch.cuda.set_device(hvd.local_rank()) | 
					
					
						
						| 
							 | 
						    if args.robust_mode: | 
					
					
						
						| 
							 | 
						        args.teacher_path = 'exp/robust/teacher/' + args.dataset + '/' +  args.model_name + '/' + args.train_criterion + "/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        args.teacher_path = 'exp/teacher/' + args.dataset + '/' +  args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						    num_gpus = hvd.size() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    torch.manual_seed(args.manual_seed) | 
					
					
						
						| 
							 | 
						    torch.cuda.manual_seed_all(args.manual_seed) | 
					
					
						
						| 
							 | 
						    np.random.seed(args.manual_seed) | 
					
					
						
						| 
							 | 
						    random.seed(args.manual_seed) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    args.image_size = [int(img_size) for img_size in args.image_size.split(",")] | 
					
					
						
						| 
							 | 
						    if len(args.image_size) == 1: | 
					
					
						
						| 
							 | 
						        args.image_size = args.image_size[0] | 
					
					
						
						| 
							 | 
						    MyRandomResizedCrop.CONTINUOUS = args.continuous_size | 
					
					
						
						| 
							 | 
						    MyRandomResizedCrop.SYNC_DISTRIBUTED = not args.not_sync_distributed_image_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    args.lr_schedule_param = None | 
					
					
						
						| 
							 | 
						    args.opt_param = { | 
					
					
						
						| 
							 | 
						        "momentum": args.momentum, | 
					
					
						
						| 
							 | 
						        "nesterov": not args.no_nesterov, | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						    args.init_lr = args.base_lr * num_gpus   | 
					
					
						
						| 
							 | 
						    if args.warmup_lr < 0: | 
					
					
						
						| 
							 | 
						        args.warmup_lr = args.base_lr | 
					
					
						
						| 
							 | 
						    args.train_batch_size = args.base_batch_size | 
					
					
						
						| 
							 | 
						    args.test_batch_size = args.base_batch_size * 4 | 
					
					
						
						| 
							 | 
						    run_config = DistributedClassificationRunConfig( | 
					
					
						
						| 
							 | 
						        **args.__dict__, num_replicas=num_gpus, rank=hvd.rank() | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if hvd.rank() == 0: | 
					
					
						
						| 
							 | 
						        print("Run config:") | 
					
					
						
						| 
							 | 
						        for k, v in run_config.config.items(): | 
					
					
						
						| 
							 | 
						            print("\t%s: %s" % (k, v)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.dy_conv_scaling_mode == -1: | 
					
					
						
						| 
							 | 
						        args.dy_conv_scaling_mode = None | 
					
					
						
						| 
							 | 
						    DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = args.dy_conv_scaling_mode | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    args.width_mult_list = [ | 
					
					
						
						| 
							 | 
						        float(width_mult) for width_mult in args.width_mult_list.split(",") | 
					
					
						
						| 
							 | 
						    ] | 
					
					
						
						| 
							 | 
						    args.ks_list = [int(ks) for ks in args.ks_list.split(",")] | 
					
					
						
						| 
							 | 
						    if args.model_name == "ResNet50": | 
					
					
						
						| 
							 | 
						        args.expand_list = [float(e) for e in args.expand_list.split(",")] | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        args.expand_list = [int(e) for e in args.expand_list.split(",")]     | 
					
					
						
						| 
							 | 
						    args.depth_list = [int(d) for d in args.depth_list.split(",")] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    args.width_mult_list = ( | 
					
					
						
						| 
							 | 
						        args.width_mult_list[0] | 
					
					
						
						| 
							 | 
						        if len(args.width_mult_list) == 1 | 
					
					
						
						| 
							 | 
						        else args.width_mult_list | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.model_name == "ResNet50": | 
					
					
						
						| 
							 | 
						        if args.dataset == "cifar10" or args.dataset == "cifar100": | 
					
					
						
						| 
							 | 
						            net = DYNResNets_Cifar( n_classes=run_config.data_provider.n_classes, | 
					
					
						
						| 
							 | 
						                bn_param=(args.bn_momentum, args.bn_eps), | 
					
					
						
						| 
							 | 
						                dropout_rate=args.dropout, | 
					
					
						
						| 
							 | 
						                depth_list=args.depth_list, | 
					
					
						
						| 
							 | 
						                expand_ratio_list=args.expand_list, | 
					
					
						
						| 
							 | 
						                width_mult_list=args.width_mult_list,) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            net = DYNResNets( n_classes=run_config.data_provider.n_classes, | 
					
					
						
						| 
							 | 
						                bn_param=(args.bn_momentum, args.bn_eps), | 
					
					
						
						| 
							 | 
						                dropout_rate=args.dropout, | 
					
					
						
						| 
							 | 
						                depth_list=args.depth_list, | 
					
					
						
						| 
							 | 
						                expand_ratio_list=args.expand_list, | 
					
					
						
						| 
							 | 
						                width_mult_list=args.width_mult_list,)    | 
					
					
						
						| 
							 | 
						    elif args.model_name == "MBV3":   | 
					
					
						
						| 
							 | 
						        if args.dataset == "cifar10" or args.dataset == "cifar100":   | 
					
					
						
						| 
							 | 
						            net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps), | 
					
					
						
						| 
							 | 
						                                dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps), | 
					
					
						
						| 
							 | 
						                                dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)     | 
					
					
						
						| 
							 | 
						    elif args.model_name == "ProxylessNASNet":  | 
					
					
						
						| 
							 | 
						        if args.dataset == "cifar10" or args.dataset == "cifar100":      | 
					
					
						
						| 
							 | 
						            net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps), | 
					
					
						
						| 
							 | 
						                                dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps), | 
					
					
						
						| 
							 | 
						                                dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)    | 
					
					
						
						| 
							 | 
						    elif args.model_name == "MBV2":  | 
					
					
						
						| 
							 | 
						        if args.dataset == "cifar10" or args.dataset == "cifar100":      | 
					
					
						
						| 
							 | 
						            net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps), | 
					
					
						
						| 
							 | 
						                                dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list,base_stage_width=args.base_stage_width) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps), | 
					
					
						
						| 
							 | 
						                                dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list,base_stage_width=args.base_stage_width)              | 
					
					
						
						| 
							 | 
						    else:  | 
					
					
						
						| 
							 | 
						        raise NotImplementedError   | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if args.kd_ratio > 0: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if args.model_name =="ResNet50": | 
					
					
						
						| 
							 | 
						            if args.dataset == "cifar10" or args.dataset == "cifar100": | 
					
					
						
						| 
							 | 
						                args.teacher_model = DYNResNets_Cifar( | 
					
					
						
						| 
							 | 
						                    n_classes=run_config.data_provider.n_classes, | 
					
					
						
						| 
							 | 
						                    bn_param=(args.bn_momentum, args.bn_eps), | 
					
					
						
						| 
							 | 
						                    dropout_rate=args.dropout, | 
					
					
						
						| 
							 | 
						                    depth_list=[2], | 
					
					
						
						| 
							 | 
						                    expand_ratio_list=[0.35], | 
					
					
						
						| 
							 | 
						                    width_mult_list=[1.0], | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                args.teacher_model = DYNResNets( | 
					
					
						
						| 
							 | 
						                    n_classes=run_config.data_provider.n_classes, | 
					
					
						
						| 
							 | 
						                    bn_param=(args.bn_momentum, args.bn_eps), | 
					
					
						
						| 
							 | 
						                    dropout_rate=args.dropout, | 
					
					
						
						| 
							 | 
						                    depth_list=[2], | 
					
					
						
						| 
							 | 
						                    expand_ratio_list=[0.35], | 
					
					
						
						| 
							 | 
						                    width_mult_list=[1.0], | 
					
					
						
						| 
							 | 
						                )    | 
					
					
						
						| 
							 | 
						        elif args.model_name =="MBV3":     | 
					
					
						
						| 
							 | 
						            if args.dataset == "cifar10" or args.dataset == "cifar100": | 
					
					
						
						| 
							 | 
						                args.teacher_model = DYNMobileNetV3_Cifar( | 
					
					
						
						| 
							 | 
						                    n_classes=run_config.data_provider.n_classes, | 
					
					
						
						| 
							 | 
						                    bn_param=(args.bn_momentum, args.bn_eps), | 
					
					
						
						| 
							 | 
						                    dropout_rate=0, | 
					
					
						
						| 
							 | 
						                    width_mult=1.0, | 
					
					
						
						| 
							 | 
						                    ks_list=[7], | 
					
					
						
						| 
							 | 
						                    expand_ratio_list=[6], | 
					
					
						
						| 
							 | 
						                    depth_list=[4] | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                args.teacher_model = DYNMobileNetV3( | 
					
					
						
						| 
							 | 
						                    n_classes=run_config.data_provider.n_classes, | 
					
					
						
						| 
							 | 
						                    bn_param=(args.bn_momentum, args.bn_eps), | 
					
					
						
						| 
							 | 
						                    dropout_rate=0, | 
					
					
						
						| 
							 | 
						                    width_mult=1.0, | 
					
					
						
						| 
							 | 
						                    ks_list=[7], | 
					
					
						
						| 
							 | 
						                    expand_ratio_list=[6], | 
					
					
						
						| 
							 | 
						                    depth_list=[4] | 
					
					
						
						| 
							 | 
						                )     | 
					
					
						
						| 
							 | 
						        elif args.model_name == "ProxylessNASNet": | 
					
					
						
						| 
							 | 
						            if args.dataset == "cifar10" or args.dataset == "cifar100": | 
					
					
						
						| 
							 | 
						                args.teacher_model  = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes, | 
					
					
						
						| 
							 | 
						                    bn_param=(args.bn_momentum, args.bn_eps), | 
					
					
						
						| 
							 | 
						                    dropout_rate=0, | 
					
					
						
						| 
							 | 
						                    width_mult=1.0, | 
					
					
						
						| 
							 | 
						                    ks_list=[7], | 
					
					
						
						| 
							 | 
						                    expand_ratio_list=[6], | 
					
					
						
						| 
							 | 
						                    depth_list=[4])    | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                args.teacher_model  = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes, | 
					
					
						
						| 
							 | 
						                    bn_param=(args.bn_momentum, args.bn_eps), | 
					
					
						
						| 
							 | 
						                    dropout_rate=0, | 
					
					
						
						| 
							 | 
						                    width_mult=1.0, | 
					
					
						
						| 
							 | 
						                    ks_list=[7], | 
					
					
						
						| 
							 | 
						                    expand_ratio_list=[6], | 
					
					
						
						| 
							 | 
						                    depth_list=[4])  | 
					
					
						
						| 
							 | 
						        elif args.model_name == "MBV2": | 
					
					
						
						| 
							 | 
						            if args.dataset == "cifar10" or args.dataset == "cifar100": | 
					
					
						
						| 
							 | 
						                args.teacher_model  = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes, | 
					
					
						
						| 
							 | 
						                    bn_param=(args.bn_momentum, args.bn_eps), | 
					
					
						
						| 
							 | 
						                    dropout_rate=0, | 
					
					
						
						| 
							 | 
						                    width_mult=1.0, | 
					
					
						
						| 
							 | 
						                    ks_list=[7], | 
					
					
						
						| 
							 | 
						                    expand_ratio_list=[6], | 
					
					
						
						| 
							 | 
						                    depth_list=[4],base_stage_width=args.base_stage_width)    | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                args.teacher_model  = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes, | 
					
					
						
						| 
							 | 
						                    bn_param=(args.bn_momentum, args.bn_eps), | 
					
					
						
						| 
							 | 
						                    dropout_rate=0, | 
					
					
						
						| 
							 | 
						                    width_mult=1.0, | 
					
					
						
						| 
							 | 
						                    ks_list=[7], | 
					
					
						
						| 
							 | 
						                    expand_ratio_list=[6], | 
					
					
						
						| 
							 | 
						                    depth_list=[4],base_stage_width=args.base_stage_width)                  | 
					
					
						
						| 
							 | 
						        args.teacher_model.cuda() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    """ Distributed RunManager """ | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none | 
					
					
						
						| 
							 | 
						    distributed_run_manager = DistributedRunManager( | 
					
					
						
						| 
							 | 
						        args.path, | 
					
					
						
						| 
							 | 
						        net, | 
					
					
						
						| 
							 | 
						        run_config, | 
					
					
						
						| 
							 | 
						        compression, | 
					
					
						
						| 
							 | 
						        backward_steps=args.dynamic_batch_size, | 
					
					
						
						| 
							 | 
						        is_root=(hvd.rank() == 0), | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    distributed_run_manager.save_config() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    distributed_run_manager.broadcast() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if args.kd_ratio > 0: | 
					
					
						
						| 
							 | 
						        load_models( | 
					
					
						
						| 
							 | 
						            distributed_run_manager, args.teacher_model, model_path=args.teacher_path | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    from proard.classification.elastic_nn.training.progressive_shrinking import ( | 
					
					
						
						| 
							 | 
						        validate, | 
					
					
						
						| 
							 | 
						        train, | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    if args.model_name =="ResNet50": | 
					
					
						
						| 
							 | 
						        validate_func_dict = { | 
					
					
						
						| 
							 | 
						            "image_size_list": {224 if args.dataset == "imagenet" else 32} | 
					
					
						
						| 
							 | 
						            if isinstance(args.image_size, int) | 
					
					
						
						| 
							 | 
						            else sorted({160, 224}), | 
					
					
						
						| 
							 | 
						            "width_mult_list": sorted({min(args.width_mult_list), max(args.width_mult_list)}), | 
					
					
						
						| 
							 | 
						            "expand_ratio_list": sorted({min(args.expand_list), max(args.expand_list)}), | 
					
					
						
						| 
							 | 
						            "depth_list": sorted({min(net.depth_list), max(net.depth_list)}), | 
					
					
						
						| 
							 | 
						        } | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        validate_func_dict = { | 
					
					
						
						| 
							 | 
						            "image_size_list": {224 if args.dataset == "imagenet" else 32} | 
					
					
						
						| 
							 | 
						            if isinstance(args.image_size, int) | 
					
					
						
						| 
							 | 
						            else sorted({160, 224}), | 
					
					
						
						| 
							 | 
						            "width_mult_list": [1.0], | 
					
					
						
						| 
							 | 
						            "ks_list": sorted({min(args.ks_list), max(args.ks_list)}), | 
					
					
						
						| 
							 | 
						            "expand_ratio_list": sorted({min(args.expand_list), max(args.expand_list)}), | 
					
					
						
						| 
							 | 
						            "depth_list": sorted({min(net.depth_list), max(net.depth_list)}), | 
					
					
						
						| 
							 | 
						        }   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.task == "width": | 
					
					
						
						| 
							 | 
						        from proard.classification.elastic_nn.training.progressive_shrinking import ( | 
					
					
						
						| 
							 | 
						            train_elastic_width_mult, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        if distributed_run_manager.start_epoch == 0: | 
					
					
						
						| 
							 | 
						            if args.robust_mode: | 
					
					
						
						| 
							 | 
						                args.dyn_checkpoint_path ='exp/robust/teacher/' +args.dataset + '/' +  args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                args.dyn_checkpoint_path ='exp/teacher/' +args.dataset + '/' +  args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"    | 
					
					
						
						| 
							 | 
						            load_models( | 
					
					
						
						| 
							 | 
						                distributed_run_manager, | 
					
					
						
						| 
							 | 
						                distributed_run_manager.net, | 
					
					
						
						| 
							 | 
						                args.dyn_checkpoint_path, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            distributed_run_manager.write_log( | 
					
					
						
						| 
							 | 
						                "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s" | 
					
					
						
						| 
							 | 
						                % validate(distributed_run_manager, is_test=True, **validate_func_dict), | 
					
					
						
						| 
							 | 
						                "valid", | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            assert args.resume | 
					
					
						
						| 
							 | 
						        train_elastic_width_mult (train,distributed_run_manager,args,validate_func_dict)     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    elif args.task == "kernel": | 
					
					
						
						| 
							 | 
						        validate_func_dict["ks_list"] = sorted(args.ks_list) | 
					
					
						
						| 
							 | 
						        if distributed_run_manager.start_epoch == 0: | 
					
					
						
						| 
							 | 
						            if args.robust_mode: | 
					
					
						
						| 
							 | 
						                args.dyn_checkpoint_path ='exp/robust/teacher/' + args.dataset + '/' +  args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						            else:  | 
					
					
						
						| 
							 | 
						                args.dyn_checkpoint_path ='exp/teacher/' + args.dataset + '/' +  args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"     | 
					
					
						
						| 
							 | 
						            load_models( | 
					
					
						
						| 
							 | 
						                distributed_run_manager, | 
					
					
						
						| 
							 | 
						                distributed_run_manager.net, | 
					
					
						
						| 
							 | 
						                args.dyn_checkpoint_path, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            distributed_run_manager.write_log( | 
					
					
						
						| 
							 | 
						               "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s" | 
					
					
						
						| 
							 | 
						                % validate(distributed_run_manager, is_test=True, **validate_func_dict), | 
					
					
						
						| 
							 | 
						                "valid", | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            assert args.resume | 
					
					
						
						| 
							 | 
						        train( | 
					
					
						
						| 
							 | 
						            distributed_run_manager, | 
					
					
						
						| 
							 | 
						            args, | 
					
					
						
						| 
							 | 
						            lambda _run_manager, epoch, is_test: validate( | 
					
					
						
						| 
							 | 
						                _run_manager, epoch, is_test, **validate_func_dict | 
					
					
						
						| 
							 | 
						            ), | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						    elif args.task == "depth": | 
					
					
						
						| 
							 | 
						        from proard.classification.elastic_nn.training.progressive_shrinking import ( | 
					
					
						
						| 
							 | 
						            train_elastic_depth, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        if args.robust_mode: | 
					
					
						
						| 
							 | 
						            if args.model_name =="ResNet50": | 
					
					
						
						| 
							 | 
						                if args.phase == 1: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path =  "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/normal2width" +"/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase1" + "/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                if args.phase == 1: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path =  "exp/robust/"+ args.dataset + '/' +  args.model_name +'/' + args.train_criterion +"/normal2kernel" +"/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase1" + "/checkpoint/model_best.pth.tar"      | 
					
					
						
						| 
							 | 
						        else : | 
					
					
						
						| 
							 | 
						            if args.model_name =="ResNet50": | 
					
					
						
						| 
							 | 
						                if args.phase == 1: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path =  "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/normal2width" +"/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase1" + "/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                if args.phase == 1: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path =  "exp/"+ args.dataset + '/' +  args.model_name +'/' + args.train_criterion +"/normal2kernel" +"/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase1" + "/checkpoint/model_best.pth.tar"                   | 
					
					
						
						| 
							 | 
						        train_elastic_depth(train, distributed_run_manager, args, validate_func_dict) | 
					
					
						
						| 
							 | 
						    elif args.task == "expand": | 
					
					
						
						| 
							 | 
						        from proard.classification.elastic_nn.training.progressive_shrinking import ( | 
					
					
						
						| 
							 | 
						            train_elastic_expand, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        if args.robust_mode :  | 
					
					
						
						| 
							 | 
						            if args.model_name =="ResNet50": | 
					
					
						
						| 
							 | 
						                if args.phase == 1: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path =  "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width2width_depth/phase2" + "/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase1" + "/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                if args.phase == 1: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path =  "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase2" + "/checkpoint/model_best.pth.tar"   | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase1" +  "/checkpoint/model_best.pth.tar"  | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            if args.model_name =="ResNet50": | 
					
					
						
						| 
							 | 
						                if args.phase == 1: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path =  "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width2width_depth/phase2" + "/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase1" + "/checkpoint/model_best.pth.tar" | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                if args.phase == 1: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path =  "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase2" + "/checkpoint/model_best.pth.tar"   | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase1" +  "/checkpoint/model_best.pth.tar"  | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        train_elastic_expand(train, distributed_run_manager, args, validate_func_dict) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise NotImplementedError | 
					
					
						
						| 
							 | 
						
 |