4DGaussians/scene/neural_3D_dataset_NDC.py
2023-12-02 14:13:12 +08:00

378 lines
13 KiB
Python

import concurrent.futures
import gc
import glob
import os
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms as T
from tqdm import tqdm
def normalize(v):
"""Normalize a vector."""
return v / np.linalg.norm(v)
def average_poses(poses):
"""
Calculate the average pose, which is then used to center all poses
using @center_poses. Its computation is as follows:
1. Compute the center: the average of pose centers.
2. Compute the z axis: the normalized average z axis.
3. Compute axis y': the average y axis.
4. Compute x' = y' cross product z, then normalize it as the x axis.
5. Compute the y axis: z cross product x.
Note that at step 3, we cannot directly use y' as y axis since it's
not necessarily orthogonal to z axis. We need to pass from x to y.
Inputs:
poses: (N_images, 3, 4)
Outputs:
pose_avg: (3, 4) the average pose
"""
# 1. Compute the center
center = poses[..., 3].mean(0) # (3)
# 2. Compute the z axis
z = normalize(poses[..., 2].mean(0)) # (3)
# 3. Compute axis y' (no need to normalize as it's not the final output)
y_ = poses[..., 1].mean(0) # (3)
# 4. Compute the x axis
x = normalize(np.cross(z, y_)) # (3)
# 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
y = np.cross(x, z) # (3)
pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
return pose_avg
def center_poses(poses, blender2opencv):
"""
Center the poses so that we can use NDC.
See https://github.com/bmild/nerf/issues/34
Inputs:
poses: (N_images, 3, 4)
Outputs:
poses_centered: (N_images, 3, 4) the centered poses
pose_avg: (3, 4) the average pose
"""
poses = poses @ blender2opencv
pose_avg = average_poses(poses) # (3, 4)
pose_avg_homo = np.eye(4)
pose_avg_homo[
:3
] = pose_avg # convert to homogeneous coordinate for faster computation
pose_avg_homo = pose_avg_homo
# by simply adding 0, 0, 0, 1 as the last row
last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
poses_homo = np.concatenate(
[poses, last_row], 1
) # (N_images, 4, 4) homogeneous coordinate
poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
# poses_centered = poses_centered @ blender2opencv
poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
return poses_centered, pose_avg_homo
def viewmatrix(z, up, pos):
vec2 = normalize(z)
vec1_avg = up
vec0 = normalize(np.cross(vec1_avg, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.eye(4)
m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)
return m
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
render_poses = []
rads = np.array(list(rads) + [1.0])
for theta in np.linspace(0.0, 2.0 * np.pi * N_rots, N + 1)[:-1]:
c = np.dot(
c2w[:3, :4],
np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0])
* rads,
)
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0])))
render_poses.append(viewmatrix(z, up, c))
return render_poses
def process_video(video_data_save, video_path, img_wh, downsample, transform):
"""
Load video_path data to video_data_save tensor.
"""
video_frames = cv2.VideoCapture(video_path)
count = 0
video_images_path = video_path.split('.')[0]
image_path = os.path.join(video_images_path,"images")
if not os.path.exists(image_path):
os.makedirs(image_path)
while video_frames.isOpened():
ret, video_frame = video_frames.read()
if ret:
video_frame = cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB)
video_frame = Image.fromarray(video_frame)
if downsample != 1.0:
img = video_frame.resize(img_wh, Image.LANCZOS)
img.save(os.path.join(image_path,"%04d.png"%count))
img = transform(img)
video_data_save[count] = img.permute(1,2,0)
count += 1
else:
break
else:
images_path = os.listdir(image_path)
images_path.sort()
for path in images_path:
img = Image.open(os.path.join(image_path,path))
if downsample != 1.0:
img = img.resize(img_wh, Image.LANCZOS)
img = transform(img)
video_data_save[count] = img.permute(1,2,0)
count += 1
video_frames.release()
print(f"Video {video_path} processed.")
return None
# define a function to process all videos
def process_videos(videos, skip_index, img_wh, downsample, transform, num_workers=1):
"""
A multi-threaded function to load all videos fastly and memory-efficiently.
To save memory, we pre-allocate a tensor to store all the images and spawn multi-threads to load the images into this tensor.
"""
all_imgs = torch.zeros(len(videos) - 1, 300, img_wh[-1] , img_wh[-2], 3)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
# start a thread for each video
current_index = 0
futures = []
for index, video_path in enumerate(videos):
# skip the video with skip_index (eval video)
if index == skip_index:
continue
else:
future = executor.submit(
process_video,
all_imgs[current_index],
video_path,
img_wh,
downsample,
transform,
)
futures.append(future)
current_index += 1
return all_imgs
def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
"""
Generate a set of poses using NeRF's spiral camera trajectory as validation poses.
"""
# center pose
c2w = average_poses(c2ws_all)
# Get average pose
up = normalize(c2ws_all[:, :3, 1].sum(0))
# Find a reasonable "focus depth" for this dataset
dt = 0.75
close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0
focal = 1.0 / ((1.0 - dt) / close_depth + dt / inf_depth)
# Get radii for spiral path
zdelta = near_fars.min() * 0.2
tt = c2ws_all[:, :3, 3]
rads = np.percentile(np.abs(tt), 90, 0) * rads_scale
render_poses = render_path_spiral(
c2w, up, rads, focal, zdelta, zrate=0.5, N=N_views
)
return np.stack(render_poses)
class Neural3D_NDC_Dataset(Dataset):
def __init__(
self,
datadir,
split="train",
downsample=1.0,
is_stack=True,
cal_fine_bbox=False,
N_vis=-1,
time_scale=1.0,
scene_bbox_min=[-1.0, -1.0, -1.0],
scene_bbox_max=[1.0, 1.0, 1.0],
N_random_pose=1000,
bd_factor=0.75,
eval_step=1,
eval_index=0,
sphere_scale=1.0,
):
self.img_wh = (
int(1352 / downsample),
int(1014 / downsample),
) # According to the neural 3D paper, the default resolution is 1024x768
self.root_dir = datadir
self.split = split
self.downsample = 2704 / self.img_wh[0]
self.is_stack = is_stack
self.N_vis = N_vis
self.time_scale = time_scale
self.scene_bbox = torch.tensor([scene_bbox_min, scene_bbox_max])
self.world_bound_scale = 1.1
self.bd_factor = bd_factor
self.eval_step = eval_step
self.eval_index = eval_index
self.blender2opencv = np.eye(4)
self.transform = T.ToTensor()
self.near = 0.0
self.far = 1.0
self.near_far = [self.near, self.far] # NDC near far is [0, 1.0]
self.white_bg = False
self.ndc_ray = True
self.depth_data = False
self.load_meta()
print(f"meta data loaded, total image:{len(self)}")
def load_meta(self):
"""
Load meta data from the dataset.
"""
# Read poses and video file paths.
poses_arr = np.load(os.path.join(self.root_dir, "poses_bounds.npy"))
poses = poses_arr[:, :-2].reshape([-1, 3, 5]) # (N_cams, 3, 5)
self.near_fars = poses_arr[:, -2:]
videos = glob.glob(os.path.join(self.root_dir, "cam*"))
videos = sorted(videos)
# breakpoint()
assert len(videos) == poses_arr.shape[0]
H, W, focal = poses[0, :, -1]
focal = focal / self.downsample
self.focal = [focal, focal]
poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
# poses, _ = center_poses(
# poses, self.blender2opencv
# ) # Re-center poses so that the average is near the center.
# near_original = self.near_fars.min()
# scale_factor = near_original * 0.75
# self.near_fars /= (
# scale_factor # rescale nearest plane so that it is at z = 4/3.
# )
# poses[..., 3] /= scale_factor
# Sample N_views poses for validation - NeRF-like camera trajectory.
N_views = 300
self.val_poses = get_spiral(poses, self.near_fars, N_views=N_views)
# self.val_poses = self.directions
W, H = self.img_wh
poses_i_train = []
for i in range(len(poses)):
if i != self.eval_index:
poses_i_train.append(i)
self.poses = poses[poses_i_train]
self.poses_all = poses
self.image_paths, self.image_poses, self.image_times, N_cam, N_time = self.load_images_path(videos, self.split)
self.cam_number = N_cam
self.time_number = N_time
def get_val_pose(self):
render_poses = self.val_poses
render_times = torch.linspace(0.0, 1.0, render_poses.shape[0]) * 2.0 - 1.0
return render_poses, self.time_scale * render_times
def load_images_path(self,videos,split):
image_paths = []
image_poses = []
image_times = []
N_cams = 0
N_time = 0
countss = 300
for index, video_path in enumerate(videos):
if index == self.eval_index:
if split =="train":
continue
else:
if split == "test":
continue
N_cams +=1
count = 0
video_images_path = video_path.split('.')[0]
image_path = os.path.join(video_images_path,"images")
video_frames = cv2.VideoCapture(video_path)
if not os.path.exists(image_path):
print(f"no images saved in {image_path}, extract images from video.")
os.makedirs(image_path)
this_count = 0
while video_frames.isOpened():
ret, video_frame = video_frames.read()
if this_count >= countss:break
if ret:
video_frame = cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB)
video_frame = Image.fromarray(video_frame)
if self.downsample != 1.0:
img = video_frame.resize(self.img_wh, Image.LANCZOS)
img.save(os.path.join(image_path,"%04d.png"%count))
# img = transform(img)
count += 1
this_count+=1
else:
break
images_path = os.listdir(image_path)
images_path.sort()
this_count = 0
for idx, path in enumerate(images_path):
if this_count >=countss:break
image_paths.append(os.path.join(image_path,path))
pose = np.array(self.poses_all[index])
R = pose[:3,:3]
R = -R
R[:,0] = -R[:,0]
T = -pose[:3,3].dot(R)
image_times.append(idx/countss)
image_poses.append((R,T))
# if self.downsample != 1.0:
# img = video_frame.resize(self.img_wh, Image.LANCZOS)
# img.save(os.path.join(image_path,"%04d.png"%count))
this_count+=1
N_time = len(images_path)
# video_data_save[count] = img.permute(1,2,0)
# count += 1
return image_paths, image_poses, image_times, N_cams, N_time
def __len__(self):
return len(self.image_paths)
def __getitem__(self,index):
img = Image.open(self.image_paths[index])
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img)
return img, self.image_poses[index], self.image_times[index]
def load_pose(self,index):
return self.image_poses[index]