Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| from torch.nn import functional as F | |
| ############################################################################### | |
| # Sequence filters | |
| ############################################################################### | |
| def mean(signals, win_length=9): | |
| """Averave filtering for signals containing nan values | |
| Arguments | |
| signals (torch.tensor (shape=(batch, time))) | |
| The signals to filter | |
| win_length | |
| The size of the analysis window | |
| Returns | |
| filtered (torch.tensor (shape=(batch, time))) | |
| """ | |
| assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)" | |
| signals = signals.unsqueeze(1) | |
| # Apply the mask by setting masked elements to zero, or make NaNs zero | |
| mask = ~torch.isnan(signals) | |
| masked_x = torch.where(mask, signals, torch.zeros_like(signals)) | |
| # Create a ones kernel with the same number of channels as the input tensor | |
| ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device) | |
| # Perform sum pooling | |
| sum_pooled = F.conv1d( | |
| masked_x, | |
| ones_kernel, | |
| stride=1, | |
| padding=win_length // 2, | |
| ) | |
| # Count the non-masked (valid) elements in each pooling window | |
| valid_count = F.conv1d( | |
| mask.float(), | |
| ones_kernel, | |
| stride=1, | |
| padding=win_length // 2, | |
| ) | |
| valid_count = valid_count.clamp(min=1) # Avoid division by zero | |
| # Perform masked average pooling | |
| avg_pooled = sum_pooled / valid_count | |
| # Fill zero values with NaNs | |
| avg_pooled[avg_pooled == 0] = float("nan") | |
| return avg_pooled.squeeze(1) | |
| def median(signals, win_length): | |
| """Median filtering for signals containing nan values | |
| Arguments | |
| signals (torch.tensor (shape=(batch, time))) | |
| The signals to filter | |
| win_length | |
| The size of the analysis window | |
| Returns | |
| filtered (torch.tensor (shape=(batch, time))) | |
| """ | |
| assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)" | |
| signals = signals.unsqueeze(1) | |
| mask = ~torch.isnan(signals) | |
| masked_x = torch.where(mask, signals, torch.zeros_like(signals)) | |
| padding = win_length // 2 | |
| x = F.pad(masked_x, (padding, padding), mode="reflect") | |
| mask = F.pad(mask.float(), (padding, padding), mode="constant", value=0) | |
| x = x.unfold(2, win_length, 1) | |
| mask = mask.unfold(2, win_length, 1) | |
| x = x.contiguous().view(x.size()[:3] + (-1,)) | |
| mask = mask.contiguous().view(mask.size()[:3] + (-1,)) | |
| # Combine the mask with the input tensor | |
| x_masked = torch.where(mask.bool(), x.double(), float("inf")).to(x) | |
| # Sort the masked tensor along the last dimension | |
| x_sorted, _ = torch.sort(x_masked, dim=-1) | |
| # Compute the count of non-masked (valid) values | |
| valid_count = mask.sum(dim=-1) | |
| # Calculate the index of the median value for each pooling window | |
| median_idx = ((valid_count - 1) // 2).clamp(min=0) | |
| # Gather the median values using the calculated indices | |
| median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1) | |
| # Fill infinite values with NaNs | |
| median_pooled[torch.isinf(median_pooled)] = float("nan") | |
| return median_pooled.squeeze(1) | |
| ############################################################################### | |
| # Utilities | |
| ############################################################################### | |
| def nanfilter(signals, win_length, filter_fn): | |
| """Filters a sequence, ignoring nan values | |
| Arguments | |
| signals (torch.tensor (shape=(batch, time))) | |
| The signals to filter | |
| win_length | |
| The size of the analysis window | |
| filter_fn (function) | |
| The function to use for filtering | |
| Returns | |
| filtered (torch.tensor (shape=(batch, time))) | |
| """ | |
| # Output buffer | |
| filtered = torch.empty_like(signals) | |
| # Loop over frames | |
| for i in range(signals.size(1)): | |
| # Get analysis window bounds | |
| start = max(0, i - win_length // 2) | |
| end = min(signals.size(1), i + win_length // 2 + 1) | |
| # Apply filter to window | |
| filtered[:, i] = filter_fn(signals[:, start:end]) | |
| return filtered | |
| def nanmean(signals): | |
| """Computes the mean, ignoring nans | |
| Arguments | |
| signals (torch.tensor [shape=(batch, time)]) | |
| The signals to filter | |
| Returns | |
| filtered (torch.tensor [shape=(batch, time)]) | |
| """ | |
| signals = signals.clone() | |
| # Find nans | |
| nans = torch.isnan(signals) | |
| # Set nans to 0. | |
| signals[nans] = 0. | |
| # Compute average | |
| return signals.sum(dim=1) / (~nans).float().sum(dim=1) | |
| def nanmedian(signals): | |
| """Computes the median, ignoring nans | |
| Arguments | |
| signals (torch.tensor [shape=(batch, time)]) | |
| The signals to filter | |
| Returns | |
| filtered (torch.tensor [shape=(batch, time)]) | |
| """ | |
| # Find nans | |
| nans = torch.isnan(signals) | |
| # Compute median for each slice | |
| medians = [nanmedian1d(signal[~nan]) for signal, nan in zip(signals, nans)] | |
| # Stack results | |
| return torch.tensor(medians, dtype=signals.dtype, device=signals.device) | |
| def nanmedian1d(signal): | |
| """Computes the median. If signal is empty, returns torch.nan | |
| Arguments | |
| signal (torch.tensor [shape=(time,)]) | |
| Returns | |
| median (torch.tensor [shape=(1,)]) | |
| """ | |
| return torch.median(signal) if signal.numel() else np.nan | |