From e0569f6bd76b1defe5fca37df6239969647f6d03 Mon Sep 17 00:00:00 2001 From: xiaoyuxi Date: Tue, 8 Jul 2025 17:05:25 +0800 Subject: [PATCH] first-commit --- inference.py | 3 ++- models/SpaTrackV2/models/SpaTrack.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/inference.py b/inference.py index ebe0747..a95c78c 100644 --- a/inference.py +++ b/inference.py @@ -42,7 +42,8 @@ if __name__ == "__main__": fps = int(args.fps) mask_dir = args.data_dir + f"/{args.video_name}.png" - os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True) + os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", + "spatrack_front.pth") #, force_download=True) VGGT_DIR = os.environ["VGGT_DIR"] assert os.path.exists(VGGT_DIR), f"VGGT_DIR {VGGT_DIR} does not exist" front_track = VGGT4Track() diff --git a/models/SpaTrackV2/models/SpaTrack.py b/models/SpaTrackV2/models/SpaTrack.py index 2c4b832..c369733 100644 --- a/models/SpaTrackV2/models/SpaTrack.py +++ b/models/SpaTrackV2/models/SpaTrack.py @@ -40,6 +40,7 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin): resolution=518, max_len=600, # the maximum video length we can preprocess, track_num=768, + moge_as_base=False, ): self.chunk_size = chunk_size @@ -51,12 +52,15 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin): backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None) super(SpaTrack2, self).__init__() - if os.path.exists(backbone_ckpt_dir)==False: - base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl') + if moge_as_base: + if os.path.exists(backbone_ckpt_dir)==False: + base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl') + else: + checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True) + base_model = MoGeModel(**checkpoint["model_config"]) + base_model.load_state_dict(checkpoint['model']) else: - checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True) - base_model = MoGeModel(**checkpoint["model_config"]) - base_model.load_state_dict(checkpoint['model']) + base_model = None # avoid the base_model is a member of SpaTrack2 object.__setattr__(self, 'base_model', base_model)