2025-07-08 15:44:50 +08:00

1221 lines
45 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Modified from https://github.com/facebookresearch/PoseDiffusion
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Union, List
from einops import rearrange, repeat
import cv2
import numpy as np
# from torchmetrics.functional.regression import pearson_corrcoef
from easydict import EasyDict as edict
from enum import Enum
import torch.utils.data.distributed as dist
from typing import Literal, Union, List, Tuple, Dict
from models.monoD.depth_anything_v2.util.transform import Resize
from models.SpaTrackV2.utils.model_utils import sample_features5d
EPS = 1e-9
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def all_reduce(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(self.sum, np.ndarray):
total = torch.tensor(
self.sum.tolist()
+ [
self.count,
],
dtype=torch.float32,
device=device,
)
else:
total = torch.tensor(
[self.sum, self.count], dtype=torch.float32, device=device
)
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
if total.shape[0] > 2:
self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item()
else:
self.sum, self.count = total.tolist()
self.avg = self.sum / (self.count + 1e-5)
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = ""
if self.summary_type is Summary.NONE:
fmtstr = ""
elif self.summary_type is Summary.AVERAGE:
fmtstr = "{name} {avg:.3f}"
elif self.summary_type is Summary.SUM:
fmtstr = "{name} {sum:.3f}"
elif self.summary_type is Summary.COUNT:
fmtstr = "{name} {count:.3f}"
else:
raise ValueError("invalid summary type %r" % self.summary_type)
return fmtstr.format(**self.__dict__)
def procrustes_analysis(X0,X1): # [N,3]
# translation
t0 = X0.mean(dim=0,keepdim=True)
t1 = X1.mean(dim=0,keepdim=True)
X0c = X0-t0
X1c = X1-t1
# scale
s0 = (X0c**2).sum(dim=-1).mean().sqrt()
s1 = (X1c**2).sum(dim=-1).mean().sqrt()
X0cs = X0c/s0
X1cs = X1c/s1
# rotation (use double for SVD, float loses precision)
U,S,V = (X0cs.t()@X1cs).double().svd(some=True)
R = (U@V.t()).float()
if R.det()<0: R[2] *= -1
# align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0
sim3 = edict(t0=t0[0],t1=t1[0],s0=s0,s1=s1,R=R)
return sim3
def create_intri_matrix(focal_length, principal_point):
"""
Creates a intri matrix from focal length and principal point.
Args:
focal_length (torch.Tensor): A Bx2 or BxSx2 tensor containing the focal lengths (fx, fy) for each image.
principal_point (torch.Tensor): A Bx2 or BxSx2 tensor containing the principal point coordinates (cx, cy) for each image.
Returns:
torch.Tensor: A Bx3x3 or BxSx3x3 tensor containing the camera matrix for each image.
"""
if len(focal_length.shape) == 2:
B = focal_length.shape[0]
intri_matrix = torch.zeros(B, 3, 3, dtype=focal_length.dtype, device=focal_length.device)
intri_matrix[:, 0, 0] = focal_length[:, 0]
intri_matrix[:, 1, 1] = focal_length[:, 1]
intri_matrix[:, 2, 2] = 1.0
intri_matrix[:, 0, 2] = principal_point[:, 0]
intri_matrix[:, 1, 2] = principal_point[:, 1]
else:
B, S = focal_length.shape[0], focal_length.shape[1]
intri_matrix = torch.zeros(B, S, 3, 3, dtype=focal_length.dtype, device=focal_length.device)
intri_matrix[:, :, 0, 0] = focal_length[:, :, 0]
intri_matrix[:, :, 1, 1] = focal_length[:, :, 1]
intri_matrix[:, :, 2, 2] = 1.0
intri_matrix[:, :, 0, 2] = principal_point[:, :, 0]
intri_matrix[:, :, 1, 2] = principal_point[:, :, 1]
return intri_matrix
def closed_form_inverse_OpenCV(se3, R=None, T=None):
"""
Computes the inverse of each 4x4 SE3 matrix in the batch.
Args:
- se3 (Tensor): Nx4x4 tensor of SE3 matrices.
Returns:
- Tensor: Nx4x4 tensor of inverted SE3 matrices.
| R t |
| 0 1 |
-->
| R^T -R^T t|
| 0 1 |
"""
if R is None:
R = se3[:, :3, :3]
if T is None:
T = se3[:, :3, 3:]
# Compute the transpose of the rotation
R_transposed = R.transpose(1, 2)
# -R^T t
top_right = -R_transposed.bmm(T)
inverted_matrix = torch.eye(4, 4)[None].repeat(len(se3), 1, 1)
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
inverted_matrix[:, :3, :3] = R_transposed
inverted_matrix[:, :3, 3:] = top_right
return inverted_matrix
def get_EFP(pred_cameras, image_size, B, S, default_focal=False):
"""
Converting PyTorch3D cameras to extrinsics, intrinsics matrix
Return extrinsics, intrinsics, focal_length, principal_point
"""
scale = image_size.min()
focal_length = pred_cameras.focal_length
principal_point = torch.zeros_like(focal_length)
focal_length = focal_length * scale / 2
principal_point = (image_size[None] - principal_point * scale) / 2
Rots = pred_cameras.R.clone()
Trans = pred_cameras.T.clone()
extrinsics = torch.cat([Rots, Trans[..., None]], dim=-1)
# reshape
extrinsics = extrinsics.reshape(B, S, 3, 4)
focal_length = focal_length.reshape(B, S, 2)
principal_point = principal_point.reshape(B, S, 2)
# only one dof focal length
if default_focal:
focal_length[:] = scale
else:
focal_length = focal_length.mean(dim=-1, keepdim=True).expand(-1, -1, 2)
focal_length = focal_length.clamp(0.2 * scale, 5 * scale)
intrinsics = create_intri_matrix(focal_length, principal_point)
return extrinsics, intrinsics
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def pose_encoding_to_camera(
pose_encoding,
pose_encoding_type="absT_quaR_logFL",
log_focal_length_bias=1.8,
min_focal_length=0.1,
max_focal_length=30,
return_dict=False,
to_OpenCV=True,
):
"""
Args:
pose_encoding: A tensor of shape `BxNxC`, containing a batch of
`BxN` `C`-dimensional pose encodings.
pose_encoding_type: The type of pose encoding,
"""
pose_encoding_reshaped = pose_encoding.reshape(-1, pose_encoding.shape[-1]) # Reshape to BNxC
if pose_encoding_type == "absT_quaR_logFL":
# 3 for absT, 4 for quaR, 2 for absFL
abs_T = pose_encoding_reshaped[:, :3]
quaternion_R = pose_encoding_reshaped[:, 3:7]
R = quaternion_to_matrix(quaternion_R)
log_focal_length = pose_encoding_reshaped[:, 7:9]
# log_focal_length_bias was the hyperparameter
# to ensure the mean of logFL close to 0 during training
# Now converted back
focal_length = (log_focal_length + log_focal_length_bias).exp()
# clamp to avoid weird fl values
focal_length = torch.clamp(focal_length,
min=min_focal_length, max=max_focal_length)
elif pose_encoding_type == "absT_quaR_OneFL":
# 3 for absT, 4 for quaR, 1 for absFL
# [absolute translation, quaternion rotation, normalized focal length]
abs_T = pose_encoding_reshaped[:, :3]
quaternion_R = pose_encoding_reshaped[:, 3:7]
R = quaternion_to_matrix(quaternion_R)
focal_length = pose_encoding_reshaped[:, 7:8]
focal_length = torch.clamp(focal_length,
min=min_focal_length, max=max_focal_length)
else:
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
if to_OpenCV:
### From Pytorch3D coordinate to OpenCV coordinate:
# I hate coordinate conversion
R = R.clone()
abs_T = abs_T.clone()
R[:, :, :2] *= -1
abs_T[:, :2] *= -1
R = R.permute(0, 2, 1)
extrinsics_4x4 = torch.eye(4, 4).to(R.dtype).to(R.device)
extrinsics_4x4 = extrinsics_4x4[None].repeat(len(R), 1, 1)
extrinsics_4x4[:, :3, :3] = R.clone()
extrinsics_4x4[:, :3, 3] = abs_T.clone()
rel_transform = closed_form_inverse_OpenCV(extrinsics_4x4[0:1])
rel_transform = rel_transform.expand(len(extrinsics_4x4), -1, -1)
# relative to the first camera
# NOTE it is extrinsics_4x4 x rel_transform instead of rel_transform x extrinsics_4x4
extrinsics_4x4 = torch.bmm(extrinsics_4x4, rel_transform)
R = extrinsics_4x4[:, :3, :3].clone()
abs_T = extrinsics_4x4[:, :3, 3].clone()
if return_dict:
return {"focal_length": focal_length, "R": R, "T": abs_T}
pred_cameras = PerspectiveCameras(focal_length=focal_length,
R=R, T=abs_T, device=R.device, in_ndc=False)
return pred_cameras
def camera_to_pose_encoding(
camera, pose_encoding_type="absT_quaR_logFL",
log_focal_length_bias=1.8, min_focal_length=0.1, max_focal_length=30
):
"""
Inverse to pose_encoding_to_camera
"""
if pose_encoding_type == "absT_quaR_logFL":
# Convert rotation matrix to quaternion
quaternion_R = matrix_to_quaternion(camera.R)
# Calculate log_focal_length
log_focal_length = (
torch.log(torch.clamp(camera.focal_length,
min=min_focal_length, max=max_focal_length))
- log_focal_length_bias
)
# Concatenate to form pose_encoding
pose_encoding = torch.cat([camera.T, quaternion_R, log_focal_length], dim=-1)
elif pose_encoding_type == "absT_quaR_OneFL":
# [absolute translation, quaternion rotation, normalized focal length]
quaternion_R = matrix_to_quaternion(camera.R)
focal_length = (torch.clamp(camera.focal_length,
min=min_focal_length,
max=max_focal_length))[..., 0:1]
pose_encoding = torch.cat([camera.T, quaternion_R, focal_length], dim=-1)
else:
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
return pose_encoding
def init_pose_enc(B: int,
S: int, pose_encoding_type: str="absT_quaR_logFL",
device: Optional[torch.device]=None):
"""
Initialize the pose encoding tensor
args:
B: batch size
S: number of frames
pose_encoding_type: the type of pose encoding
device: device to put the tensor
return:
pose_enc: [B S C]
"""
if pose_encoding_type == "absT_quaR_logFL":
C = 9
elif pose_encoding_type == "absT_quaR_OneFL":
C = 8
else:
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
pose_enc = torch.zeros(B, S, C, device=device)
pose_enc[..., :3] = 0 # absT
pose_enc[..., 3] = 1 # quaR
pose_enc[..., 7:] = 1 # logFL
return pose_enc
def first_pose_enc_norm(pose_enc: torch.Tensor,
pose_encoding_type: str="absT_quaR_OneFL",
pose_mode: str = "W2C"):
"""
make sure the poses in on window are normalized by the first frame, where the
first frame transformation is the Identity Matrix.
NOTE: Poses are all W2C
args:
pose_enc: [B S C]
return:
pose_enc_norm: [B S C]
"""
B, S, C = pose_enc.shape
# Pose encoding to Cameras (Pytorch3D coordinate)
pred_cameras = pose_encoding_to_camera(
pose_enc, pose_encoding_type=pose_encoding_type,
to_OpenCV=False
) #NOTE: the camera parameters are not in NDC
R = pred_cameras.R # [B*S, 3, 3]
T = pred_cameras.T # [B*S, 3]
Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*S, 3, 4]
extra_ = torch.tensor([[[0, 0, 0, 1]]],
device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
Tran_M = torch.cat([Tran_M, extra_
], dim=1)
Tran_M = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)
# Take the first frame as the base of world coordinate
if pose_mode == "C2W":
Tran_M_new = (Tran_M[:,:1,...].inverse())@Tran_M
elif pose_mode == "W2C":
Tran_M_new = Tran_M@(Tran_M[:,:1,...].inverse())
Tran_M_new = rearrange(Tran_M_new, 'b s c d -> (b s) c d')
R_ = Tran_M_new[:, :3, :3]
T_ = Tran_M_new[:, :3, 3]
# Cameras to Pose encoding
pred_cameras.R = R_
pred_cameras.T = T_
pose_enc_norm = camera_to_pose_encoding(pred_cameras,
pose_encoding_type=pose_encoding_type)
pose_enc_norm = rearrange(pose_enc_norm, '(b s) c -> b s c', b=B)
return pose_enc_norm
def first_pose_enc_denorm(
pose_enc: torch.Tensor,
pose_enc_1st: torch.Tensor,
pose_encoding_type: str="absT_quaR_OneFL",
pose_mode: str = "W2C"):
"""
make sure the poses in on window are de-normalized by the first frame, where the
first frame transformation is the Identity Matrix.
args:
pose_enc: [B S C]
pose_enc_1st: [B 1 C]
return:
pose_enc_denorm: [B S C]
"""
B, S, C = pose_enc.shape
pose_enc_all = torch.cat([pose_enc_1st, pose_enc], dim=1)
# Pose encoding to Cameras (Pytorch3D coordinate)
pred_cameras = pose_encoding_to_camera(
pose_enc_all, pose_encoding_type=pose_encoding_type,
to_OpenCV=False
) #NOTE: the camera parameters are not in NDC
R = pred_cameras.R # [B*(1+S), 3, 3]
T = pred_cameras.T # [B*(1+S), 3]
Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B*(1+S), 3, 4]
extra_ = torch.tensor([[[0, 0, 0, 1]]],
device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)
Tran_M = torch.cat([Tran_M, extra_
], dim=1)
Tran_M_new = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)[:, 1:]
Tran_M_1st = rearrange(Tran_M, '(b s) c d -> b s c d', b=B)[:,:1]
if pose_mode == "C2W":
Tran_M_new = Tran_M_1st@Tran_M_new
elif pose_mode == "W2C":
Tran_M_new = Tran_M_new@Tran_M_1st
Tran_M_new_ = torch.cat([Tran_M_1st, Tran_M_new], dim=1)
R_ = Tran_M_new_[..., :3, :3].view(-1, 3, 3)
T_ = Tran_M_new_[..., :3, 3].view(-1, 3)
# Cameras to Pose encoding
pred_cameras.R = R_
pred_cameras.T = T_
# Cameras to Pose encoding
pose_enc_denorm = camera_to_pose_encoding(pred_cameras,
pose_encoding_type=pose_encoding_type)
pose_enc_denorm = rearrange(pose_enc_denorm, '(b s) c -> b s c', b=B)
return pose_enc_denorm[:, 1:]
def compute_scale_and_shift(prediction, target, mask):
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
a_00 = torch.sum(mask * prediction * prediction, (1, 2))
a_01 = torch.sum(mask * prediction, (1, 2))
a_11 = torch.sum(mask, (1, 2))
# right hand side: b = [b_0, b_1]
b_0 = torch.sum(mask * prediction * target, (1, 2))
b_1 = torch.sum(mask * target, (1, 2))
# solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
x_0 = torch.zeros_like(b_0)
x_1 = torch.zeros_like(b_1)
det = a_00 * a_11 - a_01 * a_01
# A needs to be a positive definite matrix.
valid = det > 0
x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
return x_0, x_1
def normalize_prediction_robust(target, mask, Bs):
ssum = torch.sum(mask, (1, 2))
valid = ssum > 0
m = torch.zeros_like(ssum).to(target.dtype)
s = torch.ones_like(ssum).to(target.dtype)
m[valid] = torch.median(
(mask[valid] * target[valid]).view(valid.sum(), -1), dim=1
).values
target = rearrange(target, '(b c) h w -> b c h w', b=Bs)
m_vid = rearrange(m, '(b c) -> b c 1 1', b=Bs) #.mean(dim=1, keepdim=True)
mask = rearrange(mask, '(b c) h w -> b c h w', b=Bs)
target = target - m_vid
sq = torch.sum(mask * target.abs(), (2, 3))
sq = rearrange(sq, 'b c -> (b c)')
s[valid] = torch.clamp((sq[valid] / ssum[valid]), min=1e-6)
s_vid = rearrange(s, '(b c) -> b c 1 1', b=Bs) #.mean(dim=1, keepdim=True)
target = target / s_vid
target = rearrange(target, 'b c h w -> (b c) h w', b=Bs)
return target, m_vid, s_vid
def normalize_video_robust(target, mask, Bs):
vid_valid = target[mask]
# downsample to 1/20
with torch.no_grad():
vid_valid = vid_valid[torch.randperm(vid_valid.shape[0], device='cuda')[:vid_valid.shape[0]//5]]
t_2, t_98 = torch.quantile(vid_valid, 0.02), torch.quantile(vid_valid, 0.98)
# normalize
target = (target - t_2) / (t_98 - t_2)*2 - 1
return target, t_2, t_98
def video_loss(prediction, target, mask, Bs):
# median norm
prediction_nm, a_norm, b_norm = normalize_video_robust(prediction, mask, Bs)
target_nm, a_norm_gt, b_norm_gt = normalize_video_robust(target.float(), mask, Bs)
depth_loss = nn.functional.l1_loss(prediction_nm[mask], target_nm[mask])
# rel depth 2 metric --> (pred - a')/(b'-a')*(b-a) + a
scale = (b_norm_gt - a_norm_gt) / (b_norm - a_norm)
shift = a_norm_gt - a_norm*scale
return depth_loss, scale, shift, prediction_nm, target_nm
def median_loss(prediction, target, mask, Bs):
# median norm
prediction_nm, a_norm, b_norm = normalize_prediction_robust(prediction, mask, Bs)
target_nm, a_norm_gt, b_norm_gt = normalize_prediction_robust(target.float(), mask, Bs)
depth_loss = nn.functional.l1_loss(prediction_nm[mask], target_nm[mask])
scale = b_norm_gt/b_norm
shift = a_norm_gt - a_norm*scale
return depth_loss, scale, shift, prediction_nm, target_nm
def reduction_batch_based(image_loss, M):
# average of all valid pixels of the batch
# avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
divisor = torch.sum(M)
if divisor == 0:
return 0
else:
return torch.sum(image_loss) / divisor
def reduction_image_based(image_loss, M):
# mean of average of valid pixels of an image
# avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
valid = M.nonzero()
image_loss[valid] = image_loss[valid] / M[valid]
return torch.mean(image_loss)
class ScaleAndShiftInvariantLoss(nn.Module):
def __init__(self):
super().__init__()
self.name = "SSILoss"
def forward(self, prediction, target, mask, Bs,
interpolate=True, return_interpolated=False):
if prediction.shape[-1] != target.shape[-1] and interpolate:
prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True)
intr_input = prediction
else:
intr_input = prediction
prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze()
assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}."
scale, shift = compute_scale_and_shift(prediction, target, mask)
a_norm = scale.view(Bs, -1, 1, 1).mean(dim=1, keepdim=True)
b_norm = shift.view(Bs, -1, 1, 1).mean(dim=1, keepdim=True)
prediction = rearrange(prediction, '(b c) h w -> b c h w', b=Bs)
target = rearrange(target, '(b c) h w -> b c h w', b=Bs)
mask = rearrange(mask, '(b c) h w -> b c h w', b=Bs)
scaled_prediction = a_norm * prediction + b_norm
loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask])
if not return_interpolated:
return loss, a_norm, b_norm
return loss, a_norm, b_norm
ScaleAndShiftInvariantLoss_fn = ScaleAndShiftInvariantLoss()
class GradientLoss(nn.Module):
def __init__(self, scales=4, reduction='batch-based'):
super().__init__()
if reduction == 'batch-based':
self.__reduction = reduction_batch_based
else:
self.__reduction = reduction_image_based
self.__scales = scales
def forward(self, prediction, target, mask):
total = 0
for scale in range(self.__scales):
step = pow(2, scale)
l1_ln, a_nm, b_nm = ScaleAndShiftInvariantLoss_fn(prediction[:, ::step, ::step],
target[:, ::step, ::step], mask[:, ::step, ::step], 1)
total += l1_ln
a_nm = a_nm.squeeze().detach() # [B, 1, 1]
b_nm = b_nm.squeeze().detach() # [B, 1, 1]
total += 2*gradient_loss(a_nm*prediction[:, ::step, ::step]+b_nm, target[:, ::step, ::step],
mask[:, ::step, ::step], reduction=self.__reduction)
return total
Grad_fn = GradientLoss()
def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
M = torch.sum(mask, (1, 2))
diff = prediction - target
diff = torch.mul(mask, diff)
grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
grad_x = torch.mul(mask_x, grad_x)
grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
grad_y = torch.mul(mask_y, grad_y)
image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
return reduction(image_loss, M)
def loss_fn(
poses_preds: List[torch.Tensor],
poses_pred_all: List[torch.Tensor],
poses_gt: torch.Tensor,
inv_depth_preds: List[torch.Tensor],
inv_depth_raw: List[torch.Tensor],
depths_gt: torch.Tensor,
S: int = 16,
gamma: float = 0.8,
logger=None,
logger_tf=None,
global_step=0,
):
"""
Args:
poses_preds: list of predicted poses
poses_gt: ground truth poses
inv_depth_preds: list of predicted inverse depth maps
depths_gt: ground truth depth maps
S: length of sliding window
"""
B, T, _, H, W = depths_gt.shape
loss_total = 0
for i in range(len(poses_preds)):
poses_preds_i = poses_preds[i][0]
poses_unc_i = poses_preds[i][1]
poses_gt_i = poses_gt[:, i*S//2:i*S//2+S,:]
poses_gt_i_norm = first_pose_enc_norm(poses_gt_i,
pose_encoding_type="absT_quaR_OneFL")
pose_loss = 0.0
for idx, poses_preds_ij in enumerate(poses_preds_i):
i_weight = gamma ** (len(poses_preds_i) - idx - 1)
if logger is not None:
if poses_preds_ij.max()>5e1:
logger.info(f"pose_pred_max_and_mean: {poses_preds_ij.max(), poses_preds_ij.mean()}")
trans_loss = (poses_preds_ij[...,:3] - poses_gt_i_norm[...,:3]).abs().sum(dim=-1).mean()
rot_loss = (poses_preds_ij[...,3:7] - poses_gt_i_norm[...,3:7]).abs().sum(dim=-1).mean()
focal_loss = (poses_preds_ij[...,7:] - poses_gt_i_norm[...,7:]).abs().sum(dim=-1).mean()
if torch.isnan((trans_loss + rot_loss + focal_loss)).any():
pose_loss += 0
else:
pose_loss += i_weight*(trans_loss + rot_loss + focal_loss)
if (logger_tf is not None)&(i==len(poses_preds)-1):
logger_tf.add_scalar(f"loss@pose/trans_iter{idx}",
trans_loss, global_step=global_step)
logger_tf.add_scalar(f"loss@pose/rot_iter{idx}",
rot_loss, global_step=global_step)
logger_tf.add_scalar(f"loss@pose/focal_iter{idx}",
focal_loss, global_step=global_step)
# compute the uncertainty loss
with torch.no_grad():
pose_loss_dist = (poses_preds_ij-poses_gt_i_norm).detach().abs()
pose_loss_std = 3*pose_loss_dist.view(-1,8).std(dim=0)[None,None,:]
gt_dist = F.relu(pose_loss_std - pose_loss_dist) / (pose_loss_std + 1e-3)
unc_loss = (gt_dist - poses_unc_i).abs().mean()
if (logger_tf is not None)&(i==len(poses_preds)-1):
logger_tf.add_scalar(f"loss@uncertainty/unc",
unc_loss,
global_step=global_step)
# if logger is not None:
# logger.info(f"pose_loss: {pose_loss}, unc_loss: {unc_loss}")
# total loss
loss_total += 0.1*unc_loss + 2*pose_loss
poses_gt_norm = poses_gt
pose_all_loss = 0.0
prev_loss = None
for idx, poses_preds_all_j in enumerate(poses_pred_all):
i_weight = gamma ** (len(poses_pred_all) - idx - 1)
trans_loss = (poses_preds_all_j[...,:3] - poses_gt_norm[...,:3]).abs().sum(dim=-1).mean()
rot_loss = (poses_preds_all_j[...,3:7] - poses_gt_norm[...,3:7]).abs().sum(dim=-1).mean()
focal_loss = (poses_preds_all_j[...,7:] - poses_gt_norm[...,7:]).abs().sum(dim=-1).mean()
if (logger_tf is not None):
if prev_loss is None:
prev_loss = (trans_loss + rot_loss + focal_loss)
else:
des_loss = (trans_loss + rot_loss + focal_loss) - prev_loss
prev_loss = trans_loss + rot_loss + focal_loss
logger_tf.add_scalar(f"loss@global_pose/des_iter{idx}",
des_loss, global_step=global_step)
logger_tf.add_scalar(f"loss@global_pose/trans_iter{idx}",
trans_loss, global_step=global_step)
logger_tf.add_scalar(f"loss@global_pose/rot_iter{idx}",
rot_loss, global_step=global_step)
logger_tf.add_scalar(f"loss@global_pose/focal_iter{idx}",
focal_loss, global_step=global_step)
if torch.isnan((trans_loss + rot_loss + focal_loss)).any():
pose_all_loss += 0
else:
pose_all_loss += i_weight*(trans_loss + rot_loss + focal_loss)
# if logger is not None:
# logger.info(f"global_pose_loss: {pose_all_loss}")
# compute the depth loss
if inv_depth_preds[0] is not None:
depths_gt = depths_gt[:,:,0]
msk = depths_gt > 5e-2
inv_gt = 1.0 / (depths_gt.clamp(1e-3, 1e16))
inv_gt_reshp = rearrange(inv_gt, 'b t h w -> (b t) h w')
inv_depth_preds_reshp = rearrange(inv_depth_preds[0], 'b t h w -> (b t) h w')
inv_raw_reshp = rearrange(inv_depth_raw[0], 'b t h w -> (b t) h w')
msk_reshp = rearrange(msk, 'b t h w -> (b t) h w')
huber_loss = ScaleAndShiftInvariantLoss_fn(inv_depth_preds_reshp, inv_gt_reshp, msk_reshp)
huber_loss_raw = ScaleAndShiftInvariantLoss_fn(inv_raw_reshp, inv_gt_reshp, msk_reshp)
# huber_loss = (inv_depth_preds[0][msk]-inv_gt[msk]).abs().mean()
# cal perason loss
perason_loss = 0
# for i in range(B):
# perason_loss += (1 - pearson_corrcoef(inv_depth_preds[0].view(B*T,-1), inv_gt.view(B*T,-1))).mean()
# perason_loss = perason_loss/B
if torch.isnan(huber_loss).any():
huber_loss = 0
depth_loss = huber_loss + perason_loss
if (logger_tf is not None)&(i==len(poses_preds)-1):
logger_tf.add_scalar(f"loss@depth/huber_iter{idx}",
depth_loss,
global_step=global_step)
# if logger is not None:
# logger.info(f"opt_depth: {huber_loss_raw - huber_loss}")
else:
depth_loss = 0.0
loss_total = loss_total/(len(poses_preds)) + 20*depth_loss + pose_all_loss
return loss_total, (huber_loss_raw - huber_loss)
def vis_depth(x: torch.tensor,
logger_tf = None, title: str = "depth", step: int = 0):
"""
args:
x: H W
"""
assert len(x.shape) == 2
depth_map_normalized = cv2.normalize(x.cpu().numpy(),
None, 0, 255, cv2.NORM_MINMAX)
depth_map_colored = cv2.applyColorMap(depth_map_normalized.astype(np.uint8),
cv2.COLORMAP_JET)
depth_map_tensor = torch.from_numpy(depth_map_colored).permute(2, 0, 1).unsqueeze(0)
if logger_tf is not None:
logger_tf.add_image(title, depth_map_tensor[0], step)
else:
return depth_map_tensor
def vis_pcd(
rgbs: torch.Tensor,
R: torch.Tensor,
T: torch.Tensor,
xy_depth: torch.Tensor,
focal_length: torch.Tensor,
pick_idx: List = [0]
):
"""
args:
rgbs: [S C H W]
R: [S 3 3]
T: [S 3]
xy_depth: [S H W 3]
focal_length: [S]
pick_idx: list of the index to pick
"""
S, C, H, W = rgbs.shape
rgbs_pick = rgbs[pick_idx]
R_pick = R[pick_idx]
T_pick = T[pick_idx]
xy_depth_pick = xy_depth[pick_idx]
focal_length_pick = focal_length[pick_idx]
pcd_world = depth2pcd(xy_depth_pick.clone(),
focal_length_pick, R_pick.clone(), T_pick.clone(),
device=xy_depth.device, H=H, W=W)
pcd_world = pcd_world.permute(0, 2, 1) #[...,[1,0,2]]
mask = pcd_world.reshape(-1,3)[:,2] < 20
rgb_world = rgbs_pick.view(len(pick_idx), 3, -1).permute(0, 2, 1)
pcl = Pointclouds(points=[pcd_world.reshape(-1,3)[mask]],
features=[rgb_world.reshape(-1,3)[mask]/255])
return pcl
def vis_result(rgbs, poses_pred, poses_gt,
depth_gt, depth_pred, iter_num=0,
vis=None, logger_tf=None, cfg=None):
"""
Args:
rgbs: [S C H W]
depths_gt: [S C H W]
poses_gt: [S C]
poses_pred: [S C]
depth_pred: [S H W]
"""
assert len(rgbs.shape) == 4, "only support one sequence, T 3 H W of rbg"
if vis is None:
return
S, _, H, W = depth_gt.shape
# get the xy
yx = torch.meshgrid(torch.arange(H).to(depth_pred.device),
torch.arange(W).to(depth_pred.device),indexing='ij')
xy = torch.stack(yx[::-1], dim=0).float().to(depth_pred.device)
xy_norm = (xy / torch.tensor([W, H],
device=depth_pred.device).view(2, 1, 1) - 0.5)*2
xy = xy[None].repeat(S, 1, 1, 1)
xy_depth = torch.cat([xy, depth_pred[:,None]], dim=1).permute(0, 2, 3, 1)
xy_depth_gt = torch.cat([xy, depth_gt], dim=1).permute(0, 2, 3, 1)
# get the focal length
focal_length = poses_gt[:,-1]*max(H, W)
# vis the camera poses
poses_gt_vis = pose_encoding_to_camera(poses_gt,
pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
poses_pred_vis = pose_encoding_to_camera(poses_pred,
pose_encoding_type="absT_quaR_OneFL",to_OpenCV=False)
R_gt = poses_gt_vis.R.float()
R_pred = poses_pred_vis.R.float()
T_gt = poses_gt_vis.T.float()
T_pred = poses_pred_vis.T.float()
# C2W poses
R_gt_c2w = R_gt.permute(0,2,1)
T_gt_c2w = (-R_gt_c2w @ T_gt[:, :, None]).squeeze(-1)
R_pred_c2w = R_pred.permute(0,2,1)
T_pred_c2w = (-R_pred_c2w @ T_pred[:, :, None]).squeeze(-1)
with torch.cuda.amp.autocast(enabled=False):
pick_idx = torch.randperm(S)[:min(24, S)]
# pick_idx = [1]
#NOTE: very strange that the camera need C2W Rotation and W2C translation as input
poses_gt_vis = PerspectiveCamerasVisual(
R=R_gt_c2w[pick_idx], T=T_gt[pick_idx],
device=poses_gt_vis.device, image_size=((H, W),)
)
poses_pred_vis = PerspectiveCamerasVisual(
R=R_pred_c2w[pick_idx], T=T_pred[pick_idx],
device=poses_pred_vis.device
)
visual_dict = {"scenes": {"cameras": poses_pred_vis, "cameras_gt": poses_gt_vis}}
env_name = f"train_visualize_iter_{iter_num:05d}"
print(f"Visualizing the scene by visdom at env: {env_name}")
# visualize the depth map
vis_depth(depth_pred[0].detach(), logger_tf, title="vis/depth_pred",step=iter_num)
msk = depth_pred[0] > 1e-3
vis_depth(depth_gt[0,0].detach(), logger_tf, title="vis/depth_gt",step=iter_num)
depth_res = (depth_gt[0,0] - depth_pred[0]).abs()
vis_depth(depth_res.detach(), logger_tf, title="vis/depth_res",step=iter_num)
# visualize the point cloud
if cfg.debug.vis_pcd:
visual_dict["scenes"]["points_gt"] = vis_pcd(rgbs, R_gt, T_gt,
xy_depth_gt, focal_length, pick_idx)
else:
visual_dict["scenes"]["points_pred"] = vis_pcd(rgbs, R_pred, T_pred,
xy_depth, focal_length, pick_idx)
# visualize in visdom
fig = plot_scene(visual_dict, camera_scale=0.05)
vis.plotlyplot(fig, env=env_name, win="3D")
vis.save([env_name])
return
def depth2pcd(
xy_depth: torch.Tensor,
focal_length: torch.Tensor,
R: torch.Tensor,
T: torch.Tensor,
device: torch.device = None,
H: int = 518,
W: int = 518
):
"""
args:
xy_depth: [S H W 3]
focal_length: [S]
R: [S 3 3] W2C
T: [S 3] W2C
return:
xyz: [S 3 (H W)]
"""
S, H, W, _ = xy_depth.shape
# get the intrinsic
K = torch.eye(3, device=device)[None].repeat(len(focal_length), 1, 1).to(device)
K[:, 0, 0] = focal_length
K[:, 1, 1] = focal_length
K[:, 0, 2] = 0.5 * W
K[:, 1, 2] = 0.5 * H
K_inv = K.inverse()
# xyz
xyz = xy_depth.view(S, -1, 3).permute(0, 2, 1) # S 3 (H W)
depth = xyz[:, 2:].clone() # S (H W) 1
xyz[:, 2] = 1
xyz = K_inv @ xyz # S 3 (H W)
xyz = xyz * depth
# to world coordinate
xyz = R.permute(0,2,1) @ (xyz - T[:, :, None])
return xyz
def pose_enc2mat(poses_pred,
H_resize, W_resize, resolution=336):
"""
This function convert the pose encoding into `intrinsic` and `extrinsic`
Args:
poses_pred: B T 8
Return:
Intrinsic B T 3 3
Extrinsic B T 4 4
"""
B, T, _ = poses_pred.shape
focal_pred = poses_pred[:, :, -1].clone()
pos_quat_preds = poses_pred[:, :, :7].clone()
pos_quat_preds = pos_quat_preds.view(B*T, -1)
# get extrinsic
c2w_rot = quaternion_to_matrix(pos_quat_preds[:, 3:])
c2w_tran = pos_quat_preds[:, :3]
c2w_traj = torch.eye(4)[None].repeat(B*T, 1, 1).to(poses_pred.device)
c2w_traj[:, :3, :3], c2w_traj[:, :3, 3] = c2w_rot, c2w_tran
c2w_traj = c2w_traj.view(B, T, 4, 4)
# get intrinsic
fxs, fys = focal_pred*resolution, focal_pred*resolution
intrs = torch.eye(3).to(c2w_traj.device).to(c2w_traj.dtype)[None, None].repeat(B, T, 1, 1)
intrs[:,:,0,0], intrs[:,:,1,1] = fxs, fys
intrs[:,:,0,2], intrs[:,:,1,2] = W_resize/2, H_resize/2
return intrs, c2w_traj
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
q_abs = _sqrt_positive_part(
torch.stack(
[1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
return standardize_quaternion(out)
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
# returns a meshgrid sized B x Y x X
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
grid_y = torch.reshape(grid_y, [1, Y, 1])
grid_y = grid_y.repeat(B, 1, X)
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
grid_x = torch.reshape(grid_x, [1, 1, X])
grid_x = grid_x.repeat(B, Y, 1)
if stack:
# note we stack in xy order
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
grid = torch.stack([grid_x, grid_y], dim=-1)
return grid
else:
return grid_y, grid_x
def get_points_on_a_grid(grid_size, interp_shape,
grid_center=(0, 0), device="cuda"):
if grid_size == 1:
return torch.tensor([interp_shape[1] / 2,
interp_shape[0] / 2], device=device)[
None, None
]
grid_y, grid_x = meshgrid2d(
1, grid_size, grid_size, stack=False, norm=False, device=device
)
step = interp_shape[1] // 64
if grid_center[0] != 0 or grid_center[1] != 0:
grid_y = grid_y - grid_size / 2.0
grid_x = grid_x - grid_size / 2.0
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
interp_shape[0] - step * 2
)
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
interp_shape[1] - step * 2
)
grid_y = grid_y + grid_center[0]
grid_x = grid_x + grid_center[1]
xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
return xy
def normalize_rgb(x,input_size=224,
resize_mode: Literal['resize', 'padding'] = 'resize',
if_da=False):
"""
normalize the image for depth anything input
args:
x: the input images [B T C H W]
"""
if isinstance(x, np.ndarray):
x = torch.from_numpy(x) / 255.0
elif isinstance(x, torch.Tensor):
x = x / 255.0
B, T, C, H, W = x.shape
x = x.view(B * T, C, H, W)
Resizer = Resize(
width=input_size,
height=input_size,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
)
if resize_mode == 'padding':
# zero padding to make the input size to be multiple of 14
if H > W:
H_scale = input_size
W_scale = W * input_size // H
else:
W_scale = input_size
H_scale = H * input_size // W
# resize the image
x = F.interpolate(x, size=(H_scale, W_scale),
mode='bilinear', align_corners=False)
# central padding the image
padding_x = (input_size - W_scale) // 2
padding_y = (input_size - H_scale) // 2
extra_x = (input_size - W_scale) % 2
extra_y = (input_size - H_scale) % 2
x = F.pad(x, (padding_x, padding_x+extra_x,
padding_y, padding_y+extra_y), value=0.)
elif resize_mode == 'resize':
H_scale, W_scale = Resizer.get_size(H, W)
x = F.interpolate(x, size=(int(H_scale), int(W_scale)),
mode='bicubic', align_corners=True)
# get the mean and std
__mean__ = torch.tensor([0.485,
0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
__std__ = torch.tensor([0.229,
0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
# normalize the image
if if_da:
x = (x - __mean__) / __std__
else:
x = x
return x.view(B, T, C, x.shape[-2], x.shape[-1])
def get_track_points(H, W, T, device, size=100, support_frame=0,
query_size=768, unc_metric=None, mode="mixed"):
"""
This function is used to get the points on the grid
args:
H: the height of the grid.
W: the width of the grid.
T: the number of frames.
device: the device of the points.
size: the size of the grid.
"""
grid_pts = get_points_on_a_grid(size, (H, W), device=device)
grid_pts = grid_pts.round()
if mode == "incremental":
queries = torch.cat(
[torch.randint_like(grid_pts[:, :, :1], T), grid_pts],
dim=2,
)
elif mode == "first":
queries_first = torch.cat(
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts],
dim=2,
)
queries_support = torch.cat(
[torch.randint_like(grid_pts[:, :, :1], T), grid_pts],
dim=2,
)
queries = torch.cat([queries_first, queries_support, queries_support], dim=1)
elif mode == "mixed":
queries = torch.cat(
[torch.randint_like(grid_pts[:, :, :1], T), grid_pts],
dim=2,
)
queries_first = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * support_frame, grid_pts],
dim=2,
)
queries = torch.cat([queries_first, queries, queries], dim=1)
if unc_metric is not None:
# filter the points with high uncertainty
sample_unc = sample_features5d(unc_metric[None], queries[:,None]).squeeze()
if ((sample_unc>0.5).sum() < 20):
queries = queries
else:
queries = queries[:,sample_unc>0.5,:]
idx_ = torch.randperm(queries.shape[1], device=device)[:query_size]
queries = queries[:, idx_]
return queries