File size: 3,525 Bytes
11f5b0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import random
import torch
from torch import nn
from utils.loss_utils import l1_loss, ssim, msssim
from gaussian_renderer import render
import sys
from scene import Scene, GaussianModel
from utils.general_utils import safe_state, knn
import uuid
from tqdm import tqdm
from utils.image_utils import psnr, easy_cmap
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams
from torchvision.utils import make_grid
import numpy as np
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from torch.utils.data import DataLoader
try:
    from torch.utils.tensorboard import SummaryWriter
    TENSORBOARD_FOUND = True
except ImportError:
    TENSORBOARD_FOUND = False
from PIL import Image
import torchvision.transforms as T

def calc_gradient(dataset, opt, pipe, scene, gaussians, batch_size, bg_color, background):

    training_dataset = scene.getTrainCameras()
    training_dataloader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=12 if dataset.dataloader else 0, collate_fn=lambda x: x, drop_last=True)
     
    
    ###grad ์ธก์ •
    N = gaussians._xyz.shape[0]
    train_cameras = scene.train_cameras[1.0]
    timestamps = sorted(set(cam.timestamp for cam in train_cameras))
    T = len(timestamps)

    training_dataset = scene.getTrainCameras()
    viewspace_grad = torch.zeros((N, T), dtype=torch.float32, device='cuda')
    t_grad = torch.zeros((N, T), dtype=torch.float32, device='cuda')

    for idx in tqdm(range(len(training_dataset)), desc="Computing Gradients"):
        gt_image, viewpoint_cam = training_dataset[idx]
        gt_image= gt_image.cuda()
        viewpoint_cam = viewpoint_cam.cuda()
        timestamp = viewpoint_cam.timestamp
        index = timestamps.index(timestamp)


        render_pkg = render(viewpoint_cam, gaussians, pipe, background)
        image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
        depth = render_pkg["depth"]
        alpha = render_pkg["alpha"]

        # Loss
        Ll1 = l1_loss(image, gt_image)
        Lssim = 1.0 - ssim(image, gt_image)
        loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * Lssim

        loss = loss / batch_size
        loss.backward()


        batch_point_grad = (torch.norm(viewspace_point_tensor.grad[:,:2], dim=-1)) ##batch 1์ผ๋•Œ๋Š” grad:2 ์•ˆํ•˜๋Š”๋ฐ ํ™•์ธ
        viewspace_grad[:, index] += batch_point_grad
        t_grad[:, index]  += gaussians._t.grad.clone().detach().squeeze(1)  # abs ์ถ”๊ฐ€


    final_view_grad = viewspace_grad.sum(dim = 1)
    final_t_grad = t_grad.sum(dim = 1)
    
    if torch.is_tensor(final_view_grad):
        final_view_grad = final_view_grad.detach().cpu().numpy()
    if torch.is_tensor(final_t_grad):
        final_t_grad = final_t_grad.detach().cpu().numpy()

    os.makedirs(os.path.join(scene.model_path, "gradient"), exist_ok=True)

    np.save(os.path.join(scene.model_path, "gradient/view_grad.npy"), final_view_grad)
    np.save(os.path.join(scene.model_path, "gradient/t_grad.npy"), final_t_grad)

    return final_view_grad, final_t_grad