Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import sys | |
| sys.path.append('./submodules/gaussian-splatting/') | |
| from random import randint | |
| from scene import Scene, GaussianModel | |
| from gaussian_renderer import render | |
| from source.data_utils import scene_cameras_train_test_split | |
| class Warper3DGS(torch.nn.Module): | |
| def __init__(self, sh_degree, opt, pipe, dataset, viewpoint_stack, verbose, | |
| do_train_test_split=True): | |
| super(Warper3DGS, self).__init__() | |
| """ | |
| Init Warper using all the objects necessary for rendering gaussian splats. | |
| Here we merely link class objects to the objects instantiated outsided the class. | |
| """ | |
| print("ready!!!7") | |
| self.gaussians = GaussianModel(sh_degree) | |
| print("ready!!!8") | |
| self.gaussians.tmp_radii = torch.zeros((self.gaussians.get_xyz.shape[0]), device="cuda") | |
| self.render = render | |
| self.gs_config_opt = opt | |
| bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] | |
| self.bg = torch.tensor(bg_color, dtype=torch.float32, device="cuda") | |
| self.pipe = pipe | |
| print("ready!!!") | |
| self.scene = Scene(dataset, self.gaussians, shuffle=False) | |
| print("ready2") | |
| if do_train_test_split: | |
| scene_cameras_train_test_split(self.scene, verbose=verbose) | |
| self.gaussians.training_setup(opt) | |
| self.viewpoint_stack = viewpoint_stack | |
| if not self.viewpoint_stack: | |
| self.viewpoint_stack = self.scene.getTrainCameras().copy() | |
| def forward(self, viewpoint_cam=None): | |
| """ | |
| For a provided camera viewpoint_cam we render gaussians from this viewpoint. | |
| If no camera provided then we use the self.viewpoint_stack (list of cameras). | |
| If the latter is empty we reinitialize it using the self.scene object. | |
| """ | |
| if not viewpoint_cam: | |
| if not self.viewpoint_stack: | |
| self.viewpoint_stack = self.scene.getTrainCameras().copy() | |
| viewpoint_cam = self.viewpoint_stack[randint(0, len(self.viewpoint_stack) - 1)] | |
| render_pkg = self.render(viewpoint_cam, self.gaussians, self.pipe, self.bg) | |
| return render_pkg | |