Spaces:
Runtime error
Runtime error
| """ | |
| Base class for trainable models. | |
| """ | |
| from abc import ABCMeta, abstractmethod | |
| from copy import copy | |
| import omegaconf | |
| from omegaconf import OmegaConf | |
| from torch import nn | |
| class MetaModel(ABCMeta): | |
| def __prepare__(name, bases, **kwds): | |
| total_conf = OmegaConf.create() | |
| for base in bases: | |
| for key in ("base_default_conf", "default_conf"): | |
| update = getattr(base, key, {}) | |
| if isinstance(update, dict): | |
| update = OmegaConf.create(update) | |
| total_conf = OmegaConf.merge(total_conf, update) | |
| return dict(base_default_conf=total_conf) | |
| class BaseModel(nn.Module, metaclass=MetaModel): | |
| """ | |
| What the child model is expect to declare: | |
| default_conf: dictionary of the default configuration of the model. | |
| It recursively updates the default_conf of all parent classes, and | |
| it is updated by the user-provided configuration passed to __init__. | |
| Configurations can be nested. | |
| required_data_keys: list of expected keys in the input data dictionary. | |
| strict_conf (optional): boolean. If false, BaseModel does not raise | |
| an error when the user provides an unknown configuration entry. | |
| _init(self, conf): initialization method, where conf is the final | |
| configuration object (also accessible with `self.conf`). Accessing | |
| unknown configuration entries will raise an error. | |
| _forward(self, data): method that returns a dictionary of batched | |
| prediction tensors based on a dictionary of batched input data tensors. | |
| loss(self, pred, data): method that returns a dictionary of losses, | |
| computed from model predictions and input data. Each loss is a batch | |
| of scalars, i.e. a torch.Tensor of shape (B,). | |
| The total loss to be optimized has the key `'total'`. | |
| metrics(self, pred, data): method that returns a dictionary of metrics, | |
| each as a batch of scalars. | |
| """ | |
| default_conf = { | |
| "name": None, | |
| "trainable": True, # if false: do not optimize this model parameters | |
| "freeze_batch_normalization": False, # use test-time statistics | |
| "timeit": False, # time forward pass | |
| } | |
| required_data_keys = [] | |
| strict_conf = False | |
| are_weights_initialized = False | |
| def __init__(self, conf): | |
| """Perform some logic and call the _init method of the child model.""" | |
| super().__init__() | |
| default_conf = OmegaConf.merge( | |
| self.base_default_conf, OmegaConf.create(self.default_conf) | |
| ) | |
| if self.strict_conf: | |
| OmegaConf.set_struct(default_conf, True) | |
| # fixme: backward compatibility | |
| if "pad" in conf and "pad" not in default_conf: # backward compat. | |
| with omegaconf.read_write(conf): | |
| with omegaconf.open_dict(conf): | |
| conf["interpolation"] = {"pad": conf.pop("pad")} | |
| if isinstance(conf, dict): | |
| conf = OmegaConf.create(conf) | |
| self.conf = conf = OmegaConf.merge(default_conf, conf) | |
| OmegaConf.set_readonly(conf, True) | |
| OmegaConf.set_struct(conf, True) | |
| self.required_data_keys = copy(self.required_data_keys) | |
| self._init(conf) | |
| if not conf.trainable: | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| def train(self, mode=True): | |
| super().train(mode) | |
| def freeze_bn(module): | |
| if isinstance(module, nn.modules.batchnorm._BatchNorm): | |
| module.eval() | |
| if self.conf.freeze_batch_normalization: | |
| self.apply(freeze_bn) | |
| return self | |
| def forward(self, data): | |
| """Check the data and call the _forward method of the child model.""" | |
| def recursive_key_check(expected, given): | |
| for key in expected: | |
| assert key in given, f"Missing key {key} in data" | |
| if isinstance(expected, dict): | |
| recursive_key_check(expected[key], given[key]) | |
| recursive_key_check(self.required_data_keys, data) | |
| return self._forward(data) | |
| def _init(self, conf): | |
| """To be implemented by the child class.""" | |
| raise NotImplementedError | |
| def _forward(self, data): | |
| """To be implemented by the child class.""" | |
| raise NotImplementedError | |
| def loss(self, pred, data): | |
| """To be implemented by the child class.""" | |
| raise NotImplementedError | |
| def load_state_dict(self, *args, **kwargs): | |
| """Load the state dict of the model, and set the model to initialized.""" | |
| ret = super().load_state_dict(*args, **kwargs) | |
| self.set_initialized() | |
| return ret | |
| def is_initialized(self): | |
| """Recursively check if the model is initialized, i.e. weights are loaded""" | |
| is_initialized = True # initialize to true and perform recursive and | |
| for _, w in self.named_children(): | |
| if isinstance(w, BaseModel): | |
| # if children is BaseModel, we perform recursive check | |
| is_initialized = is_initialized and w.is_initialized() | |
| else: | |
| # else, we check if self is initialized or the children has no params | |
| n_params = len(list(w.parameters())) | |
| is_initialized = is_initialized and ( | |
| n_params == 0 or self.are_weights_initialized | |
| ) | |
| return is_initialized | |
| def set_initialized(self, to: bool = True): | |
| """Recursively set the initialization state.""" | |
| self.are_weights_initialized = to | |
| for _, w in self.named_parameters(): | |
| if isinstance(w, BaseModel): | |
| w.set_initialized(to) | |