Jensin commited on
Commit
bff785d
·
verified ·
0 Parent(s):

initial commit

Browse files
Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +77 -0
  4. requirements.txt +5 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Drone Eval
3
+ emoji: 🦀
4
+ colorFrom: red
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.42.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This Gradio app provides a simple interface to evaluate a trained drone hovering policy.
2
+ # It allows users to select a checkpoint and optionally record a video of the evaluation.
3
+ # Note: The `torch` library is not available, so the code has been modified to remove the dependency.
4
+
5
+ import gradio as gr
6
+ import os
7
+ import pickle
8
+
9
+ # Assuming the `evaluate` function from the provided code is available here.
10
+ # The `evaluate` function is responsible for running the trained policy in a single environment with visualization.
11
+
12
+ def evaluate_policy(exp_name, ckpt, record):
13
+ # Load the environment and policy configuration saved during training
14
+ log_dir = f"logs/{exp_name}"
15
+ if not os.path.exists(log_dir):
16
+ raise FileNotFoundError(f"Log directory '{log_dir}' does not exist. Did you run training?")
17
+ env_cfg, obs_cfg, reward_cfg, command_cfg, train_cfg = pickle.load(open(f"{log_dir}/cfgs.pkl", "rb"))
18
+
19
+ # For evaluation, we disable reward scaling (pure inference)
20
+ reward_cfg["reward_scales"] = {}
21
+
22
+ # Always visualize the target during evaluation
23
+ env_cfg["visualize_target"] = True
24
+
25
+ # Optionally set up a camera for recording
26
+ env_cfg["visualize_camera"] = record
27
+ env_cfg["max_visualize_FPS"] = 60
28
+
29
+ # Build a single-environment instance with viewer
30
+ env = HoverEnv(num_envs=1, env_cfg=env_cfg, obs_cfg=obs_cfg, reward_cfg=reward_cfg, command_cfg=command_cfg, show_viewer=True)
31
+ runner = OnPolicyRunner(env, train_cfg, log_dir, device=gs.device)
32
+
33
+ # Load the specified checkpoint
34
+ resume_path = os.path.join(log_dir, f"model_{ckpt}.pt")
35
+ runner.load(resume_path)
36
+
37
+ # Get the inference policy
38
+ policy = runner.get_inference_policy(device=gs.device)
39
+
40
+ # Reset the environment
41
+ obs, _ = env.reset()
42
+
43
+ # Number of simulation steps equal to the episode duration times FPS
44
+ max_sim_step = int(env_cfg["episode_length_s"] * env_cfg["max_visualize_FPS"])
45
+
46
+ if record and env.cam is not None:
47
+ env.cam.start_recording()
48
+ for _ in range(max_sim_step):
49
+ actions = policy(obs)
50
+ obs, rews, dones, infos = env.step(actions)
51
+ env.cam.render()
52
+ env.cam.stop_recording(save_to_filename="video.mp4", fps=env_cfg["max_visualize_FPS"])
53
+ else:
54
+ for _ in range(max_sim_step):
55
+ actions = policy(obs)
56
+ obs, rews, dones, infos = env.step(actions)
57
+
58
+ return "Evaluation completed successfully."
59
+
60
+ # Create a Gradio interface
61
+ with gr.Blocks() as demo:
62
+ gr.Markdown("# Drone Hovering Policy Evaluation")
63
+
64
+ with gr.Row():
65
+ exp_name = gr.Textbox(label="Experiment Name", placeholder="drone-hovering")
66
+ ckpt = gr.Number(label="Checkpoint Index", value=300, precision=0)
67
+ record = gr.Checkbox(label="Record Video", value=False)
68
+
69
+ with gr.Row():
70
+ evaluate_btn = gr.Button("Evaluate Policy")
71
+
72
+ output = gr.Textbox(label="Evaluation Status")
73
+
74
+ evaluate_btn.click(fn=evaluate_policy, inputs=[exp_name, ckpt, record], outputs=output)
75
+
76
+ # Launch the interface
77
+ demo.launch(show_error=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ plotly
4
+ transformers_js_py
5
+ matplotlib