| from typing import Any | |
| import torch | |
| import torch.nn | |
| import torch.optim | |
| def load_pretrained_model( | |
| init_param: str, | |
| model: torch.nn.Module, | |
| map_location: str = "cpu", | |
| ): | |
| """Load a model state and set it to the model. | |
| Args: | |
| init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys> | |
| Examples: | |
| >>> load_pretrained_model("somewhere/model.pth", model) | |
| >>> load_pretrained_model("somewhere/model.pth:decoder:decoder", model) | |
| >>> load_pretrained_model("somewhere/model.pth:decoder:decoder:", model) | |
| >>> load_pretrained_model( | |
| ... "somewhere/model.pth:decoder:decoder:decoder.embed", model | |
| ... ) | |
| >>> load_pretrained_model("somewhere/decoder.pth::decoder", model) | |
| """ | |
| sps = init_param.split(":", 4) | |
| if len(sps) == 4: | |
| path, src_key, dst_key, excludes = sps | |
| elif len(sps) == 3: | |
| path, src_key, dst_key = sps | |
| excludes = None | |
| elif len(sps) == 2: | |
| path, src_key = sps | |
| dst_key, excludes = None, None | |
| else: | |
| (path,) = sps | |
| src_key, dst_key, excludes = None, None, None | |
| if src_key == "": | |
| src_key = None | |
| if dst_key == "": | |
| dst_key = None | |
| if dst_key is None: | |
| obj = model | |
| else: | |
| def get_attr(obj: Any, key: str): | |
| """Get an nested attribute. | |
| >>> class A(torch.nn.Module): | |
| ... def __init__(self): | |
| ... super().__init__() | |
| ... self.linear = torch.nn.Linear(10, 10) | |
| >>> a = A() | |
| >>> assert A.linear.weight is get_attr(A, 'linear.weight') | |
| """ | |
| if key.strip() == "": | |
| return obj | |
| for k in key.split("."): | |
| obj = getattr(obj, k) | |
| return obj | |
| obj = get_attr(model, dst_key) | |
| src_state = torch.load(path, map_location=map_location) | |
| if excludes is not None: | |
| for e in excludes.split(","): | |
| src_state = {k: v for k, v in src_state.items() if not k.startswith(e)} | |
| if src_key is not None: | |
| src_state = { | |
| k[len(src_key) + 1 :]: v | |
| for k, v in src_state.items() | |
| if k.startswith(src_key) | |
| } | |
| dst_state = obj.state_dict() | |
| dst_state.update(src_state) | |
| obj.load_state_dict(dst_state) | |