OMG4 / train.py
MinShirley
{update code}
11f5b0a
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import os
import random
import sys
import uuid
import torch
import numpy as np
import torchvision.transforms as T
import imageio
import lpips
from torch import nn
from utils.loss_utils import l1_loss, ssim, msssim
from gaussian_renderer import render
from scene import Scene, GaussianModel
from utils.general_utils import safe_state, knn
from tqdm import tqdm
from utils.image_utils import psnr, easy_cmap
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams
from torchvision.utils import make_grid
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from torch.utils.data import DataLoader
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_FOUND = True
except ImportError:
TENSORBOARD_FOUND = False
from PIL import Image
from compute_gradient import calc_gradient
from utils.compress_utils import save_comp
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint, debug_from,
gaussian_dim, time_duration, num_pts, num_pts_ratio, rot_4d, force_sh_3d, batch_size):
if dataset.frame_ratio > 1:
time_duration = [time_duration[0] / dataset.frame_ratio, time_duration[1] / dataset.frame_ratio]
first_iter = 0
tb_writer = prepare_output_and_logger(dataset)
gaussians = GaussianModel(dataset.sh_degree, gaussian_dim=gaussian_dim, time_duration=time_duration, rot_4d=rot_4d, force_sh_3d=force_sh_3d, sh_degree_t=2 if pipe.eval_shfs_4d else 0)
scene = Scene(dataset, gaussians, num_pts=num_pts, num_pts_ratio=num_pts_ratio, time_duration=time_duration)
gaussians.training_setup(opt)
os.makedirs(scene.model_path, exist_ok=True)
loss_log_path = os.path.join(scene.model_path, "loss_log.txt")
loss_log_file = open(loss_log_path, "a")
tau_sim = opt.tau_sim
sim_cutoff = opt.sim_cutoff
grid_size = opt.grid_size
if checkpoint:
(model_params, first_iter) = torch.load(checkpoint, weights_only=False)
gaussians.restore(model_params, opt)
first_iter = 30000
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
os.makedirs(os.path.join(dataset.model_path, "train"), exist_ok=True)
os.makedirs(os.path.join(dataset.model_path, "test"), exist_ok=True)
num_merge = opt.num_merge
compression_start = 30000 # After 4DGS training
grad_pruning_iter = compression_start + opt.grad_pruning_iter
first_merge_iter = grad_pruning_iter + opt.grad_pruning_opt_iter
svq3d_iter = first_merge_iter + opt.merge_opt_iter * num_merge + opt.net_opt_iter
svq4d_iter = svq3d_iter + opt.svq3d_opt_iter
encode_iter = svq4d_iter + opt.svq4d_opt_iter
final_iteration = encode_iter
testing_iterations.append(final_iteration)
if args.grad:
view_grad = np.load(os.path.join(args.grad,'view_grad.npy'))
t_grad = np.load(os.path.join(args.grad,'t_grad.npy'))
else:
view_grad, t_grad = calc_gradient(dataset, opt, pipe, scene, gaussians, batch_size, bg_color, background)
#gradient sampling
mask = gaussians.gradient_sampling(opt.tau_GS, view_grad, t_grad, args)
torch.cuda.empty_cache()
iter_start = torch.cuda.Event(enable_timing = True)
iter_end = torch.cuda.Event(enable_timing = True)
ema_loss_for_log = 0.0
ema_l1loss_for_log = 0.0
ema_ssimloss_for_log = 0.0
lambda_all = [key for key in opt.__dict__.keys() if key.startswith('lambda') and key!='lambda_dssim']
for lambda_name in lambda_all:
vars()[f"ema_{lambda_name.replace('lambda_','')}_for_log"] = 0.0
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
first_iter += 1
if pipe.env_map_res:
env_map = nn.Parameter(torch.zeros((3,pipe.env_map_res, pipe.env_map_res),dtype=torch.float, device="cuda").requires_grad_(True))
env_map_optimizer = torch.optim.Adam([env_map], lr=opt.feature_lr, eps=1e-15)
else:
env_map = None
gaussians.env_map = env_map
training_dataset = scene.getTrainCameras()
training_dataloader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=12 if dataset.dataloader else 0, collate_fn=lambda x: x, drop_last=True)
iteration = first_iter
lpips_model = lpips.LPIPS(net='alex') # vgg for Bartender
lpips_model.eval()
lpips_model.requires_grad_(False)
lpips_model = lpips_model.to("cuda")
actual_storage = 0.0
while iteration < opt.iterations + 1:
for batch_data in training_dataloader:
iteration += 1
if iteration > opt.iterations:
break
iter_start.record()
gaussians.update_learning_rate(iteration)
# Every 1000 its we increase the levels of SH up to a maximum degree
if iteration % opt.sh_increase_interval == 0:
gaussians.oneupSHdegree()
# Render
if (iteration - 1) == debug_from:
pipe.debug = True
batch_point_grad = []
batch_visibility_filter = []
batch_radii = []
for batch_idx in range(batch_size):
gt_image, viewpoint_cam = batch_data[batch_idx]
gt_image = gt_image.cuda()
viewpoint_cam = viewpoint_cam.cuda()
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
depth = render_pkg["depth"]
alpha = render_pkg["alpha"]
# Loss
Ll1 = l1_loss(image, gt_image)
Lssim = 1.0 - ssim(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * Lssim
loss = loss / batch_size
loss.backward()
if iteration<=first_merge_iter:
batch_point_grad.append(torch.norm(viewspace_point_tensor.grad[:,:2], dim=-1))
batch_radii.append(radii)
batch_visibility_filter.append(visibility_filter)
if iteration %100 ==0 :
img = image.detach().cpu().clamp(0, 1)
img = T.ToPILImage()(img)
save_dir = os.path.join(scene.model_path, "renders")
os.makedirs(save_dir, exist_ok=True)
if batch_size > 1:
visibility_count = torch.stack(batch_visibility_filter,1).sum(1)
visibility_filter = visibility_count > 0
radii = torch.stack(batch_radii,1).max(1)[0]
if iteration<=first_merge_iter:
batch_viewspace_point_grad = torch.stack(batch_point_grad,1).sum(1)
batch_viewspace_point_grad[visibility_filter] = batch_viewspace_point_grad[visibility_filter] * batch_size / visibility_count[visibility_filter]
batch_viewspace_point_grad = batch_viewspace_point_grad.unsqueeze(1)
if gaussians.gaussian_dim == 4:
batch_t_grad = gaussians._t.grad.clone()[:,0].detach()
batch_t_grad[visibility_filter] = batch_t_grad[visibility_filter] * batch_size / visibility_count[visibility_filter]
batch_t_grad = batch_t_grad.unsqueeze(1)
else:
if gaussians.gaussian_dim == 4:
batch_t_grad = gaussians._t.grad.clone().detach()
iter_end.record()
loss_dict = {"Ll1": Ll1,
"Lssim": Lssim}
with torch.no_grad():
psnr_for_log = psnr(image, gt_image).mean().double()
# Progress bar
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
ema_l1loss_for_log = 0.4 * Ll1.item() + 0.6 * ema_l1loss_for_log
ema_ssimloss_for_log = 0.4 * Lssim.item() + 0.6 * ema_ssimloss_for_log
for lambda_name in lambda_all:
if opt.__dict__[lambda_name] > 0:
ema = vars()[f"ema_{lambda_name.replace('lambda_', '')}_for_log"]
vars()[f"ema_{lambda_name.replace('lambda_', '')}_for_log"] = 0.4 * vars()[f"L{lambda_name.replace('lambda_', '')}"].item() + 0.6*ema
loss_dict[lambda_name.replace("lambda_", "L")] = vars()[lambda_name.replace("lambda_", "L")]
if iteration % 10 == 0:
postfix = {"Loss": f"{ema_loss_for_log:.{7}f}",
"PSNR": f"{psnr_for_log:.{2}f}",
"Ll1": f"{ema_l1loss_for_log:.{4}f}",
"N": f"{gaussians._xyz.shape[0]:.1f}",
"Lssim": f"{ema_ssimloss_for_log:.{4}f}"}
for lambda_name in lambda_all:
if opt.__dict__[lambda_name] > 0:
ema_loss = vars()[f"ema_{lambda_name.replace('lambda_', '')}_for_log"]
postfix[lambda_name.replace("lambda_", "L")] = f"{ema_loss:.{4}f}"
progress_bar.set_postfix(postfix)
progress_bar.update(10)
if iteration == opt.iterations:
progress_bar.close()
#gradient pruning
if iteration == grad_pruning_iter:
gaussians.gradient_pruning(view_grad, t_grad, opt.tau_GP, opt.tau_GP, args, mask)
torch.cuda.empty_cache()
#gaussian merging
if iteration == first_merge_iter:
gaussians.calc_clusters(grid_size=grid_size, tau_sim=tau_sim, sim_cutoff=sim_cutoff, t_grid_size=opt.t_grid_size)
if opt.grid_exp_ratio:
grid_size *= opt.grid_exp_ratio
gaussians.set_alpha_groups()
if num_merge and iteration % 1000 == 0 and iteration > first_merge_iter:
print(iteration,"Pruning merged Gaussians using learned alpha...")
num_merge -= 1
gaussians.training_alpha = False
N = gaussians.get_xyz.shape[0]
gaussians.alpha_pruning_groups()
loss_log_file.write(f"Merge done. {N} -> {gaussians._xyz.shape[0]}\n")
print(f"Merge done. {N} -> {gaussians._xyz.shape[0]}\n")
if num_merge: # should prepare for next merge
gaussians.calc_clusters(grid_size=grid_size, tau_sim=tau_sim, sim_cutoff=sim_cutoff, t_grid_size=opt.t_grid_size)
gaussians.set_alpha_groups()
else: # merging is done, construct net
print("start training network")
loss_log_file.write(f"Start training network.\n")
gaussians.construct_net()
#3d svq
if iteration == svq3d_iter:
loss_log_file.write(f"3D svq start\n.")
gaussians.apply_svq_3d(args)
#4d svq
if iteration == svq4d_iter:
loss_log_file.write(f"4D svq start\n.")
gaussians.apply_svq_4d(args)
if iteration == encode_iter:
print("comp")
save_dict = gaussians.encode()
save_comp(scene.model_path + "/comp.xz", save_dict)
actual_storage = os.path.getsize(scene.model_path + "/comp.xz") / 1024 / 1024 # header is included (Not 100% actual storage).
gaussians.decode(save_dict, decompress=True)
# Log and save
test_psnr = training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background), loss_dict, lpips_model, loss_log_file, actual_storage)
if (iteration in saving_iterations):
print("\n[ITER {}] Saving Gaussians".format(iteration))
scene.save(iteration)
# Densification
if iteration < opt.densify_until_iter and (opt.densify_until_num_points < 0 or gaussians.get_xyz.shape[0] < opt.densify_until_num_points):
# Keep track of max radii in image-space for pruning
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
if batch_size == 1:
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter, batch_t_grad if gaussians.gaussian_dim == 4 else None)
else:
gaussians.add_densification_stats_grad(batch_viewspace_point_grad, visibility_filter, batch_t_grad if gaussians.gaussian_dim == 4 else None)
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
gaussians.densify_and_prune(opt.densify_grad_threshold, opt.thresh_opa_prune, scene.cameras_extent, size_threshold, opt.densify_grad_t_threshold)
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
gaussians.reset_opacity()
if iteration % 100 == 0 and loss_log_file:
log_line = f"[ITER {iteration}] Loss: {ema_loss_for_log:.6f} | PSNR: {psnr_for_log:.2f} | Ll1: {ema_l1loss_for_log:.4f} | xyz: {gaussians._xyz.shape[0]:.4f} | Lssim: {ema_ssimloss_for_log:.4f}"
for lambda_name in lambda_all:
if opt.__dict__[lambda_name] > 0:
ema_val = vars()[f"ema_{lambda_name.replace('lambda_', '')}_for_log"]
log_line += f" | L{lambda_name.replace('lambda_', '')}: {ema_val:.4f}"
loss_log_file.write(log_line + "\n")
if iteration in testing_iterations and loss_log_file:
loss_log_file.write(f"[ITER {iteration}] test_psnr: {test_psnr:.4f}\n")
if loss_log_file:
loss_log_file.flush()
# Optimizer step
if iteration < final_iteration:
gaussians.optimizer.step()
gaussians.optimizer.zero_grad(set_to_none = True)
if pipe.env_map_res and iteration < pipe.env_optimize_until:
env_map_optimizer.step()
env_map_optimizer.zero_grad(set_to_none = True)
if gaussians.net_enabled:
gaussians.optimizer_net.step()
gaussians.optimizer_net.zero_grad(set_to_none = True)
gaussians.scheduler_net.step()
if gaussians.vq_enabled:
if hasattr(gaussians, "optimizer_code") and gaussians.optimizer_code is not None:
gaussians.optimizer_code.step()
gaussians.optimizer_code.zero_grad()
if hasattr(gaussians, "optimizer_code_4d") and gaussians.optimizer_code_4d is not None:
gaussians.optimizer_code_4d.step()
gaussians.optimizer_code_4d.zero_grad()
def prepare_output_and_logger(args):
if not args.model_path:
if os.getenv('OAR_JOB_ID'):
unique_str=os.getenv('OAR_JOB_ID')
else:
unique_str = str(uuid.uuid4())
args.model_path = os.path.join("./output/", unique_str[0:10])
# Set up output folder
print("Output folder: {}".format(args.model_path))
os.makedirs(args.model_path, exist_ok = True)
with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
cfg_log_f.write(str(Namespace(**vars(args))))
# Create Tensorboard writer
tb_writer = None
if TENSORBOARD_FOUND:
tb_writer = SummaryWriter(args.model_path)
else:
print("Tensorboard not available: not logging progress")
return tb_writer
def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, loss_dict=None, lpips_model=None, log_file=None, actual_storage=0.0):
if tb_writer:
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
tb_writer.add_scalar('train_loss_patches/ssim_loss', Ll1.item(), iteration)
tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
tb_writer.add_scalar('iter_time', elapsed, iteration)
tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
if loss_dict is not None:
if "Lrigid" in loss_dict:
tb_writer.add_scalar('train_loss_patches/rigid_loss', loss_dict['Lrigid'].item(), iteration)
if "Ldepth" in loss_dict:
tb_writer.add_scalar('train_loss_patches/depth_loss', loss_dict['Ldepth'].item(), iteration)
if "Ltv" in loss_dict:
tb_writer.add_scalar('train_loss_patches/tv_loss', loss_dict['Ltv'].item(), iteration)
if "Lopa" in loss_dict:
tb_writer.add_scalar('train_loss_patches/opa_loss', loss_dict['Lopa'].item(), iteration)
if "Lptsopa" in loss_dict:
tb_writer.add_scalar('train_loss_patches/pts_opa_loss', loss_dict['Lptsopa'].item(), iteration)
if "Lsmooth" in loss_dict:
tb_writer.add_scalar('train_loss_patches/smooth_loss', loss_dict['Lsmooth'].item(), iteration)
if "Llaplacian" in loss_dict:
tb_writer.add_scalar('train_loss_patches/laplacian_loss', loss_dict['Llaplacian'].item(), iteration)
psnr_test_iter = 0.0
# Report test and samples of training set
if iteration in testing_iterations:
validation_configs = ({'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]},
{'name': 'test', 'cameras' : [scene.getTestCameras()[idx] for idx in range(len(scene.getTestCameras()))]})
for config in validation_configs:
if config['cameras'] and len(config['cameras']) > 0:
l1_test = 0.0
psnr_test = 0.0
ssim_test = 0.0
msssim_test = 0.0
lpips_test = 0.0
for idx, batch_data in enumerate(tqdm(config['cameras'])):
gt_image, viewpoint = batch_data
gt_image = gt_image.cuda()
viewpoint = viewpoint.cuda()
render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs)
image = torch.clamp(render_pkg["render"], 0.0, 1.0)
depth = easy_cmap(render_pkg['depth'][0])
alpha = torch.clamp(render_pkg['alpha'], 0.0, 1.0).repeat(3,1,1)
if tb_writer and (idx < 5):
grid = [gt_image, image, alpha, depth]
grid = make_grid(grid, nrow=2)
tb_writer.add_images(config['name'] + "_view_{}/gt_vs_render".format(viewpoint.image_name), grid[None], global_step=iteration)
l1_test += l1_loss(image, gt_image).mean().double()
psnr_test += psnr(image, gt_image).mean().double()
lpips_test += lpips_model(image[None], gt_image[None]).squeeze().item()
if idx < 5:
try:
imageio.imwrite(os.path.join(scene.model_path, config['name'], "render_{:05d}_{}.png".format(iteration, viewpoint.image_name)), (image.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8))
except:
pass
test_log_path = os.path.join(scene.model_path, "test.txt")
with open(test_log_path, "a") as test_log_file:
test_log_file.write(f"[Time {viewpoint.timestamp}] test_psnr: {psnr(image, gt_image).mean().double():.4f}\n")
ssim_test += ssim(image, gt_image).mean().double()
msssim_test += msssim(image[None].cpu(), gt_image[None].cpu())
psnr_test /= len(config['cameras'])
l1_test /= len(config['cameras'])
ssim_test /= len(config['cameras'])
msssim_test /= len(config['cameras'])
lpips_test /= len(config['cameras'])
print("\n[ITER {}] Evaluating {}: L1 {} PSNR {} lpips {}".format(iteration, config['name'], l1_test, psnr_test, lpips_test))
if tb_writer:
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - ssim', ssim_test, iteration)
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - msssim', msssim_test, iteration)
if config['name'] == 'test':
psnr_test_iter = psnr_test.item()
log_file.write(f"psnr: {psnr_test_iter} ssim: {ssim_test} lpips: {lpips_test}\n")
log_file.flush()
if iteration == testing_iterations[-1]:
name = "_".join(scene.model_path.split("/")[-2:])
with open("./res.txt", "a") as f:
num_pts = scene.gaussians.get_xyz.shape[0]
f.write("{}: PSNR {:.3f}, SSIM {:.5f}, MS-SSIM {:.5f}, LPIPS {:.5f}, num_pts {}, MB {:.2f}\n".format(name, psnr_test, ssim_test, msssim_test, lpips_test, num_pts, actual_storage))
torch.cuda.empty_cache()
return psnr_test_iter
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Training script parameters")
lp = ModelParams(parser)
op = OptimizationParams(parser)
pp = PipelineParams(parser)
parser.add_argument("--config", type=str)
parser.add_argument('--debug_from', type=int, default=-1)
parser.add_argument('--detect_anomaly', action='store_true', default=False)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000])
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--start_checkpoint", type=str, default = None)
parser.add_argument("--gaussian_dim", type=int, default=3)
parser.add_argument("--time_duration", nargs=2, type=float, default=[-0.5, 0.5])
parser.add_argument('--num_pts', type=int, default=100_000)
parser.add_argument('--num_pts_ratio', type=float, default=1.0)
parser.add_argument("--rot_4d", action="store_true")
parser.add_argument("--force_sh_3d", action="store_true")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--seed", type=int, default=6666)
parser.add_argument("--exhaust_test", action="store_true")
parser.add_argument("--grad", type=str, default = None)
parser.add_argument("--out_path", type=str, default = None)
args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations)
cfg = OmegaConf.load(args.config)
def recursive_merge(key, host):
if isinstance(host[key], DictConfig):
for key1 in host[key].keys():
recursive_merge(key1, host[key])
else:
assert hasattr(args, key), key
setattr(args, key, host[key])
for k in cfg.keys():
recursive_merge(k, cfg)
if args.exhaust_test:
args.test_iterations = args.test_iterations + [i for i in range(0,op.iterations,3000)]
setup_seed(args.seed)
print("Optimizing " + args.model_path)
# Initialize system state (RNG)
safe_state(args.quiet)
torch.autograd.set_detect_anomaly(args.detect_anomaly)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.start_checkpoint, args.debug_from,
args.gaussian_dim, args.time_duration, args.num_pts, args.num_pts_ratio, args.rot_4d, args.force_sh_3d, args.batch_size)
# All done
print("\nTraining complete.")