This commit is contained in:
xiaoyuxi 2025-07-09 18:31:17 +08:00
parent 4393c1a348
commit f427674f3f
41 changed files with 93 additions and 290 deletions

1
.gitignore vendored
View File

@ -49,3 +49,4 @@ models/**/build
models/**/dist
temp_local
examples/results

56
app.py
View File

@ -26,6 +26,9 @@ import logging
from concurrent.futures import ThreadPoolExecutor
import atexit
import uuid
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
from models.SpaTrackV2.models.predictor import Predictor
# Configure logging
logging.basicConfig(level=logging.INFO)
@ -78,20 +81,15 @@ def create_user_temp_dir():
return temp_dir
from huggingface_hub import hf_hub_download
# init the model
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
if os.environ.get("VGGT_DIR", None) is not None:
from models.vggt.vggt.models.vggt_moe import VGGT4Track
from models.vggt.vggt.utils.load_fn import preprocess_image
vggt_model = VGGT4Track()
vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")), strict=False)
vggt_model.eval()
vggt_model = vggt_model.to("cuda")
vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
vggt4track_model.eval()
vggt4track_model = vggt4track_model.to("cuda")
# Global model initialization
print("🚀 Initializing local models...")
tracker_model, _ = get_tracker_predictor(".", vo_points=756)
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
tracker_model.eval()
predictor = get_sam_predictor()
print("✅ Models loaded successfully!")
@ -129,7 +127,8 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
print("Initializing tracker models inside GPU function...")
out_dir = os.path.join(temp_dir, "results")
os.makedirs(out_dir, exist_ok=True)
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, tracker_model=tracker_model)
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
tracker_model=tracker_model.cuda())
# Setup paths
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
@ -159,25 +158,23 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
data_npz_load = {}
# run vggt
if os.environ.get("VGGT_DIR", None) is not None:
# process the image tensor
video_tensor = preprocess_image(video_tensor)[None]
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
# Predict attributes including cameras, depth maps, and point maps.
predictions = vggt_model(video_tensor.cuda()/255)
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
depth_tensor = depth_map.squeeze().cpu().numpy()
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
extrs = extrinsic.squeeze().cpu().numpy()
intrs = intrinsic.squeeze().cpu().numpy()
video_tensor = video_tensor.squeeze()
#NOTE: 20% of the depth is not reliable
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
# process the image tensor
video_tensor = preprocess_image(video_tensor)[None]
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
# Predict attributes including cameras, depth maps, and point maps.
predictions = vggt4track_model(video_tensor.cuda()/255)
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
depth_tensor = depth_map.squeeze().cpu().numpy()
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
extrs = extrinsic.squeeze().cpu().numpy()
intrs = intrinsic.squeeze().cpu().numpy()
video_tensor = video_tensor.squeeze()
#NOTE: 20% of the depth is not reliable
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
# Load and process mask
if os.path.exists(mask_path):
mask = cv2.imread(mask_path)
@ -199,7 +196,6 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
print(f"Query points shape: {query_xyt.shape}")
# Run model inference
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
(

View File

@ -17,16 +17,14 @@ import glob
from rich import print
import argparse
import decord
from models.vggt.vggt.models.vggt_moe import VGGT4Track
from models.vggt.vggt.utils.load_fn import preprocess_image
from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
from huggingface_hub import hf_hub_download
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--cfg_dir", type=str, default="config/magic_infer_offline.yaml")
parser.add_argument("--track_mode", type=str, default="offline")
parser.add_argument("--data_type", type=str, default="RGBD")
parser.add_argument("--VGGT", action="store_true")
parser.add_argument("--data_dir", type=str, default="assets/example0")
parser.add_argument("--video_name", type=str, default="snowboard")
parser.add_argument("--grid_size", type=int, default=10)
@ -36,20 +34,14 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
cfg_dir = args.cfg_dir
out_dir = args.data_dir + "/results"
# fps
fps = int(args.fps)
mask_dir = args.data_dir + f"/{args.video_name}.png"
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts",
"spatrack_front.pth") #, force_download=True)
VGGT_DIR = os.environ["VGGT_DIR"]
assert os.path.exists(VGGT_DIR), f"VGGT_DIR {VGGT_DIR} does not exist"
front_track = VGGT4Track()
front_track.load_state_dict(torch.load(VGGT_DIR), strict=False)
front_track.eval()
front_track = front_track.to("cuda")
vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
vggt4track_model.eval()
vggt4track_model = vggt4track_model.to("cuda")
if args.data_type == "RGBD":
npz_dir = args.data_dir + f"/{args.video_name}.npz"
@ -76,7 +68,7 @@ if __name__ == "__main__":
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
# Predict attributes including cameras, depth maps, and point maps.
predictions = front_track(video_tensor.cuda()/255)
predictions = vggt4track_model(video_tensor.cuda()/255)
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
@ -103,21 +95,23 @@ if __name__ == "__main__":
viz = True
os.makedirs(out_dir, exist_ok=True)
with open(cfg_dir, "r") as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
cfg = easydict.EasyDict(cfg)
cfg.out_dir = out_dir
cfg.model.track_num = args.vo_points
print(f"Downloading model from HuggingFace: {cfg.ckpts}")
checkpoint_path = hf_hub_download(
repo_id=cfg.ckpts,
repo_type="model",
filename="SpaTrack3_offline.pth"
)
model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
# with open(cfg_dir, "r") as f:
# cfg = yaml.load(f, Loader=yaml.FullLoader)
# cfg = easydict.EasyDict(cfg)
# cfg.out_dir = out_dir
# cfg.model.track_num = args.vo_points
# print(f"Downloading model from HuggingFace: {cfg.ckpts}")
if args.track_mode == "offline":
model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
else:
model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
# config the model; the track_num is the number of points in the grid
model.spatrack.track_num = args.vo_points
model.eval()
model.to("cuda")
viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
viser = Visualizer(save_dir=out_dir, grayscale=True,
fps=10, pad_value=0, tracks_leave_trace=5)
grid_size = args.grid_size

View File

@ -66,15 +66,15 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
# Tracker model
self.Track3D = TrackRefiner3D(Track_cfg)
track_base_ckpt_dir = Track_cfg.base_ckpt
track_base_ckpt_dir = Track_cfg["base_ckpt"]
if os.path.exists(track_base_ckpt_dir):
track_pretrain = torch.load(track_base_ckpt_dir)
self.Track3D.load_state_dict(track_pretrain, strict=False)
# wrap the function of make lora trainable
self.make_paras_trainable = partial(self.make_paras_trainable,
mode=ft_cfg.mode,
paras_name=ft_cfg.paras_name)
mode=ft_cfg["mode"],
paras_name=ft_cfg["paras_name"])
self.track_num = track_num
def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
@ -300,39 +300,6 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
**kwargs, annots=annots)
if self.training:
loss += out["loss"].squeeze()
# from models.SpaTrackV2.utils.visualizer import Visualizer
# vis_track = Visualizer(grayscale=False,
# fps=10, pad_value=50, tracks_leave_trace=0)
# vis_track.visualize(video=segment,
# tracks=out["traj_est"][...,:2],
# visibility=out["vis_est"],
# save_video=True)
# # visualize 4d
# import os, json
# import os.path as osp
# viser4d_dir = os.path.join("viser_4d_results")
# os.makedirs(viser4d_dir, exist_ok=True)
# depth_est = annots["depth_gt"][0]
# unc_metric = out["unc_metric"]
# mask = (unc_metric > 0.5).squeeze(1)
# # pose_est = out["poses_pred"].squeeze(0)
# pose_est = annots["traj_mat"][0]
# rgb_tracks = out["rgb_tracks"].squeeze(0)
# intrinsics = out["intrs"].squeeze(0)
# for i_k in range(out["depth"].shape[0]):
# img_i = out["imgs_raw"][0][i_k].permute(1, 2, 0).cpu().numpy()
# img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB)
# cv2.imwrite(osp.join(viser4d_dir, f'frame_{i_k:04d}.png'), img_i)
# if stage == 1:
# depth = depth_est[i_k].squeeze().cpu().numpy()
# np.save(osp.join(viser4d_dir, f'frame_{i_k:04d}.npy'), depth)
# else:
# point_map_vis = out["points_map"][i_k].cpu().numpy()
# np.save(osp.join(viser4d_dir, f'point_{i_k:04d}.npy'), point_map_vis)
# np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy())
# np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy())
# np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy())
# np.save(os.path.join(viser4d_dir, f'colored_track3d.npy'), rgb_tracks.cpu().numpy())
queries_len = len(queries_new)
# update the track3d and track2d
@ -724,40 +691,3 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
}
return ret
# three stages of training
# stage 1:
# gt depth and intrinsics synthetic (includes Dynamic Replica, Kubric, Pointodyssey, Vkitti, TartanAir and Indoor() ) Motion Patern (tapvid3d)
# Tracking and Pose as well -> based on gt depth and intrinsics
# (Finished) -> (megasam + base model) vs. tapip3d. (use depth from megasam or pose, which keep the same setting as tapip3d.)
# stage 2: fixed 3D tracking
# Joint depth refiner
# input depth from whatever + rgb -> temporal module + scale and shift token -> coarse alignment -> scale and shift
# estimate the 3D tracks -> 3D tracks combine with pointmap -> update for pointmap (iteratively) -> residual map B T 3 H W
# ongoing two days
# stage 3: train multi windows by propagation
# 4 frames overlapped -> train on 64 -> fozen image encoder and finetuning the transformer (learnable parameters pretty small)
# types of scenarioes:
# 1. auto driving (waymo open dataset)
# 2. robot
# 3. internet ego video
# Iterative Transformer -- Solver -- General Neural MegaSAM + Tracks
# Update Variables:
# 1. 3D tracks B T N 3 xyz.
# 2. 2D tracks B T N 2 x y.
# 3. Dynamic Mask B T H W.
# 4. Camera Pose B T 4 4.
# 5. Video Depth.
# (RGB, RGBD, RGBD+Pose) x (Static, Dynamic)
# Campatiablity by product.

View File

@ -16,82 +16,21 @@ from typing import Union, Optional
import cv2
import os
import decord
from huggingface_hub import PyTorchModelHubMixin # used for model hub
class Predictor(torch.nn.Module):
class Predictor(torch.nn.Module, PyTorchModelHubMixin):
def __init__(self, args=None):
super().__init__()
self.args = args
self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
self.S_wind = args.Track_cfg.s_wind
self.overlap = args.Track_cfg.overlap
self.S_wind = args["Track_cfg"]["s_wind"]
self.overlap = args["Track_cfg"]["overlap"]
def to(self, device: Union[str, torch.device]):
self.spatrack.to(device)
if self.spatrack.base_model is not None:
self.spatrack.base_model.to(device)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, Path],
*,
force_download: bool = False,
cache_dir: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
model_cfg: Optional[dict] = None,
**kwargs,
) -> "SpaTrack2":
"""
Load a pretrained model from a local file or a remote repository.
Args:
pretrained_model_name_or_path (str or Path):
- Path to a local model file (e.g., `./model.pth`).
- HuggingFace Hub model ID (e.g., `username/model-name`).
force_download (bool, optional):
Whether to force re-download even if cached. Default: False.
cache_dir (str, optional):
Custom cache directory. Default: None (use default cache).
device (str or torch.device, optional):
Target device (e.g., "cuda", "cpu"). Default: None (keep original).
**kwargs:
Additional config overrides.
Returns:
SpaTrack2: Loaded pretrained model.
"""
# (1) check the path is local or remote
if isinstance(pretrained_model_name_or_path, Path):
model_path = str(pretrained_model_name_or_path)
else:
model_path = pretrained_model_name_or_path
# (2) if the path is remote, download it
if not os.path.exists(model_path):
raise NotImplementedError("Remote download not implemented yet. Use a local path.")
# (3) load the model weights
state_dict = torch.load(model_path, map_location="cpu")
# (4) initialize the model (can load config.json if exists)
config_path = os.path.join(os.path.dirname(model_path), "config.json")
config = {}
if os.path.exists(config_path):
import json
with open(config_path, "r") as f:
config.update(json.load(f))
config.update(kwargs) # allow override the config
if model_cfg is not None:
config = model_cfg
model = cls(config)
if "model" in state_dict:
model.spatrack.load_state_dict(state_dict["model"], strict=False)
else:
model.spatrack.load_state_dict(state_dict, strict=False)
# (5) device management
if device is not None:
model.to(device)
return model
def forward(self, video: str|torch.Tensor|np.ndarray,
depth: str|torch.Tensor|np.ndarray=None,
unc_metric: str|torch.Tensor|np.ndarray=None,
@ -146,7 +85,6 @@ class Predictor(torch.nn.Module):
window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track,
fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],)
return ret

View File

@ -30,7 +30,7 @@ from models.SpaTrackV2.models.tracker3D.delta_utils.upsample_transformer import
class TrackRefiner3D(CoTrackerThreeOffline):
def __init__(self, args=None):
super().__init__(**args.base)
super().__init__(**args["base"])
"""
This is 3D warpper from cotracker, which load the cotracker pretrain and
@ -46,15 +46,7 @@ class TrackRefiner3D(CoTrackerThreeOffline):
self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110)
# get the anchor point's embedding, and init the pts refiner
update_pts = True
# self.corr_transformer = nn.ModuleList([
# CorrPointformer(
# dim=128,
# num_heads=8,
# head_dim=128 // 8,
# mlp_ratio=4.0,
# )
# for _ in range(self.corr_levels)
# ])
self.corr_transformer = nn.ModuleList([
CorrPointformer(
dim=128,
@ -68,28 +60,10 @@ class TrackRefiner3D(CoTrackerThreeOffline):
output_dim=self.latent_dim, stride=self.stride)
self.corr3d_radius = 3
if args.stablizer:
self.scale_shift_tokens = nn.Parameter(torch.randn(1, 2, self.latent_dim, requires_grad=True))
self.upsample_kernel_size = 5
self.residual_embedding = nn.Parameter(torch.randn(
self.latent_dim, self.model_resolution[0]//16,
self.model_resolution[1]//16, requires_grad=True))
self.dense_mlp = nn.Conv2d(2*self.latent_dim+63, self.latent_dim, kernel_size=1, stride=1, padding=0)
self.upsample_factor = 4
self.upsample_transformer = UpsampleTransformerAlibi(
kernel_size=self.upsample_kernel_size, # kernel_size=3, #
stride=self.stride,
latent_dim=self.latent_dim,
num_attn_blocks=2,
upsample_factor=4,
)
else:
self.update_pointmap = None
self.mode = args.mode
self.mode = args["mode"]
if self.mode == "online":
self.s_wind = args.s_wind
self.overlap = args.overlap
self.s_wind = args["s_wind"]
self.overlap = args["overlap"]
def upsample_with_mask(
self, inp: torch.Tensor, mask: torch.Tensor
@ -1061,29 +1035,7 @@ class TrackRefiner3D(CoTrackerThreeOffline):
vis_est = (vis_est>0.5).float()
sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean()
# coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype)
# if torch.isnan(coords_proj_curr).sum()>0:
# import pdb; pdb.set_trace()
if False:
point_map_resize = point_map.clone().view(B, T, 3, H, W)
update_input = torch.cat([point_map_resize, metric_unc.view(B,T,1,H,W)], dim=2)
coords_append_resize = coords.clone().detach()
coords_append_resize[..., :2] = coords_append_resize[..., :2] * float(self.stride)
update_track_input = self.norm_xyz(cam_pts_est)*5
update_track_input = torch.cat([update_track_input, vis_est[...,None]], dim=-1)
update_track_input = posenc(update_track_input, min_deg=0, max_deg=12)
update = self.update_pointmap.stablizer(update_input,
update_track_input, coords_append_resize)#, imgs=video, vis_track=viser)
#NOTE: update the point map
point_map_resize += update
point_map_refine_out = F.interpolate(point_map_resize.view(B*T, -1, H, W),
size=(self.image_size[0].item(), self.image_size[1].item()), mode='nearest')
point_map_refine_out = rearrange(point_map_refine_out, '(b t) c h w -> b t c h w', t=T, b=B)
point_map_preds.append(self.denorm_xyz(point_map_refine_out))
point_map_org = self.denorm_xyz(point_map_refine_out).view(B*T, 3, H_, W_)
# if torch.isnan(coords).sum()>0:
# import pdb; pdb.set_trace()
#NOTE: the 2d tracking + unproject depth
fix_cam_est = coords_append.clone()
fix_cam_est[...,2] = depth_unproj

View File

@ -11,9 +11,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from models.vggt.vggt.layers import Mlp
from models.vggt.vggt.layers.block import Block
from models.vggt.vggt.heads.head_act import activate_pose
from models.SpaTrackV2.models.vggt4track.layers import Mlp
from models.SpaTrackV2.models.vggt4track.layers.block import Block
from models.SpaTrackV2.models.vggt4track.heads.head_act import activate_pose
class CameraHead(nn.Module):

View File

@ -11,9 +11,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from models.vggt.vggt.layers import Mlp
from models.vggt.vggt.layers.block import Block
from models.vggt.vggt.heads.head_act import activate_pose
from models.SpaTrackV2.models.vggt4track.layers import Mlp
from models.SpaTrackV2.models.vggt4track.layers.block import Block
from models.SpaTrackV2.models.vggt4track.heads.head_act import activate_pose
class ScaleHead(nn.Module):

View File

@ -10,10 +10,10 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Union, List, Dict, Any
from models.vggt.vggt.layers import PatchEmbed
from models.vggt.vggt.layers.block import Block
from models.vggt.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from models.vggt.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
from models.SpaTrackV2.models.vggt4track.layers import PatchEmbed
from models.SpaTrackV2.models.vggt4track.layers.block import Block
from models.SpaTrackV2.models.vggt4track.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from models.SpaTrackV2.models.vggt4track.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
from torch.utils.checkpoint import checkpoint
logger = logging.getLogger(__name__)

View File

@ -10,10 +10,10 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Union, List, Dict, Any
from models.vggt.vggt.layers import PatchEmbed
from models.vggt.vggt.layers.block import Block
from models.vggt.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from models.vggt.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
from models.SpaTrackV2.models.vggt4track.layers import PatchEmbed
from models.SpaTrackV2.models.vggt4track.layers.block import Block
from models.SpaTrackV2.models.vggt4track.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from models.SpaTrackV2.models.vggt4track.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
from torch.utils.checkpoint import checkpoint
logger = logging.getLogger(__name__)

View File

@ -9,12 +9,12 @@ import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from huggingface_hub import PyTorchModelHubMixin # used for model hub
from models.vggt.vggt.models.aggregator_front import Aggregator
from models.vggt.vggt.heads.camera_head import CameraHead
from models.vggt.vggt.heads.scale_head import ScaleHead
from models.SpaTrackV2.models.vggt4track.models.aggregator_front import Aggregator
from models.SpaTrackV2.models.vggt4track.heads.camera_head import CameraHead
from models.SpaTrackV2.models.vggt4track.heads.scale_head import ScaleHead
from einops import rearrange
from models.vggt.vggt.utils.loss import compute_loss
from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
from models.SpaTrackV2.utils.loss import compute_loss
from models.SpaTrackV2.utils.pose_enc import pose_encoding_to_extri_intri
import torch.nn.functional as F
class FrontTracker(nn.Module, PyTorchModelHubMixin):

View File

@ -8,14 +8,14 @@ import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin # used for model hub
from models.vggt.vggt.models.aggregator import Aggregator
from models.vggt.vggt.heads.camera_head import CameraHead
from models.vggt.vggt.heads.dpt_head import DPTHead
from models.vggt.vggt.heads.track_head import TrackHead
from models.vggt.vggt.utils.loss import compute_loss
from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
from models.SpaTrackV2.models.vggt4track.models.aggregator import Aggregator
from models.SpaTrackV2.models.vggt4track.heads.camera_head import CameraHead
from models.SpaTrackV2.models.vggt4track.heads.dpt_head import DPTHead
from models.SpaTrackV2.models.vggt4track.heads.track_head import TrackHead
from models.SpaTrackV2.models.vggt4track.utils.loss import compute_loss
from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
from models.vggt.vggt.utils.load_fn import preprocess_image
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
from einops import rearrange
import torch.nn.functional as F

View File

@ -15,7 +15,7 @@ from models.moge.train.losses import (
import torch.nn.functional as F
from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri, extri_intri_to_pose_encoding
from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri, extri_intri_to_pose_encoding
def compute_loss(predictions, annots):
"""

View File

@ -1,8 +0,0 @@
from setuptools import setup, find_packages
setup(
name='vggt',
version='0.1',
packages=find_packages(),
description='vggt local package',
)