imwithye commited on
Commit
5f0be19
Β·
1 Parent(s): bfe342c

close to solve!

Browse files
rlcube/cube2.ipynb CHANGED
@@ -40,6 +40,7 @@
40
  "source": [
41
  "from rlcube.models.models import DNN\n",
42
  "from rlcube.envs.cube2 import Cube2Env\n",
 
43
  "\n",
44
  "net = DNN()\n",
45
  "net.load(\"models/model_best.pth\")\n",
@@ -48,190 +49,81 @@
48
  },
49
  {
50
  "cell_type": "code",
51
- "execution_count": 2,
52
- "id": "16736f3a",
53
  "metadata": {},
54
  "outputs": [
55
  {
56
- "name": "stderr",
57
  "output_type": "stream",
58
  "text": [
59
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 300/300 [00:02<00:00, 132.06it/s]\n"
 
 
 
 
 
 
 
 
 
 
60
  ]
61
  }
62
  ],
63
  "source": [
64
- "from rlcube.models.search import MonteCarloTree\n",
65
  "\n",
66
  "env = Cube2Env()\n",
 
67
  "actions = []\n",
68
- "for _ in range(3):\n",
 
69
  " action = env.action_space.sample()\n",
70
- " actions.append(action)\n",
71
  " env.step(action)\n",
72
- "tree = MonteCarloTree(env.obs())"
73
- ]
74
- },
75
- {
76
- "cell_type": "code",
77
- "execution_count": 3,
78
- "id": "aee2a911",
79
- "metadata": {},
80
- "outputs": [],
81
- "source": [
82
- "node = tree.root"
83
- ]
84
- },
85
- {
86
- "cell_type": "code",
87
- "execution_count": 4,
88
- "id": "048f58c9",
89
- "metadata": {},
90
- "outputs": [
91
- {
92
- "data": {
93
- "text/plain": [
94
- "[np.int64(8), np.int64(1), np.int64(4)]"
95
- ]
96
- },
97
- "execution_count": 4,
98
- "metadata": {},
99
- "output_type": "execute_result"
100
- }
101
- ],
102
- "source": [
103
- "actions"
104
- ]
105
- },
106
- {
107
- "cell_type": "code",
108
- "execution_count": 5,
109
- "id": "00994021",
110
- "metadata": {},
111
- "outputs": [
112
- {
113
- "data": {
114
- "text/plain": [
115
- "tensor([3.4725e+00, 3.3189e+00, 1.2619e-02, 3.1231e-01, 1.1286e-02, 2.5817e-02,\n",
116
- " 1.6722e-02, 2.1334e-02, 3.4603e+00, 7.5021e-02, 2.5891e-02, 2.8712e-03])"
117
- ]
118
- },
119
- "execution_count": 5,
120
- "metadata": {},
121
- "output_type": "execute_result"
122
- }
123
- ],
124
- "source": [
125
- "node.u()"
126
  ]
127
  },
128
  {
129
  "cell_type": "code",
130
- "execution_count": 6,
131
- "id": "fb9ac54c",
132
  "metadata": {},
133
  "outputs": [
134
  {
135
- "data": {
136
- "text/plain": [
137
- "defaultdict(<function rlcube.models.search.Node.__init__.<locals>.<lambda>()>,\n",
138
- " {0: 276,\n",
139
- " 1: 7,\n",
140
- " 2: 0,\n",
141
- " 3: 0,\n",
142
- " 4: 0,\n",
143
- " 5: 0,\n",
144
- " 6: 0,\n",
145
- " 7: 0,\n",
146
- " 8: 16,\n",
147
- " 9: 0,\n",
148
- " 10: 0,\n",
149
- " 11: 0})"
150
- ]
151
- },
152
- "execution_count": 6,
153
- "metadata": {},
154
- "output_type": "execute_result"
155
- }
156
- ],
157
- "source": [
158
- "node.N"
159
- ]
160
- },
161
- {
162
- "cell_type": "code",
163
- "execution_count": 7,
164
- "id": "2f8a09d1",
165
- "metadata": {},
166
- "outputs": [
167
  {
168
- "data": {
169
- "text/plain": [
170
- "defaultdict(<function rlcube.models.search.Node.__init__.<locals>.<lambda>()>,\n",
171
- " {0: tensor([3.4720]),\n",
172
- " 1: tensor([1.8959]),\n",
173
- " 2: 0,\n",
174
- " 3: 0,\n",
175
- " 4: 0,\n",
176
- " 5: 0,\n",
177
- " 6: 0,\n",
178
- " 7: 0,\n",
179
- " 8: tensor([2.7285]),\n",
180
- " 9: 0,\n",
181
- " 10: 0,\n",
182
- " 11: 0})"
183
- ]
184
- },
185
- "execution_count": 7,
186
- "metadata": {},
187
- "output_type": "execute_result"
188
- }
189
- ],
190
- "source": [
191
- "node.W"
192
- ]
193
- },
194
- {
195
- "cell_type": "code",
196
- "execution_count": 8,
197
- "id": "3e341459",
198
- "metadata": {},
199
- "outputs": [
200
  {
201
- "data": {
202
- "text/plain": [
203
- "defaultdict(<function rlcube.models.search.Node.__init__.<locals>.<lambda>()>,\n",
204
- " {0: 4,\n",
205
- " 1: 0,\n",
206
- " 2: 0,\n",
207
- " 3: 0,\n",
208
- " 4: 0,\n",
209
- " 5: 2,\n",
210
- " 6: 0,\n",
211
- " 7: 0,\n",
212
- " 8: 269,\n",
213
- " 9: 0,\n",
214
- " 10: 0,\n",
215
- " 11: 0})"
216
- ]
217
- },
218
- "execution_count": 8,
219
- "metadata": {},
220
- "output_type": "execute_result"
221
  }
222
  ],
223
  "source": [
224
- "node.children[0].N"
225
- ]
226
- },
227
- {
228
- "cell_type": "code",
229
- "execution_count": null,
230
- "id": "51dddf56",
231
- "metadata": {},
232
- "outputs": [],
233
- "source": [
234
- "node.children[8].N"
235
  ]
236
  }
237
  ],
 
40
  "source": [
41
  "from rlcube.models.models import DNN\n",
42
  "from rlcube.envs.cube2 import Cube2Env\n",
43
+ "import torch\n",
44
  "\n",
45
  "net = DNN()\n",
46
  "net.load(\"models/model_best.pth\")\n",
 
49
  },
50
  {
51
  "cell_type": "code",
52
+ "execution_count": 16,
53
+ "id": "defde44e",
54
  "metadata": {},
55
  "outputs": [
56
  {
57
+ "name": "stdout",
58
  "output_type": "stream",
59
  "text": [
60
+ "[2, 3, 7, 6, 8, 6, 3, 2, 2, 5]\n",
61
+ "tensor([[ 1.1924],\n",
62
+ " [ 0.0826],\n",
63
+ " [ 1.0202],\n",
64
+ " [ 0.0826],\n",
65
+ " [ 1.1121],\n",
66
+ " [-0.0302],\n",
67
+ " [-1.5963],\n",
68
+ " [-0.0302],\n",
69
+ " [-1.3707],\n",
70
+ " [-2.4068]], grad_fn=<AddmmBackward0>)\n"
71
  ]
72
  }
73
  ],
74
  "source": [
75
+ "import numpy as np\n",
76
  "\n",
77
  "env = Cube2Env()\n",
78
+ "\n",
79
  "actions = []\n",
80
+ "obs = []\n",
81
+ "for _ in range(10):\n",
82
  " action = env.action_space.sample()\n",
83
+ " actions.append(action.item())\n",
84
  " env.step(action)\n",
85
+ " obs.append(env.obs())\n",
86
+ "\n",
87
+ "obs = torch.tensor(np.array(obs), dtype=torch.float32)\n",
88
+ "values, policies = net(obs)\n",
89
+ "print(actions)\n",
90
+ "print(values)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  ]
92
  },
93
  {
94
  "cell_type": "code",
95
+ "execution_count": 18,
96
+ "id": "cae20b12",
97
  "metadata": {},
98
  "outputs": [
99
  {
100
+ "name": "stderr",
101
+ "output_type": "stream",
102
+ "text": [
103
+ " 14%|β–ˆβ– | 43/300 [00:00<00:02, 127.98it/s]"
104
+ ]
105
+ },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  {
107
+ "name": "stdout",
108
+ "output_type": "stream",
109
+ "text": [
110
+ "[4, 3, 7, 11]\n"
111
+ ]
112
+ },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  {
114
+ "name": "stderr",
115
+ "output_type": "stream",
116
+ "text": [
117
+ "\n"
118
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  }
120
  ],
121
  "source": [
122
+ "from rlcube.models.search import MonteCarloTree\n",
123
+ "\n",
124
+ "tree = MonteCarloTree(env.obs(), max_simulations=300)\n",
125
+ "if tree.is_solved:\n",
126
+ " print([action for _, action in tree.solved_path])"
 
 
 
 
 
 
127
  ]
128
  }
129
  ],
rlcube/rlcube/models/models.py CHANGED
@@ -79,6 +79,31 @@ class DNN(nn.Module):
79
  self.load_state_dict(torch.load(filepath))
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if __name__ == "__main__":
83
  print("Testing RewardNet")
84
  env = Cube2Env()
 
79
  self.load_state_dict(torch.load(filepath))
80
 
81
 
82
+ class DNN2(nn.Module):
83
+ def __init__(self):
84
+ super(DNN2, self).__init__()
85
+
86
+ self.body = nn.Sequential(
87
+ nn.Linear(24 * 6, 4096), nn.ELU(), nn.Linear(4096, 2048), nn.ELU()
88
+ )
89
+ self.policy = nn.Sequential(nn.Linear(2048, 512), nn.ELU(), nn.Linear(512, 12))
90
+ self.value = nn.Sequential(nn.Linear(2048, 512), nn.ELU(), nn.Linear(512, 1))
91
+
92
+ def forward(self, x):
93
+ batch_size = x.size(0)
94
+ x = x.view(batch_size, -1)
95
+ x = self.body(x)
96
+ value = self.value(x)
97
+ policy = self.policy(x)
98
+ return value, policy
99
+
100
+ def save(self, filepath: str):
101
+ torch.save(self.state_dict(), filepath)
102
+
103
+ def load(self, filepath: str):
104
+ self.load_state_dict(torch.load(filepath))
105
+
106
+
107
  if __name__ == "__main__":
108
  print("Testing RewardNet")
109
  env = Cube2Env()
rlcube/rlcube/models/search.py CHANGED
@@ -14,9 +14,9 @@ class Node:
14
  self.obs = torch.tensor(obs, dtype=torch.float32)
15
  self.parent = parent
16
 
17
- out = net(self.obs.unsqueeze(0))
18
- value = out["value"].detach()
19
- policy = torch.softmax(out["policy"].detach(), dim=1)
20
 
21
  self.is_solved = Cube2Env.from_obs(obs).is_solved()
22
  self.value = torch.tensor(1) if self.is_solved else value.view(-1)
@@ -55,6 +55,7 @@ class MonteCarloTree:
55
  self.root = Node(obs)
56
  self.nodes = [self.root]
57
  self.is_solved = False
 
58
  self._build()
59
 
60
  def _build(self):
@@ -80,6 +81,8 @@ class MonteCarloTree:
80
  node.children[i] = child
81
  self.nodes.append(child)
82
  self.is_solved = self.is_solved or child.is_solved
 
 
83
 
84
  # Backup
85
  for parent, action in reversed(path):
 
14
  self.obs = torch.tensor(obs, dtype=torch.float32)
15
  self.parent = parent
16
 
17
+ value, policy = net(self.obs.unsqueeze(0))
18
+ value = value.detach()
19
+ policy = torch.softmax(policy.detach(), dim=1)
20
 
21
  self.is_solved = Cube2Env.from_obs(obs).is_solved()
22
  self.value = torch.tensor(1) if self.is_solved else value.view(-1)
 
55
  self.root = Node(obs)
56
  self.nodes = [self.root]
57
  self.is_solved = False
58
+ self.solved_path = []
59
  self._build()
60
 
61
  def _build(self):
 
81
  node.children[i] = child
82
  self.nodes.append(child)
83
  self.is_solved = self.is_solved or child.is_solved
84
+ if child.is_solved:
85
+ self.solved_path = path + [(node, i)]
86
 
87
  # Backup
88
  for parent, action in reversed(path):
rlcube/rlcube/train/train.py CHANGED
@@ -29,7 +29,7 @@ def train(epochs: int = 100):
29
  if os.path.exists("models/model_best.pth"):
30
  net.load("models/model_best.pth")
31
  net = net.to(device)
32
- optimizer = torch.optim.RMSprop(net.parameters(), lr=0.00001)
33
  value_loss_fn = torch.nn.MSELoss(reduction="none")
34
  policy_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
35
 
@@ -42,6 +42,8 @@ def train(epochs: int = 100):
42
  states, neighbors, D = states.to(device), neighbors.to(device), D.to(device)
43
 
44
  values, policies = net(states)
 
 
45
 
46
  with torch.no_grad():
47
  batch_size = neighbors.shape[0]
@@ -53,7 +55,9 @@ def train(epochs: int = 100):
53
  nrewards = rewards_out.view(batch_size, 12, -1)
54
 
55
  target_values, indices = (nvalues + nrewards).max(dim=1)
 
56
  target_values = target_values.detach()
 
57
  indices = indices.reshape(-1)
58
  weights = 1 / D.reshape(-1).detach()
59
 
 
29
  if os.path.exists("models/model_best.pth"):
30
  net.load("models/model_best.pth")
31
  net = net.to(device)
32
+ optimizer = torch.optim.RMSprop(net.parameters(), lr=0.000001)
33
  value_loss_fn = torch.nn.MSELoss(reduction="none")
34
  policy_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
35
 
 
42
  states, neighbors, D = states.to(device), neighbors.to(device), D.to(device)
43
 
44
  values, policies = net(states)
45
+ rewards = reward(states)
46
+ masks = torch.where(rewards > 0, 0, 1).unsqueeze(1)
47
 
48
  with torch.no_grad():
49
  batch_size = neighbors.shape[0]
 
55
  nrewards = rewards_out.view(batch_size, 12, -1)
56
 
57
  target_values, indices = (nvalues + nrewards).max(dim=1)
58
+ target_values = target_values * masks
59
  target_values = target_values.detach()
60
+
61
  indices = indices.reshape(-1)
62
  weights = 1 / D.reshape(-1).detach()
63