eval_pose_fist
This commit is contained in:
parent
a625632571
commit
3b9e9a1e8e
6
.gitignore
vendored
6
.gitignore
vendored
@ -49,4 +49,8 @@ models/**/build
|
|||||||
models/**/dist
|
models/**/dist
|
||||||
|
|
||||||
temp_local
|
temp_local
|
||||||
examples/results
|
examples/results
|
||||||
|
|
||||||
|
dyn_check
|
||||||
|
|
||||||
|
evaluation/saved
|
||||||
@ -39,7 +39,7 @@ model:
|
|||||||
window_len: 60
|
window_len: 60
|
||||||
stablizer: True
|
stablizer: True
|
||||||
mode: "online"
|
mode: "online"
|
||||||
s_wind: 200
|
s_wind: 500
|
||||||
overlap: 4
|
overlap: 4
|
||||||
track_num: 0
|
track_num: 0
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,8 @@ class VGGT4Track(nn.Module, PyTorchModelHubMixin):
|
|||||||
self,
|
self,
|
||||||
images: torch.Tensor,
|
images: torch.Tensor,
|
||||||
annots = {},
|
annots = {},
|
||||||
|
fx_prev = None,
|
||||||
|
fy_prev = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Forward pass of the VGGT4Track model.
|
Forward pass of the VGGT4Track model.
|
||||||
@ -85,16 +87,29 @@ class VGGT4Track(nn.Module, PyTorchModelHubMixin):
|
|||||||
predictions["unc_metric"] = depth_conf.view(B*T, H_proc, W_proc)
|
predictions["unc_metric"] = depth_conf.view(B*T, H_proc, W_proc)
|
||||||
|
|
||||||
predictions["images"] = (images)*255.0
|
predictions["images"] = (images)*255.0
|
||||||
|
|
||||||
# output the camera pose
|
# output the camera pose
|
||||||
predictions["poses_pred"] = torch.eye(4)[None].repeat(T, 1, 1)[None]
|
predictions["poses_pred"] = torch.eye(4)[None].repeat(T, 1, 1)[None]
|
||||||
predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
|
predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
|
||||||
images_proc.shape[-2:])
|
images_proc.shape[-2:])
|
||||||
predictions["poses_pred"] = torch.inverse(predictions["poses_pred"])
|
predictions["poses_pred"] = torch.inverse(predictions["poses_pred"])
|
||||||
|
|
||||||
|
if fx_prev is not None:
|
||||||
|
scale_x = torch.from_numpy(fx_prev).to(predictions["intrs"].device) / predictions["intrs"][0, :fx_prev.shape[0], 0, 0]
|
||||||
|
scale_x = scale_x.mean() * W_proc / W
|
||||||
|
predictions["intrs"][:, :, 0, 0] *= scale_x
|
||||||
|
if fy_prev is not None:
|
||||||
|
scale_y = torch.from_numpy(fy_prev).to(predictions["intrs"].device) / predictions["intrs"][0, :fy_prev.shape[0], 1, 1]
|
||||||
|
scale_y = scale_y.mean() * H_proc / H
|
||||||
|
predictions["intrs"][:, :, 1, 1] *= scale_y
|
||||||
|
|
||||||
|
# get the points map
|
||||||
points_map = depth_to_points_colmap(depth.view(B*T, H_proc, W_proc), predictions["intrs"].view(B*T, 3, 3))
|
points_map = depth_to_points_colmap(depth.view(B*T, H_proc, W_proc), predictions["intrs"].view(B*T, 3, 3))
|
||||||
predictions["points_map"] = points_map
|
predictions["points_map"] = points_map
|
||||||
#NOTE: resize back
|
#NOTE: resize back
|
||||||
predictions["points_map"] = F.interpolate(points_map.permute(0,3,1,2),
|
predictions["points_map"] = F.interpolate(points_map.permute(0,3,1,2),
|
||||||
size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
|
size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
|
||||||
|
|
||||||
predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
|
predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
|
||||||
size=(H, W), mode='bilinear', align_corners=True)[:,0]
|
size=(H, W), mode='bilinear', align_corners=True)[:,0]
|
||||||
predictions["intrs"][..., :1, :] *= W/W_proc
|
predictions["intrs"][..., :1, :] *= W/W_proc
|
||||||
|
|||||||
@ -145,7 +145,7 @@ def load_and_preprocess_images(image_path_list, mode="crop"):
|
|||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def preprocess_image(img_tensor, mode="crop", target_size=518):
|
def preprocess_image(img_tensor, mode="crop", target_size=518, keep_ratio=False):
|
||||||
"""
|
"""
|
||||||
Preprocess image tensor(s) to target size with crop or pad mode.
|
Preprocess image tensor(s) to target size with crop or pad mode.
|
||||||
Args:
|
Args:
|
||||||
@ -190,9 +190,10 @@ def preprocess_image(img_tensor, mode="crop", target_size=518):
|
|||||||
new_W = target_size
|
new_W = target_size
|
||||||
new_H = round(H * (new_W / W) / 14) * 14
|
new_H = round(H * (new_W / W) / 14) * 14
|
||||||
out = torch.nn.functional.interpolate(img.unsqueeze(0), size=(new_H, new_W), mode="bicubic", align_corners=False).squeeze(0)
|
out = torch.nn.functional.interpolate(img.unsqueeze(0), size=(new_H, new_W), mode="bicubic", align_corners=False).squeeze(0)
|
||||||
if new_H > target_size:
|
if keep_ratio==False:
|
||||||
start_y = (new_H - target_size) // 2
|
if new_H > target_size:
|
||||||
out = out[:, start_y : start_y + target_size, :]
|
start_y = (new_H - target_size) // 2
|
||||||
|
out = out[:, start_y : start_y + target_size, :]
|
||||||
processed.append(out)
|
processed.append(out)
|
||||||
result = torch.stack(processed)
|
result = torch.stack(processed)
|
||||||
if squeeze:
|
if squeeze:
|
||||||
|
|||||||
@ -46,7 +46,6 @@ def affine_invariant_global_loss(
|
|||||||
scale, shift = torch.where(valid, scale, 0), torch.where(valid[..., None], shift, 0)
|
scale, shift = torch.where(valid, scale, 0), torch.where(valid[..., None], shift, 0)
|
||||||
|
|
||||||
pred_points = scale[..., None, None, None] * pred_points + shift[..., None, None, :]
|
pred_points = scale[..., None, None, None] * pred_points + shift[..., None, None, :]
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
weight = (valid[..., None, None] & mask).float() / gt_points[..., 2].clamp_min(1e-5)
|
weight = (valid[..., None, None] & mask).float() / gt_points[..., 2].clamp_min(1e-5)
|
||||||
weight = weight.clamp_max(10.0 * weighted_mean(weight, mask, dim=(-2, -1), keepdim=True)) # In case your data contains extremely small depth values
|
weight = weight.clamp_max(10.0 * weighted_mean(weight, mask, dim=(-2, -1), keepdim=True)) # In case your data contains extremely small depth values
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user