Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| from methods.elasticdnn.model.base import ElasticDNNUtil | |
| def test(raw_dnn: nn.Module, ignore_layers, elastic_dnn_util: ElasticDNNUtil, input_sample: torch.Tensor, sparsity): | |
| # raw_dnn.eval() | |
| # with torch.no_grad(): | |
| # raw_dnn(input_sample) | |
| master_dnn = elastic_dnn_util.convert_raw_dnn_to_master_dnn_with_perf_test(raw_dnn, 16, ignore_layers) | |
| # print(master_dnn) | |
| # exit() | |
| elastic_dnn_util.set_master_dnn_sparsity(master_dnn, sparsity) | |
| # master_dnn.eval() | |
| # with torch.no_grad(): | |
| # master_dnn(input_sample) | |
| surrogate_dnn = elastic_dnn_util.extract_surrogate_dnn_via_samples_with_perf_test(master_dnn, input_sample) | |
| if __name__ == '__main__': | |
| from utils.dl.common.env import set_random_seed | |
| set_random_seed(1) | |
| # from torchvision.models import resnet50 | |
| # from methods.elasticdnn.model.cnn import ElasticCNNUtil | |
| # raw_cnn = resnet50() | |
| # prunable_layers = [] | |
| # for i in range(1, 5): | |
| # for j in range([3, 4, 6, 3][i - 1]): | |
| # prunable_layers += [f'layer{i}.{j}.conv1', f'layer{i}.{j}.conv2'] | |
| # ignore_layers = [layer for layer, m in raw_cnn.named_modules() if isinstance(m, nn.Conv2d) and layer not in prunable_layers] | |
| # test(raw_cnn, ignore_layers, ElasticCNNUtil(), torch.rand(1, 3, 224, 224)) | |
| ignore_layers = [] | |
| from methods.elasticdnn.model.vit import ElasticViTUtil | |
| # raw_vit = torch.load('tmp-master-dnn.pt') | |
| raw_vit = torch.load('') | |
| test(raw_vit, ignore_layers, ElasticViTUtil(), torch.rand(16, 3, 224, 224).cuda(), 0.9) | |
| exit() | |
| from dnns.vit import vit_b_16 | |
| # from methods.elasticdnn.model.vit_new import ElasticViTUtil | |
| from methods.elasticdnn.model.vit import ElasticViTUtil | |
| # raw_vit = vit_b_16() | |
| for s in [0.8, 0.9, 0.95]: | |
| raw_vit = vit_b_16().cuda() | |
| ignore_layers = [] | |
| test(raw_vit, ignore_layers, ElasticViTUtil(), torch.rand(16, 3, 224, 224).cuda(), s) | |
| # for s in [0, 0.2, 0.4, 0.6, 0.8]: | |
| # pretrained_md_models_dict_path = 'experiments/elasticdnn/vit_b_16/offline/fm_to_md/results/20230518/999999-164524-wo_FBS_trial_dsnet_lr/models/md_best.pt' | |
| # raw_vit = torch.load(pretrained_md_models_dict_path)['main'].cuda() | |
| # ignore_layers = [] | |
| # test(raw_vit, ignore_layers, ElasticViTUtil(), torch.rand(16, 3, 224, 224).cuda(), s) | |
| # exit() | |
| # weight = torch.rand((10, 5)) | |
| # bias = torch.rand(10) | |
| # x = torch.rand((1, 3, 5)) | |
| # t = torch.randperm(5) | |
| # pruned, unpruned = t[0: 3], t[3: ] | |
| # mask = torch.ones_like(x) | |
| # mask[:, :, pruned] = 0 | |
| # print(x, x * mask, (x * mask).sum((0, 1))) | |
| # import torch.nn.functional as F | |
| # o1 = F.linear(x * mask, weight, bias) | |
| # # print(o1) | |
| # o2 = F.linear(x[:, :, unpruned], weight[:, unpruned], bias) | |
| # # print(o2) | |
| # print(o1.size(), o2.size(), ((o1 - o2) ** 2).sum()) | |
| # weight = torch.rand((130, 5)) | |
| # bias = torch.rand(130) | |
| # x = torch.rand((1, 3, 5)) | |
| # t = torch.randperm(5) | |
| # pruned, unpruned = t[0: 3], t[3: ] | |
| # mask = torch.ones_like(x) | |
| # mask[:, :, pruned] = 0 | |
| # print(x, x * mask, (x * mask).sum((0, 1))) | |
| # import torch.nn.functional as F | |
| # o1 = F.linear(x * mask, weight, bias) | |
| # # print(o1) | |
| # o2 = F.linear(x[:, :, unpruned], weight[:, unpruned], bias) | |
| # # print(o2) | |
| # print(o1.size(), o2.size(), ((o1 - o2) ** 2).sum()) | |
| # weight = torch.rand((1768, 768)) | |
| # bias = torch.rand(1768) | |
| # x = torch.rand([1, 197, 768]) | |
| # t = torch.randperm(768) | |
| # unpruned, pruned = t[0: 144], t[144: ] | |
| # unpruned = unpruned.sort()[0] | |
| # pruned = pruned.sort()[0] | |
| # mask = torch.ones_like(x) | |
| # mask[:, :, pruned] = 0 | |
| # print(x.sum((0, 1)).size(), (x * mask).sum((0, 1))[0: 10], x[:, :, unpruned].sum((0, 1))[0: 10]) | |
| # import torch.nn.functional as F | |
| # o1 = F.linear(x * mask, weight, bias) | |
| # o2 = F.linear(x[:, :, unpruned], weight[:, unpruned], bias) | |
| # print(o1.sum((0, 1))[0: 10], o2.sum((0, 1))[0: 10], o1.size(), o2.size(), ((o1 - o2).abs()).sum(), ((o1 - o2) ** 2).sum()) | |
| # unpruned_indexes = torch.randperm(5)[0: 2] | |
| # o2 = F.linear(x[:, unpruned_indexes], weight[:, unpruned_indexes]) | |
| # print(o2) |