4DGaussians/utils/loader_utils.py
2023-12-02 14:13:12 +08:00

53 lines
1.8 KiB
Python

import os
import cv2
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler
from torchvision import transforms, utils
import random
def get_stamp_list(dataset, timestamp):
frame_length = int(len(dataset)/len(dataset.dataset.poses))
# print(frame_length)
if timestamp > frame_length:
raise IndexError("input timestamp bigger than total timestamp.")
print("select index:",[i*frame_length+timestamp for i in range(len(dataset.dataset.poses))])
return [dataset[i*frame_length+timestamp] for i in range(len(dataset.dataset.poses))]
class FineSampler(Sampler):
def __init__(self, dataset):
self.len_dataset = len(dataset)
self.len_pose = len(dataset.dataset.poses)
self.frame_length = int(self.len_dataset/ self.len_pose)
sample_list = []
for i in range(self.frame_length):
for j in range(4):
idx = torch.randperm(self.len_pose) *self.frame_length + i
# print(idx)
# breakpoint()
now_list = []
cnt = 0
for item in idx.tolist():
now_list.append(item)
cnt+=1
if cnt % 2 == 0 and len(sample_list)>2:
select_element = [x for x in random.sample(sample_list,2)]
now_list += select_element
sample_list += now_list
self.sample_list = sample_list
# print(self.sample_list)
# breakpoint()
print("one epoch containing:",len(self.sample_list))
def __iter__(self):
return iter(self.sample_list)
def __len__(self):
return len(self.sample_list)