4DGaussians/scene/hyper_loader.py
guanjunwu 8bf73f413d 123
2023-09-24 19:51:57 +08:00

213 lines
8.5 KiB
Python

import warnings
warnings.filterwarnings("ignore")
import json
import os
import random
import numpy as np
import torch
from PIL import Image
import math
from tqdm import tqdm
from scene.utils import Camera
from typing import NamedTuple
# from scene.dataset_readers import CameraInfo
class CameraInfo(NamedTuple):
uid: int
R: np.array
T: np.array
FovY: np.array
FovX: np.array
image: np.array
image_path: str
image_name: str
width: int
height: int
time : float
# flow_f: np.array
# flow_mask_f: np.array
# flow_b: np.array
# flow_mask_b: np.array
# motion_mask: np.array
class Load_hyper_data():
def __init__(self,
datadir,
ratio=1.0,
use_bg_points=False,
add_cam=False):
from .utils import Camera
datadir = os.path.expanduser(datadir)
with open(f'{datadir}/scene.json', 'r') as f:
scene_json = json.load(f)
with open(f'{datadir}/metadata.json', 'r') as f:
meta_json = json.load(f)
with open(f'{datadir}/dataset.json', 'r') as f:
dataset_json = json.load(f)
self.near = scene_json['near']
self.far = scene_json['far']
self.coord_scale = scene_json['scale']
self.scene_center = scene_json['center']
self.all_img = dataset_json['ids']
self.val_id = dataset_json['val_ids']
self.add_cam = False
if len(self.val_id) == 0:
self.i_train = np.array([i for i in np.arange(len(self.all_img)) if
(i%4 == 0)])
self.i_test = self.i_train+2
self.i_test = self.i_test[:-1,]
else:
self.add_cam = True
self.train_id = dataset_json['train_ids']
self.i_test = []
self.i_train = []
for i in range(len(self.all_img)):
id = self.all_img[i]
if id in self.val_id:
self.i_test.append(i)
if id in self.train_id:
self.i_train.append(i)
assert self.add_cam == add_cam
print('self.i_train',self.i_train)
print('self.i_test',self.i_test)
self.all_cam = [meta_json[i]['camera_id'] for i in self.all_img]
self.all_time = [meta_json[i]['warp_id'] for i in self.all_img]
max_time = max(self.all_time)
self.all_time = [meta_json[i]['warp_id']/max_time for i in self.all_img]
self.selected_time = set(self.all_time)
self.ratio = ratio
self.max_time = max(self.all_time)
# all poses
self.all_cam_params = []
for im in self.all_img:
camera = Camera.from_json(f'{datadir}/camera/{im}.json')
camera = camera.scale(ratio)
camera.position = camera.position - self.scene_center
camera.position = camera.position * self.coord_scale
camera.orientation = camera.orientation.T
# camera.orientation[0:3, 1:3] *= -1 # switch cam coord x,y
camera.orientation = camera.orientation[[1, 0, 2], :] # switch world x,y
# camera.orientation[2, :] *= -1 # invert world z
camera.orientation = - camera.orientation
camera.orientation[:,0] = -camera.orientation[:,0]
camera.orientation = camera.orientation.T
camera.position = -camera.position.dot(camera.orientation)
self.all_cam_params.append(camera)
self.all_img = [f'{datadir}/rgb/{int(1/ratio)}x/{i}.png' for i in self.all_img]
self.h, self.w = self.all_cam_params[0].image_shape
self.use_bg_points = use_bg_points
if use_bg_points:
with open(f'{datadir}/points.npy', 'rb') as f:
points = np.load(f)
self.bg_points = (points - self.scene_center) * self.coord_scale
self.bg_points = torch.tensor(self.bg_points).float()
print(f'total {len(self.all_img)} images ',
'use cam =',self.add_cam,
'use bg_point=',self.use_bg_points)
def load_idx(self, idx,not_dic=False):
all_data = self.load_raw(idx)
if not_dic == True:
rays_o = all_data['rays_ori']
rays_d = all_data['rays_dir']
viewdirs = all_data['viewdirs']
rays_color = all_data['rays_color']
return rays_o, rays_d, viewdirs,rays_color
return all_data
def load_raw(self, idx):
camera = self.all_cam_params[idx]
image = Image.open(self.all_img[idx])
im_data = np.array(image.convert("RGBA"))
norm_data = im_data / 255.0
bg = np.array([1,1,1]) if self.use_bg_points else np.array([0, 0, 0])
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
rays_color = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
time = self.all_time[idx]
import math
# def calculate_corrected_fov(w, h, focal_length, k1, k2, k3, p1, p2, x, y):
# r = math.sqrt(x**2 + y**2)
# fovx_corrected = 2 * math.atan(w / (2 * (focal_length * (1 + k1 * r**2 + k2 * r**4 + k3 * r**6))) + 2 * p1 * x * y + p2 * (r**2 + 2 * x**2))
# fov_y_corrected = 2 * math.atan(h / (2 * (focal_length * (1 + k1 * r**2 + k2 * r**4 + k3 * r**6))) + 2 * p1 * x * y + p2 * (r**2 + 2 * y**2))
# return fovx_corrected, fov_y_corrected
# fovx, fovy = calculate_corrected_fov(rays_color.size[0],
# rays_color.size[1],
# camera.focal_length,
# camera.radial_distortion[0],
# camera.radial_distortion[1],
# camera.radial_distortion[2],
# camera.tangential_distortion[0],
# camera.tangential_distortion[1],
# 0,0)
pixels = camera.get_pixel_centers()
rays_dir_tensor = torch.tensor(camera.pixels_to_rays(pixels)).float().view([-1,3])
rays_ori_tensor = torch.tensor(camera.position[None, :]).float().expand_as(rays_dir_tensor)
rays_color_tensor = torch.tensor(np.array(image)).view([-1,3])/255.
# poses = np.eye(4)
# poses[:3, :3] = camera.orientation
# poses[:3, 3] = camera.position
# matrix = np.linalg.inv(np.array(poses))
# R = -np.transpose(matrix[:3,:3])
# R[:,0] = -R[:,0]
# T = -matrix[:3, 3]
return {'camera': camera,
'image_path':"/".join(self.all_img[idx].split("/")[:-1]),
"image_name":self.all_img[idx].split("/")[-1],
'image': rays_color,
'width':int(self.w),
'height':int(self.h),
'FovX':2 * math.atan(self.w / (2 * camera.focal_length)),
'FovY':2 * math.atan(self.h / (2 * camera.focal_length)),
'R':camera.orientation,
'T':camera.position,
'time':time,
'rays_ori': rays_ori_tensor,
'rays_dir': rays_dir_tensor,
'viewdirs':rays_dir_tensor / rays_dir_tensor.norm(dim=-1, keepdim=True),
'rays_color': rays_color_tensor,
'near': torch.tensor(self.near).float().view([-1]),
'far': torch.tensor(self.far).float().view([-1]),
}
def format_hyper_data(data_class, split):
if split == "train":
data_idx = data_class.i_train
elif split == "test":
data_idx = data_class.i_test
cam_infos = []
for uid, index in tqdm(enumerate(data_idx)):
frame_info = data_class.load_idx(index)
image = frame_info['image']
image_path = frame_info["image_path"]
image_name = frame_info["image_name"]
width = frame_info["width"]
height = frame_info["height"]
R = frame_info["R"]
T = frame_info["T"]
FovY = frame_info["FovY"]
FovX = frame_info["FovX"]
time = frame_info["time"]
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
image_path=image_path, image_name=image_name, width=width, height=height, time=time,
)
cam_infos.append(cam_info)
return cam_infos
# matrix = np.linalg.inv(np.array(poses))
# R = -np.transpose(matrix[:3,:3])
# R[:,0] = -R[:,0]
# T = -matrix[:3, 3]