Supplement on multipleviews scenes data reading

This commit is contained in:
Geralt_of_Rivia 2024-04-12 08:28:07 +08:00 committed by Hongyuan-Tao
parent 504d25f8eb
commit a59e346db9
7 changed files with 279 additions and 6 deletions

View File

@ -57,7 +57,9 @@ In our environment, we use pytorch=1.13.1+cu116.
The dataset provided in [D-NeRF](https://github.com/albertpumarola/D-NeRF) is used. You can download the dataset from [dropbox](https://www.dropbox.com/s/0bf6fl0ye2vz3vr/data.zip?dl=0).
**For real dynamic scenes:**
The dataset provided in [HyperNeRF](https://github.com/google/hypernerf) is used. You can download scenes from [Hypernerf Dataset](https://github.com/google/hypernerf/releases/tag/v0.1) and organize them as [Nerfies](https://github.com/google/nerfies#datasets). Meanwhile, [Plenoptic Dataset](https://github.com/facebookresearch/Neural_3D_Video) could be downloaded from their official websites. To save the memory, you should extract the frames of each video and then organize your dataset as follows.
The dataset provided in [HyperNeRF](https://github.com/google/hypernerf) is used. You can download scenes from [Hypernerf Dataset](https://github.com/google/hypernerf/releases/tag/v0.1) and organize them as [Nerfies](https://github.com/google/nerfies#datasets).
Meanwhile, [Plenoptic Dataset](https://github.com/facebookresearch/Neural_3D_Video) could be downloaded from their official websites. To save the memory, you should extract the frames of each video and then organize your dataset as follows.
```
├── data
@ -86,6 +88,50 @@ The dataset provided in [HyperNeRF](https://github.com/google/hypernerf) is used
| ├── ...
```
**For multipleviews scenes:**
If you want to train your own dataset of multipleviews scenes,you can orginize your dataset as follows:
```
├── data
| | multipleview
│ | (your dataset name)
│ | cam01
| ├── frame_00001.jpg
│ ├── frame_00002.jpg
│ ├── ...
│ | cam02
│ ├── frame_00001.jpg
│ ├── frame_00002.jpg
│ ├── ...
│ | ...
```
After that,you can use the `multipleviewprogress.sh` we provided to generate related data of poses and pointcloud.You can use it as follows:
```bash
bash multipleviewprogress.sh (youe dataset name)
```
You need to ensure that the data folder is orginized as follows after running multipleviewprogress.sh:
```
├── data
| | multipleview
│ | (your dataset name)
│ | cam01
| ├── frame_00001.jpg
│ ├── frame_00002.jpg
│ ├── ...
│ | cam02
│ ├── frame_00001.jpg
│ ├── frame_00002.jpg
│ ├── ...
│ | ...
│ | sparse_
│ ├── cameras.bin
│ ├── images.bin
│ ├── ...
│ | points3D_multipleview.ply
│ | poses_bounds_multipleview.npy
```
## Training
For training synthetic scenes such as `bouncingballs`, run
@ -105,17 +151,23 @@ python scripts/downsample_point.py data/dynerf/cut_roasted_beef/colmap/dense/wor
# Finally, train.
python train.py -s data/dynerf/cut_roasted_beef --port 6017 --expname "dynerf/cut_roasted_beef" --configs arguments/dynerf/cut_roasted_beef.py
```
For training hypernerf scenes such as `virg/broom`, run
For training hypernerf scenes such as `virg/broom`: Pregenerated point clouds by COLMAP are provided [here](https://drive.google.com/file/d/1fUHiSgimVjVQZ2OOzTFtz02E9EqCoWr5/view). Just download them and put them in to correspond folder, and you can skip the former two steps. Also, you can run the commands directly.
```python
# First, computing dense point clouds by COLMAP
bash colmap.sh data/hypernerf/virg/broom2 hypernerf
# Second, downsample the point clouds generated in the first step.
# Second, downsample the point clouds generated in the first step.
python scripts/downsample_point.py data/hypernerf/virg/broom2/colmap/dense/workspace/fused.ply data/hypernerf/virg/broom2/points3D_downsample2.ply
# Finally, train.
python train.py -s data/hypernerf/virg/broom2/ --port 6017 --expname "hypernerf/broom2" --configs arguments/hypernerf/broom2.py
```
For your custom datasets, install nerfstudio and follow their colmap pipeline.
For training multipleviews scenes,you are supposed to build a configuration file named (you dataset name).py under "./arguments/mutipleview",after that,run
```python
python train.py -s data/multipleview/(your dataset name) --port 6017 --expname "multipleview/(your dataset name)" --configs arguments/multipleview/(you dataset name).py
```
For your custom datasets, install nerfstudio and follow their [COLMAP](https://colmap.github.io/) pipeline. You should install COLMAP at first, then:
```python
pip install nerfstudio

View File

@ -0,0 +1,33 @@
ModelHiddenParams = dict(
kplanes_config = {
'grid_dimensions': 2,
'input_coordinate_dim': 4,
'output_coordinate_dim': 16,
'resolution': [64, 64, 64, 150]
},
multires = [1,2],
defor_depth = 0,
net_width = 128,
plane_tv_weight = 0.0002,
time_smoothness_weight = 0.001,
l1_time_planes = 0.0001,
no_do=False,
no_dshs=False,
no_ds=False,
empty_voxel=False,
render_process=False,
static_mlp=False
)
OptimizationParams = dict(
dataloader=True,
iterations = 15000,
batch_size=1,
coarse_iterations = 3000,
densify_until_iter = 10_000,
# opacity_reset_interval = 60000,
opacity_threshold_coarse = 0.005,
opacity_threshold_fine_init = 0.005,
opacity_threshold_fine_after = 0.005,
# pruning_interval = 2000
)

27
multipleviewprogress.sh Normal file
View File

@ -0,0 +1,27 @@
workdir=$1
python scripts/extractimages.py $workdir
colmap feature_extractor --database_path ./colmap_tmp/database.db --image_path ./colmap_tmp/images --SiftExtraction.max_image_size 4096 --SiftExtraction.max_num_features 16384 --SiftExtraction.estimate_affine_shape 1 --SiftExtraction.domain_size_pooling 1
colmap exhaustive_matcher --database_path ./colmap_tmp/database.db
mkdir ./colmap_tmp/sparse
colmap mapper --database_path ./colmap_tmp/database.db --image_path ./colmap_tmp/images --output_path ./colmap_tmp/sparse
mkdir ./data/multipleview/$workdir/sparse_
cp -r ./colmap_tmp/sparse/0/* ./data/multipleview/$workdir/sparse_
mkdir ./colmap_tmp/dense
colmap image_undistorter --image_path ./colmap_tmp/images --input_path ./colmap_tmp/sparse/0 --output_path ./colmap_tmp/dense --output_type COLMAP
colmap patch_match_stereo --workspace_path ./colmap_tmp/dense --workspace_format COLMAP --PatchMatchStereo.geom_consistency true
colmap stereo_fusion --workspace_path ./colmap_tmp/dense --workspace_format COLMAP --input_type geometric --output_path ./colmap_tmp/dense/fused.ply
python scripts/downsample_point.py ./colmap_tmp/dense/fused.ply ./data/multipleview/$workdir/points3D_multipleview.ply
git clone https://github.com/Fyusion/LLFF.git
pip install scikit-image
python LLFF/imgs2poses.py ./colmap_tmp/
cp ./colmap_tmp/poses_bounds.npy ./data/multipleview/$workdir/poses_bounds_multipleview.npy
rm -rf ./colmap_tmp
rm -rf ./LLFF

View File

@ -58,6 +58,9 @@ class Scene:
elif os.path.exists(os.path.join(args.source_path,"train_meta.json")):
scene_info = sceneLoadTypeCallbacks["PanopticSports"](args.source_path)
dataset_type="PanopticSports"
elif os.path.exists(os.path.join(args.source_path,"points3D_multipleview.ply")):
scene_info = sceneLoadTypeCallbacks["MultipleView"](args.source_path)
dataset_type="MultipleView"
else:
assert False, "Could not recognize scene type!"
self.maxtime = scene_info.maxtime

View File

@ -570,6 +570,7 @@ def readPanopticmeta(datadir, json_path):
scene_radius = 1.1 * np.max(np.linalg.norm(cam_centers - np.mean(cam_centers, 0)[None], axis=-1))
# breakpoint()
return cam_infos, max_time, scene_radius
def readPanopticSportsinfos(datadir):
train_cam_infos, max_time, scene_radius = readPanopticmeta(datadir, "train_meta.json")
test_cam_infos,_, _ = readPanopticmeta(datadir, "test_meta.json")
@ -599,11 +600,51 @@ def readPanopticSportsinfos(datadir):
maxtime=max_time,
)
return scene_info
def readMultipleViewinfos(datadir,llffhold=8):
cameras_extrinsic_file = os.path.join(datadir, "sparse_/images.bin")
cameras_intrinsic_file = os.path.join(datadir, "sparse_/cameras.bin")
cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
from scene.multipleview_dataset import multipleview_dataset
train_cam_infos = multipleview_dataset(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, cam_folder=datadir,split="train")
test_cam_infos = multipleview_dataset(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, cam_folder=datadir,split="test")
train_cam_infos_ = format_infos(train_cam_infos,"train")
nerf_normalization = getNerfppNorm(train_cam_infos_)
ply_path = os.path.join(datadir, "points3D_multipleview.ply")
bin_path = os.path.join(datadir, "points3D_multipleview.bin")
txt_path = os.path.join(datadir, "points3D_multipleview.txt")
if not os.path.exists(ply_path):
print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
try:
xyz, rgb, _ = read_points3D_binary(bin_path)
except:
xyz, rgb, _ = read_points3D_text(txt_path)
storePly(ply_path, xyz, rgb)
try:
pcd = fetchPly(ply_path)
except:
pcd = None
scene_info = SceneInfo(point_cloud=pcd,
train_cameras=train_cam_infos,
test_cameras=test_cam_infos,
video_cameras=test_cam_infos.video_cam_infos,
maxtime=0,
nerf_normalization=nerf_normalization,
ply_path=ply_path)
return scene_info
sceneLoadTypeCallbacks = {
"Colmap": readColmapSceneInfo,
"Blender" : readNerfSyntheticInfo,
"dynerf" : readdynerfInfo,
"nerfies": readHyperDataInfos, # NeRFies & HyperNeRF dataset proposed by [https://github.com/google/hypernerf/releases/tag/v0.1]
"PanopticSports" : readPanopticSportsinfos
"PanopticSports" : readPanopticSportsinfos,
"MultipleView": readMultipleViewinfos
}

View File

@ -0,0 +1,95 @@
import os
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from utils.graphics_utils import focal2fov
from scene.colmap_loader import qvec2rotmat
from scene.dataset_readers import CameraInfo
from scene.neural_3D_dataset_NDC import get_spiral
from torchvision import transforms as T
class multipleview_dataset(Dataset):
def __init__(
self,
cam_extrinsics,
cam_intrinsics,
cam_folder,
split
):
self.focal = [cam_intrinsics[1].params[0], cam_intrinsics[1].params[0]]
height=cam_intrinsics[1].height
width=cam_intrinsics[1].width
self.FovY = focal2fov(self.focal[0], height)
self.FovX = focal2fov(self.focal[0], width)
self.transform = T.ToTensor()
self.image_paths, self.image_poses, self.image_times= self.load_images_path(cam_folder, cam_extrinsics,cam_intrinsics,split)
if split=="test":
self.video_cam_infos=self.get_video_cam_infos(cam_folder)
def load_images_path(self, cam_folder, cam_extrinsics,cam_intrinsics,split):
image_length = len(os.listdir(os.path.join(cam_folder,"cam01")))
#len_cam=len(cam_extrinsics)
image_paths=[]
image_poses=[]
image_times=[]
for idx, key in enumerate(cam_extrinsics):
extr = cam_extrinsics[key]
R = np.transpose(qvec2rotmat(extr.qvec))
T = np.array(extr.tvec)
number = os.path.basename(extr.name)[5:-4]
images_folder=os.path.join(cam_folder,"cam"+number.zfill(2))
image_range=range(image_length)
if split=="test":
image_range = [image_range[0],image_range[int(image_length/3)],image_range[int(image_length*2/3)]]
for i in image_range:
num=i+1
image_path=os.path.join(images_folder,"frame_"+str(num).zfill(5)+".jpg")
image_paths.append(image_path)
image_poses.append((R,T))
image_times.append(float(i/image_length))
return image_paths, image_poses,image_times
def get_video_cam_infos(self,datadir):
poses_arr = np.load(os.path.join(datadir, "poses_bounds_multipleview.npy"))
poses = poses_arr[:, :-2].reshape([-1, 3, 5]) # (N_cams, 3, 5)
near_fars = poses_arr[:, -2:]
poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
N_views = 300
val_poses = get_spiral(poses, near_fars, N_views=N_views)
cameras = []
len_poses = len(val_poses)
times = [i/len_poses for i in range(len_poses)]
image = Image.open(self.image_paths[0])
image = self.transform(image)
for idx, p in enumerate(val_poses):
image_path = None
image_name = f"{idx}"
time = times[idx]
pose = np.eye(4)
pose[:3,:] = p[:3,:]
R = pose[:3,:3]
R = - R
R[:,0] = -R[:,0]
T = -pose[:3,3].dot(R)
FovX = self.FovX
FovY = self.FovY
cameras.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
image_path=image_path, image_name=image_name, width=image.shape[2], height=image.shape[1],
time = time, mask=None))
return cameras
def __len__(self):
return len(self.image_paths)
def __getitem__(self, index):
img = Image.open(self.image_paths[index])
img = self.transform(img)
return img, self.image_poses[index], self.image_times[index]
def load_pose(self,index):
return self.image_poses[index]

22
scripts/extractimages.py Normal file
View File

@ -0,0 +1,22 @@
import os
import sys
import shutil
folder_path = sys.argv[1]
colmap_path = "./colmap_tmp"
images_path = os.path.join(colmap_path, "images")
os.makedirs(images_path, exist_ok=True)
i=0
dir1=os.path.join("data",folder_path)
for folder_name in os.listdir(dir1):
dir2=os.path.join(dir1,folder_name)
for file_name in os.listdir(dir2):
if file_name.startswith("frame_00001"):
i=i+1
src_path = os.path.join(dir2, file_name)
dst_path = os.path.join(images_path, f"image{i}.jpg")
shutil.copyfile(src_path, dst_path)
print("End")