Spaces:
Running
Running
update search
Browse files- rlcube/cube2.ipynb +28 -66
- rlcube/main.py +1 -1
- rlcube/rlcube/models/search.py +37 -34
rlcube/cube2.ipynb
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"id": "624c83c1",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [
|
|
@@ -32,7 +32,7 @@
|
|
| 32 |
")"
|
| 33 |
]
|
| 34 |
},
|
| 35 |
-
"execution_count":
|
| 36 |
"metadata": {},
|
| 37 |
"output_type": "execute_result"
|
| 38 |
}
|
|
@@ -43,13 +43,13 @@
|
|
| 43 |
"import torch\n",
|
| 44 |
"\n",
|
| 45 |
"net = DNN()\n",
|
| 46 |
-
"net.load(\"models/
|
| 47 |
"net.eval()"
|
| 48 |
]
|
| 49 |
},
|
| 50 |
{
|
| 51 |
"cell_type": "code",
|
| 52 |
-
"execution_count":
|
| 53 |
"id": "defde44e",
|
| 54 |
"metadata": {},
|
| 55 |
"outputs": [
|
|
@@ -57,48 +57,44 @@
|
|
| 57 |
"name": "stdout",
|
| 58 |
"output_type": "stream",
|
| 59 |
"text": [
|
| 60 |
-
"[
|
| 61 |
-
"tensor([[
|
| 62 |
-
" [
|
| 63 |
-
" [-0.
|
| 64 |
-
" [-0.
|
| 65 |
-
" [-0.
|
| 66 |
-
" [-1.
|
| 67 |
-
" [-
|
| 68 |
-
" [-1.
|
| 69 |
-
" [-3.
|
| 70 |
-
" [-
|
| 71 |
]
|
| 72 |
},
|
| 73 |
{
|
| 74 |
"name": "stderr",
|
| 75 |
"output_type": "stream",
|
| 76 |
"text": [
|
| 77 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
]
|
| 79 |
},
|
| 80 |
{
|
| 81 |
-
"
|
| 82 |
-
"
|
| 83 |
-
"
|
| 84 |
-
|
| 85 |
-
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 86 |
-
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
|
| 87 |
-
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 20\u001b[39m\n\u001b[32m 16\u001b[39m \u001b[38;5;28mprint\u001b[39m(values)\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrlcube\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodels\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01msearch\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m MonteCarloTree\n\u001b[32m---> \u001b[39m\u001b[32m20\u001b[39m tree = \u001b[43mMonteCarloTree\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_simulations\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m5000\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 21\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m tree.is_solved:\n\u001b[32m 22\u001b[39m \u001b[38;5;28mprint\u001b[39m([action \u001b[38;5;28;01mfor\u001b[39;00m _, action \u001b[38;5;129;01min\u001b[39;00m tree.solved_path])\n",
|
| 88 |
-
"\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/models/search.py:59\u001b[39m, in \u001b[36mMonteCarloTree.__init__\u001b[39m\u001b[34m(self, obs, max_simulations)\u001b[39m\n\u001b[32m 57\u001b[39m \u001b[38;5;28mself\u001b[39m.is_solved = \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[32m 58\u001b[39m \u001b[38;5;28mself\u001b[39m.solved_path = []\n\u001b[32m---> \u001b[39m\u001b[32m59\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_build\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 89 |
-
"\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/models/search.py:80\u001b[39m, in \u001b[36mMonteCarloTree._build\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m12\u001b[39m):\n\u001b[32m 79\u001b[39m obs = adjacent_obs[i]\n\u001b[32m---> \u001b[39m\u001b[32m80\u001b[39m child = \u001b[43mNode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 81\u001b[39m node.children[i] = child\n\u001b[32m 82\u001b[39m \u001b[38;5;28mself\u001b[39m.nodes.append(child)\n",
|
| 90 |
-
"\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/models/search.py:21\u001b[39m, in \u001b[36mNode.__init__\u001b[39m\u001b[34m(self, obs, parent)\u001b[39m\n\u001b[32m 18\u001b[39m value = value.detach()\n\u001b[32m 19\u001b[39m policy = torch.softmax(policy.detach(), dim=\u001b[32m1\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m21\u001b[39m \u001b[38;5;28mself\u001b[39m.is_solved = \u001b[43mCube2Env\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfrom_obs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m)\u001b[49m.is_solved()\n\u001b[32m 22\u001b[39m \u001b[38;5;28mself\u001b[39m.value = torch.tensor(\u001b[32m1\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.is_solved \u001b[38;5;28;01melse\u001b[39;00m value.view(-\u001b[32m1\u001b[39m)\n\u001b[32m 23\u001b[39m \u001b[38;5;28mself\u001b[39m.policy = policy.view(-\u001b[32m1\u001b[39m)\n",
|
| 91 |
-
"\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/envs/cube2.py:30\u001b[39m, in \u001b[36mCube2Env.from_obs\u001b[39m\u001b[34m(obs)\u001b[39m\n\u001b[32m 28\u001b[39m idx = i * \u001b[32m4\u001b[39m + j\n\u001b[32m 29\u001b[39m state[i, j] = np.argmax(obs[idx])\n\u001b[32m---> \u001b[39m\u001b[32m30\u001b[39m env = \u001b[43mCube2Env\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 31\u001b[39m env.reset(state=state)\n\u001b[32m 32\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m env\n",
|
| 92 |
-
"\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/envs/cube2.py:16\u001b[39m, in \u001b[36mCube2Env.__init__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 14\u001b[39m \u001b[38;5;28msuper\u001b[39m(Cube2Env, \u001b[38;5;28mself\u001b[39m).\u001b[34m__init__\u001b[39m()\n\u001b[32m 15\u001b[39m \u001b[38;5;28mself\u001b[39m.action_space = gym.spaces.Discrete(\u001b[32m12\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m \u001b[38;5;28mself\u001b[39m.observation_space = \u001b[43mgym\u001b[49m\u001b[43m.\u001b[49m\u001b[43mspaces\u001b[49m\u001b[43m.\u001b[49m\u001b[43mBox\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 17\u001b[39m \u001b[43m \u001b[49m\u001b[43mlow\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhigh\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m=\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m24\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m6\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m=\u001b[49m\u001b[43mnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43mint8\u001b[49m\n\u001b[32m 18\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 19\u001b[39m \u001b[38;5;28mself\u001b[39m.state = np.zeros((\u001b[32m6\u001b[39m, \u001b[32m4\u001b[39m), dtype=np.int8)\n\u001b[32m 20\u001b[39m \u001b[38;5;28mself\u001b[39m.reset()\n",
|
| 93 |
-
"\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/.venv/lib/python3.12/site-packages/gymnasium/spaces/box.py:149\u001b[39m, in \u001b[36mBox.__init__\u001b[39m\u001b[34m(self, low, high, shape, dtype, seed)\u001b[39m\n\u001b[32m 147\u001b[39m \u001b[38;5;66;03m# Cast `low` and `high` to ndarray for the dtype min and max for out of range tests\u001b[39;00m\n\u001b[32m 148\u001b[39m \u001b[38;5;28mself\u001b[39m.low, \u001b[38;5;28mself\u001b[39m.bounded_below = \u001b[38;5;28mself\u001b[39m._cast_low(low, dtype_min)\n\u001b[32m--> \u001b[39m\u001b[32m149\u001b[39m \u001b[38;5;28mself\u001b[39m.high, \u001b[38;5;28mself\u001b[39m.bounded_above = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_cast_high\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhigh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype_max\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 151\u001b[39m \u001b[38;5;66;03m# recheck shape for case where shape and (low or high) are provided\u001b[39;00m\n\u001b[32m 152\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.low.shape != shape:\n",
|
| 94 |
-
"\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/.venv/lib/python3.12/site-packages/gymnasium/spaces/box.py:251\u001b[39m, in \u001b[36mBox._cast_high\u001b[39m\u001b[34m(self, high, dtype_max)\u001b[39m\n\u001b[32m 241\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_cast_high\u001b[39m(\u001b[38;5;28mself\u001b[39m, high, dtype_max) -> \u001b[38;5;28mtuple\u001b[39m[np.ndarray, np.ndarray]:\n\u001b[32m 242\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Casts the input Box high value to ndarray with provided dtype.\u001b[39;00m\n\u001b[32m 243\u001b[39m \n\u001b[32m 244\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 249\u001b[39m \u001b[33;03m The updated high value and for what values the input is bounded (above)\u001b[39;00m\n\u001b[32m 250\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m251\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mis_float_integer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhigh\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[32m 252\u001b[39m bounded_above = np.full(\u001b[38;5;28mself\u001b[39m.shape, high, dtype=\u001b[38;5;28mfloat\u001b[39m) < np.inf\n\u001b[32m 254\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m np.isnan(high):\n",
|
| 95 |
-
"\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/.venv/lib/python3.12/site-packages/gymnasium/spaces/box.py:32\u001b[39m, in \u001b[36mis_float_integer\u001b[39m\u001b[34m(var)\u001b[39m\n\u001b[32m 28\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(np.min(arr))\n\u001b[32m 29\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(arr)\n\u001b[32m---> \u001b[39m\u001b[32m32\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mis_float_integer\u001b[39m(var: Any) -> \u001b[38;5;28mbool\u001b[39m:\n\u001b[32m 33\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Checks if a scalar variable is an integer or float (does not include bool).\"\"\"\u001b[39;00m\n\u001b[32m 34\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m np.issubdtype(\u001b[38;5;28mtype\u001b[39m(var), np.integer) \u001b[38;5;129;01mor\u001b[39;00m np.issubdtype(\u001b[38;5;28mtype\u001b[39m(var), np.floating)\n",
|
| 96 |
-
"\u001b[31mKeyboardInterrupt\u001b[39m: "
|
| 97 |
]
|
| 98 |
}
|
| 99 |
],
|
| 100 |
"source": [
|
| 101 |
"import numpy as np\n",
|
|
|
|
| 102 |
"\n",
|
| 103 |
"env = Cube2Env()\n",
|
| 104 |
"\n",
|
|
@@ -115,46 +111,12 @@
|
|
| 115 |
"print(actions)\n",
|
| 116 |
"print(values)\n",
|
| 117 |
"\n",
|
| 118 |
-
"from rlcube.models.search import MonteCarloTree\n",
|
| 119 |
"\n",
|
| 120 |
"tree = MonteCarloTree(env.obs(), max_simulations=1000)\n",
|
| 121 |
"if tree.is_solved:\n",
|
| 122 |
" print([action for _, action in tree.solved_path])"
|
| 123 |
]
|
| 124 |
},
|
| 125 |
-
{
|
| 126 |
-
"cell_type": "code",
|
| 127 |
-
"execution_count": 6,
|
| 128 |
-
"id": "a91732d7",
|
| 129 |
-
"metadata": {},
|
| 130 |
-
"outputs": [
|
| 131 |
-
{
|
| 132 |
-
"data": {
|
| 133 |
-
"text/plain": [
|
| 134 |
-
"defaultdict(<function rlcube.models.search.Node.__init__.<locals>.<lambda>()>,\n",
|
| 135 |
-
" {0: 400,\n",
|
| 136 |
-
" 1: 0,\n",
|
| 137 |
-
" 2: 0,\n",
|
| 138 |
-
" 3: 0,\n",
|
| 139 |
-
" 4: 0,\n",
|
| 140 |
-
" 5: 0,\n",
|
| 141 |
-
" 6: 0,\n",
|
| 142 |
-
" 7: 0,\n",
|
| 143 |
-
" 8: 0,\n",
|
| 144 |
-
" 9: 0,\n",
|
| 145 |
-
" 10: 44,\n",
|
| 146 |
-
" 11: 0})"
|
| 147 |
-
]
|
| 148 |
-
},
|
| 149 |
-
"execution_count": 6,
|
| 150 |
-
"metadata": {},
|
| 151 |
-
"output_type": "execute_result"
|
| 152 |
-
}
|
| 153 |
-
],
|
| 154 |
-
"source": [
|
| 155 |
-
"tree.root.N"
|
| 156 |
-
]
|
| 157 |
-
},
|
| 158 |
{
|
| 159 |
"cell_type": "code",
|
| 160 |
"execution_count": null,
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 101,
|
| 6 |
"id": "624c83c1",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [
|
|
|
|
| 32 |
")"
|
| 33 |
]
|
| 34 |
},
|
| 35 |
+
"execution_count": 101,
|
| 36 |
"metadata": {},
|
| 37 |
"output_type": "execute_result"
|
| 38 |
}
|
|
|
|
| 43 |
"import torch\n",
|
| 44 |
"\n",
|
| 45 |
"net = DNN()\n",
|
| 46 |
+
"net.load(\"models/model_final.pth\")\n",
|
| 47 |
"net.eval()"
|
| 48 |
]
|
| 49 |
},
|
| 50 |
{
|
| 51 |
"cell_type": "code",
|
| 52 |
+
"execution_count": 103,
|
| 53 |
"id": "defde44e",
|
| 54 |
"metadata": {},
|
| 55 |
"outputs": [
|
|
|
|
| 57 |
"name": "stdout",
|
| 58 |
"output_type": "stream",
|
| 59 |
"text": [
|
| 60 |
+
"[7, 11, 6, 7, 7, 10, 1, 0, 3, 3]\n",
|
| 61 |
+
"tensor([[ 0.9634],\n",
|
| 62 |
+
" [-0.0930],\n",
|
| 63 |
+
" [-0.8327],\n",
|
| 64 |
+
" [-0.0930],\n",
|
| 65 |
+
" [-0.8955],\n",
|
| 66 |
+
" [-1.8250],\n",
|
| 67 |
+
" [-4.0525],\n",
|
| 68 |
+
" [-1.8250],\n",
|
| 69 |
+
" [-3.0264],\n",
|
| 70 |
+
" [-3.6782]], grad_fn=<AddmmBackward0>)\n"
|
| 71 |
]
|
| 72 |
},
|
| 73 |
{
|
| 74 |
"name": "stderr",
|
| 75 |
"output_type": "stream",
|
| 76 |
"text": [
|
| 77 |
+
" 1%| | 8/1000 [00:00<00:10, 99.11it/s]"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"name": "stdout",
|
| 82 |
+
"output_type": "stream",
|
| 83 |
+
"text": [
|
| 84 |
+
"[0, 2, 5, 2, 8, 6]\n"
|
| 85 |
]
|
| 86 |
},
|
| 87 |
{
|
| 88 |
+
"name": "stderr",
|
| 89 |
+
"output_type": "stream",
|
| 90 |
+
"text": [
|
| 91 |
+
"\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
]
|
| 93 |
}
|
| 94 |
],
|
| 95 |
"source": [
|
| 96 |
"import numpy as np\n",
|
| 97 |
+
"from rlcube.models.search import MonteCarloTree\n",
|
| 98 |
"\n",
|
| 99 |
"env = Cube2Env()\n",
|
| 100 |
"\n",
|
|
|
|
| 111 |
"print(actions)\n",
|
| 112 |
"print(values)\n",
|
| 113 |
"\n",
|
|
|
|
| 114 |
"\n",
|
| 115 |
"tree = MonteCarloTree(env.obs(), max_simulations=1000)\n",
|
| 116 |
"if tree.is_solved:\n",
|
| 117 |
" print([action for _, action in tree.solved_path])"
|
| 118 |
]
|
| 119 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
{
|
| 121 |
"cell_type": "code",
|
| 122 |
"execution_count": null,
|
rlcube/main.py
CHANGED
|
@@ -28,4 +28,4 @@ def solve(body: StateArgs):
|
|
| 28 |
tree = MonteCarloTree(env.obs(), max_simulations=300)
|
| 29 |
if tree.is_solved:
|
| 30 |
return {"steps": [action for _, action in tree.solved_path]}
|
| 31 |
-
raise HTTPException(status_code=
|
|
|
|
| 28 |
tree = MonteCarloTree(env.obs(), max_simulations=300)
|
| 29 |
if tree.is_solved:
|
| 30 |
return {"steps": [action for _, action in tree.solved_path]}
|
| 31 |
+
raise HTTPException(status_code=422, detail="Unable to solve the cube")
|
rlcube/rlcube/models/search.py
CHANGED
|
@@ -1,8 +1,12 @@
|
|
| 1 |
-
from collections import defaultdict
|
| 2 |
import torch
|
| 3 |
from rlcube.models.models import DNN
|
| 4 |
from rlcube.envs.cube2 import Cube2Env
|
| 5 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
net = DNN()
|
| 8 |
net.load("models/model_final.pth")
|
|
@@ -14,38 +18,29 @@ class Node:
|
|
| 14 |
self.obs = torch.tensor(obs, dtype=torch.float32)
|
| 15 |
self.parent = parent
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
policy = torch.softmax(policy.detach(), dim=1)
|
| 20 |
-
|
| 21 |
self.is_solved = Cube2Env.from_obs(obs).is_solved()
|
| 22 |
-
self.value =
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
-
self.children =
|
| 26 |
-
self.N =
|
| 27 |
-
self.W =
|
|
|
|
| 28 |
|
| 29 |
def is_leaf(self):
|
| 30 |
return len(self.children) == 0
|
| 31 |
|
| 32 |
def u(self):
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
[
|
| 37 |
-
c
|
| 38 |
-
* self.policy[action].item()
|
| 39 |
-
* torch.sqrt(n_sum)
|
| 40 |
-
/ (self.N[action] + 1)
|
| 41 |
-
+ self.W[action]
|
| 42 |
-
for action in range(12)
|
| 43 |
-
]
|
| 44 |
-
)
|
| 45 |
-
return u
|
| 46 |
|
| 47 |
def select_action(self):
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
class MonteCarloTree:
|
|
@@ -68,26 +63,34 @@ class MonteCarloTree:
|
|
| 68 |
|
| 69 |
# Selection
|
| 70 |
while not node.is_leaf():
|
| 71 |
-
|
| 72 |
-
path.append((node,
|
| 73 |
-
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# Expansion
|
| 76 |
env = Cube2Env.from_obs(node.obs)
|
| 77 |
adjacent_obs = env.adjacent_obs()
|
| 78 |
-
for i in range(
|
| 79 |
-
|
| 80 |
-
child = Node(obs, node)
|
| 81 |
node.children[i] = child
|
| 82 |
self.nodes.append(child)
|
| 83 |
-
self.is_solved = self.is_solved or child.is_solved
|
| 84 |
if child.is_solved:
|
|
|
|
| 85 |
self.solved_path = path + [(node, i)]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
# Backup
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
parent.
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
if __name__ == "__main__":
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from rlcube.models.models import DNN
|
| 3 |
from rlcube.envs.cube2 import Cube2Env
|
| 4 |
from tqdm import tqdm
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
ACTIONS = 12
|
| 8 |
+
C_PUCT = 1.414
|
| 9 |
+
VIRTUAL_LOSS = 0.0
|
| 10 |
|
| 11 |
net = DNN()
|
| 12 |
net.load("models/model_final.pth")
|
|
|
|
| 18 |
self.obs = torch.tensor(obs, dtype=torch.float32)
|
| 19 |
self.parent = parent
|
| 20 |
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
value, policy_logits = net(self.obs.unsqueeze(0))
|
|
|
|
|
|
|
| 23 |
self.is_solved = Cube2Env.from_obs(obs).is_solved()
|
| 24 |
+
self.value = 1.0 if self.is_solved else float(value.item())
|
| 25 |
+
policy = torch.softmax(policy_logits, dim=1).view(-1)
|
| 26 |
+
self.policy = np.array([float(policy[i].item()) for i in range(ACTIONS)])
|
| 27 |
|
| 28 |
+
self.children = {}
|
| 29 |
+
self.N = np.zeros(ACTIONS, dtype=np.int32)
|
| 30 |
+
self.W = np.zeros(ACTIONS, dtype=np.float32) # max value seen (not average)
|
| 31 |
+
self.L = np.zeros(ACTIONS, dtype=np.float32) # virtual loss (for async)
|
| 32 |
|
| 33 |
def is_leaf(self):
|
| 34 |
return len(self.children) == 0
|
| 35 |
|
| 36 |
def u(self):
|
| 37 |
+
n_sum = np.sum(self.N) + 1
|
| 38 |
+
scores = self.policy * C_PUCT * np.sqrt(n_sum) / (self.N + 1) + self.W - self.L
|
| 39 |
+
return scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
def select_action(self):
|
| 42 |
+
scores = self.u()
|
| 43 |
+
return np.argmax(scores)
|
| 44 |
|
| 45 |
|
| 46 |
class MonteCarloTree:
|
|
|
|
| 63 |
|
| 64 |
# Selection
|
| 65 |
while not node.is_leaf():
|
| 66 |
+
a = node.select_action()
|
| 67 |
+
path.append((node, a))
|
| 68 |
+
if VIRTUAL_LOSS:
|
| 69 |
+
node.L[a] += VIRTUAL_LOSS
|
| 70 |
+
node = node.children[a]
|
| 71 |
|
| 72 |
# Expansion
|
| 73 |
env = Cube2Env.from_obs(node.obs)
|
| 74 |
adjacent_obs = env.adjacent_obs()
|
| 75 |
+
for i in range(ACTIONS):
|
| 76 |
+
child = Node(adjacent_obs[i], node)
|
|
|
|
| 77 |
node.children[i] = child
|
| 78 |
self.nodes.append(child)
|
|
|
|
| 79 |
if child.is_solved:
|
| 80 |
+
self.is_solved = True
|
| 81 |
self.solved_path = path + [(node, i)]
|
| 82 |
+
if not path:
|
| 83 |
+
best = np.argmax(node.policy)
|
| 84 |
+
node.N[best] += 1
|
| 85 |
+
node.W[best] = max(node.W[best], float(node.children[best].value))
|
| 86 |
|
| 87 |
# Backup
|
| 88 |
+
leaf_value = float(node.value)
|
| 89 |
+
for parent, a in reversed(path):
|
| 90 |
+
parent.N[a] += 1
|
| 91 |
+
parent.W[a] = max(parent.W[a], leaf_value)
|
| 92 |
+
if VIRTUAL_LOSS:
|
| 93 |
+
parent.L[a] -= VIRTUAL_LOSS
|
| 94 |
|
| 95 |
|
| 96 |
if __name__ == "__main__":
|