File size: 4,131 Bytes
a856109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from torch import nn
from torch.nn import init
from torchvision.models import resnet34, resnet50
import torchvision.models.vision_transformer as vit


class LightSourceRegressor(nn.Module):
    def __init__(self, num_lights=4, alpha=2.0, beta=8.0, **kwargs):
        super(LightSourceRegressor, self).__init__()

        self.num_lights = num_lights
        self.alpha = alpha
        self.beta = beta

        self.model = resnet34(pretrained=True)
        # self.model = resnet50(pretrained=True)
        # self.model = vit.vit_b_16(pretrained=True)
        self.init_resnet()
        # self.init_vit()

        self.xyr_mlp = nn.Sequential(
            nn.Linear(self.last_dim, 3 * self.num_lights),
        )
        self.p_mlp = nn.Sequential(
            nn.Linear(self.last_dim, self.num_lights),
            nn.Sigmoid(),  # ensure p is in [0, 1]
        )

    def init_resnet(self):
        self.last_dim = self.model.fc.in_features
        self.model.fc = nn.Identity()

    def init_vit(self):
        self.model.image_size = 512
        old_pos_embed = self.model.encoder.pos_embedding
        num_patches_old = (224 // 16) ** 2
        num_patches_new = (512 // 16) ** 2

        if num_patches_new != num_patches_old:
            old_pos_embed = old_pos_embed[:, 1:]
            old_pos_embed = nn.functional.interpolate(
                old_pos_embed.permute(0, 2, 1), size=(num_patches_new,), mode="linear"
            )
            old_pos_embed = old_pos_embed.permute(0, 2, 1)

            # new positional embedding
            self.model.encoder.pos_embedding = nn.Parameter(
                torch.cat(
                    [self.model.encoder.pos_embedding[:, :1], old_pos_embed], dim=1
                )
            )

        # num_classes = 4 * self.num_lights  # x, y, r, p
        # self.model.heads.head = nn.Linear(self.model.hidden_dim, num_classes)

        # remove the head
        self.last_dim = self.model.hidden_dim
        self.model.heads.head = nn.Identity()

    def forward(self, x, height=512, width=512, smoothness=0.1, merge=False):
        _x = self.model(x)  # [B, last_dim]

        _xyr = self.xyr_mlp(_x)
        _xyr = _xyr.view(-1, self.num_lights, 3)

        _p = self.p_mlp(_x)
        _p = _p.view(-1, self.num_lights)

        output = torch.cat([_xyr, _p.unsqueeze(-1)], dim=-1)

        return output

    def forward_render(self, x, height=512, width=512, smoothness=0.1, merge=False):
        _x = self.forward(x)

        _xy = _x[:, :, :2]
        _r = _x[:, :, 2]
        _p = _x[:, :, 3]

        masks = None
        masks_merge = None
        for b in range(_x.size(0)):
            x, y, r = _xy[b, :, 0] * width, _xy[b, :, 1] * width, _r[b] * width / 2
            p = _p[b]

            mask_list = []
            for i in range(self.num_lights):
                if r[i] < 0 or r[i] > width or p[i] < 0.5:
                    continue

                y_coords, x_coords = torch.meshgrid(
                    torch.arange(height, device=x.device),
                    torch.arange(width, device=x.device),
                    indexing="ij",
                )

                distances = torch.sqrt((x_coords - x[i]) ** 2 + (y_coords - y[i]) ** 2)
                mask_i = torch.sigmoid(smoothness * (r[i] - distances))
                mask_list.append(mask_i)

            if len(mask_list) == 0:
                _mask_merge = torch.zeros(1, 1, height, width, device=x.device)
            else:
                _mask_merge = torch.stack(mask_list, dim=0).sum(dim=0).unsqueeze(0)
                _mask_merge = _mask_merge.unsqueeze(0)

            masks_merge = (
                _mask_merge
                if masks_merge is None
                else torch.cat([masks_merge, _mask_merge], dim=0)
            )

        masks_merge = torch.clamp(masks_merge, 0, 1)

        return masks_merge  # [B, 1, H, W]


if __name__ == "__main__":
    # pydiffvg.set_use_gpu(torch.cuda.is_available())
    model = LightSourceRegressor(num_lights=4).cuda()
    x = torch.randn(8, 3, 512, 512, device="cuda")
    y = model.forward_render(x)
    print(y.shape)