Spaces:
Runtime error
Runtime error
| import os | |
| import warnings | |
| from typing import Dict | |
| import wandb | |
| from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger | |
| from torch import nn | |
| import torch.distributed as dist | |
| class Logger: | |
| """ | |
| Helper class for managing the logging. Enables utilities such as: | |
| - Automatic saving of wandb run identifiers to correctly resume training | |
| - Application of logging callbacks to the training process. Eg. gradient information | |
| - resume_run option to resume on the same wandb run | |
| """ | |
| logger = None | |
| def __init__(self, config: Dict, checkpoints_directory=None, run_name=None, **kwargs): | |
| self.config = config | |
| # Gets the root directory | |
| if checkpoints_directory is not None: | |
| config["logging"]["checkpoints_directory"] = checkpoints_directory | |
| if run_name is not None: | |
| config["logging"]["run_name"] = run_name | |
| self.resume_run = config["logging"].get("resume_run", False) | |
| self.checkpoints_directory = config["logging"]["checkpoints_directory"] | |
| # Filename where to save the id of the run. Needed by loggers such as wandb to correctly resume logging | |
| self.run_id_filename = os.path.join(self.checkpoints_directory, "run_id.txt") | |
| self.gradient_log_steps = config["logging"].get("gradient_log_steps", 200) | |
| self.weight_log_steps = config["logging"].get("weight_log_steps", 200) | |
| # Retrieves existing wandb id or generates a new one | |
| if self.id_file_exists() and self.resume_run: | |
| run_id = self.get_id_from_file() | |
| else: | |
| run_id = wandb.util.generate_id() | |
| self.save_id_to_file(run_id) | |
| # Sets the environment variables needed by wandb | |
| self.set_wandb_environment_variables() | |
| self.project_name = config["logging"]["project_name"] | |
| # Instantiates the logger only on the main process | |
| # If this is not done, wandb will crash the application (https://docs.wandb.ai/guides/track/log/distributed-training) | |
| self.logger = None | |
| rank = 0 | |
| if dist.is_initialized(): | |
| rank = dist.get_rank() | |
| elif "HOSTNAME" in os.environ: | |
| rank = int(os.environ["HOSTNAME"].split("-")[-1]) | |
| else: | |
| rank = 0 | |
| if rank == 0: | |
| self.logger = WandbLogger(name=config["logging"]["run_name"], project=self.project_name, id=run_id, config=config, **kwargs) | |
| def print(self, *args, **kwargs): | |
| print(*args, **kwargs) | |
| def get_logger(self): | |
| return self.logger | |
| def set_wandb_environment_variables(self): | |
| """ | |
| Sets the environment variables that are necessary for wandb to work | |
| :return: | |
| """ | |
| wandb_key = self.config["logging"]["wandb_key"] | |
| os.environ["WANDB_API_KEY"] = wandb_key | |
| # wandb_username = self.config["logging"]["wandb_username"] | |
| # os.environ["WANDB_USERNAME"] = wandb_username | |
| # wandb_base_url = self.config["logging"].get("wandb_base_url", "https://snap.wandb.io") | |
| # os.environ["WANDB_BASE_URL"] = wandb_base_url | |
| # wandb_entity = self.config["logging"].get("wandb_entity", "generative-ai") # Previously "rutils-users" | |
| # os.environ["WANDB_ENTITY"] = wandb_entity | |
| # def register_grad_hooks(self, model: nn.Module): | |
| # """ | |
| # Registers grad logging hooks to the model | |
| # :param model: model for which to register the logging hooks | |
| # :return: | |
| # """ | |
| # for name, parameter in model.named_parameters(): | |
| # if parameter.requires_grad: | |
| # current_hook = GradLoggingHook(self.logger, f"gradient_statistics/{name}", self.gradient_log_steps) | |
| # parameter.register_hook(current_hook) | |
| # def register_weight_hooks(self, model: nn.Module): | |
| # """ | |
| # Registers weight logging hooks to the model | |
| # :param model: model for which to register the logging hooks. | |
| # The model is required to call the forward method for this to work, so it does not currently work with pytorch lightning modules | |
| # If necessary fix this by implementing it with a pytorch lightning callback | |
| # :return: | |
| # """ | |
| # current_hook = WeightLoggingHook(self.logger, f"weight_statistics", self.weight_log_steps) | |
| # model.register_forward_hook(current_hook) | |
| def id_file_exists(self): | |
| """ | |
| Checks if the wandb id file exists in the checkpoints directory | |
| :return: | |
| """ | |
| return os.path.isfile(self.run_id_filename) | |
| def save_id_to_file(self, run_id): | |
| """ | |
| Saves the wandb id to a file in the checkpoints directory | |
| """ | |
| with open(self.run_id_filename, "w") as f: | |
| f.write(run_id) | |
| def get_id_from_file(self): | |
| """ | |
| Reads the id file and returns the id | |
| :return: run id, None if file does not exist | |
| """ | |
| if not self.id_file_exists(): | |
| warnings.warn(f"Run ID file does not exist {self.run_id_filename}") | |
| return None | |
| with open(self.run_id_filename, "r") as f: | |
| return f.readline() |