import torch from torch import nn from torch.nn import init from torchvision.models import resnet34, resnet50 import torchvision.models.vision_transformer as vit class LightSourceRegressor(nn.Module): def __init__(self, num_lights=4, alpha=2.0, beta=8.0, **kwargs): super(LightSourceRegressor, self).__init__() self.num_lights = num_lights self.alpha = alpha self.beta = beta self.model = resnet34(pretrained=True) # self.model = resnet50(pretrained=True) # self.model = vit.vit_b_16(pretrained=True) self.init_resnet() # self.init_vit() self.xyr_mlp = nn.Sequential( nn.Linear(self.last_dim, 3 * self.num_lights), ) self.p_mlp = nn.Sequential( nn.Linear(self.last_dim, self.num_lights), nn.Sigmoid(), # ensure p is in [0, 1] ) def init_resnet(self): self.last_dim = self.model.fc.in_features self.model.fc = nn.Identity() def init_vit(self): self.model.image_size = 512 old_pos_embed = self.model.encoder.pos_embedding num_patches_old = (224 // 16) ** 2 num_patches_new = (512 // 16) ** 2 if num_patches_new != num_patches_old: old_pos_embed = old_pos_embed[:, 1:] old_pos_embed = nn.functional.interpolate( old_pos_embed.permute(0, 2, 1), size=(num_patches_new,), mode="linear" ) old_pos_embed = old_pos_embed.permute(0, 2, 1) # new positional embedding self.model.encoder.pos_embedding = nn.Parameter( torch.cat( [self.model.encoder.pos_embedding[:, :1], old_pos_embed], dim=1 ) ) # num_classes = 4 * self.num_lights # x, y, r, p # self.model.heads.head = nn.Linear(self.model.hidden_dim, num_classes) # remove the head self.last_dim = self.model.hidden_dim self.model.heads.head = nn.Identity() def forward(self, x, height=512, width=512, smoothness=0.1, merge=False): _x = self.model(x) # [B, last_dim] _xyr = self.xyr_mlp(_x) _xyr = _xyr.view(-1, self.num_lights, 3) _p = self.p_mlp(_x) _p = _p.view(-1, self.num_lights) output = torch.cat([_xyr, _p.unsqueeze(-1)], dim=-1) return output def forward_render(self, x, height=512, width=512, smoothness=0.1, merge=False): _x = self.forward(x) _xy = _x[:, :, :2] _r = _x[:, :, 2] _p = _x[:, :, 3] masks = None masks_merge = None for b in range(_x.size(0)): x, y, r = _xy[b, :, 0] * width, _xy[b, :, 1] * width, _r[b] * width / 2 p = _p[b] mask_list = [] for i in range(self.num_lights): if r[i] < 0 or r[i] > width or p[i] < 0.5: continue y_coords, x_coords = torch.meshgrid( torch.arange(height, device=x.device), torch.arange(width, device=x.device), indexing="ij", ) distances = torch.sqrt((x_coords - x[i]) ** 2 + (y_coords - y[i]) ** 2) mask_i = torch.sigmoid(smoothness * (r[i] - distances)) mask_list.append(mask_i) if len(mask_list) == 0: _mask_merge = torch.zeros(1, 1, height, width, device=x.device) else: _mask_merge = torch.stack(mask_list, dim=0).sum(dim=0).unsqueeze(0) _mask_merge = _mask_merge.unsqueeze(0) masks_merge = ( _mask_merge if masks_merge is None else torch.cat([masks_merge, _mask_merge], dim=0) ) masks_merge = torch.clamp(masks_merge, 0, 1) return masks_merge # [B, 1, H, W] if __name__ == "__main__": # pydiffvg.set_use_gpu(torch.cuda.is_available()) model = LightSourceRegressor(num_lights=4).cuda() x = torch.randn(8, 3, 512, 512, device="cuda") y = model.forward_render(x) print(y.shape)