Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| from abc import ABC, abstractmethod | |
| from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, set_module | |
| from utils.common.log import logger | |
| from .base import FMLoRA_Util, LoRA | |
| class ToQKV_WrappedWithLoRA(nn.Module): | |
| def __init__(self, fc: nn.Linear, ab_r: int): | |
| super(ToQKV_WrappedWithLoRA, self).__init__() | |
| self.fc = fc | |
| self.ab = self.create_ab_as_linear(fc.weight.data, ab_r) | |
| def create_ab_as_linear(self, fc_weight: torch.Tensor, ab_r: int): | |
| res = nn.Sequential( | |
| LoRA(fc_weight.size(1), fc_weight.size(0) // ab_r, bias=False), | |
| LoRA(fc_weight.size(0) // ab_r, fc_weight.size(0), bias=False) | |
| ).to(fc_weight.device) | |
| nn.init.kaiming_uniform_(res[0].weight, a=5 ** 0.5) | |
| nn.init.zeros_(res[1].weight) | |
| return res | |
| def forward(self, x): | |
| x1 = self.fc(x) | |
| x2 = self.ab(x) | |
| return x1 + x2 | |
| class FMLoRA_Vilt_Util(FMLoRA_Util): | |
| def add_lora_ab_to_fm(self, fm: nn.Module, ab_r: int, samples: dict): | |
| fm.eval() | |
| # print(samples) | |
| for k, v in samples.items(): | |
| if isinstance(v, torch.Tensor): | |
| samples[k] = v.to(get_model_device(fm)) | |
| o1 = fm(**samples) | |
| for name, module in fm.named_modules(): | |
| if name.endswith(('query', 'key', 'value')): | |
| set_module(fm, name, ToQKV_WrappedWithLoRA(module, ab_r)) | |
| o2 = fm(**samples) | |
| if isinstance(o1, tuple): | |
| o1 = o1[-1] | |
| o2 = o2[-1] | |
| output_diff = ((o1.logits - o2.logits) ** 2).sum() | |
| assert output_diff < 1e-5 | |
| return fm | |
| def absorb_lora_and_recover_net_structure(self, fm: nn.Module, samples: dict): | |
| fm.eval() | |
| # print('absorb lora before') | |
| for k, v in samples.items(): | |
| if isinstance(v, torch.Tensor): | |
| samples[k] = v.to(get_model_device(fm)) | |
| o1 = fm(**samples) | |
| for name, module in fm.named_modules(): | |
| if not isinstance(module, ToQKV_WrappedWithLoRA): | |
| continue | |
| fc = module.fc | |
| ab = module.ab | |
| fc.weight.add_(ab[1].weight @ ab[0].weight) | |
| set_module(fm, name, fc) | |
| # print('absorb lora after') | |
| o2 = fm(**samples) | |
| if isinstance(o1, tuple): | |
| o1 = o1[-1] | |
| o2 = o2[-1] | |
| output_diff = ((o1.logits - o2.logits) ** 2).sum() | |
| assert output_diff < 1e-6, output_diff | |
| return fm | |