Spaces:
Runtime error
Runtime error
| import timeit | |
| import numpy as np | |
| import os | |
| import os.path as osp | |
| import shutil | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.distributed as dist | |
| from .cfg_holder import cfg_unique_holder as cfguh | |
| from . import sync | |
| print_console_local_rank0_only = True | |
| def print_log(*console_info): | |
| local_rank = sync.get_rank('local') | |
| if print_console_local_rank0_only and (local_rank!=0): | |
| return | |
| console_info = [str(i) for i in console_info] | |
| console_info = ' '.join(console_info) | |
| print(console_info) | |
| if local_rank!=0: | |
| return | |
| log_file = None | |
| try: | |
| log_file = cfguh().cfg.train.log_file | |
| except: | |
| try: | |
| log_file = cfguh().cfg.eval.log_file | |
| except: | |
| return | |
| if log_file is not None: | |
| with open(log_file, 'a') as f: | |
| f.write(console_info + '\n') | |
| class distributed_log_manager(object): | |
| def __init__(self): | |
| self.sum = {} | |
| self.cnt = {} | |
| self.time_check = timeit.default_timer() | |
| cfgt = cfguh().cfg.train | |
| use_tensorboard = getattr(cfgt, 'log_tensorboard', False) | |
| self.ddp = sync.is_ddp() | |
| self.rank = sync.get_rank('local') | |
| self.world_size = sync.get_world_size('local') | |
| self.tb = None | |
| if use_tensorboard and (self.rank==0): | |
| import tensorboardX | |
| monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard') | |
| self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir)) | |
| def accumulate(self, n, **data): | |
| if n < 0: | |
| raise ValueError | |
| for itemn, di in data.items(): | |
| if itemn in self.sum: | |
| self.sum[itemn] += di * n | |
| self.cnt[itemn] += n | |
| else: | |
| self.sum[itemn] = di * n | |
| self.cnt[itemn] = n | |
| def get_mean_value_dict(self): | |
| value_gather = [ | |
| self.sum[itemn]/self.cnt[itemn] \ | |
| for itemn in sorted(self.sum.keys()) ] | |
| value_gather_tensor = torch.FloatTensor(value_gather).to(self.rank) | |
| if self.ddp: | |
| dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM) | |
| value_gather_tensor /= self.world_size | |
| mean = {} | |
| for idx, itemn in enumerate(sorted(self.sum.keys())): | |
| mean[itemn] = value_gather_tensor[idx].item() | |
| return mean | |
| def tensorboard_log(self, step, data, mode='train', **extra): | |
| if self.tb is None: | |
| return | |
| if mode == 'train': | |
| self.tb.add_scalar('other/epochn', extra['epochn'], step) | |
| if 'lr' in extra: | |
| self.tb.add_scalar('other/lr', extra['lr'], step) | |
| for itemn, di in data.items(): | |
| if itemn.find('loss') == 0: | |
| self.tb.add_scalar('loss/'+itemn, di, step) | |
| elif itemn == 'Loss': | |
| self.tb.add_scalar('Loss', di, step) | |
| else: | |
| self.tb.add_scalar('other/'+itemn, di, step) | |
| elif mode == 'eval': | |
| if isinstance(data, dict): | |
| for itemn, di in data.items(): | |
| self.tb.add_scalar('eval/'+itemn, di, step) | |
| else: | |
| self.tb.add_scalar('eval', data, step) | |
| return | |
| def train_summary(self, itern, epochn, samplen, lr, tbstep=None): | |
| console_info = [ | |
| 'Iter:{}'.format(itern), | |
| 'Epoch:{}'.format(epochn), | |
| 'Sample:{}'.format(samplen),] | |
| if lr is not None: | |
| console_info += ['LR:{:.4E}'.format(lr)] | |
| mean = self.get_mean_value_dict() | |
| tbstep = itern if tbstep is None else tbstep | |
| self.tensorboard_log( | |
| tbstep, mean, mode='train', | |
| itern=itern, epochn=epochn, lr=lr) | |
| loss = mean.pop('Loss') | |
| mean_info = ['Loss:{:.4f}'.format(loss)] + [ | |
| '{}:{:.4f}'.format(itemn, mean[itemn]) \ | |
| for itemn in sorted(mean.keys()) \ | |
| if itemn.find('loss') == 0 | |
| ] | |
| console_info += mean_info | |
| console_info.append('Time:{:.2f}s'.format( | |
| timeit.default_timer() - self.time_check)) | |
| return ' , '.join(console_info) | |
| def clear(self): | |
| self.sum = {} | |
| self.cnt = {} | |
| self.time_check = timeit.default_timer() | |
| def tensorboard_close(self): | |
| if self.tb is not None: | |
| self.tb.close() | |
| # ----- also include some small utils ----- | |
| def torch_to_numpy(*argv): | |
| if len(argv) > 1: | |
| data = list(argv) | |
| else: | |
| data = argv[0] | |
| if isinstance(data, torch.Tensor): | |
| return data.to('cpu').detach().numpy() | |
| elif isinstance(data, (list, tuple)): | |
| out = [] | |
| for di in data: | |
| out.append(torch_to_numpy(di)) | |
| return out | |
| elif isinstance(data, dict): | |
| out = {} | |
| for ni, di in data.items(): | |
| out[ni] = torch_to_numpy(di) | |
| return out | |
| else: | |
| return data | |