| import torch | |
| import numpy as np | |
| from abc import ABC, abstractmethod | |
| from torch import nn | |
| from hydra.utils import instantiate | |
| import copy | |
| from peft import LoraConfig, get_peft_model | |
| from utils.model_utils import print_trainable_parameters | |
| def freeze(model): | |
| """Freezes the parameters of a model.""" | |
| for p in model.parameters(): | |
| p.requires_grad = False | |
| model.eval() | |
| def unfreeze(model): | |
| """Unfreezes the parameters of a model. | |
| for p in model.parameters(): | |
| p.requires_grad = True""" | |
| model_parameters = model.named_parameters() | |
| for name, param in model_parameters: | |
| if name in [ | |
| "clip.vision_model.post_layernorm.weight", | |
| "clip.vision_model.post_layernorm.bias", | |
| ]: | |
| param.requires_grad = False | |
| else: | |
| param.requires_grad = True | |
| model.train() | |
| def unfreeze_last(model): | |
| """Unfreezes the parameters of a model. | |
| for p in model.parameters(): | |
| p.requires_grad = True""" | |
| model_parameters = model.named_parameters() | |
| for name, param in model_parameters: | |
| if len(name.split(".")) > 5: | |
| if name.split(".")[4] == "11": | |
| param.requires_grad = True | |
| else: | |
| param.requires_grad = False | |
| else: | |
| param.requires_grad = False | |
| model.train() | |
| class FrozenBackbone(nn.Module): | |
| """Freezes the backbone of a network.""" | |
| def __init__(self, backbone, mid, head): | |
| super().__init__() | |
| self.backbone = backbone.instance | |
| self.mid = mid.instance | |
| self.head = head.instance | |
| self.target_key = head.target_key | |
| freeze(self.backbone) | |
| def forward(self, x): | |
| """Forward pass of the network. | |
| x : Union[torch.Tensor, dict] with the output of the backbone. | |
| """ | |
| with torch.no_grad(): | |
| x = self.backbone(x) | |
| x = self.mid(x) | |
| x = self.head(x) | |
| return x | |
| class UnfrozenBackbone(nn.Module): | |
| """Unfreezes the backbone of a network.""" | |
| def __init__(self, backbone, mid, head): | |
| super().__init__() | |
| self.backbone = backbone.instance | |
| self.mid = mid.instance | |
| self.head = head.instance | |
| self.target_key = head.target_key | |
| unfreeze(self.backbone) | |
| def forward(self, x): | |
| """Forward pass of the network. | |
| x : Union[torch.Tensor, dict] with the output of the backbone. | |
| """ | |
| x = self.backbone(x) | |
| x = self.mid(x) | |
| x = self.head(x) | |
| return x | |
| class UnfrozenPartBackbone(nn.Module): | |
| """Unfreezes the backbone of a network.""" | |
| def __init__(self, backbone, mid, head): | |
| super().__init__() | |
| self.backbone = backbone.instance | |
| self.mid = mid.instance | |
| self.head = head.instance | |
| self.target_key = head.target_key | |
| unfreeze_last(self.backbone) | |
| def forward(self, x): | |
| """Forward pass of the network. | |
| x : Union[torch.Tensor, dict] with the output of the backbone. | |
| """ | |
| x = self.backbone(x) | |
| x = self.mid(x) | |
| x = self.head(x) | |
| return x | |
| class NoFeatureBackbone(nn.Module): | |
| """Randomizes the backbone of a network.""" | |
| def __init__(self, head): | |
| super().__init__() | |
| self.head = head.instance | |
| self.target_key = head.target_key | |
| def forward(self, x): | |
| """Forward pass of the network. | |
| x : Union[torch.Tensor, dict] with the output of the backbone. | |
| """ | |
| return self.head(x) | |
| class ContrastiveFrozenBackbone(FrozenBackbone): | |
| """Freezes the backbone of a network.""" | |
| def __init__(self, backbone, mid, head, mode): | |
| super().__init__(backbone, mid, head) | |
| self.mode = mode | |
| def forward(self, x): | |
| with torch.no_grad(): | |
| features = self.backbone(x) | |
| if self.mode != "eval": | |
| x_pos = { | |
| k.strip("pos_"): v.clone() | |
| if isinstance(v, torch.Tensor) | |
| else copy.deepcopy(v) | |
| for k, v in x.items() | |
| if k.startswith("pos_") | |
| } | |
| pos_features = self.backbone(x_pos) | |
| x = self.mid(features) | |
| x = self.head(x) | |
| if self.mode != "eval": | |
| return { | |
| "features": features[:, 0, :], | |
| "pos_features": pos_features[:, 0, :], | |
| **x, | |
| } | |
| return { | |
| "features": features[:, 0, :], | |
| **x, | |
| } | |
| class ContrastiveUnFrozenPartBackbone(UnfrozenPartBackbone): | |
| """Freezes the backbone of a network.""" | |
| def __init__(self, backbone, mid, head, mode): | |
| super().__init__(backbone, mid, head) | |
| self.mode = mode | |
| def forward(self, x): | |
| features = self.backbone(x) | |
| if self.mode != "eval": | |
| x_pos = { | |
| k.strip("pos_"): v.clone() | |
| if isinstance(v, torch.Tensor) | |
| else copy.deepcopy(v) | |
| for k, v in x.items() | |
| if k.startswith("pos_") | |
| } | |
| pos_features = self.backbone(x_pos) | |
| x = self.mid(features) | |
| x = self.head(x) | |
| if self.mode != "eval": | |
| return { | |
| "features": features[:, 0, :], | |
| "pos_features": pos_features[:, 0, :], | |
| **x, | |
| } | |
| return { | |
| "features": features[:, 0, :], | |
| **x, | |
| } | |
| class ContrastiveUnFrozenBackbone(UnfrozenBackbone): | |
| """Freezes the backbone of a network.""" | |
| def __init__(self, backbone, mid, head, mode): | |
| super().__init__(backbone, mid, head) | |
| self.mode = mode | |
| def forward(self, x): | |
| features = self.backbone(x) | |
| if self.mode != "eval": | |
| x_pos = { | |
| k.strip("pos_"): v.clone() | |
| if isinstance(v, torch.Tensor) | |
| else copy.deepcopy(v) | |
| for k, v in x.items() | |
| if k.startswith("pos_") | |
| } | |
| pos_features = self.backbone(x_pos) | |
| x = self.mid(features) | |
| x = self.head(x) | |
| if self.mode != "eval": | |
| return { | |
| "features": features[:, 0, :], | |
| "pos_features": pos_features[:, 0, :], | |
| **x, | |
| } | |
| return { | |
| "features": features[:, 0, :], | |
| **x, | |
| } | |
| class TextContrastiveUnFrozenBackbone(UnfrozenBackbone): | |
| """Freezes the backbone of a network.""" | |
| def __init__(self, backbone, mid, head): | |
| super().__init__(backbone, mid, head) | |
| def forward(self, x): | |
| con, features = self.backbone(x) | |
| x = self.mid(features) | |
| x = self.head(x) | |
| return { | |
| "features": con, | |
| **x, | |
| } | |
| class LoraBackbone(nn.Module): | |
| """Wraps the backbone in a PEFT model for LoRA tuning.""" | |
| def __init__(self, backbone, mid, head, r, alpha, dropout, bias): | |
| super().__init__() | |
| self.backbone = backbone.instance | |
| self.mid = mid.instance | |
| self.head = head.instance | |
| self.target_key = head.target_key | |
| freeze(self.backbone) | |
| config = LoraConfig( | |
| r=r, | |
| lora_alpha=alpha, | |
| lora_dropout=dropout, | |
| bias=bias, | |
| target_modules=["q_proj", "k_proj", "v_proj"], | |
| ) | |
| self.backbone = get_peft_model(self.backbone, config) | |
| print_trainable_parameters(self) | |
| def forward(self, x): | |
| """Forward pass of the network. | |
| x : Union[torch.Tensor, dict] with the output of the backbone. | |
| """ | |
| x = self.backbone(x) | |
| x = self.mid(x) | |
| return self.head(x) | |
| class HybridFrozenBackbone(FrozenBackbone): | |
| """Freezes the backbone of a network.""" | |
| def forward(self, x): | |
| """Forward pass of the network. | |
| x : Union[torch.Tensor, dict] with the output of the backbone. | |
| """ | |
| gt_label = x["label"] if self.training else None | |
| with torch.no_grad(): | |
| x = self.backbone(x) | |
| x = self.mid(x) | |
| x = self.head(x, gt_label) | |
| return x | |
| class HybridUnfrozenBackbone(UnfrozenBackbone): | |
| """Unfreezes the backbone of a network.""" | |
| def forward(self, x): | |
| """Forward pass of the network. | |
| x : Union[torch.Tensor, dict] with the output of the backbone. | |
| """ | |
| gt_label = x["label"] if self.training else None | |
| x = self.backbone(x) | |
| x = self.mid(x) | |
| x = self.head(x, gt_label) | |
| return x | |
| class ContrastiveHybridUnFrozenBackbone(UnfrozenBackbone): | |
| """Freezes the backbone of a network.""" | |
| def __init__(self, backbone, mid, head, mode): | |
| super().__init__(backbone, mid, head) | |
| self.mode = mode | |
| def forward(self, x): | |
| gt_label = x["label"] if self.training else None | |
| features = self.backbone(x) | |
| if self.mode != "eval": | |
| x_pos = { | |
| k.strip("pos_"): v.clone() | |
| if isinstance(v, torch.Tensor) | |
| else copy.deepcopy(v) | |
| for k, v in x.items() | |
| if k.startswith("pos_") | |
| } | |
| pos_features = self.backbone(x_pos) | |
| x = self.mid(features) | |
| x = self.head(x, gt_label) | |
| if self.mode != "eval": | |
| return { | |
| "features": features[:, 0, :], | |
| "pos_features": pos_features[:, 0, :], | |
| **x, | |
| } | |
| return { | |
| "features": features[:, 0, :], | |
| **x, | |
| } | |