| import torch | |
| import numpy as np | |
| def nantensor(*args, **kwargs): | |
| return torch.ones(*args, **kwargs) * np.nan | |
| def nanmean(v, *args, inplace=False, **kwargs): | |
| if not inplace: | |
| v = v.clone() | |
| is_nan = torch.isnan(v) | |
| v[is_nan] = 0 | |
| return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs) | |