first-commit

This commit is contained in:
xiaoyuxi 2025-07-08 17:05:25 +08:00
parent c2ed617c97
commit e0569f6bd7
2 changed files with 11 additions and 6 deletions

View File

@ -42,7 +42,8 @@ if __name__ == "__main__":
fps = int(args.fps)
mask_dir = args.data_dir + f"/{args.video_name}.png"
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts",
"spatrack_front.pth") #, force_download=True)
VGGT_DIR = os.environ["VGGT_DIR"]
assert os.path.exists(VGGT_DIR), f"VGGT_DIR {VGGT_DIR} does not exist"
front_track = VGGT4Track()

View File

@ -40,6 +40,7 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
resolution=518,
max_len=600, # the maximum video length we can preprocess,
track_num=768,
moge_as_base=False,
):
self.chunk_size = chunk_size
@ -51,12 +52,15 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
super(SpaTrack2, self).__init__()
if os.path.exists(backbone_ckpt_dir)==False:
base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
if moge_as_base:
if os.path.exists(backbone_ckpt_dir)==False:
base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
else:
checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
base_model = MoGeModel(**checkpoint["model_config"])
base_model.load_state_dict(checkpoint['model'])
else:
checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
base_model = MoGeModel(**checkpoint["model_config"])
base_model.load_state_dict(checkpoint['model'])
base_model = None
# avoid the base_model is a member of SpaTrack2
object.__setattr__(self, 'base_model', base_model)