Spaces:
Build error
Build error
| from typing import Mapping, Any | |
| import importlib | |
| from torch import nn | |
| def get_obj_from_str(string: str, reload: bool=False) -> object: | |
| module, cls = string.rsplit(".", 1) | |
| if reload: | |
| module_imp = importlib.import_module(module) | |
| importlib.reload(module_imp) | |
| return getattr(importlib.import_module(module, package=None), cls) | |
| def instantiate_from_config(config: Mapping[str, Any]) -> object: | |
| if not "target" in config: | |
| raise KeyError("Expected key `target` to instantiate.") | |
| return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
| def disabled_train(self: nn.Module) -> nn.Module: | |
| """Overwrite model.train with this function to make sure train/eval mode | |
| does not change anymore.""" | |
| return self | |
| def frozen_module(module: nn.Module) -> None: | |
| module.eval() | |
| module.train = disabled_train | |
| for p in module.parameters(): | |
| p.requires_grad = False | |
| def load_state_dict(model: nn.Module, state_dict: Mapping[str, Any], strict: bool=False) -> None: | |
| state_dict = state_dict.get("state_dict", state_dict) | |
| is_model_key_starts_with_module = list(model.state_dict().keys())[0].startswith("module.") | |
| is_state_dict_key_starts_with_module = list(state_dict.keys())[0].startswith("module.") | |
| if ( | |
| is_model_key_starts_with_module and | |
| (not is_state_dict_key_starts_with_module) | |
| ): | |
| state_dict = {f"module.{key}": value for key, value in state_dict.items()} | |
| if ( | |
| (not is_model_key_starts_with_module) and | |
| is_state_dict_key_starts_with_module | |
| ): | |
| state_dict = {key[len("module."):]: value for key, value in state_dict.items()} | |
| model.load_state_dict(state_dict, strict=strict) | |