Spaces:
Running
Running
| from typing import Any, Dict | |
| from schema import Schema, Or | |
| import schema | |
| from data import Scenario, MergedDataset | |
| from methods.base.alg import BaseAlg | |
| from data import build_dataloader | |
| from ..model import ElasticDNN_OfflineFMModel, ElasticDNN_OfflineMDModel | |
| from ...model.base import ElasticDNNUtil | |
| import torch.optim | |
| import tqdm | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from utils.dl.common.env import create_tbwriter | |
| import os | |
| import random | |
| import numpy as np | |
| from copy import deepcopy | |
| from utils.dl.common.model import LayerActivation, get_module | |
| from utils.common.log import logger | |
| class ElasticDNN_MDPretrainingAlg(BaseAlg): | |
| """ | |
| construct indexes between a filter/row of MD and all filters/rows of FM in the same layer | |
| too huge indexes (~1GB), train so slow, hard to optimize | |
| """ | |
| def get_required_models_schema(self) -> Schema: | |
| return Schema({ | |
| 'fm': ElasticDNN_OfflineFMModel, | |
| 'md': ElasticDNN_OfflineMDModel | |
| }) | |
| def get_required_hyp_schema(self) -> Schema: | |
| return Schema({ | |
| 'launch_tbboard': bool, | |
| 'samples_size': (int, int, int, int), | |
| 'generate_md_width_ratio': int, | |
| 'FBS_r': int, | |
| 'FBS_ignore_layers': [str], | |
| 'train_batch_size': int, | |
| 'val_batch_size': int, | |
| 'num_workers': int, | |
| 'optimizer': str, | |
| 'md_optimizer_args': dict, | |
| 'indexes_optimizer_args': dict, | |
| 'scheduler': str, | |
| 'scheduler_args': dict, | |
| 'num_iters': int, | |
| 'val_freq': int, | |
| 'max_sparsity': float, | |
| 'min_sparsity': float, | |
| 'distill_loss_weight': float, | |
| 'index_loss_weight': float, | |
| 'val_num_sparsities': int, | |
| 'bn_cal_num_iters': int, | |
| 'index_guided_linear_comb_split_size': Or(int, None) | |
| }) | |
| def upsample_2d_tensor(self, p: torch.Tensor, target_len: int): | |
| assert p.dim() == 2 # regard 2d weight as (batch_size, 1d_vector_dim) | |
| return F.upsample(p.unsqueeze(1).unsqueeze(3), | |
| size=(target_len, 1), | |
| mode='bilinear').squeeze(3).squeeze(1) | |
| def two_params_diff_fast(self, trained_p: torch.Tensor, ref_p: torch.Tensor, | |
| index: torch.Tensor, | |
| split_size: int): | |
| assert trained_p.dim() == ref_p.dim() | |
| assert index.size(0) == trained_p.size(0) and index.size(1) == ref_p.size(0) | |
| # print(trained_p.size(), ref_p.size(), index.size()) | |
| ref_p = ref_p.detach() | |
| if trained_p.dim() > 1: | |
| trained_p = trained_p.flatten(1) | |
| ref_p = ref_p.flatten(1) | |
| # the weight size of master DNN and foundation model may be totally different | |
| # MD -> FM: upsample first | |
| # FM -> MD: downsample first | |
| if trained_p.size(1) < ref_p.size(1): | |
| trained_p = self.upsample_2d_tensor(trained_p, ref_p.size(1)) | |
| index = index.unsqueeze(-1) | |
| # linear_combed_ref_p = (ref_p.unsqueeze(0) * index).sum(1) | |
| # else: | |
| # print(trained_p.size(), ref_p.size(), index.size()) | |
| if split_size is None: | |
| # old version: huge memory consumption, not recommended (although this is fastest) | |
| # print('old version') | |
| linear_combed_ref_p = (ref_p.unsqueeze(0) * index).sum(1) | |
| else: | |
| # new version | |
| linear_combed_ref_p = 0 | |
| cur_split_size = split_size | |
| while index.size(1) % cur_split_size != 0: | |
| cur_split_size -= 1 | |
| # print(cur_split_size) | |
| for i in range(0, index.size(1), cur_split_size): | |
| # if not isinstance(linear_combed_ref_p, int): | |
| # print(linear_combed_ref_p.size(), ref_p.unsqueeze(0)[:, i: i + cur_split_size].size(), index[:, i: i + cur_split_size].size()) | |
| linear_combed_ref_p += ref_p.unsqueeze(0)[:, i: i + cur_split_size] * index[:, i: i + cur_split_size] | |
| linear_combed_ref_p = linear_combed_ref_p.sum(1) | |
| diff = (linear_combed_ref_p - trained_p).norm(2) ** 2 | |
| return diff | |
| def get_index_loss(self, fm, md, indexes, match_fn, split_size): | |
| res = 0. | |
| for name, p in md.named_parameters(): | |
| if p.dim() == 0: | |
| continue | |
| raw_p = match_fn(name, fm) | |
| if raw_p is None: | |
| continue | |
| index = indexes[name] | |
| # print(name) | |
| res += self.two_params_diff_fast(p, raw_p, index, split_size) | |
| return res | |
| def bn_cal(self, model: nn.Module, train_loader, num_iters, device): | |
| has_bn = False | |
| for n, m in model.named_modules(): | |
| if isinstance(m, nn.BatchNorm2d): | |
| has_bn = True | |
| break | |
| if not has_bn: | |
| return {} | |
| def bn_calibration_init(m): | |
| """ calculating post-statistics of batch normalization """ | |
| if getattr(m, 'track_running_stats', False): | |
| # reset all values for post-statistics | |
| m.reset_running_stats() | |
| # set bn in training mode to update post-statistics | |
| m.training = True | |
| with torch.no_grad(): | |
| model.eval() | |
| model.apply(bn_calibration_init) | |
| for _ in range(num_iters): | |
| x, _ = next(train_loader) | |
| model(x.to(device)) | |
| model.eval() | |
| bn_stats = {} | |
| for n, m in model.named_modules(): | |
| if isinstance(m, nn.BatchNorm2d): | |
| bn_stats[n] = m | |
| return bn_stats | |
| def run(self, scenario: Scenario, hyps: Dict) -> Dict[str, Any]: | |
| super().run(scenario, hyps) | |
| # sanity check | |
| # a= torch.tensor([[1, 2, 3], [1, 2, 4]]) | |
| # index = torch.tensor([[1, 2, 3], | |
| # [1, 2, 4]]) | |
| # b = torch.tensor([[1, 2, 3], [1, 2, 4], [2, 3, 4]]) | |
| # print(self.two_params_diff_fast(a, b, index, hyps['index_guided_linear_comb_split_size'])) | |
| assert isinstance(self.models['md'], ElasticDNN_OfflineMDModel) # for auto completion | |
| assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) # for auto completion | |
| # 1. add FBS | |
| device = self.models['md'].device | |
| logger.info(f'init master DNN by reducing width of an adapted foundation model (already tuned by LoRA)...') | |
| before_fm_model = deepcopy(self.models['fm'].models_dict['main']) | |
| lora_util = self.models['fm'].get_lora_util() | |
| lora_absorbed_fm_model = lora_util.absorb_lora_and_recover_net_structure(self.models['fm'].models_dict['main'], | |
| torch.rand(hyps['samples_size']).to(device)) | |
| self.models['fm'].models_dict['main'] = lora_absorbed_fm_model | |
| master_dnn = self.models['fm'].generate_md_by_reducing_width(hyps['generate_md_width_ratio'], | |
| torch.rand(hyps['samples_size']).to(device)) | |
| self.models['fm'].models_dict['main'] = before_fm_model | |
| elastic_dnn_util = self.models['fm'].get_elastic_dnn_util() | |
| master_dnn = elastic_dnn_util.convert_raw_dnn_to_master_dnn_with_perf_test(master_dnn, | |
| hyps['FBS_r'], hyps['FBS_ignore_layers']) | |
| self.models['md'].models_dict['main'] = master_dnn | |
| self.models['md'].to(device) | |
| # 2. train (knowledge distillation, index relationship) | |
| offline_datasets = scenario.get_offline_datasets() | |
| train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()]) | |
| val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()]) | |
| train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'], | |
| True, None)) | |
| val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'], | |
| False, False) | |
| # 2.1 train only FBS (skipped because current md cannot do proper inference) | |
| # 2.2 train whole master DNN (knowledge distillation, index relationship) | |
| for p in master_dnn.parameters(): | |
| p.requires_grad = True | |
| self.models['md'].to_train_mode() | |
| indexes = {} | |
| for name, p in self.models['md'].models_dict['main'].named_parameters(): | |
| if p.dim() > 0: | |
| matched_p_in_fm = self.models['md'].get_matched_param_of_fm(name, self.models['fm'].models_dict['main']) | |
| if matched_p_in_fm is None: | |
| continue | |
| indexes[name] = torch.zeros((p.size(0), matched_p_in_fm.size(0))).to(device) | |
| indexes[name].requires_grad = True | |
| tmp_indexes_file_path = os.path.join(self.res_save_dir, 'tmp-indexes.pt') | |
| torch.save(indexes, tmp_indexes_file_path) | |
| logger.info(f'generate indexes ({(os.path.getsize(tmp_indexes_file_path) / 1024**2):.3f}MB)') | |
| os.remove(tmp_indexes_file_path) | |
| optimizer = torch.optim.__dict__[hyps['optimizer']]([ | |
| {'params': self.models['md'].models_dict['main'].parameters(), **hyps['md_optimizer_args']}, | |
| {'params': [v for v in indexes.values()], **hyps['indexes_optimizer_args']} | |
| ]) | |
| scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args']) | |
| tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard']) | |
| pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True) | |
| best_avg_val_acc = 0. | |
| md_output_hook = None | |
| for iter_index in pbar: | |
| self.models['md'].to_train_mode() | |
| self.models['fm'].to_eval_mode() | |
| rand_sparsity = random.random() * (hyps['max_sparsity'] - hyps['min_sparsity']) + hyps['min_sparsity'] | |
| elastic_dnn_util.set_master_dnn_sparsity(self.models['md'].models_dict['main'], rand_sparsity) | |
| x, y = next(train_loader) | |
| x, y = x.to(device), y.to(device) | |
| with torch.no_grad(): | |
| fm_output = self.models['fm'].infer(x) | |
| if md_output_hook is None: | |
| md_output_hook = LayerActivation(self.models['md'].models_dict['main'], False, device) | |
| task_loss = self.models['md'].forward_to_get_task_loss(x, y) | |
| md_output = md_output_hook.output | |
| distill_loss = hyps['distill_loss_weight'] * self.models['md'].get_distill_loss(md_output, fm_output) | |
| index_loss = hyps['index_loss_weight'] * self.get_index_loss(self.models['fm'].models_dict['main'], | |
| self.models['md'].models_dict['main'], | |
| indexes, | |
| self.models['md'].get_matched_param_of_fm, | |
| hyps['index_guided_linear_comb_split_size']) | |
| total_loss = task_loss + distill_loss + index_loss | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| if (iter_index + 1) % hyps['val_freq'] == 0: | |
| elastic_dnn_util.clear_cached_channel_attention_in_master_dnn(self.models['md'].models_dict['main']) | |
| md_output_hook.remove() | |
| md_output_hook = None | |
| cur_md = self.models['md'].models_dict['main'] | |
| md_for_test = deepcopy(self.models['md'].models_dict['main']) | |
| val_accs = {} | |
| avg_val_acc = 0. | |
| bn_stats = {} | |
| for val_sparsity in np.linspace(hyps['min_sparsity'], hyps['max_sparsity'], num=hyps['val_num_sparsities']): | |
| elastic_dnn_util.set_master_dnn_sparsity(md_for_test, val_sparsity) | |
| bn_stats[f'{val_sparsity:.4f}'] = self.bn_cal(md_for_test, train_loader, hyps['bn_cal_num_iters'], device) | |
| self.models['md'].models_dict['main'] = md_for_test | |
| self.models['md'].to_eval_mode() | |
| val_acc = self.models['md'].get_accuracy(val_loader) | |
| val_accs[f'{val_sparsity:.4f}'] = val_acc | |
| avg_val_acc += val_acc | |
| avg_val_acc /= hyps['val_num_sparsities'] | |
| self.models['md'].models_dict['main'] = cur_md | |
| self.models['md'].models_dict['indexes'] = indexes | |
| self.models['md'].models_dict['bn_stats'] = bn_stats | |
| self.models['fm'].models_dict['indexes'] = indexes | |
| self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_last.pt')) | |
| self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt')) | |
| if avg_val_acc > best_avg_val_acc: | |
| best_avg_val_acc = avg_val_acc | |
| self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_best.pt')) | |
| self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt')) | |
| tb_writer.add_scalars(f'losses', dict(task=task_loss, distill=distill_loss, index=index_loss, total=total_loss), iter_index) | |
| pbar.set_description(f'loss: {total_loss:.6f}') | |
| if (iter_index + 1) >= hyps['val_freq']: | |
| tb_writer.add_scalars(f'accs/val_accs', val_accs, iter_index) | |
| tb_writer.add_scalar(f'accs/avg_val_acc', avg_val_acc, iter_index) | |
| pbar.set_description(f'loss: {total_loss:.6f}, avg_val_acc: {avg_val_acc:.4f}') | |