240 lines
8.1 KiB
Python
240 lines
8.1 KiB
Python
#
|
|
# Copyright (C) 2023, Inria
|
|
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
|
# All rights reserved.
|
|
#
|
|
# This software is free for non-commercial, research and evaluation use
|
|
# under the terms of the LICENSE.md file.
|
|
#
|
|
# For inquiries contact george.drettakis@inria.fr
|
|
#
|
|
|
|
from typing import NamedTuple
|
|
import torch.nn as nn
|
|
import torch
|
|
import os
|
|
from . import _C
|
|
|
|
def cpu_deep_copy_tuple(input_tuple):
|
|
copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
|
|
return tuple(copied_tensors)
|
|
|
|
def rasterize_gaussians(
|
|
means3D,
|
|
means2D,
|
|
sh,
|
|
colors_precomp,
|
|
opacities,
|
|
scales,
|
|
rotations,
|
|
deformation,
|
|
cov3Ds_precomp,
|
|
raster_settings,
|
|
):
|
|
return _RasterizeGaussians.apply(
|
|
means3D,
|
|
means2D,
|
|
sh,
|
|
colors_precomp,
|
|
opacities,
|
|
scales,
|
|
rotations,
|
|
deformation,
|
|
cov3Ds_precomp,
|
|
raster_settings,
|
|
)
|
|
|
|
class _RasterizeGaussians(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
means3D,
|
|
means2D,
|
|
sh,
|
|
colors_precomp,
|
|
opacities,
|
|
scales,
|
|
rotations,
|
|
deformation,
|
|
cov3Ds_precomp,
|
|
raster_settings,
|
|
):
|
|
|
|
# Restructure arguments the way that the C++ lib expects them
|
|
args = (
|
|
raster_settings.bg,
|
|
means3D,
|
|
colors_precomp,
|
|
opacities,
|
|
scales,
|
|
rotations,
|
|
deformation,
|
|
raster_settings.scale_modifier,
|
|
cov3Ds_precomp,
|
|
raster_settings.viewmatrix,
|
|
raster_settings.projmatrix,
|
|
raster_settings.tanfovx,
|
|
raster_settings.tanfovy,
|
|
raster_settings.image_height,
|
|
raster_settings.image_width,
|
|
sh,
|
|
raster_settings.sh_degree,
|
|
raster_settings.campos,
|
|
raster_settings.prefiltered,
|
|
raster_settings.debug
|
|
)
|
|
|
|
# Invoke C++/CUDA rasterizer
|
|
if raster_settings.debug:
|
|
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
|
|
try:
|
|
|
|
if not os.path.exists("snapshot_fw_good.dump"):
|
|
torch.save(cpu_args, "snapshot_fw_good.dump")
|
|
# print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
|
|
num_rendered, color, depth, acc, velocity, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
|
|
except Exception as ex:
|
|
torch.save(cpu_args, "snapshot_fw_bad.dump")
|
|
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
|
|
raise ex
|
|
else:
|
|
num_rendered, color, depth, acc, velocity, radii, geomBuffer, binningBuffer, imgBuffer, velocity = _C.rasterize_gaussians(*args)
|
|
|
|
# Keep relevant tensors for backward
|
|
ctx.raster_settings = raster_settings
|
|
ctx.num_rendered = num_rendered
|
|
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, deformation, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
|
|
return color, radii, depth, acc, velocity
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out_color, grad_radii, grad_depth, grad_acc, grad_velocity):
|
|
|
|
# Restore necessary values from context
|
|
num_rendered = ctx.num_rendered
|
|
raster_settings = ctx.raster_settings
|
|
colors_precomp, means3D, scales, rotations, deformation, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
|
|
|
|
# Restructure args as C++ method expects them
|
|
args = (raster_settings.bg,
|
|
means3D,
|
|
radii,
|
|
colors_precomp,
|
|
scales,
|
|
rotations,
|
|
deformation,
|
|
raster_settings.scale_modifier,
|
|
cov3Ds_precomp,
|
|
raster_settings.viewmatrix,
|
|
raster_settings.projmatrix,
|
|
raster_settings.tanfovx,
|
|
raster_settings.tanfovy,
|
|
grad_out_color,
|
|
grad_acc,
|
|
grad_velocity,
|
|
sh,
|
|
raster_settings.sh_degree,
|
|
raster_settings.campos,
|
|
geomBuffer,
|
|
num_rendered,
|
|
binningBuffer,
|
|
imgBuffer,
|
|
raster_settings.debug)
|
|
|
|
# Compute gradients for relevant tensors by invoking backward method
|
|
if raster_settings.debug:
|
|
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
|
|
try:
|
|
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations, grad_deformation = _C.rasterize_gaussians_backward(*args)
|
|
if not os.path.exists("snapshot_bw_good.dump"):
|
|
torch.save(cpu_args, "snapshot_bw_good.dump")
|
|
# print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
|
|
# raise ex
|
|
except Exception as ex:
|
|
torch.save(cpu_args, "snapshot_bw_bad.dump")
|
|
print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
|
|
raise ex
|
|
else:
|
|
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations, grad_deformation = _C.rasterize_gaussians_backward(*args)
|
|
|
|
grads = (
|
|
grad_means3D,
|
|
grad_means2D,
|
|
grad_sh,
|
|
grad_colors_precomp,
|
|
grad_opacities,
|
|
grad_scales,
|
|
grad_rotations,
|
|
grad_deformation,
|
|
grad_cov3Ds_precomp,
|
|
None,
|
|
)
|
|
|
|
return grads
|
|
|
|
class GaussianRasterizationSettings(NamedTuple):
|
|
image_height: int
|
|
image_width: int
|
|
tanfovx : float
|
|
tanfovy : float
|
|
bg : torch.Tensor
|
|
scale_modifier : float
|
|
viewmatrix : torch.Tensor
|
|
projmatrix : torch.Tensor
|
|
sh_degree : int
|
|
campos : torch.Tensor
|
|
prefiltered : bool
|
|
debug : bool
|
|
|
|
class GaussianRasterizer(nn.Module):
|
|
def __init__(self, raster_settings):
|
|
super().__init__()
|
|
self.raster_settings = raster_settings
|
|
|
|
def markVisible(self, positions):
|
|
# Mark visible points (based on frustum culling for camera) with a boolean
|
|
with torch.no_grad():
|
|
raster_settings = self.raster_settings
|
|
visible = _C.mark_visible(
|
|
positions,
|
|
raster_settings.viewmatrix,
|
|
raster_settings.projmatrix)
|
|
|
|
return visible
|
|
|
|
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, deformation = None ,cov3D_precomp = None):
|
|
|
|
raster_settings = self.raster_settings
|
|
|
|
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
|
|
raise Exception('Please provide excatly one of either SHs or precomputed colors!')
|
|
|
|
if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
|
|
raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
|
|
|
|
if shs is None:
|
|
shs = torch.Tensor([])
|
|
if colors_precomp is None:
|
|
colors_precomp = torch.Tensor([])
|
|
|
|
if scales is None:
|
|
scales = torch.Tensor([])
|
|
if rotations is None:
|
|
rotations = torch.Tensor([])
|
|
if cov3D_precomp is None:
|
|
cov3D_precomp = torch.Tensor([])
|
|
|
|
# Invoke C++/CUDA rasterization routine
|
|
return rasterize_gaussians(
|
|
means3D,
|
|
means2D,
|
|
shs,
|
|
colors_precomp,
|
|
opacities,
|
|
scales,
|
|
rotations,
|
|
deformation,
|
|
cov3D_precomp,
|
|
raster_settings,
|
|
)
|
|
|