imwithye commited on
Commit
6537541
·
1 Parent(s): 314d0a6

generate dataset

Browse files
.gitattributes CHANGED
@@ -1 +1,2 @@
1
  *.blend filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.blend filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
rlcube/.gitignore CHANGED
@@ -217,3 +217,6 @@ __marimo__/
217
 
218
  # Blender
219
  *.blend1
 
 
 
 
217
 
218
  # Blender
219
  *.blend1
220
+
221
+ # Dataset
222
+ dataset.pt
rlcube/pyproject.toml CHANGED
@@ -12,4 +12,5 @@ dependencies = [
12
  "numpy>=2.3.2",
13
  "tensordict>=0.10.0",
14
  "torch>=2.8.0",
 
15
  ]
 
12
  "numpy>=2.3.2",
13
  "tensordict>=0.10.0",
14
  "torch>=2.8.0",
15
+ "tqdm>=4.67.1",
16
  ]
rlcube/rlcube/models/dataset.py CHANGED
@@ -1,27 +1,43 @@
1
  from torch.utils.data import Dataset
2
  from rlcube.envs.cube2 import Cube2
3
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  class Cube2Dataset(Dataset):
7
- def __init__(self, num_envs: int = 1000, num_steps: int = 20):
8
- self.num_envs = num_envs
9
- self.num_steps = num_steps
10
- self.states = []
11
- self.D = []
12
- for _ in range(num_envs):
13
- env = Cube2()
14
- obs, _ = env.reset()
15
- for _ in range(num_steps):
16
- action = env.action_space.sample()
17
- obs, _, _, _, _ = env.step(action)
18
- self.states.append(obs)
19
- self.D.append(env.step_count)
20
- self.states = np.array(self.states)
21
- self.D = np.array(self.D)
22
 
23
  def __len__(self):
24
  return len(self.states)
25
 
26
  def __getitem__(self, idx):
27
- return self.states[idx], self.D[idx]
 
1
  from torch.utils.data import Dataset
2
  from rlcube.envs.cube2 import Cube2
3
  import numpy as np
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+
8
+ def create_dataset(num_envs: int = 10000, num_steps: int = 50):
9
+ states = []
10
+ neighbors = []
11
+ D = []
12
+ for _ in tqdm(range(num_envs)):
13
+ env = Cube2()
14
+ obs, _ = env.reset()
15
+ for _ in range(num_steps):
16
+ action = env.action_space.sample()
17
+ obs, _, _, _, _ = env.step(action)
18
+ states.append(obs)
19
+ neighbors.append(env.neighbors())
20
+ D.append(env.step_count)
21
+ states = np.array(states)
22
+ neighbors = np.array(neighbors)
23
+ D = np.array(D)
24
+ dataseet = {
25
+ "states": torch.tensor(states),
26
+ "neighbors": torch.tensor(neighbors),
27
+ "D": torch.tensor(D),
28
+ }
29
+ torch.save(dataseet, "dataset.pt")
30
 
31
 
32
  class Cube2Dataset(Dataset):
33
+ def __init__(self, filepath: str = "dataset.pt"):
34
+ self.dataset = torch.load(filepath)
35
+ self.states = self.dataset["states"]
36
+ self.neighbors = self.dataset["neighbors"]
37
+ self.D = self.dataset["D"]
 
 
 
 
 
 
 
 
 
 
38
 
39
  def __len__(self):
40
  return len(self.states)
41
 
42
  def __getitem__(self, idx):
43
+ return self.states[idx], self.neighbors[idx], self.D[idx]
rlcube/rlcube/train/train.py CHANGED
@@ -1,6 +1,14 @@
1
  from rlcube.models.dataset import Cube2Dataset
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  if __name__ == "__main__":
5
- dataset = Cube2Dataset(num_envs=10, num_steps=20)
6
- print(dataset[10])
 
1
  from rlcube.models.dataset import Cube2Dataset
2
+ from rlcube.envs.cube2 import Cube2
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+
6
+
7
+ def train(epochs: int = 100):
8
+ dataset = Cube2Dataset()
9
+ for _ in tqdm(range(epochs)):
10
+ pass
11
 
12
 
13
  if __name__ == "__main__":
14
+ train()
 
rlcube/uv.lock CHANGED
@@ -1536,6 +1536,7 @@ dependencies = [
1536
  { name = "numpy" },
1537
  { name = "tensordict" },
1538
  { name = "torch" },
 
1539
  ]
1540
 
1541
  [package.metadata]
@@ -1547,6 +1548,7 @@ requires-dist = [
1547
  { name = "numpy", specifier = ">=2.3.2" },
1548
  { name = "tensordict", specifier = ">=0.10.0" },
1549
  { name = "torch", specifier = ">=2.8.0" },
 
1550
  ]
1551
 
1552
  [[package]]
 
1536
  { name = "numpy" },
1537
  { name = "tensordict" },
1538
  { name = "torch" },
1539
+ { name = "tqdm" },
1540
  ]
1541
 
1542
  [package.metadata]
 
1548
  { name = "numpy", specifier = ">=2.3.2" },
1549
  { name = "tensordict", specifier = ">=0.10.0" },
1550
  { name = "torch", specifier = ">=2.8.0" },
1551
+ { name = "tqdm", specifier = ">=4.67.1" },
1552
  ]
1553
 
1554
  [[package]]