Spaces:
Build error
Build error
| import torch | |
| from kornia import create_meshgrid | |
| def project_and_normalize(ref_grid, src_proj, length): | |
| """ | |
| @param ref_grid: b 3 n | |
| @param src_proj: b 4 4 | |
| @param length: int | |
| @return: b, n, 2 | |
| """ | |
| src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n | |
| div_val = src_grid[:, -1:] | |
| div_val[div_val<1e-4] = 1e-4 | |
| src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n) | |
| src_grid[:, 0] = src_grid[:, 0]/((length - 1) / 2) - 1 # scale to -1~1 | |
| src_grid[:, 1] = src_grid[:, 1]/((length - 1) / 2) - 1 # scale to -1~1 | |
| src_grid = src_grid.permute(0, 2, 1) # (b, n, 2) | |
| return src_grid | |
| def construct_project_matrix(x_ratio, y_ratio, Ks, poses): | |
| """ | |
| @param x_ratio: float | |
| @param y_ratio: float | |
| @param Ks: b,3,3 | |
| @param poses: b,3,4 | |
| @return: | |
| """ | |
| rfn = Ks.shape[0] | |
| scale_m = torch.tensor([x_ratio, y_ratio, 1.0], dtype=torch.float32, device=Ks.device) | |
| scale_m = torch.diag(scale_m) | |
| ref_prj = scale_m[None, :, :] @ Ks @ poses # rfn,3,4 | |
| pad_vals = torch.zeros([rfn, 1, 4], dtype=torch.float32, device=ref_prj.device) | |
| pad_vals[:, :, 3] = 1.0 | |
| ref_prj = torch.cat([ref_prj, pad_vals], 1) # rfn,4,4 | |
| return ref_prj | |
| def get_warp_coordinates(volume_xyz, warp_size, input_size, Ks, warp_pose): | |
| B, _, D, H, W = volume_xyz.shape | |
| ratio = warp_size / input_size | |
| warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4 | |
| warp_coords = project_and_normalize(volume_xyz.view(B,3,D*H*W), warp_proj, warp_size).view(B, D, H, W, 2) | |
| return warp_coords | |
| def create_target_volume(depth_size, volume_size, input_image_size, pose_target, K, near=None, far=None): | |
| device, dtype = pose_target.device, pose_target.dtype | |
| # compute a depth range on the unit sphere | |
| H, W, D, B = volume_size, volume_size, depth_size, pose_target.shape[0] | |
| if near is not None and far is not None : | |
| # near, far b,1,h,w | |
| depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d | |
| depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1 | |
| depth_values = depth_values * (far - near) + near # b d h w | |
| depth_values = depth_values.view(B, 1, D, H * W) | |
| else: | |
| near, far = near_far_from_unit_sphere_using_camera_poses(pose_target) # b 1 | |
| depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d | |
| depth_values = depth_values[None,:,None] * (far[:,None,:] - near[:,None,:]) + near[:,None,:] # b d 1 | |
| depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H*W) | |
| ratio = volume_size / input_image_size | |
| # creat a grid on the target (reference) view | |
| # H, W, D, B = volume_size, volume_size, depth_values.shape[1], depth_values.shape[0] | |
| # creat mesh grid: note reference also means target | |
| ref_grid = create_meshgrid(H, W, normalized_coordinates=False) # (1, H, W, 2) | |
| ref_grid = ref_grid.to(device).to(dtype) | |
| ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W) | |
| ref_grid = ref_grid.reshape(1, 2, H*W) # (1, 2, H*W) | |
| ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W) | |
| ref_grid = torch.cat((ref_grid, torch.ones(B, 1, H*W, dtype=ref_grid.dtype, device=ref_grid.device)), dim=1) # (B, 3, H*W) | |
| ref_grid = ref_grid.unsqueeze(2) * depth_values # (B, 3, D, H*W) | |
| # unproject to space and transfer to world coordinates. | |
| Ks = K | |
| ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4 | |
| ref_proj_inv = torch.inverse(ref_proj) # B,4,4 | |
| ref_grid = ref_proj_inv[:,:3,:3] @ ref_grid.view(B,3,D*H*W) + ref_proj_inv[:,:3,3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW | |
| return ref_grid.reshape(B,3,D,H,W), depth_values.view(B,1,D,H,W) | |
| def near_far_from_unit_sphere_using_camera_poses(camera_poses): | |
| """ | |
| @param camera_poses: b 3 4 | |
| @return: | |
| near: b,1 | |
| far: b,1 | |
| """ | |
| R_w2c = camera_poses[..., :3, :3] # b 3 3 | |
| t_w2c = camera_poses[..., :3, 3:] # b 3 1 | |
| camera_origin = -R_w2c.permute(0,2,1) @ t_w2c # b 3 1 | |
| # R_w2c.T @ (0,0,1) = z_dir | |
| camera_orient = R_w2c.permute(0,2,1)[...,:3,2:3] # b 3 1 | |
| camera_origin, camera_orient = camera_origin[...,0], camera_orient[..., 0] # b 3 | |
| a = torch.sum(camera_orient ** 2, dim=-1, keepdim=True) # b 1 | |
| b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1 | |
| mid = b / a # b 1 | |
| near, far = mid - 1.0, mid + 1.0 | |
| return near, far |