Spaces:
Build error
Build error
File size: 1,007 Bytes
94ada0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import torch
from .. import custom_ops
_plugin = None
def _init():
global _plugin
if _plugin is None:
_plugin = custom_ops.get_plugin(
module_name='nerf_utils_plugin',
sources=['nerf_utils.cu'],
headers=['utils.h'],
source_dir=os.path.dirname(__file__),
extra_cuda_cflags=['--use_fast_math'],
)
return True
def topp_masking(w, p=0.99):
"""
w: B x N x S normalized (S number of samples)
p: top-P used
"""
# _init()
w_sorted, w_indices = w.sort(dim=-1, descending=True)
w_mask = w_sorted.cumsum(-1).lt(p)
w_mask = torch.cat([torch.ones_like(w_mask[...,:1]), w_mask[..., :-1]], -1)
w_mask = w_mask.scatter(-1, w_indices, w_mask)
# w_mask = torch.zeros_like(w).bool()
# _plugin.topp_masking(w_indices.int(), w_sorted, w_mask, p, w.size(0), w.size(1), w.size(2))
return w_mask |