Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| from scipy.ndimage import binary_dilation | |
| from skimage.filters import threshold_otsu | |
| def gaussian_blur(image, kernel_size=7, sigma=2): | |
| """ | |
| Apply Gaussian blur to a binary mask image. | |
| Args: | |
| image (torch.Tensor): Input binary mask (1x1xHxW or HxW) as a PyTorch tensor. | |
| kernel_size (int): Size of the Gaussian kernel. Should be odd. | |
| sigma (float): Standard deviation of the Gaussian kernel. | |
| Returns: | |
| torch.Tensor: Blurred mask image. | |
| """ | |
| # Ensure kernel size is odd | |
| if kernel_size % 2 == 0: | |
| kernel_size += 1 | |
| # Generate Gaussian kernel | |
| x = torch.arange(kernel_size, device=image.device, dtype=image.dtype) - kernel_size // 2 | |
| gaussian_1d = torch.exp(-(x**2) / (2 * sigma**2)) | |
| gaussian_1d = gaussian_1d / gaussian_1d.sum() | |
| gaussian_kernel = gaussian_1d[:, None] * gaussian_1d[None, :] | |
| gaussian_kernel = gaussian_kernel / gaussian_kernel.sum() # Normalize | |
| # Reshape to fit convolution: (out_channels, in_channels, kH, kW) | |
| gaussian_kernel = gaussian_kernel.unsqueeze(0).unsqueeze(0) | |
| # Ensure image is 4D (BxCxHxW) | |
| if image.ndim == 2: # HxW | |
| image = image.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions | |
| elif image.ndim == 3: # CxHxW | |
| image = image.unsqueeze(0) # Add batch dimension | |
| # Convolve image with Gaussian kernel | |
| blurred_image = F.conv2d(image, gaussian_kernel, padding=kernel_size // 2) | |
| return blurred_image.squeeze() # Remove extra dimensions | |
| def mask_interpolate(mask, size=128): | |
| mask = torch.tensor(mask) | |
| mask = F.interpolate(mask[None, None, ...], size, mode='bicubic') | |
| mask = mask.squeeze() | |
| return mask | |
| def get_mask(ca, ca_index, gb_kernel=11, gb_sigma=2, dilation=1, nbins=64): | |
| if ca is None: | |
| return None | |
| else: | |
| ca = ca[0].mean(0) | |
| token_ca = ca[..., ca_index].mean(dim=-1).reshape(64, 64) | |
| token_ca = gaussian_blur(token_ca, kernel_size=gb_kernel, sigma=gb_sigma) | |
| token_ca = mask_interpolate(token_ca, size=1024) | |
| thres = threshold_otsu(token_ca.float().cpu().numpy(), nbins=nbins) | |
| mask = token_ca > thres | |
| mask = mask_interpolate(mask.to(ca.dtype), 128) | |
| if dilation: | |
| mask = binary_dilation(mask.float().cpu().numpy(), iterations=dilation) | |
| mask = torch.tensor(mask, device=ca.device, dtype=ca.dtype) | |
| return mask |