Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| import numpy as np | |
| from numpy.lib.function_base import angle | |
| import torch | |
| import torch.nn.functional as F | |
| import math | |
| from scipy.spatial.transform import Rotation as Rot | |
| HUGE_NUMBER = 1e10 | |
| TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision | |
| def get_camera_mat(fov=49.13, invert=True): | |
| # fov = 2 * arctan(sensor / (2 * focal)) | |
| # focal = (sensor / 2) * 1 / (tan(0.5 * fov)) | |
| # in our case, sensor = 2 as pixels are in [-1, 1] | |
| focal = 1. / np.tan(0.5 * fov * np.pi/180.) | |
| focal = focal.astype(np.float32) | |
| mat = torch.tensor([ | |
| [focal, 0., 0., 0.], | |
| [0., focal, 0., 0.], | |
| [0., 0., 1, 0.], | |
| [0., 0., 0., 1.] | |
| ]).reshape(1, 4, 4) | |
| if invert: | |
| mat = torch.inverse(mat) | |
| return mat | |
| def get_random_pose(range_u, range_v, range_radius, batch_size=32, | |
| invert=False, gaussian=False, angular=False): | |
| loc, (u, v) = sample_on_sphere(range_u, range_v, size=(batch_size), gaussian=gaussian, angular=angular) | |
| radius = range_radius[0] + torch.rand(batch_size) * (range_radius[1] - range_radius[0]) | |
| loc = loc * radius.unsqueeze(-1) | |
| R = look_at(loc) | |
| RT = torch.eye(4).reshape(1, 4, 4).repeat(batch_size, 1, 1) | |
| RT[:, :3, :3] = R | |
| RT[:, :3, -1] = loc | |
| if invert: | |
| RT = torch.inverse(RT) | |
| def N(a, range_a): | |
| if range_a[0] == range_a[1]: | |
| return a * 0 | |
| return (a - range_a[0]) / (range_a[1] - range_a[0]) | |
| val_u, val_v, val_r = N(u, range_u), N(v, range_v), N(radius, range_radius) | |
| return RT, (val_u, val_v, val_r) | |
| def get_camera_pose(range_u, range_v, range_r, val_u=0.5, val_v=0.5, val_r=0.5, | |
| batch_size=32, invert=False, gaussian=False, angular=False): | |
| r0, rr = range_r[0], range_r[1] - range_r[0] | |
| r = r0 + val_r * rr | |
| if not gaussian: | |
| u0, ur = range_u[0], range_u[1] - range_u[0] | |
| v0, vr = range_v[0], range_v[1] - range_v[0] | |
| u = u0 + val_u * ur | |
| v = v0 + val_v * vr | |
| else: | |
| mean_u, mean_v = sum(range_u) / 2, sum(range_v) / 2 | |
| vu, vv = mean_u - range_u[0], mean_v - range_v[0] | |
| u = mean_u + vu * val_u | |
| v = mean_v + vv * val_v | |
| loc, _ = sample_on_sphere((u, u), (v, v), size=(batch_size), angular=angular) | |
| radius = torch.ones(batch_size) * r | |
| loc = loc * radius.unsqueeze(-1) | |
| R = look_at(loc) | |
| RT = torch.eye(4).reshape(1, 4, 4).repeat(batch_size, 1, 1) | |
| RT[:, :3, :3] = R | |
| RT[:, :3, -1] = loc | |
| if invert: | |
| RT = torch.inverse(RT) | |
| return RT | |
| def get_camera_pose_v2(range_u, range_v, range_r, mode, invert=False, gaussian=False, angular=False): | |
| r0, rr = range_r[0], range_r[1] - range_r[0] | |
| val_u, val_v = mode[:,0], mode[:,1] | |
| val_r = torch.ones_like(val_u) * 0.5 | |
| if not gaussian: | |
| u0, ur = range_u[0], range_u[1] - range_u[0] | |
| v0, vr = range_v[0], range_v[1] - range_v[0] | |
| u = u0 + val_u * ur | |
| v = v0 + val_v * vr | |
| else: | |
| mean_u, mean_v = sum(range_u) / 2, sum(range_v) / 2 | |
| vu, vv = mean_u - range_u[0], mean_v - range_v[0] | |
| u = mean_u + vu * val_u | |
| v = mean_v + vv * val_v | |
| loc = to_sphere(u, v, angular) | |
| radius = r0 + val_r * rr | |
| loc = loc * radius.unsqueeze(-1) | |
| R = look_at(loc) | |
| RT = torch.eye(4).to(R.device).reshape(1, 4, 4).repeat(R.size(0), 1, 1) | |
| RT[:, :3, :3] = R | |
| RT[:, :3, -1] = loc | |
| if invert: | |
| RT = torch.inverse(RT) | |
| return RT, (val_u, val_v, val_r) | |
| def to_sphere(u, v, angular=False): | |
| T = torch if isinstance(u, torch.Tensor) else np | |
| if not angular: | |
| theta = 2 * math.pi * u | |
| phi = T.arccos(1 - 2 * v) | |
| else: | |
| theta, phi = u, v | |
| cx = T.sin(phi) * T.cos(theta) | |
| cy = T.sin(phi) * T.sin(theta) | |
| cz = T.cos(phi) | |
| return T.stack([cx, cy, cz], -1) | |
| def sample_on_sphere(range_u=(0, 1), range_v=(0, 1), size=(1,), | |
| to_pytorch=True, gaussian=False, angular=False): | |
| if not gaussian: | |
| u = np.random.uniform(*range_u, size=size) | |
| v = np.random.uniform(*range_v, size=size) | |
| else: | |
| mean_u, mean_v = sum(range_u) / 2, sum(range_v) / 2 | |
| var_u, var_v = mean_u - range_u[0], mean_v - range_v[0] | |
| u = np.random.normal(size=size) * var_u + mean_u | |
| v = np.random.normal(size=size) * var_v + mean_v | |
| sample = to_sphere(u, v, angular) | |
| if to_pytorch: | |
| sample = torch.tensor(sample).float() | |
| u, v = torch.tensor(u).float(), torch.tensor(v).float() | |
| return sample, (u, v) | |
| def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5, | |
| to_pytorch=True): | |
| if not isinstance(eye, torch.Tensor): | |
| # this is the original code from GRAF | |
| at = at.astype(float).reshape(1, 3) | |
| up = up.astype(float).reshape(1, 3) | |
| eye = eye.reshape(-1, 3) | |
| up = up.repeat(eye.shape[0] // up.shape[0], axis=0) | |
| eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0) | |
| z_axis = eye - at | |
| z_axis /= np.max(np.stack([np.linalg.norm(z_axis, | |
| axis=1, keepdims=True), eps])) | |
| x_axis = np.cross(up, z_axis) | |
| x_axis /= np.max(np.stack([np.linalg.norm(x_axis, | |
| axis=1, keepdims=True), eps])) | |
| y_axis = np.cross(z_axis, x_axis) | |
| y_axis /= np.max(np.stack([np.linalg.norm(y_axis, | |
| axis=1, keepdims=True), eps])) | |
| r_mat = np.concatenate( | |
| (x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape( | |
| -1, 3, 1)), axis=2) | |
| if to_pytorch: | |
| r_mat = torch.tensor(r_mat).float() | |
| else: | |
| def normalize(x, axis=-1, order=2): | |
| l2 = x.norm(p=order, dim=axis, keepdim=True).clamp(min=1e-8) | |
| return x / l2 | |
| at, up = torch.from_numpy(at).float().to(eye.device), torch.from_numpy(up).float().to(eye.device) | |
| z_axis = normalize(eye - at[None, :]) | |
| x_axis = normalize(torch.cross(up[None,:].expand_as(z_axis), z_axis, dim=-1)) | |
| y_axis = normalize(torch.cross(z_axis, x_axis, dim=-1)) | |
| r_mat = torch.stack([x_axis, y_axis, z_axis], dim=-1) | |
| return r_mat | |
| def get_rotation_matrix(axis='z', value=0., batch_size=32): | |
| r = Rot.from_euler(axis, value * 2 * np.pi).as_dcm() | |
| r = torch.from_numpy(r).reshape(1, 3, 3).repeat(batch_size, 1, 1) | |
| return r | |
| def get_corner_rays(corner_pixels, camera_matrices, res): | |
| assert (res + 1) * (res + 1) == corner_pixels.size(1) | |
| batch_size = camera_matrices[0].size(0) | |
| rays, origins, _ = get_camera_rays(camera_matrices, corner_pixels) | |
| corner_rays = torch.cat([rays, torch.cross(origins, rays, dim=-1)], -1) | |
| corner_rays = corner_rays.reshape(batch_size, res+1, res+1, 6).permute(0,3,1,2) | |
| corner_rays = torch.cat([corner_rays[..., :-1, :-1], corner_rays[..., 1:, :-1], corner_rays[..., 1:, 1:], corner_rays[..., :-1, 1:]], 1) | |
| return corner_rays | |
| def arange_pixels( | |
| resolution=(128, 128), | |
| batch_size=1, | |
| subsample_to=None, | |
| invert_y_axis=False, | |
| margin=0, | |
| corner_aligned=True, | |
| jitter=None | |
| ): | |
| ''' Arranges pixels for given resolution in range image_range. | |
| The function returns the unscaled pixel locations as integers and the | |
| scaled float values. | |
| Args: | |
| resolution (tuple): image resolution | |
| batch_size (int): batch size | |
| subsample_to (int): if integer and > 0, the points are randomly | |
| subsampled to this value | |
| ''' | |
| h, w = resolution | |
| n_points = resolution[0] * resolution[1] | |
| uh = 1 if corner_aligned else 1 - (1 / h) | |
| uw = 1 if corner_aligned else 1 - (1 / w) | |
| if margin > 0: | |
| uh = uh + (2 / h) * margin | |
| uw = uw + (2 / w) * margin | |
| w, h = w + margin * 2, h + margin * 2 | |
| x, y = torch.linspace(-uw, uw, w), torch.linspace(-uh, uh, h) | |
| if jitter is not None: | |
| dx = (torch.ones_like(x).uniform_() - 0.5) * 2 / w * jitter | |
| dy = (torch.ones_like(y).uniform_() - 0.5) * 2 / h * jitter | |
| x, y = x + dx, y + dy | |
| x, y = torch.meshgrid(x, y) | |
| pixel_scaled = torch.stack([x, y], -1).permute(1,0,2).reshape(1, -1, 2).repeat(batch_size, 1, 1) | |
| # Subsample points if subsample_to is not None and > 0 | |
| if (subsample_to is not None and subsample_to > 0 and subsample_to < n_points): | |
| idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to,), | |
| replace=False) | |
| pixel_scaled = pixel_scaled[:, idx] | |
| if invert_y_axis: | |
| pixel_scaled[..., -1] *= -1. | |
| return pixel_scaled | |
| def to_pytorch(tensor, return_type=False): | |
| ''' Converts input tensor to pytorch. | |
| Args: | |
| tensor (tensor): Numpy or Pytorch tensor | |
| return_type (bool): whether to return input type | |
| ''' | |
| is_numpy = False | |
| if type(tensor) == np.ndarray: | |
| tensor = torch.from_numpy(tensor) | |
| is_numpy = True | |
| tensor = tensor.clone() | |
| if return_type: | |
| return tensor, is_numpy | |
| return tensor | |
| def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat=None, | |
| invert=True, use_absolute_depth=True): | |
| ''' Transforms pixel positions p with given depth value d to world coordinates. | |
| Args: | |
| pixels (tensor): pixel tensor of size B x N x 2 | |
| depth (tensor): depth tensor of size B x N x 1 | |
| camera_mat (tensor): camera matrix | |
| world_mat (tensor): world matrix | |
| scale_mat (tensor): scale matrix | |
| invert (bool): whether to invert matrices (default: true) | |
| ''' | |
| assert(pixels.shape[-1] == 2) | |
| if scale_mat is None: | |
| scale_mat = torch.eye(4).unsqueeze(0).repeat( | |
| camera_mat.shape[0], 1, 1).to(camera_mat.device) | |
| # Convert to pytorch | |
| pixels, is_numpy = to_pytorch(pixels, True) | |
| depth = to_pytorch(depth) | |
| camera_mat = to_pytorch(camera_mat) | |
| world_mat = to_pytorch(world_mat) | |
| scale_mat = to_pytorch(scale_mat) | |
| # Invert camera matrices | |
| if invert: | |
| camera_mat = torch.inverse(camera_mat) | |
| world_mat = torch.inverse(world_mat) | |
| scale_mat = torch.inverse(scale_mat) | |
| # Transform pixels to homogen coordinates | |
| pixels = pixels.permute(0, 2, 1) | |
| pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1) | |
| # Project pixels into camera space | |
| if use_absolute_depth: | |
| pixels[:, :2] = pixels[:, :2] * depth.permute(0, 2, 1).abs() | |
| pixels[:, 2:3] = pixels[:, 2:3] * depth.permute(0, 2, 1) | |
| else: | |
| pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1) | |
| # Transform pixels to world space | |
| p_world = scale_mat @ world_mat @ camera_mat @ pixels | |
| # Transform p_world back to 3D coordinates | |
| p_world = p_world[:, :3].permute(0, 2, 1) | |
| if is_numpy: | |
| p_world = p_world.numpy() | |
| return p_world | |
| def transform_to_camera_space(p_world, world_mat, camera_mat=None, scale_mat=None): | |
| ''' Transforms world points to camera space. | |
| Args: | |
| p_world (tensor): world points tensor of size B x N x 3 | |
| camera_mat (tensor): camera matrix | |
| world_mat (tensor): world matrix | |
| scale_mat (tensor): scale matrix | |
| ''' | |
| batch_size, n_p, _ = p_world.shape | |
| device = p_world.device | |
| # Transform world points to homogen coordinates | |
| p_world = torch.cat([p_world, torch.ones( | |
| batch_size, n_p, 1).to(device)], dim=-1).permute(0, 2, 1) | |
| # Apply matrices to transform p_world to camera space | |
| if scale_mat is None: | |
| if camera_mat is None: | |
| p_cam = world_mat @ p_world | |
| else: | |
| p_cam = camera_mat @ world_mat @ p_world | |
| else: | |
| p_cam = camera_mat @ world_mat @ scale_mat @ p_world | |
| # Transform points back to 3D coordinates | |
| p_cam = p_cam[:, :3].permute(0, 2, 1) | |
| return p_cam | |
| def origin_to_world(n_points, camera_mat, world_mat, scale_mat=None, | |
| invert=False): | |
| ''' Transforms origin (camera location) to world coordinates. | |
| Args: | |
| n_points (int): how often the transformed origin is repeated in the | |
| form (batch_size, n_points, 3) | |
| camera_mat (tensor): camera matrix | |
| world_mat (tensor): world matrix | |
| scale_mat (tensor): scale matrix | |
| invert (bool): whether to invert the matrices (default: true) | |
| ''' | |
| batch_size = camera_mat.shape[0] | |
| device = camera_mat.device | |
| # Create origin in homogen coordinates | |
| p = torch.zeros(batch_size, 4, n_points).to(device) | |
| p[:, -1] = 1. | |
| if scale_mat is None: | |
| scale_mat = torch.eye(4).unsqueeze( | |
| 0).repeat(batch_size, 1, 1).to(device) | |
| # Invert matrices | |
| if invert: | |
| camera_mat = torch.inverse(camera_mat) | |
| world_mat = torch.inverse(world_mat) | |
| scale_mat = torch.inverse(scale_mat) | |
| # Apply transformation | |
| p_world = scale_mat @ world_mat @ camera_mat @ p | |
| # Transform points back to 3D coordinates | |
| p_world = p_world[:, :3].permute(0, 2, 1) | |
| return p_world | |
| def image_points_to_world(image_points, camera_mat, world_mat, scale_mat=None, | |
| invert=False, negative_depth=True): | |
| ''' Transforms points on image plane to world coordinates. | |
| In contrast to transform_to_world, no depth value is needed as points on | |
| the image plane have a fixed depth of 1. | |
| Args: | |
| image_points (tensor): image points tensor of size B x N x 2 | |
| camera_mat (tensor): camera matrix | |
| world_mat (tensor): world matrix | |
| scale_mat (tensor): scale matrix | |
| invert (bool): whether to invert matrices | |
| ''' | |
| batch_size, n_pts, dim = image_points.shape | |
| assert(dim == 2) | |
| device = image_points.device | |
| d_image = torch.ones(batch_size, n_pts, 1).to(device) | |
| if negative_depth: | |
| d_image *= -1. | |
| return transform_to_world(image_points, d_image, camera_mat, world_mat, | |
| scale_mat, invert=invert) | |
| def image_points_to_camera(image_points, camera_mat, | |
| invert=False, negative_depth=True, use_absolute_depth=True): | |
| batch_size, n_pts, dim = image_points.shape | |
| assert(dim == 2) | |
| device = image_points.device | |
| d_image = torch.ones(batch_size, n_pts, 1).to(device) | |
| if negative_depth: | |
| d_image *= -1. | |
| # Convert to pytorch | |
| pixels, is_numpy = to_pytorch(image_points, True) | |
| depth = to_pytorch(d_image) | |
| camera_mat = to_pytorch(camera_mat) | |
| # Invert camera matrices | |
| if invert: | |
| camera_mat = torch.inverse(camera_mat) | |
| # Transform pixels to homogen coordinates | |
| pixels = pixels.permute(0, 2, 1) | |
| pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1) | |
| # Project pixels into camera space | |
| if use_absolute_depth: | |
| pixels[:, :2] = pixels[:, :2] * depth.permute(0, 2, 1).abs() | |
| pixels[:, 2:3] = pixels[:, 2:3] * depth.permute(0, 2, 1) | |
| else: | |
| pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1) | |
| # Transform pixels to world space | |
| p_camera = camera_mat @ pixels | |
| # Transform p_world back to 3D coordinates | |
| p_camera = p_camera[:, :3].permute(0, 2, 1) | |
| if is_numpy: | |
| p_camera = p_camera.numpy() | |
| return p_camera | |
| def camera_points_to_image(camera_points, camera_mat, | |
| invert=False, negative_depth=True, use_absolute_depth=True): | |
| batch_size, n_pts, dim = camera_points.shape | |
| assert(dim == 3) | |
| device = camera_points.device | |
| # Convert to pytorch | |
| p_camera, is_numpy = to_pytorch(camera_points, True) | |
| camera_mat = to_pytorch(camera_mat) | |
| # Invert camera matrices | |
| if invert: | |
| camera_mat = torch.inverse(camera_mat) | |
| # Transform world camera space to pixels | |
| p_camera = p_camera.permute(0, 2, 1) # B x 3 x N | |
| pixels = camera_mat[:, :3, :3] @ p_camera | |
| assert use_absolute_depth and negative_depth | |
| pixels, p_depths = pixels[:, :2], pixels[:, 2:3] | |
| p_depths = -p_depths # negative depth | |
| pixels = pixels / p_depths | |
| pixels = pixels.permute(0, 2, 1) | |
| if is_numpy: | |
| pixels = pixels.numpy() | |
| return pixels | |
| def angular_interpolation(res, camera_mat): | |
| batch_size = camera_mat.shape[0] | |
| device = camera_mat.device | |
| input_rays = image_points_to_camera(arange_pixels((res, res), batch_size, | |
| invert_y_axis=True).to(device), camera_mat) | |
| output_rays = image_points_to_camera(arange_pixels((res * 2, res * 2), batch_size, | |
| invert_y_axis=True).to(device), camera_mat) | |
| input_rays = input_rays / input_rays.norm(dim=-1, keepdim=True) | |
| output_rays = output_rays / output_rays.norm(dim=-1, keepdim=True) | |
| def dir2sph(v): | |
| u = (v[..., :2] ** 2).sum(-1).sqrt() | |
| theta = torch.atan2(u, v[..., 2]) / math.pi | |
| phi = torch.atan2(v[..., 1], v[..., 0]) / math.pi | |
| return torch.stack([theta, phi], 1) | |
| input_rays = dir2sph(input_rays).reshape(batch_size, 2, res, res) | |
| output_rays = dir2sph(output_rays).reshape(batch_size, 2, res * 2, res * 2) | |
| return input_rays | |
| def interpolate_sphere(z1, z2, t): | |
| p = (z1 * z2).sum(dim=-1, keepdim=True) | |
| p = p / z1.pow(2).sum(dim=-1, keepdim=True).sqrt() | |
| p = p / z2.pow(2).sum(dim=-1, keepdim=True).sqrt() | |
| omega = torch.acos(p) | |
| s1 = torch.sin((1-t)*omega)/torch.sin(omega) | |
| s2 = torch.sin(t*omega)/torch.sin(omega) | |
| z = s1 * z1 + s2 * z2 | |
| return z | |
| def get_camera_rays(camera_matrices, pixels=None, res=None, margin=0): | |
| device = camera_matrices[0].device | |
| batch_size = camera_matrices[0].shape[0] | |
| if pixels is None: | |
| assert res is not None | |
| pixels = arange_pixels((res, res), batch_size, invert_y_axis=True, margin=margin).to(device) | |
| n_points = pixels.size(1) | |
| pixels_world = image_points_to_world( | |
| pixels, camera_mat=camera_matrices[0], | |
| world_mat=camera_matrices[1]) | |
| camera_world = origin_to_world( | |
| n_points, camera_mat=camera_matrices[0], | |
| world_mat=camera_matrices[1]) | |
| ray_vector = pixels_world - camera_world | |
| ray_vector = ray_vector / ray_vector.norm(dim=-1, keepdim=True) | |
| return ray_vector, camera_world, pixels_world | |
| def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Converts 6D rotation representation by Zhou et al. [1] to rotation matrix | |
| using Gram--Schmidt orthogonalization per Section B of [1]. | |
| Args: | |
| d6: 6D rotation representation, of size (*, 6) | |
| Returns: | |
| batch of rotation matrices of size (*, 3, 3) | |
| [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
| On the Continuity of Rotation Representations in Neural Networks. | |
| IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
| Retrieved from http://arxiv.org/abs/1812.07035 | |
| """ | |
| a1, a2 = d6[..., :3], d6[..., 3:] | |
| b1 = F.normalize(a1, dim=-1) | |
| b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 | |
| b2 = F.normalize(b2, dim=-1) | |
| b3 = torch.cross(b1, b2, dim=-1) | |
| return torch.stack((b1, b2, b3), dim=-2) | |
| def camera_9d_to_16d(d9): | |
| d6, translation = d9[..., :6], d9[..., 6:] | |
| rotation = rotation_6d_to_matrix(d6) | |
| RT = torch.eye(4).to(device=d9.device, dtype=d9.dtype).reshape( | |
| 1, 4, 4).repeat(d6.size(0), 1, 1) | |
| RT[:, :3, :3] = rotation | |
| RT[:, :3, -1] = translation | |
| return RT.reshape(-1, 16) | |
| def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Converts rotation matrices to 6D rotation representation by Zhou et al. [1] | |
| by dropping the last row. Note that 6D representation is not unique. | |
| Args: | |
| matrix: batch of rotation matrices of size (*, 3, 3) | |
| Returns: | |
| 6D rotation representation, of size (*, 6) | |
| [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
| On the Continuity of Rotation Representations in Neural Networks. | |
| IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
| Retrieved from http://arxiv.org/abs/1812.07035 | |
| """ | |
| return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) | |
| def depth2pts_outside(ray_o, ray_d, depth): | |
| ''' | |
| ray_o, ray_d: [..., 3] | |
| depth: [...]; inverse of distance to sphere origin | |
| ''' | |
| # note: d1 becomes negative if this mid point is behind camera | |
| d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) | |
| p_mid = ray_o + d1.unsqueeze(-1) * ray_d | |
| p_mid_norm = torch.norm(p_mid, dim=-1) | |
| ray_d_cos = 1. / torch.norm(ray_d, dim=-1) | |
| d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos | |
| p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d | |
| rot_axis = torch.cross(ray_o, p_sphere, dim=-1) | |
| rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True) | |
| phi = torch.asin(p_mid_norm) | |
| theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1] | |
| rot_angle = (phi - theta).unsqueeze(-1) # [..., 1] | |
| # now rotate p_sphere | |
| # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula | |
| p_sphere_new = p_sphere * torch.cos(rot_angle) + \ | |
| torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \ | |
| rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle)) | |
| p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True) | |
| pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1) | |
| # now calculate conventional depth | |
| depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1 | |
| return pts, depth_real | |
| def intersect_sphere(ray_o, ray_d, radius=1): | |
| ''' | |
| ray_o, ray_d: [..., 3] | |
| compute the depth of the intersection point between this ray and unit sphere | |
| ''' | |
| # note: d1 becomes negative if this mid point is behind camera | |
| d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) | |
| p = ray_o + d1.unsqueeze(-1) * ray_d | |
| # consider the case where the ray does not intersect the sphere | |
| ray_d_cos = 1. / torch.norm(ray_d, dim=-1) | |
| d2 = radius ** 2 - torch.sum(p * p, dim=-1) | |
| mask = (d2 > 0) | |
| d2 = torch.sqrt(d2.clamp(min=1e-6)) * ray_d_cos | |
| d1, d2 = d1.unsqueeze(-1), d2.unsqueeze(-1) | |
| depth_range = [d1 - d2, d1 + d2] | |
| return depth_range, mask | |
| def normalize(x, axis=-1, order=2): | |
| if isinstance(x, torch.Tensor): | |
| l2 = x.norm(p=order, dim=axis, keepdim=True) | |
| return x / (l2 + 1e-8), l2 | |
| else: | |
| l2 = np.linalg.norm(x, order, axis) | |
| l2 = np.expand_dims(l2, axis) | |
| l2[l2==0] = 1 | |
| return x / l2, l2 | |
| def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5): | |
| """ | |
| Sample @N_importance samples from @bins with distribution defined by @weights. | |
| Inputs: | |
| bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" | |
| weights: (N_rays, N_samples_) | |
| N_importance: the number of samples to draw from the distribution | |
| det: deterministic or not | |
| eps: a small number to prevent division by zero | |
| Outputs: | |
| samples: the sampled samples | |
| Source: https://github.com/kwea123/nerf_pl/blob/master/models/rendering.py | |
| """ | |
| N_rays, N_samples_ = weights.shape | |
| weights = weights + eps # prevent division by zero (don't do inplace op!) | |
| pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) | |
| cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function | |
| cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) | |
| # padded to 0~1 inclusive | |
| if det: | |
| u = torch.linspace(0, 1, N_importance, device=bins.device) | |
| u = u.expand(N_rays, N_importance) | |
| else: | |
| u = torch.rand(N_rays, N_importance, device=bins.device) | |
| u = u.contiguous() | |
| inds = torch.searchsorted(cdf, u) | |
| below = torch.clamp_min(inds-1, 0) | |
| above = torch.clamp_max(inds, N_samples_) | |
| inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) | |
| cdf_g = torch.gather(cdf, 1, inds_sampled) | |
| cdf_g = cdf_g.view(N_rays, N_importance, 2) | |
| bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) | |
| denom = cdf_g[...,1]-cdf_g[...,0] | |
| denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled | |
| # anyway, therefore any value for it is fine (set to 1 here) | |
| samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0]) | |
| return samples | |
| def normalization_inverse_sqrt_dist_centered(x_in_world, view_cell_center, max_depth): | |
| localized = x_in_world - view_cell_center | |
| local = torch.sqrt(torch.linalg.norm(localized, dim=-1)) | |
| res = localized / (math.sqrt(max_depth) * local[..., None]) | |
| return res | |
| ###################################################################################### | |