Spaces:
Runtime error
Runtime error
| # flake8: noqa | |
| from torch.autograd import Function, Variable | |
| from torch.nn.modules.module import Module | |
| import channelnorm_cuda | |
| class ChannelNormFunction(Function): | |
| def forward(ctx, input1, norm_deg=2): | |
| assert input1.is_contiguous() | |
| b, _, h, w = input1.size() | |
| output = input1.new(b, 1, h, w).zero_() | |
| channelnorm_cuda.forward(input1, output, norm_deg) | |
| ctx.save_for_backward(input1, output) | |
| ctx.norm_deg = norm_deg | |
| return output | |
| def backward(ctx, grad_output): | |
| input1, output = ctx.saved_tensors | |
| grad_input1 = Variable(input1.new(input1.size()).zero_()) | |
| channelnorm_cuda.backward(input1, output, grad_output.data, | |
| grad_input1.data, ctx.norm_deg) | |
| return grad_input1, None | |
| class ChannelNorm(Module): | |
| def __init__(self, norm_deg=2): | |
| super(ChannelNorm, self).__init__() | |
| self.norm_deg = norm_deg | |
| def forward(self, input1): | |
| return ChannelNormFunction.apply(input1, self.norm_deg) | |