first-commit
This commit is contained in:
parent
c2ed617c97
commit
e0569f6bd7
@ -42,7 +42,8 @@ if __name__ == "__main__":
|
||||
fps = int(args.fps)
|
||||
mask_dir = args.data_dir + f"/{args.video_name}.png"
|
||||
|
||||
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
|
||||
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts",
|
||||
"spatrack_front.pth") #, force_download=True)
|
||||
VGGT_DIR = os.environ["VGGT_DIR"]
|
||||
assert os.path.exists(VGGT_DIR), f"VGGT_DIR {VGGT_DIR} does not exist"
|
||||
front_track = VGGT4Track()
|
||||
|
||||
@ -40,6 +40,7 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
||||
resolution=518,
|
||||
max_len=600, # the maximum video length we can preprocess,
|
||||
track_num=768,
|
||||
moge_as_base=False,
|
||||
):
|
||||
|
||||
self.chunk_size = chunk_size
|
||||
@ -51,12 +52,15 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
||||
backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
|
||||
|
||||
super(SpaTrack2, self).__init__()
|
||||
if os.path.exists(backbone_ckpt_dir)==False:
|
||||
base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
|
||||
if moge_as_base:
|
||||
if os.path.exists(backbone_ckpt_dir)==False:
|
||||
base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
|
||||
else:
|
||||
checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
|
||||
base_model = MoGeModel(**checkpoint["model_config"])
|
||||
base_model.load_state_dict(checkpoint['model'])
|
||||
else:
|
||||
checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
|
||||
base_model = MoGeModel(**checkpoint["model_config"])
|
||||
base_model.load_state_dict(checkpoint['model'])
|
||||
base_model = None
|
||||
# avoid the base_model is a member of SpaTrack2
|
||||
object.__setattr__(self, 'base_model', base_model)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user