Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import math | |
| from torch.nn.init import zeros_ | |
| from typing import Any | |
| from torch.autograd import Function | |
| from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd | |
| import triton | |
| import triton.language as tl | |
| def forward_kernel( | |
| B: tl.constexpr, | |
| H: tl.constexpr, # image_size_h | |
| W: tl.constexpr, # image_size_w | |
| G: tl.constexpr, # num_channels_per_group | |
| C: tl.constexpr, # num_groups | |
| K: tl.constexpr, # kernel size | |
| input_ptr, # input features [B, H, W, G, C] | |
| deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K] | |
| weights_ptr, # weights [B, H, W, G, K] | |
| out_ptr, # out [B, H, W, G, C] | |
| BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group | |
| ): | |
| pid = tl.program_id(0) | |
| wid = pid % W | |
| hid = pid // W % H | |
| gid = pid // (W * H) % G | |
| bid = pid // (W * H * G) | |
| id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) | |
| common_offset = bid*H*W*G + hid*W*G + wid*G + gid | |
| batch_base = bid * H * W * G * C | |
| for block_base in tl.static_range(0, C, BLOCK_SIZE): | |
| buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) | |
| block_offset = tl.arange(0, BLOCK_SIZE) + block_base | |
| block_mask = (block_offset < C) & id_mask | |
| for k in tl.static_range(K): | |
| deformable_offset = (common_offset * K + k) * 2 | |
| x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid | |
| y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid | |
| floor_x = x.to(tl.int32) | |
| floor_y = y.to(tl.int32) | |
| ceil_x = floor_x + 1 | |
| ceil_y = floor_y + 1 | |
| # load top left | |
| tl_weight = (ceil_x - x) * (ceil_y - y) | |
| tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE | |
| tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H) | |
| # load top right | |
| tr_weight = (x - floor_x) * (ceil_y - y) | |
| tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE | |
| tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0) | |
| # load bottom left | |
| bl_weight = (ceil_x - x) * (y - floor_y) | |
| bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE | |
| bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0) | |
| # load bottom right | |
| br_weight = (x - floor_x) * (y - floor_y) | |
| br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE | |
| br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0) | |
| # load dynamic weight and mask | |
| weights_offset = common_offset*K + k | |
| weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) | |
| tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0) | |
| tl_block_input = tl_block_input * tl_weight | |
| # load top right | |
| tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0) | |
| tr_block_input = tr_block_input * tr_weight | |
| # load bottom left | |
| bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0) | |
| bl_block_input = bl_block_input * bl_weight | |
| # load bottom right | |
| br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0) | |
| br_block_input = br_block_input * br_weight | |
| # sampled | |
| sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input | |
| weighted_sampled_input = sampled_input * weight | |
| buffer = buffer + weighted_sampled_input | |
| # store to out_ptr | |
| tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask) | |
| def backward_kernel( | |
| B: tl.constexpr, | |
| H: tl.constexpr, # image_size_h | |
| W: tl.constexpr, # image_size_w | |
| G: tl.constexpr, # num_groups | |
| C: tl.constexpr, # num_channels_per_group | |
| K: tl.constexpr, # kernel size | |
| input_ptr, # input features [B, H, W, G, C] | |
| deformable_ptr, # deformable offsets [B, H, W, G, K, 2] | |
| weights_ptr, # weights [B, H, W, G, K] | |
| grad_ptr, # out [B, H, W, G, C] | |
| grad_input_ptr, # input features [B, H, W, G, C] | |
| grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2] | |
| grad_weights_ptr, # weights [B, H, W, G, K] | |
| BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group | |
| ): | |
| pid = tl.program_id(0) | |
| wid = pid % W | |
| hid = pid // W % H | |
| gid = pid // (W * H) % G | |
| bid = pid // (W * H * G) | |
| id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) | |
| common_offset = bid*H*W*G + hid*W*G + wid*G + gid | |
| batch_base = bid * H * W * G * C | |
| for k in tl.static_range(K): | |
| # load dynamic weight and mask | |
| weights_offset = common_offset*K + k | |
| weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) | |
| dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) | |
| dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) | |
| dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty) | |
| deformable_offset = (common_offset * K + k)*2 | |
| x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid | |
| y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid | |
| for block_base in tl.static_range(0, C, BLOCK_SIZE): | |
| block_offset = tl.arange(0, BLOCK_SIZE) + block_base | |
| block_mask = (block_offset < C) & id_mask | |
| grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0) | |
| dods = weight*grad | |
| floor_x = x.to(tl.int32) | |
| floor_y = y.to(tl.int32) | |
| ceil_x = floor_x + 1 | |
| ceil_y = floor_y + 1 | |
| # load top left | |
| tl_weight = (ceil_x - x) * (ceil_y - y) | |
| tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset | |
| tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)) | |
| tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0) | |
| tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0) | |
| dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y) | |
| dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x) | |
| dodw = dodw + tl_block_input_dot_grad * tl_weight | |
| dodtl = dods * tl_weight | |
| tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl) | |
| # load top right | |
| tr_weight = (x - floor_x) * (ceil_y - y) | |
| tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset | |
| tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)) | |
| tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0) | |
| tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0) | |
| dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y) | |
| dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x) | |
| dodw = dodw + tr_block_input_dot_grad*tr_weight | |
| dodtr = dods * tr_weight | |
| tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr) | |
| # load bottom left | |
| bl_weight = (ceil_x - x) * (y - floor_y) | |
| bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset | |
| bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)) | |
| bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0) | |
| bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0) | |
| dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y) | |
| dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x) | |
| dodw = dodw + bl_block_input_dot_grad*bl_weight | |
| dodbl = dods * bl_weight | |
| tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl) | |
| # load bottom right | |
| br_weight = (x - floor_x) * (y - floor_y) | |
| br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset | |
| br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)) | |
| br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0) | |
| br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask | |
| dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y) | |
| dody = dody + 1 * br_block_input_dot_grad * (x - floor_x) | |
| dodw = dodw + br_block_input_dot_grad*br_weight | |
| dodbr = dods * br_weight | |
| tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr) | |
| dodx = dodx * weight | |
| dody = dody * weight | |
| tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask) | |
| tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask) | |
| tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask) | |
| class DCNFunction(Function): | |
| def forward(ctx: Any, inputs, deformables, weights) -> Any: | |
| B, H, W, G, C = inputs.shape | |
| _, _, _, _, K, _ = deformables.shape | |
| out = torch.zeros_like(inputs) | |
| grid = lambda META: (B * H * W * G,) | |
| forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out) | |
| ctx.save_for_backward(inputs, deformables, weights) | |
| return out | |
| def backward(ctx: Any, *grad_outputs: Any) -> Any: | |
| grad_output = grad_outputs[0].contiguous() | |
| inputs, deformables, weights = ctx.saved_tensors | |
| B, H, W, G, C = inputs.shape | |
| _, _, _, _, K, _ = deformables.shape | |
| grad_inputs = torch.zeros_like(inputs) | |
| grad_deformables = torch.zeros_like(deformables) | |
| grad_weights = torch.zeros_like(weights) | |
| grid = lambda META: (B * H * W * G,) | |
| backward_kernel[grid]( | |
| B, H, W, G, C, K, | |
| inputs, | |
| deformables, | |
| weights, | |
| grad_output, | |
| grad_inputs, | |
| grad_deformables, | |
| grad_weights, | |
| ) | |
| return (grad_inputs, grad_deformables, grad_weights) | |
| class MultiScaleDCN(nn.Module): | |
| def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.groups = groups | |
| self.channels = channels | |
| self.kernels = kernels | |
| self.v = nn.Linear(in_channels, groups * channels, bias=True) | |
| self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True) | |
| self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False) | |
| self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True) | |
| self.out = nn.Linear(groups * channels, in_channels) | |
| self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False) | |
| self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True) | |
| self.max_scale = 6 | |
| self._init_weights() | |
| def _init_weights(self): | |
| zeros_(self.qk_deformables.weight.data) | |
| zeros_(self.qk_scales.weight.data) | |
| zeros_(self.qk_deformables.bias.data) | |
| zeros_(self.qk_weights.weight.data) | |
| zeros_(self.v.bias.data) | |
| zeros_(self.out.bias.data) | |
| num_prior = int(self.kernels ** 0.5) | |
| dx = torch.linspace(-1, 1, num_prior, device="cuda") | |
| dy = torch.linspace(-1, 1, num_prior, device="cuda") | |
| dxy = torch.meshgrid([dx, dy], indexing="xy") | |
| dxy = torch.stack(dxy, dim=-1) | |
| dxy = dxy.view(-1, 2) | |
| self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy | |
| for i in range(self.groups): | |
| scale = (i+1)/self.groups - 0.0001 | |
| inv_scale = math.log((scale)/(1-scale)) | |
| self.deformables_scale.data[..., i, :, :] = inv_scale | |
| def forward(self, x): | |
| B, H, W, _ = x.shape | |
| v = self.v(x).view(B, H, W, self.groups, self.channels) | |
| deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2) | |
| scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale | |
| deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale | |
| weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels) | |
| out = DCNFunction.apply(v, deformables, weights) | |
| out = out.view(B, H, W, -1) | |
| out = self.out(out) | |
| return out |