diff --git a/models/SpaTrackV2/models/predictor.py b/models/SpaTrackV2/models/predictor.py index 10dd942..a6a2861 100644 --- a/models/SpaTrackV2/models/predictor.py +++ b/models/SpaTrackV2/models/predictor.py @@ -27,7 +27,8 @@ class Predictor(torch.nn.Module): def to(self, device: Union[str, torch.device]): self.spatrack.to(device) - self.spatrack.base_model.to(device) + if self.spatrack.base_model is not None: + self.spatrack.base_model.to(device) @classmethod def from_pretrained(