From f427674f3f29a2a4f0ff79153c63e82cffe13235 Mon Sep 17 00:00:00 2001 From: xiaoyuxi Date: Wed, 9 Jul 2025 18:31:17 +0800 Subject: [PATCH] fix_md --- .gitignore | 3 +- app.py | 60 +++++++------- inference.py | 52 ++++++------- models/SpaTrackV2/models/SpaTrack.py | 78 +------------------ models/SpaTrackV2/models/predictor.py | 70 +---------------- .../models/tracker3D/TrackRefiner.py | 58 ++------------ .../models/vggt4track}/__init__.py | 0 .../models/vggt4track}/heads/camera_head.py | 6 +- .../models/vggt4track}/heads/dpt_head.py | 0 .../models/vggt4track}/heads/head_act.py | 0 .../models/vggt4track}/heads/scale_head.py | 6 +- .../models/vggt4track}/heads/track_head.py | 0 .../heads/track_modules/__init__.py | 0 .../track_modules/base_track_predictor.py | 0 .../vggt4track}/heads/track_modules/blocks.py | 0 .../heads/track_modules/modules.py | 0 .../vggt4track}/heads/track_modules/utils.py | 0 .../models/vggt4track}/heads/utils.py | 0 .../models/vggt4track}/layers/__init__.py | 0 .../models/vggt4track}/layers/attention.py | 0 .../models/vggt4track}/layers/block.py | 0 .../models/vggt4track}/layers/drop_path.py | 0 .../models/vggt4track}/layers/layer_scale.py | 0 .../models/vggt4track}/layers/mlp.py | 0 .../models/vggt4track}/layers/patch_embed.py | 0 .../models/vggt4track}/layers/rope.py | 0 .../models/vggt4track}/layers/swiglu_ffn.py | 0 .../vggt4track}/layers/vision_transformer.py | 0 .../models/vggt4track}/models/aggregator.py | 8 +- .../vggt4track}/models/aggregator_front.py | 8 +- .../vggt4track}/models/tracker_front.py | 10 +-- .../models/vggt4track}/models/vggt.py | 0 .../models/vggt4track}/models/vggt_moe.py | 14 ++-- .../models/vggt4track}/utils/__init__.py | 0 .../models/vggt4track}/utils/geometry.py | 0 .../models/vggt4track}/utils/load_fn.py | 0 .../models/vggt4track}/utils/loss.py | 2 +- .../models/vggt4track}/utils/pose_enc.py | 0 .../models/vggt4track}/utils/rotation.py | 0 .../models/vggt4track}/utils/visual_track.py | 0 models/vggt/setup.py | 8 -- 41 files changed, 93 insertions(+), 290 deletions(-) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/__init__.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/heads/camera_head.py (96%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/heads/dpt_head.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/heads/head_act.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/heads/scale_head.py (96%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/heads/track_head.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/heads/track_modules/__init__.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/heads/track_modules/base_track_predictor.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/heads/track_modules/blocks.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/heads/track_modules/modules.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/heads/track_modules/utils.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/heads/utils.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/layers/__init__.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/layers/attention.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/layers/block.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/layers/drop_path.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/layers/layer_scale.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/layers/mlp.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/layers/patch_embed.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/layers/rope.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/layers/swiglu_ffn.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/layers/vision_transformer.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/models/aggregator.py (97%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/models/aggregator_front.py (97%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/models/tracker_front.py (95%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/models/vggt.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/models/vggt_moe.py (90%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/utils/__init__.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/utils/geometry.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/utils/load_fn.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/utils/loss.py (97%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/utils/pose_enc.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/utils/rotation.py (100%) rename models/{vggt/vggt => SpaTrackV2/models/vggt4track}/utils/visual_track.py (100%) delete mode 100644 models/vggt/setup.py diff --git a/.gitignore b/.gitignore index 038265d..12d8869 100755 --- a/.gitignore +++ b/.gitignore @@ -48,4 +48,5 @@ config/fix_2d.yaml models/**/build models/**/dist -temp_local \ No newline at end of file +temp_local +examples/results \ No newline at end of file diff --git a/app.py b/app.py index a14c49e..2255da9 100644 --- a/app.py +++ b/app.py @@ -26,6 +26,9 @@ import logging from concurrent.futures import ThreadPoolExecutor import atexit import uuid +from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track +from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image +from models.SpaTrackV2.models.predictor import Predictor # Configure logging logging.basicConfig(level=logging.INFO) @@ -78,20 +81,15 @@ def create_user_temp_dir(): return temp_dir from huggingface_hub import hf_hub_download -# init the model -os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True) -if os.environ.get("VGGT_DIR", None) is not None: - from models.vggt.vggt.models.vggt_moe import VGGT4Track - from models.vggt.vggt.utils.load_fn import preprocess_image - vggt_model = VGGT4Track() - vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")), strict=False) - vggt_model.eval() - vggt_model = vggt_model.to("cuda") +vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front") +vggt4track_model.eval() +vggt4track_model = vggt4track_model.to("cuda") # Global model initialization print("🚀 Initializing local models...") -tracker_model, _ = get_tracker_predictor(".", vo_points=756) +tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline") +tracker_model.eval() predictor = get_sam_predictor() print("✅ Models loaded successfully!") @@ -129,7 +127,8 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, print("Initializing tracker models inside GPU function...") out_dir = os.path.join(temp_dir, "results") os.makedirs(out_dir, exist_ok=True) - tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, tracker_model=tracker_model) + tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, + tracker_model=tracker_model.cuda()) # Setup paths video_path = os.path.join(temp_dir, f"{video_name}.mp4") @@ -159,25 +158,23 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, data_npz_load = {} # run vggt - if os.environ.get("VGGT_DIR", None) is not None: - # process the image tensor - video_tensor = preprocess_image(video_tensor)[None] - with torch.no_grad(): - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - # Predict attributes including cameras, depth maps, and point maps. - predictions = vggt_model(video_tensor.cuda()/255) - extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"] - depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"] - - depth_tensor = depth_map.squeeze().cpu().numpy() - extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0) - extrs = extrinsic.squeeze().cpu().numpy() - intrs = intrinsic.squeeze().cpu().numpy() - video_tensor = video_tensor.squeeze() - #NOTE: 20% of the depth is not reliable - # threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item() - unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5 + # process the image tensor + video_tensor = preprocess_image(video_tensor)[None] + with torch.no_grad(): + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + # Predict attributes including cameras, depth maps, and point maps. + predictions = vggt4track_model(video_tensor.cuda()/255) + extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"] + depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"] + depth_tensor = depth_map.squeeze().cpu().numpy() + extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0) + extrs = extrinsic.squeeze().cpu().numpy() + intrs = intrinsic.squeeze().cpu().numpy() + video_tensor = video_tensor.squeeze() + #NOTE: 20% of the depth is not reliable + # threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item() + unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5 # Load and process mask if os.path.exists(mask_path): mask = cv2.imread(mask_path) @@ -199,7 +196,6 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy() print(f"Query points shape: {query_xyt.shape}") - # Run model inference with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): ( @@ -210,8 +206,8 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, queries=query_xyt, fps=1, full_point=False, iters_track=4, query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric, - support_frame=len(video_tensor)-1, replace_ratio=0.2) - + support_frame=len(video_tensor)-1, replace_ratio=0.2) + # Resize results to avoid large I/O max_size = 224 h, w = video.shape[2:] diff --git a/inference.py b/inference.py index a95c78c..de9d958 100644 --- a/inference.py +++ b/inference.py @@ -17,16 +17,14 @@ import glob from rich import print import argparse import decord -from models.vggt.vggt.models.vggt_moe import VGGT4Track -from models.vggt.vggt.utils.load_fn import preprocess_image -from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri -from huggingface_hub import hf_hub_download +from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track +from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image +from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--cfg_dir", type=str, default="config/magic_infer_offline.yaml") + parser.add_argument("--track_mode", type=str, default="offline") parser.add_argument("--data_type", type=str, default="RGBD") - parser.add_argument("--VGGT", action="store_true") parser.add_argument("--data_dir", type=str, default="assets/example0") parser.add_argument("--video_name", type=str, default="snowboard") parser.add_argument("--grid_size", type=int, default=10) @@ -36,20 +34,14 @@ def parse_args(): if __name__ == "__main__": args = parse_args() - cfg_dir = args.cfg_dir out_dir = args.data_dir + "/results" # fps fps = int(args.fps) mask_dir = args.data_dir + f"/{args.video_name}.png" - os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", - "spatrack_front.pth") #, force_download=True) - VGGT_DIR = os.environ["VGGT_DIR"] - assert os.path.exists(VGGT_DIR), f"VGGT_DIR {VGGT_DIR} does not exist" - front_track = VGGT4Track() - front_track.load_state_dict(torch.load(VGGT_DIR), strict=False) - front_track.eval() - front_track = front_track.to("cuda") + vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front") + vggt4track_model.eval() + vggt4track_model = vggt4track_model.to("cuda") if args.data_type == "RGBD": npz_dir = args.data_dir + f"/{args.video_name}.npz" @@ -76,7 +68,7 @@ if __name__ == "__main__": with torch.no_grad(): with torch.cuda.amp.autocast(dtype=torch.bfloat16): # Predict attributes including cameras, depth maps, and point maps. - predictions = front_track(video_tensor.cuda()/255) + predictions = vggt4track_model(video_tensor.cuda()/255) extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"] depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"] @@ -103,21 +95,23 @@ if __name__ == "__main__": viz = True os.makedirs(out_dir, exist_ok=True) - with open(cfg_dir, "r") as f: - cfg = yaml.load(f, Loader=yaml.FullLoader) - cfg = easydict.EasyDict(cfg) - cfg.out_dir = out_dir - cfg.model.track_num = args.vo_points - print(f"Downloading model from HuggingFace: {cfg.ckpts}") - checkpoint_path = hf_hub_download( - repo_id=cfg.ckpts, - repo_type="model", - filename="SpaTrack3_offline.pth" - ) - model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"]) + # with open(cfg_dir, "r") as f: + # cfg = yaml.load(f, Loader=yaml.FullLoader) + # cfg = easydict.EasyDict(cfg) + # cfg.out_dir = out_dir + # cfg.model.track_num = args.vo_points + # print(f"Downloading model from HuggingFace: {cfg.ckpts}") + if args.track_mode == "offline": + model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline") + else: + model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online") + + # config the model; the track_num is the number of points in the grid + model.spatrack.track_num = args.vo_points + model.eval() model.to("cuda") - viser = Visualizer(save_dir=cfg.out_dir, grayscale=True, + viser = Visualizer(save_dir=out_dir, grayscale=True, fps=10, pad_value=0, tracks_leave_trace=5) grid_size = args.grid_size diff --git a/models/SpaTrackV2/models/SpaTrack.py b/models/SpaTrackV2/models/SpaTrack.py index c369733..c5f43c0 100644 --- a/models/SpaTrackV2/models/SpaTrack.py +++ b/models/SpaTrackV2/models/SpaTrack.py @@ -66,15 +66,15 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin): # Tracker model self.Track3D = TrackRefiner3D(Track_cfg) - track_base_ckpt_dir = Track_cfg.base_ckpt + track_base_ckpt_dir = Track_cfg["base_ckpt"] if os.path.exists(track_base_ckpt_dir): track_pretrain = torch.load(track_base_ckpt_dir) self.Track3D.load_state_dict(track_pretrain, strict=False) # wrap the function of make lora trainable self.make_paras_trainable = partial(self.make_paras_trainable, - mode=ft_cfg.mode, - paras_name=ft_cfg.paras_name) + mode=ft_cfg["mode"], + paras_name=ft_cfg["paras_name"]) self.track_num = track_num def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []): @@ -149,7 +149,7 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin): ): # step 1 allocate the query points on the grid T, C, H, W = video.shape - + if annots_train is not None: vis_gt = annots_train["vis"] _, _, N = vis_gt.shape @@ -300,39 +300,6 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin): **kwargs, annots=annots) if self.training: loss += out["loss"].squeeze() - # from models.SpaTrackV2.utils.visualizer import Visualizer - # vis_track = Visualizer(grayscale=False, - # fps=10, pad_value=50, tracks_leave_trace=0) - # vis_track.visualize(video=segment, - # tracks=out["traj_est"][...,:2], - # visibility=out["vis_est"], - # save_video=True) - # # visualize 4d - # import os, json - # import os.path as osp - # viser4d_dir = os.path.join("viser_4d_results") - # os.makedirs(viser4d_dir, exist_ok=True) - # depth_est = annots["depth_gt"][0] - # unc_metric = out["unc_metric"] - # mask = (unc_metric > 0.5).squeeze(1) - # # pose_est = out["poses_pred"].squeeze(0) - # pose_est = annots["traj_mat"][0] - # rgb_tracks = out["rgb_tracks"].squeeze(0) - # intrinsics = out["intrs"].squeeze(0) - # for i_k in range(out["depth"].shape[0]): - # img_i = out["imgs_raw"][0][i_k].permute(1, 2, 0).cpu().numpy() - # img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB) - # cv2.imwrite(osp.join(viser4d_dir, f'frame_{i_k:04d}.png'), img_i) - # if stage == 1: - # depth = depth_est[i_k].squeeze().cpu().numpy() - # np.save(osp.join(viser4d_dir, f'frame_{i_k:04d}.npy'), depth) - # else: - # point_map_vis = out["points_map"][i_k].cpu().numpy() - # np.save(osp.join(viser4d_dir, f'point_{i_k:04d}.npy'), point_map_vis) - # np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy()) - # np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy()) - # np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy()) - # np.save(os.path.join(viser4d_dir, f'colored_track3d.npy'), rgb_tracks.cpu().numpy()) queries_len = len(queries_new) # update the track3d and track2d @@ -724,40 +691,3 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin): } return ret - - - - -# three stages of training - -# stage 1: -# gt depth and intrinsics synthetic (includes Dynamic Replica, Kubric, Pointodyssey, Vkitti, TartanAir and Indoor() ) Motion Patern (tapvid3d) -# Tracking and Pose as well -> based on gt depth and intrinsics -# (Finished) -> (megasam + base model) vs. tapip3d. (use depth from megasam or pose, which keep the same setting as tapip3d.) - -# stage 2: fixed 3D tracking -# Joint depth refiner -# input depth from whatever + rgb -> temporal module + scale and shift token -> coarse alignment -> scale and shift -# estimate the 3D tracks -> 3D tracks combine with pointmap -> update for pointmap (iteratively) -> residual map B T 3 H W -# ongoing two days - -# stage 3: train multi windows by propagation -# 4 frames overlapped -> train on 64 -> fozen image encoder and finetuning the transformer (learnable parameters pretty small) - -# types of scenarioes: -# 1. auto driving (waymo open dataset) -# 2. robot -# 3. internet ego video - - - -# Iterative Transformer -- Solver -- General Neural MegaSAM + Tracks -# Update Variables: -# 1. 3D tracks B T N 3 xyz. -# 2. 2D tracks B T N 2 x y. -# 3. Dynamic Mask B T H W. -# 4. Camera Pose B T 4 4. -# 5. Video Depth. - -# (RGB, RGBD, RGBD+Pose) x (Static, Dynamic) -# Campatiablity by product. \ No newline at end of file diff --git a/models/SpaTrackV2/models/predictor.py b/models/SpaTrackV2/models/predictor.py index a6a2861..b331041 100644 --- a/models/SpaTrackV2/models/predictor.py +++ b/models/SpaTrackV2/models/predictor.py @@ -16,82 +16,21 @@ from typing import Union, Optional import cv2 import os import decord +from huggingface_hub import PyTorchModelHubMixin # used for model hub -class Predictor(torch.nn.Module): +class Predictor(torch.nn.Module, PyTorchModelHubMixin): def __init__(self, args=None): super().__init__() self.args = args self.spatrack = SpaTrack2(loggers=[None, None, None], **args) - self.S_wind = args.Track_cfg.s_wind - self.overlap = args.Track_cfg.overlap + self.S_wind = args["Track_cfg"]["s_wind"] + self.overlap = args["Track_cfg"]["overlap"] def to(self, device: Union[str, torch.device]): self.spatrack.to(device) if self.spatrack.base_model is not None: self.spatrack.base_model.to(device) - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Union[str, Path], - *, - force_download: bool = False, - cache_dir: Optional[str] = None, - device: Optional[Union[str, torch.device]] = None, - model_cfg: Optional[dict] = None, - **kwargs, - ) -> "SpaTrack2": - """ - Load a pretrained model from a local file or a remote repository. - - Args: - pretrained_model_name_or_path (str or Path): - - Path to a local model file (e.g., `./model.pth`). - - HuggingFace Hub model ID (e.g., `username/model-name`). - force_download (bool, optional): - Whether to force re-download even if cached. Default: False. - cache_dir (str, optional): - Custom cache directory. Default: None (use default cache). - device (str or torch.device, optional): - Target device (e.g., "cuda", "cpu"). Default: None (keep original). - **kwargs: - Additional config overrides. - - Returns: - SpaTrack2: Loaded pretrained model. - """ - # (1) check the path is local or remote - if isinstance(pretrained_model_name_or_path, Path): - model_path = str(pretrained_model_name_or_path) - else: - model_path = pretrained_model_name_or_path - # (2) if the path is remote, download it - if not os.path.exists(model_path): - raise NotImplementedError("Remote download not implemented yet. Use a local path.") - # (3) load the model weights - - state_dict = torch.load(model_path, map_location="cpu") - # (4) initialize the model (can load config.json if exists) - config_path = os.path.join(os.path.dirname(model_path), "config.json") - config = {} - if os.path.exists(config_path): - import json - with open(config_path, "r") as f: - config.update(json.load(f)) - config.update(kwargs) # allow override the config - if model_cfg is not None: - config = model_cfg - model = cls(config) - if "model" in state_dict: - model.spatrack.load_state_dict(state_dict["model"], strict=False) - else: - model.spatrack.load_state_dict(state_dict, strict=False) - # (5) device management - if device is not None: - model.to(device) - - return model - def forward(self, video: str|torch.Tensor|np.ndarray, depth: str|torch.Tensor|np.ndarray=None, unc_metric: str|torch.Tensor|np.ndarray=None, @@ -146,7 +85,6 @@ class Predictor(torch.nn.Module): window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track, fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],) - return ret diff --git a/models/SpaTrackV2/models/tracker3D/TrackRefiner.py b/models/SpaTrackV2/models/tracker3D/TrackRefiner.py index cb1d5f7..cb39db4 100644 --- a/models/SpaTrackV2/models/tracker3D/TrackRefiner.py +++ b/models/SpaTrackV2/models/tracker3D/TrackRefiner.py @@ -30,7 +30,7 @@ from models.SpaTrackV2.models.tracker3D.delta_utils.upsample_transformer import class TrackRefiner3D(CoTrackerThreeOffline): def __init__(self, args=None): - super().__init__(**args.base) + super().__init__(**args["base"]) """ This is 3D warpper from cotracker, which load the cotracker pretrain and @@ -46,15 +46,7 @@ class TrackRefiner3D(CoTrackerThreeOffline): self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110) # get the anchor point's embedding, and init the pts refiner update_pts = True - # self.corr_transformer = nn.ModuleList([ - # CorrPointformer( - # dim=128, - # num_heads=8, - # head_dim=128 // 8, - # mlp_ratio=4.0, - # ) - # for _ in range(self.corr_levels) - # ]) + self.corr_transformer = nn.ModuleList([ CorrPointformer( dim=128, @@ -67,29 +59,11 @@ class TrackRefiner3D(CoTrackerThreeOffline): self.fnet = BasicEncoder(input_dim=3, output_dim=self.latent_dim, stride=self.stride) self.corr3d_radius = 3 - - if args.stablizer: - self.scale_shift_tokens = nn.Parameter(torch.randn(1, 2, self.latent_dim, requires_grad=True)) - self.upsample_kernel_size = 5 - self.residual_embedding = nn.Parameter(torch.randn( - self.latent_dim, self.model_resolution[0]//16, - self.model_resolution[1]//16, requires_grad=True)) - self.dense_mlp = nn.Conv2d(2*self.latent_dim+63, self.latent_dim, kernel_size=1, stride=1, padding=0) - self.upsample_factor = 4 - self.upsample_transformer = UpsampleTransformerAlibi( - kernel_size=self.upsample_kernel_size, # kernel_size=3, # - stride=self.stride, - latent_dim=self.latent_dim, - num_attn_blocks=2, - upsample_factor=4, - ) - else: - self.update_pointmap = None - self.mode = args.mode + self.mode = args["mode"] if self.mode == "online": - self.s_wind = args.s_wind - self.overlap = args.overlap + self.s_wind = args["s_wind"] + self.overlap = args["overlap"] def upsample_with_mask( self, inp: torch.Tensor, mask: torch.Tensor @@ -1061,29 +1035,7 @@ class TrackRefiner3D(CoTrackerThreeOffline): vis_est = (vis_est>0.5).float() sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean() # coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype) - # if torch.isnan(coords_proj_curr).sum()>0: - # import pdb; pdb.set_trace() - if False: - point_map_resize = point_map.clone().view(B, T, 3, H, W) - update_input = torch.cat([point_map_resize, metric_unc.view(B,T,1,H,W)], dim=2) - coords_append_resize = coords.clone().detach() - coords_append_resize[..., :2] = coords_append_resize[..., :2] * float(self.stride) - update_track_input = self.norm_xyz(cam_pts_est)*5 - update_track_input = torch.cat([update_track_input, vis_est[...,None]], dim=-1) - update_track_input = posenc(update_track_input, min_deg=0, max_deg=12) - update = self.update_pointmap.stablizer(update_input, - update_track_input, coords_append_resize)#, imgs=video, vis_track=viser) - #NOTE: update the point map - point_map_resize += update - point_map_refine_out = F.interpolate(point_map_resize.view(B*T, -1, H, W), - size=(self.image_size[0].item(), self.image_size[1].item()), mode='nearest') - point_map_refine_out = rearrange(point_map_refine_out, '(b t) c h w -> b t c h w', t=T, b=B) - point_map_preds.append(self.denorm_xyz(point_map_refine_out)) - point_map_org = self.denorm_xyz(point_map_refine_out).view(B*T, 3, H_, W_) - - # if torch.isnan(coords).sum()>0: - # import pdb; pdb.set_trace() #NOTE: the 2d tracking + unproject depth fix_cam_est = coords_append.clone() fix_cam_est[...,2] = depth_unproj diff --git a/models/vggt/vggt/__init__.py b/models/SpaTrackV2/models/vggt4track/__init__.py similarity index 100% rename from models/vggt/vggt/__init__.py rename to models/SpaTrackV2/models/vggt4track/__init__.py diff --git a/models/vggt/vggt/heads/camera_head.py b/models/SpaTrackV2/models/vggt4track/heads/camera_head.py similarity index 96% rename from models/vggt/vggt/heads/camera_head.py rename to models/SpaTrackV2/models/vggt4track/heads/camera_head.py index ee95d35..d9a4891 100644 --- a/models/vggt/vggt/heads/camera_head.py +++ b/models/SpaTrackV2/models/vggt4track/heads/camera_head.py @@ -11,9 +11,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.vggt.vggt.layers import Mlp -from models.vggt.vggt.layers.block import Block -from models.vggt.vggt.heads.head_act import activate_pose +from models.SpaTrackV2.models.vggt4track.layers import Mlp +from models.SpaTrackV2.models.vggt4track.layers.block import Block +from models.SpaTrackV2.models.vggt4track.heads.head_act import activate_pose class CameraHead(nn.Module): diff --git a/models/vggt/vggt/heads/dpt_head.py b/models/SpaTrackV2/models/vggt4track/heads/dpt_head.py similarity index 100% rename from models/vggt/vggt/heads/dpt_head.py rename to models/SpaTrackV2/models/vggt4track/heads/dpt_head.py diff --git a/models/vggt/vggt/heads/head_act.py b/models/SpaTrackV2/models/vggt4track/heads/head_act.py similarity index 100% rename from models/vggt/vggt/heads/head_act.py rename to models/SpaTrackV2/models/vggt4track/heads/head_act.py diff --git a/models/vggt/vggt/heads/scale_head.py b/models/SpaTrackV2/models/vggt4track/heads/scale_head.py similarity index 96% rename from models/vggt/vggt/heads/scale_head.py rename to models/SpaTrackV2/models/vggt4track/heads/scale_head.py index 6fdf09b..7e2a551 100644 --- a/models/vggt/vggt/heads/scale_head.py +++ b/models/SpaTrackV2/models/vggt4track/heads/scale_head.py @@ -11,9 +11,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.vggt.vggt.layers import Mlp -from models.vggt.vggt.layers.block import Block -from models.vggt.vggt.heads.head_act import activate_pose +from models.SpaTrackV2.models.vggt4track.layers import Mlp +from models.SpaTrackV2.models.vggt4track.layers.block import Block +from models.SpaTrackV2.models.vggt4track.heads.head_act import activate_pose class ScaleHead(nn.Module): diff --git a/models/vggt/vggt/heads/track_head.py b/models/SpaTrackV2/models/vggt4track/heads/track_head.py similarity index 100% rename from models/vggt/vggt/heads/track_head.py rename to models/SpaTrackV2/models/vggt4track/heads/track_head.py diff --git a/models/vggt/vggt/heads/track_modules/__init__.py b/models/SpaTrackV2/models/vggt4track/heads/track_modules/__init__.py similarity index 100% rename from models/vggt/vggt/heads/track_modules/__init__.py rename to models/SpaTrackV2/models/vggt4track/heads/track_modules/__init__.py diff --git a/models/vggt/vggt/heads/track_modules/base_track_predictor.py b/models/SpaTrackV2/models/vggt4track/heads/track_modules/base_track_predictor.py similarity index 100% rename from models/vggt/vggt/heads/track_modules/base_track_predictor.py rename to models/SpaTrackV2/models/vggt4track/heads/track_modules/base_track_predictor.py diff --git a/models/vggt/vggt/heads/track_modules/blocks.py b/models/SpaTrackV2/models/vggt4track/heads/track_modules/blocks.py similarity index 100% rename from models/vggt/vggt/heads/track_modules/blocks.py rename to models/SpaTrackV2/models/vggt4track/heads/track_modules/blocks.py diff --git a/models/vggt/vggt/heads/track_modules/modules.py b/models/SpaTrackV2/models/vggt4track/heads/track_modules/modules.py similarity index 100% rename from models/vggt/vggt/heads/track_modules/modules.py rename to models/SpaTrackV2/models/vggt4track/heads/track_modules/modules.py diff --git a/models/vggt/vggt/heads/track_modules/utils.py b/models/SpaTrackV2/models/vggt4track/heads/track_modules/utils.py similarity index 100% rename from models/vggt/vggt/heads/track_modules/utils.py rename to models/SpaTrackV2/models/vggt4track/heads/track_modules/utils.py diff --git a/models/vggt/vggt/heads/utils.py b/models/SpaTrackV2/models/vggt4track/heads/utils.py similarity index 100% rename from models/vggt/vggt/heads/utils.py rename to models/SpaTrackV2/models/vggt4track/heads/utils.py diff --git a/models/vggt/vggt/layers/__init__.py b/models/SpaTrackV2/models/vggt4track/layers/__init__.py similarity index 100% rename from models/vggt/vggt/layers/__init__.py rename to models/SpaTrackV2/models/vggt4track/layers/__init__.py diff --git a/models/vggt/vggt/layers/attention.py b/models/SpaTrackV2/models/vggt4track/layers/attention.py similarity index 100% rename from models/vggt/vggt/layers/attention.py rename to models/SpaTrackV2/models/vggt4track/layers/attention.py diff --git a/models/vggt/vggt/layers/block.py b/models/SpaTrackV2/models/vggt4track/layers/block.py similarity index 100% rename from models/vggt/vggt/layers/block.py rename to models/SpaTrackV2/models/vggt4track/layers/block.py diff --git a/models/vggt/vggt/layers/drop_path.py b/models/SpaTrackV2/models/vggt4track/layers/drop_path.py similarity index 100% rename from models/vggt/vggt/layers/drop_path.py rename to models/SpaTrackV2/models/vggt4track/layers/drop_path.py diff --git a/models/vggt/vggt/layers/layer_scale.py b/models/SpaTrackV2/models/vggt4track/layers/layer_scale.py similarity index 100% rename from models/vggt/vggt/layers/layer_scale.py rename to models/SpaTrackV2/models/vggt4track/layers/layer_scale.py diff --git a/models/vggt/vggt/layers/mlp.py b/models/SpaTrackV2/models/vggt4track/layers/mlp.py similarity index 100% rename from models/vggt/vggt/layers/mlp.py rename to models/SpaTrackV2/models/vggt4track/layers/mlp.py diff --git a/models/vggt/vggt/layers/patch_embed.py b/models/SpaTrackV2/models/vggt4track/layers/patch_embed.py similarity index 100% rename from models/vggt/vggt/layers/patch_embed.py rename to models/SpaTrackV2/models/vggt4track/layers/patch_embed.py diff --git a/models/vggt/vggt/layers/rope.py b/models/SpaTrackV2/models/vggt4track/layers/rope.py similarity index 100% rename from models/vggt/vggt/layers/rope.py rename to models/SpaTrackV2/models/vggt4track/layers/rope.py diff --git a/models/vggt/vggt/layers/swiglu_ffn.py b/models/SpaTrackV2/models/vggt4track/layers/swiglu_ffn.py similarity index 100% rename from models/vggt/vggt/layers/swiglu_ffn.py rename to models/SpaTrackV2/models/vggt4track/layers/swiglu_ffn.py diff --git a/models/vggt/vggt/layers/vision_transformer.py b/models/SpaTrackV2/models/vggt4track/layers/vision_transformer.py similarity index 100% rename from models/vggt/vggt/layers/vision_transformer.py rename to models/SpaTrackV2/models/vggt4track/layers/vision_transformer.py diff --git a/models/vggt/vggt/models/aggregator.py b/models/SpaTrackV2/models/vggt4track/models/aggregator.py similarity index 97% rename from models/vggt/vggt/models/aggregator.py rename to models/SpaTrackV2/models/vggt4track/models/aggregator.py index 3218af6..07fd59e 100644 --- a/models/vggt/vggt/models/aggregator.py +++ b/models/SpaTrackV2/models/vggt4track/models/aggregator.py @@ -10,10 +10,10 @@ import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Union, List, Dict, Any -from models.vggt.vggt.layers import PatchEmbed -from models.vggt.vggt.layers.block import Block -from models.vggt.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter -from models.vggt.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 +from models.SpaTrackV2.models.vggt4track.layers import PatchEmbed +from models.SpaTrackV2.models.vggt4track.layers.block import Block +from models.SpaTrackV2.models.vggt4track.layers.rope import RotaryPositionEmbedding2D, PositionGetter +from models.SpaTrackV2.models.vggt4track.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 from torch.utils.checkpoint import checkpoint logger = logging.getLogger(__name__) diff --git a/models/vggt/vggt/models/aggregator_front.py b/models/SpaTrackV2/models/vggt4track/models/aggregator_front.py similarity index 97% rename from models/vggt/vggt/models/aggregator_front.py rename to models/SpaTrackV2/models/vggt4track/models/aggregator_front.py index 6d7a243..db57bc6 100644 --- a/models/vggt/vggt/models/aggregator_front.py +++ b/models/SpaTrackV2/models/vggt4track/models/aggregator_front.py @@ -10,10 +10,10 @@ import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Union, List, Dict, Any -from models.vggt.vggt.layers import PatchEmbed -from models.vggt.vggt.layers.block import Block -from models.vggt.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter -from models.vggt.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 +from models.SpaTrackV2.models.vggt4track.layers import PatchEmbed +from models.SpaTrackV2.models.vggt4track.layers.block import Block +from models.SpaTrackV2.models.vggt4track.layers.rope import RotaryPositionEmbedding2D, PositionGetter +from models.SpaTrackV2.models.vggt4track.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 from torch.utils.checkpoint import checkpoint logger = logging.getLogger(__name__) diff --git a/models/vggt/vggt/models/tracker_front.py b/models/SpaTrackV2/models/vggt4track/models/tracker_front.py similarity index 95% rename from models/vggt/vggt/models/tracker_front.py rename to models/SpaTrackV2/models/vggt4track/models/tracker_front.py index 88a7582..fb9c670 100644 --- a/models/vggt/vggt/models/tracker_front.py +++ b/models/SpaTrackV2/models/vggt4track/models/tracker_front.py @@ -9,12 +9,12 @@ import torch.nn as nn from torch.utils.checkpoint import checkpoint from huggingface_hub import PyTorchModelHubMixin # used for model hub -from models.vggt.vggt.models.aggregator_front import Aggregator -from models.vggt.vggt.heads.camera_head import CameraHead -from models.vggt.vggt.heads.scale_head import ScaleHead +from models.SpaTrackV2.models.vggt4track.models.aggregator_front import Aggregator +from models.SpaTrackV2.models.vggt4track.heads.camera_head import CameraHead +from models.SpaTrackV2.models.vggt4track.heads.scale_head import ScaleHead from einops import rearrange -from models.vggt.vggt.utils.loss import compute_loss -from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri +from models.SpaTrackV2.utils.loss import compute_loss +from models.SpaTrackV2.utils.pose_enc import pose_encoding_to_extri_intri import torch.nn.functional as F class FrontTracker(nn.Module, PyTorchModelHubMixin): diff --git a/models/vggt/vggt/models/vggt.py b/models/SpaTrackV2/models/vggt4track/models/vggt.py similarity index 100% rename from models/vggt/vggt/models/vggt.py rename to models/SpaTrackV2/models/vggt4track/models/vggt.py diff --git a/models/vggt/vggt/models/vggt_moe.py b/models/SpaTrackV2/models/vggt4track/models/vggt_moe.py similarity index 90% rename from models/vggt/vggt/models/vggt_moe.py rename to models/SpaTrackV2/models/vggt4track/models/vggt_moe.py index 3d2d2e9..dbdf7c5 100644 --- a/models/vggt/vggt/models/vggt_moe.py +++ b/models/SpaTrackV2/models/vggt4track/models/vggt_moe.py @@ -8,14 +8,14 @@ import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin # used for model hub -from models.vggt.vggt.models.aggregator import Aggregator -from models.vggt.vggt.heads.camera_head import CameraHead -from models.vggt.vggt.heads.dpt_head import DPTHead -from models.vggt.vggt.heads.track_head import TrackHead -from models.vggt.vggt.utils.loss import compute_loss -from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri +from models.SpaTrackV2.models.vggt4track.models.aggregator import Aggregator +from models.SpaTrackV2.models.vggt4track.heads.camera_head import CameraHead +from models.SpaTrackV2.models.vggt4track.heads.dpt_head import DPTHead +from models.SpaTrackV2.models.vggt4track.heads.track_head import TrackHead +from models.SpaTrackV2.models.vggt4track.utils.loss import compute_loss +from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index -from models.vggt.vggt.utils.load_fn import preprocess_image +from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image from einops import rearrange import torch.nn.functional as F diff --git a/models/vggt/vggt/utils/__init__.py b/models/SpaTrackV2/models/vggt4track/utils/__init__.py similarity index 100% rename from models/vggt/vggt/utils/__init__.py rename to models/SpaTrackV2/models/vggt4track/utils/__init__.py diff --git a/models/vggt/vggt/utils/geometry.py b/models/SpaTrackV2/models/vggt4track/utils/geometry.py similarity index 100% rename from models/vggt/vggt/utils/geometry.py rename to models/SpaTrackV2/models/vggt4track/utils/geometry.py diff --git a/models/vggt/vggt/utils/load_fn.py b/models/SpaTrackV2/models/vggt4track/utils/load_fn.py similarity index 100% rename from models/vggt/vggt/utils/load_fn.py rename to models/SpaTrackV2/models/vggt4track/utils/load_fn.py diff --git a/models/vggt/vggt/utils/loss.py b/models/SpaTrackV2/models/vggt4track/utils/loss.py similarity index 97% rename from models/vggt/vggt/utils/loss.py rename to models/SpaTrackV2/models/vggt4track/utils/loss.py index 7203036..a5f7867 100644 --- a/models/vggt/vggt/utils/loss.py +++ b/models/SpaTrackV2/models/vggt4track/utils/loss.py @@ -15,7 +15,7 @@ from models.moge.train.losses import ( import torch.nn.functional as F from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index -from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri, extri_intri_to_pose_encoding +from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri, extri_intri_to_pose_encoding def compute_loss(predictions, annots): """ diff --git a/models/vggt/vggt/utils/pose_enc.py b/models/SpaTrackV2/models/vggt4track/utils/pose_enc.py similarity index 100% rename from models/vggt/vggt/utils/pose_enc.py rename to models/SpaTrackV2/models/vggt4track/utils/pose_enc.py diff --git a/models/vggt/vggt/utils/rotation.py b/models/SpaTrackV2/models/vggt4track/utils/rotation.py similarity index 100% rename from models/vggt/vggt/utils/rotation.py rename to models/SpaTrackV2/models/vggt4track/utils/rotation.py diff --git a/models/vggt/vggt/utils/visual_track.py b/models/SpaTrackV2/models/vggt4track/utils/visual_track.py similarity index 100% rename from models/vggt/vggt/utils/visual_track.py rename to models/SpaTrackV2/models/vggt4track/utils/visual_track.py diff --git a/models/vggt/setup.py b/models/vggt/setup.py deleted file mode 100644 index 2774cae..0000000 --- a/models/vggt/setup.py +++ /dev/null @@ -1,8 +0,0 @@ -from setuptools import setup, find_packages - -setup( - name='vggt', - version='0.1', - packages=find_packages(), - description='vggt local package', -) \ No newline at end of file