|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import random |
|
|
import sys |
|
|
import uuid |
|
|
import torch |
|
|
import numpy as np |
|
|
import torchvision.transforms as T |
|
|
import imageio |
|
|
import lpips |
|
|
|
|
|
from torch import nn |
|
|
from utils.loss_utils import l1_loss, ssim, msssim |
|
|
from gaussian_renderer import render |
|
|
from scene import Scene, GaussianModel |
|
|
from utils.general_utils import safe_state, knn |
|
|
from tqdm import tqdm |
|
|
from utils.image_utils import psnr, easy_cmap |
|
|
from argparse import ArgumentParser, Namespace |
|
|
from arguments import ModelParams, PipelineParams, OptimizationParams |
|
|
from torchvision.utils import make_grid |
|
|
from omegaconf import OmegaConf |
|
|
from omegaconf.dictconfig import DictConfig |
|
|
from torch.utils.data import DataLoader |
|
|
try: |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
TENSORBOARD_FOUND = True |
|
|
except ImportError: |
|
|
TENSORBOARD_FOUND = False |
|
|
from PIL import Image |
|
|
from compute_gradient import calc_gradient |
|
|
from utils.compress_utils import save_comp |
|
|
|
|
|
|
|
|
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint, debug_from, |
|
|
gaussian_dim, time_duration, num_pts, num_pts_ratio, rot_4d, force_sh_3d, batch_size): |
|
|
|
|
|
if dataset.frame_ratio > 1: |
|
|
time_duration = [time_duration[0] / dataset.frame_ratio, time_duration[1] / dataset.frame_ratio] |
|
|
|
|
|
first_iter = 0 |
|
|
tb_writer = prepare_output_and_logger(dataset) |
|
|
gaussians = GaussianModel(dataset.sh_degree, gaussian_dim=gaussian_dim, time_duration=time_duration, rot_4d=rot_4d, force_sh_3d=force_sh_3d, sh_degree_t=2 if pipe.eval_shfs_4d else 0) |
|
|
scene = Scene(dataset, gaussians, num_pts=num_pts, num_pts_ratio=num_pts_ratio, time_duration=time_duration) |
|
|
gaussians.training_setup(opt) |
|
|
|
|
|
os.makedirs(scene.model_path, exist_ok=True) |
|
|
loss_log_path = os.path.join(scene.model_path, "loss_log.txt") |
|
|
loss_log_file = open(loss_log_path, "a") |
|
|
|
|
|
tau_sim = opt.tau_sim |
|
|
sim_cutoff = opt.sim_cutoff |
|
|
grid_size = opt.grid_size |
|
|
|
|
|
if checkpoint: |
|
|
(model_params, first_iter) = torch.load(checkpoint, weights_only=False) |
|
|
gaussians.restore(model_params, opt) |
|
|
|
|
|
first_iter = 30000 |
|
|
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] |
|
|
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") |
|
|
|
|
|
os.makedirs(os.path.join(dataset.model_path, "train"), exist_ok=True) |
|
|
os.makedirs(os.path.join(dataset.model_path, "test"), exist_ok=True) |
|
|
|
|
|
num_merge = opt.num_merge |
|
|
|
|
|
compression_start = 30000 |
|
|
grad_pruning_iter = compression_start + opt.grad_pruning_iter |
|
|
first_merge_iter = grad_pruning_iter + opt.grad_pruning_opt_iter |
|
|
svq3d_iter = first_merge_iter + opt.merge_opt_iter * num_merge + opt.net_opt_iter |
|
|
svq4d_iter = svq3d_iter + opt.svq3d_opt_iter |
|
|
encode_iter = svq4d_iter + opt.svq4d_opt_iter |
|
|
final_iteration = encode_iter |
|
|
testing_iterations.append(final_iteration) |
|
|
|
|
|
if args.grad: |
|
|
view_grad = np.load(os.path.join(args.grad,'view_grad.npy')) |
|
|
t_grad = np.load(os.path.join(args.grad,'t_grad.npy')) |
|
|
else: |
|
|
view_grad, t_grad = calc_gradient(dataset, opt, pipe, scene, gaussians, batch_size, bg_color, background) |
|
|
|
|
|
|
|
|
mask = gaussians.gradient_sampling(opt.tau_GS, view_grad, t_grad, args) |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
iter_start = torch.cuda.Event(enable_timing = True) |
|
|
iter_end = torch.cuda.Event(enable_timing = True) |
|
|
|
|
|
ema_loss_for_log = 0.0 |
|
|
ema_l1loss_for_log = 0.0 |
|
|
ema_ssimloss_for_log = 0.0 |
|
|
lambda_all = [key for key in opt.__dict__.keys() if key.startswith('lambda') and key!='lambda_dssim'] |
|
|
for lambda_name in lambda_all: |
|
|
vars()[f"ema_{lambda_name.replace('lambda_','')}_for_log"] = 0.0 |
|
|
|
|
|
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") |
|
|
first_iter += 1 |
|
|
|
|
|
if pipe.env_map_res: |
|
|
env_map = nn.Parameter(torch.zeros((3,pipe.env_map_res, pipe.env_map_res),dtype=torch.float, device="cuda").requires_grad_(True)) |
|
|
env_map_optimizer = torch.optim.Adam([env_map], lr=opt.feature_lr, eps=1e-15) |
|
|
else: |
|
|
env_map = None |
|
|
|
|
|
gaussians.env_map = env_map |
|
|
|
|
|
training_dataset = scene.getTrainCameras() |
|
|
training_dataloader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=12 if dataset.dataloader else 0, collate_fn=lambda x: x, drop_last=True) |
|
|
|
|
|
iteration = first_iter |
|
|
lpips_model = lpips.LPIPS(net='alex') |
|
|
lpips_model.eval() |
|
|
lpips_model.requires_grad_(False) |
|
|
lpips_model = lpips_model.to("cuda") |
|
|
actual_storage = 0.0 |
|
|
|
|
|
while iteration < opt.iterations + 1: |
|
|
for batch_data in training_dataloader: |
|
|
iteration += 1 |
|
|
if iteration > opt.iterations: |
|
|
break |
|
|
|
|
|
iter_start.record() |
|
|
gaussians.update_learning_rate(iteration) |
|
|
|
|
|
|
|
|
if iteration % opt.sh_increase_interval == 0: |
|
|
gaussians.oneupSHdegree() |
|
|
|
|
|
|
|
|
if (iteration - 1) == debug_from: |
|
|
pipe.debug = True |
|
|
|
|
|
batch_point_grad = [] |
|
|
batch_visibility_filter = [] |
|
|
batch_radii = [] |
|
|
|
|
|
for batch_idx in range(batch_size): |
|
|
gt_image, viewpoint_cam = batch_data[batch_idx] |
|
|
gt_image = gt_image.cuda() |
|
|
viewpoint_cam = viewpoint_cam.cuda() |
|
|
|
|
|
render_pkg = render(viewpoint_cam, gaussians, pipe, background) |
|
|
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] |
|
|
depth = render_pkg["depth"] |
|
|
alpha = render_pkg["alpha"] |
|
|
|
|
|
|
|
|
Ll1 = l1_loss(image, gt_image) |
|
|
Lssim = 1.0 - ssim(image, gt_image) |
|
|
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * Lssim |
|
|
|
|
|
loss = loss / batch_size |
|
|
loss.backward() |
|
|
if iteration<=first_merge_iter: |
|
|
batch_point_grad.append(torch.norm(viewspace_point_tensor.grad[:,:2], dim=-1)) |
|
|
batch_radii.append(radii) |
|
|
batch_visibility_filter.append(visibility_filter) |
|
|
|
|
|
if iteration %100 ==0 : |
|
|
img = image.detach().cpu().clamp(0, 1) |
|
|
img = T.ToPILImage()(img) |
|
|
|
|
|
save_dir = os.path.join(scene.model_path, "renders") |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
if batch_size > 1: |
|
|
visibility_count = torch.stack(batch_visibility_filter,1).sum(1) |
|
|
visibility_filter = visibility_count > 0 |
|
|
radii = torch.stack(batch_radii,1).max(1)[0] |
|
|
if iteration<=first_merge_iter: |
|
|
batch_viewspace_point_grad = torch.stack(batch_point_grad,1).sum(1) |
|
|
batch_viewspace_point_grad[visibility_filter] = batch_viewspace_point_grad[visibility_filter] * batch_size / visibility_count[visibility_filter] |
|
|
batch_viewspace_point_grad = batch_viewspace_point_grad.unsqueeze(1) |
|
|
|
|
|
if gaussians.gaussian_dim == 4: |
|
|
batch_t_grad = gaussians._t.grad.clone()[:,0].detach() |
|
|
batch_t_grad[visibility_filter] = batch_t_grad[visibility_filter] * batch_size / visibility_count[visibility_filter] |
|
|
batch_t_grad = batch_t_grad.unsqueeze(1) |
|
|
else: |
|
|
if gaussians.gaussian_dim == 4: |
|
|
batch_t_grad = gaussians._t.grad.clone().detach() |
|
|
|
|
|
iter_end.record() |
|
|
loss_dict = {"Ll1": Ll1, |
|
|
"Lssim": Lssim} |
|
|
|
|
|
with torch.no_grad(): |
|
|
psnr_for_log = psnr(image, gt_image).mean().double() |
|
|
|
|
|
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log |
|
|
ema_l1loss_for_log = 0.4 * Ll1.item() + 0.6 * ema_l1loss_for_log |
|
|
ema_ssimloss_for_log = 0.4 * Lssim.item() + 0.6 * ema_ssimloss_for_log |
|
|
|
|
|
for lambda_name in lambda_all: |
|
|
if opt.__dict__[lambda_name] > 0: |
|
|
ema = vars()[f"ema_{lambda_name.replace('lambda_', '')}_for_log"] |
|
|
vars()[f"ema_{lambda_name.replace('lambda_', '')}_for_log"] = 0.4 * vars()[f"L{lambda_name.replace('lambda_', '')}"].item() + 0.6*ema |
|
|
loss_dict[lambda_name.replace("lambda_", "L")] = vars()[lambda_name.replace("lambda_", "L")] |
|
|
|
|
|
if iteration % 10 == 0: |
|
|
postfix = {"Loss": f"{ema_loss_for_log:.{7}f}", |
|
|
"PSNR": f"{psnr_for_log:.{2}f}", |
|
|
"Ll1": f"{ema_l1loss_for_log:.{4}f}", |
|
|
"N": f"{gaussians._xyz.shape[0]:.1f}", |
|
|
"Lssim": f"{ema_ssimloss_for_log:.{4}f}"} |
|
|
|
|
|
for lambda_name in lambda_all: |
|
|
if opt.__dict__[lambda_name] > 0: |
|
|
ema_loss = vars()[f"ema_{lambda_name.replace('lambda_', '')}_for_log"] |
|
|
postfix[lambda_name.replace("lambda_", "L")] = f"{ema_loss:.{4}f}" |
|
|
|
|
|
progress_bar.set_postfix(postfix) |
|
|
progress_bar.update(10) |
|
|
if iteration == opt.iterations: |
|
|
progress_bar.close() |
|
|
|
|
|
|
|
|
if iteration == grad_pruning_iter: |
|
|
gaussians.gradient_pruning(view_grad, t_grad, opt.tau_GP, opt.tau_GP, args, mask) |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
if iteration == first_merge_iter: |
|
|
gaussians.calc_clusters(grid_size=grid_size, tau_sim=tau_sim, sim_cutoff=sim_cutoff, t_grid_size=opt.t_grid_size) |
|
|
if opt.grid_exp_ratio: |
|
|
grid_size *= opt.grid_exp_ratio |
|
|
gaussians.set_alpha_groups() |
|
|
|
|
|
if num_merge and iteration % 1000 == 0 and iteration > first_merge_iter: |
|
|
print(iteration,"Pruning merged Gaussians using learned alpha...") |
|
|
num_merge -= 1 |
|
|
gaussians.training_alpha = False |
|
|
N = gaussians.get_xyz.shape[0] |
|
|
gaussians.alpha_pruning_groups() |
|
|
loss_log_file.write(f"Merge done. {N} -> {gaussians._xyz.shape[0]}\n") |
|
|
print(f"Merge done. {N} -> {gaussians._xyz.shape[0]}\n") |
|
|
if num_merge: |
|
|
gaussians.calc_clusters(grid_size=grid_size, tau_sim=tau_sim, sim_cutoff=sim_cutoff, t_grid_size=opt.t_grid_size) |
|
|
gaussians.set_alpha_groups() |
|
|
|
|
|
else: |
|
|
print("start training network") |
|
|
loss_log_file.write(f"Start training network.\n") |
|
|
gaussians.construct_net() |
|
|
|
|
|
|
|
|
if iteration == svq3d_iter: |
|
|
loss_log_file.write(f"3D svq start\n.") |
|
|
gaussians.apply_svq_3d(args) |
|
|
|
|
|
if iteration == svq4d_iter: |
|
|
loss_log_file.write(f"4D svq start\n.") |
|
|
gaussians.apply_svq_4d(args) |
|
|
|
|
|
if iteration == encode_iter: |
|
|
print("comp") |
|
|
save_dict = gaussians.encode() |
|
|
save_comp(scene.model_path + "/comp.xz", save_dict) |
|
|
|
|
|
actual_storage = os.path.getsize(scene.model_path + "/comp.xz") / 1024 / 1024 |
|
|
gaussians.decode(save_dict, decompress=True) |
|
|
|
|
|
|
|
|
|
|
|
test_psnr = training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background), loss_dict, lpips_model, loss_log_file, actual_storage) |
|
|
if (iteration in saving_iterations): |
|
|
print("\n[ITER {}] Saving Gaussians".format(iteration)) |
|
|
scene.save(iteration) |
|
|
|
|
|
|
|
|
if iteration < opt.densify_until_iter and (opt.densify_until_num_points < 0 or gaussians.get_xyz.shape[0] < opt.densify_until_num_points): |
|
|
|
|
|
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) |
|
|
if batch_size == 1: |
|
|
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter, batch_t_grad if gaussians.gaussian_dim == 4 else None) |
|
|
else: |
|
|
gaussians.add_densification_stats_grad(batch_viewspace_point_grad, visibility_filter, batch_t_grad if gaussians.gaussian_dim == 4 else None) |
|
|
|
|
|
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: |
|
|
size_threshold = 20 if iteration > opt.opacity_reset_interval else None |
|
|
gaussians.densify_and_prune(opt.densify_grad_threshold, opt.thresh_opa_prune, scene.cameras_extent, size_threshold, opt.densify_grad_t_threshold) |
|
|
|
|
|
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): |
|
|
gaussians.reset_opacity() |
|
|
|
|
|
if iteration % 100 == 0 and loss_log_file: |
|
|
log_line = f"[ITER {iteration}] Loss: {ema_loss_for_log:.6f} | PSNR: {psnr_for_log:.2f} | Ll1: {ema_l1loss_for_log:.4f} | xyz: {gaussians._xyz.shape[0]:.4f} | Lssim: {ema_ssimloss_for_log:.4f}" |
|
|
for lambda_name in lambda_all: |
|
|
if opt.__dict__[lambda_name] > 0: |
|
|
ema_val = vars()[f"ema_{lambda_name.replace('lambda_', '')}_for_log"] |
|
|
log_line += f" | L{lambda_name.replace('lambda_', '')}: {ema_val:.4f}" |
|
|
loss_log_file.write(log_line + "\n") |
|
|
|
|
|
if iteration in testing_iterations and loss_log_file: |
|
|
loss_log_file.write(f"[ITER {iteration}] test_psnr: {test_psnr:.4f}\n") |
|
|
|
|
|
|
|
|
if loss_log_file: |
|
|
loss_log_file.flush() |
|
|
|
|
|
|
|
|
if iteration < final_iteration: |
|
|
gaussians.optimizer.step() |
|
|
gaussians.optimizer.zero_grad(set_to_none = True) |
|
|
if pipe.env_map_res and iteration < pipe.env_optimize_until: |
|
|
env_map_optimizer.step() |
|
|
env_map_optimizer.zero_grad(set_to_none = True) |
|
|
if gaussians.net_enabled: |
|
|
gaussians.optimizer_net.step() |
|
|
gaussians.optimizer_net.zero_grad(set_to_none = True) |
|
|
gaussians.scheduler_net.step() |
|
|
if gaussians.vq_enabled: |
|
|
if hasattr(gaussians, "optimizer_code") and gaussians.optimizer_code is not None: |
|
|
gaussians.optimizer_code.step() |
|
|
gaussians.optimizer_code.zero_grad() |
|
|
|
|
|
if hasattr(gaussians, "optimizer_code_4d") and gaussians.optimizer_code_4d is not None: |
|
|
gaussians.optimizer_code_4d.step() |
|
|
gaussians.optimizer_code_4d.zero_grad() |
|
|
|
|
|
def prepare_output_and_logger(args): |
|
|
if not args.model_path: |
|
|
if os.getenv('OAR_JOB_ID'): |
|
|
unique_str=os.getenv('OAR_JOB_ID') |
|
|
else: |
|
|
unique_str = str(uuid.uuid4()) |
|
|
args.model_path = os.path.join("./output/", unique_str[0:10]) |
|
|
|
|
|
|
|
|
print("Output folder: {}".format(args.model_path)) |
|
|
os.makedirs(args.model_path, exist_ok = True) |
|
|
with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: |
|
|
cfg_log_f.write(str(Namespace(**vars(args)))) |
|
|
|
|
|
|
|
|
tb_writer = None |
|
|
if TENSORBOARD_FOUND: |
|
|
tb_writer = SummaryWriter(args.model_path) |
|
|
else: |
|
|
print("Tensorboard not available: not logging progress") |
|
|
return tb_writer |
|
|
|
|
|
def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, loss_dict=None, lpips_model=None, log_file=None, actual_storage=0.0): |
|
|
if tb_writer: |
|
|
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) |
|
|
tb_writer.add_scalar('train_loss_patches/ssim_loss', Ll1.item(), iteration) |
|
|
tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) |
|
|
tb_writer.add_scalar('iter_time', elapsed, iteration) |
|
|
tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) |
|
|
tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) |
|
|
if loss_dict is not None: |
|
|
if "Lrigid" in loss_dict: |
|
|
tb_writer.add_scalar('train_loss_patches/rigid_loss', loss_dict['Lrigid'].item(), iteration) |
|
|
if "Ldepth" in loss_dict: |
|
|
tb_writer.add_scalar('train_loss_patches/depth_loss', loss_dict['Ldepth'].item(), iteration) |
|
|
if "Ltv" in loss_dict: |
|
|
tb_writer.add_scalar('train_loss_patches/tv_loss', loss_dict['Ltv'].item(), iteration) |
|
|
if "Lopa" in loss_dict: |
|
|
tb_writer.add_scalar('train_loss_patches/opa_loss', loss_dict['Lopa'].item(), iteration) |
|
|
if "Lptsopa" in loss_dict: |
|
|
tb_writer.add_scalar('train_loss_patches/pts_opa_loss', loss_dict['Lptsopa'].item(), iteration) |
|
|
if "Lsmooth" in loss_dict: |
|
|
tb_writer.add_scalar('train_loss_patches/smooth_loss', loss_dict['Lsmooth'].item(), iteration) |
|
|
if "Llaplacian" in loss_dict: |
|
|
tb_writer.add_scalar('train_loss_patches/laplacian_loss', loss_dict['Llaplacian'].item(), iteration) |
|
|
|
|
|
psnr_test_iter = 0.0 |
|
|
|
|
|
if iteration in testing_iterations: |
|
|
validation_configs = ({'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}, |
|
|
{'name': 'test', 'cameras' : [scene.getTestCameras()[idx] for idx in range(len(scene.getTestCameras()))]}) |
|
|
|
|
|
for config in validation_configs: |
|
|
if config['cameras'] and len(config['cameras']) > 0: |
|
|
l1_test = 0.0 |
|
|
psnr_test = 0.0 |
|
|
ssim_test = 0.0 |
|
|
msssim_test = 0.0 |
|
|
lpips_test = 0.0 |
|
|
for idx, batch_data in enumerate(tqdm(config['cameras'])): |
|
|
gt_image, viewpoint = batch_data |
|
|
gt_image = gt_image.cuda() |
|
|
viewpoint = viewpoint.cuda() |
|
|
|
|
|
render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs) |
|
|
image = torch.clamp(render_pkg["render"], 0.0, 1.0) |
|
|
|
|
|
depth = easy_cmap(render_pkg['depth'][0]) |
|
|
alpha = torch.clamp(render_pkg['alpha'], 0.0, 1.0).repeat(3,1,1) |
|
|
if tb_writer and (idx < 5): |
|
|
grid = [gt_image, image, alpha, depth] |
|
|
grid = make_grid(grid, nrow=2) |
|
|
tb_writer.add_images(config['name'] + "_view_{}/gt_vs_render".format(viewpoint.image_name), grid[None], global_step=iteration) |
|
|
|
|
|
l1_test += l1_loss(image, gt_image).mean().double() |
|
|
psnr_test += psnr(image, gt_image).mean().double() |
|
|
|
|
|
lpips_test += lpips_model(image[None], gt_image[None]).squeeze().item() |
|
|
|
|
|
if idx < 5: |
|
|
try: |
|
|
imageio.imwrite(os.path.join(scene.model_path, config['name'], "render_{:05d}_{}.png".format(iteration, viewpoint.image_name)), (image.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
test_log_path = os.path.join(scene.model_path, "test.txt") |
|
|
with open(test_log_path, "a") as test_log_file: |
|
|
test_log_file.write(f"[Time {viewpoint.timestamp}] test_psnr: {psnr(image, gt_image).mean().double():.4f}\n") |
|
|
|
|
|
ssim_test += ssim(image, gt_image).mean().double() |
|
|
msssim_test += msssim(image[None].cpu(), gt_image[None].cpu()) |
|
|
psnr_test /= len(config['cameras']) |
|
|
l1_test /= len(config['cameras']) |
|
|
ssim_test /= len(config['cameras']) |
|
|
msssim_test /= len(config['cameras']) |
|
|
lpips_test /= len(config['cameras']) |
|
|
print("\n[ITER {}] Evaluating {}: L1 {} PSNR {} lpips {}".format(iteration, config['name'], l1_test, psnr_test, lpips_test)) |
|
|
if tb_writer: |
|
|
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) |
|
|
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) |
|
|
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - ssim', ssim_test, iteration) |
|
|
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - msssim', msssim_test, iteration) |
|
|
if config['name'] == 'test': |
|
|
psnr_test_iter = psnr_test.item() |
|
|
log_file.write(f"psnr: {psnr_test_iter} ssim: {ssim_test} lpips: {lpips_test}\n") |
|
|
log_file.flush() |
|
|
if iteration == testing_iterations[-1]: |
|
|
name = "_".join(scene.model_path.split("/")[-2:]) |
|
|
with open("./res.txt", "a") as f: |
|
|
num_pts = scene.gaussians.get_xyz.shape[0] |
|
|
f.write("{}: PSNR {:.3f}, SSIM {:.5f}, MS-SSIM {:.5f}, LPIPS {:.5f}, num_pts {}, MB {:.2f}\n".format(name, psnr_test, ssim_test, msssim_test, lpips_test, num_pts, actual_storage)) |
|
|
torch.cuda.empty_cache() |
|
|
return psnr_test_iter |
|
|
|
|
|
def setup_seed(seed): |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = ArgumentParser(description="Training script parameters") |
|
|
lp = ModelParams(parser) |
|
|
op = OptimizationParams(parser) |
|
|
pp = PipelineParams(parser) |
|
|
parser.add_argument("--config", type=str) |
|
|
parser.add_argument('--debug_from', type=int, default=-1) |
|
|
parser.add_argument('--detect_anomaly', action='store_true', default=False) |
|
|
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000]) |
|
|
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000]) |
|
|
parser.add_argument("--quiet", action="store_true") |
|
|
parser.add_argument("--start_checkpoint", type=str, default = None) |
|
|
|
|
|
parser.add_argument("--gaussian_dim", type=int, default=3) |
|
|
parser.add_argument("--time_duration", nargs=2, type=float, default=[-0.5, 0.5]) |
|
|
parser.add_argument('--num_pts', type=int, default=100_000) |
|
|
parser.add_argument('--num_pts_ratio', type=float, default=1.0) |
|
|
parser.add_argument("--rot_4d", action="store_true") |
|
|
parser.add_argument("--force_sh_3d", action="store_true") |
|
|
parser.add_argument("--batch_size", type=int, default=1) |
|
|
parser.add_argument("--seed", type=int, default=6666) |
|
|
parser.add_argument("--exhaust_test", action="store_true") |
|
|
parser.add_argument("--grad", type=str, default = None) |
|
|
parser.add_argument("--out_path", type=str, default = None) |
|
|
|
|
|
args = parser.parse_args(sys.argv[1:]) |
|
|
args.save_iterations.append(args.iterations) |
|
|
|
|
|
cfg = OmegaConf.load(args.config) |
|
|
def recursive_merge(key, host): |
|
|
if isinstance(host[key], DictConfig): |
|
|
for key1 in host[key].keys(): |
|
|
recursive_merge(key1, host[key]) |
|
|
else: |
|
|
assert hasattr(args, key), key |
|
|
setattr(args, key, host[key]) |
|
|
for k in cfg.keys(): |
|
|
recursive_merge(k, cfg) |
|
|
|
|
|
|
|
|
if args.exhaust_test: |
|
|
args.test_iterations = args.test_iterations + [i for i in range(0,op.iterations,3000)] |
|
|
|
|
|
setup_seed(args.seed) |
|
|
|
|
|
print("Optimizing " + args.model_path) |
|
|
|
|
|
|
|
|
safe_state(args.quiet) |
|
|
|
|
|
torch.autograd.set_detect_anomaly(args.detect_anomaly) |
|
|
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.start_checkpoint, args.debug_from, |
|
|
args.gaussian_dim, args.time_duration, args.num_pts, args.num_pts_ratio, args.rot_4d, args.force_sh_3d, args.batch_size) |
|
|
|
|
|
|
|
|
print("\nTraining complete.") |
|
|
|