Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import os | |
| import torch | |
| from pathlib import Path | |
| import math | |
| import numpy as np | |
| from torch import nn | |
| from PIL import Image | |
| from torchvision.transforms import ToTensor | |
| from romatch.utils.kde import kde | |
| class BasicLayer(nn.Module): | |
| """ | |
| Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True): | |
| super().__init__() | |
| self.layer = nn.Sequential( | |
| nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias), | |
| nn.BatchNorm2d(out_channels, affine=False), | |
| nn.ReLU(inplace = True) if relu else nn.Identity() | |
| ) | |
| def forward(self, x): | |
| return self.layer(x) | |
| class TinyRoMa(nn.Module): | |
| """ | |
| Implementation of architecture described in | |
| "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024." | |
| """ | |
| def __init__(self, xfeat = None, | |
| freeze_xfeat = True, | |
| sample_mode = "threshold_balanced", | |
| symmetric = False, | |
| exact_softmax = False): | |
| super().__init__() | |
| del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher | |
| if freeze_xfeat: | |
| xfeat.train(False) | |
| self.xfeat = [xfeat]# hide params from ddp | |
| else: | |
| self.xfeat = nn.ModuleList([xfeat]) | |
| self.freeze_xfeat = freeze_xfeat | |
| match_dim = 256 | |
| self.coarse_matcher = nn.Sequential( | |
| BasicLayer(64+64+2, match_dim,), | |
| BasicLayer(match_dim, match_dim,), | |
| BasicLayer(match_dim, match_dim,), | |
| BasicLayer(match_dim, match_dim,), | |
| nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0)) | |
| fine_match_dim = 64 | |
| self.fine_matcher = nn.Sequential( | |
| BasicLayer(24+24+2, fine_match_dim,), | |
| BasicLayer(fine_match_dim, fine_match_dim,), | |
| BasicLayer(fine_match_dim, fine_match_dim,), | |
| BasicLayer(fine_match_dim, fine_match_dim,), | |
| nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),) | |
| self.sample_mode = sample_mode | |
| self.sample_thresh = 0.05 | |
| self.symmetric = symmetric | |
| self.exact_softmax = exact_softmax | |
| def device(self): | |
| return self.fine_matcher[-1].weight.device | |
| def preprocess_tensor(self, x): | |
| """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """ | |
| H, W = x.shape[-2:] | |
| _H, _W = (H//32) * 32, (W//32) * 32 | |
| rh, rw = H/_H, W/_W | |
| x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False) | |
| return x, rh, rw | |
| def forward_single(self, x): | |
| with torch.inference_mode(self.freeze_xfeat or not self.training): | |
| xfeat = self.xfeat[0] | |
| with torch.no_grad(): | |
| x = x.mean(dim=1, keepdim = True) | |
| x = xfeat.norm(x) | |
| #main backbone | |
| x1 = xfeat.block1(x) | |
| x2 = xfeat.block2(x1 + xfeat.skip1(x)) | |
| x3 = xfeat.block3(x2) | |
| x4 = xfeat.block4(x3) | |
| x5 = xfeat.block5(x4) | |
| x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear') | |
| x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear') | |
| feats = xfeat.block_fusion( x3 + x4 + x5 ) | |
| if self.freeze_xfeat: | |
| return x2.clone(), feats.clone() | |
| return x2, feats | |
| def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None): | |
| if coords.shape[-1] == 2: | |
| return self._to_pixel_coordinates(coords, H_A, W_A) | |
| if isinstance(coords, (list, tuple)): | |
| kpts_A, kpts_B = coords[0], coords[1] | |
| else: | |
| kpts_A, kpts_B = coords[...,:2], coords[...,2:] | |
| return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B) | |
| def _to_pixel_coordinates(self, coords, H, W): | |
| kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1) | |
| return kpts | |
| def pos_embed(self, corr_volume: torch.Tensor): | |
| B, H1, W1, H0, W0 = corr_volume.shape | |
| grid = torch.stack( | |
| torch.meshgrid( | |
| torch.linspace(-1+1/W1,1-1/W1, W1), | |
| torch.linspace(-1+1/H1,1-1/H1, H1), | |
| indexing = "xy"), | |
| dim = -1).float().to(corr_volume).reshape(H1*W1, 2) | |
| down = 4 | |
| if not self.training and not self.exact_softmax: | |
| grid_lr = torch.stack( | |
| torch.meshgrid( | |
| torch.linspace(-1+down/W1,1-down/W1, W1//down), | |
| torch.linspace(-1+down/H1,1-down/H1, H1//down), | |
| indexing = "xy"), | |
| dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2) | |
| cv = corr_volume | |
| best_match = cv.reshape(B,H1*W1,H0,W0).argmax(dim=1) # B, HW, H, W | |
| P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1) | |
| pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr) | |
| pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2) | |
| #print("hej") | |
| else: | |
| P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W | |
| pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid) | |
| return pos_embeddings | |
| def visualize_warp(self, warp, certainty, im_A = None, im_B = None, | |
| im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False): | |
| device = warp.device | |
| H,W2,_ = warp.shape | |
| W = W2//2 if symmetric else W2 | |
| if im_A is None: | |
| from PIL import Image | |
| im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") | |
| if not isinstance(im_A, torch.Tensor): | |
| im_A = im_A.resize((W,H)) | |
| im_B = im_B.resize((W,H)) | |
| x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1) | |
| if symmetric: | |
| x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1) | |
| else: | |
| if symmetric: | |
| x_A = im_A | |
| x_B = im_B | |
| im_A_transfer_rgb = F.grid_sample( | |
| x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False | |
| )[0] | |
| if symmetric: | |
| im_B_transfer_rgb = F.grid_sample( | |
| x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False | |
| )[0] | |
| warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2) | |
| white_im = torch.ones((H,2*W),device=device) | |
| else: | |
| warp_im = im_A_transfer_rgb | |
| white_im = torch.ones((H, W), device = device) | |
| vis_im = certainty * warp_im + (1 - certainty) * white_im | |
| if save_path is not None: | |
| from romatch.utils import tensor_to_pil | |
| tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path) | |
| return vis_im | |
| def corr_volume(self, feat0, feat1): | |
| """ | |
| input: | |
| feat0 -> torch.Tensor(B, C, H, W) | |
| feat1 -> torch.Tensor(B, C, H, W) | |
| return: | |
| corr_volume -> torch.Tensor(B, H, W, H, W) | |
| """ | |
| B, C, H0, W0 = feat0.shape | |
| B, C, H1, W1 = feat1.shape | |
| feat0 = feat0.view(B, C, H0*W0) | |
| feat1 = feat1.view(B, C, H1*W1) | |
| corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16 | |
| return corr_volume | |
| def match_from_path(self, im0_path, im1_path): | |
| device = self.device | |
| im0 = ToTensor()(Image.open(im0_path))[None].to(device) | |
| im1 = ToTensor()(Image.open(im1_path))[None].to(device) | |
| return self.match(im0, im1, batched = False) | |
| def match(self, im0, im1, *args, batched = True): | |
| # stupid | |
| if isinstance(im0, (str, Path)): | |
| return self.match_from_path(im0, im1) | |
| elif isinstance(im0, Image.Image): | |
| batched = False | |
| device = self.device | |
| im0 = ToTensor()(im0)[None].to(device) | |
| im1 = ToTensor()(im1)[None].to(device) | |
| B,C,H0,W0 = im0.shape | |
| B,C,H1,W1 = im1.shape | |
| self.train(False) | |
| corresps = self.forward({"im_A":im0, "im_B":im1}) | |
| #return 1,1 | |
| flow = F.interpolate( | |
| corresps[4]["flow"], | |
| size = (H0, W0), | |
| mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2) | |
| grid = torch.stack( | |
| torch.meshgrid( | |
| torch.linspace(-1+1/W0,1-1/W0, W0), | |
| torch.linspace(-1+1/H0,1-1/H0, H0), | |
| indexing = "xy"), | |
| dim = -1).float().to(flow.device).expand(B, H0, W0, 2) | |
| certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False) | |
| warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid() | |
| if batched: | |
| return warp, cert | |
| else: | |
| return warp[0], cert[0] | |
| def sample( | |
| self, | |
| matches, | |
| certainty, | |
| num=5_000, | |
| ): | |
| H,W,_ = matches.shape | |
| if "threshold" in self.sample_mode: | |
| upper_thresh = self.sample_thresh | |
| certainty = certainty.clone() | |
| certainty[certainty > upper_thresh] = 1 | |
| matches, certainty = ( | |
| matches.reshape(-1, 4), | |
| certainty.reshape(-1), | |
| ) | |
| expansion_factor = 4 if "balanced" in self.sample_mode else 1 | |
| good_samples = torch.multinomial(certainty, | |
| num_samples = min(expansion_factor*num, len(certainty)), | |
| replacement=False) | |
| good_matches, good_certainty = matches[good_samples], certainty[good_samples] | |
| if "balanced" not in self.sample_mode: | |
| return good_matches, good_certainty | |
| use_half = True if matches.device.type == "cuda" else False | |
| down = 1 if matches.device.type == "cuda" else 8 | |
| density = kde(good_matches, std=0.1, half = use_half, down = down) | |
| p = 1 / (density+1) | |
| p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones | |
| balanced_samples = torch.multinomial(p, | |
| num_samples = min(num,len(good_certainty)), | |
| replacement=False) | |
| return good_matches[balanced_samples], good_certainty[balanced_samples] | |
| def forward(self, batch): | |
| """ | |
| input: | |
| x -> torch.Tensor(B, C, H, W) grayscale or rgb images | |
| return: | |
| """ | |
| im0 = batch["im_A"] | |
| im1 = batch["im_B"] | |
| corresps = {} | |
| im0, rh0, rw0 = self.preprocess_tensor(im0) | |
| im1, rh1, rw1 = self.preprocess_tensor(im1) | |
| B, C, H0, W0 = im0.shape | |
| B, C, H1, W1 = im1.shape | |
| to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None] | |
| if im0.shape[-2:] == im1.shape[-2:]: | |
| x = torch.cat([im0, im1], dim=0) | |
| x = self.forward_single(x) | |
| feats_x0_c, feats_x1_c = x[1].chunk(2) | |
| feats_x0_f, feats_x1_f = x[0].chunk(2) | |
| else: | |
| feats_x0_f, feats_x0_c = self.forward_single(im0) | |
| feats_x1_f, feats_x1_c = self.forward_single(im1) | |
| corr_volume = self.corr_volume(feats_x0_c, feats_x1_c) | |
| coarse_warp = self.pos_embed(corr_volume) | |
| coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1) | |
| feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False) | |
| coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1)) | |
| coarse_matches = coarse_matches + coarse_matches_delta * to_normalized | |
| corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]} | |
| coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False) | |
| coarse_matches_up_detach = coarse_matches_up.detach()#note the detach | |
| feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False) | |
| fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1)) | |
| fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized | |
| corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]} | |
| return corresps |