Spaces:
Running
Running
detach
Browse files- rlcube/rlcube/train/train.py +15 -17
rlcube/rlcube/train/train.py
CHANGED
|
@@ -29,7 +29,7 @@ def train(epochs: int = 100):
|
|
| 29 |
if os.path.exists("models/model_best.pth"):
|
| 30 |
net.load("models/model_best.pth")
|
| 31 |
net = net.to(device)
|
| 32 |
-
optimizer = torch.optim.
|
| 33 |
value_loss_fn = torch.nn.MSELoss(reduction="none")
|
| 34 |
policy_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
|
| 35 |
|
|
@@ -43,25 +43,23 @@ def train(epochs: int = 100):
|
|
| 43 |
|
| 44 |
values, policies = net(states)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
loss_v = (
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
)
|
| 61 |
-
loss_p = (
|
| 62 |
-
policy_loss_fn(policies, indices).reshape(-1) / D.reshape(-1).detach()
|
| 63 |
-
)
|
| 64 |
-
loss = (loss_v + loss_p).mean()
|
| 65 |
epoch_loss += loss.item()
|
| 66 |
optimizer.zero_grad()
|
| 67 |
loss.backward()
|
|
|
|
| 29 |
if os.path.exists("models/model_best.pth"):
|
| 30 |
net.load("models/model_best.pth")
|
| 31 |
net = net.to(device)
|
| 32 |
+
optimizer = torch.optim.RMSprop(net.parameters(), lr=0.00001)
|
| 33 |
value_loss_fn = torch.nn.MSELoss(reduction="none")
|
| 34 |
policy_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
|
| 35 |
|
|
|
|
| 43 |
|
| 44 |
values, policies = net(states)
|
| 45 |
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
batch_size = neighbors.shape[0]
|
| 48 |
+
neighbors_reshaped = neighbors.view(-1, 24, 6)
|
| 49 |
+
values_out, _ = net(neighbors_reshaped)
|
| 50 |
+
rewards_out = reward(neighbors_reshaped)
|
| 51 |
|
| 52 |
+
nvalues = values_out.view(batch_size, 12, -1)
|
| 53 |
+
nrewards = rewards_out.view(batch_size, 12, -1)
|
| 54 |
|
| 55 |
+
target_values, indices = (nvalues + nrewards).max(dim=1)
|
| 56 |
+
target_values = target_values.detach()
|
| 57 |
+
indices = indices.reshape(-1)
|
| 58 |
+
weights = 1 / D.reshape(-1).detach()
|
| 59 |
|
| 60 |
+
loss_v = value_loss_fn(values, target_values).reshape(-1) * weights
|
| 61 |
+
loss_p = policy_loss_fn(policies, indices).reshape(-1) * weights
|
| 62 |
+
loss = (0.2 * loss_v + 0.8 * loss_p).mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
epoch_loss += loss.item()
|
| 64 |
optimizer.zero_grad()
|
| 65 |
loss.backward()
|