248 lines
8.5 KiB
Python
248 lines
8.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
# Adapted from https://github.com/amyxlase/relpose-plus-plus
|
|
|
|
import torch
|
|
import numpy as np
|
|
import math
|
|
|
|
|
|
|
|
|
|
def bbox_xyxy_to_xywh(xyxy):
|
|
wh = xyxy[2:] - xyxy[:2]
|
|
xywh = np.concatenate([xyxy[:2], wh])
|
|
return xywh
|
|
|
|
|
|
def adjust_camera_to_bbox_crop_(fl, pp, image_size_wh: torch.Tensor, clamp_bbox_xywh: torch.Tensor):
|
|
focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, image_size_wh)
|
|
|
|
principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2]
|
|
|
|
focal_length, principal_point_cropped = _convert_pixels_to_ndc(
|
|
focal_length_px, principal_point_px_cropped, clamp_bbox_xywh[2:]
|
|
)
|
|
|
|
return focal_length, principal_point_cropped
|
|
|
|
|
|
def adjust_camera_to_image_scale_(fl, pp, original_size_wh: torch.Tensor, new_size_wh: torch.LongTensor):
|
|
focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, original_size_wh)
|
|
|
|
# now scale and convert from pixels to NDC
|
|
image_size_wh_output = new_size_wh.float()
|
|
scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values
|
|
focal_length_px_scaled = focal_length_px * scale
|
|
principal_point_px_scaled = principal_point_px * scale
|
|
|
|
focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc(
|
|
focal_length_px_scaled, principal_point_px_scaled, image_size_wh_output
|
|
)
|
|
return focal_length_scaled, principal_point_scaled
|
|
|
|
|
|
def _convert_ndc_to_pixels(focal_length: torch.Tensor, principal_point: torch.Tensor, image_size_wh: torch.Tensor):
|
|
half_image_size = image_size_wh / 2
|
|
rescale = half_image_size.min()
|
|
principal_point_px = half_image_size - principal_point * rescale
|
|
focal_length_px = focal_length * rescale
|
|
return focal_length_px, principal_point_px
|
|
|
|
|
|
def _convert_pixels_to_ndc(
|
|
focal_length_px: torch.Tensor, principal_point_px: torch.Tensor, image_size_wh: torch.Tensor
|
|
):
|
|
half_image_size = image_size_wh / 2
|
|
rescale = half_image_size.min()
|
|
principal_point = (half_image_size - principal_point_px) / rescale
|
|
focal_length = focal_length_px / rescale
|
|
return focal_length, principal_point
|
|
|
|
|
|
def normalize_cameras(
|
|
cameras, compute_optical=True, first_camera=True, normalize_trans=True, scale=1.0, points=None, max_norm=False,
|
|
pose_mode="C2W"
|
|
):
|
|
"""
|
|
Normalizes cameras such that
|
|
(1) the optical axes point to the origin and the average distance to the origin is 1
|
|
(2) the first camera is the origin
|
|
(3) the translation vector is normalized
|
|
|
|
TODO: some transforms overlap with others. no need to do so many transforms
|
|
Args:
|
|
cameras (List[camera]).
|
|
"""
|
|
# Let distance from first camera to origin be unit
|
|
new_cameras = cameras.clone()
|
|
scale = 1.0
|
|
|
|
if compute_optical:
|
|
new_cameras, points = compute_optical_transform(new_cameras, points=points)
|
|
if first_camera:
|
|
new_cameras, points = first_camera_transform(new_cameras, points=points, pose_mode=pose_mode)
|
|
if normalize_trans:
|
|
new_cameras, points, scale = normalize_translation(new_cameras,
|
|
points=points, max_norm=max_norm)
|
|
return new_cameras, points, scale
|
|
|
|
|
|
def compute_optical_transform(new_cameras, points=None):
|
|
"""
|
|
adapted from https://github.com/amyxlase/relpose-plus-plus
|
|
"""
|
|
|
|
new_transform = new_cameras.get_world_to_view_transform()
|
|
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(new_cameras)
|
|
t = Translate(p_intersect)
|
|
scale = dist.squeeze()[0]
|
|
|
|
if points is not None:
|
|
points = t.inverse().transform_points(points)
|
|
points = points / scale
|
|
|
|
# Degenerate case
|
|
if scale == 0:
|
|
scale = torch.norm(new_cameras.T, dim=(0, 1))
|
|
scale = torch.sqrt(scale)
|
|
new_cameras.T = new_cameras.T / scale
|
|
else:
|
|
new_matrix = t.compose(new_transform).get_matrix()
|
|
new_cameras.R = new_matrix[:, :3, :3]
|
|
new_cameras.T = new_matrix[:, 3, :3] / scale
|
|
|
|
return new_cameras, points
|
|
|
|
|
|
def compute_optical_axis_intersection(cameras):
|
|
centers = cameras.get_camera_center()
|
|
principal_points = cameras.principal_point
|
|
|
|
one_vec = torch.ones((len(cameras), 1))
|
|
optical_axis = torch.cat((principal_points, one_vec), -1)
|
|
|
|
pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
|
|
|
|
pp2 = pp[torch.arange(pp.shape[0]), torch.arange(pp.shape[0])]
|
|
|
|
directions = pp2 - centers
|
|
centers = centers.unsqueeze(0).unsqueeze(0)
|
|
directions = directions.unsqueeze(0).unsqueeze(0)
|
|
|
|
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(p=centers, r=directions, mask=None)
|
|
|
|
p_intersect = p_intersect.squeeze().unsqueeze(0)
|
|
dist = (p_intersect - centers).norm(dim=-1)
|
|
|
|
return p_intersect, dist, p_line_intersect, pp2, r
|
|
|
|
|
|
def intersect_skew_line_groups(p, r, mask):
|
|
# p, r both of shape (B, N, n_intersected_lines, 3)
|
|
# mask of shape (B, N, n_intersected_lines)
|
|
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
|
|
_, p_line_intersect = _point_line_distance(p, r, p_intersect[..., None, :].expand_as(p))
|
|
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(dim=-1)
|
|
return p_intersect, p_line_intersect, intersect_dist_squared, r
|
|
|
|
|
|
def intersect_skew_lines_high_dim(p, r, mask=None):
|
|
# Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
|
|
dim = p.shape[-1]
|
|
# make sure the heading vectors are l2-normed
|
|
if mask is None:
|
|
mask = torch.ones_like(p[..., 0])
|
|
r = torch.nn.functional.normalize(r, dim=-1)
|
|
|
|
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
|
|
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
|
|
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
|
|
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
|
|
|
|
if torch.any(torch.isnan(p_intersect)):
|
|
print(p_intersect)
|
|
raise ValueError(f"p_intersect is NaN")
|
|
|
|
return p_intersect, r
|
|
|
|
|
|
def _point_line_distance(p1, r1, p2):
|
|
df = p2 - p1
|
|
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
|
|
line_pt_nearest = p2 - proj_vector
|
|
d = (proj_vector).norm(dim=-1)
|
|
return d, line_pt_nearest
|
|
|
|
|
|
def first_camera_transform(cameras, rotation_only=False,
|
|
points=None, pose_mode="C2W"):
|
|
"""
|
|
Transform so that the first camera is the origin
|
|
"""
|
|
|
|
new_cameras = cameras.clone()
|
|
# new_transform = new_cameras.get_world_to_view_transform()
|
|
|
|
R = cameras.R
|
|
T = cameras.T
|
|
Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B, 3, 4]
|
|
Tran_M = torch.cat([Tran_M,
|
|
torch.tensor([[[0, 0, 0, 1]]], device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)], dim=1)
|
|
if pose_mode == "C2W":
|
|
Tran_M_new = (Tran_M[:1,...].inverse())@Tran_M
|
|
elif pose_mode == "W2C":
|
|
Tran_M_new = Tran_M@(Tran_M[:1,...].inverse())
|
|
|
|
if False:
|
|
tR = Rotate(new_cameras.R[0].unsqueeze(0))
|
|
if rotation_only:
|
|
t = tR.inverse()
|
|
else:
|
|
tT = Translate(new_cameras.T[0].unsqueeze(0))
|
|
t = tR.compose(tT).inverse()
|
|
|
|
if points is not None:
|
|
points = t.inverse().transform_points(points)
|
|
|
|
if pose_mode == "C2W":
|
|
new_matrix = new_transform.compose(t).get_matrix()
|
|
else:
|
|
import ipdb; ipdb.set_trace()
|
|
new_matrix = t.compose(new_transform).get_matrix()
|
|
|
|
new_cameras.R = Tran_M_new[:, :3, :3]
|
|
new_cameras.T = Tran_M_new[:, :3, 3]
|
|
|
|
return new_cameras, points
|
|
|
|
|
|
def normalize_translation(new_cameras, points=None, max_norm=False):
|
|
t_gt = new_cameras.T.clone()
|
|
t_gt = t_gt[1:, :]
|
|
|
|
if max_norm:
|
|
t_gt_norm = torch.norm(t_gt, dim=(-1))
|
|
t_gt_scale = t_gt_norm.max()
|
|
if t_gt_norm.max() < 0.001:
|
|
t_gt_scale = torch.ones_like(t_gt_scale)
|
|
t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
|
|
else:
|
|
t_gt_norm = torch.norm(t_gt, dim=(0, 1))
|
|
t_gt_scale = t_gt_norm / math.sqrt(len(t_gt))
|
|
t_gt_scale = t_gt_scale / 2
|
|
if t_gt_norm.max() < 0.001:
|
|
t_gt_scale = torch.ones_like(t_gt_scale)
|
|
t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
|
|
|
|
new_cameras.T = new_cameras.T / t_gt_scale
|
|
|
|
if points is not None:
|
|
points = points / t_gt_scale
|
|
|
|
return new_cameras, points, t_gt_scale |