imwithye commited on
Commit
bfe342c
·
1 Parent(s): a48903f
Files changed (1) hide show
  1. rlcube/rlcube/train/train.py +15 -17
rlcube/rlcube/train/train.py CHANGED
@@ -29,7 +29,7 @@ def train(epochs: int = 100):
29
  if os.path.exists("models/model_best.pth"):
30
  net.load("models/model_best.pth")
31
  net = net.to(device)
32
- optimizer = torch.optim.Adam(net.parameters(), lr=0.000001)
33
  value_loss_fn = torch.nn.MSELoss(reduction="none")
34
  policy_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
35
 
@@ -43,25 +43,23 @@ def train(epochs: int = 100):
43
 
44
  values, policies = net(states)
45
 
46
- batch_size = neighbors.shape[0]
47
- neighbors_reshaped = neighbors.view(-1, 24, 6)
48
- neighbors_values, _ = net(neighbors_reshaped)
49
- rewards_out = reward(neighbors_reshaped)
 
50
 
51
- neighbors_values = neighbors_values.view(batch_size, 12, -1)
52
- neighbors_rewards = rewards_out.view(batch_size, 12, -1)
53
 
54
- target_values, indices = (neighbors_values + neighbors_rewards).max(dim=1)
55
- indices = indices.reshape(-1)
 
 
56
 
57
- loss_v = (
58
- value_loss_fn(values, target_values).reshape(-1)
59
- / D.reshape(-1).detach()
60
- )
61
- loss_p = (
62
- policy_loss_fn(policies, indices).reshape(-1) / D.reshape(-1).detach()
63
- )
64
- loss = (loss_v + loss_p).mean()
65
  epoch_loss += loss.item()
66
  optimizer.zero_grad()
67
  loss.backward()
 
29
  if os.path.exists("models/model_best.pth"):
30
  net.load("models/model_best.pth")
31
  net = net.to(device)
32
+ optimizer = torch.optim.RMSprop(net.parameters(), lr=0.00001)
33
  value_loss_fn = torch.nn.MSELoss(reduction="none")
34
  policy_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
35
 
 
43
 
44
  values, policies = net(states)
45
 
46
+ with torch.no_grad():
47
+ batch_size = neighbors.shape[0]
48
+ neighbors_reshaped = neighbors.view(-1, 24, 6)
49
+ values_out, _ = net(neighbors_reshaped)
50
+ rewards_out = reward(neighbors_reshaped)
51
 
52
+ nvalues = values_out.view(batch_size, 12, -1)
53
+ nrewards = rewards_out.view(batch_size, 12, -1)
54
 
55
+ target_values, indices = (nvalues + nrewards).max(dim=1)
56
+ target_values = target_values.detach()
57
+ indices = indices.reshape(-1)
58
+ weights = 1 / D.reshape(-1).detach()
59
 
60
+ loss_v = value_loss_fn(values, target_values).reshape(-1) * weights
61
+ loss_p = policy_loss_fn(policies, indices).reshape(-1) * weights
62
+ loss = (0.2 * loss_v + 0.8 * loss_p).mean()
 
 
 
 
 
63
  epoch_loss += loss.item()
64
  optimizer.zero_grad()
65
  loss.backward()