Moayed's picture
add demo files
cef9e84
import os
import warnings
from typing import Dict
import wandb
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from torch import nn
import torch.distributed as dist
class Logger:
"""
Helper class for managing the logging. Enables utilities such as:
- Automatic saving of wandb run identifiers to correctly resume training
- Application of logging callbacks to the training process. Eg. gradient information
- resume_run option to resume on the same wandb run
"""
logger = None
def __init__(self, config: Dict, checkpoints_directory=None, run_name=None, **kwargs):
self.config = config
# Gets the root directory
if checkpoints_directory is not None:
config["logging"]["checkpoints_directory"] = checkpoints_directory
if run_name is not None:
config["logging"]["run_name"] = run_name
self.resume_run = config["logging"].get("resume_run", False)
self.checkpoints_directory = config["logging"]["checkpoints_directory"]
# Filename where to save the id of the run. Needed by loggers such as wandb to correctly resume logging
self.run_id_filename = os.path.join(self.checkpoints_directory, "run_id.txt")
self.gradient_log_steps = config["logging"].get("gradient_log_steps", 200)
self.weight_log_steps = config["logging"].get("weight_log_steps", 200)
# Retrieves existing wandb id or generates a new one
if self.id_file_exists() and self.resume_run:
run_id = self.get_id_from_file()
else:
run_id = wandb.util.generate_id()
self.save_id_to_file(run_id)
# Sets the environment variables needed by wandb
self.set_wandb_environment_variables()
self.project_name = config["logging"]["project_name"]
# Instantiates the logger only on the main process
# If this is not done, wandb will crash the application (https://docs.wandb.ai/guides/track/log/distributed-training)
self.logger = None
rank = 0
if dist.is_initialized():
rank = dist.get_rank()
elif "HOSTNAME" in os.environ:
rank = int(os.environ["HOSTNAME"].split("-")[-1])
else:
rank = 0
if rank == 0:
self.logger = WandbLogger(name=config["logging"]["run_name"], project=self.project_name, id=run_id, config=config, **kwargs)
def print(self, *args, **kwargs):
print(*args, **kwargs)
def get_logger(self):
return self.logger
def set_wandb_environment_variables(self):
"""
Sets the environment variables that are necessary for wandb to work
:return:
"""
wandb_key = self.config["logging"]["wandb_key"]
os.environ["WANDB_API_KEY"] = wandb_key
# wandb_username = self.config["logging"]["wandb_username"]
# os.environ["WANDB_USERNAME"] = wandb_username
# wandb_base_url = self.config["logging"].get("wandb_base_url", "https://snap.wandb.io")
# os.environ["WANDB_BASE_URL"] = wandb_base_url
# wandb_entity = self.config["logging"].get("wandb_entity", "generative-ai") # Previously "rutils-users"
# os.environ["WANDB_ENTITY"] = wandb_entity
# def register_grad_hooks(self, model: nn.Module):
# """
# Registers grad logging hooks to the model
# :param model: model for which to register the logging hooks
# :return:
# """
# for name, parameter in model.named_parameters():
# if parameter.requires_grad:
# current_hook = GradLoggingHook(self.logger, f"gradient_statistics/{name}", self.gradient_log_steps)
# parameter.register_hook(current_hook)
# def register_weight_hooks(self, model: nn.Module):
# """
# Registers weight logging hooks to the model
# :param model: model for which to register the logging hooks.
# The model is required to call the forward method for this to work, so it does not currently work with pytorch lightning modules
# If necessary fix this by implementing it with a pytorch lightning callback
# :return:
# """
# current_hook = WeightLoggingHook(self.logger, f"weight_statistics", self.weight_log_steps)
# model.register_forward_hook(current_hook)
def id_file_exists(self):
"""
Checks if the wandb id file exists in the checkpoints directory
:return:
"""
return os.path.isfile(self.run_id_filename)
def save_id_to_file(self, run_id):
"""
Saves the wandb id to a file in the checkpoints directory
"""
with open(self.run_id_filename, "w") as f:
f.write(run_id)
def get_id_from_file(self):
"""
Reads the id file and returns the id
:return: run id, None if file does not exist
"""
if not self.id_file_exists():
warnings.warn(f"Run ID file does not exist {self.run_id_filename}")
return None
with open(self.run_id_filename, "r") as f:
return f.readline()