Spaces:
Runtime error
Runtime error
| ''' | |
| author: wayn391@mastertones | |
| ''' | |
| import datetime | |
| import os | |
| import time | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import yaml | |
| from torch.utils.tensorboard import SummaryWriter | |
| class Saver(object): | |
| def __init__( | |
| self, | |
| args, | |
| initial_global_step=-1): | |
| self.expdir = args.env.expdir | |
| self.sample_rate = args.data.sampling_rate | |
| # cold start | |
| self.global_step = initial_global_step | |
| self.init_time = time.time() | |
| self.last_time = time.time() | |
| # makedirs | |
| os.makedirs(self.expdir, exist_ok=True) | |
| # path | |
| self.path_log_info = os.path.join(self.expdir, 'log_info.txt') | |
| # ckpt | |
| os.makedirs(self.expdir, exist_ok=True) | |
| # writer | |
| self.writer = SummaryWriter(os.path.join(self.expdir, 'logs')) | |
| # save config | |
| path_config = os.path.join(self.expdir, 'config.yaml') | |
| with open(path_config, "w") as out_config: | |
| yaml.dump(dict(args), out_config) | |
| def log_info(self, msg): | |
| '''log method''' | |
| if isinstance(msg, dict): | |
| msg_list = [] | |
| for k, v in msg.items(): | |
| tmp_str = '' | |
| if isinstance(v, int): | |
| tmp_str = '{}: {:,}'.format(k, v) | |
| else: | |
| tmp_str = '{}: {}'.format(k, v) | |
| msg_list.append(tmp_str) | |
| msg_str = '\n'.join(msg_list) | |
| else: | |
| msg_str = msg | |
| # dsplay | |
| print(msg_str) | |
| # save | |
| with open(self.path_log_info, 'a') as fp: | |
| fp.write(msg_str+'\n') | |
| def log_value(self, dict): | |
| for k, v in dict.items(): | |
| self.writer.add_scalar(k, v, self.global_step) | |
| def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5): | |
| spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1) | |
| spec = spec_cat[0] | |
| if isinstance(spec, torch.Tensor): | |
| spec = spec.cpu().numpy() | |
| fig = plt.figure(figsize=(12, 9)) | |
| plt.pcolor(spec.T, vmin=vmin, vmax=vmax) | |
| plt.tight_layout() | |
| self.writer.add_figure(name, fig, self.global_step) | |
| def log_audio(self, dict): | |
| for k, v in dict.items(): | |
| self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate) | |
| def get_interval_time(self, update=True): | |
| cur_time = time.time() | |
| time_interval = cur_time - self.last_time | |
| if update: | |
| self.last_time = cur_time | |
| return time_interval | |
| def get_total_time(self, to_str=True): | |
| total_time = time.time() - self.init_time | |
| if to_str: | |
| total_time = str(datetime.timedelta( | |
| seconds=total_time))[:-5] | |
| return total_time | |
| def save_model( | |
| self, | |
| model, | |
| optimizer, | |
| name='model', | |
| postfix='', | |
| to_json=False): | |
| # path | |
| if postfix: | |
| postfix = '_' + postfix | |
| path_pt = os.path.join( | |
| self.expdir , name+postfix+'.pt') | |
| # check | |
| print(' [*] model checkpoint saved: {}'.format(path_pt)) | |
| # save | |
| if optimizer is not None: | |
| torch.save({ | |
| 'global_step': self.global_step, | |
| 'model': model.state_dict(), | |
| 'optimizer': optimizer.state_dict()}, path_pt) | |
| else: | |
| torch.save({ | |
| 'global_step': self.global_step, | |
| 'model': model.state_dict()}, path_pt) | |
| def delete_model(self, name='model', postfix=''): | |
| # path | |
| if postfix: | |
| postfix = '_' + postfix | |
| path_pt = os.path.join( | |
| self.expdir , name+postfix+'.pt') | |
| # delete | |
| if os.path.exists(path_pt): | |
| os.remove(path_pt) | |
| print(' [*] model checkpoint deleted: {}'.format(path_pt)) | |
| def global_step_increment(self): | |
| self.global_step += 1 | |