""" Utility functions for training deep learning models, particularly focused on distributed training, metric logging, and gradient handling. This module provides tools for: - Tracking and logging metrics during training - Setting up distributed training environments - Handling gradient scaling and normalization - Managing learning rates and parameter groups - Saving and loading model checkpoints References: CroCo (https://github.com/naver/croco) """ import builtins import datetime import json import math import os import time from collections import defaultdict, deque from pathlib import Path import torch import torch.distributed as dist from torch import inf class SmoothedValue(object): """ Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value, ) class MetricLogger(object): """ Logger for tracking and displaying training metrics. This class maintains a collection of metrics during training, provides methods to update them, and formats them for display. It also handles synchronization of metrics across processes in distributed training. """ def __init__(self, delimiter="\t", print_per_view_stats=False): """ Initialize the MetricLogger. Args: delimiter (str, optional): Delimiter for formatting output. Defaults to "\t". print_per_view_stats (bool, optional): Whether to print per-view statistics. Defaults to False. """ self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter self.print_per_view_stats = print_per_view_stats def update(self, **kwargs): """ Update metrics with new values. Args: **kwargs: Key-value pairs where keys are metric names and values are metric values Values can be tensors or numbers Raises: AssertionError: If a value is not a float or int after conversion from tensor """ for k, v in kwargs.items(): if v is None: continue if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): """ Get a meter by attribute name. This allows accessing meters as attributes of the logger. Args: attr (str): Name of the attribute to get Returns: SmoothedValue: The meter corresponding to the attribute name Raises: AttributeError: If the attribute doesn't exist as a meter or regular attribute """ if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError( "'{}' object has no attribute '{}'".format(type(self).__name__, attr) ) def __str__(self): """ Format all metrics as a string. Returns: str: Formatted string containing all metrics """ loss_str = [] for name, meter in self.meters.items(): # Skip printing per-view stats if not enabled if not self.print_per_view_stats and "view" in name: continue loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): """ Synchronize metrics across processes in distributed training. This method calls synchronize_between_processes on each meter to ensure consistent values across all processes. """ for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): """ Add a custom meter to the logger. Args: name (str): Name of the meter meter (SmoothedValue): The meter to add """ self.meters[name] = meter def log_every(self, iterable, print_freq, header=None, max_iter=None): """ Log metrics at regular intervals while iterating. This method wraps an iterable and logs metrics every print_freq iterations. It also tracks iteration time, data loading time, and memory usage. Args: iterable: Iterable to iterate over (typically a data loader) print_freq (int): How often to log metrics (in iterations) header (str, optional): Header string to print before metrics. Defaults to None. max_iter (int, optional): Maximum number of iterations. Defaults to None. Yields: object: Items from the original iterable """ i = 0 if not header: header = "" start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt="{avg:.4f}") data_time = SmoothedValue(fmt="{avg:.4f}") len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable) space_fmt = ":" + str(len(str(len_iterable))) + "d" log_msg = [ header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}", ] if torch.cuda.is_available(): log_msg.append("max mem: {memory:.0f}") log_msg = self.delimiter.join(log_msg) MB = 1024.0 * 1024.0 for it, obj in enumerate(iterable): data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len_iterable - 1: eta_seconds = iter_time.global_avg * (len_iterable - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print( log_msg.format( i, len_iterable, eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB, ) ) else: print( log_msg.format( i, len_iterable, eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), ) ) i += 1 end = time.time() if max_iter and it >= max_iter: break total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print( "{} Total time: {} ({:.4f} s / it)".format( header, total_time_str, total_time / len_iterable ) ) def setup_for_distributed(is_master): """ This function disables printing when not in master process. It replaces the built-in print function with a custom version that only prints when the current process is the master process or when explicitly forced. Args: is_master (bool): Whether the current process is the master process """ builtin_print = builtins.print def print(*args, **kwargs): force = kwargs.pop("force", False) # force = force or (get_world_size() > 8) if is_master or force: now = datetime.datetime.now().time() builtin_print("[{}] ".format(now), end="") # print with time stamp builtin_print(*args, **kwargs) builtins.print = print def is_dist_avail_and_initialized(): """ Check if distributed training is available and initialized. Returns: bool: True if distributed training is available and initialized, False otherwise """ if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): """ Get the number of processes in the distributed training group. Returns: int: Number of processes in the distributed group, or 1 if not using distributed training """ if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): """ Get the rank of the current process in the distributed training group. Returns: int: Rank of the current process, or 0 if not using distributed training """ if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): """ Check if the current process is the main process (rank 0). Returns: bool: True if the current process is the main process, False otherwise """ return get_rank() == 0 def save_on_master(*args, **kwargs): """ Save a PyTorch object only on the master process. This function is useful in distributed training to avoid multiple processes trying to save the same file simultaneously. Args: *args: Positional arguments to pass to torch.save() **kwargs: Keyword arguments to pass to torch.save() """ if is_main_process(): torch.save(*args, **kwargs) def init_distributed_mode(args): """ Initialize distributed training mode. This function sets up the distributed training environment based on environment variables and command-line arguments. It initializes the process group, sets the appropriate device, and configures printing for the distributed setup. Args: args: Arguments object containing distributed training configuration. Expected to have attributes like dist_url, and will be modified to include rank, world_size, gpu, and distributed flag. """ nodist = args.nodist if hasattr(args, "nodist") else False if "RANK" in os.environ and "WORLD_SIZE" in os.environ and not nodist: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.gpu = int(os.environ["LOCAL_RANK"]) else: print("Not using distributed mode") setup_for_distributed(is_master=True) # hack args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = "nccl" print( "| distributed init (rank {}): {}, gpu {}".format( args.rank, args.dist_url, args.gpu ), flush=True, ) torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank, ) torch.distributed.barrier() setup_for_distributed(args.rank == 0) class NativeScalerWithGradNormCount: """ A gradient scaler that handles gradient scaling and norm computation for mixed precision training. This class wraps PyTorch's GradScaler to provide additional functionality for gradient norm tracking and clipping during mixed precision training. """ state_dict_key = "amp_scaler" def __init__(self, enabled=True): """Initialize the scaler. Args: enabled (bool): Whether to enable gradient scaling. Default: True """ self._scaler = torch.GradScaler("cuda", enabled=enabled) def __call__( self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True, ): """Scales loss and performs backward pass with optional gradient clipping. Args: loss: The loss to backpropagate optimizer: The optimizer being used clip_grad: Max norm for gradient clipping. None means no clipping parameters: Model parameters or list of parameters for gradient norm computation create_graph: Whether to create graph during backward pass update_grad: Whether to update gradients Returns: norm: The gradient norm if computed, else None. Returns list of norms if parameters is a list. """ self._scaler.scale(loss).backward(create_graph=create_graph) if update_grad: if clip_grad is not None: assert parameters is not None self._scaler.unscale_( optimizer ) # unscale the gradients of optimizer's assigned params in-place if isinstance(parameters, (list, tuple)): norm = [ torch.nn.utils.clip_grad_norm_(p, clip_grad) for p in parameters ] else: norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) else: self._scaler.unscale_(optimizer) norm = get_grad_norm_(parameters) self._scaler.step(optimizer) self._scaler.update() else: norm = None return norm def state_dict(self): """Returns the state dict of the underlying scaler. Returns: dict: The state dict of the gradient scaler """ return self._scaler.state_dict() def load_state_dict(self, state_dict): """Loads the state dict into the underlying scaler. Args: state_dict: The state dict to load """ self._scaler.load_state_dict(state_dict) def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: """ Calculate the gradient norm of parameters. This function computes the norm of gradients for a set of parameters. It can handle both single parameter groups and multiple parameter groups (list/tuple of parameters). Args: parameters: A tensor or iterable of tensors or iterable of iterables of tensors containing model parameters for which to compute gradient norms norm_type (float): Type of norm to use (e.g., 2.0 for L2 norm, inf for infinity norm) Returns: torch.Tensor: The computed gradient norm. If parameters is a list/tuple of parameter groups, returns a list of norms, one for each group. """ if isinstance(parameters, (list, tuple)): # If parameters is already a list/tuple, process each parameter group all_norms = [] for params in parameters: if isinstance(params, torch.Tensor): params = [params] params = [p for p in params if p.grad is not None] if len(params) > 0: device = params[0].grad.device if norm_type == inf: group_norm = max( p.grad.detach().abs().max().to(device) for p in params ) else: group_norm = torch.norm( torch.stack( [ torch.norm(p.grad.detach(), norm_type).to(device) for p in params ] ), norm_type, ) else: group_norm = torch.tensor(0.0) all_norms.append(group_norm) return all_norms # Original logic for single parameter group if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] norm_type = float(norm_type) if len(parameters) == 0: return torch.tensor(0.0) device = parameters[0].grad.device if norm_type == inf: total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) else: total_norm = torch.norm( torch.stack( [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] ), norm_type, ) return total_norm def save_model( args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None ): """ Save model checkpoint to disk. This function saves the model state, optimizer state, loss scaler state, training arguments, current epoch, and optionally the best metric value so far. The checkpoint is only saved on the master process in distributed training. Args: args: Arguments containing output directory information epoch (int): Current training epoch model_without_ddp (torch.nn.Module): Model without DistributedDataParallel wrapper optimizer (torch.optim.Optimizer): Optimizer instance loss_scaler: Gradient scaler for mixed precision training fname (str, optional): Custom filename suffix. If None, uses the epoch number. Defaults to None. best_so_far (float, optional): Best metric value achieved so far. Defaults to None. """ output_dir = Path(args.output_dir) if fname is None: fname = str(epoch) checkpoint_path = output_dir / ("checkpoint-%s.pth" % fname) to_save = { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "scaler": loss_scaler.state_dict(), "args": args, "epoch": epoch, } if best_so_far is not None: to_save["best_so_far"] = best_so_far print(f">> Saving model to {checkpoint_path} ...") save_on_master(to_save, checkpoint_path) def load_model(train_args, model_without_ddp, optimizer, loss_scaler): """ Load model checkpoint from disk or URL. This function loads a saved checkpoint, restoring the model state, optimizer state, loss scaler state, and training epoch. It can load from a local file or a URL. Args: train_args: Training arguments containing resume information model_without_ddp (torch.nn.Module): Model without DistributedDataParallel wrapper optimizer (torch.optim.Optimizer): Optimizer instance loss_scaler: Gradient scaler for mixed precision training Returns: float or None: Best metric value from the checkpoint if available, otherwise None """ train_args.start_epoch = 0 best_so_far = None if train_args.resume and train_args.resume_ckpt is not None: if train_args.resume_ckpt.startswith("https"): checkpoint = torch.hub.load_state_dict_from_url( train_args.resume_ckpt, map_location="cpu", check_hash=True ) else: checkpoint = torch.load( train_args.resume_ckpt, map_location="cpu", weights_only=False ) print("Resume checkpoint %s" % train_args.resume_ckpt) model_without_ddp.load_state_dict(checkpoint["model"], strict=False) train_args.start_epoch = checkpoint["epoch"] + 1 optimizer.load_state_dict(checkpoint["optimizer"]) if "scaler" in checkpoint: loss_scaler.load_state_dict(checkpoint["scaler"]) if "best_so_far" in checkpoint: best_so_far = checkpoint["best_so_far"] print(" & best_so_far={:g}".format(best_so_far)) else: print("") print( "With optim & sched! start_epoch={:d}".format(train_args.start_epoch), end="", ) return best_so_far def all_reduce_mean(x): """ Compute the mean of a value across all processes in distributed training. This function takes a value, reduces it across all processes using all_reduce, and returns the mean value. Args: x: The value to reduce (typically a scalar) Returns: float: The mean value across all processes """ world_size = get_world_size() if world_size > 1: x_reduce = torch.tensor(x).cuda() dist.all_reduce(x_reduce) x_reduce /= world_size return x_reduce.item() else: return x def _replace(text, src, tgt, rm=""): """ Advanced string replacement utility. Given a text: - replace all elements in src by the corresponding element in tgt - remove all elements in rm Args: text (str): The input text to modify src (str): String of characters to replace tgt (str): String of replacement characters (must be same length as src or length 1) rm (str, optional): String of characters to remove. Defaults to "". Returns: str: The modified text after replacements and removals Raises: AssertionError: If src and tgt have different lengths (unless tgt has length 1) """ if len(tgt) == 1: tgt = tgt * len(src) assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len" for s, t in zip(src, tgt): text = text.replace(s, t) for c in rm: text = text.replace(c, "") return text def filename(obj): """ Transform a Python object or command into a proper filename. This function converts a Python object or command string into a valid filename by replacing special characters and ensuring the filename is not too long. Special replacements: - \1 gets replaced by slash '/' - \2 gets replaced by comma ',' Args: obj: The Python object or string to convert to a filename Returns: str: A valid filename derived from the input object Raises: AssertionError: If any part of the resulting path is longer than 256 characters """ if not isinstance(obj, str): obj = repr(obj) obj = str(obj).replace("()", "") obj = _replace(obj, "_,(*/\1\2", "-__x%/,", rm=" )'\"") assert all(len(s) < 256 for s in obj.split(os.sep)), ( "filename too long (>256 characters):\n" + obj ) return obj def compute_effective_lrs(train_args): """ Compute the effective learning rates based on batch size scaling. This function calculates the effective learning rates for the main model and any submodules based on the effective batch size (accounting for gradient accumulation and distributed training) and the base learning rates. Args: train_args: Training arguments containing batch size, accumulation iterations, learning rates, and submodule configurations Returns: train_args: Updated training arguments with computed effective learning rates """ # Compute the effective batch size eff_batch_size = train_args.batch_size * train_args.accum_iter * get_world_size() print("Accumulate grad iterations: %d" % train_args.accum_iter) print("Effective batch size: %d" % eff_batch_size) # Compute the effective default learning rate if train_args.lr is None: # only base_lr is specified train_args.lr = train_args.blr * math.sqrt( eff_batch_size / train_args.base_eff_batch_size ) print( f"Base default lr for effective batch size {eff_batch_size}: %.2e" % (train_args.lr * math.sqrt(train_args.base_eff_batch_size / eff_batch_size)) ) print("Actual default lr: %.2e" % train_args.lr) for submodule, config in train_args.submodule_configs.items(): if config.get("lr") is None: # only base_lr is specified config["lr"] = config["blr"] * math.sqrt( eff_batch_size / train_args.base_eff_batch_size ) print( f"Submodule {submodule} base lr for effective batch size {eff_batch_size}: %.2e" % ( config["lr"] * math.sqrt(train_args.base_eff_batch_size / eff_batch_size) ) ) print(f"Submodule {submodule} actual lr: %.2e" % config["lr"]) return train_args def get_parameter_groups( model, lr, weight_decay, skip_list=[], submodule_configs=None, warn_not_in_submodule=False, ): """ Get parameter groups for optimizer with customized learning rates and weight decay. This function organizes model parameters into groups for the optimizer, allowing different learning rates and weight decay values for different parts of the model. Parameters are grouped by: 1. Whether they should have weight decay applied (bias terms and 1D tensors typically don't) 2. Which submodule they belong to (if submodule_configs is provided) Args: model (torch.nn.Module): Model to get parameter groups for lr (float): Default learning rate for parameters not in submodule_configs weight_decay (float): Default weight decay for parameters not in submodule_configs skip_list (list): List of parameter names to skip weight decay for submodule_configs (dict, optional): Dictionary mapping submodule prefixes to configs with 'lr' and 'weight_decay' keys warn_not_in_submodule (bool, optional): Whether to warn if a parameter does not belong to any submodule. Defaults to False. Returns: tuple: A tuple containing: - parameter_group_vars (list): List of parameter groups for optimizer - parameter_group_name_to_idx_map (dict): Mapping from submodule name to parameter group indices - parameter_group_idx_to_name_map (dict): Mapping from parameter group index to submodule name """ if submodule_configs is None: submodule_configs = {} parameter_group_names = {} parameter_group_vars = {} parameter_group_name_to_idx_map = {} parameter_group_idx_to_name_map = {} mapping_index = 0 for name, param in model.named_parameters(): # Skip frozen parameters if not param.requires_grad: continue # Determine the submodule this parameter belongs to submodule_name = None for submodule, config in submodule_configs.items(): if name.startswith(submodule): submodule_name = submodule break if submodule_name: config = submodule_configs[submodule_name] this_weight_decay = config.get("weight_decay", weight_decay) this_lr = config.get("lr", lr) # Freeze the parameters if lr is 0 if this_lr == 0: param.requires_grad = False continue else: this_weight_decay = weight_decay this_lr = lr if warn_not_in_submodule and submodule_configs is not None: print( f"Warning: Parameter {name} does not belong to any submodule in {submodule_configs.keys()}." ) # Assign weight decay values if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: group_name = f"{submodule_name}_no_decay" if submodule_name else "no_decay" this_weight_decay = 0.0 else: group_name = f"{submodule_name}_decay" if submodule_name else "decay" if group_name not in parameter_group_names: parameter_group_names[group_name] = { "weight_decay": this_weight_decay, "lr": this_lr, "params": [], } parameter_group_vars[group_name] = { "weight_decay": this_weight_decay, "lr": this_lr, "params": [], } submodule_name_mapping = submodule_name if submodule_name else "default" if submodule_name_mapping not in parameter_group_name_to_idx_map: parameter_group_name_to_idx_map[submodule_name_mapping] = [ mapping_index ] else: parameter_group_name_to_idx_map[submodule_name_mapping].append( mapping_index ) parameter_group_idx_to_name_map[mapping_index] = submodule_name_mapping mapping_index += 1 parameter_group_vars[group_name]["params"].append(param) parameter_group_names[group_name]["params"].append(name) # Print the parameter groups print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) return ( list(parameter_group_vars.values()), parameter_group_name_to_idx_map, parameter_group_idx_to_name_map, ) def adjust_learning_rate( optimizer, epoch, train_args, parameter_group_idx_to_name_map, submodule_configs=None, ): """ Adjust the learning rate based on the schedule type and current epoch. This function updates the learning rates for all parameter groups in the optimizer according to the specified learning rate schedule. Different submodules can have different learning rate schedules. Currently supported schedule types: - linear_warmup_half_cycle_cosine_decay: Linear warmup followed by cosine decay Args: optimizer (torch.optim.Optimizer): The optimizer to update epoch (int): Current training epoch train_args: Training arguments containing schedule type, warmup epochs, etc. parameter_group_idx_to_name_map (dict): Mapping from parameter group index to submodule name submodule_configs (dict, optional): Dictionary of submodule-specific configurations for learning rate schedules Raises: ValueError: If an unsupported schedule type is specified """ if submodule_configs is None: submodule_configs = {} for group_num, param_group in enumerate(optimizer.param_groups): submodule_name = parameter_group_idx_to_name_map.get(group_num) if submodule_name in submodule_configs: config = submodule_configs[submodule_name] lr = config.get("lr", train_args.lr) warmup_epochs = config.get("warmup_epochs", train_args.warmup_epochs) min_lr = config.get("min_lr", train_args.min_lr) schedule_type = config.get("schedule_type", train_args.schedule_type) else: lr = train_args.lr warmup_epochs = train_args.warmup_epochs min_lr = train_args.min_lr schedule_type = train_args.schedule_type if schedule_type == "linear_warmup_half_cycle_cosine_decay": if epoch < warmup_epochs: lr = lr * epoch / warmup_epochs else: lr = min_lr + (lr - min_lr) * 0.5 * ( 1.0 + math.cos( math.pi * (epoch - warmup_epochs) / (train_args.epochs - warmup_epochs) ) ) else: raise ValueError(f"Schedule type {schedule_type} not implemented") param_group["lr"] = lr def debug_after_backward( model, check_missing_gradients=True, check_gradient_mismatch=False, target_size=(256, 256, 1, 1), target_stride=(256, 1, 256, 256), ): """ Debugging function to check for gradient issues after backward pass. This function performs two types of gradient debugging: 1. Gradient mismatch: Checks for parameters with specific gradient shapes and strides that might indicate incorrect gradient computation. 2. Missing gradients: Identifies parameters that require gradients but didn't receive any. Args: model (torch.nn.Module): The model to check gradients for check_missing_gradients (bool, optional): Whether to check for missing gradients. Defaults to True. check_gradient_mismatch (bool, optional): Whether to check for gradient mismatches. Defaults to False. target_size (tuple, optional): Target tensor size to check for gradient mismatch. Defaults to (256, 256, 1, 1). target_stride (tuple, optional): Target tensor stride to check for gradient mismatch. Defaults to (256, 1, 256, 256). """ # Debug for missing gradients if check_missing_gradients: missing_grad_params = [] for name, param in model.named_parameters(): if param.requires_grad and param.grad is None: missing_grad_params.append(name) if missing_grad_params: print("Parameters requiring gradients but missing gradients:") for name in missing_grad_params: print(f" - {name}") else: print("All parameters requiring gradients received gradients!") # Debug for gradient mismatch if check_gradient_mismatch: for name, param in model.named_parameters(): grad = param.grad if grad is None: continue if grad.size() == target_size and grad.stride() == target_stride: print(f"Found parameter with incorrect gradient: '{name}'") print(f"Gradient shape: {grad.size()}, strides: {grad.stride()}")