first-commit
This commit is contained in:
parent
c2ed617c97
commit
e0569f6bd7
@ -42,7 +42,8 @@ if __name__ == "__main__":
|
|||||||
fps = int(args.fps)
|
fps = int(args.fps)
|
||||||
mask_dir = args.data_dir + f"/{args.video_name}.png"
|
mask_dir = args.data_dir + f"/{args.video_name}.png"
|
||||||
|
|
||||||
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
|
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts",
|
||||||
|
"spatrack_front.pth") #, force_download=True)
|
||||||
VGGT_DIR = os.environ["VGGT_DIR"]
|
VGGT_DIR = os.environ["VGGT_DIR"]
|
||||||
assert os.path.exists(VGGT_DIR), f"VGGT_DIR {VGGT_DIR} does not exist"
|
assert os.path.exists(VGGT_DIR), f"VGGT_DIR {VGGT_DIR} does not exist"
|
||||||
front_track = VGGT4Track()
|
front_track = VGGT4Track()
|
||||||
|
|||||||
@ -40,6 +40,7 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
|||||||
resolution=518,
|
resolution=518,
|
||||||
max_len=600, # the maximum video length we can preprocess,
|
max_len=600, # the maximum video length we can preprocess,
|
||||||
track_num=768,
|
track_num=768,
|
||||||
|
moge_as_base=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
@ -51,12 +52,15 @@ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
|||||||
backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
|
backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
|
||||||
|
|
||||||
super(SpaTrack2, self).__init__()
|
super(SpaTrack2, self).__init__()
|
||||||
if os.path.exists(backbone_ckpt_dir)==False:
|
if moge_as_base:
|
||||||
base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
|
if os.path.exists(backbone_ckpt_dir)==False:
|
||||||
|
base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
|
||||||
|
base_model = MoGeModel(**checkpoint["model_config"])
|
||||||
|
base_model.load_state_dict(checkpoint['model'])
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
|
base_model = None
|
||||||
base_model = MoGeModel(**checkpoint["model_config"])
|
|
||||||
base_model.load_state_dict(checkpoint['model'])
|
|
||||||
# avoid the base_model is a member of SpaTrack2
|
# avoid the base_model is a member of SpaTrack2
|
||||||
object.__setattr__(self, 'base_model', base_model)
|
object.__setattr__(self, 'base_model', base_model)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user