import os import cv2 import random import numpy as np from PIL import Image import torch from torch.utils.data import Dataset, DataLoader from torch.utils.data.sampler import Sampler from torchvision import transforms, utils import random def get_stamp_list(dataset, timestamp): frame_length = int(len(dataset)/len(dataset.dataset.poses)) # print(frame_length) if timestamp > frame_length: raise IndexError("input timestamp bigger than total timestamp.") print("select index:",[i*frame_length+timestamp for i in range(len(dataset.dataset.poses))]) return [dataset[i*frame_length+timestamp] for i in range(len(dataset.dataset.poses))] class FineSampler(Sampler): def __init__(self, dataset): self.len_dataset = len(dataset) self.len_pose = len(dataset.dataset.poses) self.frame_length = int(self.len_dataset/ self.len_pose) sample_list = [] for i in range(self.frame_length): for j in range(4): idx = torch.randperm(self.len_pose) *self.frame_length + i # print(idx) # breakpoint() now_list = [] cnt = 0 for item in idx.tolist(): now_list.append(item) cnt+=1 if cnt % 2 == 0 and len(sample_list)>2: select_element = [x for x in random.sample(sample_list,2)] now_list += select_element sample_list += now_list self.sample_list = sample_list # print(self.sample_list) # breakpoint() print("one epoch containing:",len(self.sample_list)) def __iter__(self): return iter(self.sample_list) def __len__(self): return len(self.sample_list)