4DGaussians/scene/__init__.py
2023-10-13 17:06:11 +08:00

130 lines
6.5 KiB
Python

#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import os
import random
import json
from utils.system_utils import searchForMaxIteration
from scene.dataset_readers import sceneLoadTypeCallbacks
from scene.gaussian_model import GaussianModel
from scene.dataset import FourDGSdataset
from arguments import ModelParams
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
from torch.utils.data import Dataset
class Scene:
gaussians : GaussianModel
def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0], load_coarse=False):
"""b
:param path: Path to colmap scene main folder.
"""
self.model_path = args.model_path
self.loaded_iter = None
self.gaussians = gaussians
if load_iteration:
if load_iteration == -1:
self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
else:
self.loaded_iter = load_iteration
print("Loading trained model at iteration {}".format(self.loaded_iter))
self.train_cameras = {}
self.test_cameras = {}
self.video_cameras = {}
if os.path.exists(os.path.join(args.source_path, "sparse")):
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
print("Found transforms_train.json file, assuming Blender data set!")
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
elif os.path.exists(os.path.join(args.source_path, "poses_bounds.npy")):
scene_info = sceneLoadTypeCallbacks["dynerf"](args.source_path, args.white_background, args.eval)
elif os.path.exists(os.path.join(args.source_path,"dataset.json")):
scene_info = sceneLoadTypeCallbacks["nerfies"](args.source_path, False, args.eval)
else:
assert False, "Could not recognize scene type!"
self.maxtime = scene_info.maxtime
# if not self.loaded_iter:
# with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
# dest_file.write(src_file.read())
# json_cams = []
# camlist = []
# if scene_info.test_cameras:
# camlist.extend(scene_info.test_cameras)
# if scene_info.train_cameras:
# camlist.extend(scene_info.train_cameras)
# for id, cam in enumerate(camlist):
# json_cams.append(camera_to_JSON(id, cam))
# with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
# json.dump(json_cams, file)
# if shuffle:
# random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
# random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
self.cameras_extent = scene_info.nerf_normalization["radius"]
# for resolution_scale in resolution_scales:
# print("Loading Training Cameras")
# self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
# print("Loading Test Cameras")
# self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
# print("Loading Video Cameras")
# self.video_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.video_cameras, resolution_scale, args)
print("Loading Training Cameras")
self.train_camera = FourDGSdataset(scene_info.train_cameras, args)
print("Loading Test Cameras")
self.test_camera = FourDGSdataset(scene_info.test_cameras, args)
print("Loading Video Cameras")
self.video_camera = cameraList_from_camInfos(scene_info.video_cameras,-1,args)
xyz_max = scene_info.point_cloud.points.max(axis=0)
xyz_min = scene_info.point_cloud.points.min(axis=0)
self.gaussians._deformation.deformation_net.grid.set_aabb(xyz_max,xyz_min)
if self.loaded_iter:
self.gaussians.load_ply(os.path.join(self.model_path,
"point_cloud",
"iteration_" + str(self.loaded_iter),
"point_cloud.ply"))
self.gaussians.load_model(os.path.join(self.model_path,
"point_cloud",
"iteration_" + str(self.loaded_iter),
))
# elif load_coarse:
# self.gaussians.load_ply(os.path.join(self.model_path,
# "point_cloud",
# "coarse_iteration_" + str(load_coarse),
# "point_cloud.ply"))
# self.gaussians.load_model(os.path.join(self.model_path,
# "point_cloud",
# "coarse_iteration_" + str(load_coarse),
# ))
# print("load coarse stage gaussians")
else:
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent, self.maxtime)
def save(self, iteration, stage):
if stage == "coarse":
point_cloud_path = os.path.join(self.model_path, "point_cloud/coarse_iteration_{}".format(iteration))
else:
point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
self.gaussians.save_deformation(point_cloud_path)
def getTrainCameras(self, scale=1.0):
return self.train_camera
def getTestCameras(self, scale=1.0):
return self.test_camera
def getVideoCameras(self, scale=1.0):
return self.video_camera