imwithye commited on
Commit
ee0d368
·
1 Parent(s): 5f0be19

fix frontend

Browse files
Files changed (1) hide show
  1. rlcube/main.py +7 -9
rlcube/main.py CHANGED
@@ -2,7 +2,8 @@ from typing import List
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from fastapi import HTTPException
5
- from rlcube.envs.cube2 import Cube2
 
6
  import numpy as np
7
 
8
  app = FastAPI()
@@ -22,12 +23,9 @@ def solve(body: StateArgs):
22
  ):
23
  raise HTTPException(status_code=400, detail="state must be a 6x4 matrix")
24
 
25
- env = Cube2()
26
  env.reset(state=np.array(state, dtype=np.int8))
27
-
28
- steps = []
29
- for _ in range(10):
30
- action = env.action_space.sample()
31
- steps.append(action.item())
32
-
33
- return {"steps": steps}
 
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from fastapi import HTTPException
5
+ from rlcube.envs.cube2 import Cube2Env
6
+ from rlcube.models.search import MonteCarloTree
7
  import numpy as np
8
 
9
  app = FastAPI()
 
23
  ):
24
  raise HTTPException(status_code=400, detail="state must be a 6x4 matrix")
25
 
26
+ env = Cube2Env()
27
  env.reset(state=np.array(state, dtype=np.int8))
28
+ tree = MonteCarloTree(env.obs(), max_simulations=300)
29
+ if tree.is_solved:
30
+ return {"steps": [action for _, action in tree.solved_path]}
31
+ raise HTTPException(status_code=400, detail="Unable to solve the cube")