Spaces:
Runtime error
Runtime error
| """plot_prediction_planning_evaluation.py --load_from <wandb ID> --seed <seed> | |
| --scene_type <safer_fast or safer_slow> --risk_level <a list of risk-levels> | |
| --num_samples <a list of numbers of prediction samples> | |
| This script plots statistics of evaluation results generated by | |
| evaluate_prediction_planning_stack.py or evaluate_prediction_planning_stack_with_replanning.py. | |
| Add --with_replanning flag to plot results with re-planning, otherwise open-loop evaluations are | |
| used. | |
| """ | |
| import argparse | |
| import os | |
| import pickle | |
| from typing import List | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import scipy.stats as st | |
| def plot_main( | |
| stats_dir: str, | |
| scene_type: str, | |
| risk_level_list: List[float], | |
| num_prediction_samples_list: List[int], | |
| ) -> None: | |
| if not "with_replanning" in stats_dir: | |
| if 0.0 in risk_level_list: | |
| plot_computation_time( | |
| stats_dir, | |
| scene_type, | |
| num_prediction_samples_list=num_prediction_samples_list, | |
| ) | |
| plot_varying_risk( | |
| stats_dir, | |
| scene_type, | |
| num_prediction_samples_list=[num_prediction_samples_list[-1]], | |
| risk_level_list=risk_level_list, | |
| risk_in_planner=True, | |
| ) | |
| plot_varying_risk( | |
| stats_dir, | |
| scene_type, | |
| num_prediction_samples_list=[num_prediction_samples_list[-1]], | |
| risk_level_list=risk_level_list, | |
| risk_in_planner=False, | |
| ) | |
| plot_policy_comparison( | |
| stats_dir, | |
| scene_type, | |
| num_prediction_samples_list=num_prediction_samples_list, | |
| risk_level_list=list(filter(lambda r: r != 0.0, risk_level_list)), | |
| ) | |
| # How does computation time scale as we increase the number of samples? | |
| def plot_computation_time( | |
| stats_dir: str, | |
| scene_type: str, | |
| num_prediction_samples_list: List[int], | |
| alpha_for_confint: float = 0.95, | |
| ) -> None: | |
| risk_level = 0.0 | |
| stats_dict_zero_risk = dict() | |
| computation_time_mean_list, computation_time_sem_list = [], [] | |
| for num_samples in num_prediction_samples_list: | |
| file_path = os.path.join( | |
| stats_dir, | |
| f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}.pkl", | |
| ) | |
| assert os.path.exists( | |
| file_path | |
| ), f"missing experiment with num_samples == {num_samples} and risk_level == {risk_level}" | |
| with open(file_path, "rb") as f: | |
| stats_dict_zero_risk[num_samples] = pickle.load(f) | |
| num_episodes = _get_num_episodes(stats_dict_zero_risk[num_samples]) | |
| computation_time_list = [ | |
| stats_dict_zero_risk[num_samples][idx]["computation_time_ms"] | |
| for idx in range(num_episodes) | |
| ] | |
| computation_time_mean_list.append(np.mean(computation_time_list)) | |
| computation_time_sem_list.append(st.sem(computation_time_list)) | |
| # ref: https://www.statology.org/confidence-intervals-python/ | |
| confint_lower, confint_upper = st.norm.interval( | |
| alpha=alpha_for_confint, | |
| loc=computation_time_mean_list, | |
| scale=computation_time_sem_list, | |
| ) | |
| _, ax = plt.subplots(1, figsize=(6, 6)) | |
| ax.plot( | |
| num_prediction_samples_list, | |
| computation_time_mean_list, | |
| color="skyblue", | |
| linewidth=2.0, | |
| ) | |
| ax.fill_between( | |
| num_prediction_samples_list, | |
| confint_upper, | |
| confint_lower, | |
| facecolor="skyblue", | |
| alpha=0.3, | |
| ) | |
| ax.set_xlabel("Number of Prediction Samples") | |
| ax.set_ylabel("Computation Time for Prediction and Planning (ms)") | |
| plt.show() | |
| # How do varying risk-levels affect the safety/efficiency of the policy? | |
| def plot_varying_risk( | |
| stats_dir: str, | |
| scene_type: str, | |
| num_prediction_samples_list: List[int], | |
| risk_level_list: List[float], | |
| risk_in_planner: bool = False, | |
| alpha_for_confint: float = 0.95, | |
| ) -> None: | |
| _, ax = plt.subplots( | |
| 1, | |
| len(num_prediction_samples_list), | |
| figsize=(6 * len(num_prediction_samples_list), 6), | |
| ) | |
| if not type(ax) == np.ndarray: | |
| ax = [ax] | |
| stats_dict = dict() | |
| suptitle = "Safety-Efficiency Tradeoff of Optimized Policy" | |
| if "with_replanning" in stats_dir: | |
| suptitle += " with Replanning" | |
| if risk_in_planner: | |
| suptitle += " (Risk in Planner)" | |
| else: | |
| suptitle += " (Risk in Predictor)" | |
| plt.suptitle(suptitle) | |
| for (plot_idx, num_samples) in enumerate(num_prediction_samples_list): | |
| stats_dict[num_samples] = dict() | |
| interaction_cost_mean_list, interaction_cost_sem_list = [], [] | |
| tracking_cost_mean_list, tracking_cost_sem_list = [], [] | |
| for risk_level in risk_level_list: | |
| if risk_level == 0.0: | |
| file_path = os.path.join( | |
| stats_dir, | |
| f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}.pkl", | |
| ) | |
| elif risk_in_planner: | |
| file_path = os.path.join( | |
| stats_dir, | |
| f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_planner.pkl", | |
| ) | |
| else: | |
| file_path = os.path.join( | |
| stats_dir, | |
| f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_predictor.pkl", | |
| ) | |
| assert os.path.exists( | |
| file_path | |
| ), f"missing experiment with num_samples == {num_samples} and risk_level == {risk_level}" | |
| with open(file_path, "rb") as f: | |
| stats_dict[num_samples][risk_level] = pickle.load(f) | |
| num_episodes = _get_num_episodes(stats_dict[num_samples][risk_level]) | |
| interaction_cost_list = [ | |
| stats_dict[num_samples][risk_level][idx][ | |
| "interaction_cost_ground_truth" | |
| ] | |
| for idx in range(num_episodes) | |
| ] | |
| interaction_cost_mean_list.append(np.mean(interaction_cost_list)) | |
| interaction_cost_sem_list.append(st.sem(interaction_cost_list)) | |
| tracking_cost_list = [ | |
| stats_dict[num_samples][risk_level][idx]["tracking_cost"] | |
| for idx in range(num_episodes) | |
| ] | |
| tracking_cost_mean_list.append(np.mean(tracking_cost_list)) | |
| tracking_cost_sem_list.append(st.sem(tracking_cost_list)) | |
| ( | |
| interaction_cost_confint_lower, | |
| interaction_cost_confint_upper, | |
| ) = st.norm.interval( | |
| alpha=alpha_for_confint, | |
| loc=interaction_cost_mean_list, | |
| scale=interaction_cost_sem_list, | |
| ) | |
| (tracking_cost_confint_lower, tracking_cost_confint_upper,) = st.norm.interval( | |
| alpha=alpha_for_confint, | |
| loc=tracking_cost_mean_list, | |
| scale=tracking_cost_sem_list, | |
| ) | |
| ax[plot_idx].plot( | |
| risk_level_list, | |
| interaction_cost_mean_list, | |
| color="orange", | |
| linewidth=2.0, | |
| label="ground-truth collision cost", | |
| ) | |
| ax[plot_idx].fill_between( | |
| risk_level_list, | |
| interaction_cost_confint_upper, | |
| interaction_cost_confint_lower, | |
| color="orange", | |
| alpha=0.3, | |
| ) | |
| ax[plot_idx].plot( | |
| risk_level_list, | |
| tracking_cost_mean_list, | |
| color="lightgreen", | |
| linewidth=2.0, | |
| label="trajectory tracking cost", | |
| ) | |
| ax[plot_idx].fill_between( | |
| risk_level_list, | |
| tracking_cost_confint_upper, | |
| tracking_cost_confint_lower, | |
| color="lightgreen", | |
| alpha=0.3, | |
| ) | |
| if risk_in_planner: | |
| ax[plot_idx].set_xlabel("Risk-Sensitivity Level (in Planner)") | |
| else: | |
| ax[plot_idx].set_xlabel("Risk-Sensitivity Level (in Predictor)") | |
| ax[plot_idx].set_ylabel("Cost") | |
| ax[plot_idx].set_title(f"Number of Prediction Samples: {num_samples}") | |
| ax[plot_idx].legend(loc="upper right") | |
| plt.show() | |
| # How does (risk-biased predictor + risk-neutral planner) compare with (risk-neutral predictor + risk-sensitive planner) | |
| # in terms of characteristics of the optimized policy? | |
| def plot_policy_comparison( | |
| stats_dir: str, | |
| scene_type: str, | |
| num_prediction_samples_list: List[int], | |
| risk_level_list: List[float], | |
| alpha_for_confint: float = 0.95, | |
| ) -> None: | |
| assert not 0.0 in risk_level_list | |
| num_rows = 2 if "with_replanning" in stats_dir else 4 | |
| _, ax = plt.subplots( | |
| num_rows, len(risk_level_list), figsize=(6 * len(risk_level_list), 6 * num_rows) | |
| ) | |
| if len(risk_level_list) == 1: | |
| for row_idx in range(num_rows): | |
| ax[row_idx] = [ax[row_idx]] | |
| suptitle = "Characteristics of Optimized Policy" | |
| if "with_replanning" in stats_dir: | |
| suptitle += " with Replanning" | |
| plt.suptitle(suptitle) | |
| predictor_stats_dict, planner_stats_dict = dict(), dict() | |
| for (plot_idx, risk_level) in enumerate(risk_level_list): | |
| predictor_stats_dict[risk_level], planner_stats_dict[risk_level] = ( | |
| dict(), | |
| dict(), | |
| ) | |
| predictor_interaction_cost_mean_list, planner_interaction_cost_mean_list = ( | |
| [], | |
| [], | |
| ) | |
| predictor_interaction_cost_sem_list, planner_interaction_cost_sem_list = [], [] | |
| predictor_tracking_cost_mean_list, planner_tracking_cost_mean_list = [], [] | |
| predictor_tracking_cost_sem_list, planner_tracking_cost_sem_list = [], [] | |
| if not "with_replanning" in stats_dir: | |
| predictor_interaction_risk_mean_list, planner_interaction_risk_mean_list = ( | |
| [], | |
| [], | |
| ) | |
| predictor_interaction_risk_sem_list, planner_interaction_risk_sem_list = ( | |
| [], | |
| [], | |
| ) | |
| predictor_total_objective_mean_list, planner_total_objective_mean_list = ( | |
| [], | |
| [], | |
| ) | |
| predictor_total_objective_sem_list, planner_total_objective_sem_list = ( | |
| [], | |
| [], | |
| ) | |
| for num_samples in num_prediction_samples_list: | |
| file_path = os.path.join( | |
| stats_dir, | |
| f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_predictor.pkl", | |
| ) | |
| assert os.path.exists( | |
| file_path | |
| ), f"missing experiment with num_samples == {num_samples} and risk_level == {risk_level}" | |
| with open(file_path, "rb") as f: | |
| predictor_stats_dict[risk_level][num_samples] = pickle.load(f) | |
| predictor_num_episodes = _get_num_episodes( | |
| predictor_stats_dict[risk_level][num_samples] | |
| ) | |
| predictor_interaction_cost_list = [ | |
| predictor_stats_dict[risk_level][num_samples][idx][ | |
| "interaction_cost_ground_truth" | |
| ] | |
| for idx in range(predictor_num_episodes) | |
| ] | |
| predictor_interaction_cost_mean_list.append( | |
| np.mean(predictor_interaction_cost_list) | |
| ) | |
| predictor_interaction_cost_sem_list.append( | |
| st.sem(predictor_interaction_cost_list) | |
| ) | |
| predictor_tracking_cost_list = [ | |
| predictor_stats_dict[risk_level][num_samples][idx]["tracking_cost"] | |
| for idx in range(predictor_num_episodes) | |
| ] | |
| predictor_tracking_cost_mean_list.append( | |
| np.mean(predictor_tracking_cost_list) | |
| ) | |
| predictor_tracking_cost_sem_list.append( | |
| st.sem(predictor_tracking_cost_list) | |
| ) | |
| if not "with_replanning" in stats_dir: | |
| predictor_interaction_risk_list = [ | |
| predictor_stats_dict[risk_level][num_samples][idx][ | |
| "interaction_risk" | |
| ] | |
| for idx in range(predictor_num_episodes) | |
| ] | |
| predictor_interaction_risk_mean_list.append( | |
| np.mean(predictor_interaction_risk_list) | |
| ) | |
| predictor_interaction_risk_sem_list.append( | |
| st.sem(predictor_interaction_risk_list) | |
| ) | |
| predictor_total_objective_list = [ | |
| interaction_risk + tracking_cost | |
| for (interaction_risk, tracking_cost) in zip( | |
| predictor_interaction_risk_list, predictor_tracking_cost_list | |
| ) | |
| ] | |
| predictor_total_objective_mean_list.append( | |
| np.mean(predictor_total_objective_list) | |
| ) | |
| predictor_total_objective_sem_list.append( | |
| st.sem(predictor_total_objective_list) | |
| ) | |
| file_path = os.path.join( | |
| stats_dir, | |
| f"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_planner.pkl", | |
| ) | |
| assert os.path.exists( | |
| file_path | |
| ), f"missing experiment with num_samples == {num_samples} and risk_level == {risk_level}" | |
| with open(file_path, "rb") as f: | |
| planner_stats_dict[risk_level][num_samples] = pickle.load(f) | |
| planner_num_episodes = _get_num_episodes( | |
| planner_stats_dict[risk_level][num_samples] | |
| ) | |
| planner_interaction_cost_list = [ | |
| planner_stats_dict[risk_level][num_samples][idx][ | |
| "interaction_cost_ground_truth" | |
| ] | |
| for idx in range(planner_num_episodes) | |
| ] | |
| planner_interaction_cost_mean_list.append( | |
| np.mean(planner_interaction_cost_list) | |
| ) | |
| planner_interaction_cost_sem_list.append( | |
| st.sem(planner_interaction_cost_list) | |
| ) | |
| planner_tracking_cost_list = [ | |
| planner_stats_dict[risk_level][num_samples][idx]["tracking_cost"] | |
| for idx in range(planner_num_episodes) | |
| ] | |
| planner_tracking_cost_mean_list.append(np.mean(planner_tracking_cost_list)) | |
| planner_tracking_cost_sem_list.append(st.sem(planner_tracking_cost_list)) | |
| if not "with_replanning" in stats_dir: | |
| planner_interaction_risk_list = [ | |
| planner_stats_dict[risk_level][num_samples][idx]["interaction_risk"] | |
| for idx in range(planner_num_episodes) | |
| ] | |
| planner_interaction_risk_mean_list.append( | |
| np.mean(planner_interaction_risk_list) | |
| ) | |
| planner_interaction_risk_sem_list.append( | |
| st.sem(planner_interaction_risk_list) | |
| ) | |
| planner_total_objective_list = [ | |
| interaction_risk + tracking_cost | |
| for (interaction_risk, tracking_cost) in zip( | |
| planner_interaction_risk_list, planner_tracking_cost_list | |
| ) | |
| ] | |
| planner_total_objective_mean_list.append( | |
| np.mean(planner_total_objective_list) | |
| ) | |
| planner_total_objective_sem_list.append( | |
| st.sem(planner_total_objective_list) | |
| ) | |
| ( | |
| predictor_interaction_cost_confint_lower, | |
| predictor_interaction_cost_confint_upper, | |
| ) = st.norm.interval( | |
| alpha=alpha_for_confint, | |
| loc=predictor_interaction_cost_mean_list, | |
| scale=predictor_interaction_cost_sem_list, | |
| ) | |
| ( | |
| predictor_tracking_cost_confint_lower, | |
| predictor_tracking_cost_confint_upper, | |
| ) = st.norm.interval( | |
| alpha=alpha_for_confint, | |
| loc=predictor_tracking_cost_mean_list, | |
| scale=predictor_tracking_cost_sem_list, | |
| ) | |
| if not "with_replanning" in stats_dir: | |
| ( | |
| predictor_interaction_risk_confint_lower, | |
| predictor_interaction_risk_confint_upper, | |
| ) = st.norm.interval( | |
| alpha=alpha_for_confint, | |
| loc=predictor_interaction_risk_mean_list, | |
| scale=predictor_interaction_risk_sem_list, | |
| ) | |
| ( | |
| predictor_total_objective_confint_lower, | |
| predictor_total_objective_confint_upper, | |
| ) = st.norm.interval( | |
| alpha=alpha_for_confint, | |
| loc=predictor_total_objective_mean_list, | |
| scale=predictor_total_objective_sem_list, | |
| ) | |
| ( | |
| planner_interaction_cost_confint_lower, | |
| planner_interaction_cost_confint_upper, | |
| ) = st.norm.interval( | |
| alpha=alpha_for_confint, | |
| loc=planner_interaction_cost_mean_list, | |
| scale=planner_interaction_cost_sem_list, | |
| ) | |
| ( | |
| planner_tracking_cost_confint_lower, | |
| planner_tracking_cost_confint_upper, | |
| ) = st.norm.interval( | |
| alpha=alpha_for_confint, | |
| loc=planner_tracking_cost_mean_list, | |
| scale=planner_tracking_cost_sem_list, | |
| ) | |
| if not "with_replanning" in stats_dir: | |
| ( | |
| planner_interaction_risk_confint_lower, | |
| planner_interaction_risk_confint_upper, | |
| ) = st.norm.interval( | |
| alpha=alpha_for_confint, | |
| loc=planner_interaction_risk_mean_list, | |
| scale=planner_interaction_risk_sem_list, | |
| ) | |
| ( | |
| planner_total_objective_confint_lower, | |
| planner_total_objective_confint_upper, | |
| ) = st.norm.interval( | |
| alpha=alpha_for_confint, | |
| loc=planner_total_objective_mean_list, | |
| scale=planner_total_objective_sem_list, | |
| ) | |
| ax[0][plot_idx].plot( | |
| num_prediction_samples_list, | |
| planner_interaction_cost_mean_list, | |
| color="skyblue", | |
| linewidth=2.0, | |
| label="risk in planner", | |
| ) | |
| ax[0][plot_idx].fill_between( | |
| num_prediction_samples_list, | |
| planner_interaction_cost_confint_upper, | |
| planner_interaction_cost_confint_lower, | |
| color="skyblue", | |
| alpha=0.3, | |
| ) | |
| ax[0][plot_idx].plot( | |
| num_prediction_samples_list, | |
| predictor_interaction_cost_mean_list, | |
| color="orange", | |
| linewidth=2.0, | |
| label="risk in predictor", | |
| ) | |
| ax[0][plot_idx].fill_between( | |
| num_prediction_samples_list, | |
| predictor_interaction_cost_confint_upper, | |
| predictor_interaction_cost_confint_lower, | |
| color="orange", | |
| alpha=0.3, | |
| ) | |
| ax[0][plot_idx].set_xlabel("Number of Prediction Samples") | |
| ax[0][plot_idx].set_ylabel("Ground-Truth Collision Cost") | |
| ax[0][plot_idx].set_title(f"Risk-Sensitivity Level: {risk_level}") | |
| ax[0][plot_idx].legend(loc="upper right") | |
| ax[0][plot_idx].set_xscale("log") | |
| ax[1][plot_idx].plot( | |
| num_prediction_samples_list, | |
| planner_tracking_cost_mean_list, | |
| color="skyblue", | |
| linewidth=2.0, | |
| label="risk in planner", | |
| ) | |
| ax[1][plot_idx].fill_between( | |
| num_prediction_samples_list, | |
| planner_tracking_cost_confint_upper, | |
| planner_tracking_cost_confint_lower, | |
| color="skyblue", | |
| alpha=0.3, | |
| ) | |
| ax[1][plot_idx].plot( | |
| num_prediction_samples_list, | |
| predictor_tracking_cost_mean_list, | |
| color="orange", | |
| linewidth=2.0, | |
| label="risk in predictor", | |
| ) | |
| ax[1][plot_idx].fill_between( | |
| num_prediction_samples_list, | |
| predictor_tracking_cost_confint_upper, | |
| predictor_tracking_cost_confint_lower, | |
| color="orange", | |
| alpha=0.3, | |
| ) | |
| ax[1][plot_idx].set_xlabel("Number of Prediction Samples") | |
| ax[1][plot_idx].set_ylabel("Trajectory Tracking Cost") | |
| # ax[1][plot_idx].set_title(f"Risk-Sensitivity Level: {risk_level}") | |
| ax[1][plot_idx].legend(loc="lower right") | |
| ax[1][plot_idx].set_xscale("log") | |
| if not "with_replanning" in stats_dir: | |
| ax[2][plot_idx].plot( | |
| num_prediction_samples_list, | |
| planner_interaction_risk_mean_list, | |
| color="skyblue", | |
| linewidth=2.0, | |
| label="risk in planner", | |
| ) | |
| ax[2][plot_idx].fill_between( | |
| num_prediction_samples_list, | |
| planner_interaction_risk_confint_upper, | |
| planner_interaction_risk_confint_lower, | |
| color="skyblue", | |
| alpha=0.3, | |
| ) | |
| ax[2][plot_idx].plot( | |
| num_prediction_samples_list, | |
| predictor_interaction_risk_mean_list, | |
| color="orange", | |
| linewidth=2.0, | |
| label="risk in predictor", | |
| ) | |
| ax[2][plot_idx].fill_between( | |
| num_prediction_samples_list, | |
| predictor_interaction_risk_confint_upper, | |
| predictor_interaction_risk_confint_lower, | |
| color="orange", | |
| alpha=0.3, | |
| ) | |
| ax[2][plot_idx].set_xlabel("Number of Prediction Samples") | |
| ax[2][plot_idx].set_ylabel("Collision Risk") | |
| # ax[2][plot_idx].set_title(f"Risk-Sensitivity Level: {risk_level}") | |
| ax[2][plot_idx].legend(loc="upper right") | |
| ax[2][plot_idx].set_xscale("log") | |
| ax[3][plot_idx].plot( | |
| num_prediction_samples_list, | |
| planner_total_objective_mean_list, | |
| color="skyblue", | |
| linewidth=2.0, | |
| label="risk in planner", | |
| ) | |
| ax[3][plot_idx].fill_between( | |
| num_prediction_samples_list, | |
| planner_total_objective_confint_upper, | |
| planner_total_objective_confint_lower, | |
| color="skyblue", | |
| alpha=0.3, | |
| ) | |
| ax[3][plot_idx].plot( | |
| num_prediction_samples_list, | |
| predictor_total_objective_mean_list, | |
| color="orange", | |
| linewidth=2.0, | |
| label="risk in predictor", | |
| ) | |
| ax[3][plot_idx].fill_between( | |
| num_prediction_samples_list, | |
| predictor_total_objective_confint_upper, | |
| predictor_total_objective_confint_lower, | |
| color="orange", | |
| alpha=0.3, | |
| ) | |
| ax[3][plot_idx].set_xlabel("Number of Prediction Samples") | |
| ax[3][plot_idx].set_ylabel("Planner's Total Objective") | |
| # ax[3][plot_idx].set_title(f"Risk-Sensitivity Level: {risk_level}") | |
| ax[3][plot_idx].legend(loc="upper right") | |
| ax[3][plot_idx].set_xscale("log") | |
| plt.show() | |
| def _get_num_episodes(stats_dict: dict): | |
| return max(filter(lambda key: type(key) == int, stats_dict)) + 1 | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="visualize evaluation result of evaluate_prediction_planning_stack.py" | |
| ) | |
| parser.add_argument( | |
| "--load_from", | |
| type=str, | |
| required=True, | |
| help="WandB ID for specification of trained predictor", | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| required=False, | |
| default=0, | |
| ) | |
| parser.add_argument( | |
| "--scene_type", | |
| type=str, | |
| choices=["safer_fast", "safer_slow"], | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--with_replanning", | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--risk_level", | |
| type=float, | |
| nargs="+", | |
| help="Risk-sensitivity level(s) to test", | |
| default=[0.95, 1.0], | |
| ) | |
| parser.add_argument( | |
| "--num_samples", | |
| type=int, | |
| nargs="+", | |
| help="Number(s) of prediction samples to test", | |
| default=[1, 4, 16, 64, 256, 1024], | |
| ) | |
| parser.add_argument( | |
| "--force_config", | |
| action="store_true", | |
| help="""Use this flag to force the use of the local config file | |
| when loading a model from a checkpoint. Otherwise the checkpoint config file is used. | |
| In any case the parameters can be overwritten with an argparse argument.""", | |
| ) | |
| args = parser.parse_args() | |
| dir_name = ( | |
| "planner_eval_with_replanning" if args.with_replanning else "planner_eval" | |
| ) | |
| stats_dir = os.path.join( | |
| os.path.dirname(os.path.realpath(__file__)), | |
| "logs", | |
| dir_name, | |
| f"run-{args.load_from}_{args.seed}", | |
| ) | |
| postfix_string = "_with_replanning" if args.with_replanning else "" | |
| assert os.path.exists( | |
| stats_dir | |
| ), f"{stats_dir} does not exist. Did you run 'evaluate_prediction_planning_stack{postfix_string}.py --load_from {args.load_from} --seed {args.seed}' ?" | |
| plot_main(stats_dir, args.scene_type, args.risk_level, args.num_samples) | |