Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Dumps things to tensorboard and console | |
| """ | |
| import os | |
| import logging | |
| import datetime | |
| from typing import Dict | |
| import numpy as np | |
| from PIL import Image | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tracker.utils.time_estimator import TimeEstimator | |
| def tensor_to_numpy(image): | |
| image_np = (image.numpy() * 255).astype('uint8') | |
| return image_np | |
| def detach_to_cpu(x): | |
| return x.detach().cpu() | |
| def fix_width_trunc(x): | |
| return ('{:.9s}'.format('{:0.9f}'.format(x))) | |
| class TensorboardLogger: | |
| def __init__(self, run_dir, py_logger: logging.Logger, *, enabled_tb): | |
| self.run_dir = run_dir | |
| self.py_log = py_logger | |
| if enabled_tb: | |
| self.tb_log = SummaryWriter(run_dir) | |
| else: | |
| self.tb_log = None | |
| # Get current git info for logging | |
| try: | |
| import git | |
| repo = git.Repo(".") | |
| git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha) | |
| except (ImportError, RuntimeError): | |
| print('Failed to fetch git info. Defaulting to None') | |
| git_info = 'None' | |
| self.log_string('git', git_info) | |
| # used when logging metrics | |
| self.time_estimator: TimeEstimator = None | |
| def log_scalar(self, tag, x, it): | |
| if self.tb_log is None: | |
| return | |
| self.tb_log.add_scalar(tag, x, it) | |
| def log_metrics(self, exp_id, prefix, metrics: Dict, it): | |
| msg = f'{exp_id}-{prefix} - it {it:6d}: ' | |
| metrics_msg = '' | |
| for k, v in sorted(metrics.items()): | |
| self.log_scalar(f'{prefix}/{k}', v, it) | |
| metrics_msg += f'{k: >10}:{v:.7f},\t' | |
| if self.time_estimator is not None: | |
| self.time_estimator.update() | |
| avg_time = self.time_estimator.get_and_reset_avg_time() | |
| est = self.time_estimator.get_est_remaining(it) | |
| est = datetime.timedelta(seconds=est) | |
| if est.days > 0: | |
| remaining_str = f'{est.days}d {est.seconds // 3600}h' | |
| else: | |
| remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m' | |
| eta = datetime.datetime.now() + est | |
| eta_str = eta.strftime('%Y-%m-%d %H:%M:%S') | |
| time_msg = f'avg_time:{avg_time:.3f},remaining:{remaining_str},eta:{eta_str},\t' | |
| msg = f'{msg} {time_msg}' | |
| msg = f'{msg} {metrics_msg}' | |
| self.py_log.info(msg) | |
| def log_image(self, stage_name, tag, image, it): | |
| image_dir = os.path.join(self.run_dir, f'{stage_name}_images') | |
| os.makedirs(image_dir, exist_ok=True) | |
| image = Image.fromarray(image) | |
| image.save(os.path.join(image_dir, f'{tag}_{it}.png')) | |
| def log_string(self, tag, x): | |
| self.py_log.info(f'{tag} - {x}') | |
| if self.tb_log is None: | |
| return | |
| self.tb_log.add_text(tag, x) | |
| def debug(self, x): | |
| self.py_log.debug(x) | |
| def info(self, x): | |
| self.py_log.info(x) | |
| def warning(self, x): | |
| self.py_log.warning(x) | |
| def error(self, x): | |
| self.py_log.error(x) | |
| def critical(self, x): | |
| self.py_log.critical(x) | |