Spaces:
Running
Running
save best model
Browse files- rlcube/models/.gitignore +2 -0
- rlcube/rlcube/models/models.py +6 -0
- rlcube/rlcube/train/train.py +12 -3
rlcube/models/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*
|
| 2 |
+
!.gitignore
|
rlcube/rlcube/models/models.py
CHANGED
|
@@ -73,6 +73,12 @@ class DNN(nn.Module):
|
|
| 73 |
policy = self.fc_policy(out)
|
| 74 |
return TensorDict({"value": value, "policy": policy}, batch_size=batch_size)
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
if __name__ == "__main__":
|
| 78 |
print("Testing RewardNet")
|
|
|
|
| 73 |
policy = self.fc_policy(out)
|
| 74 |
return TensorDict({"value": value, "policy": policy}, batch_size=batch_size)
|
| 75 |
|
| 76 |
+
def save(self, filepath: str):
|
| 77 |
+
torch.save(self.state_dict(), filepath)
|
| 78 |
+
|
| 79 |
+
def load(self, filepath: str):
|
| 80 |
+
self.load_state_dict(torch.load(filepath))
|
| 81 |
+
|
| 82 |
|
| 83 |
if __name__ == "__main__":
|
| 84 |
print("Testing RewardNet")
|
rlcube/rlcube/train/train.py
CHANGED
|
@@ -23,13 +23,17 @@ def train(epochs: int = 100):
|
|
| 23 |
print("Number of epochs:", epochs)
|
| 24 |
print()
|
| 25 |
|
| 26 |
-
dataloader = DataLoader(dataset, batch_size=
|
| 27 |
reward = Reward().to(device)
|
| 28 |
-
net = DNN()
|
|
|
|
|
|
|
|
|
|
| 29 |
optimizer = torch.optim.RMSprop(net.parameters(), lr=0.0001)
|
| 30 |
value_loss_fn = torch.nn.MSELoss()
|
| 31 |
policy_loss_fn = torch.nn.CrossEntropyLoss()
|
| 32 |
|
|
|
|
| 33 |
for epoch in range(epochs):
|
| 34 |
epoch_loss = 0
|
| 35 |
print(f"Training Epoch {epoch}")
|
|
@@ -59,7 +63,12 @@ def train(epochs: int = 100):
|
|
| 59 |
optimizer.zero_grad()
|
| 60 |
loss.backward()
|
| 61 |
optimizer.step()
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
if __name__ == "__main__":
|
|
|
|
| 23 |
print("Number of epochs:", epochs)
|
| 24 |
print()
|
| 25 |
|
| 26 |
+
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True)
|
| 27 |
reward = Reward().to(device)
|
| 28 |
+
net = DNN()
|
| 29 |
+
if os.path.exists("models/model_best.pth"):
|
| 30 |
+
net.load("models/model_best.pth")
|
| 31 |
+
net = net.to(device)
|
| 32 |
optimizer = torch.optim.RMSprop(net.parameters(), lr=0.0001)
|
| 33 |
value_loss_fn = torch.nn.MSELoss()
|
| 34 |
policy_loss_fn = torch.nn.CrossEntropyLoss()
|
| 35 |
|
| 36 |
+
best_loss = float("inf")
|
| 37 |
for epoch in range(epochs):
|
| 38 |
epoch_loss = 0
|
| 39 |
print(f"Training Epoch {epoch}")
|
|
|
|
| 63 |
optimizer.zero_grad()
|
| 64 |
loss.backward()
|
| 65 |
optimizer.step()
|
| 66 |
+
epoch_loss /= len(dataloader)
|
| 67 |
+
if epoch_loss < best_loss:
|
| 68 |
+
best_loss = epoch_loss
|
| 69 |
+
print(f"Saving model at epoch {epoch}")
|
| 70 |
+
net.save("models/model_best.pth")
|
| 71 |
+
print(f"Epoch {epoch} loss: {epoch_loss}")
|
| 72 |
|
| 73 |
|
| 74 |
if __name__ == "__main__":
|