File size: 327 Bytes
3b6a091 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch
import numpy as np
def nantensor(*args, **kwargs):
return torch.ones(*args, **kwargs) * np.nan
def nanmean(v, *args, inplace=False, **kwargs):
if not inplace:
v = v.clone()
is_nan = torch.isnan(v)
v[is_nan] = 0
return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
|