From 06e36ca0b179fde356f80c1d84863f4bfc458593 Mon Sep 17 00:00:00 2001 From: xiaoyuxi Date: Sun, 13 Jul 2025 15:11:26 +0800 Subject: [PATCH] online_prev --- evaluation/eval_dyn.py | 237 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 237 insertions(+) create mode 100644 evaluation/eval_dyn.py diff --git a/evaluation/eval_dyn.py b/evaluation/eval_dyn.py new file mode 100644 index 0000000..aeafac1 --- /dev/null +++ b/evaluation/eval_dyn.py @@ -0,0 +1,237 @@ +import pycolmap +from models.SpaTrackV2.models.predictor import Predictor +import yaml +import easydict +import os +import numpy as np +import cv2 +import torch +import torchvision.transforms as T +from PIL import Image +import io +import moviepy.editor as mp +from models.SpaTrackV2.utils.visualizer import Visualizer +import tqdm +from models.SpaTrackV2.models.utils import get_points_on_a_grid +import glob +from rich import print +import argparse +import decord +from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track +from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image +from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri +import tqdm +from models.moge.train.losses import ( + affine_invariant_global_loss, +) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--track_mode", type=str, default="offline") + parser.add_argument("--data_type", type=str, default="RGBD") + parser.add_argument("--data_dir", type=str, default="dyn_check") + parser.add_argument("--video_name", type=str, default="snowboard") + parser.add_argument("--grid_size", type=int, default=2) + parser.add_argument("--vo_points", type=int, default=1024) + parser.add_argument("--fps", type=int, default=1) + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + out_dir = args.data_dir + "/saved" + # fps + fps = int(args.fps) + mask_dir = args.data_dir + f"/{args.video_name}.png" + # dyn check root dir + dyn_check_root = "/mnt/bn/xyxdata/data/4d_data/dyn_check" + dyn_check_list = os.listdir(dyn_check_root) + dyn_check_list = [os.path.join(dyn_check_root, i) for i in dyn_check_list if os.path.isdir(os.path.join(dyn_check_root, i))] + # get the video name + vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front") + vggt4track_model.eval() + vggt4track_model = vggt4track_model.to("cuda") + if args.track_mode == "offline": + model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline") + else: + model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online") + # set the fps + MAX_LEN = 2000 + wind_size = 60 + overlap = 16 + non_overlap = wind_size - overlap + model.S_wind = 500 + for dyn_check_dir in tqdm.tqdm(dyn_check_list): + + img_files = sorted(glob.glob(os.path.join(dyn_check_dir, "dense", "images", "*.png"))) + # fps = max(1, len(img_files) // MAX_LEN) + fps = 2 + video_tensor = torch.stack([torch.from_numpy(cv2.imread(i)).permute(2, 0, 1) for i in img_files])[::fps].float() + raw_len = len(video_tensor) + wind_num = max((raw_len - wind_size) // non_overlap + 1, 1) + + if (wind_num-1)*non_overlap + wind_size < raw_len: + wind_num += 1 + + # record the intermediate results + video_tensor = preprocess_image(video_tensor) + T_vid, _, H, W = video_tensor.shape + points_map_list = np.zeros((T_vid, H, W, 3)) + extrs_list = np.zeros((T_vid, 4, 4)) + intrs_list = np.zeros((T_vid, 3, 3)) + unc_metric_list = np.zeros((T_vid, H, W)) + + for i in tqdm.tqdm(range(wind_num)): + start_idx = i * non_overlap + end_idx = start_idx + wind_size + video_tensor_i = video_tensor[start_idx:end_idx] + # run the model + video_tensor_i = video_tensor_i.float().clone() + # process the image tensor + video_tensor_i = video_tensor_i[None] + + with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + # Predict attributes including cameras, depth maps, and point maps. + if i > 0: + overlap_len_i = min(overlap, depth_tensor.shape[0]) + prev_intrs = intrs_list[start_idx:start_idx+overlap_len_i] + fx_prev, fy_prev = prev_intrs[:,0,0], prev_intrs[:,1,1] + else: + fx_prev, fy_prev = None, None + predictions = vggt4track_model(video_tensor_i.cuda()/255, fx_prev=fx_prev, fy_prev=fy_prev) + extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"] + depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"] + + depth_tensor = depth_map.squeeze().cpu().numpy() + extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0) + extrs = extrinsic.squeeze().cpu().numpy() + intrs = intrinsic.squeeze().cpu().numpy() + video_tensor = video_tensor.squeeze() + #NOTE: 20% of the depth is not reliable + # threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item() + unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5 + # update the list + if i == 0: + points_map_list[start_idx:end_idx] = predictions["points_map"].squeeze().cpu().numpy() + extrs_list[start_idx:end_idx] = extrs + intrs_list[start_idx:end_idx] = intrs + unc_metric_list[start_idx:end_idx] = unc_metric + else: + # merge the list + next_clip_points = predictions["points_map"][:overlap_len_i].cuda().float().clone() + prev_clip_points = torch.from_numpy(points_map_list[start_idx:start_idx+overlap_len_i]).cuda().float() + mask_i = torch.from_numpy(unc_metric_list[start_idx:start_idx+overlap_len_i]).cuda().bool() + loss_i, _, scale_x = affine_invariant_global_loss(next_clip_points, + prev_clip_points, + mask_i, + align_resolution=32) + # update the list + scale_mean = scale_x.mean() + current_points = predictions["points_map"].clone() + current_points = current_points * scale_mean + #NOTE: chain the results + points_map_list[start_idx:end_idx] = current_points.squeeze().cpu().numpy() + intrs_list[start_idx:end_idx] = predictions["intrs"].squeeze().cpu().numpy() + unc_metric_list[start_idx:end_idx] = unc_metric + prev_extr = extrs_list[start_idx:start_idx+overlap_len_i][:1] + extrs_list[start_idx:end_idx] = prev_extr@extrs + + # get the final results + intrs = intrs_list.astype(np.float32) + extrs = extrs_list.astype(np.float32) + unc_metric = np.bool_(unc_metric_list) + depth_tensor = points_map_list[..., 2] + + data_npz_load = {} + + if os.path.exists(mask_dir): + mask_files = mask_dir + mask = cv2.imread(mask_files) + mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2])) + mask = mask.sum(axis=-1)>0 + else: + mask = np.ones_like(video_tensor[0,0].numpy())>0 + + # get all data pieces + viz = True + os.makedirs(out_dir, exist_ok=True) + + # config the model; the track_num is the number of points in the grid + model.spatrack.track_num = args.vo_points + + model.eval() + model.to("cuda") + viser = Visualizer(save_dir=out_dir, grayscale=True, + fps=10, pad_value=0, tracks_leave_trace=5) + + grid_size = args.grid_size + + # get frame H W + if video_tensor is None: + cap = cv2.VideoCapture(video_path) + frame_H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frame_W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + else: + frame_H, frame_W = video_tensor.shape[2:] + grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu") + + # Sample mask values at grid points and filter out points where mask=0 + if os.path.exists(mask_dir): + grid_pts_int = grid_pts[0].long() + mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]] + grid_pts = grid_pts[:, mask_values] + + query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy() + + # Run model inference + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + ( + c2w_traj, intrs, point_map, conf_depth, + track3d_pred, track2d_pred, vis_pred, conf_pred, video + ) = model.forward(video_tensor, depth=depth_tensor, + intrs=intrs, extrs=extrs, + queries=query_xyt, + fps=1, full_point=False, iters_track=6, + query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric, + support_frame=len(video_tensor)-1, replace_ratio=0.2) + + # resize the results to avoid too large I/O Burden + # depth and image, the maximum side is 336 + + max_size = 336 + h, w = video.shape[2:] + scale = min(max_size / h, max_size / w) + if scale < 1: + new_h, new_w = int(h * scale), int(w * scale) + video = T.Resize((new_h, new_w))(video) + video_tensor = T.Resize((new_h, new_w))(video_tensor) + point_map = T.Resize((new_h, new_w))(point_map) + conf_depth = T.Resize((new_h, new_w))(conf_depth) + track2d_pred[...,:2] = track2d_pred[...,:2] * scale + intrs[:,:2,:] = intrs[:,:2,:] * scale + if depth_tensor is not None: + if isinstance(depth_tensor, torch.Tensor): + depth_tensor = T.Resize((new_h, new_w))(depth_tensor) + else: + depth_tensor = T.Resize((new_h, new_w))(torch.from_numpy(depth_tensor)) + + if viz: + viser.visualize(video=video[None], + tracks=track2d_pred[None][...,:2], + visibility=vis_pred[None],filename="test") + + # save as the tapip3d format + data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy() + data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy() + data_npz_load["intrinsics"] = intrs.cpu().numpy() + depth_save = point_map[:,2,...] + depth_save[conf_depth<0.5] = 0 + data_npz_load["depths"] = depth_save.cpu().numpy() + data_npz_load["video"] = (video_tensor).cpu().numpy()/255 + data_npz_load["visibs"] = vis_pred.cpu().numpy() + data_npz_load["unc_metric"] = conf_depth.cpu().numpy() + data_npz_load["fps"] = fps + np.savez(os.path.join(out_dir, f'{dyn_check_dir.split("/")[-1]}.npz'), **data_npz_load) + + print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")