Spaces:
Configuration error
Configuration error
| import torch.nn as nn | |
| import numpy as np | |
| from abc import abstractmethod | |
| class BaseModel(nn.Module): | |
| """ | |
| Base class for all models | |
| """ | |
| def forward(self, *inputs): | |
| """ | |
| Forward pass logic | |
| :return: Model output | |
| """ | |
| raise NotImplementedError | |
| def __str__(self): | |
| """ | |
| Model prints with number of trainable parameters | |
| """ | |
| model_parameters = filter(lambda p: p.requires_grad, self.parameters()) | |
| params = sum([np.prod(p.size()) for p in model_parameters]) | |
| return super().__str__() + '\nTrainable parameters: {}'.format(params) | |