| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| class BatchedKDE(nn.Module): | |
| def __init__(self, bandwith=0.0): | |
| super().__init__() | |
| self.bandwidth = bandwith | |
| self.X = None | |
| def fit(self, X: torch.Tensor): | |
| self.mu = X | |
| self.nmu2 = torch.sum(X * X, dim=-1, keepdim=True) | |
| b, n, d = X.shape | |
| if self.bandwidth == 0: | |
| q = torch.quantile(X.view(b, -1), 0.75) - torch.quantile( | |
| X.view(b, -1), 0.25 | |
| ) | |
| self.bandwidth = ( | |
| 0.9 * torch.min(torch.std(X, dim=(1, 2)), q / 1.34) / pow(n, 0.2) | |
| ) | |
| def score(self, X): | |
| nx2 = torch.sum(X * X, dim=-1, keepdim=True) | |
| dot = torch.einsum("bnd, bmd -> bnm", X, self.mu) | |
| dist = nx2 + self.nmu2.transpose(1, 2) - 2 * dot | |
| return torch.sum( | |
| torch.exp(-dist / self.bandwidth.unsqueeze(-1).unsqueeze(-1)), dim=-1 | |
| ) | |