fix_md
This commit is contained in:
parent
4393c1a348
commit
f427674f3f
3
.gitignore
vendored
3
.gitignore
vendored
@ -48,4 +48,5 @@ config/fix_2d.yaml
|
||||
models/**/build
|
||||
models/**/dist
|
||||
|
||||
temp_local
|
||||
temp_local
|
||||
examples/results
|
||||
60
app.py
60
app.py
@ -26,6 +26,9 @@ import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import atexit
|
||||
import uuid
|
||||
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
|
||||
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
|
||||
from models.SpaTrackV2.models.predictor import Predictor
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -78,20 +81,15 @@ def create_user_temp_dir():
|
||||
return temp_dir
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
# init the model
|
||||
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
|
||||
|
||||
if os.environ.get("VGGT_DIR", None) is not None:
|
||||
from models.vggt.vggt.models.vggt_moe import VGGT4Track
|
||||
from models.vggt.vggt.utils.load_fn import preprocess_image
|
||||
vggt_model = VGGT4Track()
|
||||
vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")), strict=False)
|
||||
vggt_model.eval()
|
||||
vggt_model = vggt_model.to("cuda")
|
||||
vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
|
||||
vggt4track_model.eval()
|
||||
vggt4track_model = vggt4track_model.to("cuda")
|
||||
|
||||
# Global model initialization
|
||||
print("🚀 Initializing local models...")
|
||||
tracker_model, _ = get_tracker_predictor(".", vo_points=756)
|
||||
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
||||
tracker_model.eval()
|
||||
predictor = get_sam_predictor()
|
||||
print("✅ Models loaded successfully!")
|
||||
|
||||
@ -129,7 +127,8 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
||||
print("Initializing tracker models inside GPU function...")
|
||||
out_dir = os.path.join(temp_dir, "results")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, tracker_model=tracker_model)
|
||||
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
||||
tracker_model=tracker_model.cuda())
|
||||
|
||||
# Setup paths
|
||||
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
||||
@ -159,25 +158,23 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
||||
data_npz_load = {}
|
||||
|
||||
# run vggt
|
||||
if os.environ.get("VGGT_DIR", None) is not None:
|
||||
# process the image tensor
|
||||
video_tensor = preprocess_image(video_tensor)[None]
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||
# Predict attributes including cameras, depth maps, and point maps.
|
||||
predictions = vggt_model(video_tensor.cuda()/255)
|
||||
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
|
||||
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
|
||||
|
||||
depth_tensor = depth_map.squeeze().cpu().numpy()
|
||||
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
||||
extrs = extrinsic.squeeze().cpu().numpy()
|
||||
intrs = intrinsic.squeeze().cpu().numpy()
|
||||
video_tensor = video_tensor.squeeze()
|
||||
#NOTE: 20% of the depth is not reliable
|
||||
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
|
||||
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
||||
# process the image tensor
|
||||
video_tensor = preprocess_image(video_tensor)[None]
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||
# Predict attributes including cameras, depth maps, and point maps.
|
||||
predictions = vggt4track_model(video_tensor.cuda()/255)
|
||||
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
|
||||
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
|
||||
|
||||
depth_tensor = depth_map.squeeze().cpu().numpy()
|
||||
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
||||
extrs = extrinsic.squeeze().cpu().numpy()
|
||||
intrs = intrinsic.squeeze().cpu().numpy()
|
||||
video_tensor = video_tensor.squeeze()
|
||||
#NOTE: 20% of the depth is not reliable
|
||||
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
|
||||
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
||||
# Load and process mask
|
||||
if os.path.exists(mask_path):
|
||||
mask = cv2.imread(mask_path)
|
||||
@ -199,7 +196,6 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
||||
|
||||
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
|
||||
print(f"Query points shape: {query_xyt.shape}")
|
||||
|
||||
# Run model inference
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
(
|
||||
@ -210,8 +206,8 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
||||
queries=query_xyt,
|
||||
fps=1, full_point=False, iters_track=4,
|
||||
query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
|
||||
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
||||
|
||||
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
||||
|
||||
# Resize results to avoid large I/O
|
||||
max_size = 224
|
||||
h, w = video.shape[2:]
|
||||
|
||||
52
inference.py
52
inference.py
@ -17,16 +17,14 @@ import glob
|
||||
from rich import print
|
||||
import argparse
|
||||
import decord
|
||||
from models.vggt.vggt.models.vggt_moe import VGGT4Track
|
||||
from models.vggt.vggt.utils.load_fn import preprocess_image
|
||||
from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
||||
from huggingface_hub import hf_hub_download
|
||||
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
|
||||
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
|
||||
from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--cfg_dir", type=str, default="config/magic_infer_offline.yaml")
|
||||
parser.add_argument("--track_mode", type=str, default="offline")
|
||||
parser.add_argument("--data_type", type=str, default="RGBD")
|
||||
parser.add_argument("--VGGT", action="store_true")
|
||||
parser.add_argument("--data_dir", type=str, default="assets/example0")
|
||||
parser.add_argument("--video_name", type=str, default="snowboard")
|
||||
parser.add_argument("--grid_size", type=int, default=10)
|
||||
@ -36,20 +34,14 @@ def parse_args():
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
cfg_dir = args.cfg_dir
|
||||
out_dir = args.data_dir + "/results"
|
||||
# fps
|
||||
fps = int(args.fps)
|
||||
mask_dir = args.data_dir + f"/{args.video_name}.png"
|
||||
|
||||
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts",
|
||||
"spatrack_front.pth") #, force_download=True)
|
||||
VGGT_DIR = os.environ["VGGT_DIR"]
|
||||
assert os.path.exists(VGGT_DIR), f"VGGT_DIR {VGGT_DIR} does not exist"
|
||||
front_track = VGGT4Track()
|
||||
front_track.load_state_dict(torch.load(VGGT_DIR), strict=False)
|
||||
front_track.eval()
|
||||
front_track = front_track.to("cuda")
|
||||
vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front")
|
||||
vggt4track_model.eval()
|
||||
vggt4track_model = vggt4track_model.to("cuda")
|
||||
|
||||
if args.data_type == "RGBD":
|
||||
npz_dir = args.data_dir + f"/{args.video_name}.npz"
|
||||
@ -76,7 +68,7 @@ if __name__ == "__main__":
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||
# Predict attributes including cameras, depth maps, and point maps.
|
||||
predictions = front_track(video_tensor.cuda()/255)
|
||||
predictions = vggt4track_model(video_tensor.cuda()/255)
|
||||
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
|
||||
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
|
||||
|
||||
@ -103,21 +95,23 @@ if __name__ == "__main__":
|
||||
viz = True
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
with open(cfg_dir, "r") as f:
|
||||
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
||||
cfg = easydict.EasyDict(cfg)
|
||||
cfg.out_dir = out_dir
|
||||
cfg.model.track_num = args.vo_points
|
||||
print(f"Downloading model from HuggingFace: {cfg.ckpts}")
|
||||
checkpoint_path = hf_hub_download(
|
||||
repo_id=cfg.ckpts,
|
||||
repo_type="model",
|
||||
filename="SpaTrack3_offline.pth"
|
||||
)
|
||||
model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
|
||||
# with open(cfg_dir, "r") as f:
|
||||
# cfg = yaml.load(f, Loader=yaml.FullLoader)
|
||||
# cfg = easydict.EasyDict(cfg)
|
||||
# cfg.out_dir = out_dir
|
||||
# cfg.model.track_num = args.vo_points
|
||||
# print(f"Downloading model from HuggingFace: {cfg.ckpts}")
|
||||
if args.track_mode == "offline":
|
||||
model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
||||
else:
|
||||
model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
|
||||
|
||||
# config the model; the track_num is the number of points in the grid
|
||||
model.spatrack.track_num = args.vo_points
|
||||
|
||||
model.eval()
|
||||
model.to("cuda")
|
||||
viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
|
||||
viser = Visualizer(save_dir=out_dir, grayscale=True,
|
||||
fps=10, pad_value=0, tracks_leave_trace=5)
|
||||
|
||||
grid_size = args.grid_size
|
||||
|
||||
@ -66,15 +66,15 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
# Tracker model
|
||||
self.Track3D = TrackRefiner3D(Track_cfg)
|
||||
track_base_ckpt_dir = Track_cfg.base_ckpt
|
||||
track_base_ckpt_dir = Track_cfg["base_ckpt"]
|
||||
if os.path.exists(track_base_ckpt_dir):
|
||||
track_pretrain = torch.load(track_base_ckpt_dir)
|
||||
self.Track3D.load_state_dict(track_pretrain, strict=False)
|
||||
|
||||
# wrap the function of make lora trainable
|
||||
self.make_paras_trainable = partial(self.make_paras_trainable,
|
||||
mode=ft_cfg.mode,
|
||||
paras_name=ft_cfg.paras_name)
|
||||
mode=ft_cfg["mode"],
|
||||
paras_name=ft_cfg["paras_name"])
|
||||
self.track_num = track_num
|
||||
|
||||
def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
|
||||
@ -149,7 +149,7 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
||||
):
|
||||
# step 1 allocate the query points on the grid
|
||||
T, C, H, W = video.shape
|
||||
|
||||
|
||||
if annots_train is not None:
|
||||
vis_gt = annots_train["vis"]
|
||||
_, _, N = vis_gt.shape
|
||||
@ -300,39 +300,6 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
||||
**kwargs, annots=annots)
|
||||
if self.training:
|
||||
loss += out["loss"].squeeze()
|
||||
# from models.SpaTrackV2.utils.visualizer import Visualizer
|
||||
# vis_track = Visualizer(grayscale=False,
|
||||
# fps=10, pad_value=50, tracks_leave_trace=0)
|
||||
# vis_track.visualize(video=segment,
|
||||
# tracks=out["traj_est"][...,:2],
|
||||
# visibility=out["vis_est"],
|
||||
# save_video=True)
|
||||
# # visualize 4d
|
||||
# import os, json
|
||||
# import os.path as osp
|
||||
# viser4d_dir = os.path.join("viser_4d_results")
|
||||
# os.makedirs(viser4d_dir, exist_ok=True)
|
||||
# depth_est = annots["depth_gt"][0]
|
||||
# unc_metric = out["unc_metric"]
|
||||
# mask = (unc_metric > 0.5).squeeze(1)
|
||||
# # pose_est = out["poses_pred"].squeeze(0)
|
||||
# pose_est = annots["traj_mat"][0]
|
||||
# rgb_tracks = out["rgb_tracks"].squeeze(0)
|
||||
# intrinsics = out["intrs"].squeeze(0)
|
||||
# for i_k in range(out["depth"].shape[0]):
|
||||
# img_i = out["imgs_raw"][0][i_k].permute(1, 2, 0).cpu().numpy()
|
||||
# img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB)
|
||||
# cv2.imwrite(osp.join(viser4d_dir, f'frame_{i_k:04d}.png'), img_i)
|
||||
# if stage == 1:
|
||||
# depth = depth_est[i_k].squeeze().cpu().numpy()
|
||||
# np.save(osp.join(viser4d_dir, f'frame_{i_k:04d}.npy'), depth)
|
||||
# else:
|
||||
# point_map_vis = out["points_map"][i_k].cpu().numpy()
|
||||
# np.save(osp.join(viser4d_dir, f'point_{i_k:04d}.npy'), point_map_vis)
|
||||
# np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy())
|
||||
# np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy())
|
||||
# np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy())
|
||||
# np.save(os.path.join(viser4d_dir, f'colored_track3d.npy'), rgb_tracks.cpu().numpy())
|
||||
|
||||
queries_len = len(queries_new)
|
||||
# update the track3d and track2d
|
||||
@ -724,40 +691,3 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
||||
}
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
|
||||
# three stages of training
|
||||
|
||||
# stage 1:
|
||||
# gt depth and intrinsics synthetic (includes Dynamic Replica, Kubric, Pointodyssey, Vkitti, TartanAir and Indoor() ) Motion Patern (tapvid3d)
|
||||
# Tracking and Pose as well -> based on gt depth and intrinsics
|
||||
# (Finished) -> (megasam + base model) vs. tapip3d. (use depth from megasam or pose, which keep the same setting as tapip3d.)
|
||||
|
||||
# stage 2: fixed 3D tracking
|
||||
# Joint depth refiner
|
||||
# input depth from whatever + rgb -> temporal module + scale and shift token -> coarse alignment -> scale and shift
|
||||
# estimate the 3D tracks -> 3D tracks combine with pointmap -> update for pointmap (iteratively) -> residual map B T 3 H W
|
||||
# ongoing two days
|
||||
|
||||
# stage 3: train multi windows by propagation
|
||||
# 4 frames overlapped -> train on 64 -> fozen image encoder and finetuning the transformer (learnable parameters pretty small)
|
||||
|
||||
# types of scenarioes:
|
||||
# 1. auto driving (waymo open dataset)
|
||||
# 2. robot
|
||||
# 3. internet ego video
|
||||
|
||||
|
||||
|
||||
# Iterative Transformer -- Solver -- General Neural MegaSAM + Tracks
|
||||
# Update Variables:
|
||||
# 1. 3D tracks B T N 3 xyz.
|
||||
# 2. 2D tracks B T N 2 x y.
|
||||
# 3. Dynamic Mask B T H W.
|
||||
# 4. Camera Pose B T 4 4.
|
||||
# 5. Video Depth.
|
||||
|
||||
# (RGB, RGBD, RGBD+Pose) x (Static, Dynamic)
|
||||
# Campatiablity by product.
|
||||
@ -16,82 +16,21 @@ from typing import Union, Optional
|
||||
import cv2
|
||||
import os
|
||||
import decord
|
||||
from huggingface_hub import PyTorchModelHubMixin # used for model hub
|
||||
|
||||
class Predictor(torch.nn.Module):
|
||||
class Predictor(torch.nn.Module, PyTorchModelHubMixin):
|
||||
def __init__(self, args=None):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
|
||||
self.S_wind = args.Track_cfg.s_wind
|
||||
self.overlap = args.Track_cfg.overlap
|
||||
self.S_wind = args["Track_cfg"]["s_wind"]
|
||||
self.overlap = args["Track_cfg"]["overlap"]
|
||||
|
||||
def to(self, device: Union[str, torch.device]):
|
||||
self.spatrack.to(device)
|
||||
if self.spatrack.base_model is not None:
|
||||
self.spatrack.base_model.to(device)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, Path],
|
||||
*,
|
||||
force_download: bool = False,
|
||||
cache_dir: Optional[str] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
model_cfg: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> "SpaTrack2":
|
||||
"""
|
||||
Load a pretrained model from a local file or a remote repository.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (str or Path):
|
||||
- Path to a local model file (e.g., `./model.pth`).
|
||||
- HuggingFace Hub model ID (e.g., `username/model-name`).
|
||||
force_download (bool, optional):
|
||||
Whether to force re-download even if cached. Default: False.
|
||||
cache_dir (str, optional):
|
||||
Custom cache directory. Default: None (use default cache).
|
||||
device (str or torch.device, optional):
|
||||
Target device (e.g., "cuda", "cpu"). Default: None (keep original).
|
||||
**kwargs:
|
||||
Additional config overrides.
|
||||
|
||||
Returns:
|
||||
SpaTrack2: Loaded pretrained model.
|
||||
"""
|
||||
# (1) check the path is local or remote
|
||||
if isinstance(pretrained_model_name_or_path, Path):
|
||||
model_path = str(pretrained_model_name_or_path)
|
||||
else:
|
||||
model_path = pretrained_model_name_or_path
|
||||
# (2) if the path is remote, download it
|
||||
if not os.path.exists(model_path):
|
||||
raise NotImplementedError("Remote download not implemented yet. Use a local path.")
|
||||
# (3) load the model weights
|
||||
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
# (4) initialize the model (can load config.json if exists)
|
||||
config_path = os.path.join(os.path.dirname(model_path), "config.json")
|
||||
config = {}
|
||||
if os.path.exists(config_path):
|
||||
import json
|
||||
with open(config_path, "r") as f:
|
||||
config.update(json.load(f))
|
||||
config.update(kwargs) # allow override the config
|
||||
if model_cfg is not None:
|
||||
config = model_cfg
|
||||
model = cls(config)
|
||||
if "model" in state_dict:
|
||||
model.spatrack.load_state_dict(state_dict["model"], strict=False)
|
||||
else:
|
||||
model.spatrack.load_state_dict(state_dict, strict=False)
|
||||
# (5) device management
|
||||
if device is not None:
|
||||
model.to(device)
|
||||
|
||||
return model
|
||||
|
||||
def forward(self, video: str|torch.Tensor|np.ndarray,
|
||||
depth: str|torch.Tensor|np.ndarray=None,
|
||||
unc_metric: str|torch.Tensor|np.ndarray=None,
|
||||
@ -146,7 +85,6 @@ class Predictor(torch.nn.Module):
|
||||
window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track,
|
||||
fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],)
|
||||
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ from models.SpaTrackV2.models.tracker3D.delta_utils.upsample_transformer import
|
||||
class TrackRefiner3D(CoTrackerThreeOffline):
|
||||
|
||||
def __init__(self, args=None):
|
||||
super().__init__(**args.base)
|
||||
super().__init__(**args["base"])
|
||||
|
||||
"""
|
||||
This is 3D warpper from cotracker, which load the cotracker pretrain and
|
||||
@ -46,15 +46,7 @@ class TrackRefiner3D(CoTrackerThreeOffline):
|
||||
self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110)
|
||||
# get the anchor point's embedding, and init the pts refiner
|
||||
update_pts = True
|
||||
# self.corr_transformer = nn.ModuleList([
|
||||
# CorrPointformer(
|
||||
# dim=128,
|
||||
# num_heads=8,
|
||||
# head_dim=128 // 8,
|
||||
# mlp_ratio=4.0,
|
||||
# )
|
||||
# for _ in range(self.corr_levels)
|
||||
# ])
|
||||
|
||||
self.corr_transformer = nn.ModuleList([
|
||||
CorrPointformer(
|
||||
dim=128,
|
||||
@ -67,29 +59,11 @@ class TrackRefiner3D(CoTrackerThreeOffline):
|
||||
self.fnet = BasicEncoder(input_dim=3,
|
||||
output_dim=self.latent_dim, stride=self.stride)
|
||||
self.corr3d_radius = 3
|
||||
|
||||
if args.stablizer:
|
||||
self.scale_shift_tokens = nn.Parameter(torch.randn(1, 2, self.latent_dim, requires_grad=True))
|
||||
self.upsample_kernel_size = 5
|
||||
self.residual_embedding = nn.Parameter(torch.randn(
|
||||
self.latent_dim, self.model_resolution[0]//16,
|
||||
self.model_resolution[1]//16, requires_grad=True))
|
||||
self.dense_mlp = nn.Conv2d(2*self.latent_dim+63, self.latent_dim, kernel_size=1, stride=1, padding=0)
|
||||
self.upsample_factor = 4
|
||||
self.upsample_transformer = UpsampleTransformerAlibi(
|
||||
kernel_size=self.upsample_kernel_size, # kernel_size=3, #
|
||||
stride=self.stride,
|
||||
latent_dim=self.latent_dim,
|
||||
num_attn_blocks=2,
|
||||
upsample_factor=4,
|
||||
)
|
||||
else:
|
||||
self.update_pointmap = None
|
||||
|
||||
self.mode = args.mode
|
||||
self.mode = args["mode"]
|
||||
if self.mode == "online":
|
||||
self.s_wind = args.s_wind
|
||||
self.overlap = args.overlap
|
||||
self.s_wind = args["s_wind"]
|
||||
self.overlap = args["overlap"]
|
||||
|
||||
def upsample_with_mask(
|
||||
self, inp: torch.Tensor, mask: torch.Tensor
|
||||
@ -1061,29 +1035,7 @@ class TrackRefiner3D(CoTrackerThreeOffline):
|
||||
vis_est = (vis_est>0.5).float()
|
||||
sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean()
|
||||
# coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype)
|
||||
# if torch.isnan(coords_proj_curr).sum()>0:
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
if False:
|
||||
point_map_resize = point_map.clone().view(B, T, 3, H, W)
|
||||
update_input = torch.cat([point_map_resize, metric_unc.view(B,T,1,H,W)], dim=2)
|
||||
coords_append_resize = coords.clone().detach()
|
||||
coords_append_resize[..., :2] = coords_append_resize[..., :2] * float(self.stride)
|
||||
update_track_input = self.norm_xyz(cam_pts_est)*5
|
||||
update_track_input = torch.cat([update_track_input, vis_est[...,None]], dim=-1)
|
||||
update_track_input = posenc(update_track_input, min_deg=0, max_deg=12)
|
||||
update = self.update_pointmap.stablizer(update_input,
|
||||
update_track_input, coords_append_resize)#, imgs=video, vis_track=viser)
|
||||
#NOTE: update the point map
|
||||
point_map_resize += update
|
||||
point_map_refine_out = F.interpolate(point_map_resize.view(B*T, -1, H, W),
|
||||
size=(self.image_size[0].item(), self.image_size[1].item()), mode='nearest')
|
||||
point_map_refine_out = rearrange(point_map_refine_out, '(b t) c h w -> b t c h w', t=T, b=B)
|
||||
point_map_preds.append(self.denorm_xyz(point_map_refine_out))
|
||||
point_map_org = self.denorm_xyz(point_map_refine_out).view(B*T, 3, H_, W_)
|
||||
|
||||
# if torch.isnan(coords).sum()>0:
|
||||
# import pdb; pdb.set_trace()
|
||||
#NOTE: the 2d tracking + unproject depth
|
||||
fix_cam_est = coords_append.clone()
|
||||
fix_cam_est[...,2] = depth_unproj
|
||||
|
||||
@ -11,9 +11,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.vggt.vggt.layers import Mlp
|
||||
from models.vggt.vggt.layers.block import Block
|
||||
from models.vggt.vggt.heads.head_act import activate_pose
|
||||
from models.SpaTrackV2.models.vggt4track.layers import Mlp
|
||||
from models.SpaTrackV2.models.vggt4track.layers.block import Block
|
||||
from models.SpaTrackV2.models.vggt4track.heads.head_act import activate_pose
|
||||
|
||||
|
||||
class CameraHead(nn.Module):
|
||||
@ -11,9 +11,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.vggt.vggt.layers import Mlp
|
||||
from models.vggt.vggt.layers.block import Block
|
||||
from models.vggt.vggt.heads.head_act import activate_pose
|
||||
from models.SpaTrackV2.models.vggt4track.layers import Mlp
|
||||
from models.SpaTrackV2.models.vggt4track.layers.block import Block
|
||||
from models.SpaTrackV2.models.vggt4track.heads.head_act import activate_pose
|
||||
|
||||
|
||||
class ScaleHead(nn.Module):
|
||||
@ -10,10 +10,10 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any
|
||||
|
||||
from models.vggt.vggt.layers import PatchEmbed
|
||||
from models.vggt.vggt.layers.block import Block
|
||||
from models.vggt.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
||||
from models.vggt.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
||||
from models.SpaTrackV2.models.vggt4track.layers import PatchEmbed
|
||||
from models.SpaTrackV2.models.vggt4track.layers.block import Block
|
||||
from models.SpaTrackV2.models.vggt4track.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
||||
from models.SpaTrackV2.models.vggt4track.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -10,10 +10,10 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any
|
||||
|
||||
from models.vggt.vggt.layers import PatchEmbed
|
||||
from models.vggt.vggt.layers.block import Block
|
||||
from models.vggt.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
||||
from models.vggt.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
||||
from models.SpaTrackV2.models.vggt4track.layers import PatchEmbed
|
||||
from models.SpaTrackV2.models.vggt4track.layers.block import Block
|
||||
from models.SpaTrackV2.models.vggt4track.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
||||
from models.SpaTrackV2.models.vggt4track.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -9,12 +9,12 @@ import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from huggingface_hub import PyTorchModelHubMixin # used for model hub
|
||||
|
||||
from models.vggt.vggt.models.aggregator_front import Aggregator
|
||||
from models.vggt.vggt.heads.camera_head import CameraHead
|
||||
from models.vggt.vggt.heads.scale_head import ScaleHead
|
||||
from models.SpaTrackV2.models.vggt4track.models.aggregator_front import Aggregator
|
||||
from models.SpaTrackV2.models.vggt4track.heads.camera_head import CameraHead
|
||||
from models.SpaTrackV2.models.vggt4track.heads.scale_head import ScaleHead
|
||||
from einops import rearrange
|
||||
from models.vggt.vggt.utils.loss import compute_loss
|
||||
from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
||||
from models.SpaTrackV2.utils.loss import compute_loss
|
||||
from models.SpaTrackV2.utils.pose_enc import pose_encoding_to_extri_intri
|
||||
import torch.nn.functional as F
|
||||
|
||||
class FrontTracker(nn.Module, PyTorchModelHubMixin):
|
||||
@ -8,14 +8,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import PyTorchModelHubMixin # used for model hub
|
||||
|
||||
from models.vggt.vggt.models.aggregator import Aggregator
|
||||
from models.vggt.vggt.heads.camera_head import CameraHead
|
||||
from models.vggt.vggt.heads.dpt_head import DPTHead
|
||||
from models.vggt.vggt.heads.track_head import TrackHead
|
||||
from models.vggt.vggt.utils.loss import compute_loss
|
||||
from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
||||
from models.SpaTrackV2.models.vggt4track.models.aggregator import Aggregator
|
||||
from models.SpaTrackV2.models.vggt4track.heads.camera_head import CameraHead
|
||||
from models.SpaTrackV2.models.vggt4track.heads.dpt_head import DPTHead
|
||||
from models.SpaTrackV2.models.vggt4track.heads.track_head import TrackHead
|
||||
from models.SpaTrackV2.models.vggt4track.utils.loss import compute_loss
|
||||
from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri
|
||||
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
|
||||
from models.vggt.vggt.utils.load_fn import preprocess_image
|
||||
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -15,7 +15,7 @@ from models.moge.train.losses import (
|
||||
import torch.nn.functional as F
|
||||
from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb
|
||||
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
|
||||
from models.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri, extri_intri_to_pose_encoding
|
||||
from models.SpaTrackV2.models.vggt4track.utils.pose_enc import pose_encoding_to_extri_intri, extri_intri_to_pose_encoding
|
||||
|
||||
def compute_loss(predictions, annots):
|
||||
"""
|
||||
@ -1,8 +0,0 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='vggt',
|
||||
version='0.1',
|
||||
packages=find_packages(),
|
||||
description='vggt local package',
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user