Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.transforms import v2 | |
| from einops import rearrange | |
| from freesplatter.models.transformer import Transformer | |
| from freesplatter.utils.infer_util import instantiate_from_config | |
| from freesplatter.utils.recon_util import estimate_focal, fast_pnp | |
| C0 = 0.28209479177387814 | |
| def RGB2SH(rgb): | |
| return (rgb - 0.5) / C0 | |
| class FreeSplatterModel(nn.Module): | |
| def __init__( | |
| self, | |
| transformer_config=None, | |
| renderer_config=None, | |
| use_2dgs=False, | |
| sh_residual=False, | |
| ): | |
| super().__init__() | |
| self.sh_dim = (renderer_config.sh_degree + 1) ** 2 * 3 | |
| self.sh_residual = sh_residual | |
| self.use_2dgs = use_2dgs | |
| self.transformer = instantiate_from_config(transformer_config) | |
| if not use_2dgs: | |
| from .renderer.gaussian_renderer import GaussianRenderer | |
| else: | |
| from .renderer_2dgs.gaussian_renderer import GaussianRenderer | |
| self.gs_renderer = GaussianRenderer(renderer_config=renderer_config) | |
| self.register_buffer('pp', torch.tensor([256, 256], dtype=torch.float32), persistent=False) | |
| def forward_gaussians(self, images, **kwargs): | |
| """ | |
| images: B x N x 3 x H x W | |
| """ | |
| gaussians = self.transformer(images) # B x N x H x W x C | |
| if self.sh_residual: | |
| residual = torch.zeros_like(gaussians) | |
| sh = RGB2SH(rearrange(images, 'b n c h w -> b n h w c')) | |
| residual[..., 3:6] = sh | |
| gaussians = gaussians + residual | |
| gaussians = rearrange(gaussians, 'b n h w c -> b (n h w) c') | |
| return gaussians | |
| def forward_renderer(self, gaussians, c2ws, fxfycxcy, **kwargs): | |
| """ | |
| gaussians: B x K x 14 | |
| c2ws: B x N x 4 x 4 | |
| fxfycxcy: B x N x 4 | |
| """ | |
| render_results = self.gs_renderer.render(gaussians, fxfycxcy, c2ws, **kwargs) | |
| return render_results | |
| def estimate_focals( | |
| self, | |
| images, | |
| masks=None, | |
| use_first_focal=False, | |
| ): | |
| """ | |
| Estimate the focal lengths of N input images. | |
| images: N x 3 x H x W | |
| masks: N x 1 x H x W | |
| """ | |
| assert images.ndim == 4 | |
| N, _, H, W = images.shape | |
| assert H == W, "Non-square images are not supported." | |
| pp = self.pp.to(images) | |
| # pp = torch.tensor([W/2, H/2]).to(images) | |
| focals = [] | |
| for i in range(N): | |
| if use_first_focal and i > 0: | |
| break | |
| images_input = torch.cat([images[i:], images[:i]], dim=0) | |
| gaussians = self.forward_gaussians(images_input.unsqueeze(0)) # 1 x (N x H x W) x 14 | |
| points = rearrange(gaussians[0, :H*W, :3], '(h w) c -> h w c', h=H, w=W) | |
| mask = masks[i] if masks is not None else None | |
| focal = estimate_focal(points, pp=pp, mask=mask) | |
| focals.append(focal) | |
| focals = torch.stack(focals).to(images) | |
| focals = focals.mean().reshape(1).repeat(N) | |
| return focals | |
| def estimate_poses( | |
| self, | |
| images, | |
| gaussians=None, | |
| masks=None, | |
| focals=None, | |
| use_first_focal=True, | |
| opacity_threshold=5e-2, | |
| pnp_iter=20, | |
| ): | |
| """ | |
| Estimate the camera poses of N input images. | |
| images: N x 3 x h x W | |
| gaussians: K x 14 or 1 x K x 14 | |
| masks: N x 1 x H x W | |
| focals: N | |
| """ | |
| assert images.ndim == 4 | |
| N, _, H, W = images.shape | |
| assert H == W, "Non-square images are not supported." | |
| # predict gaussians from images | |
| if gaussians is None: | |
| gaussians = self.forward_gaussians(images.unsqueeze(0)) # 1 x (N x H x W) x 14 | |
| else: | |
| if gaussians.ndim == 2: | |
| gaussians = gaussians.unsqueeze(0) | |
| assert gaussians.shape[1] == N * H * W | |
| points = gaussians[..., :3].reshape(1, N, H, W, 3).squeeze(0) # N x H x W x 3 | |
| opacities = gaussians[..., 3+self.sh_dim].reshape(1, N, H, W).squeeze(0) | |
| opacities = torch.sigmoid(opacities) # N x H x W | |
| # estimate focals if not provided | |
| if focals is None: | |
| focals = self.estimate_focals(images, masks=masks, use_first_focal=use_first_focal) | |
| # run PnP | |
| c2ws = [] | |
| for i in range(N): | |
| pts3d = points[i].float().detach().cpu().numpy() | |
| # If masks are not provided, we use Gaussian opacities | |
| if masks is None: | |
| mask = (opacities[i] > opacity_threshold).detach().cpu().numpy() | |
| else: | |
| mask = masks[i].reshape(H, W).bool().detach().cpu().numpy() | |
| focal = focals[i].item() | |
| _, c2w = fast_pnp(pts3d, mask, focal=focal, niter_PnP=pnp_iter) | |
| c2ws.append(torch.from_numpy(c2w)) | |
| c2ws = torch.stack(c2ws, dim=0).to(images) | |
| return c2ws, focals |