2025-07-13 15:11:26 +08:00

238 lines
11 KiB
Python

import pycolmap
from models.SpaTrackV2.models.predictor import Predictor
import yaml
import easydict
import os
import numpy as np
import cv2
import torch
import torchvision.transforms as T
from PIL import Image
import io
import moviepy.editor as mp
from models.SpaTrackV2.utils.visualizer import Visualizer
import tqdm
from models.SpaTrackV2.models.utils import get_points_on_a_grid
import glob
from rich import print
import argparse
import decord
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri
import tqdm
from models.moge.train.losses import (
affine_invariant_global_loss,
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--track_mode", type=str, default="offline")
parser.add_argument("--data_type", type=str, default="RGBD")
parser.add_argument("--data_dir", type=str, default="dyn_check")
parser.add_argument("--video_name", type=str, default="snowboard")
parser.add_argument("--grid_size", type=int, default=2)
parser.add_argument("--vo_points", type=int, default=1024)
parser.add_argument("--fps", type=int, default=1)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
out_dir = args.data_dir + "/saved"
# fps
fps = int(args.fps)
mask_dir = args.data_dir + f"/{args.video_name}.png"
# dyn check root dir
dyn_check_root = "/mnt/bn/xyxdata/data/4d_data/dyn_check"
dyn_check_list = os.listdir(dyn_check_root)
dyn_check_list = [os.path.join(dyn_check_root, i) for i in dyn_check_list if os.path.isdir(os.path.join(dyn_check_root, i))]
# get the video name
vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
vggt4track_model.eval()
vggt4track_model = vggt4track_model.to("cuda")
if args.track_mode == "offline":
model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
else:
model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
# set the fps
MAX_LEN = 2000
wind_size = 60
overlap = 16
non_overlap = wind_size - overlap
model.S_wind = 500
for dyn_check_dir in tqdm.tqdm(dyn_check_list):
img_files = sorted(glob.glob(os.path.join(dyn_check_dir, "dense", "images", "*.png")))
# fps = max(1, len(img_files) // MAX_LEN)
fps = 2
video_tensor = torch.stack([torch.from_numpy(cv2.imread(i)).permute(2, 0, 1) for i in img_files])[::fps].float()
raw_len = len(video_tensor)
wind_num = max((raw_len - wind_size) // non_overlap + 1, 1)
if (wind_num-1)*non_overlap + wind_size < raw_len:
wind_num += 1
# record the intermediate results
video_tensor = preprocess_image(video_tensor)
T_vid, _, H, W = video_tensor.shape
points_map_list = np.zeros((T_vid, H, W, 3))
extrs_list = np.zeros((T_vid, 4, 4))
intrs_list = np.zeros((T_vid, 3, 3))
unc_metric_list = np.zeros((T_vid, H, W))
for i in tqdm.tqdm(range(wind_num)):
start_idx = i * non_overlap
end_idx = start_idx + wind_size
video_tensor_i = video_tensor[start_idx:end_idx]
# run the model
video_tensor_i = video_tensor_i.float().clone()
# process the image tensor
video_tensor_i = video_tensor_i[None]
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
# Predict attributes including cameras, depth maps, and point maps.
if i > 0:
overlap_len_i = min(overlap, depth_tensor.shape[0])
prev_intrs = intrs_list[start_idx:start_idx+overlap_len_i]
fx_prev, fy_prev = prev_intrs[:,0,0], prev_intrs[:,1,1]
else:
fx_prev, fy_prev = None, None
predictions = vggt4track_model(video_tensor_i.cuda()/255, fx_prev=fx_prev, fy_prev=fy_prev)
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
depth_tensor = depth_map.squeeze().cpu().numpy()
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
extrs = extrinsic.squeeze().cpu().numpy()
intrs = intrinsic.squeeze().cpu().numpy()
video_tensor = video_tensor.squeeze()
#NOTE: 20% of the depth is not reliable
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
# update the list
if i == 0:
points_map_list[start_idx:end_idx] = predictions["points_map"].squeeze().cpu().numpy()
extrs_list[start_idx:end_idx] = extrs
intrs_list[start_idx:end_idx] = intrs
unc_metric_list[start_idx:end_idx] = unc_metric
else:
# merge the list
next_clip_points = predictions["points_map"][:overlap_len_i].cuda().float().clone()
prev_clip_points = torch.from_numpy(points_map_list[start_idx:start_idx+overlap_len_i]).cuda().float()
mask_i = torch.from_numpy(unc_metric_list[start_idx:start_idx+overlap_len_i]).cuda().bool()
loss_i, _, scale_x = affine_invariant_global_loss(next_clip_points,
prev_clip_points,
mask_i,
align_resolution=32)
# update the list
scale_mean = scale_x.mean()
current_points = predictions["points_map"].clone()
current_points = current_points * scale_mean
#NOTE: chain the results
points_map_list[start_idx:end_idx] = current_points.squeeze().cpu().numpy()
intrs_list[start_idx:end_idx] = predictions["intrs"].squeeze().cpu().numpy()
unc_metric_list[start_idx:end_idx] = unc_metric
prev_extr = extrs_list[start_idx:start_idx+overlap_len_i][:1]
extrs_list[start_idx:end_idx] = prev_extr@extrs
# get the final results
intrs = intrs_list.astype(np.float32)
extrs = extrs_list.astype(np.float32)
unc_metric = np.bool_(unc_metric_list)
depth_tensor = points_map_list[..., 2]
data_npz_load = {}
if os.path.exists(mask_dir):
mask_files = mask_dir
mask = cv2.imread(mask_files)
mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
mask = mask.sum(axis=-1)>0
else:
mask = np.ones_like(video_tensor[0,0].numpy())>0
# get all data pieces
viz = True
os.makedirs(out_dir, exist_ok=True)
# config the model; the track_num is the number of points in the grid
model.spatrack.track_num = args.vo_points
model.eval()
model.to("cuda")
viser = Visualizer(save_dir=out_dir, grayscale=True,
fps=10, pad_value=0, tracks_leave_trace=5)
grid_size = args.grid_size
# get frame H W
if video_tensor is None:
cap = cv2.VideoCapture(video_path)
frame_H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
else:
frame_H, frame_W = video_tensor.shape[2:]
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
# Sample mask values at grid points and filter out points where mask=0
if os.path.exists(mask_dir):
grid_pts_int = grid_pts[0].long()
mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
grid_pts = grid_pts[:, mask_values]
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
# Run model inference
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
(
c2w_traj, intrs, point_map, conf_depth,
track3d_pred, track2d_pred, vis_pred, conf_pred, video
) = model.forward(video_tensor, depth=depth_tensor,
intrs=intrs, extrs=extrs,
queries=query_xyt,
fps=1, full_point=False, iters_track=6,
query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
support_frame=len(video_tensor)-1, replace_ratio=0.2)
# resize the results to avoid too large I/O Burden
# depth and image, the maximum side is 336
max_size = 336
h, w = video.shape[2:]
scale = min(max_size / h, max_size / w)
if scale < 1:
new_h, new_w = int(h * scale), int(w * scale)
video = T.Resize((new_h, new_w))(video)
video_tensor = T.Resize((new_h, new_w))(video_tensor)
point_map = T.Resize((new_h, new_w))(point_map)
conf_depth = T.Resize((new_h, new_w))(conf_depth)
track2d_pred[...,:2] = track2d_pred[...,:2] * scale
intrs[:,:2,:] = intrs[:,:2,:] * scale
if depth_tensor is not None:
if isinstance(depth_tensor, torch.Tensor):
depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
else:
depth_tensor = T.Resize((new_h, new_w))(torch.from_numpy(depth_tensor))
if viz:
viser.visualize(video=video[None],
tracks=track2d_pred[None][...,:2],
visibility=vis_pred[None],filename="test")
# save as the tapip3d format
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
data_npz_load["intrinsics"] = intrs.cpu().numpy()
depth_save = point_map[:,2,...]
depth_save[conf_depth<0.5] = 0
data_npz_load["depths"] = depth_save.cpu().numpy()
data_npz_load["video"] = (video_tensor).cpu().numpy()/255
data_npz_load["visibs"] = vis_pred.cpu().numpy()
data_npz_load["unc_metric"] = conf_depth.cpu().numpy()
data_npz_load["fps"] = fps
np.savez(os.path.join(out_dir, f'{dyn_check_dir.split("/")[-1]}.npz'), **data_npz_load)
print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")