Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.optim as optim | |
| import numpy as np | |
| import itertools | |
| def singleton(class_): | |
| instances = {} | |
| def getinstance(*args, **kwargs): | |
| if class_ not in instances: | |
| instances[class_] = class_(*args, **kwargs) | |
| return instances[class_] | |
| return getinstance | |
| class get_optimizer(object): | |
| def __init__(self): | |
| self.optimizer = {} | |
| self.register(optim.SGD, 'sgd') | |
| self.register(optim.Adam, 'adam') | |
| self.register(optim.AdamW, 'adamw') | |
| def register(self, optim, name): | |
| self.optimizer[name] = optim | |
| def __call__(self, net, cfg): | |
| if cfg is None: | |
| return None | |
| t = cfg.type | |
| if isinstance(net, (torch.nn.DataParallel, | |
| torch.nn.parallel.DistributedDataParallel)): | |
| netm = net.module | |
| else: | |
| netm = net | |
| pg = getattr(netm, 'parameter_group', None) | |
| if pg is not None: | |
| params = [] | |
| for group_name, module_or_para in pg.items(): | |
| if not isinstance(module_or_para, list): | |
| module_or_para = [module_or_para] | |
| grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para] | |
| grouped_params = itertools.chain(*grouped_params) | |
| pg_dict = {'params':grouped_params, 'name':group_name} | |
| params.append(pg_dict) | |
| else: | |
| params = net.parameters() | |
| return self.optimizer[t](params, lr=0, **cfg.args) | |