imwithye commited on
Commit
658f806
·
1 Parent(s): a02161e
Files changed (2) hide show
  1. rlcube/rlcube/cube2.py +16 -10
  2. rlcube/rlcube/envs/cube2.py +19 -4
rlcube/rlcube/cube2.py CHANGED
@@ -1,15 +1,21 @@
 
1
  from .envs.cube2 import Cube2
 
 
 
 
 
 
 
 
 
 
2
 
3
  def train():
4
  env = Cube2()
 
 
 
 
5
 
6
- obs, _ = env.reset()
7
- for i in range(4):
8
- # action = env.action_space.sample()
9
- obs, reward, terminated, truncated, _ = env.step(10)
10
- print(obs)
11
- print("--------------------------------")
12
- if terminated or truncated:
13
- break
14
- print(env._is_solved())
15
- env.close()
 
1
+ import gymnasium as gym
2
  from .envs.cube2 import Cube2
3
+ from stable_baselines3 import DQN
4
+
5
+ class RewardWrapper(gym.Wrapper):
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+
9
+ def step(self, action):
10
+ obs, reward, terminated, truncated, _ = super().step(action)
11
+ return obs, reward, terminated, truncated, _
12
+
13
 
14
  def train():
15
  env = Cube2()
16
+ env = RewardWrapper(env)
17
+
18
+ model = DQN("MlpPolicy", env, verbose=1)
19
+ model.learn(total_timesteps=10000, log_interval=10)
20
 
21
+ env.close()
 
 
 
 
 
 
 
 
 
rlcube/rlcube/envs/cube2.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gymnasium as gym
2
  import numpy as np
3
 
@@ -11,8 +12,8 @@ DOWN = 5
11
  class Cube2(gym.Env):
12
  def __init__(self):
13
  super().__init__()
14
- self.action_space = gym.spaces.Discrete(6)
15
- self.observation_space = gym.spaces.Box(low=0, high=1, shape=(24, 6))
16
  self.state = np.zeros((6, 2, 2))
17
  self.step_count = 0
18
 
@@ -25,8 +26,11 @@ class Cube2(gym.Env):
25
  self.state[3] = np.ones((2, 2)) * LEFT
26
  self.state[4] = np.ones((2, 2)) * UP
27
  self.state[5] = np.ones((2, 2)) * DOWN
 
 
 
28
  self.step_count = 0
29
- return self.state, {}
30
 
31
  def step(self, action):
32
  self.step_count += 1
@@ -154,7 +158,18 @@ class Cube2(gym.Env):
154
  new_state[FRONT, 1, 1] = self.state[RIGHT, 1, 1]
155
 
156
  self.state = new_state
157
- return self.state.copy(), 1 if self._is_solved() else -1, self._is_solved(), self.step_count >= 100, {}
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  def _is_solved(self):
160
  for i in range(6):
 
1
+ from random import shuffle
2
  import gymnasium as gym
3
  import numpy as np
4
 
 
12
  class Cube2(gym.Env):
13
  def __init__(self):
14
  super().__init__()
15
+ self.action_space = gym.spaces.Discrete(12)
16
+ self.observation_space = gym.spaces.Box(low=0,high=1,shape=(24, 6),dtype=np.int8)
17
  self.state = np.zeros((6, 2, 2))
18
  self.step_count = 0
19
 
 
26
  self.state[3] = np.ones((2, 2)) * LEFT
27
  self.state[4] = np.ones((2, 2)) * UP
28
  self.state[5] = np.ones((2, 2)) * DOWN
29
+ shuffle_steps =self.np_random.integers(0, 20)
30
+ for i in range(shuffle_steps):
31
+ self.step(self.action_space.sample())
32
  self.step_count = 0
33
+ return self._get_obs(), {}
34
 
35
  def step(self, action):
36
  self.step_count += 1
 
158
  new_state[FRONT, 1, 1] = self.state[RIGHT, 1, 1]
159
 
160
  self.state = new_state
161
+ return self._get_obs(), 1 if self._is_solved() else -1, self._is_solved(), self.step_count >= 100, {}
162
+
163
+ def _get_obs(self):
164
+ one_hots = []
165
+ for i in range(6):
166
+ for j in range(2):
167
+ for k in range(2):
168
+ label = int(self.state[i, j, k])
169
+ zeros = np.zeros(6)
170
+ zeros[label] = 1
171
+ one_hots.append(zeros)
172
+ return np.array(one_hots)
173
 
174
  def _is_solved(self):
175
  for i in range(6):