Spaces:
Running
on
T4
Running
on
T4
| # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. | |
| # !/usr/bin/env python | |
| # -*- coding:utf-8 -*- | |
| import torch | |
| from tqdm import tqdm | |
| def calculate_lfd_distance( | |
| q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8, | |
| tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8): | |
| with torch.no_grad(): | |
| src_ArtCoeff = src_ArtCoeff.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, 10, 10, -1, -1, -1) | |
| tgt_ArtCoeff = tgt_ArtCoeff.unsqueeze(dim=3).unsqueeze(dim=3).expand(-1, -1, -1, 10, 10, -1) | |
| art_distance = q8_table[src_ArtCoeff.reshape(-1).long(), tgt_ArtCoeff.reshape(-1).long()] | |
| art_distance = art_distance.reshape( | |
| src_ArtCoeff.shape[0], src_ArtCoeff.shape[1], src_ArtCoeff.shape[2], | |
| src_ArtCoeff.shape[3], | |
| src_ArtCoeff.shape[4], src_ArtCoeff.shape[5]) | |
| art_distance = torch.sum(art_distance, dim=-1) | |
| src_FdCoeff_q8 = src_FdCoeff_q8.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, 10, 10, -1, -1, -1) | |
| tgt_FdCoeff_q8 = tgt_FdCoeff_q8.unsqueeze(dim=3).unsqueeze(dim=3).expand(-1, -1, -1, 10, 10, -1) | |
| fd_distance = q8_table[src_FdCoeff_q8.reshape(-1).long(), tgt_FdCoeff_q8.reshape(-1).long()] | |
| fd_distance = fd_distance.reshape( | |
| src_FdCoeff_q8.shape[0], src_FdCoeff_q8.shape[1], src_FdCoeff_q8.shape[2], | |
| src_FdCoeff_q8.shape[3], src_FdCoeff_q8.shape[4], src_FdCoeff_q8.shape[5]) | |
| fd_distance = torch.sum(fd_distance, dim=-1) * 2.0 | |
| src_CirCoeff_q8 = src_CirCoeff_q8.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, 10, 10, -1, -1) | |
| tgt_CirCoeff_q8 = tgt_CirCoeff_q8.unsqueeze(dim=3).unsqueeze(dim=3).expand(-1, -1, -1, 10, 10) | |
| cir_distance = q8_table[src_CirCoeff_q8.reshape(-1).long(), tgt_CirCoeff_q8.reshape(-1).long()] | |
| cir_distance = cir_distance.reshape( | |
| src_CirCoeff_q8.shape[0], src_CirCoeff_q8.shape[1], | |
| src_CirCoeff_q8.shape[2], | |
| src_CirCoeff_q8.shape[3], src_CirCoeff_q8.shape[4]) | |
| cir_distance = cir_distance * 2.0 | |
| src_EccCoeff_q8 = src_EccCoeff_q8.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, 10, 10, -1, -1) | |
| tgt_EccCoeff_q8 = tgt_EccCoeff_q8.unsqueeze(dim=3).unsqueeze(dim=3).expand(-1, -1, -1, 10, 10) | |
| ecc_distance = q8_table[src_EccCoeff_q8.reshape(-1).long(), tgt_EccCoeff_q8.reshape(-1).long()] | |
| ecc_distance = ecc_distance.reshape( | |
| src_EccCoeff_q8.shape[0], src_EccCoeff_q8.shape[1], | |
| src_EccCoeff_q8.shape[2], src_EccCoeff_q8.shape[3], | |
| src_EccCoeff_q8.shape[4]) | |
| cost = art_distance + fd_distance + cir_distance + ecc_distance | |
| # find the cloest matching | |
| # cost shape: batch_size x src_camera x src_angle x dst_camera x dst_angle | |
| cost = cost.permute(0, 1, 3, 2, 4).long() | |
| align_n = align_10[:, :10].reshape(-1) | |
| cost_bxsrc_cxdst_cxsrc_axdst_a = cost | |
| align_err = torch.gather( | |
| input=cost_bxsrc_cxdst_cxsrc_axdst_a, | |
| index=align_n.reshape(1, 1, 1, 60 * 10, 1).expand( | |
| cost.shape[0], cost.shape[1], | |
| cost.shape[2], 60 * 10, 10).long(), | |
| dim=3) | |
| align_err = align_err.reshape(cost.shape[0], cost.shape[1], cost.shape[2], 60, 10, 10) | |
| sum_diag = 0 | |
| for i in range(10): | |
| sum_diag += align_err[:, :, :, :, i, i] | |
| sum_diag = sum_diag.reshape(cost.shape[0], -1) | |
| dist = torch.min(sum_diag, dim=-1)[0] | |
| return dist | |
| class LightFieldDistanceFunction(torch.autograd.Function): | |
| def forward( | |
| ctx, q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8, | |
| tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8, log): | |
| n = src_ArtCoeff.shape[0] | |
| m = tgt_ArtCoeff.shape[0] | |
| ############## | |
| # This is only calculating one pair of distance | |
| print(f"src_size: {n}") | |
| print(f"tgt_size: {m}") | |
| all_dist = [] | |
| with torch.no_grad(): | |
| for i in tqdm(range(n), mininterval=60, disable=not log): | |
| start_idx = 0 | |
| n_all_run = tgt_ArtCoeff.shape[0] | |
| n_each_run = 1000 | |
| one_run_d = [] | |
| while start_idx < n_all_run: | |
| end_idx = min(n_all_run, start_idx + n_each_run) | |
| run_length = end_idx - start_idx | |
| d = calculate_lfd_distance( | |
| q8_table, align_10, | |
| src_ArtCoeff[i:i + 1].expand(run_length, -1, -1, -1), | |
| src_FdCoeff_q8[i:i + 1].expand(run_length, -1, -1, -1), | |
| src_CirCoeff_q8[i:i + 1].expand(run_length, -1, -1), | |
| src_EccCoeff_q8[i:i + 1].expand(run_length, -1, -1), | |
| tgt_ArtCoeff[start_idx:end_idx], | |
| tgt_FdCoeff_q8[start_idx:end_idx], | |
| tgt_CirCoeff_q8[start_idx:end_idx], | |
| tgt_EccCoeff_q8[start_idx:end_idx]) | |
| start_idx = end_idx | |
| one_run_d.append(d) | |
| d = torch.cat(one_run_d, dim=0) | |
| all_dist.append(d.unsqueeze(dim=0)) | |
| dist = torch.cat(all_dist, dim=0) | |
| return dist | |
| def backward(ctx, graddist): | |
| raise NotImplementedError | |
| return None, None, None, None, None, None, None, None, None, None | |
| class LFD(torch.nn.Module): | |
| def forward( | |
| self, q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8, | |
| tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8, log): | |
| return LightFieldDistanceFunction.apply( | |
| q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8, | |
| tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8, log) | |