This commit is contained in:
xiaoyuxi 2025-07-08 18:11:58 +08:00
parent ee685d0f59
commit b71363a727

View File

@ -27,6 +27,7 @@ class Predictor(torch.nn.Module):
def to(self, device: Union[str, torch.device]): def to(self, device: Union[str, torch.device]):
self.spatrack.to(device) self.spatrack.to(device)
if self.spatrack.base_model is not None:
self.spatrack.base_model.to(device) self.spatrack.base_model.to(device)
@classmethod @classmethod