Spaces:
Sleeping
Sleeping
fix env
Browse files- rlcube/rlcube/cube2.py +16 -10
- rlcube/rlcube/envs/cube2.py +19 -4
rlcube/rlcube/cube2.py
CHANGED
|
@@ -1,15 +1,21 @@
|
|
|
|
|
| 1 |
from .envs.cube2 import Cube2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
def train():
|
| 4 |
env = Cube2()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
for i in range(4):
|
| 8 |
-
# action = env.action_space.sample()
|
| 9 |
-
obs, reward, terminated, truncated, _ = env.step(10)
|
| 10 |
-
print(obs)
|
| 11 |
-
print("--------------------------------")
|
| 12 |
-
if terminated or truncated:
|
| 13 |
-
break
|
| 14 |
-
print(env._is_solved())
|
| 15 |
-
env.close()
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
from .envs.cube2 import Cube2
|
| 3 |
+
from stable_baselines3 import DQN
|
| 4 |
+
|
| 5 |
+
class RewardWrapper(gym.Wrapper):
|
| 6 |
+
def __init__(self, *args, **kwargs):
|
| 7 |
+
super().__init__(*args, **kwargs)
|
| 8 |
+
|
| 9 |
+
def step(self, action):
|
| 10 |
+
obs, reward, terminated, truncated, _ = super().step(action)
|
| 11 |
+
return obs, reward, terminated, truncated, _
|
| 12 |
+
|
| 13 |
|
| 14 |
def train():
|
| 15 |
env = Cube2()
|
| 16 |
+
env = RewardWrapper(env)
|
| 17 |
+
|
| 18 |
+
model = DQN("MlpPolicy", env, verbose=1)
|
| 19 |
+
model.learn(total_timesteps=10000, log_interval=10)
|
| 20 |
|
| 21 |
+
env.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rlcube/rlcube/envs/cube2.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import gymnasium as gym
|
| 2 |
import numpy as np
|
| 3 |
|
|
@@ -11,8 +12,8 @@ DOWN = 5
|
|
| 11 |
class Cube2(gym.Env):
|
| 12 |
def __init__(self):
|
| 13 |
super().__init__()
|
| 14 |
-
self.action_space = gym.spaces.Discrete(
|
| 15 |
-
self.observation_space = gym.spaces.Box(low=0,
|
| 16 |
self.state = np.zeros((6, 2, 2))
|
| 17 |
self.step_count = 0
|
| 18 |
|
|
@@ -25,8 +26,11 @@ class Cube2(gym.Env):
|
|
| 25 |
self.state[3] = np.ones((2, 2)) * LEFT
|
| 26 |
self.state[4] = np.ones((2, 2)) * UP
|
| 27 |
self.state[5] = np.ones((2, 2)) * DOWN
|
|
|
|
|
|
|
|
|
|
| 28 |
self.step_count = 0
|
| 29 |
-
return self.
|
| 30 |
|
| 31 |
def step(self, action):
|
| 32 |
self.step_count += 1
|
|
@@ -154,7 +158,18 @@ class Cube2(gym.Env):
|
|
| 154 |
new_state[FRONT, 1, 1] = self.state[RIGHT, 1, 1]
|
| 155 |
|
| 156 |
self.state = new_state
|
| 157 |
-
return self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
def _is_solved(self):
|
| 160 |
for i in range(6):
|
|
|
|
| 1 |
+
from random import shuffle
|
| 2 |
import gymnasium as gym
|
| 3 |
import numpy as np
|
| 4 |
|
|
|
|
| 12 |
class Cube2(gym.Env):
|
| 13 |
def __init__(self):
|
| 14 |
super().__init__()
|
| 15 |
+
self.action_space = gym.spaces.Discrete(12)
|
| 16 |
+
self.observation_space = gym.spaces.Box(low=0,high=1,shape=(24, 6),dtype=np.int8)
|
| 17 |
self.state = np.zeros((6, 2, 2))
|
| 18 |
self.step_count = 0
|
| 19 |
|
|
|
|
| 26 |
self.state[3] = np.ones((2, 2)) * LEFT
|
| 27 |
self.state[4] = np.ones((2, 2)) * UP
|
| 28 |
self.state[5] = np.ones((2, 2)) * DOWN
|
| 29 |
+
shuffle_steps =self.np_random.integers(0, 20)
|
| 30 |
+
for i in range(shuffle_steps):
|
| 31 |
+
self.step(self.action_space.sample())
|
| 32 |
self.step_count = 0
|
| 33 |
+
return self._get_obs(), {}
|
| 34 |
|
| 35 |
def step(self, action):
|
| 36 |
self.step_count += 1
|
|
|
|
| 158 |
new_state[FRONT, 1, 1] = self.state[RIGHT, 1, 1]
|
| 159 |
|
| 160 |
self.state = new_state
|
| 161 |
+
return self._get_obs(), 1 if self._is_solved() else -1, self._is_solved(), self.step_count >= 100, {}
|
| 162 |
+
|
| 163 |
+
def _get_obs(self):
|
| 164 |
+
one_hots = []
|
| 165 |
+
for i in range(6):
|
| 166 |
+
for j in range(2):
|
| 167 |
+
for k in range(2):
|
| 168 |
+
label = int(self.state[i, j, k])
|
| 169 |
+
zeros = np.zeros(6)
|
| 170 |
+
zeros[label] = 1
|
| 171 |
+
one_hots.append(zeros)
|
| 172 |
+
return np.array(one_hots)
|
| 173 |
|
| 174 |
def _is_solved(self):
|
| 175 |
for i in range(6):
|