91 lines
3.4 KiB
Python
91 lines
3.4 KiB
Python
import os
|
|
import time
|
|
import functools
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
# import tinycudann as tcnn
|
|
parent_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
''' Dense 3D grid
|
|
'''
|
|
class DenseGrid(nn.Module):
|
|
def __init__(self, channels, world_size, **kwargs):
|
|
super(DenseGrid, self).__init__()
|
|
self.channels = channels
|
|
self.world_size = world_size
|
|
# self.xyz_max = xyz_max
|
|
# self.xyz_min = xyz_min
|
|
# self.register_buffer('xyz_min', torch.Tensor(xyz_min))
|
|
# self.register_buffer('xyz_max', torch.Tensor(xyz_max))
|
|
self.grid = nn.Parameter(torch.ones([1, channels, *world_size]))
|
|
|
|
def forward(self, xyz):
|
|
'''
|
|
xyz: global coordinates to query
|
|
'''
|
|
shape = xyz.shape[:-1]
|
|
xyz = xyz.reshape(1,1,1,-1,3)
|
|
ind_norm = ((xyz - self.xyz_min) / (self.xyz_max - self.xyz_min)).flip((-1,)) * 2 - 1
|
|
out = F.grid_sample(self.grid, ind_norm, mode='bilinear', align_corners=True)
|
|
out = out.reshape(self.channels,-1).T.reshape(*shape,self.channels)
|
|
# if self.channels == 1:
|
|
# out = out.squeeze(-1)
|
|
return out
|
|
|
|
def scale_volume_grid(self, new_world_size):
|
|
if self.channels == 0:
|
|
self.grid = nn.Parameter(torch.ones([1, self.channels, *new_world_size]))
|
|
else:
|
|
self.grid = nn.Parameter(
|
|
F.interpolate(self.grid.data, size=tuple(new_world_size), mode='trilinear', align_corners=True))
|
|
def set_aabb(self, xyz_max, xyz_min):
|
|
self.register_buffer('xyz_min', torch.Tensor(xyz_min))
|
|
self.register_buffer('xyz_max', torch.Tensor(xyz_max))
|
|
def get_dense_grid(self):
|
|
return self.grid
|
|
|
|
@torch.no_grad()
|
|
def __isub__(self, val):
|
|
self.grid.data -= val
|
|
return self
|
|
|
|
def extra_repr(self):
|
|
return f'channels={self.channels}, world_size={self.world_size}'
|
|
|
|
# class HashHexPlane(nn.Module):
|
|
# def __init__(self,hparams,
|
|
# desired_resolution=1024,
|
|
# base_solution=128,
|
|
# n_levels=4,
|
|
# ):
|
|
# super(HashHexPlane, self).__init__()
|
|
|
|
# per_level_scale = np.exp2(np.log2(desired_resolution / base_solution) / (int(n_levels) - 1))
|
|
# encoding_2d_config = {
|
|
# "otype": "Grid",
|
|
# "type": "Hash",
|
|
# "n_levels": n_levels,
|
|
# "n_features_per_level": 2,
|
|
# "base_resolution": base_solution,
|
|
# "per_level_scale":per_level_scale,
|
|
# }
|
|
# self.xy = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config)
|
|
# self.yz = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config)
|
|
# self.xz = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config)
|
|
# self.xt = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config)
|
|
# self.yt = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config)
|
|
# self.zt = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config)
|
|
|
|
# self.feat_dim = n_levels * 2 *3
|
|
|
|
# def forward(self, x, bound):
|
|
# x = (x + bound) / (2 * bound) # zyq: map to [0, 1]
|
|
# xy_feat = self.xy(x[:, [0, 1]])
|
|
# yz_feat = self.yz(x[:, [0, 2]])
|
|
# xz_feat = self.xz(x[:, [1, 2]])
|
|
# xt_feat = self.xt(x[:, []])
|
|
# return torch.cat([xy_feat, yz_feat, xz_feat], dim=-1) |