import torch import numpy as np def nantensor(*args, **kwargs): return torch.ones(*args, **kwargs) * np.nan def nanmean(v, *args, inplace=False, **kwargs): if not inplace: v = v.clone() is_nan = torch.isnan(v) v[is_nan] = 0 return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)