Spaces:
Runtime error
Runtime error
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from shap_e.rendering.view_data import ProjectiveCamera | |
| class DifferentiableCamera(ABC): | |
| """ | |
| An object describing how a camera corresponds to pixels in an image. | |
| """ | |
| def camera_rays(self, coords: torch.Tensor) -> torch.Tensor: | |
| """ | |
| For every (x, y) coordinate in a rendered image, compute the ray of the | |
| corresponding pixel. | |
| :param coords: an [N x ... x 2] integer array of 2D image coordinates. | |
| :return: an [N x ... x 2 x 3] array of [2 x 3] (origin, direction) tuples. | |
| The direction should always be unit length. | |
| """ | |
| def resize_image(self, width: int, height: int) -> "DifferentiableCamera": | |
| """ | |
| Creates a new camera with the same intrinsics and direction as this one, | |
| but with resized image dimensions. | |
| """ | |
| class DifferentiableProjectiveCamera(DifferentiableCamera): | |
| """ | |
| Implements a batch, differentiable, standard pinhole camera | |
| """ | |
| origin: torch.Tensor # [batch_size x 3] | |
| x: torch.Tensor # [batch_size x 3] | |
| y: torch.Tensor # [batch_size x 3] | |
| z: torch.Tensor # [batch_size x 3] | |
| width: int | |
| height: int | |
| x_fov: float | |
| y_fov: float | |
| def __post_init__(self): | |
| assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0] | |
| assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3 | |
| assert ( | |
| len(self.x.shape) | |
| == len(self.y.shape) | |
| == len(self.z.shape) | |
| == len(self.origin.shape) | |
| == 2 | |
| ) | |
| def resolution(self): | |
| return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32)) | |
| def fov(self): | |
| return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32)) | |
| def image_coords(self) -> torch.Tensor: | |
| """ | |
| :return: coords of shape (width * height, 2) | |
| """ | |
| pixel_indices = torch.arange(self.height * self.width) | |
| coords = torch.stack( | |
| [ | |
| pixel_indices % self.width, | |
| torch.div(pixel_indices, self.width, rounding_mode="trunc"), | |
| ], | |
| axis=1, | |
| ) | |
| return coords | |
| def camera_rays(self, coords: torch.Tensor) -> torch.Tensor: | |
| # import pdb; pdb.set_trace() | |
| batch_size, *shape, n_coords = coords.shape | |
| assert n_coords == 2 | |
| assert batch_size == self.origin.shape[0] | |
| flat = coords.view(batch_size, -1, 2) | |
| res = self.resolution().to(flat.device) | |
| fov = self.fov().to(flat.device) | |
| fracs = (flat.float() / (res - 1)) * 2 - 1 | |
| fracs = fracs * torch.tan(fov / 2) | |
| fracs = fracs.view(batch_size, -1, 2) | |
| directions = ( | |
| self.z.view(batch_size, 1, 3) | |
| + self.x.view(batch_size, 1, 3) * fracs[:, :, :1] | |
| + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] | |
| ) | |
| directions = directions / directions.norm(dim=-1, keepdim=True) | |
| rays = torch.stack( | |
| [ | |
| torch.broadcast_to( | |
| self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3] | |
| ), | |
| directions, | |
| ], | |
| dim=2, | |
| ) | |
| return rays.view(batch_size, *shape, 2, 3) | |
| def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera": | |
| """ | |
| Creates a new camera for the resized view assuming the aspect ratio does not change. | |
| """ | |
| assert width * self.height == height * self.width, "The aspect ratio should not change." | |
| return DifferentiableProjectiveCamera( | |
| origin=self.origin, | |
| x=self.x, | |
| y=self.y, | |
| z=self.z, | |
| width=width, | |
| height=height, | |
| x_fov=self.x_fov, | |
| y_fov=self.y_fov, | |
| ) | |
| class DifferentiableCameraBatch(ABC): | |
| """ | |
| Annotate a differentiable camera with a multi-dimensional batch shape. | |
| """ | |
| shape: Tuple[int] | |
| flat_camera: DifferentiableCamera | |
| def normalize(vec: torch.Tensor) -> torch.Tensor: | |
| return vec / vec.norm(dim=-1, keepdim=True) | |
| def project_out(vec1: torch.Tensor, vec2: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Removes the vec2 component from vec1 | |
| """ | |
| vec2 = normalize(vec2) | |
| proj = (vec1 * vec2).sum(dim=-1, keepdim=True) | |
| return vec1 - proj * vec2 | |
| def camera_orientation(toward: torch.Tensor, up: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """ | |
| :param toward: [batch_size x 3] unit vector from camera position to the object | |
| :param up: Optional [batch_size x 3] specifying the physical up direction in the world frame. | |
| :return: [batch_size x 3 x 3] | |
| """ | |
| if up is None: | |
| up = torch.zeros_like(toward) | |
| up[:, 2] = 1 | |
| assert len(toward.shape) == 2 | |
| assert toward.shape[1] == 3 | |
| assert len(up.shape) == 2 | |
| assert up.shape[1] == 3 | |
| z = toward / toward.norm(dim=-1, keepdim=True) | |
| y = -normalize(project_out(up, toward)) | |
| x = torch.cross(y, z, dim=1) | |
| return torch.stack([x, y, z], dim=1) | |
| def projective_camera_frame( | |
| origin: torch.Tensor, | |
| toward: torch.Tensor, | |
| camera_params: Union[ProjectiveCamera, DifferentiableProjectiveCamera], | |
| ) -> DifferentiableProjectiveCamera: | |
| """ | |
| Given the origin and the direction of a view, return a differentiable | |
| projective camera with the given parameters. | |
| TODO: We need to support the rotation of the camera frame about the | |
| `toward` vector to fully implement 6 degrees of freedom. | |
| """ | |
| rot = camera_orientation(toward) | |
| camera = DifferentiableProjectiveCamera( | |
| origin=origin, | |
| x=rot[:, 0], | |
| y=rot[:, 1], | |
| z=rot[:, 2], | |
| width=camera_params.width, | |
| height=camera_params.height, | |
| x_fov=camera_params.x_fov, | |
| y_fov=camera_params.y_fov, | |
| ) | |
| return camera | |
| def get_image_coords(width, height) -> torch.Tensor: | |
| pixel_indices = torch.arange(height * width) | |
| # torch throws warnings for pixel_indices // width | |
| pixel_indices_div = torch.div(pixel_indices, width, rounding_mode="trunc") | |
| coords = torch.stack([pixel_indices % width, pixel_indices_div], dim=1) | |
| return coords | |