This commit is contained in:
xiaoyuxi 2025-07-09 18:31:17 +08:00
parent 4393c1a348
commit f427674f3f
41 changed files with 93 additions and 290 deletions

3
.gitignore vendored
View File

@ -48,4 +48,5 @@ config/fix_2d.yaml
models/**/build models/**/build
models/**/dist models/**/dist
temp_local temp_local
examples/results

60
app.py
View File

@ -26,6 +26,9 @@ import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import atexit import atexit
import uuid import uuid
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
from models.SpaTrackV2.models.predictor import Predictor
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -78,20 +81,15 @@ def create_user_temp_dir():
return temp_dir return temp_dir
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
# init the model
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
if os.environ.get("VGGT_DIR", None) is not None: vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
from models.vggt.vggt.models.vggt_moe import VGGT4Track vggt4track_model.eval()
from models.vggt.vggt.utils.load_fn import preprocess_image vggt4track_model = vggt4track_model.to("cuda")
vggt_model = VGGT4Track()
vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")), strict=False)
vggt_model.eval()
vggt_model = vggt_model.to("cuda")
# Global model initialization # Global model initialization
print("🚀 Initializing local models...") print("🚀 Initializing local models...")
tracker_model, _ = get_tracker_predictor(".", vo_points=756) tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
tracker_model.eval()
predictor = get_sam_predictor() predictor = get_sam_predictor()
print("✅ Models loaded successfully!") print("✅ Models loaded successfully!")
@ -129,7 +127,8 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
print("Initializing tracker models inside GPU function...") print("Initializing tracker models inside GPU function...")
out_dir = os.path.join(temp_dir, "results") out_dir = os.path.join(temp_dir, "results")
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, tracker_model=tracker_model) tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
tracker_model=tracker_model.cuda())
# Setup paths # Setup paths
video_path = os.path.join(temp_dir, f"{video_name}.mp4") video_path = os.path.join(temp_dir, f"{video_name}.mp4")
@ -159,25 +158,23 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
data_npz_load = {} data_npz_load = {}
# run vggt # run vggt
if os.environ.get("VGGT_DIR", None) is not None: # process the image tensor
# process the image tensor video_tensor = preprocess_image(video_tensor)[None]
video_tensor = preprocess_image(video_tensor)[None] with torch.no_grad():
with torch.no_grad(): with torch.cuda.amp.autocast(dtype=torch.bfloat16):
with torch.cuda.amp.autocast(dtype=torch.bfloat16): # Predict attributes including cameras, depth maps, and point maps.
# Predict attributes including cameras, depth maps, and point maps. predictions = vggt4track_model(video_tensor.cuda()/255)
predictions = vggt_model(video_tensor.cuda()/255) extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"] depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
depth_tensor = depth_map.squeeze().cpu().numpy()
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
extrs = extrinsic.squeeze().cpu().numpy()
intrs = intrinsic.squeeze().cpu().numpy()
video_tensor = video_tensor.squeeze()
#NOTE: 20% of the depth is not reliable
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
depth_tensor = depth_map.squeeze().cpu().numpy()
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
extrs = extrinsic.squeeze().cpu().numpy()
intrs = intrinsic.squeeze().cpu().numpy()
video_tensor = video_tensor.squeeze()
#NOTE: 20% of the depth is not reliable
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
# Load and process mask # Load and process mask
if os.path.exists(mask_path): if os.path.exists(mask_path):
mask = cv2.imread(mask_path) mask = cv2.imread(mask_path)
@ -199,7 +196,6 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy() query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
print(f"Query points shape: {query_xyt.shape}") print(f"Query points shape: {query_xyt.shape}")
# Run model inference # Run model inference
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
( (
@ -210,8 +206,8 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
queries=query_xyt, queries=query_xyt,
fps=1, full_point=False, iters_track=4, fps=1, full_point=False, iters_track=4,
query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric, query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
support_frame=len(video_tensor)-1, replace_ratio=0.2) support_frame=len(video_tensor)-1, replace_ratio=0.2)
# Resize results to avoid large I/O # Resize results to avoid large I/O
max_size = 224 max_size = 224
h, w = video.shape[2:] h, w = video.shape[2:]

View File

@ -17,16 +17,14 @@ import glob
from rich import print from rich import print
import argparse import argparse
import decord import decord
from models.vggt.vggt.models.vggt_moe import VGGT4Track from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
from models.vggt.vggt.utils.load_fn import preprocess_image from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri
from huggingface_hub import hf_hub_download
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--cfg_dir", type=str, default="config/magic_infer_offline.yaml") parser.add_argument("--track_mode", type=str, default="offline")
parser.add_argument("--data_type", type=str, default="RGBD") parser.add_argument("--data_type", type=str, default="RGBD")
parser.add_argument("--VGGT", action="store_true")
parser.add_argument("--data_dir", type=str, default="assets/example0") parser.add_argument("--data_dir", type=str, default="assets/example0")
parser.add_argument("--video_name", type=str, default="snowboard") parser.add_argument("--video_name", type=str, default="snowboard")
parser.add_argument("--grid_size", type=int, default=10) parser.add_argument("--grid_size", type=int, default=10)
@ -36,20 +34,14 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
cfg_dir = args.cfg_dir
out_dir = args.data_dir + "/results" out_dir = args.data_dir + "/results"
# fps # fps
fps = int(args.fps) fps = int(args.fps)
mask_dir = args.data_dir + f"/{args.video_name}.png" mask_dir = args.data_dir + f"/{args.video_name}.png"
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
"spatrack_front.pth") #, force_download=True) vggt4track_model.eval()
VGGT_DIR = os.environ["VGGT_DIR"] vggt4track_model = vggt4track_model.to("cuda")
assert os.path.exists(VGGT_DIR), f"VGGT_DIR {VGGT_DIR} does not exist"
front_track = VGGT4Track()
front_track.load_state_dict(torch.load(VGGT_DIR), strict=False)
front_track.eval()
front_track = front_track.to("cuda")
if args.data_type == "RGBD": if args.data_type == "RGBD":
npz_dir = args.data_dir + f"/{args.video_name}.npz" npz_dir = args.data_dir + f"/{args.video_name}.npz"
@ -76,7 +68,7 @@ if __name__ == "__main__":
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16): with torch.cuda.amp.autocast(dtype=torch.bfloat16):
# Predict attributes including cameras, depth maps, and point maps. # Predict attributes including cameras, depth maps, and point maps.
predictions = front_track(video_tensor.cuda()/255) predictions = vggt4track_model(video_tensor.cuda()/255)
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"] extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"] depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
@ -103,21 +95,23 @@ if __name__ == "__main__":
viz = True viz = True
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
with open(cfg_dir, "r") as f: # with open(cfg_dir, "r") as f:
cfg = yaml.load(f, Loader=yaml.FullLoader) # cfg = yaml.load(f, Loader=yaml.FullLoader)
cfg = easydict.EasyDict(cfg) # cfg = easydict.EasyDict(cfg)
cfg.out_dir = out_dir # cfg.out_dir = out_dir
cfg.model.track_num = args.vo_points # cfg.model.track_num = args.vo_points
print(f"Downloading model from HuggingFace: {cfg.ckpts}") # print(f"Downloading model from HuggingFace: {cfg.ckpts}")
checkpoint_path = hf_hub_download( if args.track_mode == "offline":
repo_id=cfg.ckpts, model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
repo_type="model", else:
filename="SpaTrack3_offline.pth" model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
)
model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"]) # config the model; the track_num is the number of points in the grid
model.spatrack.track_num = args.vo_points
model.eval() model.eval()
model.to("cuda") model.to("cuda")
viser = Visualizer(save_dir=cfg.out_dir, grayscale=True, viser = Visualizer(save_dir=out_dir, grayscale=True,
fps=10, pad_value=0, tracks_leave_trace=5) fps=10, pad_value=0, tracks_leave_trace=5)
grid_size = args.grid_size grid_size = args.grid_size

View File

@ -66,15 +66,15 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
# Tracker model # Tracker model
self.Track3D = TrackRefiner3D(Track_cfg) self.Track3D = TrackRefiner3D(Track_cfg)
track_base_ckpt_dir = Track_cfg.base_ckpt track_base_ckpt_dir = Track_cfg["base_ckpt"]
if os.path.exists(track_base_ckpt_dir): if os.path.exists(track_base_ckpt_dir):
track_pretrain = torch.load(track_base_ckpt_dir) track_pretrain = torch.load(track_base_ckpt_dir)
self.Track3D.load_state_dict(track_pretrain, strict=False) self.Track3D.load_state_dict(track_pretrain, strict=False)
# wrap the function of make lora trainable # wrap the function of make lora trainable
self.make_paras_trainable = partial(self.make_paras_trainable, self.make_paras_trainable = partial(self.make_paras_trainable,
mode=ft_cfg.mode, mode=ft_cfg["mode"],
paras_name=ft_cfg.paras_name) paras_name=ft_cfg["paras_name"])
self.track_num = track_num self.track_num = track_num
def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []): def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
@ -149,7 +149,7 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
): ):
# step 1 allocate the query points on the grid # step 1 allocate the query points on the grid
T, C, H, W = video.shape T, C, H, W = video.shape
if annots_train is not None: if annots_train is not None:
vis_gt = annots_train["vis"] vis_gt = annots_train["vis"]
_, _, N = vis_gt.shape _, _, N = vis_gt.shape
@ -300,39 +300,6 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
**kwargs, annots=annots) **kwargs, annots=annots)
if self.training: if self.training:
loss += out["loss"].squeeze() loss += out["loss"].squeeze()
# from models.SpaTrackV2.utils.visualizer import Visualizer
# vis_track = Visualizer(grayscale=False,
# fps=10, pad_value=50, tracks_leave_trace=0)
# vis_track.visualize(video=segment,
# tracks=out["traj_est"][...,:2],
# visibility=out["vis_est"],
# save_video=True)
# # visualize 4d
# import os, json
# import os.path as osp
# viser4d_dir = os.path.join("viser_4d_results")
# os.makedirs(viser4d_dir, exist_ok=True)
# depth_est = annots["depth_gt"][0]
# unc_metric = out["unc_metric"]
# mask = (unc_metric > 0.5).squeeze(1)
# # pose_est = out["poses_pred"].squeeze(0)
# pose_est = annots["traj_mat"][0]
# rgb_tracks = out["rgb_tracks"].squeeze(0)
# intrinsics = out["intrs"].squeeze(0)
# for i_k in range(out["depth"].shape[0]):
# img_i = out["imgs_raw"][0][i_k].permute(1, 2, 0).cpu().numpy()
# img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB)
# cv2.imwrite(osp.join(viser4d_dir, f'frame_{i_k:04d}.png'), img_i)
# if stage == 1:
# depth = depth_est[i_k].squeeze().cpu().numpy()
# np.save(osp.join(viser4d_dir, f'frame_{i_k:04d}.npy'), depth)
# else:
# point_map_vis = out["points_map"][i_k].cpu().numpy()
# np.save(osp.join(viser4d_dir, f'point_{i_k:04d}.npy'), point_map_vis)
# np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy())
# np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy())
# np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy())
# np.save(os.path.join(viser4d_dir, f'colored_track3d.npy'), rgb_tracks.cpu().numpy())
queries_len = len(queries_new) queries_len = len(queries_new)
# update the track3d and track2d # update the track3d and track2d
@ -724,40 +691,3 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
} }
return ret return ret
# three stages of training
# stage 1:
# gt depth and intrinsics synthetic (includes Dynamic Replica, Kubric, Pointodyssey, Vkitti, TartanAir and Indoor() ) Motion Patern (tapvid3d)
# Tracking and Pose as well -> based on gt depth and intrinsics
# (Finished) -> (megasam + base model) vs. tapip3d. (use depth from megasam or pose, which keep the same setting as tapip3d.)
# stage 2: fixed 3D tracking
# Joint depth refiner
# input depth from whatever + rgb -> temporal module + scale and shift token -> coarse alignment -> scale and shift
# estimate the 3D tracks -> 3D tracks combine with pointmap -> update for pointmap (iteratively) -> residual map B T 3 H W
# ongoing two days
# stage 3: train multi windows by propagation
# 4 frames overlapped -> train on 64 -> fozen image encoder and finetuning the transformer (learnable parameters pretty small)
# types of scenarioes:
# 1. auto driving (waymo open dataset)
# 2. robot
# 3. internet ego video
# Iterative Transformer -- Solver -- General Neural MegaSAM + Tracks
# Update Variables:
# 1. 3D tracks B T N 3 xyz.
# 2. 2D tracks B T N 2 x y.
# 3. Dynamic Mask B T H W.
# 4. Camera Pose B T 4 4.
# 5. Video Depth.
# (RGB, RGBD, RGBD+Pose) x (Static, Dynamic)
# Campatiablity by product.

View File

@ -16,82 +16,21 @@ from typing import Union, Optional
import cv2 import cv2
import os import os
import decord import decord
from huggingface_hub import PyTorchModelHubMixin # used for model hub
class Predictor(torch.nn.Module): class Predictor(torch.nn.Module, PyTorchModelHubMixin):
def __init__(self, args=None): def __init__(self, args=None):
super().__init__() super().__init__()
self.args = args self.args = args
self.spatrack = SpaTrack2(loggers=[None, None, None], **args) self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
self.S_wind = args.Track_cfg.s_wind self.S_wind = args["Track_cfg"]["s_wind"]
self.overlap = args.Track_cfg.overlap self.overlap = args["Track_cfg"]["overlap"]
def to(self, device: Union[str, torch.device]): def to(self, device: Union[str, torch.device]):
self.spatrack.to(device) self.spatrack.to(device)
if self.spatrack.base_model is not None: if self.spatrack.base_model is not None:
self.spatrack.base_model.to(device) self.spatrack.base_model.to(device)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, Path],
*,
force_download: bool = False,
cache_dir: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
model_cfg: Optional[dict] = None,
**kwargs,
) -> "SpaTrack2":
"""
Load a pretrained model from a local file or a remote repository.
Args:
pretrained_model_name_or_path (str or Path):
- Path to a local model file (e.g., `./model.pth`).
- HuggingFace Hub model ID (e.g., `username/model-name`).
force_download (bool, optional):
Whether to force re-download even if cached. Default: False.
cache_dir (str, optional):
Custom cache directory. Default: None (use default cache).
device (str or torch.device, optional):
Target device (e.g., "cuda", "cpu"). Default: None (keep original).
**kwargs:
Additional config overrides.
Returns:
SpaTrack2: Loaded pretrained model.
"""
# (1) check the path is local or remote
if isinstance(pretrained_model_name_or_path, Path):
model_path = str(pretrained_model_name_or_path)
else:
model_path = pretrained_model_name_or_path
# (2) if the path is remote, download it
if not os.path.exists(model_path):
raise NotImplementedError("Remote download not implemented yet. Use a local path.")
# (3) load the model weights
state_dict = torch.load(model_path, map_location="cpu")
# (4) initialize the model (can load config.json if exists)
config_path = os.path.join(os.path.dirname(model_path), "config.json")
config = {}
if os.path.exists(config_path):
import json
with open(config_path, "r") as f:
config.update(json.load(f))
config.update(kwargs) # allow override the config
if model_cfg is not None:
config = model_cfg
model = cls(config)
if "model" in state_dict:
model.spatrack.load_state_dict(state_dict["model"], strict=False)
else:
model.spatrack.load_state_dict(state_dict, strict=False)
# (5) device management
if device is not None:
model.to(device)
return model
def forward(self, video: str|torch.Tensor|np.ndarray, def forward(self, video: str|torch.Tensor|np.ndarray,
depth: str|torch.Tensor|np.ndarray=None, depth: str|torch.Tensor|np.ndarray=None,
unc_metric: str|torch.Tensor|np.ndarray=None, unc_metric: str|torch.Tensor|np.ndarray=None,
@ -146,7 +85,6 @@ class Predictor(torch.nn.Module):
window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track, window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track,
fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],) fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],)
return ret return ret

View File

@ -30,7 +30,7 @@ from models.SpaTrackV2.models.tracker3D.delta_utils.upsample_transformer import
class TrackRefiner3D(CoTrackerThreeOffline): class TrackRefiner3D(CoTrackerThreeOffline):
def __init__(self, args=None): def __init__(self, args=None):
super().__init__(**args.base) super().__init__(**args["base"])
""" """
This is 3D warpper from cotracker, which load the cotracker pretrain and This is 3D warpper from cotracker, which load the cotracker pretrain and
@ -46,15 +46,7 @@ class TrackRefiner3D(CoTrackerThreeOffline):
self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110) self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110)
# get the anchor point's embedding, and init the pts refiner # get the anchor point's embedding, and init the pts refiner
update_pts = True update_pts = True
# self.corr_transformer = nn.ModuleList([
# CorrPointformer(
# dim=128,
# num_heads=8,
# head_dim=128 // 8,
# mlp_ratio=4.0,
# )
# for _ in range(self.corr_levels)
# ])
self.corr_transformer = nn.ModuleList([ self.corr_transformer = nn.ModuleList([
CorrPointformer( CorrPointformer(
dim=128, dim=128,
@ -67,29 +59,11 @@ class TrackRefiner3D(CoTrackerThreeOffline):
self.fnet = BasicEncoder(input_dim=3, self.fnet = BasicEncoder(input_dim=3,
output_dim=self.latent_dim, stride=self.stride) output_dim=self.latent_dim, stride=self.stride)
self.corr3d_radius = 3 self.corr3d_radius = 3
if args.stablizer:
self.scale_shift_tokens = nn.Parameter(torch.randn(1, 2, self.latent_dim, requires_grad=True))
self.upsample_kernel_size = 5
self.residual_embedding = nn.Parameter(torch.randn(
self.latent_dim, self.model_resolution[0]//16,
self.model_resolution[1]//16, requires_grad=True))
self.dense_mlp = nn.Conv2d(2*self.latent_dim+63, self.latent_dim, kernel_size=1, stride=1, padding=0)
self.upsample_factor = 4
self.upsample_transformer = UpsampleTransformerAlibi(
kernel_size=self.upsample_kernel_size, # kernel_size=3, #
stride=self.stride,
latent_dim=self.latent_dim,
num_attn_blocks=2,
upsample_factor=4,
)
else:
self.update_pointmap = None
self.mode = args.mode self.mode = args["mode"]
if self.mode == "online": if self.mode == "online":
self.s_wind = args.s_wind self.s_wind = args["s_wind"]
self.overlap = args.overlap self.overlap = args["overlap"]
def upsample_with_mask( def upsample_with_mask(
self, inp: torch.Tensor, mask: torch.Tensor self, inp: torch.Tensor, mask: torch.Tensor
@ -1061,29 +1035,7 @@ class TrackRefiner3D(CoTrackerThreeOffline):
vis_est = (vis_est>0.5).float() vis_est = (vis_est>0.5).float()
sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean() sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean()
# coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype) # coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype)
# if torch.isnan(coords_proj_curr).sum()>0:
# import pdb; pdb.set_trace()
if False:
point_map_resize = point_map.clone().view(B, T, 3, H, W)
update_input = torch.cat([point_map_resize, metric_unc.view(B,T,1,H,W)], dim=2)
coords_append_resize = coords.clone().detach()
coords_append_resize[..., :2] = coords_append_resize[..., :2] * float(self.stride)
update_track_input = self.norm_xyz(cam_pts_est)*5
update_track_input = torch.cat([update_track_input, vis_est[...,None]], dim=-1)
update_track_input = posenc(update_track_input, min_deg=0, max_deg=12)
update = self.update_pointmap.stablizer(update_input,
update_track_input, coords_append_resize)#, imgs=video, vis_track=viser)
#NOTE: update the point map
point_map_resize += update
point_map_refine_out = F.interpolate(point_map_resize.view(B*T, -1, H, W),
size=(self.image_size[0].item(), self.image_size[1].item()), mode='nearest')
point_map_refine_out = rearrange(point_map_refine_out, '(b t) c h w -> b t c h w', t=T, b=B)
point_map_preds.append(self.denorm_xyz(point_map_refine_out))
point_map_org = self.denorm_xyz(point_map_refine_out).view(B*T, 3, H_, W_)
# if torch.isnan(coords).sum()>0:
# import pdb; pdb.set_trace()
#NOTE: the 2d tracking + unproject depth #NOTE: the 2d tracking + unproject depth
fix_cam_est = coords_append.clone() fix_cam_est = coords_append.clone()
fix_cam_est[...,2] = depth_unproj fix_cam_est[...,2] = depth_unproj

View File

@ -11,9 +11,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from models.vggt.vggt.layers import Mlp from models.SpaTrackV2.models.vggt4track.layers import Mlp
from models.vggt.vggt.layers.block import Block from models.SpaTrackV2.models.vggt4track.layers.block import Block
from models.vggt.vggt.heads.head_act import activate_pose from models.SpaTrackV2.models.vggt4track.heads.head_act import activate_pose
class CameraHead(nn.Module): class CameraHead(nn.Module):

View File

@ -11,9 +11,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from models.vggt.vggt.layers import Mlp from models.SpaTrackV2.models.vggt4track.layers import Mlp
from models.vggt.vggt.layers.block import Block from models.SpaTrackV2.models.vggt4track.layers.block import Block
from models.vggt.vggt.heads.head_act import activate_pose from models.SpaTrackV2.models.vggt4track.heads.head_act import activate_pose
class ScaleHead(nn.Module): class ScaleHead(nn.Module):

View File

@ -10,10 +10,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, Tuple, Union, List, Dict, Any from typing import Optional, Tuple, Union, List, Dict, Any
from models.vggt.vggt.layers import PatchEmbed from models.SpaTrackV2.models.vggt4track.layers import PatchEmbed
from models.vggt.vggt.layers.block import Block from models.SpaTrackV2.models.vggt4track.layers.block import Block
from models.vggt.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter from models.SpaTrackV2.models.vggt4track.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from models.vggt.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 from models.SpaTrackV2.models.vggt4track.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -10,10 +10,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, Tuple, Union, List, Dict, Any from typing import Optional, Tuple, Union, List, Dict, Any
from models.vggt.vggt.layers import PatchEmbed from models.SpaTrackV2.models.vggt4track.layers import PatchEmbed
from models.vggt.vggt.layers.block import Block from models.SpaTrackV2.models.vggt4track.layers.block import Block
from models.vggt.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter from models.SpaTrackV2.models.vggt4track.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from models.vggt.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 from models.SpaTrackV2.models.vggt4track.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -9,12 +9,12 @@ import torch.nn as nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from huggingface_hub import PyTorchModelHubMixin # used for model hub from huggingface_hub import PyTorchModelHubMixin # used for model hub
from models.vggt.vggt.models.aggregator_front import Aggregator from models.SpaTrackV2.models.vggt4track.models.aggregator_front import Aggregator
from models.vggt.vggt.heads.camera_head import CameraHead from models.SpaTrackV2.models.vggt4track.heads.camera_head import CameraHead
from models.vggt.vggt.heads.scale_head import ScaleHead from models.SpaTrackV2.models.vggt4track.heads.scale_head import ScaleHead
from einops import rearrange from einops import rearrange
from models.vggt.vggt.utils.loss import compute_loss from models.SpaTrackV2.utils.loss import compute_loss
from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri from models.SpaTrackV2.utils.pose_enc import pose_encoding_to_extri_intri
import torch.nn.functional as F import torch.nn.functional as F
class FrontTracker(nn.Module, PyTorchModelHubMixin): class FrontTracker(nn.Module, PyTorchModelHubMixin):

View File

@ -8,14 +8,14 @@ import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin # used for model hub from huggingface_hub import PyTorchModelHubMixin # used for model hub
from models.vggt.vggt.models.aggregator import Aggregator from models.SpaTrackV2.models.vggt4track.models.aggregator import Aggregator
from models.vggt.vggt.heads.camera_head import CameraHead from models.SpaTrackV2.models.vggt4track.heads.camera_head import CameraHead
from models.vggt.vggt.heads.dpt_head import DPTHead from models.SpaTrackV2.models.vggt4track.heads.dpt_head import DPTHead
from models.vggt.vggt.heads.track_head import TrackHead from models.SpaTrackV2.models.vggt4track.heads.track_head import TrackHead
from models.vggt.vggt.utils.loss import compute_loss from models.SpaTrackV2.models.vggt4track.utils.loss import compute_loss
from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
from models.vggt.vggt.utils.load_fn import preprocess_image from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
from einops import rearrange from einops import rearrange
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -15,7 +15,7 @@ from models.moge.train.losses import (
import torch.nn.functional as F import torch.nn.functional as F
from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri, extri_intri_to_pose_encoding from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri, extri_intri_to_pose_encoding
def compute_loss(predictions, annots): def compute_loss(predictions, annots):
""" """

View File

@ -1,8 +0,0 @@
from setuptools import setup, find_packages
setup(
name='vggt',
version='0.1',
packages=find_packages(),
description='vggt local package',
)