Spaces:
Runtime error
Runtime error
| import time | |
| import pickle | |
| import logging | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from collections import OrderedDict | |
| from yaml import safe_dump | |
| from yacs.config import load_cfg, CfgNode#, _to_dict | |
| from maskrcnn_benchmark.config import cfg | |
| from maskrcnn_benchmark.engine.inference import _accumulate_predictions_from_multiple_gpus | |
| from maskrcnn_benchmark.modeling.backbone.nas import get_layer_name | |
| from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, get_world_size, all_gather | |
| from maskrcnn_benchmark.data.datasets.evaluation import evaluate | |
| from maskrcnn_benchmark.utils.flops import profile | |
| choice = lambda x:x[np.random.randint(len(x))] if isinstance(x,tuple) else choice(tuple(x)) | |
| def gather_candidates(all_candidates): | |
| all_candidates = all_gather(all_candidates) | |
| all_candidates = [cand for candidates in all_candidates for cand in candidates] | |
| return list(set(all_candidates)) | |
| def gather_stats(all_candidates): | |
| all_candidates = all_gather(all_candidates) | |
| reduced_statcs = {} | |
| for candidates in all_candidates: | |
| reduced_statcs.update(candidates) # will replace the existing key with last value if more than one exists | |
| return reduced_statcs | |
| def compute_on_dataset(model, rngs, data_loader, device=cfg.MODEL.DEVICE): | |
| model.eval() | |
| results_dict = {} | |
| cpu_device = torch.device("cpu") | |
| for _, batch in enumerate(data_loader): | |
| images, targets, image_ids = batch | |
| with torch.no_grad(): | |
| output = model(images.to(device), rngs=rngs) | |
| output = [o.to(cpu_device) for o in output] | |
| results_dict.update( | |
| {img_id: result for img_id, result in zip(image_ids, output)} | |
| ) | |
| return results_dict | |
| def bn_statistic(model, rngs, data_loader, device=cfg.MODEL.DEVICE, max_iter=500): | |
| for name, param in model.named_buffers(): | |
| if 'running_mean' in name: | |
| nn.init.constant_(param, 0) | |
| if 'running_var' in name: | |
| nn.init.constant_(param, 1) | |
| model.train() | |
| for iteration, (images, targets, _) in enumerate(data_loader, 1): | |
| images = images.to(device) | |
| targets = [target.to(device) for target in targets] | |
| with torch.no_grad(): | |
| loss_dict = model(images, targets, rngs) | |
| if iteration >= max_iter: | |
| break | |
| return model | |
| def inference( | |
| model, | |
| rngs, | |
| data_loader, | |
| iou_types=("bbox",), | |
| box_only=False, | |
| device="cuda", | |
| expected_results=(), | |
| expected_results_sigma_tol=4, | |
| output_folder=None, | |
| ): | |
| # convert to a torch.device for efficiency | |
| device = torch.device(device) | |
| dataset = data_loader.dataset | |
| predictions = compute_on_dataset(model, rngs, data_loader, device) | |
| # wait for all processes to complete before measuring the time | |
| synchronize() | |
| predictions = _accumulate_predictions_from_multiple_gpus(predictions) | |
| if not is_main_process(): | |
| return | |
| extra_args = dict( | |
| box_only=box_only, | |
| iou_types=iou_types, | |
| expected_results=expected_results, | |
| expected_results_sigma_tol=expected_results_sigma_tol, | |
| ) | |
| return evaluate(dataset=dataset, | |
| predictions=predictions, | |
| output_folder=output_folder, | |
| **extra_args) | |
| def fitness(cfg, model, rngs, val_loaders): | |
| iou_types = ("bbox",) | |
| if cfg.MODEL.MASK_ON: | |
| iou_types = iou_types + ("segm",) | |
| for data_loader_val in val_loaders: | |
| results = inference( | |
| model, | |
| rngs, | |
| data_loader_val, | |
| iou_types=iou_types, | |
| box_only=False, | |
| device=cfg.MODEL.DEVICE, | |
| expected_results=cfg.TEST.EXPECTED_RESULTS, | |
| expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, | |
| ) | |
| synchronize() | |
| return results | |
| class EvolutionTrainer(object): | |
| def __init__(self, cfg, model, flops_limit=None, is_distributed=True): | |
| self.log_dir = cfg.OUTPUT_DIR | |
| self.checkpoint_name = os.path.join(self.log_dir,'evolution.pth') | |
| self.is_distributed = is_distributed | |
| self.states = model.module.mix_nums if is_distributed else model.mix_nums | |
| self.supernet_state_dict = pickle.loads(pickle.dumps(model.state_dict())) | |
| self.flops_limit = flops_limit | |
| self.model = model | |
| self.candidates = [] | |
| self.vis_dict = {} | |
| self.max_epochs = cfg.SEARCH.MAX_EPOCH | |
| self.select_num = cfg.SEARCH.SELECT_NUM | |
| self.population_num = cfg.SEARCH.POPULATION_NUM/get_world_size() | |
| self.mutation_num = cfg.SEARCH.MUTATION_NUM/get_world_size() | |
| self.crossover_num = cfg.SEARCH.CROSSOVER_NUM/get_world_size() | |
| self.mutation_prob = cfg.SEARCH.MUTATION_PROB/get_world_size() | |
| self.keep_top_k = {self.select_num:[], 50:[]} | |
| self.epoch=0 | |
| self.cfg = cfg | |
| def save_checkpoint(self): | |
| if not is_main_process(): | |
| return | |
| if not os.path.exists(self.log_dir): | |
| os.makedirs(self.log_dir) | |
| info = {} | |
| info['candidates'] = self.candidates | |
| info['vis_dict'] = self.vis_dict | |
| info['keep_top_k'] = self.keep_top_k | |
| info['epoch'] = self.epoch | |
| torch.save(info, self.checkpoint_name) | |
| print('Save checkpoint to', self.checkpoint_name) | |
| def load_checkpoint(self): | |
| if not os.path.exists(self.checkpoint_name): | |
| return False | |
| info = torch.load(self.checkpoint_name) | |
| self.candidates = info['candidates'] | |
| self.vis_dict = info['vis_dict'] | |
| self.keep_top_k = info['keep_top_k'] | |
| self.epoch = info['epoch'] | |
| print('Load checkpoint from', self.checkpoint_name) | |
| return True | |
| def legal(self, cand): | |
| assert isinstance(cand,tuple) and len(cand)==len(self.states) | |
| if cand in self.vis_dict: | |
| return False | |
| if self.flops_limit is not None: | |
| net = self.model.module.backbone if self.is_distributed else self.model.backbone | |
| inp = (1, 3, 224, 224) | |
| flops, params = profile(net, inp, extra_args={'paths': list(cand)}) | |
| flops = flops/1e6 | |
| print('flops:',flops) | |
| if flops>self.flops_limit: | |
| return False | |
| return True | |
| def update_top_k(self, candidates, *, k, key, reverse=False): | |
| assert k in self.keep_top_k | |
| # print('select ......') | |
| t = self.keep_top_k[k] | |
| t += candidates | |
| t.sort(key=key,reverse=reverse) | |
| self.keep_top_k[k]=t[:k] | |
| def eval_candidates(self, train_loader, val_loader): | |
| for cand in self.candidates: | |
| t0 = time.time() | |
| # load back supernet state dict | |
| self.model.load_state_dict(self.supernet_state_dict) | |
| # bn_statistic | |
| model = bn_statistic(self.model, list(cand), train_loader) | |
| # fitness | |
| evals = fitness(cfg, model, list(cand), val_loader) | |
| if is_main_process(): | |
| acc = evals[0].results['bbox']['AP'] | |
| self.vis_dict[cand] = acc | |
| print('candiate ', cand) | |
| print('time: {}s'.format(time.time() - t0)) | |
| print('acc ', acc) | |
| def stack_random_cand(self, random_func, *, batchsize=10): | |
| while True: | |
| cands = [random_func() for _ in range(batchsize)] | |
| for cand in cands: | |
| yield cand | |
| def random_can(self, num): | |
| # print('random select ........') | |
| candidates = [] | |
| cand_iter = self.stack_random_cand(lambda:tuple(np.random.randint(i) for i in self.states)) | |
| while len(candidates)<num: | |
| cand = next(cand_iter) | |
| if not self.legal(cand): | |
| continue | |
| candidates.append(cand) | |
| #print('random {}/{}'.format(len(candidates),num)) | |
| # print('random_num = {}'.format(len(candidates))) | |
| return candidates | |
| def get_mutation(self, k, mutation_num, m_prob): | |
| assert k in self.keep_top_k | |
| # print('mutation ......') | |
| res = [] | |
| iter = 0 | |
| max_iters = mutation_num*10 | |
| def random_func(): | |
| cand = list(choice(self.keep_top_k[k])) | |
| for i in range(len(self.states)): | |
| if np.random.random_sample()<m_prob: | |
| cand[i] = np.random.randint(self.states[i]) | |
| return tuple(cand) | |
| cand_iter = self.stack_random_cand(random_func) | |
| while len(res)<mutation_num and max_iters>0: | |
| cand = next(cand_iter) | |
| if not self.legal(cand): | |
| continue | |
| res.append(cand) | |
| #print('mutation {}/{}'.format(len(res),mutation_num)) | |
| max_iters-=1 | |
| # print('mutation_num = {}'.format(len(res))) | |
| return res | |
| def get_crossover(self, k, crossover_num): | |
| assert k in self.keep_top_k | |
| # print('crossover ......') | |
| res = [] | |
| iter = 0 | |
| max_iters = 10 * crossover_num | |
| def random_func(): | |
| p1=choice(self.keep_top_k[k]) | |
| p2=choice(self.keep_top_k[k]) | |
| return tuple(choice([i,j]) for i,j in zip(p1,p2)) | |
| cand_iter = self.stack_random_cand(random_func) | |
| while len(res)<crossover_num and max_iters>0: | |
| cand = next(cand_iter) | |
| if not self.legal(cand): | |
| continue | |
| res.append(cand) | |
| #print('crossover {}/{}'.format(len(res),crossover_num)) | |
| max_iters-=1 | |
| # print('crossover_num = {}'.format(len(res))) | |
| return res | |
| def train(self, train_loader, val_loader): | |
| logger = logging.getLogger("maskrcnn_benchmark.evolution") | |
| if not self.load_checkpoint(): | |
| self.candidates = gather_candidates(self.random_can(self.population_num)) | |
| while self.epoch<self.max_epochs: | |
| self.eval_candidates(train_loader, val_loader) | |
| self.vis_dict = gather_stats(self.vis_dict) | |
| self.update_top_k(self.candidates, k=self.select_num, key=lambda x:1-self.vis_dict[x]) | |
| self.update_top_k(self.candidates, k=50, key=lambda x:1-self.vis_dict[x]) | |
| if is_main_process(): | |
| logger.info('Epoch {} : top {} result'.format(self.epoch+1, len(self.keep_top_k[self.select_num]))) | |
| for i,cand in enumerate(self.keep_top_k[self.select_num]): | |
| logger.info(' No.{} {} perf = {}'.format(i+1, cand, self.vis_dict[cand])) | |
| mutation = gather_candidates(self.get_mutation(self.select_num, self.mutation_num, self.mutation_prob)) | |
| crossover = gather_candidates(self.get_crossover(self.select_num, self.crossover_num)) | |
| rand = gather_candidates(self.random_can(self.population_num - len(mutation) - len(crossover))) | |
| self.candidates = mutation + crossover + rand | |
| self.epoch+=1 | |
| self.save_checkpoint() | |
| def save_candidates(self, cand, template): | |
| paths = self.keep_top_k[self.select_num][cand-1] | |
| with open(template, "r") as f: | |
| super_cfg = load_cfg(f) | |
| search_spaces = {} | |
| for mix_ops in super_cfg.MODEL.BACKBONE.LAYER_SEARCH: | |
| search_spaces[mix_ops] = super_cfg.MODEL.BACKBONE.LAYER_SEARCH[mix_ops] | |
| search_layers = super_cfg.MODEL.BACKBONE.LAYER_SETUP | |
| layer_setup = [] | |
| for i, layer in enumerate(search_layers): | |
| name, setup = get_layer_name(layer, search_spaces) | |
| if not isinstance(name, list): | |
| name = [name] | |
| name = name[paths[i]] | |
| layer_setup.append("('{}', {})".format(name, str(setup)[1:-1])) | |
| super_cfg.MODEL.BACKBONE.LAYER_SETUP = layer_setup | |
| cand_cfg = _to_dict(super_cfg) | |
| del cand_cfg['MODEL']['BACKBONE']['LAYER_SEARCH'] | |
| with open(os.path.join(self.cfg.OUTPUT_DIR, os.path.basename(template)).replace('.yaml','_cand{}.yaml'.format(cand)), 'w') as f: | |
| f.writelines(safe_dump(cand_cfg)) | |
| super_weight = self.supernet_state_dict | |
| cand_weight = OrderedDict() | |
| cand_keys = ['layers.{}.ops.{}'.format(i, c) for i, c in enumerate(paths)] | |
| for key, val in super_weight.items(): | |
| if 'ops' in key: | |
| for ck in cand_keys: | |
| if ck in key: | |
| cand_weight[key.replace(ck,ck.split('.ops.')[0])] = val | |
| else: | |
| cand_weight[key] = val | |
| torch.save({'model':cand_weight}, os.path.join(self.cfg.OUTPUT_DIR, 'init_cand{}.pth'.format(cand))) | |