Spaces:
Running
on
Zero
Running
on
Zero
| # coding=utf-8 | |
| # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. | |
| # SPDX-License-Identifier: BSD-3-Clause | |
| # | |
| # Redistribution and use in source and binary forms, with or without | |
| # modification, are permitted provided that the following conditions are met: | |
| # | |
| # 1. Redistributions of source code must retain the above copyright notice, this | |
| # list of conditions and the following disclaimer. | |
| # | |
| # 2. Redistributions in binary form must reproduce the above copyright notice, | |
| # this list of conditions and the following disclaimer in the documentation | |
| # and/or other materials provided with the distribution. | |
| # | |
| # 3. Neither the name of the copyright holder nor the names of its | |
| # contributors may be used to endorse or promote products derived from | |
| # this software without specific prior written permission. | |
| # | |
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
| # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
| # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| # | |
| import math | |
| import torch | |
| # triton will only be avaiable on cuda installations of pytorch | |
| import triton | |
| import triton.language as tl | |
| BLOCK_SIZE_BATCH = 4 | |
| BLOCK_SIZE_NZ = 8 | |
| BLOCK_SIZE_POUT = 8 | |
| def _disco_s2_contraction_kernel( | |
| inz_ptr, | |
| vnz_ptr, | |
| nnz, | |
| inz_stride_ii, | |
| inz_stride_nz, | |
| vnz_stride, | |
| x_ptr, | |
| batch_size, | |
| nlat_in, | |
| nlon_in, | |
| x_stride_b, | |
| x_stride_t, | |
| x_stride_p, | |
| y_ptr, | |
| kernel_size, | |
| nlat_out, | |
| nlon_out, | |
| y_stride_b, | |
| y_stride_f, | |
| y_stride_t, | |
| y_stride_p, | |
| pscale, | |
| backward: tl.constexpr, | |
| BLOCK_SIZE_BATCH: tl.constexpr, | |
| BLOCK_SIZE_NZ: tl.constexpr, | |
| BLOCK_SIZE_POUT: tl.constexpr, | |
| ): | |
| """ | |
| Kernel for the sparse-dense contraction for the S2 DISCO convolution. | |
| """ | |
| pid_batch = tl.program_id(0) | |
| pid_pout = tl.program_id(2) | |
| # pid_nz should always be 0 as we do not account for larger grids in this dimension | |
| pid_nz = tl.program_id(1) # should be always 0 | |
| tl.device_assert(pid_nz == 0) | |
| # create the pointer block for pout | |
| pout = pid_pout * BLOCK_SIZE_POUT + tl.arange(0, BLOCK_SIZE_POUT) | |
| b = pid_batch * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH) | |
| # create pointer blocks for the psi datastructure | |
| iinz = tl.arange(0, BLOCK_SIZE_NZ) | |
| # get the initial pointers | |
| fout_ptrs = inz_ptr + iinz * inz_stride_nz | |
| tout_ptrs = inz_ptr + iinz * inz_stride_nz + inz_stride_ii | |
| tpnz_ptrs = inz_ptr + iinz * inz_stride_nz + 2 * inz_stride_ii | |
| vals_ptrs = vnz_ptr + iinz * vnz_stride | |
| # iterate in a blocked fashion over the non-zero entries | |
| for offs_nz in range(0, nnz, BLOCK_SIZE_NZ): | |
| # load input output latitude coordinate pairs | |
| fout = tl.load( | |
| fout_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1 | |
| ) | |
| tout = tl.load( | |
| tout_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1 | |
| ) | |
| tpnz = tl.load( | |
| tpnz_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1 | |
| ) | |
| # load corresponding values | |
| vals = tl.load( | |
| vals_ptrs + offs_nz * vnz_stride, mask=(offs_nz + iinz < nnz), other=0.0 | |
| ) | |
| # compute the shifted longitude coordinates p+p' to read in a coalesced fashion | |
| tnz = tpnz // nlon_in | |
| pnz = tpnz % nlon_in | |
| # make sure the value is not out of bounds | |
| tl.device_assert(fout < kernel_size) | |
| tl.device_assert(tout < nlat_out) | |
| tl.device_assert(tnz < nlat_in) | |
| tl.device_assert(pnz < nlon_in) | |
| # load corresponding portion of the input array | |
| x_ptrs = ( | |
| x_ptr | |
| + tnz[None, :, None] * x_stride_t | |
| + ((pnz[None, :, None] + pout[None, None, :] * pscale) % nlon_in) | |
| * x_stride_p | |
| + b[:, None, None] * x_stride_b | |
| ) | |
| y_ptrs = ( | |
| y_ptr | |
| + fout[None, :, None] * y_stride_f | |
| + tout[None, :, None] * y_stride_t | |
| + (pout[None, None, :] % nlon_out) * y_stride_p | |
| + b[:, None, None] * y_stride_b | |
| ) | |
| # precompute the mask | |
| mask = ( | |
| (b[:, None, None] < batch_size) and (offs_nz + iinz[None, :, None] < nnz) | |
| ) and (pout[None, None, :] < nlon_out) | |
| # do the actual computation. Backward is essentially just the same operation with swapped tensors. | |
| if not backward: | |
| x = tl.load(x_ptrs, mask=mask, other=0.0) | |
| y = vals[None, :, None] * x | |
| # store it to the output array | |
| tl.atomic_add(y_ptrs, y, mask=mask) | |
| else: | |
| y = tl.load(y_ptrs, mask=mask, other=0.0) | |
| x = vals[None, :, None] * y | |
| # store it to the output array | |
| tl.atomic_add(x_ptrs, x, mask=mask) | |
| def _disco_s2_contraction_fwd(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): | |
| """ | |
| Wrapper function for the triton implementation of the efficient DISCO convolution on the sphere. | |
| Parameters | |
| ---------- | |
| x: torch.Tensor | |
| Input signal on the sphere. Expects a tensor of shape batch_size x channels x nlat_in x nlon_in). | |
| psi : torch.Tensor | |
| Pre-computed convolution tensor. Expects a sparse tensor of shape kernel_size x nlat_out x (nlat_in * nlon_in). | |
| nlon_out: int | |
| Number of longitude points the output should have. | |
| """ | |
| # check the shapes of all input tensors | |
| assert len(psi.shape) == 3 | |
| assert len(x.shape) == 4 | |
| assert psi.is_sparse, "Psi must be a sparse COO tensor" | |
| # TODO: check that Psi is also coalesced | |
| # get the dimensions of the problem | |
| kernel_size, nlat_out, n_in = psi.shape | |
| nnz = psi.indices().shape[-1] | |
| batch_size, n_chans, nlat_in, nlon_in = x.shape | |
| assert nlat_in * nlon_in == n_in | |
| # TODO: check that Psi index vector is of type long | |
| # make sure that the grid-points of the output grid fall onto the grid points of the input grid | |
| assert nlon_in % nlon_out == 0 | |
| pscale = nlon_in // nlon_out | |
| # to simplify things, we merge batch and channel dimensions | |
| x = x.reshape(batch_size * n_chans, nlat_in, nlon_in) | |
| # prepare the output tensor | |
| y = torch.zeros( | |
| batch_size * n_chans, | |
| kernel_size, | |
| nlat_out, | |
| nlon_out, | |
| device=x.device, | |
| dtype=x.dtype, | |
| ) | |
| # determine the grid for the computation | |
| grid = ( | |
| triton.cdiv(batch_size * n_chans, BLOCK_SIZE_BATCH), | |
| 1, | |
| triton.cdiv(nlon_out, BLOCK_SIZE_POUT), | |
| ) | |
| # launch the kernel | |
| _disco_s2_contraction_kernel[grid]( | |
| psi.indices(), | |
| psi.values(), | |
| nnz, | |
| psi.indices().stride(-2), | |
| psi.indices().stride(-1), | |
| psi.values().stride(-1), | |
| x, | |
| batch_size * n_chans, | |
| nlat_in, | |
| nlon_in, | |
| x.stride(0), | |
| x.stride(-2), | |
| x.stride(-1), | |
| y, | |
| kernel_size, | |
| nlat_out, | |
| nlon_out, | |
| y.stride(0), | |
| y.stride(1), | |
| y.stride(-2), | |
| y.stride(-1), | |
| pscale, | |
| False, | |
| BLOCK_SIZE_BATCH, | |
| BLOCK_SIZE_NZ, | |
| BLOCK_SIZE_POUT, | |
| ) | |
| # reshape y back to expose the correct dimensions | |
| y = y.reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out) | |
| return y | |
| def _disco_s2_contraction_bwd(grad_y: torch.Tensor, psi: torch.Tensor, nlon_in: int): | |
| """ | |
| Backward pass for the triton implementation of the efficient DISCO convolution on the sphere. | |
| Parameters | |
| ---------- | |
| grad_y: torch.Tensor | |
| Input gradient on the sphere. Expects a tensor of shape batch_size x channels x kernel_size x nlat_out x nlon_out. | |
| psi : torch.Tensor | |
| Pre-computed convolution tensor. Expects a sparse tensor of shape kernel_size x nlat_out x (nlat_in * nlon_in). | |
| nlon_in: int | |
| Number of longitude points the input used. Is required to infer the correct dimensions | |
| """ | |
| # check the shapes of all input tensors | |
| assert len(psi.shape) == 3 | |
| assert len(grad_y.shape) == 5 | |
| assert psi.is_sparse, "psi must be a sparse COO tensor" | |
| # TODO: check that Psi is also coalesced | |
| # get the dimensions of the problem | |
| kernel_size, nlat_out, n_in = psi.shape | |
| nnz = psi.indices().shape[-1] | |
| assert grad_y.shape[-2] == nlat_out | |
| assert grad_y.shape[-3] == kernel_size | |
| assert n_in % nlon_in == 0 | |
| nlat_in = n_in // nlon_in | |
| batch_size, n_chans, _, _, nlon_out = grad_y.shape | |
| # make sure that the grid-points of the output grid fall onto the grid points of the input grid | |
| assert nlon_in % nlon_out == 0 | |
| pscale = nlon_in // nlon_out | |
| # to simplify things, we merge batch and channel dimensions | |
| grad_y = grad_y.reshape(batch_size * n_chans, kernel_size, nlat_out, nlon_out) | |
| # prepare the output tensor | |
| grad_x = torch.zeros( | |
| batch_size * n_chans, nlat_in, nlon_in, device=grad_y.device, dtype=grad_y.dtype | |
| ) | |
| # determine the grid for the computation | |
| grid = ( | |
| triton.cdiv(batch_size * n_chans, BLOCK_SIZE_BATCH), | |
| 1, | |
| triton.cdiv(nlon_out, BLOCK_SIZE_POUT), | |
| ) | |
| # launch the kernel | |
| _disco_s2_contraction_kernel[grid]( | |
| psi.indices(), | |
| psi.values(), | |
| nnz, | |
| psi.indices().stride(-2), | |
| psi.indices().stride(-1), | |
| psi.values().stride(-1), | |
| grad_x, | |
| batch_size * n_chans, | |
| nlat_in, | |
| nlon_in, | |
| grad_x.stride(0), | |
| grad_x.stride(-2), | |
| grad_x.stride(-1), | |
| grad_y, | |
| kernel_size, | |
| nlat_out, | |
| nlon_out, | |
| grad_y.stride(0), | |
| grad_y.stride(1), | |
| grad_y.stride(-2), | |
| grad_y.stride(-1), | |
| pscale, | |
| True, | |
| BLOCK_SIZE_BATCH, | |
| BLOCK_SIZE_NZ, | |
| BLOCK_SIZE_POUT, | |
| ) | |
| # reshape y back to expose the correct dimensions | |
| grad_x = grad_x.reshape(batch_size, n_chans, nlat_in, nlon_in) | |
| return grad_x | |
| class _DiscoS2ContractionTriton(torch.autograd.Function): | |
| """ | |
| Helper function to make the triton implementation work with PyTorch autograd functionality | |
| """ | |
| def forward(ctx, x: torch.Tensor, psi: torch.Tensor, nlon_out: int): | |
| ctx.save_for_backward(psi) | |
| ctx.nlon_in = x.shape[-1] | |
| return _disco_s2_contraction_fwd(x, psi, nlon_out) | |
| def backward(ctx, grad_output): | |
| (psi,) = ctx.saved_tensors | |
| grad_input = _disco_s2_contraction_bwd(grad_output, psi, ctx.nlon_in) | |
| grad_x = grad_psi = None | |
| return grad_input, None, None | |
| class _DiscoS2TransposeContractionTriton(torch.autograd.Function): | |
| """ | |
| Helper function to make the triton implementation work with PyTorch autograd functionality | |
| """ | |
| def forward(ctx, x: torch.Tensor, psi: torch.Tensor, nlon_out: int): | |
| ctx.save_for_backward(psi) | |
| ctx.nlon_in = x.shape[-1] | |
| return _disco_s2_contraction_bwd(x, psi, nlon_out) | |
| def backward(ctx, grad_output): | |
| (psi,) = ctx.saved_tensors | |
| grad_input = _disco_s2_contraction_fwd(grad_output, psi, ctx.nlon_in) | |
| grad_x = grad_psi = None | |
| return grad_input, None, None | |
| def _disco_s2_contraction_triton(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): | |
| return _DiscoS2ContractionTriton.apply(x, psi, nlon_out) | |
| def _disco_s2_transpose_contraction_triton( | |
| x: torch.Tensor, psi: torch.Tensor, nlon_out: int | |
| ): | |
| return _DiscoS2TransposeContractionTriton.apply(x, psi, nlon_out) | |
| def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): | |
| """ | |
| Reference implementation of the custom contraction as described in [1]. This requires repeated | |
| shifting of the input tensor, which can potentially be costly. For an efficient implementation | |
| on GPU, make sure to use the custom kernel written in Triton. | |
| """ | |
| assert len(psi.shape) == 3 | |
| assert len(x.shape) == 4 | |
| psi = psi.to(x.device) | |
| batch_size, n_chans, nlat_in, nlon_in = x.shape | |
| kernel_size, nlat_out, _ = psi.shape | |
| assert psi.shape[-1] == nlat_in * nlon_in | |
| assert nlon_in % nlon_out == 0 | |
| assert nlon_in >= nlat_out | |
| pscale = nlon_in // nlon_out | |
| # add a dummy dimension for nkernel and move the batch and channel dims to the end | |
| x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1) | |
| x = x.expand(kernel_size, -1, -1, -1) | |
| y = torch.zeros( | |
| nlon_out, | |
| kernel_size, | |
| nlat_out, | |
| batch_size * n_chans, | |
| device=x.device, | |
| dtype=x.dtype, | |
| ) | |
| for pout in range(nlon_out): | |
| # sparse contraction with psi | |
| y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1)) | |
| # we need to repeatedly roll the input tensor to faciliate the shifted multiplication | |
| x = torch.roll(x, -pscale, dims=2) | |
| # reshape y back to expose the correct dimensions | |
| y = y.permute(3, 1, 2, 0).reshape( | |
| batch_size, n_chans, kernel_size, nlat_out, nlon_out | |
| ) | |
| return y | |
| def _disco_s2_transpose_contraction_torch( | |
| x: torch.Tensor, psi: torch.Tensor, nlon_out: int | |
| ): | |
| """ | |
| Reference implementation of the custom contraction as described in [1]. This requires repeated | |
| shifting of the input tensor, which can potentially be costly. For an efficient implementation | |
| on GPU, make sure to use the custom kernel written in Triton. | |
| """ | |
| assert len(psi.shape) == 3 | |
| assert len(x.shape) == 5 | |
| psi = psi.to(x.device) | |
| batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape | |
| kernel_size, _, n_out = psi.shape | |
| assert psi.shape[-2] == nlat_in | |
| assert n_out % nlon_out == 0 | |
| nlat_out = n_out // nlon_out | |
| assert nlon_out >= nlat_in | |
| pscale = nlon_out // nlon_in | |
| # we do a semi-transposition to faciliate the computation | |
| inz = psi.indices() | |
| tout = inz[2] // nlon_out | |
| pout = inz[2] % nlon_out | |
| # flip the axis of longitudes | |
| pout = nlon_out - 1 - pout | |
| tin = inz[1] | |
| inz = torch.stack([inz[0], tout, tin * nlon_out + pout], dim=0) | |
| psi_mod = torch.sparse_coo_tensor( | |
| inz, psi.values(), size=(kernel_size, nlat_out, nlat_in * nlon_out) | |
| ) | |
| # interleave zeros along the longitude dimension to allow for fractional offsets to be considered | |
| x_ext = torch.zeros( | |
| kernel_size, | |
| nlat_in, | |
| nlon_out, | |
| batch_size * n_chans, | |
| device=x.device, | |
| dtype=x.dtype, | |
| ) | |
| x_ext[:, :, ::pscale, :] = x.reshape( | |
| batch_size * n_chans, kernel_size, nlat_in, nlon_in | |
| ).permute(1, 2, 3, 0) | |
| # we need to go backwards through the vector, so we flip the axis | |
| x_ext = x_ext.contiguous() | |
| y = torch.zeros( | |
| kernel_size, | |
| nlon_out, | |
| nlat_out, | |
| batch_size * n_chans, | |
| device=x.device, | |
| dtype=x.dtype, | |
| ) | |
| for pout in range(nlon_out): | |
| # we need to repeatedly roll the input tensor to faciliate the shifted multiplication | |
| # TODO: double-check why this has to happen first | |
| x_ext = torch.roll(x_ext, -1, dims=2) | |
| # sparse contraction with the modified psi | |
| y[:, pout, :, :] = torch.bmm( | |
| psi_mod, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1) | |
| ) | |
| # sum over the kernel dimension and reshape to the correct output size | |
| y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out) | |
| return y | |