File size: 2,539 Bytes
f582ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from pathlib import Path
import torch
import torch.nn as nn
from utils.torch_utilities import load_pretrained_model, merge_matched_keys
import warnings

class LoadPretrainedBase(nn.Module):
    def process_state_dict(
        self, model_dict: dict[str, torch.Tensor],
        state_dict: dict[str, torch.Tensor]
    ):
        """
        Custom processing functions of each model that transforms `state_dict` loaded from 
        checkpoints to the state that can be used in `load_state_dict`.
        Use `merge_mathced_keys` to update parameters with matched names and shapes by 
        default.  

        Args
            model_dict:
                The state dict of the current model, which is going to load pretrained parameters
            state_dict:
                A dictionary of parameters from a pre-trained model.

            Returns:
                dict[str, torch.Tensor]:
                    The updated state dict, where parameters with matched keys and shape are 
                    updated with values in `state_dict`.      
        """
        state_dict = merge_matched_keys(model_dict, state_dict)
        return state_dict

    def load_pretrained(self, ckpt_path: str | Path):
        load_pretrained_model(
            self, ckpt_path, state_dict_process_fn=self.process_state_dict
        )


class CountParamsBase(nn.Module):
    def count_params(self):
        num_params = 0
        trainable_params = 0
        for param in self.parameters():
            num_params += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
        return num_params, trainable_params


class SaveTrainableParamsBase(nn.Module):
    @property
    def param_names_to_save(self):
        names = []
        for name, param in self.named_parameters():
            if param.requires_grad:
                names.append(name)
        for name, _ in self.named_buffers():
            names.append(name)
        return names

    def load_state_dict(self, state_dict, strict=True, assign=True):
        print("State dict keys:", list(state_dict.keys()))
        #for key in self.param_names_to_save:
        #    if key not in state_dict:
        #        raise Exception(
        #            f"{key} not found in either pre-trained models (e.g. BERT)"
        #            " or resumed checkpoints (e.g. epoch_40/model.pt)"
        #        )
        # 兼容 PyTorch/transformers 的 assign 参数
        return super().load_state_dict(state_dict, strict=strict, assign=assign)