""" Miscellaneous utility functions. """ import logging import os import random import numpy as np import torch class StreamToLogger: """ A class that redirects stream writes to a logger. This class can be used to redirect stdout or stderr to a logger by implementing a file-like interface with write and flush methods. Parameters: - logger: A logger instance that will receive the log messages - log_level: The logging level to use (default: logging.INFO) """ def __init__(self, logger, log_level=logging.INFO): self.logger = logger self.log_level = log_level self.linebuf = "" def write(self, buf): """ Write the buffer content to the logger. Parameters: - buf: The string buffer to write """ for line in buf.rstrip().splitlines(): self.logger.log(self.log_level, line.rstrip()) def flush(self): """ Flush method to comply with file-like object interface. This method is required but does nothing in this implementation. """ pass def seed_everything(seed: int = 42): """ Set the `seed` value for torch and numpy seeds. Also turns on deterministic execution for cudnn. Parameters: - seed: A hashable seed value """ random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False print(f"Seed set to: {seed}") def invalid_to_nans(arr, valid_mask, ndim=999): """ Replace invalid values in an array with NaN values based on a validity mask. Parameters: - arr: Input array (typically a PyTorch tensor) - valid_mask: Boolean mask indicating valid elements (True) and invalid elements (False) - ndim: Maximum number of dimensions to keep; flattens dimensions if arr.ndim > ndim Returns: - Modified array with invalid values replaced by NaN """ if valid_mask is not None: arr = arr.clone() arr[~valid_mask] = float("nan") if arr.ndim > ndim: arr = arr.flatten(-2 - (arr.ndim - ndim), -2) return arr def invalid_to_zeros(arr, valid_mask, ndim=999): """ Replace invalid values in an array with zeros based on a validity mask. Parameters: - arr: Input array (typically a PyTorch tensor) - valid_mask: Boolean mask indicating valid elements (True) and invalid elements (False) - ndim: Maximum number of dimensions to keep; flattens dimensions if arr.ndim > ndim Returns: - Tuple containing: - Modified array with invalid values replaced by zeros - nnz: Number of non-zero (valid) elements per sample in the batch """ if valid_mask is not None: arr = arr.clone() arr[~valid_mask] = 0 nnz = valid_mask.view(len(valid_mask), -1).sum(1) else: nnz = ( arr[..., 0].numel() // len(arr) if len(arr) else 0 ) # Number of pixels per image if arr.ndim > ndim: arr = arr.flatten(-2 - (arr.ndim - ndim), -2) return arr, nnz