diff --git a/inference.py b/inference.py index ebe0747..a95c78c 100644 --- a/inference.py +++ b/inference.py @@ -42,7 +42,8 @@ if __name__ == "__main__": fps = int(args.fps) mask_dir = args.data_dir + f"/{args.video_name}.png" - os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True) + os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", + "spatrack_front.pth") #, force_download=True) VGGT_DIR = os.environ["VGGT_DIR"] assert os.path.exists(VGGT_DIR), f"VGGT_DIR {VGGT_DIR} does not exist" front_track = VGGT4Track() diff --git a/models/SpaTrackV2/models/SpaTrack.py b/models/SpaTrackV2/models/SpaTrack.py index 2c4b832..c369733 100644 --- a/models/SpaTrackV2/models/SpaTrack.py +++ b/models/SpaTrackV2/models/SpaTrack.py @@ -40,6 +40,7 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin): resolution=518, max_len=600, # the maximum video length we can preprocess, track_num=768, + moge_as_base=False, ): self.chunk_size = chunk_size @@ -51,12 +52,15 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin): backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None) super(SpaTrack2, self).__init__() - if os.path.exists(backbone_ckpt_dir)==False: - base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl') + if moge_as_base: + if os.path.exists(backbone_ckpt_dir)==False: + base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl') + else: + checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True) + base_model = MoGeModel(**checkpoint["model_config"]) + base_model.load_state_dict(checkpoint['model']) else: - checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True) - base_model = MoGeModel(**checkpoint["model_config"]) - base_model.load_state_dict(checkpoint['model']) + base_model = None # avoid the base_model is a member of SpaTrack2 object.__setattr__(self, 'base_model', base_model)