Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| from einops import rearrange | |
| def sample_img_rays(x, img_fov=45): | |
| """ | |
| Samples a unit ray for each pixel in image | |
| Args: | |
| x: images (...,h,w) | |
| img_fov: assumed image fov for ray calculation; int or tuple(h,w) | |
| Returns: | |
| img_rays (h,w,3) 3:<x,y,z> | |
| """ | |
| h, w, dtype, device = *x.shape[-2:], x.dtype, x.device | |
| hf_rad = 2*torch.pi*torch.tensor(img_fov)/2/360 | |
| axis_mag = (1/hf_rad.cos()).expand(2) # [y,x] | |
| axis_max_coord = (axis_mag**2-1)**.5 # [y,x] | |
| y_coords = torch.linspace(-axis_max_coord[0],axis_max_coord[0],h, dtype=dtype, device=device) | |
| x_coords = torch.linspace(-axis_max_coord[1],axis_max_coord[1],w, dtype=dtype, device=device) | |
| y, x = torch.meshgrid(y_coords, x_coords, indexing = 'ij') | |
| xyz = torch.stack([x, y, torch.ones_like(x)], dim=-1) # (h,w,<x,y,z>) | |
| img_rays = xyz / xyz.norm(dim=-1).unsqueeze(-1) | |
| return img_rays | |
| def gen_rotation_matrix(angles): | |
| """ | |
| Generate rotation matrix from angles | |
| Args: | |
| angles: axis-wise rotations in [0,360] (...,3) | |
| Returns: | |
| rot_mat (...,3,3) | |
| """ | |
| dims = angles.shape[:-1] | |
| angles = 2*torch.pi*angles/360 # [0,1] -> [0,2pi] | |
| angles = rearrange(angles, '... a -> a ...') # (3,...) | |
| cos = angles.cos() | |
| sin = angles.sin() | |
| rot_mat = torch.stack([ | |
| cos[1]*cos[2], sin[0]*sin[1]*cos[2]-cos[0]*sin[2], cos[0]*sin[1]*cos[2]+sin[0]*sin[2], | |
| cos[1]*sin[2], sin[0]*sin[1]*sin[2]+cos[0]*cos[2], cos[0]*sin[1]*sin[2]-sin[0]*cos[2], | |
| -sin[1], sin[0]*cos[1], cos[0]*cos[1] | |
| ], dim=-1).reshape(*dims,3,3) # (...,9) -> (...,3,3) | |
| return rot_mat | |
| def cart_2_spherical(pts): | |
| """ | |
| Convert Cartesian to spherical coordinates | |
| Args: | |
| pts: input pts (...,<x,y,z>) | |
| Returns: | |
| ret (...,<theta,phi,r>) (<azimuth,inclination,radius>) (radians) | |
| """ | |
| x,y,z = pts.moveaxis(-1,0) | |
| r = pts.norm(dim=-1) | |
| phi = torch.arcsin(y/r) | |
| theta = x.sign()*torch.arccos(z/(x**2+z**2)**.5) | |
| ret = torch.stack([theta,phi,r],dim=-1) | |
| return ret | |
| def sample_pano_img(img, pts, h_fov_ratio=1, w_fov_ratio=1): | |
| """ | |
| Sample points from panoramic image | |
| Args: | |
| img: pano-image (...,3:<rgb>,h,w) | |
| pts: spherical points to sample from img (...,h,w,3:<azimuth,inclination,radius>) | |
| *_fov_ratio: ratio of full fov for pano | |
| Returns: | |
| sampled_img (...,3:<rgb>,h,w) | |
| """ | |
| h, w = img.shape[-2:] | |
| sh, sw = pts.shape[-3:-1] | |
| h_conv, w_conv = h/h_fov_ratio, w/w_fov_ratio | |
| img = rearrange(img, '... c h w -> ... (h w) c') # (...,n,3) | |
| pts = rearrange(pts, '... h w c -> ... (h w) c') # (...,m,3) | |
| # convert (pts) radians to indices | |
| h_inds = (((pts[...,1] + torch.pi/2) / torch.pi) % 1) * h_conv # azimuth (-pi/2,+pi/2) | |
| w_inds = (((pts[...,0] + torch.pi) / (2*torch.pi)) % 1) * w_conv # azimuth (-pi,+pi) | |
| # get inds for bilin interp | |
| h_l, w_l = h_inds.to(torch.int).clamp(0,h-1), w_inds.to(torch.int).clamp(0,w-1) | |
| h_r, w_r = (h_l+1).clamp(0,h-1), (w_l+1).clamp(0,w-1) | |
| # get weights | |
| h_p_r, w_p_r = h_inds-h_l, w_inds-w_l | |
| h_p_l, w_p_l = 1-h_p_r, 1-w_p_r | |
| # linearize inds,weights | |
| inds = (torch.stack([w*h_l, w*h_r],dim=-1)[...,:,None] + torch.stack([w_l, w_r],dim=-1)[...,None,:]).flatten(-2).moveaxis(-1,0).to(torch.long) # (4,...) | |
| weights = (torch.stack([h_p_l, h_p_r],dim=-1)[...,:,None] * torch.stack([w_p_l, w_p_r],dim=-1)[...,None,:]).flatten(-2).moveaxis(-1,0) # (4,...) | |
| # do bilin interp | |
| img_extract = img[None,:].expand(4,*(len(img.shape)*[-1])).gather(-2, inds[...,None].expand(*(len(inds.shape)*[-1]),3)) | |
| sampled_img = (weights[...,None]*img_extract).sum(0) # (4,...,m,3) -> (...,m,3) | |
| sampled_img = rearrange(sampled_img, '... (h w) c -> ... c h w', h=sh, w=sw) | |
| return sampled_img | |
| def sample_perspective_img(pano_img, output_shape, fov=None, rot=None): | |
| """ | |
| Sample perspective image from panoramic | |
| Args: | |
| pano_img: pano-image numpy.array (h,w,3:<rgb>) | |
| output_shape: output image dimensions tuple(h,w) | |
| fov: desired perspective image fov; int or tuple(vertical,horizontal) in degrees [0,180) | |
| rot: axis-wise rotations; tuple(pitch,yaw,roll) in degrees [0,360] | |
| Returns: | |
| sampled_img numpy.array (h,w,3:<rgb>), fov, rot | |
| """ | |
| if fov is None: | |
| fov = torch.tensor([30,30]) + torch.tensor([60,60])*torch.rand(2) # (v-fov,h-fov) | |
| fov = (fov[0].item(), fov[1].item()) | |
| if rot is None: | |
| rot = (-torch.tensor([10,135,20]) + torch.tensor([20,225,40])*torch.rand(3)) # rot w.r.t (x,y,z) aka (pitch,yaw,roll) | |
| else: | |
| rot = torch.tensor(rot) | |
| pano_img = torch.tensor(pano_img, dtype=torch.uint8).moveaxis(-1,0) | |
| out_dtype = pano_img.dtype | |
| pano_img = pano_img.to(torch.float) | |
| img_rays = sample_img_rays(torch.empty(output_shape, dtype=pano_img.dtype, device=pano_img.device), img_fov=fov) | |
| rot_mat = gen_rotation_matrix(rot.to(pano_img.dtype))[None,None,:] # (3,3) -> (1,1,3,3) | |
| rot_img_rays = torch.matmul(rot_mat, img_rays.unsqueeze(-1)).squeeze(-1) | |
| spher_rot_img_rays = cart_2_spherical(rot_img_rays) # (h,w,3) | |
| # sample img | |
| pano_img = sample_pano_img(pano_img, spher_rot_img_rays) | |
| return pano_img.moveaxis(0,-1).to(out_dtype).numpy(), fov, rot.numpy() |