95 lines
3.6 KiB
Python
95 lines
3.6 KiB
Python
import os
|
|
import numpy as np
|
|
from torch.utils.data import Dataset
|
|
from PIL import Image
|
|
from utils.graphics_utils import focal2fov
|
|
from scene.colmap_loader import qvec2rotmat
|
|
from scene.dataset_readers import CameraInfo
|
|
from scene.neural_3D_dataset_NDC import get_spiral
|
|
from torchvision import transforms as T
|
|
|
|
|
|
class multipleview_dataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
cam_extrinsics,
|
|
cam_intrinsics,
|
|
cam_folder,
|
|
split
|
|
):
|
|
self.focal = [cam_intrinsics[1].params[0], cam_intrinsics[1].params[0]]
|
|
height=cam_intrinsics[1].height
|
|
width=cam_intrinsics[1].width
|
|
self.FovY = focal2fov(self.focal[0], height)
|
|
self.FovX = focal2fov(self.focal[0], width)
|
|
self.transform = T.ToTensor()
|
|
self.image_paths, self.image_poses, self.image_times= self.load_images_path(cam_folder, cam_extrinsics,cam_intrinsics,split)
|
|
if split=="test":
|
|
self.video_cam_infos=self.get_video_cam_infos(cam_folder)
|
|
|
|
|
|
def load_images_path(self, cam_folder, cam_extrinsics,cam_intrinsics,split):
|
|
image_length = len(os.listdir(os.path.join(cam_folder,"cam01")))
|
|
#len_cam=len(cam_extrinsics)
|
|
image_paths=[]
|
|
image_poses=[]
|
|
image_times=[]
|
|
for idx, key in enumerate(cam_extrinsics):
|
|
extr = cam_extrinsics[key]
|
|
R = np.transpose(qvec2rotmat(extr.qvec))
|
|
T = np.array(extr.tvec)
|
|
|
|
number = os.path.basename(extr.name)[5:-4]
|
|
images_folder=os.path.join(cam_folder,"cam"+number.zfill(2))
|
|
|
|
image_range=range(image_length)
|
|
if split=="test":
|
|
image_range = [image_range[0],image_range[int(image_length/3)],image_range[int(image_length*2/3)]]
|
|
|
|
for i in image_range:
|
|
num=i+1
|
|
image_path=os.path.join(images_folder,"frame_"+str(num).zfill(5)+".jpg")
|
|
image_paths.append(image_path)
|
|
image_poses.append((R,T))
|
|
image_times.append(float(i/image_length))
|
|
|
|
return image_paths, image_poses,image_times
|
|
|
|
def get_video_cam_infos(self,datadir):
|
|
poses_arr = np.load(os.path.join(datadir, "poses_bounds_multipleview.npy"))
|
|
poses = poses_arr[:, :-2].reshape([-1, 3, 5]) # (N_cams, 3, 5)
|
|
near_fars = poses_arr[:, -2:]
|
|
poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
|
|
N_views = 300
|
|
val_poses = get_spiral(poses, near_fars, N_views=N_views)
|
|
|
|
cameras = []
|
|
len_poses = len(val_poses)
|
|
times = [i/len_poses for i in range(len_poses)]
|
|
image = Image.open(self.image_paths[0])
|
|
image = self.transform(image)
|
|
|
|
for idx, p in enumerate(val_poses):
|
|
image_path = None
|
|
image_name = f"{idx}"
|
|
time = times[idx]
|
|
pose = np.eye(4)
|
|
pose[:3,:] = p[:3,:]
|
|
R = pose[:3,:3]
|
|
R = - R
|
|
R[:,0] = -R[:,0]
|
|
T = -pose[:3,3].dot(R)
|
|
FovX = self.FovX
|
|
FovY = self.FovY
|
|
cameras.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
|
image_path=image_path, image_name=image_name, width=image.shape[2], height=image.shape[1],
|
|
time = time, mask=None))
|
|
return cameras
|
|
def __len__(self):
|
|
return len(self.image_paths)
|
|
def __getitem__(self, index):
|
|
img = Image.open(self.image_paths[index])
|
|
img = self.transform(img)
|
|
return img, self.image_poses[index], self.image_times[index]
|
|
def load_pose(self,index):
|
|
return self.image_poses[index] |