Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| import os | |
| import torch | |
| from .. import custom_ops | |
| _plugin = None | |
| def _init(): | |
| global _plugin | |
| if _plugin is None: | |
| _plugin = custom_ops.get_plugin( | |
| module_name='nerf_utils_plugin', | |
| sources=['nerf_utils.cu'], | |
| headers=['utils.h'], | |
| source_dir=os.path.dirname(__file__), | |
| extra_cuda_cflags=['--use_fast_math'], | |
| ) | |
| return True | |
| def topp_masking(w, p=0.99): | |
| """ | |
| w: B x N x S normalized (S number of samples) | |
| p: top-P used | |
| """ | |
| # _init() | |
| w_sorted, w_indices = w.sort(dim=-1, descending=True) | |
| w_mask = w_sorted.cumsum(-1).lt(p) | |
| w_mask = torch.cat([torch.ones_like(w_mask[...,:1]), w_mask[..., :-1]], -1) | |
| w_mask = w_mask.scatter(-1, w_indices, w_mask) | |
| # w_mask = torch.zeros_like(w).bool() | |
| # _plugin.topp_masking(w_indices.int(), w_sorted, w_mask, p, w.size(0), w.size(1), w.size(2)) | |
| return w_mask |