4DGaussians/scene/deformation.py
guanjunwu 8bf73f413d 123
2023-09-24 19:51:57 +08:00

230 lines
10 KiB
Python

import functools
import math
import os
import time
import tinycudann as tcnn
from tkinter import W
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load
import torch.nn.init as init
class TriPlaneGrid(nn.Module):
def __init__(self,
desired_resolution=256,
base_solution=32,
n_levels=4,
):
super(TriPlaneGrid, self).__init__()
per_level_scale = np.exp2(np.log2(desired_resolution / base_solution) / (int(n_levels) - 1))
encoding_2d_config = {
"otype": "Grid",
"type": "Dense",
"n_levels": n_levels,
"n_features_per_level": 4,
"base_resolution": base_solution,
"per_level_scale":per_level_scale,
}
self.xy = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config,dtype=torch.float32)
self.yz = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config,dtype=torch.float32)
self.xz = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config,dtype=torch.float32)
self.feat_dim = n_levels * 4 *3
def forward(self, x,bound):
x = (x + bound) / (2 * bound) # zyq: map to [0, 1]
xy_feat = self.xy(x[:, [0, 1]])
yz_feat = self.yz(x[:, [0, 2]])
xz_feat = self.xz(x[:, [1, 2]])
return torch.cat([xy_feat, yz_feat, xz_feat], dim=-1)
class TriPlanetimeGrid(nn.Module):
def __init__(self,
desired_resolution=256,
base_solution=16,
n_levels=6,
):
super(TriPlanetimeGrid, self).__init__()
per_level_scale = np.exp2(np.log2(desired_resolution / base_solution) / (int(n_levels) - 1))
encoding_2d_config = {
"otype": "Grid",
"type": "Dense",
"n_levels": n_levels,
"n_features_per_level": 4,
"base_resolution": base_solution,
"per_level_scale":per_level_scale,
}
self.xt = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config,dtype=torch.float32)
self.yt = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config,dtype=torch.float32)
self.zt = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config,dtype=torch.float32)
self.feat_dim = n_levels * 4 *3
def forward(self, x, time, bound):
x = (x + bound) / (2 * bound) # zyq: map to [0, 1]
xt = torch.cat([x[:,0:1],time],-1)
yt = torch.cat([x[:,1:2],time],-1)
zt = torch.cat([x[:,2:3],time],-1)
xt_feat = self.xt(xt)
yt_feat = self.yt(yt)
xt_feat = self.xt(zt)
return torch.cat([xt_feat, yt_feat, xt_feat], dim=-1)
class Deformation(nn.Module):
def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[],):
super(Deformation, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_time = input_ch_time
self.skips = skips
self.grid = TriPlaneGrid()
self.timegrid = TriPlanetimeGrid()
self._time, self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform = self.create_net()
def create_net(self):
layers = [nn.Linear(self.input_ch + self.input_ch_time, self.W)]
for i in range(self.D):
layer = nn.Linear
in_channels = self.W
layers += [layer(in_channels, self.W)]
self.mlp_out = nn.Linear(self.W,self.W)
# self.grid_out = nn.Linear(self.grid.feat_dim+self.timegrid.feat_dim,self.W//2)
self.feature_out = nn.Linear(self.W+self.grid.feat_dim + self.timegrid.feat_dim,self.W)
output_dim = self.W
return nn.ModuleList(layers), nn.Sequential(nn.Linear(output_dim,self.W),nn.ReLU(),nn.Linear(self.W, 3)), nn.Sequential(nn.Linear(output_dim,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\
nn.Sequential(nn.Linear(output_dim,self.W),nn.ReLU(), nn.Linear(self.W, 4)), nn.Sequential(nn.Linear(output_dim,self.W),nn.ReLU(),nn.Linear(self.W, 1))
def query_time(self, rays_pts_emb, scales_emb, rotations_emb, t, net, time_emb):
h = torch.cat([rays_pts_emb, scales_emb, rotations_emb, t], dim=-1)
for i, l in enumerate(net):
h = net[i](h)
h = F.relu(h)
mlp_feature = self.mlp_out(h)
# mlp_feature = F.relu(mlp_feature)
grid_feature = self.grid(rays_pts_emb[:,:3],bound=2)
time_feature = self.timegrid(rays_pts_emb[:,:3],time_emb[:,0:1],bound=2)
voxel_feature = torch.cat([grid_feature,time_feature],-1)
# voxel_feature = self.grid_out(voxel_feature)
h = torch.cat([mlp_feature,voxel_feature],-1)
h = self.feature_out(h)
h = F.relu(h)
# h = self.out_layers(h)
# h = F.sigmoid(h) # map to [0,1]
# h = self.grid(h)
return h
def forward(self, rays_pts_emb, scales_emb, rotations_emb, ts, time_emb):
hidden = self.query_time(rays_pts_emb, rotations_emb, scales_emb, ts, self._time, time_emb).float()
dx = self.pos_deform(hidden)
pts = rays_pts_emb[:, :3] + dx
ds = self.scales_deform(hidden)
scales = scales_emb[:,:3] + ds
dr = self.rotations_deform(hidden)
rotations = rotations_emb[:,:4] + dr
# do = self.opacity_deform(hidden)
# opacity = opacity_emb[:,:1] + do
# print("deformation value:","pts:",torch.abs(dx).mean(),"rotation:",torch.abs(dr).mean())
return pts, scales, rotations
def get_mlp_parameters(self):
parameter_list = []
for name, param in self.named_parameters():
if "grid" not in name:
parameter_list.append(param)
return parameter_list
def get_grid_parameters(self):
return list(self.grid.parameters() ) + list(self.timegrid.parameters())
class deform_network(nn.Module):
def __init__(self) :
super(deform_network, self).__init__()
net_width = 256
timebase_pe = 4
defor_depth= 1
posbase_pe= 10
scale_rotation_pe = 4
opacity_pe = 2
timenet_width = 256
timenet_output = 32
times_ch = 2*timebase_pe+1
self.timenet = nn.Sequential(
nn.Linear(times_ch, timenet_width), nn.ReLU(),
nn.Linear(timenet_width, timenet_output))
self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(3+4+3)+(3*posbase_pe+(4+3)*scale_rotation_pe)*2, input_ch_time=timenet_output)
self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)]))
self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)]))
self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)]))
self.register_buffer('opacity_poc', torch.FloatTensor([(2**i) for i in range(opacity_pe)]))
self.apply(initialize_weights)
def forward(self, point, scales, rotations, opacity, times_sel):
times_emb = poc_fre(times_sel, self.time_poc)
times_feature = self.timenet(times_emb)
pts_emb = poc_fre(point, self.pos_poc)
scales_emb = poc_fre(scales, self.rotation_scaling_poc)
rotations_emb = poc_fre(rotations, self.rotation_scaling_poc)
# opacity_emb = poc_fre(opacity, self.opacity_poc)
means3D, scales, rotations = self.deformation_net( pts_emb,
scales_emb,
rotations_emb,
# opacity_emb,
times_feature,
times_emb)
return means3D, scales, rotations, opacity
def get_mlp_parameters(self):
return self.deformation_net.get_mlp_parameters() + list(self.timenet.parameters())
def get_grid_parameters(self):
return self.deformation_net.get_grid_parameters()
class dynamic_gate(nn.Module):
def __init__(self) -> None:
super(dynamic_gate).__init__()
net_width = 256
timebase_pe = 4
posbase_pe= 10
scale_rotation_pe = 4
timenet_width = 256
timenet_output = 32
times_ch = 2*timebase_pe+1
self.timenet = nn.Sequential(
nn.Linear(times_ch, timenet_width), nn.ReLU(),
nn.Linear(timenet_width, timenet_output))
self.deformation_net = nn.Sequential(
nn.Linear((3+4+3)+(3*posbase_pe+(4+3)*scale_rotation_pe)*2+timenet_output,net_width),
nn.ReLU(),
nn.Linear(net_width,1)
)
# self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(3+4+3)+(3*posbase_pe+(4+3)*scale_rotation_pe)*2, input_ch_time=timenet_output)
self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)]))
self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)]))
self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)]))
self.apply(initialize_weights)
def forward(self, point, scales, rotations, opacity, times_sel):
times_emb = poc_fre(times_sel, self.time_poc)
times_feature = self.timenet(times_emb)
pts_emb = poc_fre(point, self.pos_poc)
scales_emb = poc_fre(scales, self.rotation_scaling_poc)
rotations_emb = poc_fre(rotations, self.rotation_scaling_poc)
motion_rate = self.deformation_net(torch.cat([pts_emb,
scales_emb,
rotations_emb,
times_feature],-1))
return motion_rate
class Tineuvox(nn.Module):
def __init__(self) -> None:
super(Tineuvox).__init__()
pass
def poc_fre(input_data,poc_buf):
input_data_emb = (input_data.unsqueeze(-1) * poc_buf).flatten(-2)
input_data_sin = input_data_emb.sin()
input_data_cos = input_data_emb.cos()
input_data_emb = torch.cat([input_data, input_data_sin,input_data_cos], -1)
return input_data_emb
def initialize_weights(m):
if isinstance(m, nn.Linear):
# init.constant_(m.weight, 0)
init.xavier_uniform_(m.weight,gain=1)
if m.bias is not None:
init.xavier_uniform_(m.weight,gain=1)
# init.constant_(m.bias, 0)