Spaces:
Runtime error
Runtime error
| from einops.einops import rearrange | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from dkm.utils.utils import warp_kpts | |
| class DepthRegressionLoss(nn.Module): | |
| def __init__( | |
| self, | |
| robust=True, | |
| center_coords=False, | |
| scale_normalize=False, | |
| ce_weight=0.01, | |
| local_loss=True, | |
| local_dist=4.0, | |
| local_largest_scale=8, | |
| ): | |
| super().__init__() | |
| self.robust = robust # measured in pixels | |
| self.center_coords = center_coords | |
| self.scale_normalize = scale_normalize | |
| self.ce_weight = ce_weight | |
| self.local_loss = local_loss | |
| self.local_dist = local_dist | |
| self.local_largest_scale = local_largest_scale | |
| def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches, scale): | |
| """[summary] | |
| Args: | |
| H ([type]): [description] | |
| scale ([type]): [description] | |
| Returns: | |
| [type]: [description] | |
| """ | |
| b, h1, w1, d = dense_matches.shape | |
| with torch.no_grad(): | |
| x1_n = torch.meshgrid( | |
| *[ | |
| torch.linspace( | |
| -1 + 1 / n, 1 - 1 / n, n, device=dense_matches.device | |
| ) | |
| for n in (b, h1, w1) | |
| ] | |
| ) | |
| x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(b, h1 * w1, 2) | |
| mask, x2 = warp_kpts( | |
| x1_n.double(), | |
| depth1.double(), | |
| depth2.double(), | |
| T_1to2.double(), | |
| K1.double(), | |
| K2.double(), | |
| ) | |
| prob = mask.float().reshape(b, h1, w1) | |
| gd = (dense_matches - x2.reshape(b, h1, w1, 2)).norm(dim=-1) # *scale? | |
| return gd, prob | |
| def dense_depth_loss(self, dense_certainty, prob, gd, scale, eps=1e-8): | |
| """[summary] | |
| Args: | |
| dense_certainty ([type]): [description] | |
| prob ([type]): [description] | |
| eps ([type], optional): [description]. Defaults to 1e-8. | |
| Returns: | |
| [type]: [description] | |
| """ | |
| smooth_prob = prob | |
| ce_loss = F.binary_cross_entropy_with_logits(dense_certainty[:, 0], smooth_prob) | |
| depth_loss = gd[prob > 0] | |
| if not torch.any(prob > 0).item(): | |
| depth_loss = (gd * 0.0).mean() # Prevent issues where prob is 0 everywhere | |
| return { | |
| f"ce_loss_{scale}": ce_loss.mean(), | |
| f"depth_loss_{scale}": depth_loss.mean(), | |
| } | |
| def forward(self, dense_corresps, batch): | |
| """[summary] | |
| Args: | |
| out ([type]): [description] | |
| batch ([type]): [description] | |
| Returns: | |
| [type]: [description] | |
| """ | |
| scales = list(dense_corresps.keys()) | |
| tot_loss = 0.0 | |
| prev_gd = 0.0 | |
| for scale in scales: | |
| dense_scale_corresps = dense_corresps[scale] | |
| dense_scale_certainty, dense_scale_coords = ( | |
| dense_scale_corresps["dense_certainty"], | |
| dense_scale_corresps["dense_flow"], | |
| ) | |
| dense_scale_coords = rearrange(dense_scale_coords, "b d h w -> b h w d") | |
| b, h, w, d = dense_scale_coords.shape | |
| gd, prob = self.geometric_dist( | |
| batch["query_depth"], | |
| batch["support_depth"], | |
| batch["T_1to2"], | |
| batch["K1"], | |
| batch["K2"], | |
| dense_scale_coords, | |
| scale, | |
| ) | |
| if ( | |
| scale <= self.local_largest_scale and self.local_loss | |
| ): # Thought here is that fine matching loss should not be punished by coarse mistakes, but should identify wrong matching | |
| prob = prob * ( | |
| F.interpolate(prev_gd[:, None], size=(h, w), mode="nearest")[:, 0] | |
| < (2 / 512) * (self.local_dist * scale) | |
| ) | |
| depth_losses = self.dense_depth_loss(dense_scale_certainty, prob, gd, scale) | |
| scale_loss = ( | |
| self.ce_weight * depth_losses[f"ce_loss_{scale}"] | |
| + depth_losses[f"depth_loss_{scale}"] | |
| ) # scale ce loss for coarser scales | |
| if self.scale_normalize: | |
| scale_loss = scale_loss * 1 / scale | |
| tot_loss = tot_loss + scale_loss | |
| prev_gd = gd.detach() | |
| return tot_loss | |