Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import torch | |
| import yaml | |
| def traverse_dir( | |
| root_dir, | |
| extensions, | |
| amount=None, | |
| str_include=None, | |
| str_exclude=None, | |
| is_pure=False, | |
| is_sort=False, | |
| is_ext=True): | |
| file_list = [] | |
| cnt = 0 | |
| for root, _, files in os.walk(root_dir): | |
| for file in files: | |
| if any([file.endswith(f".{ext}") for ext in extensions]): | |
| # path | |
| mix_path = os.path.join(root, file) | |
| pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path | |
| # amount | |
| if (amount is not None) and (cnt == amount): | |
| if is_sort: | |
| file_list.sort() | |
| return file_list | |
| # check string | |
| if (str_include is not None) and (str_include not in pure_path): | |
| continue | |
| if (str_exclude is not None) and (str_exclude in pure_path): | |
| continue | |
| if not is_ext: | |
| ext = pure_path.split('.')[-1] | |
| pure_path = pure_path[:-(len(ext)+1)] | |
| file_list.append(pure_path) | |
| cnt += 1 | |
| if is_sort: | |
| file_list.sort() | |
| return file_list | |
| class DotDict(dict): | |
| def __getattr__(*args): | |
| val = dict.get(*args) | |
| return DotDict(val) if type(val) is dict else val | |
| __setattr__ = dict.__setitem__ | |
| __delattr__ = dict.__delitem__ | |
| def get_network_paras_amount(model_dict): | |
| info = dict() | |
| for model_name, model in model_dict.items(): | |
| # all_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| info[model_name] = trainable_params | |
| return info | |
| def load_config(path_config): | |
| with open(path_config, "r") as config: | |
| args = yaml.safe_load(config) | |
| args = DotDict(args) | |
| # print(args) | |
| return args | |
| def save_config(path_config,config): | |
| config = dict(config) | |
| with open(path_config, "w") as f: | |
| yaml.dump(config, f) | |
| def to_json(path_params, path_json): | |
| params = torch.load(path_params, map_location=torch.device('cpu')) | |
| raw_state_dict = {} | |
| for k, v in params.items(): | |
| val = v.flatten().numpy().tolist() | |
| raw_state_dict[k] = val | |
| with open(path_json, 'w') as outfile: | |
| json.dump(raw_state_dict, outfile,indent= "\t") | |
| def convert_tensor_to_numpy(tensor, is_squeeze=True): | |
| if is_squeeze: | |
| tensor = tensor.squeeze() | |
| if tensor.requires_grad: | |
| tensor = tensor.detach() | |
| if tensor.is_cuda: | |
| tensor = tensor.cpu() | |
| return tensor.numpy() | |
| def load_model( | |
| expdir, | |
| model, | |
| optimizer, | |
| name='model', | |
| postfix='', | |
| device='cpu'): | |
| if postfix == '': | |
| postfix = '_' + postfix | |
| path = os.path.join(expdir, name+postfix) | |
| path_pt = traverse_dir(expdir, ['pt'], is_ext=False) | |
| global_step = 0 | |
| if len(path_pt) > 0: | |
| steps = [s[len(path):] for s in path_pt] | |
| maxstep = max([int(s) if s.isdigit() else 0 for s in steps]) | |
| if maxstep >= 0: | |
| path_pt = path+str(maxstep)+'.pt' | |
| else: | |
| path_pt = path+'best.pt' | |
| print(' [*] restoring model from', path_pt) | |
| ckpt = torch.load(path_pt, map_location=torch.device(device)) | |
| global_step = ckpt['global_step'] | |
| model.load_state_dict(ckpt['model'], strict=False) | |
| if ckpt.get("optimizer") is not None: | |
| optimizer.load_state_dict(ckpt['optimizer']) | |
| return global_step, model, optimizer | |