184 lines
6.4 KiB
Python
184 lines
6.4 KiB
Python
import itertools
|
|
import logging as log
|
|
from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def get_normalized_directions(directions):
|
|
"""SH encoding must be in the range [0, 1]
|
|
|
|
Args:
|
|
directions: batch of directions
|
|
"""
|
|
return (directions + 1.0) / 2.0
|
|
|
|
|
|
def normalize_aabb(pts, aabb):
|
|
return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0
|
|
def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor:
|
|
grid_dim = coords.shape[-1]
|
|
|
|
if grid.dim() == grid_dim + 1:
|
|
# no batch dimension present, need to add it
|
|
grid = grid.unsqueeze(0)
|
|
if coords.dim() == 2:
|
|
coords = coords.unsqueeze(0)
|
|
|
|
if grid_dim == 2 or grid_dim == 3:
|
|
grid_sampler = F.grid_sample
|
|
else:
|
|
raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only "
|
|
f"implemented for 2 and 3D data.")
|
|
|
|
coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:]))
|
|
B, feature_dim = grid.shape[:2]
|
|
n = coords.shape[-2]
|
|
interp = grid_sampler(
|
|
grid, # [B, feature_dim, reso, ...]
|
|
coords, # [B, 1, ..., n, grid_dim]
|
|
align_corners=align_corners,
|
|
mode='bilinear', padding_mode='border')
|
|
interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim]
|
|
interp = interp.squeeze() # [B?, n, feature_dim?]
|
|
return interp
|
|
|
|
def init_grid_param(
|
|
grid_nd: int,
|
|
in_dim: int,
|
|
out_dim: int,
|
|
reso: Sequence[int],
|
|
a: float = 0.1,
|
|
b: float = 0.5):
|
|
assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension"
|
|
has_time_planes = in_dim == 4
|
|
assert grid_nd <= in_dim
|
|
coo_combs = list(itertools.combinations(range(in_dim), grid_nd))
|
|
grid_coefs = nn.ParameterList()
|
|
for ci, coo_comb in enumerate(coo_combs):
|
|
new_grid_coef = nn.Parameter(torch.empty(
|
|
[1, out_dim] + [reso[cc] for cc in coo_comb[::-1]]
|
|
))
|
|
if has_time_planes and 3 in coo_comb: # Initialize time planes to 1
|
|
nn.init.ones_(new_grid_coef)
|
|
else:
|
|
nn.init.uniform_(new_grid_coef, a=a, b=b)
|
|
grid_coefs.append(new_grid_coef)
|
|
|
|
return grid_coefs
|
|
|
|
|
|
def interpolate_ms_features(pts: torch.Tensor,
|
|
ms_grids: Collection[Iterable[nn.Module]],
|
|
grid_dimensions: int,
|
|
concat_features: bool,
|
|
num_levels: Optional[int],
|
|
) -> torch.Tensor:
|
|
coo_combs = list(itertools.combinations(
|
|
range(pts.shape[-1]), grid_dimensions)
|
|
)
|
|
if num_levels is None:
|
|
num_levels = len(ms_grids)
|
|
multi_scale_interp = [] if concat_features else 0.
|
|
grid: nn.ParameterList
|
|
for scale_id, grid in enumerate(ms_grids[:num_levels]):
|
|
interp_space = 1.
|
|
for ci, coo_comb in enumerate(coo_combs):
|
|
# interpolate in plane
|
|
feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso
|
|
interp_out_plane = (
|
|
grid_sample_wrapper(grid[ci], pts[..., coo_comb])
|
|
.view(-1, feature_dim)
|
|
)
|
|
# compute product over planes
|
|
interp_space = interp_space * interp_out_plane
|
|
|
|
# combine over scales
|
|
if concat_features:
|
|
multi_scale_interp.append(interp_space)
|
|
else:
|
|
multi_scale_interp = multi_scale_interp + interp_space
|
|
|
|
if concat_features:
|
|
multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)
|
|
return multi_scale_interp
|
|
|
|
|
|
class HexPlaneField(nn.Module):
|
|
def __init__(
|
|
self,
|
|
|
|
bounds,
|
|
planeconfig,
|
|
multires
|
|
) -> None:
|
|
super().__init__()
|
|
aabb = torch.tensor([[bounds,bounds,bounds],
|
|
[-bounds,-bounds,-bounds]])
|
|
self.aabb = nn.Parameter(aabb, requires_grad=False)
|
|
self.grid_config = [planeconfig]
|
|
self.multiscale_res_multipliers = multires
|
|
self.concat_features = True
|
|
|
|
# 1. Init planes
|
|
self.grids = nn.ModuleList()
|
|
self.feat_dim = 0
|
|
for res in self.multiscale_res_multipliers:
|
|
# initialize coordinate grid
|
|
config = self.grid_config[0].copy()
|
|
# Resolution fix: multi-res only on spatial planes
|
|
config["resolution"] = [
|
|
r * res for r in config["resolution"][:3]
|
|
] + config["resolution"][3:]
|
|
gp = init_grid_param(
|
|
grid_nd=config["grid_dimensions"],
|
|
in_dim=config["input_coordinate_dim"],
|
|
out_dim=config["output_coordinate_dim"],
|
|
reso=config["resolution"],
|
|
)
|
|
# shape[1] is out-dim - Concatenate over feature len for each scale
|
|
if self.concat_features:
|
|
self.feat_dim += gp[-1].shape[1]
|
|
else:
|
|
self.feat_dim = gp[-1].shape[1]
|
|
self.grids.append(gp)
|
|
# print(f"Initialized model grids: {self.grids}")
|
|
print("feature_dim:",self.feat_dim)
|
|
@property
|
|
def get_aabb(self):
|
|
return self.aabb[0], self.aabb[1]
|
|
def set_aabb(self,xyz_max, xyz_min):
|
|
aabb = torch.tensor([
|
|
xyz_max,
|
|
xyz_min
|
|
],dtype=torch.float32)
|
|
self.aabb = nn.Parameter(aabb,requires_grad=False)
|
|
print("Voxel Plane: set aabb=",self.aabb)
|
|
|
|
def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):
|
|
"""Computes and returns the densities."""
|
|
# breakpoint()
|
|
pts = normalize_aabb(pts, self.aabb)
|
|
pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4]
|
|
|
|
pts = pts.reshape(-1, pts.shape[-1])
|
|
features = interpolate_ms_features(
|
|
pts, ms_grids=self.grids, # noqa
|
|
grid_dimensions=self.grid_config[0]["grid_dimensions"],
|
|
concat_features=self.concat_features, num_levels=None)
|
|
if len(features) < 1:
|
|
features = torch.zeros((0, 1)).to(features.device)
|
|
|
|
|
|
return features
|
|
|
|
def forward(self,
|
|
pts: torch.Tensor,
|
|
timestamps: Optional[torch.Tensor] = None):
|
|
|
|
features = self.get_density(pts, timestamps)
|
|
|
|
return features
|