Spaces:
Runtime error
Runtime error
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torch import torch | |
| from shap_e.models.generation.perceiver import SimplePerceiver | |
| from shap_e.models.generation.transformer import Transformer | |
| from shap_e.models.nn.camera import DifferentiableProjectiveCamera | |
| from shap_e.models.nn.encoding import ( | |
| MultiviewPointCloudEmbedding, | |
| MultiviewPoseEmbedding, | |
| PosEmbLinear, | |
| ) | |
| from shap_e.models.nn.ops import PointSetEmbedding | |
| from shap_e.rendering.point_cloud import PointCloud | |
| from shap_e.rendering.view_data import ProjectiveCamera | |
| from shap_e.util.collections import AttrDict | |
| from .base import ChannelsEncoder | |
| class TransformerChannelsEncoder(ChannelsEncoder, ABC): | |
| """ | |
| Encode point clouds using a transformer model with an extra output | |
| token used to extract a latent vector. | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| param_shapes: Dict[str, Tuple[int]], | |
| params_proj: Dict[str, Any], | |
| d_latent: int = 512, | |
| latent_bottleneck: Optional[Dict[str, Any]] = None, | |
| latent_warp: Optional[Dict[str, Any]] = None, | |
| n_ctx: int = 1024, | |
| width: int = 512, | |
| layers: int = 12, | |
| heads: int = 8, | |
| init_scale: float = 0.25, | |
| latent_scale: float = 1.0, | |
| ): | |
| super().__init__( | |
| device=device, | |
| param_shapes=param_shapes, | |
| params_proj=params_proj, | |
| d_latent=d_latent, | |
| latent_bottleneck=latent_bottleneck, | |
| latent_warp=latent_warp, | |
| ) | |
| self.width = width | |
| self.device = device | |
| self.dtype = dtype | |
| self.n_ctx = n_ctx | |
| self.backbone = Transformer( | |
| device=device, | |
| dtype=dtype, | |
| n_ctx=n_ctx + self.latent_ctx, | |
| width=width, | |
| layers=layers, | |
| heads=heads, | |
| init_scale=init_scale, | |
| ) | |
| self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) | |
| self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) | |
| self.register_parameter( | |
| "output_tokens", | |
| nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)), | |
| ) | |
| self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype) | |
| self.latent_scale = latent_scale | |
| def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: | |
| pass | |
| def encode_to_channels( | |
| self, batch: AttrDict, options: Optional[AttrDict] = None | |
| ) -> torch.Tensor: | |
| h = self.encode_input(batch, options=options) | |
| h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1) | |
| h = self.ln_pre(h) | |
| h = self.backbone(h) | |
| h = h[:, -self.latent_ctx :] | |
| h = self.ln_post(h) | |
| h = self.output_proj(h) | |
| return h | |
| class PerceiverChannelsEncoder(ChannelsEncoder, ABC): | |
| """ | |
| Encode point clouds using a perceiver model with an extra output | |
| token used to extract a latent vector. | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| param_shapes: Dict[str, Tuple[int]], | |
| params_proj: Dict[str, Any], | |
| min_unrolls: int, | |
| max_unrolls: int, | |
| d_latent: int = 512, | |
| latent_bottleneck: Optional[Dict[str, Any]] = None, | |
| latent_warp: Optional[Dict[str, Any]] = None, | |
| width: int = 512, | |
| layers: int = 12, | |
| xattn_layers: int = 1, | |
| heads: int = 8, | |
| init_scale: float = 0.25, | |
| # Training hparams | |
| inner_batch_size: Union[int, List[int]] = 1, | |
| data_ctx: int = 1, | |
| ): | |
| super().__init__( | |
| device=device, | |
| param_shapes=param_shapes, | |
| params_proj=params_proj, | |
| d_latent=d_latent, | |
| latent_bottleneck=latent_bottleneck, | |
| latent_warp=latent_warp, | |
| ) | |
| self.width = width | |
| self.device = device | |
| self.dtype = dtype | |
| if isinstance(inner_batch_size, int): | |
| inner_batch_size = [inner_batch_size] | |
| self.inner_batch_size = inner_batch_size | |
| self.data_ctx = data_ctx | |
| self.min_unrolls = min_unrolls | |
| self.max_unrolls = max_unrolls | |
| encoder_fn = lambda inner_batch_size: SimplePerceiver( | |
| device=device, | |
| dtype=dtype, | |
| n_ctx=self.data_ctx + self.latent_ctx, | |
| n_data=inner_batch_size, | |
| width=width, | |
| layers=xattn_layers, | |
| heads=heads, | |
| init_scale=init_scale, | |
| ) | |
| self.encoder = ( | |
| encoder_fn(self.inner_batch_size[0]) | |
| if len(self.inner_batch_size) == 1 | |
| else nn.ModuleList([encoder_fn(inner_bsz) for inner_bsz in self.inner_batch_size]) | |
| ) | |
| self.processor = Transformer( | |
| device=device, | |
| dtype=dtype, | |
| n_ctx=self.data_ctx + self.latent_ctx, | |
| layers=layers - xattn_layers, | |
| width=width, | |
| heads=heads, | |
| init_scale=init_scale, | |
| ) | |
| self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) | |
| self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) | |
| self.register_parameter( | |
| "output_tokens", | |
| nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)), | |
| ) | |
| self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype) | |
| def get_h_and_iterator( | |
| self, batch: AttrDict, options: Optional[AttrDict] = None | |
| ) -> Tuple[torch.Tensor, Iterable[Union[torch.Tensor, Tuple]]]: | |
| """ | |
| :return: a tuple of ( | |
| the initial output tokens of size [batch_size, data_ctx + latent_ctx, width], | |
| an iterator over the given data | |
| ) | |
| """ | |
| def encode_to_channels( | |
| self, batch: AttrDict, options: Optional[AttrDict] = None | |
| ) -> torch.Tensor: | |
| h, it = self.get_h_and_iterator(batch, options=options) | |
| n_unrolls = self.get_n_unrolls() | |
| for _ in range(n_unrolls): | |
| data = next(it) | |
| if isinstance(data, tuple): | |
| for data_i, encoder_i in zip(data, self.encoder): | |
| h = encoder_i(h, data_i) | |
| else: | |
| h = self.encoder(h, data) | |
| h = self.processor(h) | |
| h = self.output_proj(self.ln_post(h[:, -self.latent_ctx :])) | |
| return h | |
| def get_n_unrolls(self): | |
| if self.training: | |
| n_unrolls = torch.randint( | |
| self.min_unrolls, self.max_unrolls + 1, size=(), device=self.device | |
| ) | |
| dist.broadcast(n_unrolls, 0) | |
| n_unrolls = n_unrolls.item() | |
| else: | |
| n_unrolls = self.max_unrolls | |
| return n_unrolls | |
| class DatasetIterator: | |
| embs: torch.Tensor # [batch_size, dataset_size, *shape] | |
| batch_size: int | |
| def __iter__(self): | |
| self._reset() | |
| return self | |
| def __next__(self): | |
| _outer_batch_size, dataset_size, *_shape = self.embs.shape | |
| while True: | |
| start = self.idx | |
| self.idx += self.batch_size | |
| end = self.idx | |
| if end <= dataset_size: | |
| break | |
| self._reset() | |
| return self.embs[:, start:end] | |
| def _reset(self): | |
| self._shuffle() | |
| self.idx = 0 # pylint: disable=attribute-defined-outside-init | |
| def _shuffle(self): | |
| outer_batch_size, dataset_size, *shape = self.embs.shape | |
| idx = torch.stack( | |
| [ | |
| torch.randperm(dataset_size, device=self.embs.device) | |
| for _ in range(outer_batch_size) | |
| ], | |
| dim=0, | |
| ) | |
| idx = idx.view(outer_batch_size, dataset_size, *([1] * len(shape))) | |
| idx = torch.broadcast_to(idx, self.embs.shape) | |
| self.embs = torch.gather(self.embs, 1, idx) | |
| class PointCloudTransformerChannelsEncoder(TransformerChannelsEncoder): | |
| """ | |
| Encode point clouds using a transformer model with an extra output | |
| token used to extract a latent vector. | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| input_channels: int = 6, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.input_channels = input_channels | |
| self.input_proj = nn.Linear( | |
| input_channels, self.width, device=self.device, dtype=self.dtype | |
| ) | |
| def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: | |
| _ = options | |
| points = batch.points | |
| h = self.input_proj(points.permute(0, 2, 1)) # NCL -> NLC | |
| return h | |
| class PointCloudPerceiverChannelsEncoder(PerceiverChannelsEncoder): | |
| """ | |
| Encode point clouds using a transformer model with an extra output | |
| token used to extract a latent vector. | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| cross_attention_dataset: str = "pcl", | |
| fps_method: str = "fps", | |
| # point cloud hyperparameters | |
| input_channels: int = 6, | |
| pos_emb: Optional[str] = None, | |
| # multiview hyperparameters | |
| image_size: int = 256, | |
| patch_size: int = 32, | |
| pose_dropout: float = 0.0, | |
| use_depth: bool = False, | |
| max_depth: float = 5.0, | |
| # point conv hyperparameters | |
| pointconv_radius: float = 0.5, | |
| pointconv_samples: int = 32, | |
| pointconv_hidden: Optional[List[int]] = None, | |
| pointconv_patch_size: int = 1, | |
| pointconv_stride: int = 1, | |
| pointconv_padding_mode: str = "zeros", | |
| use_pointconv: bool = False, | |
| # other hyperparameters | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| assert cross_attention_dataset in ( | |
| "pcl", | |
| "multiview", | |
| "dense_pose_multiview", | |
| "multiview_pcl", | |
| "pcl_and_multiview_pcl", | |
| "incorrect_multiview_pcl", | |
| "pcl_and_incorrect_multiview_pcl", | |
| ) | |
| assert fps_method in ("fps", "first") | |
| self.cross_attention_dataset = cross_attention_dataset | |
| self.fps_method = fps_method | |
| self.input_channels = input_channels | |
| self.input_proj = PosEmbLinear( | |
| pos_emb, | |
| input_channels, | |
| self.width, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| self.use_pointconv = use_pointconv | |
| if use_pointconv: | |
| if pointconv_hidden is None: | |
| pointconv_hidden = [self.width] | |
| self.point_conv = PointSetEmbedding( | |
| n_point=self.data_ctx, | |
| radius=pointconv_radius, | |
| n_sample=pointconv_samples, | |
| d_input=self.input_proj.weight.shape[0], | |
| d_hidden=pointconv_hidden, | |
| patch_size=pointconv_patch_size, | |
| stride=pointconv_stride, | |
| padding_mode=pointconv_padding_mode, | |
| fps_method=fps_method, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| if self.cross_attention_dataset == "multiview": | |
| self.image_size = image_size | |
| self.patch_size = patch_size | |
| self.pose_dropout = pose_dropout | |
| self.use_depth = use_depth | |
| self.max_depth = max_depth | |
| pos_ctx = (image_size // patch_size) ** 2 | |
| self.register_parameter( | |
| "pos_emb", | |
| nn.Parameter( | |
| torch.randn( | |
| pos_ctx * self.inner_batch_size, | |
| self.width, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| ), | |
| ) | |
| self.patch_emb = nn.Conv2d( | |
| in_channels=3 if not use_depth else 4, | |
| out_channels=self.width, | |
| kernel_size=patch_size, | |
| stride=patch_size, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| self.camera_emb = nn.Sequential( | |
| nn.Linear( | |
| 3 * 4 + 1, self.width, device=self.device, dtype=self.dtype | |
| ), # input size is for origin+x+y+z+fov | |
| nn.GELU(), | |
| nn.Linear(self.width, 2 * self.width, device=self.device, dtype=self.dtype), | |
| ) | |
| elif self.cross_attention_dataset == "dense_pose_multiview": | |
| # The number of output features is halved, because a patch_size of | |
| # 32 ends up with a large patch_emb weight. | |
| self.view_pose_width = self.width // 2 | |
| self.image_size = image_size | |
| self.patch_size = patch_size | |
| self.use_depth = use_depth | |
| self.max_depth = max_depth | |
| self.mv_pose_embed = MultiviewPoseEmbedding( | |
| posemb_version="nerf", | |
| n_channels=4 if self.use_depth else 3, | |
| out_features=self.view_pose_width, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| pos_ctx = (image_size // patch_size) ** 2 | |
| # Positional embedding is unnecessary because pose information is baked into each pixel | |
| self.patch_emb = nn.Conv2d( | |
| in_channels=self.view_pose_width, | |
| out_channels=self.width, | |
| kernel_size=patch_size, | |
| stride=patch_size, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| elif ( | |
| self.cross_attention_dataset == "multiview_pcl" | |
| or self.cross_attention_dataset == "incorrect_multiview_pcl" | |
| ): | |
| self.view_pose_width = self.width // 2 | |
| self.image_size = image_size | |
| self.patch_size = patch_size | |
| self.max_depth = max_depth | |
| assert use_depth | |
| self.mv_pcl_embed = MultiviewPointCloudEmbedding( | |
| posemb_version="nerf", | |
| n_channels=3, | |
| out_features=self.view_pose_width, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| self.patch_emb = nn.Conv2d( | |
| in_channels=self.view_pose_width, | |
| out_channels=self.width, | |
| kernel_size=patch_size, | |
| stride=patch_size, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| elif ( | |
| self.cross_attention_dataset == "pcl_and_multiview_pcl" | |
| or self.cross_attention_dataset == "pcl_and_incorrect_multiview_pcl" | |
| ): | |
| self.view_pose_width = self.width // 2 | |
| self.image_size = image_size | |
| self.patch_size = patch_size | |
| self.max_depth = max_depth | |
| assert use_depth | |
| self.mv_pcl_embed = MultiviewPointCloudEmbedding( | |
| posemb_version="nerf", | |
| n_channels=3, | |
| out_features=self.view_pose_width, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| self.patch_emb = nn.Conv2d( | |
| in_channels=self.view_pose_width, | |
| out_channels=self.width, | |
| kernel_size=patch_size, | |
| stride=patch_size, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| def get_h_and_iterator( | |
| self, batch: AttrDict, options: Optional[AttrDict] = None | |
| ) -> Tuple[torch.Tensor, Iterable]: | |
| """ | |
| :return: a tuple of ( | |
| the initial output tokens of size [batch_size, data_ctx + latent_ctx, width], | |
| an iterator over the given data | |
| ) | |
| """ | |
| options = AttrDict() if options is None else options | |
| # Build the initial query embeddings | |
| points = batch.points.permute(0, 2, 1) # NCL -> NLC | |
| if self.use_pointconv: | |
| points = self.input_proj(points).permute(0, 2, 1) # NLC -> NCL | |
| xyz = batch.points[:, :3] | |
| data_tokens = self.point_conv(xyz, points).permute(0, 2, 1) # NCL -> NLC | |
| else: | |
| fps_samples = self.sample_pcl_fps(points) | |
| data_tokens = self.input_proj(fps_samples) | |
| batch_size = points.shape[0] | |
| latent_tokens = self.output_tokens.unsqueeze(0).repeat(batch_size, 1, 1) | |
| h = self.ln_pre(torch.cat([data_tokens, latent_tokens], dim=1)) | |
| assert h.shape == (batch_size, self.data_ctx + self.latent_ctx, self.width) | |
| # Build the dataset embedding iterator | |
| dataset_fn = { | |
| "pcl": self.get_pcl_dataset, | |
| "multiview": self.get_multiview_dataset, | |
| "dense_pose_multiview": self.get_dense_pose_multiview_dataset, | |
| "pcl_and_multiview_pcl": self.get_pcl_and_multiview_pcl_dataset, | |
| "multiview_pcl": self.get_multiview_pcl_dataset, | |
| }[self.cross_attention_dataset] | |
| it = dataset_fn(batch, options=options) | |
| return h, it | |
| def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor: | |
| return sample_pcl_fps(points, data_ctx=self.data_ctx, method=self.fps_method) | |
| def get_pcl_dataset( | |
| self, | |
| batch: AttrDict, | |
| options: Optional[AttrDict[str, Any]] = None, | |
| inner_batch_size: Optional[int] = None, | |
| ) -> Iterable: | |
| _ = options | |
| if inner_batch_size is None: | |
| inner_batch_size = self.inner_batch_size[0] | |
| points = batch.points.permute(0, 2, 1) # NCL -> NLC | |
| dataset_emb = self.input_proj(points) | |
| assert dataset_emb.shape[1] >= inner_batch_size | |
| return iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size)) | |
| def get_multiview_dataset( | |
| self, | |
| batch: AttrDict, | |
| options: Optional[AttrDict] = None, | |
| inner_batch_size: Optional[int] = None, | |
| ) -> Iterable: | |
| _ = options | |
| if inner_batch_size is None: | |
| inner_batch_size = self.inner_batch_size[0] | |
| dataset_emb = self.encode_views(batch) | |
| batch_size, num_views, n_patches, width = dataset_emb.shape | |
| assert num_views >= inner_batch_size | |
| it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size)) | |
| def gen(): | |
| while True: | |
| examples = next(it) | |
| assert examples.shape == (batch_size, self.inner_batch_size, n_patches, self.width) | |
| views = examples.reshape(batch_size, -1, width) + self.pos_emb | |
| yield views | |
| return gen() | |
| def get_dense_pose_multiview_dataset( | |
| self, | |
| batch: AttrDict, | |
| options: Optional[AttrDict] = None, | |
| inner_batch_size: Optional[int] = None, | |
| ) -> Iterable: | |
| _ = options | |
| if inner_batch_size is None: | |
| inner_batch_size = self.inner_batch_size[0] | |
| dataset_emb = self.encode_dense_pose_views(batch) | |
| batch_size, num_views, n_patches, width = dataset_emb.shape | |
| assert num_views >= inner_batch_size | |
| it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size)) | |
| def gen(): | |
| while True: | |
| examples = next(it) | |
| assert examples.shape == (batch_size, inner_batch_size, n_patches, self.width) | |
| views = examples.reshape(batch_size, -1, width) | |
| yield views | |
| return gen() | |
| def get_pcl_and_multiview_pcl_dataset( | |
| self, | |
| batch: AttrDict, | |
| options: Optional[AttrDict] = None, | |
| use_distance: bool = True, | |
| ) -> Iterable: | |
| _ = options | |
| pcl_it = self.get_pcl_dataset( | |
| batch, options=options, inner_batch_size=self.inner_batch_size[0] | |
| ) | |
| multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance) | |
| batch_size, num_views, n_patches, width = multiview_pcl_emb.shape | |
| assert num_views >= self.inner_batch_size[1] | |
| multiview_pcl_it = iter( | |
| DatasetIterator(multiview_pcl_emb, batch_size=self.inner_batch_size[1]) | |
| ) | |
| def gen(): | |
| while True: | |
| pcl = next(pcl_it) | |
| multiview_pcl = next(multiview_pcl_it) | |
| assert multiview_pcl.shape == ( | |
| batch_size, | |
| self.inner_batch_size[1], | |
| n_patches, | |
| self.width, | |
| ) | |
| yield pcl, multiview_pcl.reshape(batch_size, -1, width) | |
| return gen() | |
| def get_multiview_pcl_dataset( | |
| self, | |
| batch: AttrDict, | |
| options: Optional[AttrDict] = None, | |
| inner_batch_size: Optional[int] = None, | |
| use_distance: bool = True, | |
| ) -> Iterable: | |
| _ = options | |
| if inner_batch_size is None: | |
| inner_batch_size = self.inner_batch_size[0] | |
| multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance) | |
| batch_size, num_views, n_patches, width = multiview_pcl_emb.shape | |
| assert num_views >= inner_batch_size | |
| multiview_pcl_it = iter(DatasetIterator(multiview_pcl_emb, batch_size=inner_batch_size)) | |
| def gen(): | |
| while True: | |
| multiview_pcl = next(multiview_pcl_it) | |
| assert multiview_pcl.shape == ( | |
| batch_size, | |
| inner_batch_size, | |
| n_patches, | |
| self.width, | |
| ) | |
| yield multiview_pcl.reshape(batch_size, -1, width) | |
| return gen() | |
| def encode_views(self, batch: AttrDict) -> torch.Tensor: | |
| """ | |
| :return: [batch_size, num_views, n_patches, width] | |
| """ | |
| all_views = self.views_to_tensor(batch.views).to(self.device) | |
| if self.use_depth: | |
| all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2) | |
| all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device) | |
| batch_size, num_views, _, _, _ = all_views.shape | |
| views_proj = self.patch_emb( | |
| all_views.reshape([batch_size * num_views, *all_views.shape[2:]]) | |
| ) | |
| views_proj = ( | |
| views_proj.reshape([batch_size, num_views, self.width, -1]) | |
| .permute(0, 1, 3, 2) | |
| .contiguous() | |
| ) # [batch_size x num_views x n_patches x width] | |
| # [batch_size, num_views, 1, 2 * width] | |
| camera_proj = self.camera_emb(all_cameras).reshape( | |
| [batch_size, num_views, 1, self.width * 2] | |
| ) | |
| pose_dropout = self.pose_dropout if self.training else 0.0 | |
| mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout | |
| camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj)) | |
| scale, shift = camera_proj.chunk(2, dim=3) | |
| views_proj = views_proj * (scale + 1.0) + shift | |
| return views_proj | |
| def encode_dense_pose_views(self, batch: AttrDict) -> torch.Tensor: | |
| """ | |
| :return: [batch_size, num_views, n_patches, width] | |
| """ | |
| all_views = self.views_to_tensor(batch.views).to(self.device) | |
| if self.use_depth: | |
| depths = self.depths_to_tensor(batch.depths) | |
| all_views = torch.cat([all_views, depths], dim=2) | |
| dense_poses, _ = self.dense_pose_cameras_to_tensor(batch.cameras) | |
| dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3) | |
| position, direction = dense_poses[:, :, 0], dense_poses[:, :, 1] | |
| all_view_poses = self.mv_pose_embed(all_views, position, direction) | |
| batch_size, num_views, _, _, _ = all_view_poses.shape | |
| views_proj = self.patch_emb( | |
| all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]]) | |
| ) | |
| views_proj = ( | |
| views_proj.reshape([batch_size, num_views, self.width, -1]) | |
| .permute(0, 1, 3, 2) | |
| .contiguous() | |
| ) # [batch_size x num_views x n_patches x width] | |
| return views_proj | |
| def encode_multiview_pcl(self, batch: AttrDict, use_distance: bool = True) -> torch.Tensor: | |
| """ | |
| :return: [batch_size, num_views, n_patches, width] | |
| """ | |
| all_views = self.views_to_tensor(batch.views).to(self.device) | |
| depths = self.raw_depths_to_tensor(batch.depths) | |
| all_view_alphas = self.view_alphas_to_tensor(batch.view_alphas).to(self.device) | |
| mask = all_view_alphas >= 0.999 | |
| dense_poses, camera_z = self.dense_pose_cameras_to_tensor(batch.cameras) | |
| dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3) | |
| origin, direction = dense_poses[:, :, 0], dense_poses[:, :, 1] | |
| if use_distance: | |
| ray_depth_factor = torch.sum(direction * camera_z[..., None, None], dim=2, keepdim=True) | |
| depths = depths / ray_depth_factor | |
| position = origin + depths * direction | |
| all_view_poses = self.mv_pcl_embed(all_views, origin, position, mask) | |
| batch_size, num_views, _, _, _ = all_view_poses.shape | |
| views_proj = self.patch_emb( | |
| all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]]) | |
| ) | |
| views_proj = ( | |
| views_proj.reshape([batch_size, num_views, self.width, -1]) | |
| .permute(0, 1, 3, 2) | |
| .contiguous() | |
| ) # [batch_size x num_views x n_patches x width] | |
| return views_proj | |
| def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor: | |
| """ | |
| Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1]. | |
| """ | |
| if isinstance(views, torch.Tensor): | |
| return views | |
| tensor_batch = [] | |
| num_views = len(views[0]) | |
| for inner_list in views: | |
| assert len(inner_list) == num_views | |
| inner_batch = [] | |
| for img in inner_list: | |
| img = img.resize((self.image_size,) * 2).convert("RGB") | |
| inner_batch.append( | |
| torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32) | |
| / 127.5 | |
| - 1 | |
| ) | |
| tensor_batch.append(torch.stack(inner_batch, dim=0)) | |
| return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3) | |
| def depths_to_tensor( | |
| self, depths: Union[torch.Tensor, List[List[Image.Image]]] | |
| ) -> torch.Tensor: | |
| """ | |
| Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1]. | |
| """ | |
| if isinstance(depths, torch.Tensor): | |
| return depths | |
| tensor_batch = [] | |
| num_views = len(depths[0]) | |
| for inner_list in depths: | |
| assert len(inner_list) == num_views | |
| inner_batch = [] | |
| for arr in inner_list: | |
| tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth | |
| tensor = tensor * 2 - 1 | |
| tensor = F.interpolate( | |
| tensor[None, None], | |
| (self.image_size,) * 2, | |
| mode="nearest", | |
| ) | |
| inner_batch.append(tensor.to(device=self.device, dtype=torch.float32)) | |
| tensor_batch.append(torch.cat(inner_batch, dim=0)) | |
| return torch.stack(tensor_batch, dim=0) | |
| def view_alphas_to_tensor( | |
| self, view_alphas: Union[torch.Tensor, List[List[Image.Image]]] | |
| ) -> torch.Tensor: | |
| """ | |
| Returns a [batch x num_views x 1 x size x size] tensor in the range [0, 1]. | |
| """ | |
| if isinstance(view_alphas, torch.Tensor): | |
| return view_alphas | |
| tensor_batch = [] | |
| num_views = len(view_alphas[0]) | |
| for inner_list in view_alphas: | |
| assert len(inner_list) == num_views | |
| inner_batch = [] | |
| for img in inner_list: | |
| tensor = ( | |
| torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32) | |
| / 255.0 | |
| ) | |
| tensor = F.interpolate( | |
| tensor[None, None], | |
| (self.image_size,) * 2, | |
| mode="nearest", | |
| ) | |
| inner_batch.append(tensor) | |
| tensor_batch.append(torch.cat(inner_batch, dim=0)) | |
| return torch.stack(tensor_batch, dim=0) | |
| def raw_depths_to_tensor( | |
| self, depths: Union[torch.Tensor, List[List[Image.Image]]] | |
| ) -> torch.Tensor: | |
| """ | |
| Returns a [batch x num_views x 1 x size x size] tensor | |
| """ | |
| if isinstance(depths, torch.Tensor): | |
| return depths | |
| tensor_batch = [] | |
| num_views = len(depths[0]) | |
| for inner_list in depths: | |
| assert len(inner_list) == num_views | |
| inner_batch = [] | |
| for arr in inner_list: | |
| tensor = torch.from_numpy(arr).clamp(max=self.max_depth) | |
| tensor = F.interpolate( | |
| tensor[None, None], | |
| (self.image_size,) * 2, | |
| mode="nearest", | |
| ) | |
| inner_batch.append(tensor.to(device=self.device, dtype=torch.float32)) | |
| tensor_batch.append(torch.cat(inner_batch, dim=0)) | |
| return torch.stack(tensor_batch, dim=0) | |
| def cameras_to_tensor( | |
| self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]] | |
| ) -> torch.Tensor: | |
| """ | |
| Returns a [batch x num_views x 3*4+1] tensor of camera information. | |
| """ | |
| if isinstance(cameras, torch.Tensor): | |
| return cameras | |
| outer_batch = [] | |
| for inner_list in cameras: | |
| inner_batch = [] | |
| for camera in inner_list: | |
| inner_batch.append( | |
| np.array( | |
| [ | |
| *camera.x, | |
| *camera.y, | |
| *camera.z, | |
| *camera.origin, | |
| camera.x_fov, | |
| ] | |
| ) | |
| ) | |
| outer_batch.append(np.stack(inner_batch, axis=0)) | |
| return torch.from_numpy(np.stack(outer_batch, axis=0)).float() | |
| def dense_pose_cameras_to_tensor( | |
| self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]] | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Returns a tuple of (rays, z_directions) where | |
| - rays: [batch, num_views, height, width, 2, 3] tensor of camera information. | |
| - z_directions: [batch, num_views, 3] tensor of camera z directions. | |
| """ | |
| if isinstance(cameras, torch.Tensor): | |
| raise NotImplementedError | |
| for inner_list in cameras: | |
| assert len(inner_list) == len(cameras[0]) | |
| camera = cameras[0][0] | |
| flat_camera = DifferentiableProjectiveCamera( | |
| origin=torch.from_numpy( | |
| np.stack( | |
| [cam.origin for inner_list in cameras for cam in inner_list], | |
| axis=0, | |
| ) | |
| ).to(self.device), | |
| x=torch.from_numpy( | |
| np.stack( | |
| [cam.x for inner_list in cameras for cam in inner_list], | |
| axis=0, | |
| ) | |
| ).to(self.device), | |
| y=torch.from_numpy( | |
| np.stack( | |
| [cam.y for inner_list in cameras for cam in inner_list], | |
| axis=0, | |
| ) | |
| ).to(self.device), | |
| z=torch.from_numpy( | |
| np.stack( | |
| [cam.z for inner_list in cameras for cam in inner_list], | |
| axis=0, | |
| ) | |
| ).to(self.device), | |
| width=camera.width, | |
| height=camera.height, | |
| x_fov=camera.x_fov, | |
| y_fov=camera.y_fov, | |
| ) | |
| batch_size = len(cameras) * len(cameras[0]) | |
| coords = ( | |
| flat_camera.image_coords() | |
| .to(flat_camera.origin.device) | |
| .unsqueeze(0) | |
| .repeat(batch_size, 1, 1) | |
| ) | |
| rays = flat_camera.camera_rays(coords) | |
| return ( | |
| rays.view(len(cameras), len(cameras[0]), camera.height, camera.width, 2, 3).to( | |
| self.device | |
| ), | |
| flat_camera.z.view(len(cameras), len(cameras[0]), 3).to(self.device), | |
| ) | |
| def sample_pcl_fps(points: torch.Tensor, data_ctx: int, method: str = "fps") -> torch.Tensor: | |
| """ | |
| Run farthest-point sampling on a batch of point clouds. | |
| :param points: batch of shape [N x num_points]. | |
| :param data_ctx: subsample count. | |
| :param method: either 'fps' or 'first'. Using 'first' assumes that the | |
| points are already sorted according to FPS sampling. | |
| :return: batch of shape [N x min(num_points, data_ctx)]. | |
| """ | |
| n_points = points.shape[1] | |
| if n_points == data_ctx: | |
| return points | |
| if method == "first": | |
| return points[:, :data_ctx] | |
| elif method == "fps": | |
| batch = points.cpu().split(1, dim=0) | |
| fps = [sample_fps(x, n_samples=data_ctx) for x in batch] | |
| return torch.cat(fps, dim=0).to(points.device) | |
| else: | |
| raise ValueError(f"unsupported farthest-point sampling method: {method}") | |
| def sample_fps(example: torch.Tensor, n_samples: int) -> torch.Tensor: | |
| """ | |
| :param example: [1, n_points, 3 + n_channels] | |
| :return: [1, n_samples, 3 + n_channels] | |
| """ | |
| points = example.cpu().squeeze(0).numpy() | |
| coords, raw_channels = points[:, :3], points[:, 3:] | |
| n_points, n_channels = raw_channels.shape | |
| assert n_samples <= n_points | |
| channels = {str(idx): raw_channels[:, idx] for idx in range(n_channels)} | |
| max_points = min(32768, n_points) | |
| fps_pcl = ( | |
| PointCloud(coords=coords, channels=channels) | |
| .random_sample(max_points) | |
| .farthest_point_sample(n_samples) | |
| ) | |
| fps_channels = np.stack([fps_pcl.channels[str(idx)] for idx in range(n_channels)], axis=1) | |
| fps = np.concatenate([fps_pcl.coords, fps_channels], axis=1) | |
| fps = torch.from_numpy(fps).unsqueeze(0) | |
| assert fps.shape == (1, n_samples, 3 + n_channels) | |
| return fps | |