Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops.einops import rearrange, repeat | |
| class FinePreprocess(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.cat_c_feat = config["fine_concat_coarse_feat"] | |
| self.W = self.config["fine_window_size"] | |
| d_model_c = self.config["coarse"]["d_model"] | |
| d_model_f = self.config["fine"]["d_model"] | |
| self.d_model_f = d_model_f | |
| if self.cat_c_feat: | |
| self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) | |
| self.merge_feat = nn.Linear(2 * d_model_f, d_model_f, bias=True) | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") | |
| def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): | |
| W = self.W | |
| stride = data["hw0_f"][0] // data["hw0_c"][0] | |
| data.update({"W": W}) | |
| if data["b_ids"].shape[0] == 0: | |
| feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) | |
| feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) | |
| return feat0, feat1 | |
| # 1. unfold(crop) all local windows | |
| feat_f0_unfold = F.unfold( | |
| feat_f0, kernel_size=(W, W), stride=stride, padding=W // 2 | |
| ) | |
| feat_f0_unfold = rearrange(feat_f0_unfold, "n (c ww) l -> n l ww c", ww=W**2) | |
| feat_f1_unfold = F.unfold( | |
| feat_f1, kernel_size=(W, W), stride=stride, padding=W // 2 | |
| ) | |
| feat_f1_unfold = rearrange(feat_f1_unfold, "n (c ww) l -> n l ww c", ww=W**2) | |
| # 2. select only the predicted matches | |
| feat_f0_unfold = feat_f0_unfold[data["b_ids"], data["i_ids"]] # [n, ww, cf] | |
| feat_f1_unfold = feat_f1_unfold[data["b_ids"], data["j_ids"]] | |
| # option: use coarse-level feature as context: concat and linear | |
| if self.cat_c_feat: | |
| feat_c_win = self.down_proj( | |
| torch.cat( | |
| [ | |
| feat_c0[data["b_ids"], data["i_ids"]], | |
| feat_c1[data["b_ids"], data["j_ids"]], | |
| ], | |
| 0, | |
| ) | |
| ) # [2n, c] | |
| feat_cf_win = self.merge_feat( | |
| torch.cat( | |
| [ | |
| torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] | |
| repeat(feat_c_win, "n c -> n ww c", ww=W**2), # [2n, ww, cf] | |
| ], | |
| -1, | |
| ) | |
| ) | |
| feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) | |
| return feat_f0_unfold, feat_f1_unfold | |