This commit is contained in:
xiaoyuxi 2025-07-08 18:11:58 +08:00
parent ee685d0f59
commit b71363a727

View File

@ -27,7 +27,8 @@ class Predictor(torch.nn.Module):
def to(self, device: Union[str, torch.device]):
self.spatrack.to(device)
self.spatrack.base_model.to(device)
if self.spatrack.base_model is not None:
self.spatrack.base_model.to(device)
@classmethod
def from_pretrained(