4DGaussians/scene/multipleview_dataset.py
2024-04-15 13:18:47 +08:00

95 lines
3.6 KiB
Python

import os
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from utils.graphics_utils import focal2fov
from scene.colmap_loader import qvec2rotmat
from scene.dataset_readers import CameraInfo
from scene.neural_3D_dataset_NDC import get_spiral
from torchvision import transforms as T
class multipleview_dataset(Dataset):
def __init__(
self,
cam_extrinsics,
cam_intrinsics,
cam_folder,
split
):
self.focal = [cam_intrinsics[1].params[0], cam_intrinsics[1].params[0]]
height=cam_intrinsics[1].height
width=cam_intrinsics[1].width
self.FovY = focal2fov(self.focal[0], height)
self.FovX = focal2fov(self.focal[0], width)
self.transform = T.ToTensor()
self.image_paths, self.image_poses, self.image_times= self.load_images_path(cam_folder, cam_extrinsics,cam_intrinsics,split)
if split=="test":
self.video_cam_infos=self.get_video_cam_infos(cam_folder)
def load_images_path(self, cam_folder, cam_extrinsics,cam_intrinsics,split):
image_length = len(os.listdir(os.path.join(cam_folder,"cam01")))
#len_cam=len(cam_extrinsics)
image_paths=[]
image_poses=[]
image_times=[]
for idx, key in enumerate(cam_extrinsics):
extr = cam_extrinsics[key]
R = np.transpose(qvec2rotmat(extr.qvec))
T = np.array(extr.tvec)
number = os.path.basename(extr.name)[5:-4]
images_folder=os.path.join(cam_folder,"cam"+number.zfill(2))
image_range=range(image_length)
if split=="test":
image_range = [image_range[0],image_range[int(image_length/3)],image_range[int(image_length*2/3)]]
for i in image_range:
num=i+1
image_path=os.path.join(images_folder,"frame_"+str(num).zfill(5)+".jpg")
image_paths.append(image_path)
image_poses.append((R,T))
image_times.append(float(i/image_length))
return image_paths, image_poses,image_times
def get_video_cam_infos(self,datadir):
poses_arr = np.load(os.path.join(datadir, "poses_bounds_multipleview.npy"))
poses = poses_arr[:, :-2].reshape([-1, 3, 5]) # (N_cams, 3, 5)
near_fars = poses_arr[:, -2:]
poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
N_views = 300
val_poses = get_spiral(poses, near_fars, N_views=N_views)
cameras = []
len_poses = len(val_poses)
times = [i/len_poses for i in range(len_poses)]
image = Image.open(self.image_paths[0])
image = self.transform(image)
for idx, p in enumerate(val_poses):
image_path = None
image_name = f"{idx}"
time = times[idx]
pose = np.eye(4)
pose[:3,:] = p[:3,:]
R = pose[:3,:3]
R = - R
R[:,0] = -R[:,0]
T = -pose[:3,3].dot(R)
FovX = self.FovX
FovY = self.FovY
cameras.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
image_path=image_path, image_name=image_name, width=image.shape[2], height=image.shape[1],
time = time, mask=None))
return cameras
def __len__(self):
return len(self.image_paths)
def __getitem__(self, index):
img = Image.open(self.image_paths[index])
img = self.transform(img)
return img, self.image_poses[index], self.image_times[index]
def load_pose(self,index):
return self.image_poses[index]