import concurrent.futures import gc import glob import os import cv2 import numpy as np import torch from PIL import Image from torch.utils.data import Dataset from torchvision import transforms as T from tqdm import tqdm def normalize(v): """Normalize a vector.""" return v / np.linalg.norm(v) def average_poses(poses): """ Calculate the average pose, which is then used to center all poses using @center_poses. Its computation is as follows: 1. Compute the center: the average of pose centers. 2. Compute the z axis: the normalized average z axis. 3. Compute axis y': the average y axis. 4. Compute x' = y' cross product z, then normalize it as the x axis. 5. Compute the y axis: z cross product x. Note that at step 3, we cannot directly use y' as y axis since it's not necessarily orthogonal to z axis. We need to pass from x to y. Inputs: poses: (N_images, 3, 4) Outputs: pose_avg: (3, 4) the average pose """ # 1. Compute the center center = poses[..., 3].mean(0) # (3) # 2. Compute the z axis z = normalize(poses[..., 2].mean(0)) # (3) # 3. Compute axis y' (no need to normalize as it's not the final output) y_ = poses[..., 1].mean(0) # (3) # 4. Compute the x axis x = normalize(np.cross(z, y_)) # (3) # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) y = np.cross(x, z) # (3) pose_avg = np.stack([x, y, z, center], 1) # (3, 4) return pose_avg def center_poses(poses, blender2opencv): """ Center the poses so that we can use NDC. See https://github.com/bmild/nerf/issues/34 Inputs: poses: (N_images, 3, 4) Outputs: poses_centered: (N_images, 3, 4) the centered poses pose_avg: (3, 4) the average pose """ poses = poses @ blender2opencv pose_avg = average_poses(poses) # (3, 4) pose_avg_homo = np.eye(4) pose_avg_homo[ :3 ] = pose_avg # convert to homogeneous coordinate for faster computation pose_avg_homo = pose_avg_homo # by simply adding 0, 0, 0, 1 as the last row last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) poses_homo = np.concatenate( [poses, last_row], 1 ) # (N_images, 4, 4) homogeneous coordinate poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) # poses_centered = poses_centered @ blender2opencv poses_centered = poses_centered[:, :3] # (N_images, 3, 4) return poses_centered, pose_avg_homo def viewmatrix(z, up, pos): vec2 = normalize(z) vec1_avg = up vec0 = normalize(np.cross(vec1_avg, vec2)) vec1 = normalize(np.cross(vec2, vec0)) m = np.eye(4) m[:3] = np.stack([-vec0, vec1, vec2, pos], 1) return m def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120): render_poses = [] rads = np.array(list(rads) + [1.0]) for theta in np.linspace(0.0, 2.0 * np.pi * N_rots, N + 1)[:-1]: c = np.dot( c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]) * rads, ) z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0]))) render_poses.append(viewmatrix(z, up, c)) return render_poses def process_video(video_data_save, video_path, img_wh, downsample, transform): """ Load video_path data to video_data_save tensor. """ video_frames = cv2.VideoCapture(video_path) count = 0 video_images_path = video_path.split('.')[0] image_path = os.path.join(video_images_path,"images") if not os.path.exists(image_path): os.makedirs(image_path) while video_frames.isOpened(): ret, video_frame = video_frames.read() if ret: video_frame = cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB) video_frame = Image.fromarray(video_frame) if downsample != 1.0: img = video_frame.resize(img_wh, Image.LANCZOS) img.save(os.path.join(image_path,"%04d.png"%count)) img = transform(img) video_data_save[count] = img.permute(1,2,0) count += 1 else: break else: images_path = os.listdir(image_path) images_path.sort() for path in images_path: img = Image.open(os.path.join(image_path,path)) if downsample != 1.0: img = img.resize(img_wh, Image.LANCZOS) img = transform(img) video_data_save[count] = img.permute(1,2,0) count += 1 video_frames.release() print(f"Video {video_path} processed.") return None # define a function to process all videos def process_videos(videos, skip_index, img_wh, downsample, transform, num_workers=1): """ A multi-threaded function to load all videos fastly and memory-efficiently. To save memory, we pre-allocate a tensor to store all the images and spawn multi-threads to load the images into this tensor. """ all_imgs = torch.zeros(len(videos) - 1, 300, img_wh[-1] , img_wh[-2], 3) with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: # start a thread for each video current_index = 0 futures = [] for index, video_path in enumerate(videos): # skip the video with skip_index (eval video) if index == skip_index: continue else: future = executor.submit( process_video, all_imgs[current_index], video_path, img_wh, downsample, transform, ) futures.append(future) current_index += 1 return all_imgs def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120): """ Generate a set of poses using NeRF's spiral camera trajectory as validation poses. """ # center pose c2w = average_poses(c2ws_all) # Get average pose up = normalize(c2ws_all[:, :3, 1].sum(0)) # Find a reasonable "focus depth" for this dataset dt = 0.75 close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0 focal = 1.0 / ((1.0 - dt) / close_depth + dt / inf_depth) # Get radii for spiral path zdelta = near_fars.min() * 0.2 tt = c2ws_all[:, :3, 3] rads = np.percentile(np.abs(tt), 90, 0) * rads_scale render_poses = render_path_spiral( c2w, up, rads, focal, zdelta, zrate=0.5, N=N_views ) return np.stack(render_poses) class Neural3D_NDC_Dataset(Dataset): def __init__( self, datadir, split="train", downsample=1.0, is_stack=True, cal_fine_bbox=False, N_vis=-1, time_scale=1.0, scene_bbox_min=[-1.0, -1.0, -1.0], scene_bbox_max=[1.0, 1.0, 1.0], N_random_pose=1000, bd_factor=0.75, eval_step=1, eval_index=0, sphere_scale=1.0, ): self.img_wh = ( int(1352 / downsample), int(1014 / downsample), ) # According to the neural 3D paper, the default resolution is 1024x768 self.root_dir = datadir self.split = split self.downsample = 2704 / self.img_wh[0] self.is_stack = is_stack self.N_vis = N_vis self.time_scale = time_scale self.scene_bbox = torch.tensor([scene_bbox_min, scene_bbox_max]) self.world_bound_scale = 1.1 self.bd_factor = bd_factor self.eval_step = eval_step self.eval_index = eval_index self.blender2opencv = np.eye(4) self.transform = T.ToTensor() self.near = 0.0 self.far = 1.0 self.near_far = [self.near, self.far] # NDC near far is [0, 1.0] self.white_bg = False self.ndc_ray = True self.depth_data = False self.load_meta() print(f"meta data loaded, total image:{len(self)}") def load_meta(self): """ Load meta data from the dataset. """ # Read poses and video file paths. poses_arr = np.load(os.path.join(self.root_dir, "poses_bounds.npy")) poses = poses_arr[:, :-2].reshape([-1, 3, 5]) # (N_cams, 3, 5) self.near_fars = poses_arr[:, -2:] videos = glob.glob(os.path.join(self.root_dir, "cam*")) videos = sorted(videos) assert len(videos) == poses_arr.shape[0] H, W, focal = poses[0, :, -1] focal = focal / self.downsample self.focal = [focal, focal] poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) poses, _ = center_poses( poses, self.blender2opencv ) # Re-center poses so that the average is near the center. near_original = self.near_fars.min() scale_factor = near_original * 0.75 self.near_fars /= ( scale_factor # rescale nearest plane so that it is at z = 4/3. ) poses[..., 3] /= scale_factor # Sample N_views poses for validation - NeRF-like camera trajectory. N_views = 120 self.val_poses = get_spiral(poses, self.near_fars, N_views=N_views) # self.val_poses = self.directions W, H = self.img_wh poses_i_train = [] for i in range(len(poses)): if i != self.eval_index: poses_i_train.append(i) self.poses = poses[poses_i_train] self.poses_all = poses self.image_paths, self.image_poses, self.image_times, N_cam, N_time = self.load_images_path(videos, self.split) self.cam_number = N_cam self.time_number = N_time def get_val_pose(self): render_poses = self.val_poses render_times = torch.linspace(0.0, 1.0, render_poses.shape[0]) * 2.0 - 1.0 return render_poses, self.time_scale * render_times def load_images_path(self,videos,split): image_paths = [] image_poses = [] image_times = [] N_cams = 0 N_time = 0 countss = 300 for index, video_path in enumerate(videos): if index == self.eval_index: if split =="train": continue else: if split == "test": continue N_cams +=1 count = 0 video_images_path = video_path.split('.')[0] image_path = os.path.join(video_images_path,"images") video_frames = cv2.VideoCapture(video_path) if not os.path.exists(image_path): print(f"no images saved in {image_path}, extract images from video.") os.makedirs(image_path) this_count = 0 while video_frames.isOpened(): ret, video_frame = video_frames.read() if this_count >= countss:break if ret: video_frame = cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB) video_frame = Image.fromarray(video_frame) if self.downsample != 1.0: img = video_frame.resize(self.img_wh, Image.LANCZOS) img.save(os.path.join(image_path,"%04d.png"%count)) # img = transform(img) count += 1 this_count+=1 else: break images_path = os.listdir(image_path) images_path.sort() this_count = 0 for idx, path in enumerate(images_path): if this_count >=countss:break image_paths.append(os.path.join(image_path,path)) pose = np.array(self.poses_all[index]) R = pose[:3,:3] R = -R R[:,0] = -R[:,0] T = -pose[:3,3].dot(R) image_times.append(idx/countss) image_poses.append((R,T)) # if self.downsample != 1.0: # img = video_frame.resize(self.img_wh, Image.LANCZOS) # img.save(os.path.join(image_path,"%04d.png"%count)) this_count+=1 N_time = len(images_path) # video_data_save[count] = img.permute(1,2,0) # count += 1 return image_paths, image_poses, image_times, N_cams, N_time def __len__(self): return len(self.image_paths) def __getitem__(self,index): img = Image.open(self.image_paths[index]) img = img.resize(self.img_wh, Image.LANCZOS) img = self.transform(img) return img, self.image_poses[index], self.image_times[index] def load_pose(self,index): return self.image_poses[index]