imwithye commited on
Commit
4bea016
·
1 Parent(s): 5529a00
Files changed (3) hide show
  1. rlcube/cube2.ipynb +10 -506
  2. rlcube/envs/cube2.py +12 -8
  3. rlcube/main.py +11 -1
rlcube/cube2.ipynb CHANGED
@@ -2,249 +2,15 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "dff864f2",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
10
  "import gymnasium as gym\n",
11
- "import numpy as np\n",
12
- "\n",
13
- "F = 0\n",
14
- "B = 1\n",
15
- "R = 2\n",
16
- "L = 3\n",
17
- "U = 4\n",
18
- "D = 5\n",
19
- "\n",
20
- "\n",
21
- "class Cube2(gym.Env):\n",
22
- " def __init__(self):\n",
23
- " super().__init__()\n",
24
- " self.action_space = gym.spaces.Discrete(12)\n",
25
- " self.observation_space = gym.spaces.Box(\n",
26
- " low=0, high=1, shape=(24, 6), dtype=np.int8\n",
27
- " )\n",
28
- " self.state = np.zeros((6, 4))\n",
29
- " self.step_count = 0\n",
30
- "\n",
31
- " def reset(self, seed=None, options=None):\n",
32
- " super().reset(seed=seed, options=options)\n",
33
- " self.state = np.zeros((6, 4))\n",
34
- " self.state[0] = np.ones(4) * F\n",
35
- " self.state[1] = np.ones(4) * B\n",
36
- " self.state[2] = np.ones(4) * R\n",
37
- " self.state[3] = np.ones(4) * L\n",
38
- " self.state[4] = np.ones(4) * U\n",
39
- " self.state[5] = np.ones(4) * D\n",
40
- " self.step_count = 0\n",
41
- " return self._get_obs(), {}\n",
42
- "\n",
43
- " def step(self, action):\n",
44
- " self.step_count += 1\n",
45
- " new_state = self.state.copy()\n",
46
  "\n",
47
- " # Front Clockwise\n",
48
- " if action == 0:\n",
49
- " new_state[F, 0] = self.state[F, 2]\n",
50
- " new_state[F, 1] = self.state[F, 0]\n",
51
- " new_state[F, 2] = self.state[F, 3]\n",
52
- " new_state[F, 3] = self.state[F, 1]\n",
53
- " new_state[R, 1] = self.state[U, 3]\n",
54
- " new_state[R, 3] = self.state[U, 1]\n",
55
- " new_state[L, 1] = self.state[D, 3]\n",
56
- " new_state[L, 3] = self.state[D, 1]\n",
57
- " new_state[U, 1] = self.state[L, 1]\n",
58
- " new_state[U, 3] = self.state[L, 3]\n",
59
- " new_state[D, 1] = self.state[R, 1]\n",
60
- " new_state[D, 3] = self.state[R, 3]\n",
61
- " # Front Counter-Clockwise\n",
62
- " elif action == 1:\n",
63
- " new_state[F, 0] = self.state[F, 1]\n",
64
- " new_state[F, 1] = self.state[F, 3]\n",
65
- " new_state[F, 2] = self.state[F, 0]\n",
66
- " new_state[F, 3] = self.state[F, 2]\n",
67
- " new_state[R, 1] = self.state[D, 1]\n",
68
- " new_state[R, 3] = self.state[D, 3]\n",
69
- " new_state[L, 1] = self.state[U, 1]\n",
70
- " new_state[L, 3] = self.state[U, 3]\n",
71
- " new_state[U, 1] = self.state[R, 3]\n",
72
- " new_state[U, 3] = self.state[R, 1]\n",
73
- " new_state[D, 1] = self.state[L, 3]\n",
74
- " new_state[D, 3] = self.state[L, 1]\n",
75
- " # Back Clockwise\n",
76
- " elif action == 2:\n",
77
- " new_state[B, 0] = self.state[B, 1]\n",
78
- " new_state[B, 1] = self.state[B, 3]\n",
79
- " new_state[B, 2] = self.state[B, 0]\n",
80
- " new_state[B, 3] = self.state[B, 2]\n",
81
- " new_state[R, 0] = self.state[D, 0]\n",
82
- " new_state[R, 2] = self.state[D, 2]\n",
83
- " new_state[L, 0] = self.state[U, 0]\n",
84
- " new_state[L, 2] = self.state[U, 2]\n",
85
- " new_state[U, 0] = self.state[R, 2]\n",
86
- " new_state[U, 2] = self.state[R, 0]\n",
87
- " new_state[D, 0] = self.state[L, 2]\n",
88
- " new_state[D, 2] = self.state[L, 0]\n",
89
- " # Back Counter-Clockwise\n",
90
- " elif action == 3:\n",
91
- " new_state[B, 0] = self.state[B, 2]\n",
92
- " new_state[B, 1] = self.state[B, 0]\n",
93
- " new_state[B, 2] = self.state[B, 3]\n",
94
- " new_state[B, 3] = self.state[B, 1]\n",
95
- " new_state[R, 0] = self.state[U, 2]\n",
96
- " new_state[R, 2] = self.state[U, 0]\n",
97
- " new_state[L, 0] = self.state[D, 2]\n",
98
- " new_state[L, 2] = self.state[D, 0]\n",
99
- " new_state[U, 0] = self.state[L, 0]\n",
100
- " new_state[U, 2] = self.state[L, 2]\n",
101
- " new_state[D, 0] = self.state[R, 0]\n",
102
- " new_state[D, 2] = self.state[R, 2]\n",
103
- " # Right Clockwise\n",
104
- " elif action == 4:\n",
105
- " new_state[F, 2] = self.state[D, 2]\n",
106
- " new_state[F, 3] = self.state[D, 3]\n",
107
- " new_state[B, 2] = self.state[U, 2]\n",
108
- " new_state[B, 3] = self.state[U, 3]\n",
109
- " new_state[R, 0] = self.state[R, 2]\n",
110
- " new_state[R, 1] = self.state[R, 0]\n",
111
- " new_state[R, 2] = self.state[R, 3]\n",
112
- " new_state[R, 3] = self.state[R, 1]\n",
113
- " new_state[U, 2] = self.state[F, 3]\n",
114
- " new_state[U, 3] = self.state[F, 2]\n",
115
- " new_state[D, 2] = self.state[B, 3]\n",
116
- " new_state[D, 3] = self.state[B, 2]\n",
117
- " # Right Counter-Clockwise\n",
118
- " elif action == 5:\n",
119
- " new_state[F, 2] = self.state[U, 3]\n",
120
- " new_state[F, 3] = self.state[U, 2]\n",
121
- " new_state[B, 2] = self.state[D, 3]\n",
122
- " new_state[B, 3] = self.state[D, 2]\n",
123
- " new_state[R, 0] = self.state[R, 1]\n",
124
- " new_state[R, 1] = self.state[R, 3]\n",
125
- " new_state[R, 2] = self.state[R, 0]\n",
126
- " new_state[R, 3] = self.state[R, 2]\n",
127
- " new_state[U, 2] = self.state[B, 2]\n",
128
- " new_state[U, 3] = self.state[B, 3]\n",
129
- " new_state[D, 2] = self.state[F, 2]\n",
130
- " new_state[D, 3] = self.state[F, 3]\n",
131
- " # Left Clockwise\n",
132
- " elif action == 6:\n",
133
- " new_state[F, 0] = self.state[U, 1]\n",
134
- " new_state[F, 1] = self.state[U, 0]\n",
135
- " new_state[B, 0] = self.state[D, 1]\n",
136
- " new_state[B, 1] = self.state[D, 0]\n",
137
- " new_state[L, 0] = self.state[L, 1]\n",
138
- " new_state[L, 1] = self.state[L, 3]\n",
139
- " new_state[L, 2] = self.state[L, 0]\n",
140
- " new_state[L, 3] = self.state[L, 2]\n",
141
- " new_state[U, 0] = self.state[B, 0]\n",
142
- " new_state[U, 1] = self.state[B, 1]\n",
143
- " new_state[D, 0] = self.state[F, 0]\n",
144
- " new_state[D, 1] = self.state[F, 1]\n",
145
- " # Left Counter-Clockwise\n",
146
- " elif action == 7:\n",
147
- " new_state[F, 0] = self.state[D, 0]\n",
148
- " new_state[F, 1] = self.state[D, 1]\n",
149
- " new_state[B, 0] = self.state[U, 0]\n",
150
- " new_state[B, 1] = self.state[U, 1]\n",
151
- " new_state[L, 0] = self.state[L, 2]\n",
152
- " new_state[L, 1] = self.state[L, 0]\n",
153
- " new_state[L, 2] = self.state[L, 3]\n",
154
- " new_state[L, 3] = self.state[L, 1]\n",
155
- " new_state[U, 0] = self.state[F, 1]\n",
156
- " new_state[U, 1] = self.state[F, 0]\n",
157
- " new_state[D, 0] = self.state[B, 1]\n",
158
- " new_state[D, 1] = self.state[B, 0]\n",
159
- " # Up Clockwise\n",
160
- " elif action == 8:\n",
161
- " new_state[F, 1] = self.state[R, 3]\n",
162
- " new_state[F, 3] = self.state[R, 2]\n",
163
- " new_state[B, 1] = self.state[L, 3]\n",
164
- " new_state[B, 3] = self.state[L, 2]\n",
165
- " new_state[R, 2] = self.state[B, 1]\n",
166
- " new_state[R, 3] = self.state[B, 3]\n",
167
- " new_state[L, 2] = self.state[F, 1]\n",
168
- " new_state[L, 3] = self.state[F, 3]\n",
169
- " new_state[U, 0] = self.state[U, 1]\n",
170
- " new_state[U, 1] = self.state[U, 3]\n",
171
- " new_state[U, 2] = self.state[U, 0]\n",
172
- " new_state[U, 3] = self.state[U, 2]\n",
173
- " # Up Counter-Clockwise\n",
174
- " elif action == 9:\n",
175
- " new_state[F, 1] = self.state[L, 2]\n",
176
- " new_state[F, 3] = self.state[L, 3]\n",
177
- " new_state[B, 1] = self.state[R, 2]\n",
178
- " new_state[B, 3] = self.state[R, 3]\n",
179
- " new_state[R, 2] = self.state[F, 3]\n",
180
- " new_state[R, 3] = self.state[F, 1]\n",
181
- " new_state[L, 2] = self.state[B, 3]\n",
182
- " new_state[L, 3] = self.state[B, 1]\n",
183
- " new_state[U, 0] = self.state[U, 2]\n",
184
- " new_state[U, 1] = self.state[U, 0]\n",
185
- " new_state[U, 2] = self.state[U, 3]\n",
186
- " new_state[U, 3] = self.state[U, 1]\n",
187
- " # Bottom Clockwise\n",
188
- " elif action == 10:\n",
189
- " new_state[F, 0] = self.state[L, 0]\n",
190
- " new_state[F, 2] = self.state[L, 1]\n",
191
- " new_state[B, 0] = self.state[R, 0]\n",
192
- " new_state[B, 2] = self.state[R, 1]\n",
193
- " new_state[R, 0] = self.state[F, 2]\n",
194
- " new_state[R, 1] = self.state[F, 0]\n",
195
- " new_state[L, 0] = self.state[B, 2]\n",
196
- " new_state[L, 1] = self.state[B, 0]\n",
197
- " new_state[D, 0] = self.state[D, 2]\n",
198
- " new_state[D, 1] = self.state[D, 0]\n",
199
- " new_state[D, 2] = self.state[D, 3]\n",
200
- " new_state[D, 3] = self.state[D, 1]\n",
201
- " # Bottom Counter-Clockwise\n",
202
- " elif action == 11:\n",
203
- " new_state[F, 0] = self.state[R, 1]\n",
204
- " new_state[F, 2] = self.state[R, 0]\n",
205
- " new_state[B, 0] = self.state[L, 1]\n",
206
- " new_state[B, 2] = self.state[L, 0]\n",
207
- " new_state[R, 0] = self.state[B, 0]\n",
208
- " new_state[R, 1] = self.state[B, 2]\n",
209
- " new_state[L, 0] = self.state[F, 0]\n",
210
- " new_state[L, 1] = self.state[F, 2]\n",
211
- " new_state[D, 0] = self.state[D, 1]\n",
212
- " new_state[D, 1] = self.state[D, 3]\n",
213
- " new_state[D, 2] = self.state[D, 0]\n",
214
- " new_state[D, 3] = self.state[D, 2]\n",
215
- " self.state = new_state\n",
216
- " return (\n",
217
- " self._get_obs(),\n",
218
- " 1 if self._is_solved() else -1,\n",
219
- " self._is_solved(),\n",
220
- " self.step_count >= 100,\n",
221
- " {},\n",
222
- " )\n",
223
  "\n",
224
- " def _get_obs(self):\n",
225
- " one_hots = []\n",
226
- " for i in range(6):\n",
227
- " for j in range(4):\n",
228
- " label = int(self.state[i, j])\n",
229
- " zeros = np.zeros(6)\n",
230
- " zeros[label] = 1\n",
231
- " one_hots.append(zeros)\n",
232
- " return np.array(one_hots)\n",
233
- "\n",
234
- " def _is_solved(self):\n",
235
- " for i in range(6):\n",
236
- " if np.mean(self.state[i]) != self.state[i][0]:\n",
237
- " return False\n",
238
- " return True"
239
- ]
240
- },
241
- {
242
- "cell_type": "code",
243
- "execution_count": 2,
244
- "id": "624c83c1",
245
- "metadata": {},
246
- "outputs": [],
247
- "source": [
248
  "class RewardWrapper(gym.Wrapper):\n",
249
  " def __init__(self, *args, **kwargs):\n",
250
  " super().__init__(*args, **kwargs)\n",
@@ -271,7 +37,7 @@
271
  },
272
  {
273
  "cell_type": "code",
274
- "execution_count": null,
275
  "id": "7a81c85a",
276
  "metadata": {},
277
  "outputs": [
@@ -279,30 +45,12 @@
279
  "name": "stdout",
280
  "output_type": "stream",
281
  "text": [
282
- "[[0. 0. 1. 0. 0. 0.]\n",
283
- " [1. 0. 0. 0. 0. 0.]\n",
284
- " [0. 0. 0. 0. 0. 1.]\n",
285
- " [0. 0. 0. 0. 0. 1.]\n",
286
- " [0. 0. 0. 1. 0. 0.]\n",
287
- " [0. 1. 0. 0. 0. 0.]\n",
288
- " [0. 0. 0. 0. 1. 0.]\n",
289
- " [0. 0. 0. 0. 1. 0.]\n",
290
- " [0. 0. 1. 0. 0. 0.]\n",
291
- " [0. 1. 0. 0. 0. 0.]\n",
292
- " [0. 0. 1. 0. 0. 0.]\n",
293
- " [0. 1. 0. 0. 0. 0.]\n",
294
- " [1. 0. 0. 0. 0. 0.]\n",
295
- " [1. 0. 0. 0. 0. 0.]\n",
296
- " [0. 0. 0. 1. 0. 0.]\n",
297
- " [0. 0. 0. 1. 0. 0.]\n",
298
- " [0. 0. 0. 0. 1. 0.]\n",
299
- " [0. 0. 0. 0. 1. 0.]\n",
300
- " [1. 0. 0. 0. 0. 0.]\n",
301
- " [0. 0. 1. 0. 0. 0.]\n",
302
- " [0. 0. 0. 0. 0. 1.]\n",
303
- " [0. 0. 0. 0. 0. 1.]\n",
304
- " [0. 1. 0. 0. 0. 0.]\n",
305
- " [0. 0. 0. 1. 0. 0.]]\n"
306
  ]
307
  }
308
  ],
@@ -311,250 +59,6 @@
311
  "obs, _ = env.reset()\n",
312
  "print(env.state())"
313
  ]
314
- },
315
- {
316
- "cell_type": "code",
317
- "execution_count": null,
318
- "id": "f8b4d968",
319
- "metadata": {},
320
- "outputs": [
321
- {
322
- "name": "stdout",
323
- "output_type": "stream",
324
- "text": [
325
- "Using cpu device\n",
326
- "Wrapping the env with a `Monitor` wrapper\n",
327
- "Wrapping the env in a DummyVecEnv.\n",
328
- "----------------------------------\n",
329
- "| rollout/ | |\n",
330
- "| ep_len_mean | 94.2 |\n",
331
- "| ep_rew_mean | -88.2 |\n",
332
- "| exploration_rate | 0.105 |\n",
333
- "| time/ | |\n",
334
- "| episodes | 100 |\n",
335
- "| fps | 4943 |\n",
336
- "| time_elapsed | 1 |\n",
337
- "| total_timesteps | 9424 |\n",
338
- "| train/ | |\n",
339
- "| learning_rate | 0.0001 |\n",
340
- "| loss | 0.0004 |\n",
341
- "| n_updates | 2330 |\n",
342
- "----------------------------------\n",
343
- "----------------------------------\n",
344
- "| rollout/ | |\n",
345
- "| ep_len_mean | 98.1 |\n",
346
- "| ep_rew_mean | -96.1 |\n",
347
- "| exploration_rate | 0.05 |\n",
348
- "| time/ | |\n",
349
- "| episodes | 200 |\n",
350
- "| fps | 4426 |\n",
351
- "| time_elapsed | 4 |\n",
352
- "| total_timesteps | 19236 |\n",
353
- "| train/ | |\n",
354
- "| learning_rate | 0.0001 |\n",
355
- "| loss | 0.000292 |\n",
356
- "| n_updates | 4783 |\n",
357
- "----------------------------------\n",
358
- "----------------------------------\n",
359
- "| rollout/ | |\n",
360
- "| ep_len_mean | 95.2 |\n",
361
- "| ep_rew_mean | -90.1 |\n",
362
- "| exploration_rate | 0.05 |\n",
363
- "| time/ | |\n",
364
- "| episodes | 300 |\n",
365
- "| fps | 4349 |\n",
366
- "| time_elapsed | 6 |\n",
367
- "| total_timesteps | 28754 |\n",
368
- "| train/ | |\n",
369
- "| learning_rate | 0.0001 |\n",
370
- "| loss | 0.000103 |\n",
371
- "| n_updates | 7163 |\n",
372
- "----------------------------------\n",
373
- "----------------------------------\n",
374
- "| rollout/ | |\n",
375
- "| ep_len_mean | 88.4 |\n",
376
- "| ep_rew_mean | -76.3 |\n",
377
- "| exploration_rate | 0.05 |\n",
378
- "| time/ | |\n",
379
- "| episodes | 400 |\n",
380
- "| fps | 4391 |\n",
381
- "| time_elapsed | 8 |\n",
382
- "| total_timesteps | 37598 |\n",
383
- "| train/ | |\n",
384
- "| learning_rate | 0.0001 |\n",
385
- "| loss | 0.000121 |\n",
386
- "| n_updates | 9374 |\n",
387
- "----------------------------------\n",
388
- "----------------------------------\n",
389
- "| rollout/ | |\n",
390
- "| ep_len_mean | 86.6 |\n",
391
- "| ep_rew_mean | -72.5 |\n",
392
- "| exploration_rate | 0.05 |\n",
393
- "| time/ | |\n",
394
- "| episodes | 500 |\n",
395
- "| fps | 4417 |\n",
396
- "| time_elapsed | 10 |\n",
397
- "| total_timesteps | 46260 |\n",
398
- "| train/ | |\n",
399
- "| learning_rate | 0.0001 |\n",
400
- "| loss | 0.000169 |\n",
401
- "| n_updates | 11539 |\n",
402
- "----------------------------------\n",
403
- "----------------------------------\n",
404
- "| rollout/ | |\n",
405
- "| ep_len_mean | 82.6 |\n",
406
- "| ep_rew_mean | -64.4 |\n",
407
- "| exploration_rate | 0.05 |\n",
408
- "| time/ | |\n",
409
- "| episodes | 600 |\n",
410
- "| fps | 4436 |\n",
411
- "| time_elapsed | 12 |\n",
412
- "| total_timesteps | 54520 |\n",
413
- "| train/ | |\n",
414
- "| learning_rate | 0.0001 |\n",
415
- "| loss | 9.72e-05 |\n",
416
- "| n_updates | 13604 |\n",
417
- "----------------------------------\n",
418
- "----------------------------------\n",
419
- "| rollout/ | |\n",
420
- "| ep_len_mean | 79.4 |\n",
421
- "| ep_rew_mean | -57.2 |\n",
422
- "| exploration_rate | 0.05 |\n",
423
- "| time/ | |\n",
424
- "| episodes | 700 |\n",
425
- "| fps | 4445 |\n",
426
- "| time_elapsed | 14 |\n",
427
- "| total_timesteps | 62462 |\n",
428
- "| train/ | |\n",
429
- "| learning_rate | 0.0001 |\n",
430
- "| loss | 6.99e-05 |\n",
431
- "| n_updates | 15590 |\n",
432
- "----------------------------------\n",
433
- "----------------------------------\n",
434
- "| rollout/ | |\n",
435
- "| ep_len_mean | 75.5 |\n",
436
- "| ep_rew_mean | -49.2 |\n",
437
- "| exploration_rate | 0.05 |\n",
438
- "| time/ | |\n",
439
- "| episodes | 800 |\n",
440
- "| fps | 4456 |\n",
441
- "| time_elapsed | 15 |\n",
442
- "| total_timesteps | 70012 |\n",
443
- "| train/ | |\n",
444
- "| learning_rate | 0.0001 |\n",
445
- "| loss | 0.264 |\n",
446
- "| n_updates | 17477 |\n",
447
- "----------------------------------\n",
448
- "----------------------------------\n",
449
- "| rollout/ | |\n",
450
- "| ep_len_mean | 70.5 |\n",
451
- "| ep_rew_mean | -39.2 |\n",
452
- "| exploration_rate | 0.05 |\n",
453
- "| time/ | |\n",
454
- "| episodes | 900 |\n",
455
- "| fps | 4471 |\n",
456
- "| time_elapsed | 17 |\n",
457
- "| total_timesteps | 77066 |\n",
458
- "| train/ | |\n",
459
- "| learning_rate | 0.0001 |\n",
460
- "| loss | 0.000102 |\n",
461
- "| n_updates | 19241 |\n",
462
- "----------------------------------\n",
463
- "----------------------------------\n",
464
- "| rollout/ | |\n",
465
- "| ep_len_mean | 66.1 |\n",
466
- "| ep_rew_mean | -28.8 |\n",
467
- "| exploration_rate | 0.05 |\n",
468
- "| time/ | |\n",
469
- "| episodes | 1000 |\n",
470
- "| fps | 4489 |\n",
471
- "| time_elapsed | 18 |\n",
472
- "| total_timesteps | 83678 |\n",
473
- "| train/ | |\n",
474
- "| learning_rate | 0.0001 |\n",
475
- "| loss | 0.000145 |\n",
476
- "| n_updates | 20894 |\n",
477
- "----------------------------------\n",
478
- "----------------------------------\n",
479
- "| rollout/ | |\n",
480
- "| ep_len_mean | 66.9 |\n",
481
- "| ep_rew_mean | -31.6 |\n",
482
- "| exploration_rate | 0.05 |\n",
483
- "| time/ | |\n",
484
- "| episodes | 1100 |\n",
485
- "| fps | 4504 |\n",
486
- "| time_elapsed | 20 |\n",
487
- "| total_timesteps | 90370 |\n",
488
- "| train/ | |\n",
489
- "| learning_rate | 0.0001 |\n",
490
- "| loss | 0.000488 |\n",
491
- "| n_updates | 22567 |\n",
492
- "----------------------------------\n",
493
- "----------------------------------\n",
494
- "| rollout/ | |\n",
495
- "| ep_len_mean | 68.6 |\n",
496
- "| ep_rew_mean | -34.3 |\n",
497
- "| exploration_rate | 0.05 |\n",
498
- "| time/ | |\n",
499
- "| episodes | 1200 |\n",
500
- "| fps | 4517 |\n",
501
- "| time_elapsed | 21 |\n",
502
- "| total_timesteps | 97230 |\n",
503
- "| train/ | |\n",
504
- "| learning_rate | 0.0001 |\n",
505
- "| loss | 0.00045 |\n",
506
- "| n_updates | 24282 |\n",
507
- "----------------------------------\n"
508
- ]
509
- }
510
- ],
511
- "source": [
512
- "from stable_baselines3 import DQN\n",
513
- "\n",
514
- "env = Cube2()\n",
515
- "env = RewardWrapper(env)\n",
516
- "model = DQN(\"MlpPolicy\", env, verbose=1)\n",
517
- "model.learn(total_timesteps=100000, log_interval=100)"
518
- ]
519
- },
520
- {
521
- "cell_type": "code",
522
- "execution_count": 75,
523
- "id": "24132717",
524
- "metadata": {},
525
- "outputs": [
526
- {
527
- "name": "stdout",
528
- "output_type": "stream",
529
- "text": [
530
- "rotationController.setState([[0.0, 0.0, 3.0, 4.0], [5.0, 2.0, 1.0, 1.0], [3.0, 4.0, 3.0, 2.0], [2.0, 5.0, 4.0, 5.0], [0.0, 3.0, 5.0, 1.0], [1.0, 2.0, 4.0, 0.0]])\n",
531
- "rotationController.addRotationStepCode(...[3, 1, 8, 3])\n",
532
- "\n",
533
- "Solved in 4 steps\n"
534
- ]
535
- }
536
- ],
537
- "source": [
538
- "# model = DQN.load(\"dqn_cube2.pkl\")\n",
539
- "import json\n",
540
- "\n",
541
- "env = Cube2()\n",
542
- "env = RewardWrapper(env)\n",
543
- "obs, _ = env.reset()\n",
544
- "print(f\"rotationController.setState({json.dumps(env.state().tolist())})\")\n",
545
- "\n",
546
- "solved_actions = []\n",
547
- "for i in range(100):\n",
548
- " action, _ = model.predict(obs, deterministic=True)\n",
549
- " solved_actions.append(action.item())\n",
550
- " obs, reward, terminated, truncated, _ = env.step(action)\n",
551
- " if terminated or truncated:\n",
552
- " break\n",
553
- "print(f\"rotationController.addRotationStepCode(...{json.dumps(solved_actions)})\")\n",
554
- "\n",
555
- "print()\n",
556
- "print(f\"Solved in {len(solved_actions)} steps\")"
557
- ]
558
  }
559
  ],
560
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "624c83c1",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
10
  "import gymnasium as gym\n",
11
+ "from envs.cube2 import Cube2\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  "class RewardWrapper(gym.Wrapper):\n",
15
  " def __init__(self, *args, **kwargs):\n",
16
  " super().__init__(*args, **kwargs)\n",
 
37
  },
38
  {
39
  "cell_type": "code",
40
+ "execution_count": 4,
41
  "id": "7a81c85a",
42
  "metadata": {},
43
  "outputs": [
 
45
  "name": "stdout",
46
  "output_type": "stream",
47
  "text": [
48
+ "[[3 3 0 4]\n",
49
+ " [2 2 1 5]\n",
50
+ " [5 5 0 0]\n",
51
+ " [1 4 1 4]\n",
52
+ " [4 0 2 2]\n",
53
+ " [5 1 3 3]]\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  ]
55
  }
56
  ],
 
59
  "obs, _ = env.reset()\n",
60
  "print(env.state())"
61
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  }
63
  ],
64
  "metadata": {
rlcube/envs/cube2.py CHANGED
@@ -19,15 +19,19 @@ class Cube2(gym.Env):
19
  self.state = np.zeros((6, 4))
20
  self.step_count = 0
21
 
22
- def reset(self, seed=None, options=None):
23
  super().reset(seed=seed, options=options)
24
- self.state = np.zeros((6, 4))
25
- self.state[0] = np.ones(4) * F
26
- self.state[1] = np.ones(4) * B
27
- self.state[2] = np.ones(4) * R
28
- self.state[3] = np.ones(4) * L
29
- self.state[4] = np.ones(4) * U
30
- self.state[5] = np.ones(4) * D
 
 
 
 
31
  self.step_count = 0
32
  return self._get_obs(), {}
33
 
 
19
  self.state = np.zeros((6, 4))
20
  self.step_count = 0
21
 
22
+ def reset(self, seed=None, options=None, state: np.ndarray = None):
23
  super().reset(seed=seed, options=options)
24
+ if state is None:
25
+ self.state = np.zeros((6, 4), dtype=np.int8)
26
+ self.state[0] = np.ones(4, dtype=np.int8) * F
27
+ self.state[1] = np.ones(4, dtype=np.int8) * B
28
+ self.state[2] = np.ones(4, dtype=np.int8) * R
29
+ self.state[3] = np.ones(4, dtype=np.int8) * L
30
+ self.state[4] = np.ones(4, dtype=np.int8) * U
31
+ self.state[5] = np.ones(4, dtype=np.int8) * D
32
+ else:
33
+ assert state.shape == (6, 4) and state.dtype == np.int8
34
+ self.state = state
35
  self.step_count = 0
36
  return self._get_obs(), {}
37
 
rlcube/main.py CHANGED
@@ -2,6 +2,8 @@ from typing import List
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from fastapi import HTTPException
 
 
5
 
6
  app = FastAPI()
7
 
@@ -20,4 +22,12 @@ def solve(body: StateArgs):
20
  ):
21
  raise HTTPException(status_code=400, detail="state must be a 6x4 matrix")
22
 
23
- return {"steps": [1, 2, 1, 1]}
 
 
 
 
 
 
 
 
 
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from fastapi import HTTPException
5
+ from envs.cube2 import Cube2
6
+ import numpy as np
7
 
8
  app = FastAPI()
9
 
 
22
  ):
23
  raise HTTPException(status_code=400, detail="state must be a 6x4 matrix")
24
 
25
+ env = Cube2()
26
+ env.reset(state=np.array(state, dtype=np.int8))
27
+
28
+ steps = []
29
+ for _ in range(10):
30
+ action = env.action_space.sample()
31
+ steps.append(action.item())
32
+
33
+ return {"steps": steps}