diff --git a/README.md b/README.md index 6b5d39c..f937246 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,18 @@ For training synthetic scenes such as `bouncingballs`, run python train.py -s data/dnerf/bouncingballs --port 6017 --expname "dnerf/bouncingballs" --configs arguments/dnerf/bouncingballs.py ``` You can customize your training config through the config files. + +# Checkpoint +Also, you can training your model with checkpoint. +```python +python train.py -s data/dnerf/bouncingballs --port 6017 --expname "dnerf/bouncingballs" --configs arguments/dnerf/bouncingballs.py --checkpoint_iterations 200 # change it. +``` +Then load checkpoint with: +```python +python train.py -s data/dnerf/bouncingballs --port 6017 --expname "dnerf/bouncingballs" --configs arguments/dnerf/bouncingballs.py --start_checkpoint "output/dnerf/bouncingballs/chkpnt_coarse_200.pth" +# finestage: --start_checkpoint "output/dnerf/bouncingballs/chkpnt_fine_200.pth" +``` + ## Rendering Run the following script to render the images. diff --git a/submodules/depth-diff-gaussian-rasterization b/submodules/depth-diff-gaussian-rasterization index 2eb32ea..e495066 160000 --- a/submodules/depth-diff-gaussian-rasterization +++ b/submodules/depth-diff-gaussian-rasterization @@ -1 +1 @@ -Subproject commit 2eb32ea251d3b339dab3af8b6fd78d7dec3caf8e +Subproject commit e49506654e8e11ed8a62d22bcb693e943fdecacf diff --git a/train.py b/train.py index d2c0384..13f9e51 100644 --- a/train.py +++ b/train.py @@ -45,9 +45,15 @@ def scene_reconstruction(dataset, opt, hyper, pipe, testing_iterations, saving_i gaussians.training_setup(opt) if checkpoint: - breakpoint() - (model_params, first_iter) = torch.load(checkpoint) - gaussians.restore(model_params, opt) + # breakpoint() + if stage == "coarse" and stage not in checkpoint: + print("start from fine stage, skip coarse stage.") + # process is in the coarse stage, but start from fine stage + return + if stage in checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") @@ -274,7 +280,7 @@ def scene_reconstruction(dataset, opt, hyper, pipe, testing_iterations, saving_i if (iteration in checkpoint_iterations): print("\n[ITER {}] Saving Checkpoint".format(iteration)) - torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") + torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" +f"_{stage}_" + str(iteration) + ".pth") def training(dataset, hyper, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, expname): # first_iter = 0 tb_writer = prepare_output_and_logger(expname)