DvD / train_settings /validation /metrics_flow.py
hanquansanren's picture
Add application file
05fb4ab
import torch
# those are for flows BxHxWx2
def compute_epe(input_flow, target_flow, mean=True, calculate_std=False):
"""
End-point-Error computation
Args:
input_flow: estimated flow [BxHxW,2]
target_flow: ground-truth flow [BxHxW,2]
mean: return mean ?
calculate_std: calculate standard deviation ?
Output:
Averaged end-point-error (value)
"""
EPE = torch.norm(target_flow - input_flow, p=2, dim=1)
EPE_std = 0.0
if calculate_std:
EPE_std = torch.std(EPE).item()
if mean:
EPE = EPE.mean().item()
# shape is BxHxWx2
if calculate_std:
return EPE, EPE_std
else:
return EPE
def correct_correspondences(input_flow, target_flow, alpha, img_size, epe_tensor=None):
"""
Computation PCK, i.e number of the pixels within a certain threshold
Args:
input_flow: estimated flow [BxHxW,2]
target_flow: ground-truth flow [BxHxW,2]
alpha: threshold
img_size: image size
epe_tensor: epe tensor already computed [BxHxW, 1], default is None
Output:
PCK metric
"""
if epe_tensor is not None:
dist = epe_tensor
else:
dist = torch.norm(target_flow - input_flow, p=2, dim=1)
# dist is shape BxHgtxWgt
pck_threshold = alpha * img_size
mask = dist.le(pck_threshold) # Computes dist ≤ pck_threshold element-wise (element then equal to 1)
return mask.sum().item()
def Fl_kitti_2015(input_flow, target_flow, tau=[3.0, 0.05]):
"""
Computation number of outliers
for which error > 3px(tau[0]) and error/magnitude(ground truth flow) > 0.05(tau[1])
Args:
input_flow: estimated flow [BxHxW,2]
target_flow: ground-truth flow [BxHxW,2]
Output:
PCK metric
"""
# input flow is shape (BxHgtxWgt,2)
dist = torch.norm(target_flow - input_flow, p=2, dim=1)
gt_magnitude = torch.norm(target_flow, p=2, dim=1)
# dist is shape BxHgtxWgt
mask = dist.gt(3.0) & (dist/gt_magnitude).gt(0.05) # Computes dist > 3 and
return mask.sum().item()