antoniaebner commited on
Commit
163605e
·
1 Parent(s): 54fc11b

debug model loading

Browse files
Files changed (1) hide show
  1. src/model.py +3 -1
src/model.py CHANGED
@@ -101,7 +101,9 @@ class Tox21SNNClassifier(nn.Module):
101
  return x # x.view(x.size(0), self.num_tasks)
102
 
103
  def load_model(self, path: str):
104
- state_dict = torch.load(path, weights_only=False)["model"]
 
 
105
  self.load_state_dict(state_dict)
106
  self.eval()
107
 
 
101
  return x # x.view(x.size(0), self.num_tasks)
102
 
103
  def load_model(self, path: str):
104
+ state_dict = torch.load(
105
+ path, weights_only=False, map_location=torch.device("cpu")
106
+ )["model"]
107
  self.load_state_dict(state_dict)
108
  self.eval()
109