ProArd / train_ofa_net.py
smi08's picture
Upload 10 files
f6a2150 verified
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import argparse
import numpy as np
import os
import random
# using for distributed training
import horovod.torch as hvd
import torch
from proard.classification.elastic_nn.modules.dynamic_op import (
DynamicSeparableConv2d,
)
from proard.classification.elastic_nn.networks import DYNMobileNetV3,DYNProxylessNASNets,DYNResNets,DYNProxylessNASNets_Cifar,DYNMobileNetV3_Cifar,DYNResNets_Cifar
from proard.classification.run_manager import DistributedClassificationRunConfig
from proard.classification.run_manager.distributed_run_manager import (
DistributedRunManager
)
from proard.utils import download_url, MyRandomResizedCrop
from proard.classification.elastic_nn.training.progressive_shrinking import load_models
parser = argparse.ArgumentParser()
parser.add_argument(
"--task",
type=str,
default="expand",
choices=[
"kernel", # for architecture except ResNet
"depth",
"expand",
"width", # only for ResNet
],
)
parser.add_argument("--phase", type=int, default=2, choices=[1, 2])
parser.add_argument("--resume", action="store_true")
parser.add_argument("--model_name", type=str, default="MBV2", choices=["ResNet50", "MBV3", "ProxylessNASNet","MBV2"])
parser.add_argument("--dataset", type=str, default="cifar100", choices=["cifar10", "cifar100", "imagenet"])
parser.add_argument("--robust_mode", type=bool, default=True)
parser.add_argument("--epsilon", type=float, default=0.031)
parser.add_argument("--num_steps", type=int, default=10)
parser.add_argument("--step_size", type=float, default=0.0078)
parser.add_argument("--clip_min", type=int, default=0)
parser.add_argument("--clip_max", type=int, default=1)
parser.add_argument("--const_init", type=bool, default=False)
parser.add_argument("--beta", type=float, default=6.0)
parser.add_argument("--distance", type=str, default="l_inf",choices=["l_inf","l2"])
parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
parser.add_argument("--test_criterion", type=str, default="ce",choices=["ce"])
parser.add_argument("--kd_criterion", type=str, default="rslad",choices=["ard","rslad","adaad"])
parser.add_argument("--attack_type", type=str, default="linf-pgd",choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce'])
args = parser.parse_args()
if args.model_name == "ResNet50":
args.ks_list = "3"
if args.task == "width":
if args.robust_mode:
args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2width"
else:
args.path = "exp/"+ args.dataset + '/' +args.model_name +'/' + args.train_criterion +"/normal2width"
args.dynamic_batch_size = 1
args.n_epochs = 120
args.base_lr = 3e-2
args.warmup_epochs = 5
args.warmup_lr = -1
args.width_mult_list = "0.65,0.8,1.0"
args.expand_list = "0.35"
args.depth_list = "2"
elif args.task == "depth":
if args.robust_mode:
args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase%d" % args.phase
else:
args.path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase%d" % args.phase
args.dynamic_batch_size = 2
if args.phase == 1:
args.n_epochs = 25
args.base_lr = 2.5e-3
args.warmup_epochs = 0
args.warmup_lr = -1
args.width_mult_list = "0.65,0.8,1.0"
args.expand_list ="0.35"
args.depth_list = "1,2"
else:
args.n_epochs = 120
args.base_lr = 7.5e-3
args.warmup_epochs = 5
args.warmup_lr = -1
args.width_mult_list = "0.65,0.8,1.0"
args.expand_list = "0.35"
args.depth_list = "0,1,2"
elif args.task == "expand":
if args.robust_mode :
args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase%d" % args.phase
else:
args.path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase%d" % args.phase
args.dynamic_batch_size = 4
if args.phase == 1:
args.n_epochs = 25
args.base_lr = 2.5e-3
args.warmup_epochs = 0
args.warmup_lr = -1
args.width_mult_list = "0.65,0.8,1.0"
args.expand_list = "0.25,0.35"
args.depth_list = "0,1,2"
else:
args.n_epochs = 120
args.base_lr = 7.5e-3
args.warmup_epochs = 5
args.warmup_lr = -1
args.width_mult_list = "0.65,0.8,1.0"
args.expand_list = "0.2,0.25,0.35"
args.depth_list = "0,1,2"
else:
raise NotImplementedError
else:
args.width_mult_list = "1.0"
if args.task == "kernel":
if args.robust_mode:
args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel"
else:
args.path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel"
args.dynamic_batch_size = 1
args.n_epochs = 120
args.base_lr = 3e-2
args.warmup_epochs = 5
args.warmup_lr = -1
args.ks_list = "3,5,7"
args.expand_list = "6"
args.depth_list = "4"
elif args.task == "depth":
if args.robust_mode :
args.path = "exp/robust/"+args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase%d" % args.phase
else:
args.path = "exp/"+args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase%d" % args.phase
args.dynamic_batch_size = 2
if args.phase == 1:
args.n_epochs = 25
args.base_lr = 2.5e-3
args.warmup_epochs = 0
args.warmup_lr = -1
args.ks_list = "3,5,7"
args.expand_list = "6"
args.depth_list = "3,4"
else:
args.n_epochs = 120
args.base_lr = 7.5e-3
args.warmup_epochs = 5
args.warmup_lr = -1
args.ks_list = "3,5,7"
args.expand_list = "6"
args.depth_list = "2,3,4"
elif args.task == "expand":
if args.robust_mode:
args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase%d" % args.phase
else:
args.path = "exp/"+ args.dataset + '/' + args.model_name + '/' + args.train_criterion + "/kernel_depth2kernel_depth_width/phase%d" % args.phase
args.dynamic_batch_size = 4
if args.phase == 1:
args.n_epochs = 25
args.base_lr = 2.5e-3
args.warmup_epochs = 0
args.warmup_lr = -1
args.ks_list = "3,5,7"
args.expand_list = "4,6"
args.depth_list = "2,3,4"
else:
args.n_epochs = 120
args.base_lr = 7.5e-3
args.warmup_epochs = 5
args.warmup_lr = -1
args.ks_list = "3,5,7"
args.expand_list = "3,4,6"
args.depth_list = "2,3,4"
else:
raise NotImplementedError
args.manual_seed = 0
args.lr_schedule_type = "cosine"
args.base_batch_size = 64
args.valid_size = 64
args.opt_type = "sgd"
args.momentum = 0.9
args.no_nesterov = False
args.weight_decay = 3e-5
args.label_smoothing = 0.1
args.no_decay_keys = "bn#bias"
args.fp16_allreduce = False
args.model_init = "he_fout"
args.validation_frequency = 1
args.print_frequency = 10
args.n_worker = 8
args.resize_scale = 0.08
args.distort_color = "tf"
if args.dataset == "imagenet":
args.image_size = "128,160,192,224"
else:
args.image_size = "32"
args.continuous_size = True
args.not_sync_distributed_image_size = False
args.bn_momentum = 0.1
args.bn_eps = 1e-5
args.dropout = 0.1
args.base_stage_width = "google"
args.dy_conv_scaling_mode = 1
args.independent_distributed_sampling = False
args.kd_ratio = 1.0
args.kd_type = "ce"
if __name__ == "__main__":
os.makedirs(args.path, exist_ok=True)
# Initialize Horovod
hvd.init()
# Pin GPU to be used to process local rank (one GPU per process)
torch.cuda.set_device(hvd.local_rank())
if args.robust_mode:
args.teacher_path = 'exp/robust/teacher/' + args.dataset + '/' + args.model_name + '/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
else:
args.teacher_path = 'exp/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
num_gpus = hvd.size()
torch.manual_seed(args.manual_seed)
torch.cuda.manual_seed_all(args.manual_seed)
np.random.seed(args.manual_seed)
random.seed(args.manual_seed)
# image size
args.image_size = [int(img_size) for img_size in args.image_size.split(",")]
if len(args.image_size) == 1:
args.image_size = args.image_size[0]
MyRandomResizedCrop.CONTINUOUS = args.continuous_size
MyRandomResizedCrop.SYNC_DISTRIBUTED = not args.not_sync_distributed_image_size
# build run config from args
args.lr_schedule_param = None
args.opt_param = {
"momentum": args.momentum,
"nesterov": not args.no_nesterov,
}
args.init_lr = args.base_lr * num_gpus # linearly rescale the learning rate
if args.warmup_lr < 0:
args.warmup_lr = args.base_lr
args.train_batch_size = args.base_batch_size
args.test_batch_size = args.base_batch_size * 4
run_config = DistributedClassificationRunConfig(
**args.__dict__, num_replicas=num_gpus, rank=hvd.rank()
)
# print run config information
if hvd.rank() == 0:
print("Run config:")
for k, v in run_config.config.items():
print("\t%s: %s" % (k, v))
if args.dy_conv_scaling_mode == -1:
args.dy_conv_scaling_mode = None
DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = args.dy_conv_scaling_mode
# build net from args
args.width_mult_list = [
float(width_mult) for width_mult in args.width_mult_list.split(",")
]
args.ks_list = [int(ks) for ks in args.ks_list.split(",")]
if args.model_name == "ResNet50":
args.expand_list = [float(e) for e in args.expand_list.split(",")]
else:
args.expand_list = [int(e) for e in args.expand_list.split(",")]
args.depth_list = [int(d) for d in args.depth_list.split(",")]
args.width_mult_list = (
args.width_mult_list[0]
if len(args.width_mult_list) == 1
else args.width_mult_list
)
if args.model_name == "ResNet50":
if args.dataset == "cifar10" or args.dataset == "cifar100":
net = DYNResNets_Cifar( n_classes=run_config.data_provider.n_classes,
bn_param=(args.bn_momentum, args.bn_eps),
dropout_rate=args.dropout,
depth_list=args.depth_list,
expand_ratio_list=args.expand_list,
width_mult_list=args.width_mult_list,)
else:
net = DYNResNets( n_classes=run_config.data_provider.n_classes,
bn_param=(args.bn_momentum, args.bn_eps),
dropout_rate=args.dropout,
depth_list=args.depth_list,
expand_ratio_list=args.expand_list,
width_mult_list=args.width_mult_list,)
elif args.model_name == "MBV3":
if args.dataset == "cifar10" or args.dataset == "cifar100":
net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
else:
net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
elif args.model_name == "ProxylessNASNet":
if args.dataset == "cifar10" or args.dataset == "cifar100":
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
else:
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
elif args.model_name == "MBV2":
if args.dataset == "cifar10" or args.dataset == "cifar100":
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list,base_stage_width=args.base_stage_width)
else:
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list,base_stage_width=args.base_stage_width)
else:
raise NotImplementedError
# teacher model
if args.kd_ratio > 0:
if args.model_name =="ResNet50":
if args.dataset == "cifar10" or args.dataset == "cifar100":
args.teacher_model = DYNResNets_Cifar(
n_classes=run_config.data_provider.n_classes,
bn_param=(args.bn_momentum, args.bn_eps),
dropout_rate=args.dropout,
depth_list=[2],
expand_ratio_list=[0.35],
width_mult_list=[1.0],
)
else:
args.teacher_model = DYNResNets(
n_classes=run_config.data_provider.n_classes,
bn_param=(args.bn_momentum, args.bn_eps),
dropout_rate=args.dropout,
depth_list=[2],
expand_ratio_list=[0.35],
width_mult_list=[1.0],
)
elif args.model_name =="MBV3":
if args.dataset == "cifar10" or args.dataset == "cifar100":
args.teacher_model = DYNMobileNetV3_Cifar(
n_classes=run_config.data_provider.n_classes,
bn_param=(args.bn_momentum, args.bn_eps),
dropout_rate=0,
width_mult=1.0,
ks_list=[7],
expand_ratio_list=[6],
depth_list=[4]
)
else:
args.teacher_model = DYNMobileNetV3(
n_classes=run_config.data_provider.n_classes,
bn_param=(args.bn_momentum, args.bn_eps),
dropout_rate=0,
width_mult=1.0,
ks_list=[7],
expand_ratio_list=[6],
depth_list=[4]
)
elif args.model_name == "ProxylessNASNet":
if args.dataset == "cifar10" or args.dataset == "cifar100":
args.teacher_model = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
bn_param=(args.bn_momentum, args.bn_eps),
dropout_rate=0,
width_mult=1.0,
ks_list=[7],
expand_ratio_list=[6],
depth_list=[4])
else:
args.teacher_model = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
bn_param=(args.bn_momentum, args.bn_eps),
dropout_rate=0,
width_mult=1.0,
ks_list=[7],
expand_ratio_list=[6],
depth_list=[4])
elif args.model_name == "MBV2":
if args.dataset == "cifar10" or args.dataset == "cifar100":
args.teacher_model = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
bn_param=(args.bn_momentum, args.bn_eps),
dropout_rate=0,
width_mult=1.0,
ks_list=[7],
expand_ratio_list=[6],
depth_list=[4],base_stage_width=args.base_stage_width)
else:
args.teacher_model = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
bn_param=(args.bn_momentum, args.bn_eps),
dropout_rate=0,
width_mult=1.0,
ks_list=[7],
expand_ratio_list=[6],
depth_list=[4],base_stage_width=args.base_stage_width)
args.teacher_model.cuda()
""" Distributed RunManager """
# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
distributed_run_manager = DistributedRunManager(
args.path,
net,
run_config,
compression,
backward_steps=args.dynamic_batch_size,
is_root=(hvd.rank() == 0),
)
distributed_run_manager.save_config()
# hvd broadcast
distributed_run_manager.broadcast()
# load teacher net weights
if args.kd_ratio > 0:
load_models(
distributed_run_manager, args.teacher_model, model_path=args.teacher_path
)
# training
from proard.classification.elastic_nn.training.progressive_shrinking import (
validate,
train,
)
if args.model_name =="ResNet50":
validate_func_dict = {
"image_size_list": {224 if args.dataset == "imagenet" else 32}
if isinstance(args.image_size, int)
else sorted({160, 224}),
"width_mult_list": sorted({min(args.width_mult_list), max(args.width_mult_list)}),
"expand_ratio_list": sorted({min(args.expand_list), max(args.expand_list)}),
"depth_list": sorted({min(net.depth_list), max(net.depth_list)}),
}
else:
validate_func_dict = {
"image_size_list": {224 if args.dataset == "imagenet" else 32}
if isinstance(args.image_size, int)
else sorted({160, 224}),
"width_mult_list": [1.0],
"ks_list": sorted({min(args.ks_list), max(args.ks_list)}),
"expand_ratio_list": sorted({min(args.expand_list), max(args.expand_list)}),
"depth_list": sorted({min(net.depth_list), max(net.depth_list)}),
}
if args.task == "width":
from proard.classification.elastic_nn.training.progressive_shrinking import (
train_elastic_width_mult,
)
if distributed_run_manager.start_epoch == 0:
if args.robust_mode:
args.dyn_checkpoint_path ='exp/robust/teacher/' +args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
else:
args.dyn_checkpoint_path ='exp/teacher/' +args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
load_models(
distributed_run_manager,
distributed_run_manager.net,
args.dyn_checkpoint_path,
)
distributed_run_manager.write_log(
"%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
% validate(distributed_run_manager, is_test=True, **validate_func_dict),
"valid",
)
else:
assert args.resume
train_elastic_width_mult (train,distributed_run_manager,args,validate_func_dict)
elif args.task == "kernel":
validate_func_dict["ks_list"] = sorted(args.ks_list)
if distributed_run_manager.start_epoch == 0:
if args.robust_mode:
args.dyn_checkpoint_path ='exp/robust/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
else:
args.dyn_checkpoint_path ='exp/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
load_models(
distributed_run_manager,
distributed_run_manager.net,
args.dyn_checkpoint_path,
)
distributed_run_manager.write_log(
"%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
% validate(distributed_run_manager, is_test=True, **validate_func_dict),
"valid",
)
else:
assert args.resume
train(
distributed_run_manager,
args,
lambda _run_manager, epoch, is_test: validate(
_run_manager, epoch, is_test, **validate_func_dict
),
)
elif args.task == "depth":
from proard.classification.elastic_nn.training.progressive_shrinking import (
train_elastic_depth,
)
if args.robust_mode:
if args.model_name =="ResNet50":
if args.phase == 1:
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/normal2width" +"/checkpoint/model_best.pth.tar"
else:
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase1" + "/checkpoint/model_best.pth.tar"
else:
if args.phase == 1:
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel" +"/checkpoint/model_best.pth.tar"
else:
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase1" + "/checkpoint/model_best.pth.tar"
else :
if args.model_name =="ResNet50":
if args.phase == 1:
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/normal2width" +"/checkpoint/model_best.pth.tar"
else:
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase1" + "/checkpoint/model_best.pth.tar"
else:
if args.phase == 1:
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel" +"/checkpoint/model_best.pth.tar"
else:
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase1" + "/checkpoint/model_best.pth.tar"
train_elastic_depth(train, distributed_run_manager, args, validate_func_dict)
elif args.task == "expand":
from proard.classification.elastic_nn.training.progressive_shrinking import (
train_elastic_expand,
)
if args.robust_mode :
if args.model_name =="ResNet50":
if args.phase == 1:
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width2width_depth/phase2" + "/checkpoint/model_best.pth.tar"
else:
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
else:
if args.phase == 1:
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase2" + "/checkpoint/model_best.pth.tar"
else:
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
else:
if args.model_name =="ResNet50":
if args.phase == 1:
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width2width_depth/phase2" + "/checkpoint/model_best.pth.tar"
else:
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
else:
if args.phase == 1:
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase2" + "/checkpoint/model_best.pth.tar"
else:
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
train_elastic_expand(train, distributed_run_manager, args, validate_func_dict)
else:
raise NotImplementedError