imwithye commited on
Commit
f02352c
·
1 Parent(s): edb87c5
Files changed (2) hide show
  1. rlcube/cube2.ipynb +64 -23
  2. rlcube/rlcube/train/train.py +2 -2
rlcube/cube2.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 60,
6
  "id": "624c83c1",
7
  "metadata": {},
8
  "outputs": [
@@ -32,14 +32,14 @@
32
  ")"
33
  ]
34
  },
35
- "execution_count": 60,
36
  "metadata": {},
37
  "output_type": "execute_result"
38
  }
39
  ],
40
  "source": [
41
  "from rlcube.models.models import DNN\n",
42
- "from rlcube.envs.cube2 import Cube2\n",
43
  "import numpy as np\n",
44
  "import torch\n",
45
  "\n",
@@ -50,7 +50,7 @@
50
  },
51
  {
52
  "cell_type": "code",
53
- "execution_count": 61,
54
  "id": "16736f3a",
55
  "metadata": {},
56
  "outputs": [
@@ -58,32 +58,73 @@
58
  "name": "stdout",
59
  "output_type": "stream",
60
  "text": [
61
- "tensor([[ 0.0166],\n",
62
- " [ 1.0147],\n",
63
- " [ 1.1610],\n",
64
- " [ 0.9844],\n",
65
- " [-0.0268],\n",
66
- " [ 1.1526]], grad_fn=<AddmmBackward0>)\n",
67
- "tensor([10, 1, 5, 0, 10, 1])\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  ]
69
  }
70
  ],
71
  "source": [
72
- "env = Cube2()\n",
73
- "obs, _ = env.reset()\n",
74
- "obs1, _, _, _, _ = env.step(0)\n",
75
- "obs2, _, _, _, _ = env.step(0)\n",
76
- "obs3, _, _, _, _ = env.step(2)\n",
77
- "obs4, _, _, _, _ = env.step(2)\n",
78
  "for _ in range(10):\n",
79
- " obsMany, _, _, _, _ = env.step(env.action_space.sample())\n",
80
- "batched_obs = torch.tensor(\n",
81
- " np.array([obs, obs1, obs2, obs3, obs4, obsMany]), dtype=torch.float32\n",
82
- ")\n",
83
  "out = net(batched_obs)\n",
84
- "print(out[\"value\"])\n",
85
- "print(torch.argmax(out[\"policy\"], dim=1))"
 
 
 
 
 
86
  ]
 
 
 
 
 
 
 
 
87
  }
88
  ],
89
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "624c83c1",
7
  "metadata": {},
8
  "outputs": [
 
32
  ")"
33
  ]
34
  },
35
+ "execution_count": 1,
36
  "metadata": {},
37
  "output_type": "execute_result"
38
  }
39
  ],
40
  "source": [
41
  "from rlcube.models.models import DNN\n",
42
+ "from rlcube.envs.cube2 import Cube2Env\n",
43
  "import numpy as np\n",
44
  "import torch\n",
45
  "\n",
 
50
  },
51
  {
52
  "cell_type": "code",
53
+ "execution_count": 9,
54
  "id": "16736f3a",
55
  "metadata": {},
56
  "outputs": [
 
58
  "name": "stdout",
59
  "output_type": "stream",
60
  "text": [
61
+ "rotationController.setState([[0, 0, 4, 4], [1, 1, 5, 5], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 1, 1], [5, 5, 0, 0]]);\n",
62
+ "0.40487873554229736\n",
63
+ "4\n",
64
+ "\n",
65
+ "rotationController.setState([[0, 4, 0, 4], [1, 1, 5, 5], [2, 5, 2, 0], [3, 4, 3, 1], [4, 2, 1, 2], [5, 3, 0, 3]]);\n",
66
+ "0.0839405208826065\n",
67
+ "7\n",
68
+ "\n",
69
+ "rotationController.setState([[0, 4, 0, 4], [5, 1, 5, 1], [1, 5, 4, 0], [0, 4, 5, 1], [3, 2, 3, 2], [2, 3, 2, 3]]);\n",
70
+ "-0.23320673406124115\n",
71
+ "3\n",
72
+ "\n",
73
+ "rotationController.setState([[0, 5, 0, 1], [5, 4, 5, 0], [1, 5, 4, 4], [0, 4, 1, 1], [3, 3, 2, 2], [2, 3, 2, 3]]);\n",
74
+ "0.31869572401046753\n",
75
+ "0\n",
76
+ "\n",
77
+ "rotationController.setState([[5, 5, 1, 1], [4, 4, 0, 0], [5, 5, 4, 4], [0, 0, 1, 1], [3, 3, 2, 2], [3, 3, 2, 2]]);\n",
78
+ "-0.16905824840068817\n",
79
+ "7\n",
80
+ "\n",
81
+ "rotationController.setState([[5, 4, 1, 4], [4, 1, 0, 1], [5, 5, 4, 0], [0, 0, 5, 1], [3, 2, 3, 2], [3, 3, 2, 2]]);\n",
82
+ "0.20266102254390717\n",
83
+ "3\n",
84
+ "\n",
85
+ "rotationController.setState([[2, 3, 1, 4], [3, 3, 0, 1], [5, 5, 4, 0], [0, 1, 0, 5], [4, 1, 3, 2], [5, 4, 2, 2]]);\n",
86
+ "0.6111429333686829\n",
87
+ "3\n",
88
+ "\n",
89
+ "rotationController.setState([[2, 0, 1, 4], [3, 5, 0, 0], [5, 5, 3, 1], [0, 1, 3, 4], [1, 2, 4, 3], [5, 4, 2, 2]]);\n",
90
+ "1.3550236225128174\n",
91
+ "2\n",
92
+ "\n",
93
+ "rotationController.setState([[0, 0, 1, 4], [5, 5, 5, 0], [1, 2, 3, 1], [0, 3, 3, 4], [1, 2, 4, 3], [2, 5, 2, 4]]);\n",
94
+ "0.9975889325141907\n",
95
+ "7\n",
96
+ "\n",
97
+ "rotationController.setState([[2, 0, 1, 4], [3, 5, 0, 0], [5, 5, 3, 1], [0, 1, 3, 4], [1, 2, 4, 3], [5, 4, 2, 2]]);\n",
98
+ "1.3550236225128174\n",
99
+ "2\n",
100
+ "\n"
101
  ]
102
  }
103
  ],
104
  "source": [
105
+ "batch_obs = []\n",
106
+ "env = Cube2Env()\n",
 
 
 
 
107
  "for _ in range(10):\n",
108
+ " obs, _, _, _, _ = env.step(env.action_space.sample())\n",
109
+ " batch_obs.append(torch.tensor(obs, dtype=torch.float32))\n",
110
+ "batched_obs = torch.stack(batch_obs)\n",
 
111
  "out = net(batched_obs)\n",
112
+ "\n",
113
+ "for i in range(10):\n",
114
+ " env = Cube2Env.from_obs(batch_obs[i])\n",
115
+ " env.print_js_code()\n",
116
+ " print(out[\"value\"][i].item())\n",
117
+ " print(torch.argmax(out[\"policy\"][i]).item())\n",
118
+ " print()"
119
  ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "aee2a911",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": []
128
  }
129
  ],
130
  "metadata": {
rlcube/rlcube/train/train.py CHANGED
@@ -17,7 +17,7 @@ print(f"Using device: {device}")
17
 
18
  def train(epochs: int = 100):
19
  if not os.path.exists("dataset.pt"):
20
- create_dataset(num_envs=10000, num_steps=20, filepath="dataset.pt")
21
  dataset = Cube2Dataset("dataset.pt")
22
  print("Number of samples:", len(dataset))
23
  print("Number of epochs:", epochs)
@@ -29,7 +29,7 @@ def train(epochs: int = 100):
29
  if os.path.exists("models/model_best.pth"):
30
  net.load("models/model_best.pth")
31
  net = net.to(device)
32
- optimizer = torch.optim.RMSprop(net.parameters(), lr=0.000001)
33
  value_loss_fn = torch.nn.MSELoss(reduction="none")
34
  policy_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
35
 
 
17
 
18
  def train(epochs: int = 100):
19
  if not os.path.exists("dataset.pt"):
20
+ create_dataset(num_envs=1000, num_steps=20, filepath="dataset.pt")
21
  dataset = Cube2Dataset("dataset.pt")
22
  print("Number of samples:", len(dataset))
23
  print("Number of epochs:", epochs)
 
29
  if os.path.exists("models/model_best.pth"):
30
  net.load("models/model_best.pth")
31
  net = net.to(device)
32
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.000001)
33
  value_loss_fn = torch.nn.MSELoss(reduction="none")
34
  policy_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
35