Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import torch | |
| class BaseModel(torch.nn.Module): | |
| def load(self, path): | |
| """Load model from file. | |
| Args: | |
| path (str): file path | |
| """ | |
| parameters = torch.load(path, map_location=torch.device('cpu'), weights_only=True) | |
| if 'optimizer' in parameters: | |
| parameters = parameters['model'] | |
| self.load_state_dict(parameters) | |