Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
# Please refer to original code: https://github.com/NVlabs/instant-ngp | |
# and the pytorch wrapper from https://github.com/ashawkey/torch-ngp | |
import os | |
import torch | |
from .. import custom_ops | |
from torch.cuda.amp import custom_bwd, custom_fwd | |
_plugin = None | |
_null_tensor = torch.empty([0]) | |
def _init(): | |
global _plugin | |
if _plugin is None: | |
_plugin = custom_ops.get_plugin( | |
module_name='hash_sample_plugin', | |
sources=['hash_sample.cpp', 'hash_sample.cu'], | |
headers=['hash_sample.h', 'utils.h'], | |
source_dir=os.path.dirname(__file__), | |
extra_cuda_cflags=['--use_fast_math'], | |
) | |
return True | |
def hash_sample(x, h, offsets, beta=2, base_res=16, calc_grad=True, mode='fast_hash'): | |
"""Hash-table look up and d-linear interpolation | |
x: B x N x D coordinates | |
h: B x L x T x C hash-tables | |
offsets: L resolutions | |
""" | |
assert x.device.type == 'cuda' | |
assert (x.size(-1) == 3) or (x.size(-1) == 2), "currently only 2D/3D is implemented" | |
_init() | |
return _hash_sample_cuda(mode).apply(x, h, offsets, beta, base_res, calc_grad) | |
_hash_sample_cuda_cache = dict() | |
def _hash_sample_cuda(mode='fast_hash'): | |
"""CUDA implementation of hash-table look-up | |
""" | |
if mode in _hash_sample_cuda_cache: | |
return _hash_sample_cuda_cache[mode] | |
if mode == 'fast_hash': | |
h_mode = 0 | |
elif mode == 'grid_hash': | |
h_mode = 1 | |
else: | |
raise NotImplementedError('only two types are supported now.') | |
class HashSampleCuda(torch.autograd.Function): | |
def forward(ctx, inputs, embeddings, offsets, beta, base_resolution, calc_grad_inputs=False): | |
# inputs: [B, N, D], float in [0, 1] | |
# embeddings: [B, sO, C], float | |
# offsets: [L + 1], int | |
# RETURN: [B, N, F], float | |
inputs = inputs.contiguous() | |
embeddings = embeddings.contiguous() | |
offsets = offsets.contiguous().to(inputs.device) | |
B, N, D = inputs.shape # batch size, # of samples, coord dim | |
L = offsets.shape[0] - 1 # level | |
C = embeddings.shape[-1] # embedding dim for each level | |
H = base_resolution # base resolution | |
outputs = torch.zeros(B, N, L * C, device=inputs.device, dtype=inputs.dtype) | |
if calc_grad_inputs: | |
dy_dx = torch.zeros(B, N, L * D * C).to(inputs.device, dtype=inputs.dtype) | |
else: | |
dy_dx = torch.zeros(1).to(inputs.device, dtype=inputs.dtype) | |
_plugin.hash_encode_forward(inputs, embeddings, offsets, outputs, beta, B, N, D, C, L, H, calc_grad_inputs, dy_dx, h_mode) | |
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) | |
ctx.dims = [B, N, D, C, L, H, beta] | |
ctx.calc_grad_inputs = calc_grad_inputs | |
return outputs | |
def backward(ctx, grad): | |
# grad: [B, L * C] | |
grad = grad.contiguous() | |
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors | |
B, N, D, C, L, H, beta = ctx.dims | |
calc_grad_inputs = ctx.calc_grad_inputs | |
grad_embeddings = torch.zeros_like(embeddings) | |
if calc_grad_inputs: | |
grad_inputs = torch.zeros_like(inputs) | |
else: | |
grad_inputs = torch.zeros(1).to(inputs.device, dtype=inputs.dtype) | |
_plugin.hash_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, beta, B, N, D, C, L, H, calc_grad_inputs, dy_dx, grad_inputs, h_mode) | |
if calc_grad_inputs: | |
return grad_inputs, grad_embeddings, None, None, None, None | |
else: | |
return None, grad_embeddings, None, None, None, None | |
# Add to cache. | |
_hash_sample_cuda_cache[mode] = HashSampleCuda | |
return HashSampleCuda |