File size: 5,180 Bytes
cef9e84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import warnings
from typing import Dict
import wandb
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from torch import nn
import torch.distributed as dist


class Logger:
    """
    Helper class for managing the logging. Enables utilities such as:
    - Automatic saving of wandb run identifiers to correctly resume training
    - Application of logging callbacks to the training process. Eg. gradient information
    - resume_run option to resume on the same wandb run
    """

    logger = None

    def __init__(self, config: Dict, checkpoints_directory=None, run_name=None, **kwargs):
        self.config = config

        # Gets the root directory
        if checkpoints_directory is not None:
            config["logging"]["checkpoints_directory"] = checkpoints_directory
        
        if run_name is not None:
            config["logging"]["run_name"] = run_name
        
        self.resume_run = config["logging"].get("resume_run", False)
        
        self.checkpoints_directory = config["logging"]["checkpoints_directory"]
        # Filename where to save the id of the run. Needed by loggers such as wandb to correctly resume logging
        self.run_id_filename = os.path.join(self.checkpoints_directory, "run_id.txt")
        self.gradient_log_steps = config["logging"].get("gradient_log_steps", 200)
        self.weight_log_steps = config["logging"].get("weight_log_steps", 200)

        # Retrieves existing wandb id or generates a new one
        if self.id_file_exists() and self.resume_run:
            run_id = self.get_id_from_file()
        else:
            run_id = wandb.util.generate_id()
            self.save_id_to_file(run_id)

        # Sets the environment variables needed by wandb
        self.set_wandb_environment_variables()
        self.project_name = config["logging"]["project_name"]
        # Instantiates the logger only on the main process
        # If this is not done, wandb will crash the application (https://docs.wandb.ai/guides/track/log/distributed-training)
        self.logger = None
        rank = 0
        if dist.is_initialized():
            rank = dist.get_rank()
        elif "HOSTNAME" in os.environ:
            rank = int(os.environ["HOSTNAME"].split("-")[-1])
        else:
            rank = 0
        if rank == 0:
            self.logger = WandbLogger(name=config["logging"]["run_name"], project=self.project_name, id=run_id, config=config, **kwargs)

    def print(self, *args, **kwargs):
        print(*args, **kwargs)

    def get_logger(self):
        return self.logger

    def set_wandb_environment_variables(self):
        """
        Sets the environment variables that are necessary for wandb to work
        :return:
        """
        wandb_key = self.config["logging"]["wandb_key"]
        os.environ["WANDB_API_KEY"] = wandb_key
        # wandb_username = self.config["logging"]["wandb_username"]
        # os.environ["WANDB_USERNAME"] = wandb_username
        # wandb_base_url = self.config["logging"].get("wandb_base_url", "https://snap.wandb.io")
        # os.environ["WANDB_BASE_URL"] = wandb_base_url
        # wandb_entity = self.config["logging"].get("wandb_entity", "generative-ai")  # Previously "rutils-users"
        # os.environ["WANDB_ENTITY"] = wandb_entity

    # def register_grad_hooks(self, model: nn.Module):
    #     """
    #     Registers grad logging hooks to the model
    #     :param model: model for which to register the logging hooks
    #     :return:
    #     """
    #     for name, parameter in model.named_parameters():
    #         if parameter.requires_grad:
    #             current_hook = GradLoggingHook(self.logger, f"gradient_statistics/{name}", self.gradient_log_steps)
    #             parameter.register_hook(current_hook)

    # def register_weight_hooks(self, model: nn.Module):
    #     """
    #     Registers weight logging hooks to the model
    #     :param model: model for which to register the logging hooks.
    #     The model is required to call the forward method for this to work, so it does not currently work with pytorch lightning modules
    #     If necessary fix this by implementing it with a pytorch lightning callback
    #     :return:
    #     """
    #     current_hook = WeightLoggingHook(self.logger, f"weight_statistics", self.weight_log_steps)
    #     model.register_forward_hook(current_hook)

    def id_file_exists(self):
        """
        Checks if the wandb id file exists in the checkpoints directory
        :return:
        """
        return os.path.isfile(self.run_id_filename)

    def save_id_to_file(self, run_id):
        """
        Saves the wandb id to a file in the checkpoints directory
        """
        with open(self.run_id_filename, "w") as f:
            f.write(run_id)

    def get_id_from_file(self):
        """
        Reads the id file and returns the id
        :return: run id, None if file does not exist
        """
        if not self.id_file_exists():
            warnings.warn(f"Run ID file does not exist {self.run_id_filename}")
            return None
        with open(self.run_id_filename, "r") as f:
            return f.readline()