|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import numpy as np |
|
|
import os |
|
|
import random |
|
|
|
|
|
import horovod.torch as hvd |
|
|
import torch |
|
|
|
|
|
from proard.classification.elastic_nn.modules.dynamic_op import ( |
|
|
DynamicSeparableConv2d, |
|
|
) |
|
|
from proard.classification.elastic_nn.networks import DYNResNets,DYNMobileNetV3,DYNProxylessNASNets,DYNMobileNetV3_Cifar,DYNResNets_Cifar,DYNProxylessNASNets_Cifar |
|
|
from proard.classification.run_manager import DistributedClassificationRunConfig |
|
|
from proard.classification.networks import WideResNet |
|
|
from proard.classification.run_manager import DistributedRunManager |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model_name", type=str, default="MBV2", choices=["ResNet50", "MBV3", "ProxylessNASNet","WideResNet","MBV2"]) |
|
|
parser.add_argument("--teacher_model_name", type=str, default="WideResNet", choices=["WideResNet"]) |
|
|
parser.add_argument("--dataset", type=str, default="cifar100", choices=["cifar10", "cifar100", "imagenet"]) |
|
|
parser.add_argument("--robust_mode", type=bool, default=True) |
|
|
parser.add_argument("--epsilon", type=float, default=0.031) |
|
|
parser.add_argument("--num_steps", type=int, default=10) |
|
|
parser.add_argument("--step_size", type=float, default=0.0078) |
|
|
parser.add_argument("--clip_min", type=int, default=0) |
|
|
parser.add_argument("--clip_max", type=int, default=1) |
|
|
parser.add_argument("--const_init", type=bool, default=False) |
|
|
parser.add_argument("--beta", type=float, default=6.0) |
|
|
parser.add_argument("--distance", type=str, default="l_inf",choices=["l_inf","l2"]) |
|
|
parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"]) |
|
|
parser.add_argument("--test_criterion", type=str, default="ce",choices=["ce"]) |
|
|
parser.add_argument("--kd_criterion", type=str, default="rslad",choices=["ard","rslad","adaad"]) |
|
|
parser.add_argument("--attack_type", type=str, default="linf-pgd",choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce']) |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
if args.robust_mode: |
|
|
args.path = 'exp/robust/teacher/' + args.dataset + "/" + args.model_name + '/' + args.train_criterion |
|
|
else: |
|
|
args.path = 'exp/teacher/' + args.dataset + "/" + args.model_name |
|
|
args.n_epochs = 120 |
|
|
args.base_lr = 0.1 |
|
|
args.warmup_epochs = 5 |
|
|
args.warmup_lr = -1 |
|
|
args.manual_seed = 0 |
|
|
args.lr_schedule_type = "cosine" |
|
|
args.base_batch_size = 128 |
|
|
args.valid_size = None |
|
|
args.opt_type = "sgd" |
|
|
args.momentum = 0.9 |
|
|
args.no_nesterov = False |
|
|
args.weight_decay = 2e-4 |
|
|
args.label_smoothing = 0.0 |
|
|
args.no_decay_keys = "bn#bias" |
|
|
args.fp16_allreduce = False |
|
|
args.model_init = "he_fout" |
|
|
args.validation_frequency = 1 |
|
|
args.print_frequency = 10 |
|
|
args.n_worker = 32 |
|
|
if args.dataset =="imagenet": |
|
|
args.image_size = "224" |
|
|
else: |
|
|
args.image_size = "32" |
|
|
args.continuous_size = True |
|
|
args.not_sync_distributed_image_size = False |
|
|
args.bn_momentum = 0.1 |
|
|
args.bn_eps = 1e-5 |
|
|
args.dropout = 0.0 |
|
|
args.base_stage_width = "google" |
|
|
|
|
|
if args.model_name != "ResNet50": |
|
|
args.ks_list = '7' |
|
|
args.expand_list = '6' |
|
|
args.depth_list = '4' |
|
|
args.width_mult_list = "1.0" |
|
|
else: |
|
|
|
|
|
args.ks_list = "3" |
|
|
args.expand_list = "0.35" |
|
|
args.depth_list = "2" |
|
|
args.width_mult_list = "1.0" |
|
|
|
|
|
args.dy_conv_scaling_mode = 1 |
|
|
args.independent_distributed_sampling = False |
|
|
args.kd_ratio = 0.0 |
|
|
args.kd_type = "ce" |
|
|
args.dynamic_batch_size = 1 |
|
|
args.num_gpus = 4 |
|
|
if __name__ == "__main__": |
|
|
os.makedirs(args.path, exist_ok=True) |
|
|
|
|
|
|
|
|
hvd.init() |
|
|
|
|
|
torch.cuda.set_device(hvd.local_rank()) |
|
|
|
|
|
num_gpus = hvd.size() |
|
|
torch.manual_seed(args.manual_seed) |
|
|
torch.cuda.manual_seed_all(args.manual_seed) |
|
|
np.random.seed(args.manual_seed) |
|
|
random.seed(args.manual_seed) |
|
|
|
|
|
|
|
|
args.image_size = [int(img_size) for img_size in args.image_size.split(",")] |
|
|
if len(args.image_size) == 1: |
|
|
args.image_size = args.image_size[0] |
|
|
|
|
|
|
|
|
args.lr_schedule_param = None |
|
|
args.opt_param = { |
|
|
"momentum": args.momentum, |
|
|
"nesterov": not args.no_nesterov, |
|
|
} |
|
|
args.init_lr = args.base_lr * num_gpus |
|
|
if args.warmup_lr < 0: |
|
|
args.warmup_lr = args.base_lr |
|
|
args.train_batch_size = args.base_batch_size |
|
|
args.test_batch_size = args.base_batch_size |
|
|
print(args.__dict__) |
|
|
run_config = DistributedClassificationRunConfig( |
|
|
**args.__dict__,num_replicas=num_gpus, rank=hvd.rank() |
|
|
) |
|
|
|
|
|
|
|
|
if hvd.rank() == 0: |
|
|
print("Run config:") |
|
|
for k, v in run_config.config.items(): |
|
|
print("\t%s: %s" % (k, v)) |
|
|
|
|
|
if args.dy_conv_scaling_mode == -1: |
|
|
args.dy_conv_scaling_mode = None |
|
|
DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = args.dy_conv_scaling_mode |
|
|
|
|
|
|
|
|
args.width_mult_list = [ |
|
|
float(width_mult) for width_mult in args.width_mult_list.split(",") |
|
|
] |
|
|
args.ks_list = [int(ks) for ks in args.ks_list.split(",")] |
|
|
args.expand_list = [float(e) for e in args.expand_list.split(",")] |
|
|
args.depth_list = [int(d) for d in args.depth_list.split(",")] |
|
|
|
|
|
args.width_mult_list = ( |
|
|
args.width_mult_list[0] |
|
|
if len(args.width_mult_list) == 1 |
|
|
else args.width_mult_list |
|
|
) |
|
|
if args.model_name == "ResNet50": |
|
|
if args.dataset == "cifar10" or args.dataset == "cifar100": |
|
|
|
|
|
net = DYNResNets_Cifar( n_classes=run_config.data_provider.n_classes, |
|
|
bn_param=(args.bn_momentum, args.bn_eps), |
|
|
dropout_rate=args.dropout, |
|
|
depth_list=args.depth_list, |
|
|
expand_ratio_list=args.expand_list, |
|
|
width_mult_list=args.width_mult_list,) |
|
|
else: |
|
|
net = DYNResNets( n_classes=run_config.data_provider.n_classes, |
|
|
bn_param=(args.bn_momentum, args.bn_eps), |
|
|
dropout_rate=args.dropout, |
|
|
depth_list=args.depth_list, |
|
|
expand_ratio_list=args.expand_list, |
|
|
width_mult_list=args.width_mult_list,) |
|
|
elif args.model_name == "MBV3": |
|
|
if args.dataset == "cifar10" or args.dataset == "cifar100": |
|
|
net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps), |
|
|
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list) |
|
|
else: |
|
|
net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps), |
|
|
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list) |
|
|
elif args.model_name == "ProxylessNASNet": |
|
|
if args.dataset == "cifar10" or args.dataset == "cifar100": |
|
|
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps), |
|
|
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list) |
|
|
else: |
|
|
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps), |
|
|
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list) |
|
|
|
|
|
elif args.model_name == "MBV2": |
|
|
if args.dataset == "cifar10" or args.dataset == "cifar100": |
|
|
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),base_stage_width=args.base_stage_width, |
|
|
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list) |
|
|
else: |
|
|
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),base_stage_width=args.base_stage_width, |
|
|
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
if args.teacher_model_name == "WideResNet": |
|
|
if args.dataset == "cifar10" or args.dataset == "cifar100": |
|
|
net = WideResNet(num_classes=run_config.data_provider.n_classes) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
else: |
|
|
raise NotImplementedError |
|
|
args.teacher_model = None |
|
|
|
|
|
""" Distributed RunManager """ |
|
|
|
|
|
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none |
|
|
distributed_run_manager = DistributedRunManager( |
|
|
args.path, |
|
|
net, |
|
|
run_config, |
|
|
compression, |
|
|
backward_steps=args.dynamic_batch_size, |
|
|
is_root=(hvd.rank() == 0), |
|
|
) |
|
|
distributed_run_manager.save_config() |
|
|
distributed_run_manager.broadcast() |
|
|
|
|
|
|
|
|
distributed_run_manager.train(args) |
|
|
|