Spaces:
Running
Running
generate dataset
Browse files- .gitattributes +1 -0
- rlcube/.gitignore +3 -0
- rlcube/pyproject.toml +1 -0
- rlcube/rlcube/models/dataset.py +32 -16
- rlcube/rlcube/train/train.py +10 -2
- rlcube/uv.lock +2 -0
.gitattributes
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
*.blend filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 1 |
*.blend filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
rlcube/.gitignore
CHANGED
|
@@ -217,3 +217,6 @@ __marimo__/
|
|
| 217 |
|
| 218 |
# Blender
|
| 219 |
*.blend1
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
# Blender
|
| 219 |
*.blend1
|
| 220 |
+
|
| 221 |
+
# Dataset
|
| 222 |
+
dataset.pt
|
rlcube/pyproject.toml
CHANGED
|
@@ -12,4 +12,5 @@ dependencies = [
|
|
| 12 |
"numpy>=2.3.2",
|
| 13 |
"tensordict>=0.10.0",
|
| 14 |
"torch>=2.8.0",
|
|
|
|
| 15 |
]
|
|
|
|
| 12 |
"numpy>=2.3.2",
|
| 13 |
"tensordict>=0.10.0",
|
| 14 |
"torch>=2.8.0",
|
| 15 |
+
"tqdm>=4.67.1",
|
| 16 |
]
|
rlcube/rlcube/models/dataset.py
CHANGED
|
@@ -1,27 +1,43 @@
|
|
| 1 |
from torch.utils.data import Dataset
|
| 2 |
from rlcube.envs.cube2 import Cube2
|
| 3 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class Cube2Dataset(Dataset):
|
| 7 |
-
def __init__(self,
|
| 8 |
-
self.
|
| 9 |
-
self.
|
| 10 |
-
self.
|
| 11 |
-
self.D = []
|
| 12 |
-
for _ in range(num_envs):
|
| 13 |
-
env = Cube2()
|
| 14 |
-
obs, _ = env.reset()
|
| 15 |
-
for _ in range(num_steps):
|
| 16 |
-
action = env.action_space.sample()
|
| 17 |
-
obs, _, _, _, _ = env.step(action)
|
| 18 |
-
self.states.append(obs)
|
| 19 |
-
self.D.append(env.step_count)
|
| 20 |
-
self.states = np.array(self.states)
|
| 21 |
-
self.D = np.array(self.D)
|
| 22 |
|
| 23 |
def __len__(self):
|
| 24 |
return len(self.states)
|
| 25 |
|
| 26 |
def __getitem__(self, idx):
|
| 27 |
-
return self.states[idx], self.D[idx]
|
|
|
|
| 1 |
from torch.utils.data import Dataset
|
| 2 |
from rlcube.envs.cube2 import Cube2
|
| 3 |
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_dataset(num_envs: int = 10000, num_steps: int = 50):
|
| 9 |
+
states = []
|
| 10 |
+
neighbors = []
|
| 11 |
+
D = []
|
| 12 |
+
for _ in tqdm(range(num_envs)):
|
| 13 |
+
env = Cube2()
|
| 14 |
+
obs, _ = env.reset()
|
| 15 |
+
for _ in range(num_steps):
|
| 16 |
+
action = env.action_space.sample()
|
| 17 |
+
obs, _, _, _, _ = env.step(action)
|
| 18 |
+
states.append(obs)
|
| 19 |
+
neighbors.append(env.neighbors())
|
| 20 |
+
D.append(env.step_count)
|
| 21 |
+
states = np.array(states)
|
| 22 |
+
neighbors = np.array(neighbors)
|
| 23 |
+
D = np.array(D)
|
| 24 |
+
dataseet = {
|
| 25 |
+
"states": torch.tensor(states),
|
| 26 |
+
"neighbors": torch.tensor(neighbors),
|
| 27 |
+
"D": torch.tensor(D),
|
| 28 |
+
}
|
| 29 |
+
torch.save(dataseet, "dataset.pt")
|
| 30 |
|
| 31 |
|
| 32 |
class Cube2Dataset(Dataset):
|
| 33 |
+
def __init__(self, filepath: str = "dataset.pt"):
|
| 34 |
+
self.dataset = torch.load(filepath)
|
| 35 |
+
self.states = self.dataset["states"]
|
| 36 |
+
self.neighbors = self.dataset["neighbors"]
|
| 37 |
+
self.D = self.dataset["D"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
def __len__(self):
|
| 40 |
return len(self.states)
|
| 41 |
|
| 42 |
def __getitem__(self, idx):
|
| 43 |
+
return self.states[idx], self.neighbors[idx], self.D[idx]
|
rlcube/rlcube/train/train.py
CHANGED
|
@@ -1,6 +1,14 @@
|
|
| 1 |
from rlcube.models.dataset import Cube2Dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
if __name__ == "__main__":
|
| 5 |
-
|
| 6 |
-
print(dataset[10])
|
|
|
|
| 1 |
from rlcube.models.dataset import Cube2Dataset
|
| 2 |
+
from rlcube.envs.cube2 import Cube2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def train(epochs: int = 100):
|
| 8 |
+
dataset = Cube2Dataset()
|
| 9 |
+
for _ in tqdm(range(epochs)):
|
| 10 |
+
pass
|
| 11 |
|
| 12 |
|
| 13 |
if __name__ == "__main__":
|
| 14 |
+
train()
|
|
|
rlcube/uv.lock
CHANGED
|
@@ -1536,6 +1536,7 @@ dependencies = [
|
|
| 1536 |
{ name = "numpy" },
|
| 1537 |
{ name = "tensordict" },
|
| 1538 |
{ name = "torch" },
|
|
|
|
| 1539 |
]
|
| 1540 |
|
| 1541 |
[package.metadata]
|
|
@@ -1547,6 +1548,7 @@ requires-dist = [
|
|
| 1547 |
{ name = "numpy", specifier = ">=2.3.2" },
|
| 1548 |
{ name = "tensordict", specifier = ">=0.10.0" },
|
| 1549 |
{ name = "torch", specifier = ">=2.8.0" },
|
|
|
|
| 1550 |
]
|
| 1551 |
|
| 1552 |
[[package]]
|
|
|
|
| 1536 |
{ name = "numpy" },
|
| 1537 |
{ name = "tensordict" },
|
| 1538 |
{ name = "torch" },
|
| 1539 |
+
{ name = "tqdm" },
|
| 1540 |
]
|
| 1541 |
|
| 1542 |
[package.metadata]
|
|
|
|
| 1548 |
{ name = "numpy", specifier = ">=2.3.2" },
|
| 1549 |
{ name = "tensordict", specifier = ">=0.10.0" },
|
| 1550 |
{ name = "torch", specifier = ">=2.8.0" },
|
| 1551 |
+
{ name = "tqdm", specifier = ">=4.67.1" },
|
| 1552 |
]
|
| 1553 |
|
| 1554 |
[[package]]
|