Spaces:
Runtime error
Runtime error
| from email.policy import strict | |
| import torch | |
| import torchvision.models | |
| import os.path as osp | |
| import copy | |
| from ...log_service import print_log | |
| from .utils import \ | |
| get_total_param, get_total_param_sum, \ | |
| get_unit | |
| # def load_state_dict(net, model_path): | |
| # if isinstance(net, dict): | |
| # for ni, neti in net.items(): | |
| # paras = torch.load(model_path[ni], map_location=torch.device('cpu')) | |
| # new_paras = neti.state_dict() | |
| # new_paras.update(paras) | |
| # neti.load_state_dict(new_paras) | |
| # else: | |
| # paras = torch.load(model_path, map_location=torch.device('cpu')) | |
| # new_paras = net.state_dict() | |
| # new_paras.update(paras) | |
| # net.load_state_dict(new_paras) | |
| # return | |
| # def save_state_dict(net, path): | |
| # if isinstance(net, (torch.nn.DataParallel, | |
| # torch.nn.parallel.DistributedDataParallel)): | |
| # torch.save(net.module.state_dict(), path) | |
| # else: | |
| # torch.save(net.state_dict(), path) | |
| def singleton(class_): | |
| instances = {} | |
| def getinstance(*args, **kwargs): | |
| if class_ not in instances: | |
| instances[class_] = class_(*args, **kwargs) | |
| return instances[class_] | |
| return getinstance | |
| def preprocess_model_args(args): | |
| # If args has layer_units, get the corresponding | |
| # units. | |
| # If args get backbone, get the backbone model. | |
| args = copy.deepcopy(args) | |
| if 'layer_units' in args: | |
| layer_units = [ | |
| get_unit()(i) for i in args.layer_units | |
| ] | |
| args.layer_units = layer_units | |
| if 'backbone' in args: | |
| args.backbone = get_model()(args.backbone) | |
| return args | |
| class get_model(object): | |
| def __init__(self): | |
| self.model = {} | |
| def register(self, model, name): | |
| self.model[name] = model | |
| def __call__(self, cfg, verbose=True): | |
| """ | |
| Construct model based on the config. | |
| """ | |
| if cfg is None: | |
| return None | |
| t = cfg.type | |
| # the register is in each file | |
| if t.find('pfd')==0: | |
| from .. import pfd | |
| elif t=='autoencoderkl': | |
| from .. import autokl | |
| elif (t.find('clip')==0) or (t.find('openclip')==0): | |
| from .. import clip | |
| elif t.find('openai_unet')==0: | |
| from .. import openaimodel | |
| elif t.find('controlnet')==0: | |
| from .. import controlnet | |
| elif t.find('seecoder')==0: | |
| from .. import seecoder | |
| elif t.find('swin')==0: | |
| from .. import swin | |
| args = preprocess_model_args(cfg.args) | |
| net = self.model[t](**args) | |
| pretrained = cfg.get('pretrained', None) | |
| if pretrained is None: # backward compatible | |
| pretrained = cfg.get('pth', None) | |
| map_location = cfg.get('map_location', 'cpu') | |
| strict_sd = cfg.get('strict_sd', True) | |
| if pretrained is not None: | |
| if osp.splitext(pretrained)[1] == '.pth': | |
| sd = torch.load(pretrained, map_location=map_location) | |
| elif osp.splitext(pretrained)[1] == '.ckpt': | |
| sd = torch.load(pretrained, map_location=map_location)['state_dict'] | |
| elif osp.splitext(pretrained)[1] == '.safetensors': | |
| from safetensors.torch import load_file | |
| from collections import OrderedDict | |
| sd = load_file(pretrained, map_location) | |
| sd = OrderedDict(sd) | |
| net.load_state_dict(sd, strict=strict_sd) | |
| if verbose: | |
| print_log('Load model from [{}] strict [{}].'.format(pretrained, strict_sd)) | |
| # display param_num & param_sum | |
| if verbose: | |
| print_log( | |
| 'Load {} with total {} parameters,' | |
| '{:.3f} parameter sum.'.format( | |
| t, | |
| get_total_param(net), | |
| get_total_param_sum(net) )) | |
| return net | |
| def register(name): | |
| def wrapper(class_): | |
| get_model().register(class_, name) | |
| return class_ | |
| return wrapper | |