imwithye commited on
Commit
5b4ee5e
·
1 Parent(s): 297d0f8
Files changed (2) hide show
  1. rlcube/cube2.ipynb +75 -4
  2. rlcube/rlcube/train/train.py +6 -6
rlcube/cube2.ipynb CHANGED
@@ -2,14 +2,85 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "id": "624c83c1",
7
  "metadata": {},
8
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  "source": [
10
- "from rlcube.train.train import train\n",
 
 
 
11
  "\n",
12
- "train()"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  ]
14
  }
15
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 60,
6
  "id": "624c83c1",
7
  "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "text/plain": [
12
+ "DNN(\n",
13
+ " (fc_in): Linear(in_features=144, out_features=512, bias=True)\n",
14
+ " (residual_blocks): ModuleList(\n",
15
+ " (0-3): 4 x ResidualBlock(\n",
16
+ " (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
17
+ " (fc1): Linear(in_features=512, out_features=1024, bias=True)\n",
18
+ " (ln2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
19
+ " (fc2): Linear(in_features=1024, out_features=512, bias=True)\n",
20
+ " )\n",
21
+ " )\n",
22
+ " (fc_value): Sequential(\n",
23
+ " (0): Linear(in_features=512, out_features=64, bias=True)\n",
24
+ " (1): ReLU()\n",
25
+ " (2): Linear(in_features=64, out_features=1, bias=True)\n",
26
+ " )\n",
27
+ " (fc_policy): Sequential(\n",
28
+ " (0): Linear(in_features=512, out_features=64, bias=True)\n",
29
+ " (1): ReLU()\n",
30
+ " (2): Linear(in_features=64, out_features=12, bias=True)\n",
31
+ " )\n",
32
+ ")"
33
+ ]
34
+ },
35
+ "execution_count": 60,
36
+ "metadata": {},
37
+ "output_type": "execute_result"
38
+ }
39
+ ],
40
  "source": [
41
+ "from rlcube.models.models import DNN\n",
42
+ "from rlcube.envs.cube2 import Cube2\n",
43
+ "import numpy as np\n",
44
+ "import torch\n",
45
  "\n",
46
+ "net = DNN()\n",
47
+ "net.load(\"models/model_best.pth\")\n",
48
+ "net.eval()"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 61,
54
+ "id": "16736f3a",
55
+ "metadata": {},
56
+ "outputs": [
57
+ {
58
+ "name": "stdout",
59
+ "output_type": "stream",
60
+ "text": [
61
+ "tensor([[ 0.0166],\n",
62
+ " [ 1.0147],\n",
63
+ " [ 1.1610],\n",
64
+ " [ 0.9844],\n",
65
+ " [-0.0268],\n",
66
+ " [ 1.1526]], grad_fn=<AddmmBackward0>)\n",
67
+ "tensor([10, 1, 5, 0, 10, 1])\n"
68
+ ]
69
+ }
70
+ ],
71
+ "source": [
72
+ "env = Cube2()\n",
73
+ "obs, _ = env.reset()\n",
74
+ "obs1, _, _, _, _ = env.step(0)\n",
75
+ "obs2, _, _, _, _ = env.step(0)\n",
76
+ "obs3, _, _, _, _ = env.step(2)\n",
77
+ "obs4, _, _, _, _ = env.step(2)\n",
78
+ "for _ in range(10):\n",
79
+ " obsMany, _, _, _, _ = env.step(env.action_space.sample())\n",
80
+ "batched_obs = torch.tensor(np.array([obs, obs1, obs2, obs3, obs4, obsMany]), dtype=torch.float32)\n",
81
+ "out = net(batched_obs)\n",
82
+ "print(out[\"value\"])\n",
83
+ "print(torch.argmax(out[\"policy\"], dim=1))"
84
  ]
85
  }
86
  ],
rlcube/rlcube/train/train.py CHANGED
@@ -29,9 +29,9 @@ def train(epochs: int = 100):
29
  if os.path.exists("models/model_best.pth"):
30
  net.load("models/model_best.pth")
31
  net = net.to(device)
32
- optimizer = torch.optim.RMSprop(net.parameters(), lr=0.0001)
33
- value_loss_fn = torch.nn.MSELoss()
34
- policy_loss_fn = torch.nn.CrossEntropyLoss()
35
 
36
  best_loss = float("inf")
37
  for epoch in range(epochs):
@@ -56,9 +56,9 @@ def train(epochs: int = 100):
56
  target_values, indices = (neighbors_values + neighbors_rewards).max(dim=1)
57
  indices = indices.reshape(-1)
58
 
59
- loss_v = value_loss_fn(values, target_values)
60
- loss_p = policy_loss_fn(policies, indices)
61
- loss = loss_v + loss_p
62
  epoch_loss += loss.item()
63
  optimizer.zero_grad()
64
  loss.backward()
 
29
  if os.path.exists("models/model_best.pth"):
30
  net.load("models/model_best.pth")
31
  net = net.to(device)
32
+ optimizer = torch.optim.RMSprop(net.parameters(), lr=0.000001)
33
+ value_loss_fn = torch.nn.MSELoss(reduction="none")
34
+ policy_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
35
 
36
  best_loss = float("inf")
37
  for epoch in range(epochs):
 
56
  target_values, indices = (neighbors_values + neighbors_rewards).max(dim=1)
57
  indices = indices.reshape(-1)
58
 
59
+ loss_v = value_loss_fn(values, target_values).reshape(-1) / D.reshape(-1).detach()
60
+ loss_p = policy_loss_fn(policies, indices).reshape(-1) / D.reshape(-1).detach()
61
+ loss = (loss_v + loss_p).mean()
62
  epoch_loss += loss.item()
63
  optimizer.zero_grad()
64
  loss.backward()