PAPERMD
This commit is contained in:
parent
ee685d0f59
commit
b71363a727
@ -27,7 +27,8 @@ class Predictor(torch.nn.Module):
|
||||
|
||||
def to(self, device: Union[str, torch.device]):
|
||||
self.spatrack.to(device)
|
||||
self.spatrack.base_model.to(device)
|
||||
if self.spatrack.base_model is not None:
|
||||
self.spatrack.base_model.to(device)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user