4DGaussians/scene/dataset.py
2023-12-02 14:13:12 +08:00

47 lines
1.5 KiB
Python

from torch.utils.data import Dataset
from scene.cameras import Camera
import numpy as np
from utils.general_utils import PILtoTorch
from utils.graphics_utils import fov2focal, focal2fov
import torch
from utils.camera_utils import loadCam
from utils.graphics_utils import focal2fov
class FourDGSdataset(Dataset):
def __init__(
self,
dataset,
args,
dataset_type
):
self.dataset = dataset
self.args = args
self.dataset_type=dataset_type
def __getitem__(self, index):
# breakpoint()
if self.dataset_type != "PanopticSports":
try:
image, w2c, time = self.dataset[index]
R,T = w2c
FovX = focal2fov(self.dataset.focal[0], image.shape[2])
FovY = focal2fov(self.dataset.focal[0], image.shape[1])
mask=None
except:
caminfo = self.dataset[index]
image = caminfo.image
R = caminfo.R
T = caminfo.T
FovX = caminfo.FovX
FovY = caminfo.FovY
time = caminfo.time
mask = caminfo.mask
return Camera(colmap_id=index,R=R,T=T,FoVx=FovX,FoVy=FovY,image=image,gt_alpha_mask=None,
image_name=f"{index}",uid=index,data_device=torch.device("cuda"),time=time,
mask=mask)
else:
return self.dataset[index]
def __len__(self):
return len(self.dataset)