Spaces:
Runtime error
Runtime error
| import torch | |
| def Pdist2(x, y): | |
| """compute the paired distance between x and y.""" | |
| x_norm = (x ** 2).sum(1).view(-1, 1) | |
| if y is not None: | |
| y_norm = (y ** 2).sum(1).view(1, -1) | |
| else: | |
| y = x | |
| y_norm = x_norm.view(1, -1) | |
| Pdist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)) | |
| Pdist[Pdist<0]=0 | |
| return Pdist | |
| def MMD_batch2(Fea, len_s, Fea_org, sigma, sigma0=0.1, epsilon = 10**(-10), is_smooth=True, is_var_computed=True, use_1sample_U=True, coeff_xy=2): | |
| X = Fea[0:len_s, :] | |
| Y = Fea[len_s:, :] | |
| if is_smooth: | |
| X_org = Fea_org[0:len_s, :] | |
| Y_org = Fea_org[len_s:, :] | |
| L = 1 # generalized Gaussian (if L>1) | |
| nx = X.shape[0] | |
| ny = Y.shape[0] | |
| Dxx = Pdist2(X, X) | |
| Dyy = torch.zeros(Fea.shape[0] - len_s, 1).to(Dxx.device) | |
| # Dyy = Pdist2(Y, Y) | |
| Dxy = Pdist2(X, Y).transpose(0,1) | |
| if is_smooth: | |
| Dxx_org = Pdist2(X_org, X_org) | |
| Dyy_org = torch.zeros(Fea.shape[0] - len_s, 1).to(Dxx.device) | |
| # Dyy_org = Pdist2(Y_org, Y_org) # 1,1 0 | |
| Dxy_org = Pdist2(X_org, Y_org).transpose(0,1) | |
| if is_smooth: | |
| Kx = (1-epsilon) * torch.exp(-(Dxx / sigma0)**L -Dxx_org / sigma) + epsilon * torch.exp(-Dxx_org / sigma) | |
| Ky = (1-epsilon) * torch.exp(-(Dyy / sigma0)**L -Dyy_org / sigma) + epsilon * torch.exp(-Dyy_org / sigma) | |
| Kxy = (1-epsilon) * torch.exp(-(Dxy / sigma0)**L -Dxy_org / sigma) + epsilon * torch.exp(-Dxy_org / sigma) | |
| else: | |
| Kx = torch.exp(-Dxx / sigma0) | |
| Ky = torch.exp(-Dyy / sigma0) | |
| Kxy = torch.exp(-Dxy / sigma0) | |
| nx = Kx.shape[0] | |
| is_unbiased = False | |
| if 1: | |
| xx = torch.div((torch.sum(Kx)), (nx * nx)) | |
| yy = Ky.reshape(-1) | |
| # one-sample U-statistic. | |
| xy = torch.div(torch.sum(Kxy, dim = 1), (nx )) | |
| mmd2 = xx - 2 * xy + yy | |
| return mmd2 | |