Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| from copy import deepcopy | |
| from .base import FM_to_MD_Util | |
| from utils.common.log import logger | |
| from utils.dl.common.model import set_module, get_module, get_super_module | |
| from utils.dl.common.model import get_model_device, get_model_latency, get_model_size | |
| from utils.common.log import logger | |
| from typing import Optional, Tuple | |
| from transformers.models.clip.modeling_clip import CLIPAttention | |
| from transformers import CLIPVisionConfig | |
| class CLIPAttentionPrunable(CLIPAttention): | |
| """Multi-headed attention from 'Attention Is All You Need' paper""" | |
| def __init__(self): | |
| config = CLIPVisionConfig.from_pretrained('openai/clip-vit-base-patch16') | |
| super(CLIPAttentionPrunable, self).__init__(config) | |
| # def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
| # # print(tensor.size(), self.num_heads, self.head_dim, bsz) # torch.Size([1, 197, 192]) 8 64 1 | |
| # # head_dim should be modified | |
| # # 'b n (h d) -> b h n d', h = self.num_heads | |
| # if seq_len == -1: | |
| # seq_len = tensor.size(1) | |
| # # print(tensor.size(), bsz, seq_len, self.num_heads, -1) | |
| # return tensor.view(bsz, seq_len, self.num_heads, -1).transpose(1, 2).contiguous() | |
| # def forward( | |
| # self, | |
| # hidden_states: torch.Tensor, | |
| # attention_mask: Optional[torch.Tensor] = None, | |
| # causal_attention_mask: Optional[torch.Tensor] = None, | |
| # output_attentions: Optional[bool] = False, | |
| # ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
| # """Input shape: Batch x Time x Channel""" | |
| # bsz, tgt_len, embed_dim = hidden_states.size() | |
| # # get query proj | |
| # query_states = self.q_proj(hidden_states) * self.scale | |
| # key_states = self._shape(self.k_proj(hidden_states), -1, bsz) | |
| # value_states = self._shape(self.v_proj(hidden_states), -1, bsz) | |
| # proj_shape = (-1, tgt_len, self.head_dim) | |
| # # print(proj_shape, key_states.size()) | |
| # query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) | |
| # key_states = key_states.view(*proj_shape) | |
| # value_states = value_states.view(*proj_shape) | |
| # src_len = key_states.size(1) | |
| # attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
| # # if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): | |
| # # raise ValueError( | |
| # # f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" | |
| # # f" {attn_weights.size()}" | |
| # # ) | |
| # # apply the causal_attention_mask first | |
| # if causal_attention_mask is not None: | |
| # if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| # raise ValueError( | |
| # f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" | |
| # f" {causal_attention_mask.size()}" | |
| # ) | |
| # attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask | |
| # attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| # if attention_mask is not None: | |
| # if attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| # raise ValueError( | |
| # f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" | |
| # ) | |
| # attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask | |
| # attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| # attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
| # if output_attentions: | |
| # # this operation is a bit akward, but it's required to | |
| # # make sure that attn_weights keeps its gradient. | |
| # # In order to do so, attn_weights have to reshaped | |
| # # twice and have to be reused in the following | |
| # attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | |
| # attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) | |
| # else: | |
| # attn_weights_reshaped = None | |
| # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | |
| # attn_output = torch.bmm(attn_probs, value_states) | |
| # # if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): | |
| # # raise ValueError( | |
| # # f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" | |
| # # f" {attn_output.size()}" | |
| # # ) | |
| # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, -1) | |
| # attn_output = attn_output.transpose(1, 2) | |
| # attn_output = attn_output.reshape(bsz, tgt_len, -1) | |
| # attn_output = self.out_proj(attn_output) | |
| # return attn_output, attn_weights_reshaped | |
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | |
| def _shape_dynamic_head_dim(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
| return tensor.view(bsz, seq_len, self.num_heads, -1).transpose(1, 2).contiguous() | |
| def _shape_dynamic_num_head(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
| return tensor.view(bsz, seq_len, -1, self.head_dim).transpose(1, 2).contiguous() | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| causal_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
| """Input shape: Batch x Time x Channel""" | |
| bsz, tgt_len, embed_dim = hidden_states.size() | |
| # logger.info(f'hidden state size: {hidden_states.size()}') # (64, 197, 768) | |
| # get query proj | |
| query_states = self.q_proj(hidden_states) * self.scale | |
| key_states = self._shape_dynamic_head_dim(self.k_proj(hidden_states), tgt_len, bsz) | |
| value_states = self._shape_dynamic_head_dim(self.v_proj(hidden_states), tgt_len, bsz) | |
| # (64, 197, 768), numhead: 12, head_dim: 64, seq_len: 197 | |
| # logger.info(f'key states: {self.k_proj(hidden_states).size()}, bsz: {bsz}, num_heads: {self.num_heads}, head_dim: {self.head_dim}, ' | |
| # f'seq_len: {self.k_proj(hidden_states).numel() / bsz / self.num_heads / self.head_dim}') | |
| # (64, 197, 768), numhead: 12, head_dim: 64, seq_len: 197 | |
| # logger.info(f'value states: {self.v_proj(hidden_states).size()}, bsz: {bsz}, num_heads: {self.num_heads}, head_dim: {self.head_dim}, ' | |
| # f'seq_len: {self.v_proj(hidden_states).numel() / bsz / self.num_heads / self.head_dim}') | |
| proj_shape = (bsz * self.num_heads, tgt_len, -1) | |
| query_states = self._shape_dynamic_head_dim(query_states, tgt_len, bsz).view(*proj_shape) | |
| # (64, 12, 197, 64), -1 means 197 | |
| # logger.info(f'query states: {self._shape(query_states, tgt_len, bsz).size()}, ' | |
| # f'-1 in proj_shape: {self._shape(query_states, tgt_len, bsz).numel() / bsz / self.num_heads / self.head_dim}') | |
| key_states = key_states.view(*proj_shape) | |
| value_states = value_states.view(*proj_shape) | |
| src_len = key_states.size(1) | |
| attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
| if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" | |
| f" {attn_weights.size()}" | |
| ) | |
| # apply the causal_attention_mask first | |
| if causal_attention_mask is not None: | |
| if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" | |
| f" {causal_attention_mask.size()}" | |
| ) | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| if attention_mask is not None: | |
| if attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" | |
| ) | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
| if output_attentions: | |
| # this operation is a bit akward, but it's required to | |
| # make sure that attn_weights keeps its gradient. | |
| # In order to do so, attn_weights have to reshaped | |
| # twice and have to be reused in the following | |
| attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | |
| attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) | |
| else: | |
| attn_weights_reshaped = None | |
| attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | |
| attn_output = torch.bmm(attn_probs, value_states) | |
| # if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): | |
| # raise ValueError( | |
| # f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" | |
| # f" {attn_output.size()}" | |
| # ) | |
| # print(attn_output.size(), bsz, tgt_len, embed_dim) | |
| attn_output = attn_output.view(bsz, self.num_heads, tgt_len, -1) | |
| attn_output = attn_output.transpose(1, 2) | |
| attn_output = attn_output.reshape(bsz, tgt_len, -1) | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output, attn_weights_reshaped | |
| # reduce num_head | |
| # def forward( | |
| # self, | |
| # hidden_states: torch.Tensor, | |
| # attention_mask: Optional[torch.Tensor] = None, | |
| # causal_attention_mask: Optional[torch.Tensor] = None, | |
| # output_attentions: Optional[bool] = False, | |
| # ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
| # """Input shape: Batch x Time x Channel""" | |
| # bsz, tgt_len, embed_dim = hidden_states.size() | |
| # # logger.info(f'hidden state size: {hidden_states.size()}') # (64, 197, 768) | |
| # # get query proj | |
| # query_states = self.q_proj(hidden_states) * self.scale | |
| # key_states = self._shape_dynamic_num_head(self.k_proj(hidden_states), tgt_len, bsz) | |
| # value_states = self._shape_dynamic_num_head(self.v_proj(hidden_states), tgt_len, bsz) | |
| # # (64, 197, 768), numhead: 12, head_dim: 64, seq_len: 197 | |
| # # logger.info(f'key states: {self.k_proj(hidden_states).size()}, bsz: {bsz}, num_heads: {self.num_heads}, head_dim: {self.head_dim}, ' | |
| # # f'seq_len: {self.k_proj(hidden_states).numel() / bsz / self.num_heads / self.head_dim}') | |
| # # (64, 197, 768), numhead: 12, head_dim: 64, seq_len: 197 | |
| # # logger.info(f'value states: {self.v_proj(hidden_states).size()}, bsz: {bsz}, num_heads: {self.num_heads}, head_dim: {self.head_dim}, ' | |
| # # f'seq_len: {self.v_proj(hidden_states).numel() / bsz / self.num_heads / self.head_dim}') | |
| # proj_shape = (-1, tgt_len, self.head_dim) | |
| # query_states = self._shape_dynamic_head_dim(query_states, tgt_len, bsz).view(*proj_shape) | |
| # # (64, 12, 197, 64), -1 means 197 | |
| # # logger.info(f'query states: {self._shape(query_states, tgt_len, bsz).size()}, ' | |
| # # f'-1 in proj_shape: {self._shape(query_states, tgt_len, bsz).numel() / bsz / self.num_heads / self.head_dim}') | |
| # key_states = key_states.view(*proj_shape) | |
| # value_states = value_states.view(*proj_shape) | |
| # src_len = key_states.size(1) | |
| # attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
| # # if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): | |
| # # raise ValueError( | |
| # # f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" | |
| # # f" {attn_weights.size()}" | |
| # # ) | |
| # # apply the causal_attention_mask first | |
| # if causal_attention_mask is not None: | |
| # if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| # raise ValueError( | |
| # f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" | |
| # f" {causal_attention_mask.size()}" | |
| # ) | |
| # attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask | |
| # attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| # if attention_mask is not None: | |
| # if attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| # raise ValueError( | |
| # f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" | |
| # ) | |
| # attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask | |
| # attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| # attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
| # if output_attentions: | |
| # # this operation is a bit akward, but it's required to | |
| # # make sure that attn_weights keeps its gradient. | |
| # # In order to do so, attn_weights have to reshaped | |
| # # twice and have to be reused in the following | |
| # attn_weights_reshaped = attn_weights.view(bsz, -1, tgt_len, src_len) | |
| # attn_weights = attn_weights_reshaped.view(-1, tgt_len, src_len) | |
| # else: | |
| # attn_weights_reshaped = None | |
| # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | |
| # attn_output = torch.bmm(attn_probs, value_states) | |
| # # if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): | |
| # # raise ValueError( | |
| # # f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" | |
| # # f" {attn_output.size()}" | |
| # # ) | |
| # # print(attn_output.size(), bsz, tgt_len, embed_dim) | |
| # attn_output = attn_output.view(bsz, -1, tgt_len, self.head_dim) | |
| # attn_output = attn_output.transpose(1, 2) | |
| # attn_output = attn_output.reshape(bsz, tgt_len, -1) | |
| # attn_output = self.out_proj(attn_output) | |
| # return attn_output, attn_weights_reshaped | |
| def init_from_exist_self_attn(attn: CLIPAttention): | |
| # print(attn) | |
| res = CLIPAttentionPrunable() | |
| for attr in dir(attn): | |
| # if str(attr) in ['transpose_for_scores'] or str(attr).startswith('_'): | |
| # continue | |
| # if isinstance(getattr(attn, attr), nn.Module): | |
| # print(attr) | |
| if isinstance(getattr(attn, attr), nn.Module): | |
| try: | |
| # print(attr, 'ok') | |
| setattr(res, attr, getattr(attn, attr)) | |
| except Exception as e: | |
| print(attr, str(e)) | |
| return res | |
| from einops import rearrange, repeat | |
| from einops.layers.torch import Rearrange | |
| class PrunableAttention(nn.Module): | |
| """ | |
| https://github.com/lucidrains/vit-pytorch | |
| """ | |
| def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qkv_bias = False): | |
| super().__init__() | |
| self.inner_dim = inner_dim = dim_head * heads | |
| project_out = not (heads == 1 and dim_head == dim) | |
| self.num_heads = heads | |
| self.scale = dim_head ** -0.5 | |
| self.attend = nn.Softmax(dim = -1) | |
| self.dropout = nn.Dropout(dropout) | |
| self.qkv = nn.Linear(dim, inner_dim * 3, bias = qkv_bias) | |
| # self.proj = nn.Sequential( | |
| # nn.Linear(inner_dim, dim), | |
| # nn.Dropout(dropout) | |
| # ) if project_out else nn.Identity() | |
| self.proj = nn.Linear(inner_dim, dim) if project_out else nn.Identity() | |
| self.proj_dropout = nn.Dropout(dropout) | |
| def forward(self, hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| causal_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = False,): | |
| x = hidden_states | |
| assert attention_mask is None | |
| assert causal_attention_mask is None | |
| assert not output_attentions | |
| # qkv = self.qkv(x).chunk(3, dim = -1) | |
| raw_qkv = self.qkv(x) | |
| self.inner_dim = (raw_qkv.size(-1) - self.proj.in_features) // 2 | |
| qkv = raw_qkv[:, :, 0: self.inner_dim], raw_qkv[:, :, self.inner_dim: self.inner_dim * 2], raw_qkv[:, :, self.inner_dim * 2:] | |
| # print('v', qkv[0].size(), qkv[0].sum((0, 1))[0: 10], qkv[0].sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
| # raw_v = qkv[2] | |
| # print('after_fbs_q, after_fbs_k', qkv[0].sum((0, 1))[0: 10], qkv[0].sum((0, 1)).nonzero(as_tuple=True)[0].size(), | |
| # qkv[1].sum((0, 1))[0: 10], qkv[1].sum((0, 1)).nonzero(as_tuple=True)[0].size(),) | |
| # print('after_fbs_v', raw_v.size(), raw_v.sum((0, 1))[0: 10], raw_v.sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
| # print('q, before rearrage', qkv[0].size()) | |
| q, k, v = qkv | |
| # print('raw qkv size', q.size(), k.size(), v.size()) | |
| # exit() | |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.num_heads), qkv) | |
| # print('raw qkv size', q.size(), k.size(), v.size()) | |
| dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale | |
| # print('q, k, dots, after rearrage', q.size(), k.transpose(-1, -2).size(), dots.size()) | |
| attn = self.attend(dots) | |
| # attn = dots | |
| attn = self.dropout(attn) | |
| # print(attn) | |
| # print('attn', attn.size(), attn.sum((0, 1))[0: 10], attn.sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
| # print('attn', attn.size(), attn.sum((0, 1))[0: 10], attn.sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
| # print('v2', v.size()) | |
| out = torch.matmul(attn, v) | |
| # print('out1', out.size()) | |
| # NOTE: just for trial debug | |
| # out = v | |
| # print('out before rerange', out.size()) | |
| # print(v.size(), v) | |
| # exit() | |
| out = rearrange(out, 'b h n d -> b n (h d)') | |
| # print('out', out.size(), out.sum((0, 1))[0: 10], out.sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
| # exit() | |
| res = self.proj_dropout(self.proj(out)) | |
| # res = self.proj_dropout( | |
| # F.linear(self.proj.weight.T, out.T, self.proj.bias) | |
| # ) | |
| # print(self.proj, self.proj_dropout) | |
| # print('res', res.size(), res.sum((0, 1))[0: 10], res.sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
| return res, None | |
| class FM_to_MD_CLIP_Util(FM_to_MD_Util): | |
| def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int) -> nn.Module: | |
| fm_vit = deepcopy(fm) | |
| # for block in fm_vit.model.text_model.encoder.layers: | |
| # set_module(block, 'self_attn', CLIPAttentionPrunable.init_from_exist_self_attn(block.self_attn)) | |
| debug_input = torch.rand((1, 3, 32, 32)).cuda() | |
| fm.eval() | |
| o1 = fm.model.vision_model(debug_input).pooler_output | |
| for block in fm_vit.model.vision_model.encoder.layers: | |
| # set_module(block, 'self_attn', CLIPAttentionPrunable.init_from_exist_self_attn(block.self_attn)) | |
| attn: CLIPAttention = block.self_attn | |
| # from dnns.vit import PrunableAttention | |
| new_attn = PrunableAttention( | |
| dim=768, | |
| heads=12, | |
| dim_head=64, | |
| dropout=0, | |
| qkv_bias=True | |
| ) | |
| new_attn.qkv.weight.data.copy_(torch.cat([ | |
| attn.q_proj.weight, | |
| attn.k_proj.weight, | |
| attn.v_proj.weight | |
| ], dim=0)) | |
| new_attn.qkv.bias.data.copy_(torch.cat([ | |
| attn.q_proj.bias, | |
| attn.k_proj.bias, | |
| attn.v_proj.bias | |
| ], dim=0)) | |
| new_attn.proj.weight.data.copy_(attn.out_proj.weight) | |
| new_attn.proj.bias.data.copy_(attn.out_proj.bias) | |
| set_module(block, 'self_attn', new_attn) | |
| o2 = fm.model.vision_model(debug_input).pooler_output | |
| # NOTE: bug is here!!! | |
| # although the diff is ZERO, but the logic of CLIPAttentionPrunable is incorrect!!!! | |
| diff = ((o1 - o2) ** 2).sum() | |
| print('diff before/after adding CLIPAttentionPrunable', diff) | |
| assert diff < 1e-4 | |
| # print('\n\nDEBUG: WITHOUT ADDING CLIPAttentionPrunable\n\n') | |
| # exit() | |
| # return fm | |
| def _f(n): | |
| return int(n // reducing_width_ratio) | |
| # def _rand_indexes(n): | |
| # return torch.randperm(n)[0: int(n // reducing_width_ratio)] | |
| def l1_max_indexes(p: torch.Tensor, dim=0): | |
| assert dim in [0, 1] | |
| assert p.dim() in [1, 2, 4] | |
| if dim == 1: | |
| p = p.T | |
| p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1) | |
| n = p.size(0) | |
| res = p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0] | |
| # print(res) | |
| return res | |
| # first_attn = True | |
| # for block_i, block in enumerate(fm_vit.model.text_model.encoder.layers): | |
| # for k in ['k_proj', 'q_proj', 'v_proj']: | |
| # qkv = get_module(block, f'self_attn.{k}') | |
| # new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
| # qkv.bias is not None, qkv.weight.device) | |
| # indexes = l1_max_indexes(qkv.weight.data, 0) | |
| # new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
| # if qkv.bias is not None: | |
| # new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
| # set_module(block, f'self_attn.{k}', new_qkv) | |
| # proj = block.self_attn.out_proj | |
| # new_proj = nn.Linear(_f(proj.in_features), proj.out_features, | |
| # proj.bias is not None, proj.weight.device) | |
| # new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) | |
| # if proj.bias is not None: | |
| # new_proj.bias.data.copy_(proj.bias.data) | |
| # set_module(block, f'self_attn.out_proj', new_proj) | |
| # fc1 = block.mlp.fc1 | |
| # new_fc1 = nn.Linear(fc1.in_features, _f(fc1.out_features), | |
| # fc1.bias is not None, fc1.weight.device) | |
| # indexes = l1_max_indexes(fc1.weight.data, 0) | |
| # new_fc1.weight.data.copy_(fc1.weight.data[indexes]) | |
| # if fc1.bias is not None: | |
| # new_fc1.bias.data.copy_(fc1.bias.data[indexes]) | |
| # set_module(block, f'mlp.fc1', new_fc1) | |
| # fc2 = block.mlp.fc2 | |
| # new_fc2 = nn.Linear(_f(fc2.in_features), fc2.out_features, | |
| # fc2.bias is not None, fc2.weight.device) | |
| # new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes(fc2.weight.data, 1)]) | |
| # if fc2.bias is not None: | |
| # new_fc2.bias.data.copy_(fc2.bias.data) | |
| # set_module(block, f'mlp.fc2', new_fc2) | |
| for block_i, block in enumerate(fm_vit.model.vision_model.encoder.layers): | |
| # for k in ['k_proj', 'q_proj', 'v_proj']: | |
| # qkv = get_module(block, f'self_attn.{k}') | |
| # new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
| # qkv.bias is not None, qkv.weight.device) | |
| # indexes = l1_max_indexes(qkv.weight.data, 0) | |
| # new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
| # if qkv.bias is not None: | |
| # new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
| # set_module(block, f'self_attn.{k}', new_qkv) | |
| # proj = block.self_attn.out_proj | |
| # new_proj = nn.Linear(_f(proj.in_features), proj.out_features, | |
| # proj.bias is not None, proj.weight.device) | |
| # new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) | |
| # if proj.bias is not None: | |
| # new_proj.bias.data.copy_(proj.bias.data) | |
| # set_module(block, f'self_attn.out_proj', new_proj) | |
| # ------------------ | |
| qkv = block.self_attn.qkv | |
| new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
| qkv.bias is not None, qkv.weight.device) | |
| indexes = l1_max_indexes(qkv.weight.data, 0) | |
| new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
| if qkv.bias is not None: | |
| new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
| set_module(block, f'self_attn.qkv', new_qkv) | |
| proj = block.self_attn.proj | |
| new_proj = nn.Linear(_f(proj.in_features), proj.out_features, | |
| proj.bias is not None, proj.weight.device) | |
| new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) | |
| if proj.bias is not None: | |
| new_proj.bias.data.copy_(proj.bias.data) | |
| set_module(block, f'self_attn.proj', new_proj) | |
| # -------------------- | |
| fc1 = block.mlp.fc1 | |
| new_fc1 = nn.Linear(fc1.in_features, _f(fc1.out_features), | |
| fc1.bias is not None, fc1.weight.device) | |
| indexes = l1_max_indexes(fc1.weight.data, 0) | |
| new_fc1.weight.data.copy_(fc1.weight.data[indexes]) | |
| if fc1.bias is not None: | |
| new_fc1.bias.data.copy_(fc1.bias.data[indexes]) | |
| set_module(block, f'mlp.fc1', new_fc1) | |
| fc2 = block.mlp.fc2 | |
| new_fc2 = nn.Linear(_f(fc2.in_features), fc2.out_features, | |
| fc2.bias is not None, fc2.weight.device) | |
| new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes(fc2.weight.data, 1)]) | |
| if fc2.bias is not None: | |
| new_fc2.bias.data.copy_(fc2.bias.data) | |
| set_module(block, f'mlp.fc2', new_fc2) | |
| return fm_vit | |
| def init_md_from_fm_by_reducing_width_with_perf_test(self, fm: nn.Module, reducing_width_ratio: int, | |
| samples: torch.Tensor) -> nn.Module: | |
| fm_size = get_model_size(fm, True) | |
| fm_latency = self._get_model_latency(fm, samples, 20, | |
| get_model_device(fm), 20, False) | |
| master_dnn = self.init_md_from_fm_by_reducing_width(fm, reducing_width_ratio) | |
| master_dnn_size = get_model_size(master_dnn, True) | |
| logger.debug(f'inited master DNN: {master_dnn}') | |
| # from utils.dl.common.model import get_module | |
| # print('after generating') | |
| # get_module(fm, 'head').debug() | |
| # get_module(master_dnn, 'head').debug() | |
| # print('test master latency') | |
| master_dnn_latency = self._get_model_latency(master_dnn, samples, 20, | |
| get_model_device(master_dnn), 20, False) | |
| logger.info(f'init master DNN (w/o FBS yet) by reducing foundation model\'s width (by {reducing_width_ratio:d}x)') | |
| logger.info(f'foundation model ({fm_size:.3f}MB, {fm_latency:.4f}s/sample) -> ' | |
| f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample)\n' | |
| f'(model size: ↓ {(fm_size / master_dnn_size):.2f}x, ' | |
| f'latency: ↓ {(fm_latency / master_dnn_latency):.2f}x)') | |
| return master_dnn | |
| def _get_model_latency(self, model: torch.nn.Module, model_input_size, sample_num: int, | |
| device: str, warmup_sample_num: int, return_detail=False): | |
| import time | |
| if isinstance(model_input_size, tuple): | |
| dummy_input = torch.rand(model_input_size).to(device) | |
| else: | |
| dummy_input = model_input_size | |
| model = model.to(device) | |
| model.eval() | |
| # warm up | |
| with torch.no_grad(): | |
| for _ in range(warmup_sample_num): | |
| model(**dummy_input) | |
| infer_time_list = [] | |
| if device == 'cuda' or 'cuda' in str(device): | |
| with torch.no_grad(): | |
| for _ in range(sample_num): | |
| s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
| s.record() | |
| model(**dummy_input) | |
| e.record() | |
| torch.cuda.synchronize() | |
| cur_model_infer_time = s.elapsed_time(e) / 1000. | |
| infer_time_list += [cur_model_infer_time] | |
| else: | |
| with torch.no_grad(): | |
| for _ in range(sample_num): | |
| start = time.time() | |
| model(**dummy_input) | |
| cur_model_infer_time = time.time() - start | |
| infer_time_list += [cur_model_infer_time] | |
| avg_infer_time = sum(infer_time_list) / sample_num | |
| if return_detail: | |
| return avg_infer_time, infer_time_list | |
| return avg_infer_time |