import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import collections from easydict import EasyDict as edict import util from util import log,debug class Pose(): """ A class of operations on camera poses (PyTorch tensors with shape [...,3,4]) each [3,4] camera pose takes the form of [R|t] """ def __call__(self,R=None,t=None): # construct a camera pose from the given R and/or t assert(R is not None or t is not None) if R is None: if not isinstance(t,torch.Tensor): t = torch.tensor(t) R = torch.eye(3,device=t.device).repeat(*t.shape[:-1],1,1) elif t is None: if not isinstance(R,torch.Tensor): R = torch.tensor(R) t = torch.zeros(R.shape[:-1],device=R.device) else: if not isinstance(R,torch.Tensor): R = torch.tensor(R) if not isinstance(t,torch.Tensor): t = torch.tensor(t) assert(R.shape[:-1]==t.shape and R.shape[-2:]==(3,3)) R = R.float() t = t.float() pose = torch.cat([R,t[...,None]],dim=-1) # [...,3,4] assert(pose.shape[-2:]==(3,4)) return pose def invert(self,pose,use_inverse=False): # invert a camera pose R,t = pose[...,:3],pose[...,3:] R_inv = R.inverse() if use_inverse else R.transpose(-1,-2) t_inv = (-R_inv@t)[...,0] pose_inv = self(R=R_inv,t=t_inv) return pose_inv def compose(self,pose_list): # compose a sequence of poses together # pose_new(x) = poseN o ... o pose2 o pose1(x) pose_new = pose_list[0] for pose in pose_list[1:]: pose_new = self.compose_pair(pose_new,pose) return pose_new def compose_pair(self,pose_a,pose_b): # pose_new(x) = pose_b o pose_a(x) R_a,t_a = pose_a[...,:3],pose_a[...,3:] R_b,t_b = pose_b[...,:3],pose_b[...,3:] R_new = R_b@R_a t_new = (R_b@t_a+t_b)[...,0] pose_new = self(R=R_new,t=t_new) return pose_new class Lie(): """ Lie algebra for SO(3) and SE(3) operations in PyTorch """ def so3_to_SO3(self,w): # [...,3] wx = self.skew_symmetric(w) theta = w.norm(dim=-1)[...,None,None] I = torch.eye(3,device=w.device,dtype=torch.float32) A = self.taylor_A(theta) B = self.taylor_B(theta) R = I+A*wx+B*wx@wx return R def SO3_to_so3(self,R,eps=1e-7): # [...,3,3] trace = R[...,0,0]+R[...,1,1]+R[...,2,2] theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0] w = torch.stack([w0,w1,w2],dim=-1) return w def se3_to_SE3(self,wu): # [...,3] w,u = wu.split([3,3],dim=-1) wx = self.skew_symmetric(w) theta = w.norm(dim=-1)[...,None,None] I = torch.eye(3,device=w.device,dtype=torch.float32) A = self.taylor_A(theta) B = self.taylor_B(theta) C = self.taylor_C(theta) R = I+A*wx+B*wx@wx V = I+B*wx+C*wx@wx Rt = torch.cat([R,(V@u[...,None])],dim=-1) return Rt def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4] R,t = Rt.split([3,1],dim=-1) w = self.SO3_to_so3(R) wx = self.skew_symmetric(w) theta = w.norm(dim=-1)[...,None,None] I = torch.eye(3,device=w.device,dtype=torch.float32) A = self.taylor_A(theta) B = self.taylor_B(theta) invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx u = (invV@t)[...,0] wu = torch.cat([w,u],dim=-1) return wu def skew_symmetric(self,w): w0,w1,w2 = w.unbind(dim=-1) O = torch.zeros_like(w0) wx = torch.stack([torch.stack([O,-w2,w1],dim=-1), torch.stack([w2,O,-w0],dim=-1), torch.stack([-w1,w0,O],dim=-1)],dim=-2) return wx def taylor_A(self,x,nth=10): # Taylor expansion of sin(x)/x ans = torch.zeros_like(x) denom = 1. for i in range(nth+1): if i>0: denom *= (2*i)*(2*i+1) ans = ans+(-1)**i*x**(2*i)/denom return ans def taylor_B(self,x,nth=10): # Taylor expansion of (1-cos(x))/x**2 ans = torch.zeros_like(x) denom = 1. for i in range(nth+1): denom *= (2*i+1)*(2*i+2) ans = ans+(-1)**i*x**(2*i)/denom return ans def taylor_C(self,x,nth=10): # Taylor expansion of (x-sin(x))/x**3 ans = torch.zeros_like(x) denom = 1. for i in range(nth+1): denom *= (2*i+2)*(2*i+3) ans = ans+(-1)**i*x**(2*i)/denom return ans class Quaternion(): def q_to_R(self,q): # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion qa,qb,qc,qd = q.unbind(dim=-1) R = torch.stack([torch.stack([1-2*(qc**2+qd**2),2*(qb*qc-qa*qd),2*(qa*qc+qb*qd)],dim=-1), torch.stack([2*(qb*qc+qa*qd),1-2*(qb**2+qd**2),2*(qc*qd-qa*qb)],dim=-1), torch.stack([2*(qb*qd-qa*qc),2*(qa*qb+qc*qd),1-2*(qb**2+qc**2)],dim=-1)],dim=-2) return R def R_to_q(self,R,eps=1e-8): # [B,3,3] # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion # FIXME: this function seems a bit problematic, need to double-check row0,row1,row2 = R.unbind(dim=-2) R00,R01,R02 = row0.unbind(dim=-1) R10,R11,R12 = row1.unbind(dim=-1) R20,R21,R22 = row2.unbind(dim=-1) t = R[...,0,0]+R[...,1,1]+R[...,2,2] r = (1+t+eps).sqrt() qa = 0.5*r qb = (R21-R12).sign()*0.5*(1+R00-R11-R22+eps).sqrt() qc = (R02-R20).sign()*0.5*(1-R00+R11-R22+eps).sqrt() qd = (R10-R01).sign()*0.5*(1-R00-R11+R22+eps).sqrt() q = torch.stack([qa,qb,qc,qd],dim=-1) for i,qi in enumerate(q): if torch.isnan(qi).any(): K = torch.stack([torch.stack([R00-R11-R22,R10+R01,R20+R02,R12-R21],dim=-1), torch.stack([R10+R01,R11-R00-R22,R21+R12,R20-R02],dim=-1), torch.stack([R20+R02,R21+R12,R22-R00-R11,R01-R10],dim=-1), torch.stack([R12-R21,R20-R02,R01-R10,R00+R11+R22],dim=-1)],dim=-2)/3.0 K = K[i] eigval,eigvec = torch.linalg.eigh(K) V = eigvec[:,eigval.argmax()] q[i] = torch.stack([V[3],V[0],V[1],V[2]]) return q def invert(self,q): qa,qb,qc,qd = q.unbind(dim=-1) norm = q.norm(dim=-1,keepdim=True) q_inv = torch.stack([qa,-qb,-qc,-qd],dim=-1)/norm**2 return q_inv def product(self,q1,q2): # [B,4] q1a,q1b,q1c,q1d = q1.unbind(dim=-1) q2a,q2b,q2c,q2d = q2.unbind(dim=-1) hamil_prod = torch.stack([q1a*q2a-q1b*q2b-q1c*q2c-q1d*q2d, q1a*q2b+q1b*q2a+q1c*q2d-q1d*q2c, q1a*q2c-q1b*q2d+q1c*q2a+q1d*q2b, q1a*q2d+q1b*q2c-q1c*q2b+q1d*q2a],dim=-1) return hamil_prod pose = Pose() lie = Lie() quaternion = Quaternion() def to_hom(X): # get homogeneous coordinates of the input X_hom = torch.cat([X,torch.ones_like(X[...,:1])],dim=-1) return X_hom # basic operations of transforming 3D points between world/camera/image coordinates def world2cam(X,pose): # [B,N,3] X_hom = to_hom(X) return X_hom@pose.transpose(-1,-2) def cam2img(X,cam_intr): return X@cam_intr.transpose(-1,-2) def img2cam(X,cam_intr): return X@cam_intr.inverse().transpose(-1,-2) def cam2world(X,pose): X_hom = to_hom(X) pose_inv = Pose().invert(pose) return X_hom@pose_inv.transpose(-1,-2) def angle_to_rotation_matrix(a,axis): # get the rotation matrix from Euler angle around specific axis roll = dict(X=1,Y=2,Z=0)[axis] O = torch.zeros_like(a) I = torch.ones_like(a) M = torch.stack([torch.stack([a.cos(),-a.sin(),O],dim=-1), torch.stack([a.sin(),a.cos(),O],dim=-1), torch.stack([O,O,I],dim=-1)],dim=-2) M = M.roll((roll,roll),dims=(-2,-1)) return M def get_center_and_ray(opt,pose,intr=None): # [HW,2] # given the intrinsic/extrinsic matrices, get the camera center and ray directions] assert(opt.camera.model=="perspective") with torch.no_grad(): # compute image coordinate grid y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5) x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5) Y,X = torch.meshgrid(y_range,x_range) # [H,W] xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2] # compute center and ray batch_size = len(pose) xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2] grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3] center_3D = torch.zeros_like(grid_3D) # [B,HW,3] # transform from camera to world coordinates grid_3D = cam2world(grid_3D,pose) # [B,HW,3] center_3D = cam2world(center_3D,pose) # [B,HW,3] ray = grid_3D-center_3D # [B,HW,3] return center_3D,ray def get_3D_points_from_depth(opt,center,ray,depth,multi_samples=False): if multi_samples: center,ray = center[:,:,None],ray[:,:,None] # x = c+dv points_3D = center+ray*depth # [B,HW,3]/[B,HW,N,3]/[N,3] return points_3D def convert_NDC(opt,center,ray,intr,near=1): # shift camera center (ray origins) to near plane (z=1) # (unlike conventional NDC, we assume the cameras are facing towards the +z direction) center = center+(near-center[...,2:])/ray[...,2:]*ray # projection cx,cy,cz = center.unbind(dim=-1) # [B,HW] rx,ry,rz = ray.unbind(dim=-1) # [B,HW] scale_x = intr[:,0,0]/intr[:,0,2] # [B] scale_y = intr[:,1,1]/intr[:,1,2] # [B] cnx = scale_x[:,None]*(cx/cz) cny = scale_y[:,None]*(cy/cz) cnz = 1-2*near/cz rnx = scale_x[:,None]*(rx/rz-cx/cz) rny = scale_y[:,None]*(ry/rz-cy/cz) rnz = 2*near/cz center_ndc = torch.stack([cnx,cny,cnz],dim=-1) # [B,HW,3] ray_ndc = torch.stack([rnx,rny,rnz],dim=-1) # [B,HW,3] return center_ndc,ray_ndc def rotation_distance(R1,R2,eps=1e-7): # http://www.boris-belousov.net/2016/12/01/quat-dist/ R_diff = R1@R2.transpose(-2,-1) trace = R_diff[...,0,0]+R_diff[...,1,1]+R_diff[...,2,2] angle = ((trace-1)/2).clamp(-1+eps,1-eps).acos_() # numerical stability near -1/+1 return angle def procrustes_analysis(X0,X1): # [N,3] # translation t0 = X0.mean(dim=0,keepdim=True) t1 = X1.mean(dim=0,keepdim=True) X0c = X0-t0 X1c = X1-t1 # scale s0 = (X0c**2).sum(dim=-1).mean().sqrt() s1 = (X1c**2).sum(dim=-1).mean().sqrt() X0cs = X0c/s0 X1cs = X1c/s1 # rotation (use double for SVD, float loses precision) U,S,V = (X0cs.t()@X1cs).double().svd(some=True) R = (U@V.t()).float() if R.det()<0: R[2] *= -1 # align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0 sim3 = edict(t0=t0[0],t1=t1[0],s0=s0,s1=s1,R=R) return sim3 def get_novel_view_poses(opt,pose_anchor,N=60,scale=1): # create circular viewpoints (small oscillations) theta = torch.arange(N)/N*2*np.pi R_x = angle_to_rotation_matrix((theta.sin()*0.05).asin(),"X") R_y = angle_to_rotation_matrix((theta.cos()*0.05).asin(),"Y") pose_rot = pose(R=R_y@R_x) pose_shift = pose(t=[0,0,-4*scale]) pose_shift2 = pose(t=[0,0,3.8*scale]) pose_oscil = pose.compose([pose_shift,pose_rot,pose_shift2]) pose_novel = pose.compose([pose_oscil,pose_anchor.cpu()[None]]) return pose_novel