imwithye commited on
Commit
521ddf5
·
1 Parent(s): d92472c

update search

Browse files
rlcube/cube2.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "id": "624c83c1",
7
  "metadata": {},
8
  "outputs": [
@@ -32,7 +32,7 @@
32
  ")"
33
  ]
34
  },
35
- "execution_count": 1,
36
  "metadata": {},
37
  "output_type": "execute_result"
38
  }
@@ -43,13 +43,13 @@
43
  "import torch\n",
44
  "\n",
45
  "net = DNN()\n",
46
- "net.load(\"models/model_best.pth\")\n",
47
  "net.eval()"
48
  ]
49
  },
50
  {
51
  "cell_type": "code",
52
- "execution_count": null,
53
  "id": "defde44e",
54
  "metadata": {},
55
  "outputs": [
@@ -57,48 +57,44 @@
57
  "name": "stdout",
58
  "output_type": "stream",
59
  "text": [
60
- "[11, 11, 3, 10, 9, 4, 5, 3, 11, 11]\n",
61
- "tensor([[ 1.2608],\n",
62
- " [ 0.2146],\n",
63
- " [-0.8424],\n",
64
- " [-0.6595],\n",
65
- " [-0.4404],\n",
66
- " [-1.2381],\n",
67
- " [-0.4404],\n",
68
- " [-1.6949],\n",
69
- " [-3.1237],\n",
70
- " [-2.8188]], grad_fn=<AddmmBackward0>)\n"
71
  ]
72
  },
73
  {
74
  "name": "stderr",
75
  "output_type": "stream",
76
  "text": [
77
- " 9%|▉ | 469/5000 [00:04<00:48, 94.14it/s] \n"
 
 
 
 
 
 
 
78
  ]
79
  },
80
  {
81
- "ename": "KeyboardInterrupt",
82
- "evalue": "",
83
- "output_type": "error",
84
- "traceback": [
85
- "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
86
- "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
87
- "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 20\u001b[39m\n\u001b[32m 16\u001b[39m \u001b[38;5;28mprint\u001b[39m(values)\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrlcube\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodels\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01msearch\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m MonteCarloTree\n\u001b[32m---> \u001b[39m\u001b[32m20\u001b[39m tree = \u001b[43mMonteCarloTree\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_simulations\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m5000\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 21\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m tree.is_solved:\n\u001b[32m 22\u001b[39m \u001b[38;5;28mprint\u001b[39m([action \u001b[38;5;28;01mfor\u001b[39;00m _, action \u001b[38;5;129;01min\u001b[39;00m tree.solved_path])\n",
88
- "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/models/search.py:59\u001b[39m, in \u001b[36mMonteCarloTree.__init__\u001b[39m\u001b[34m(self, obs, max_simulations)\u001b[39m\n\u001b[32m 57\u001b[39m \u001b[38;5;28mself\u001b[39m.is_solved = \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[32m 58\u001b[39m \u001b[38;5;28mself\u001b[39m.solved_path = []\n\u001b[32m---> \u001b[39m\u001b[32m59\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_build\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
89
- "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/models/search.py:80\u001b[39m, in \u001b[36mMonteCarloTree._build\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m12\u001b[39m):\n\u001b[32m 79\u001b[39m obs = adjacent_obs[i]\n\u001b[32m---> \u001b[39m\u001b[32m80\u001b[39m child = \u001b[43mNode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 81\u001b[39m node.children[i] = child\n\u001b[32m 82\u001b[39m \u001b[38;5;28mself\u001b[39m.nodes.append(child)\n",
90
- "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/models/search.py:21\u001b[39m, in \u001b[36mNode.__init__\u001b[39m\u001b[34m(self, obs, parent)\u001b[39m\n\u001b[32m 18\u001b[39m value = value.detach()\n\u001b[32m 19\u001b[39m policy = torch.softmax(policy.detach(), dim=\u001b[32m1\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m21\u001b[39m \u001b[38;5;28mself\u001b[39m.is_solved = \u001b[43mCube2Env\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfrom_obs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m)\u001b[49m.is_solved()\n\u001b[32m 22\u001b[39m \u001b[38;5;28mself\u001b[39m.value = torch.tensor(\u001b[32m1\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.is_solved \u001b[38;5;28;01melse\u001b[39;00m value.view(-\u001b[32m1\u001b[39m)\n\u001b[32m 23\u001b[39m \u001b[38;5;28mself\u001b[39m.policy = policy.view(-\u001b[32m1\u001b[39m)\n",
91
- "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/envs/cube2.py:30\u001b[39m, in \u001b[36mCube2Env.from_obs\u001b[39m\u001b[34m(obs)\u001b[39m\n\u001b[32m 28\u001b[39m idx = i * \u001b[32m4\u001b[39m + j\n\u001b[32m 29\u001b[39m state[i, j] = np.argmax(obs[idx])\n\u001b[32m---> \u001b[39m\u001b[32m30\u001b[39m env = \u001b[43mCube2Env\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 31\u001b[39m env.reset(state=state)\n\u001b[32m 32\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m env\n",
92
- "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/rlcube/envs/cube2.py:16\u001b[39m, in \u001b[36mCube2Env.__init__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 14\u001b[39m \u001b[38;5;28msuper\u001b[39m(Cube2Env, \u001b[38;5;28mself\u001b[39m).\u001b[34m__init__\u001b[39m()\n\u001b[32m 15\u001b[39m \u001b[38;5;28mself\u001b[39m.action_space = gym.spaces.Discrete(\u001b[32m12\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m \u001b[38;5;28mself\u001b[39m.observation_space = \u001b[43mgym\u001b[49m\u001b[43m.\u001b[49m\u001b[43mspaces\u001b[49m\u001b[43m.\u001b[49m\u001b[43mBox\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 17\u001b[39m \u001b[43m \u001b[49m\u001b[43mlow\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhigh\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m=\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m24\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m6\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m=\u001b[49m\u001b[43mnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43mint8\u001b[49m\n\u001b[32m 18\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 19\u001b[39m \u001b[38;5;28mself\u001b[39m.state = np.zeros((\u001b[32m6\u001b[39m, \u001b[32m4\u001b[39m), dtype=np.int8)\n\u001b[32m 20\u001b[39m \u001b[38;5;28mself\u001b[39m.reset()\n",
93
- "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/.venv/lib/python3.12/site-packages/gymnasium/spaces/box.py:149\u001b[39m, in \u001b[36mBox.__init__\u001b[39m\u001b[34m(self, low, high, shape, dtype, seed)\u001b[39m\n\u001b[32m 147\u001b[39m \u001b[38;5;66;03m# Cast `low` and `high` to ndarray for the dtype min and max for out of range tests\u001b[39;00m\n\u001b[32m 148\u001b[39m \u001b[38;5;28mself\u001b[39m.low, \u001b[38;5;28mself\u001b[39m.bounded_below = \u001b[38;5;28mself\u001b[39m._cast_low(low, dtype_min)\n\u001b[32m--> \u001b[39m\u001b[32m149\u001b[39m \u001b[38;5;28mself\u001b[39m.high, \u001b[38;5;28mself\u001b[39m.bounded_above = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_cast_high\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhigh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype_max\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 151\u001b[39m \u001b[38;5;66;03m# recheck shape for case where shape and (low or high) are provided\u001b[39;00m\n\u001b[32m 152\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.low.shape != shape:\n",
94
- "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/.venv/lib/python3.12/site-packages/gymnasium/spaces/box.py:251\u001b[39m, in \u001b[36mBox._cast_high\u001b[39m\u001b[34m(self, high, dtype_max)\u001b[39m\n\u001b[32m 241\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_cast_high\u001b[39m(\u001b[38;5;28mself\u001b[39m, high, dtype_max) -> \u001b[38;5;28mtuple\u001b[39m[np.ndarray, np.ndarray]:\n\u001b[32m 242\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Casts the input Box high value to ndarray with provided dtype.\u001b[39;00m\n\u001b[32m 243\u001b[39m \n\u001b[32m 244\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 249\u001b[39m \u001b[33;03m The updated high value and for what values the input is bounded (above)\u001b[39;00m\n\u001b[32m 250\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m251\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mis_float_integer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhigh\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[32m 252\u001b[39m bounded_above = np.full(\u001b[38;5;28mself\u001b[39m.shape, high, dtype=\u001b[38;5;28mfloat\u001b[39m) < np.inf\n\u001b[32m 254\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m np.isnan(high):\n",
95
- "\u001b[36mFile \u001b[39m\u001b[32m~/Documents/Workspace/imwithye/rlcube/rlcube/.venv/lib/python3.12/site-packages/gymnasium/spaces/box.py:32\u001b[39m, in \u001b[36mis_float_integer\u001b[39m\u001b[34m(var)\u001b[39m\n\u001b[32m 28\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(np.min(arr))\n\u001b[32m 29\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(arr)\n\u001b[32m---> \u001b[39m\u001b[32m32\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mis_float_integer\u001b[39m(var: Any) -> \u001b[38;5;28mbool\u001b[39m:\n\u001b[32m 33\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Checks if a scalar variable is an integer or float (does not include bool).\"\"\"\u001b[39;00m\n\u001b[32m 34\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m np.issubdtype(\u001b[38;5;28mtype\u001b[39m(var), np.integer) \u001b[38;5;129;01mor\u001b[39;00m np.issubdtype(\u001b[38;5;28mtype\u001b[39m(var), np.floating)\n",
96
- "\u001b[31mKeyboardInterrupt\u001b[39m: "
97
  ]
98
  }
99
  ],
100
  "source": [
101
  "import numpy as np\n",
 
102
  "\n",
103
  "env = Cube2Env()\n",
104
  "\n",
@@ -115,46 +111,12 @@
115
  "print(actions)\n",
116
  "print(values)\n",
117
  "\n",
118
- "from rlcube.models.search import MonteCarloTree\n",
119
  "\n",
120
  "tree = MonteCarloTree(env.obs(), max_simulations=1000)\n",
121
  "if tree.is_solved:\n",
122
  " print([action for _, action in tree.solved_path])"
123
  ]
124
  },
125
- {
126
- "cell_type": "code",
127
- "execution_count": 6,
128
- "id": "a91732d7",
129
- "metadata": {},
130
- "outputs": [
131
- {
132
- "data": {
133
- "text/plain": [
134
- "defaultdict(<function rlcube.models.search.Node.__init__.<locals>.<lambda>()>,\n",
135
- " {0: 400,\n",
136
- " 1: 0,\n",
137
- " 2: 0,\n",
138
- " 3: 0,\n",
139
- " 4: 0,\n",
140
- " 5: 0,\n",
141
- " 6: 0,\n",
142
- " 7: 0,\n",
143
- " 8: 0,\n",
144
- " 9: 0,\n",
145
- " 10: 44,\n",
146
- " 11: 0})"
147
- ]
148
- },
149
- "execution_count": 6,
150
- "metadata": {},
151
- "output_type": "execute_result"
152
- }
153
- ],
154
- "source": [
155
- "tree.root.N"
156
- ]
157
- },
158
  {
159
  "cell_type": "code",
160
  "execution_count": null,
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 101,
6
  "id": "624c83c1",
7
  "metadata": {},
8
  "outputs": [
 
32
  ")"
33
  ]
34
  },
35
+ "execution_count": 101,
36
  "metadata": {},
37
  "output_type": "execute_result"
38
  }
 
43
  "import torch\n",
44
  "\n",
45
  "net = DNN()\n",
46
+ "net.load(\"models/model_final.pth\")\n",
47
  "net.eval()"
48
  ]
49
  },
50
  {
51
  "cell_type": "code",
52
+ "execution_count": 103,
53
  "id": "defde44e",
54
  "metadata": {},
55
  "outputs": [
 
57
  "name": "stdout",
58
  "output_type": "stream",
59
  "text": [
60
+ "[7, 11, 6, 7, 7, 10, 1, 0, 3, 3]\n",
61
+ "tensor([[ 0.9634],\n",
62
+ " [-0.0930],\n",
63
+ " [-0.8327],\n",
64
+ " [-0.0930],\n",
65
+ " [-0.8955],\n",
66
+ " [-1.8250],\n",
67
+ " [-4.0525],\n",
68
+ " [-1.8250],\n",
69
+ " [-3.0264],\n",
70
+ " [-3.6782]], grad_fn=<AddmmBackward0>)\n"
71
  ]
72
  },
73
  {
74
  "name": "stderr",
75
  "output_type": "stream",
76
  "text": [
77
+ " 1%| | 8/1000 [00:00<00:10, 99.11it/s]"
78
+ ]
79
+ },
80
+ {
81
+ "name": "stdout",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "[0, 2, 5, 2, 8, 6]\n"
85
  ]
86
  },
87
  {
88
+ "name": "stderr",
89
+ "output_type": "stream",
90
+ "text": [
91
+ "\n"
 
 
 
 
 
 
 
 
 
 
 
 
92
  ]
93
  }
94
  ],
95
  "source": [
96
  "import numpy as np\n",
97
+ "from rlcube.models.search import MonteCarloTree\n",
98
  "\n",
99
  "env = Cube2Env()\n",
100
  "\n",
 
111
  "print(actions)\n",
112
  "print(values)\n",
113
  "\n",
 
114
  "\n",
115
  "tree = MonteCarloTree(env.obs(), max_simulations=1000)\n",
116
  "if tree.is_solved:\n",
117
  " print([action for _, action in tree.solved_path])"
118
  ]
119
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  {
121
  "cell_type": "code",
122
  "execution_count": null,
rlcube/main.py CHANGED
@@ -28,4 +28,4 @@ def solve(body: StateArgs):
28
  tree = MonteCarloTree(env.obs(), max_simulations=300)
29
  if tree.is_solved:
30
  return {"steps": [action for _, action in tree.solved_path]}
31
- raise HTTPException(status_code=400, detail="Unable to solve the cube")
 
28
  tree = MonteCarloTree(env.obs(), max_simulations=300)
29
  if tree.is_solved:
30
  return {"steps": [action for _, action in tree.solved_path]}
31
+ raise HTTPException(status_code=422, detail="Unable to solve the cube")
rlcube/rlcube/models/search.py CHANGED
@@ -1,8 +1,12 @@
1
- from collections import defaultdict
2
  import torch
3
  from rlcube.models.models import DNN
4
  from rlcube.envs.cube2 import Cube2Env
5
  from tqdm import tqdm
 
 
 
 
 
6
 
7
  net = DNN()
8
  net.load("models/model_final.pth")
@@ -14,38 +18,29 @@ class Node:
14
  self.obs = torch.tensor(obs, dtype=torch.float32)
15
  self.parent = parent
16
 
17
- value, policy = net(self.obs.unsqueeze(0))
18
- value = value.detach()
19
- policy = torch.softmax(policy.detach(), dim=1)
20
-
21
  self.is_solved = Cube2Env.from_obs(obs).is_solved()
22
- self.value = torch.tensor(1) if self.is_solved else value.view(-1)
23
- self.policy = policy.view(-1)
 
24
 
25
- self.children = defaultdict(lambda: None)
26
- self.N = defaultdict(lambda: 0)
27
- self.W = defaultdict(lambda: 0)
 
28
 
29
  def is_leaf(self):
30
  return len(self.children) == 0
31
 
32
  def u(self):
33
- c = 1.414
34
- n_sum = torch.sum(torch.tensor([self.N[action] for action in range(12)]))
35
- u = torch.tensor(
36
- [
37
- c
38
- * self.policy[action].item()
39
- * torch.sqrt(n_sum)
40
- / (self.N[action] + 1)
41
- + self.W[action]
42
- for action in range(12)
43
- ]
44
- )
45
- return u
46
 
47
  def select_action(self):
48
- return torch.argmax(self.u()).item()
 
49
 
50
 
51
  class MonteCarloTree:
@@ -68,26 +63,34 @@ class MonteCarloTree:
68
 
69
  # Selection
70
  while not node.is_leaf():
71
- action = node.select_action()
72
- path.append((node, action))
73
- node = node.children[action]
 
 
74
 
75
  # Expansion
76
  env = Cube2Env.from_obs(node.obs)
77
  adjacent_obs = env.adjacent_obs()
78
- for i in range(12):
79
- obs = adjacent_obs[i]
80
- child = Node(obs, node)
81
  node.children[i] = child
82
  self.nodes.append(child)
83
- self.is_solved = self.is_solved or child.is_solved
84
  if child.is_solved:
 
85
  self.solved_path = path + [(node, i)]
 
 
 
 
86
 
87
  # Backup
88
- for parent, action in reversed(path):
89
- parent.N[action] += 1
90
- parent.W[action] = max(parent.W[action], node.value)
 
 
 
91
 
92
 
93
  if __name__ == "__main__":
 
 
1
  import torch
2
  from rlcube.models.models import DNN
3
  from rlcube.envs.cube2 import Cube2Env
4
  from tqdm import tqdm
5
+ import numpy as np
6
+
7
+ ACTIONS = 12
8
+ C_PUCT = 1.414
9
+ VIRTUAL_LOSS = 0.0
10
 
11
  net = DNN()
12
  net.load("models/model_final.pth")
 
18
  self.obs = torch.tensor(obs, dtype=torch.float32)
19
  self.parent = parent
20
 
21
+ with torch.no_grad():
22
+ value, policy_logits = net(self.obs.unsqueeze(0))
 
 
23
  self.is_solved = Cube2Env.from_obs(obs).is_solved()
24
+ self.value = 1.0 if self.is_solved else float(value.item())
25
+ policy = torch.softmax(policy_logits, dim=1).view(-1)
26
+ self.policy = np.array([float(policy[i].item()) for i in range(ACTIONS)])
27
 
28
+ self.children = {}
29
+ self.N = np.zeros(ACTIONS, dtype=np.int32)
30
+ self.W = np.zeros(ACTIONS, dtype=np.float32) # max value seen (not average)
31
+ self.L = np.zeros(ACTIONS, dtype=np.float32) # virtual loss (for async)
32
 
33
  def is_leaf(self):
34
  return len(self.children) == 0
35
 
36
  def u(self):
37
+ n_sum = np.sum(self.N) + 1
38
+ scores = self.policy * C_PUCT * np.sqrt(n_sum) / (self.N + 1) + self.W - self.L
39
+ return scores
 
 
 
 
 
 
 
 
 
 
40
 
41
  def select_action(self):
42
+ scores = self.u()
43
+ return np.argmax(scores)
44
 
45
 
46
  class MonteCarloTree:
 
63
 
64
  # Selection
65
  while not node.is_leaf():
66
+ a = node.select_action()
67
+ path.append((node, a))
68
+ if VIRTUAL_LOSS:
69
+ node.L[a] += VIRTUAL_LOSS
70
+ node = node.children[a]
71
 
72
  # Expansion
73
  env = Cube2Env.from_obs(node.obs)
74
  adjacent_obs = env.adjacent_obs()
75
+ for i in range(ACTIONS):
76
+ child = Node(adjacent_obs[i], node)
 
77
  node.children[i] = child
78
  self.nodes.append(child)
 
79
  if child.is_solved:
80
+ self.is_solved = True
81
  self.solved_path = path + [(node, i)]
82
+ if not path:
83
+ best = np.argmax(node.policy)
84
+ node.N[best] += 1
85
+ node.W[best] = max(node.W[best], float(node.children[best].value))
86
 
87
  # Backup
88
+ leaf_value = float(node.value)
89
+ for parent, a in reversed(path):
90
+ parent.N[a] += 1
91
+ parent.W[a] = max(parent.W[a], leaf_value)
92
+ if VIRTUAL_LOSS:
93
+ parent.L[a] -= VIRTUAL_LOSS
94
 
95
 
96
  if __name__ == "__main__":