imwithye commited on
Commit
297d0f8
·
1 Parent(s): 54c4741

save best model

Browse files
rlcube/models/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
rlcube/rlcube/models/models.py CHANGED
@@ -73,6 +73,12 @@ class DNN(nn.Module):
73
  policy = self.fc_policy(out)
74
  return TensorDict({"value": value, "policy": policy}, batch_size=batch_size)
75
 
 
 
 
 
 
 
76
 
77
  if __name__ == "__main__":
78
  print("Testing RewardNet")
 
73
  policy = self.fc_policy(out)
74
  return TensorDict({"value": value, "policy": policy}, batch_size=batch_size)
75
 
76
+ def save(self, filepath: str):
77
+ torch.save(self.state_dict(), filepath)
78
+
79
+ def load(self, filepath: str):
80
+ self.load_state_dict(torch.load(filepath))
81
+
82
 
83
  if __name__ == "__main__":
84
  print("Testing RewardNet")
rlcube/rlcube/train/train.py CHANGED
@@ -23,13 +23,17 @@ def train(epochs: int = 100):
23
  print("Number of epochs:", epochs)
24
  print()
25
 
26
- dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
27
  reward = Reward().to(device)
28
- net = DNN().to(device)
 
 
 
29
  optimizer = torch.optim.RMSprop(net.parameters(), lr=0.0001)
30
  value_loss_fn = torch.nn.MSELoss()
31
  policy_loss_fn = torch.nn.CrossEntropyLoss()
32
 
 
33
  for epoch in range(epochs):
34
  epoch_loss = 0
35
  print(f"Training Epoch {epoch}")
@@ -59,7 +63,12 @@ def train(epochs: int = 100):
59
  optimizer.zero_grad()
60
  loss.backward()
61
  optimizer.step()
62
- print(f"Epoch {epoch} loss: {epoch_loss / len(dataloader)}")
 
 
 
 
 
63
 
64
 
65
  if __name__ == "__main__":
 
23
  print("Number of epochs:", epochs)
24
  print()
25
 
26
+ dataloader = DataLoader(dataset, batch_size=1024, shuffle=True)
27
  reward = Reward().to(device)
28
+ net = DNN()
29
+ if os.path.exists("models/model_best.pth"):
30
+ net.load("models/model_best.pth")
31
+ net = net.to(device)
32
  optimizer = torch.optim.RMSprop(net.parameters(), lr=0.0001)
33
  value_loss_fn = torch.nn.MSELoss()
34
  policy_loss_fn = torch.nn.CrossEntropyLoss()
35
 
36
+ best_loss = float("inf")
37
  for epoch in range(epochs):
38
  epoch_loss = 0
39
  print(f"Training Epoch {epoch}")
 
63
  optimizer.zero_grad()
64
  loss.backward()
65
  optimizer.step()
66
+ epoch_loss /= len(dataloader)
67
+ if epoch_loss < best_loss:
68
+ best_loss = epoch_loss
69
+ print(f"Saving model at epoch {epoch}")
70
+ net.save("models/model_best.pth")
71
+ print(f"Epoch {epoch} loss: {epoch_loss}")
72
 
73
 
74
  if __name__ == "__main__":