eval_pose_fist

This commit is contained in:
xiaoyuxi 2025-07-13 15:07:14 +08:00
parent a625632571
commit 3b9e9a1e8e
5 changed files with 26 additions and 7 deletions

6
.gitignore vendored
View File

@ -49,4 +49,8 @@ models/**/build
models/**/dist
temp_local
examples/results
examples/results
dyn_check
evaluation/saved

View File

@ -39,7 +39,7 @@ model:
window_len: 60
stablizer: True
mode: "online"
s_wind: 200
s_wind: 500
overlap: 4
track_num: 0

View File

@ -31,6 +31,8 @@ class VGGT4Track(nn.Module, PyTorchModelHubMixin):
self,
images: torch.Tensor,
annots = {},
fx_prev = None,
fy_prev = None,
**kwargs):
"""
Forward pass of the VGGT4Track model.
@ -85,16 +87,29 @@ class VGGT4Track(nn.Module, PyTorchModelHubMixin):
predictions["unc_metric"] = depth_conf.view(B*T, H_proc, W_proc)
predictions["images"] = (images)*255.0
# output the camera pose
predictions["poses_pred"] = torch.eye(4)[None].repeat(T, 1, 1)[None]
predictions["poses_pred"][:,:,:3,:4], predictions["intrs"] = pose_encoding_to_extri_intri(predictions["pose_enc_list"][-1],
images_proc.shape[-2:])
predictions["poses_pred"] = torch.inverse(predictions["poses_pred"])
if fx_prev is not None:
scale_x = torch.from_numpy(fx_prev).to(predictions["intrs"].device) / predictions["intrs"][0, :fx_prev.shape[0], 0, 0]
scale_x = scale_x.mean() * W_proc / W
predictions["intrs"][:, :, 0, 0] *= scale_x
if fy_prev is not None:
scale_y = torch.from_numpy(fy_prev).to(predictions["intrs"].device) / predictions["intrs"][0, :fy_prev.shape[0], 1, 1]
scale_y = scale_y.mean() * H_proc / H
predictions["intrs"][:, :, 1, 1] *= scale_y
# get the points map
points_map = depth_to_points_colmap(depth.view(B*T, H_proc, W_proc), predictions["intrs"].view(B*T, 3, 3))
predictions["points_map"] = points_map
#NOTE: resize back
predictions["points_map"] = F.interpolate(points_map.permute(0,3,1,2),
size=(H, W), mode='bilinear', align_corners=True).permute(0,2,3,1)
predictions["unc_metric"] = F.interpolate(predictions["unc_metric"][:,None],
size=(H, W), mode='bilinear', align_corners=True)[:,0]
predictions["intrs"][..., :1, :] *= W/W_proc

View File

@ -145,7 +145,7 @@ def load_and_preprocess_images(image_path_list, mode="crop"):
return images
def preprocess_image(img_tensor, mode="crop", target_size=518):
def preprocess_image(img_tensor, mode="crop", target_size=518, keep_ratio=False):
"""
Preprocess image tensor(s) to target size with crop or pad mode.
Args:
@ -190,9 +190,10 @@ def preprocess_image(img_tensor, mode="crop", target_size=518):
new_W = target_size
new_H = round(H * (new_W / W) / 14) * 14
out = torch.nn.functional.interpolate(img.unsqueeze(0), size=(new_H, new_W), mode="bicubic", align_corners=False).squeeze(0)
if new_H > target_size:
start_y = (new_H - target_size) // 2
out = out[:, start_y : start_y + target_size, :]
if keep_ratio==False:
if new_H > target_size:
start_y = (new_H - target_size) // 2
out = out[:, start_y : start_y + target_size, :]
processed.append(out)
result = torch.stack(processed)
if squeeze:

View File

@ -46,7 +46,6 @@ def affine_invariant_global_loss(
scale, shift = torch.where(valid, scale, 0), torch.where(valid[..., None], shift, 0)
pred_points = scale[..., None, None, None] * pred_points + shift[..., None, None, :]
# Compute loss
weight = (valid[..., None, None] & mask).float() / gt_points[..., 2].clamp_min(1e-5)
weight = weight.clamp_max(10.0 * weighted_mean(weight, mask, dim=(-2, -1), keepdim=True)) # In case your data contains extremely small depth values