Spaces:
Runtime error
Runtime error
| from datasets import load_dataset, Dataset | |
| import fire | |
| from functools import partial, update_wrapper | |
| import numpy | |
| import os | |
| from typing import Dict, Iterable, Tuple | |
| import sys | |
| import time | |
| import torch | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from mmcv import Config | |
| import plotly.graph_objects as go | |
| from torch.utils.data.dataloader import DataLoader | |
| from risk_biased.utils.load_model import get_predictor | |
| from risk_biased.utils.torch_utils import load_weights | |
| from risk_biased.utils.waymo_dataloader import WaymoDataloaders | |
| from risk_biased.predictors.biased_predictor import ( | |
| LitTrajectoryPredictor, | |
| ) | |
| def to_numpy(**kwargs): | |
| dic_outputs = {} | |
| for k, v in kwargs.items(): | |
| dic_outputs[k] = v.detach().cpu().numpy() | |
| return dic_outputs | |
| def get_scatter_data(x, mask_x, name, **kwargs): | |
| return [ | |
| go.Scatter( | |
| x=x[k, mask_x[k], 0], | |
| y=x[k, mask_x[k], 1], | |
| showlegend=k == 0, | |
| name=name, | |
| **kwargs, | |
| ) | |
| for k in range(x.shape[0]) | |
| ] | |
| def configuration_paths() -> Iterable[os.PathLike]: | |
| working_dir = os.path.dirname(os.path.realpath(__file__)) | |
| return [ | |
| os.path.join( | |
| working_dir, | |
| "../../risk_biased/config", | |
| config_file, | |
| ) | |
| for config_file in ("learning_config.py", "waymo_config.py") | |
| ] | |
| def load_item(index: int, dataset: Dataset, device: str = "cpu") -> Tuple: | |
| x = torch.from_numpy(numpy.array(dataset[index]["x"]).astype(numpy.float32)).to(device) | |
| mask_x = torch.from_numpy(numpy.array(dataset[index]["mask_x"]).astype(numpy.bool8)).to(device) | |
| y = torch.from_numpy(numpy.array(dataset[index]["y"]).astype(numpy.float32)).to(device) | |
| mask_y = torch.from_numpy(numpy.array(dataset[index]["mask_y"]).astype(numpy.bool8)).to(device) | |
| mask_loss = torch.from_numpy( numpy.array(dataset[index]["mask_loss"]).astype(numpy.bool8)).to(device) | |
| map_data = torch.from_numpy(numpy.array(dataset[index]["map_data"]).astype(numpy.float32)).to(device) | |
| mask_map = torch.from_numpy(numpy.array(dataset[index]["mask_map"]).astype(numpy.bool8)).to(device) | |
| offset = torch.from_numpy(numpy.array(dataset[index]["offset"]).astype(numpy.float32)).to(device) | |
| x_ego = torch.from_numpy(numpy.array(dataset[index]["x_ego"]).astype(numpy.float32)).to(device) | |
| y_ego = torch.from_numpy(numpy.array(dataset[index]["y_ego"]).astype(numpy.float32)).to(device) | |
| return (x, mask_x, map_data, mask_map, offset, x_ego, y_ego), y, mask_y, mask_loss | |
| def build_data( | |
| predictor: LitTrajectoryPredictor, | |
| dataset: Dataset, | |
| index: int, | |
| risk_level: float, | |
| n_samples: int, | |
| ) -> Dict[str, go.Scatter]: | |
| assert n_samples >= 1 | |
| batch, y, mask_y, mask_loss = load_item(index, dataset, predictor.device) | |
| predictions = predictor.predict_step( | |
| batch=batch, | |
| risk_level=risk_level, | |
| n_samples=n_samples, | |
| ) | |
| offset = batch[4] | |
| y = predictor._unnormalize_trajectory(y, offset) | |
| x = predictor._unnormalize_trajectory(batch[0], offset) | |
| numpy_data = to_numpy( | |
| predictions=predictions, | |
| y=y, | |
| mask_y=mask_y, | |
| x=x, | |
| mask_x=batch[1], | |
| map_data=batch[2], | |
| mask_map=batch[3], | |
| mask_pred=mask_loss, | |
| ) | |
| x = numpy_data["x"][0] | |
| mask_x = numpy_data["mask_x"][0] | |
| y = numpy_data["y"][0] | |
| mask_y = numpy_data["mask_y"][0] | |
| pred = numpy_data["predictions"][0] | |
| mask_pred = numpy_data["mask_pred"][0] | |
| map_data = numpy_data["map_data"][0] | |
| mask_map = numpy_data["mask_map"][0] | |
| marker_size = 12 | |
| data_x = get_scatter_data( | |
| x, | |
| mask_x, | |
| mode="lines", | |
| line=dict(width=2, color="black"), | |
| name="Past", | |
| ) | |
| ego_present = get_scatter_data( | |
| x=x[0:1, -1:], | |
| mask_x=mask_x[0:1, -1:], | |
| mode="markers", | |
| marker=dict(color="blue", size=marker_size, opacity=0.5), | |
| name="Ego", | |
| ) | |
| agent_present = get_scatter_data( | |
| x=x[1:2, -1:], | |
| mask_x=mask_x[1:2, -1:], | |
| mode="markers", | |
| marker=dict(color="green", size=marker_size, opacity=0.5), | |
| name="Agent", | |
| ) | |
| data_y = get_scatter_data( | |
| y, | |
| mask_y, | |
| mode="lines", | |
| line=dict(width=2, color="green"), | |
| name="Ground truth", | |
| ) | |
| data_map = get_scatter_data( | |
| map_data, | |
| mask_map, | |
| mode="lines", | |
| line=dict(width=15, color="gray"), | |
| opacity=0.3, | |
| name="Centerline", | |
| ) | |
| data_pred = [] | |
| forecasts_end = [] | |
| for i in range(n_samples): | |
| cur_data_pred = get_scatter_data( | |
| pred[:, i], | |
| mask_pred, | |
| mode="lines", | |
| line=dict(width=2, color="red"), | |
| name="Forecast", | |
| ) | |
| data_pred += cur_data_pred | |
| forecast_end = get_scatter_data( | |
| pred[:, i, -1:], | |
| mask_pred[:, -1:], | |
| mode="markers", | |
| marker=dict(color="red", size=marker_size/2, opacity=0.5, symbol="x"), | |
| name="Forecast end", | |
| ) | |
| forecasts_end += forecast_end | |
| static_data = data_map + data_x + data_y + data_pred + ego_present + agent_present + forecasts_end | |
| animation_opacity = 0.5 | |
| frames_x = [ | |
| go.Frame( | |
| data=[ | |
| go.Scatter( | |
| x=x[mask_x[:, k], k, 0], | |
| y=x[mask_x[:, k], k, 1], | |
| mode="markers", | |
| opacity=animation_opacity, | |
| marker=dict(color="black", size=marker_size), | |
| showlegend=False, | |
| ), | |
| go.Scatter( | |
| x=x[0:1, k, 0], | |
| y=x[0:1, k, 1], | |
| mode="markers", | |
| opacity=animation_opacity, | |
| marker=dict(color="blue", size=marker_size), | |
| showlegend=False, | |
| ), | |
| ] | |
| ) | |
| for k in range(x.shape[1]) | |
| ] | |
| frames_y_pred = [] | |
| for k in range(y.shape[1]): | |
| cur_gt_agent_data = go.Scatter( | |
| x=y[1:2][mask_y[1:2, k], k, 0], | |
| y=y[1:2][mask_y[1:2, k], k, 1], | |
| mode="markers", | |
| opacity=animation_opacity, | |
| marker=dict(color="green", size=marker_size), | |
| ) | |
| cur_gt_future_data = go.Scatter( | |
| x=y[2:][mask_y[2:, k], k, 0], | |
| y=y[2:][mask_y[2:, k], k, 1], | |
| mode="markers", | |
| opacity=animation_opacity, | |
| marker=dict(color="black", size=marker_size), | |
| ) | |
| cur_pred_data = [] | |
| for i in range(n_samples): | |
| cur_pred_data.append( | |
| go.Scatter( | |
| x=pred[mask_pred[:, k], i, k, 0], | |
| y=pred[mask_pred[:, k], i, k, 1], | |
| mode="markers", | |
| opacity=animation_opacity, | |
| marker=dict(color="red", size=marker_size), | |
| showlegend=False, | |
| ) | |
| ) | |
| cur_ego_data = go.Scatter( | |
| x=y[0:1, k, 0], | |
| y=y[0:1, k, 1], | |
| mode="markers", | |
| opacity=animation_opacity, | |
| marker=dict(color="blue", size=marker_size), | |
| ) | |
| cur_data = [cur_gt_agent_data, cur_gt_future_data, *cur_pred_data, cur_ego_data] | |
| frame = go.Frame(data=cur_data) | |
| frames_y_pred.append(frame) | |
| return {"frames": frames_x + frames_y_pred, "data": static_data} | |
| def prediction_plot( | |
| predictor: LitTrajectoryPredictor, | |
| dataset: Dataset, | |
| index: int, | |
| risk_level: float, | |
| n_samples: int = 1, | |
| use_biaser: bool = True, | |
| ) -> go.Figure: | |
| range_radius = 80 | |
| if use_biaser: | |
| risk_level = float(risk_level) | |
| else: | |
| risk_level = None | |
| layout = go.Layout( | |
| xaxis=dict( | |
| range=[-0.5*range_radius, 1.5*range_radius], | |
| autorange=False, | |
| zeroline=False, | |
| ), | |
| yaxis=dict( | |
| range=[-range_radius, range_radius], | |
| autorange=False, | |
| zeroline=False, | |
| ), | |
| title_text="Road Scene", | |
| hovermode="closest", | |
| width=800, | |
| height=600, | |
| updatemenus=[ | |
| dict( | |
| type="buttons", | |
| buttons=[ | |
| dict( | |
| label="Play", | |
| method="animate", | |
| args=[ | |
| None, | |
| dict( | |
| frame=dict(duration=100, redraw=False), | |
| mode="immediate", | |
| fromcurrent=True, | |
| ), | |
| ], | |
| ), | |
| dict( | |
| label="Pause", | |
| method="animate", | |
| args=[[None], {"frame": {"duration": 0, "redraw": False}, | |
| "mode": "immediate", | |
| "transition": {"duration": 0}}], | |
| ) | |
| ], | |
| ) | |
| ], | |
| ) | |
| fig = go.Figure( | |
| **build_data(predictor, dataset, index, risk_level, n_samples), | |
| layout=layout, | |
| ) | |
| fig.update_geos(projection_type="equirectangular", visible=True, resolution=110) | |
| return fig | |
| def get_figure( | |
| predictor: LitTrajectoryPredictor, | |
| dataset: Dataset, | |
| index: int, | |
| risk_level: float, | |
| n_samples: int, | |
| ) -> go.Figure: | |
| fig = prediction_plot( | |
| predictor, dataset, index, risk_level, n_samples, use_biaser=True | |
| ) | |
| fig.update_layout() | |
| return fig | |
| def update_figure( | |
| predictor: LitTrajectoryPredictor, | |
| dataset: Dataset, | |
| index: int, | |
| risk_level: float, | |
| n_samples: int, | |
| image = None | |
| ) -> go.Figure: | |
| fig = prediction_plot( | |
| predictor, dataset, index, risk_level, n_samples, use_biaser=True | |
| ) | |
| fig.update_layout() | |
| return fig | |
| def load_predictor_from_hf(model_source: str = "TRI-ML/risk_biased_model", config_name: str="learning_config.py", checkpoint_name: str = "last.ckpt", device: str = "cpu") -> Tuple[LitTrajectoryPredictor, Dataset]: | |
| config_file = hf_hub_download(model_source, filename=config_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')) | |
| ckpt = torch.load(hf_hub_download(model_source, filename=checkpoint_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')), map_location="cpu") | |
| cfg = Config.fromfile(config_file) | |
| predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory) | |
| predictor = load_weights(predictor, ckpt) | |
| predictor.eval() | |
| predictor = predictor.to(device) | |
| return predictor | |
| def load_dataset_from_hf(data_source: str = "jmercat/risk_biased_dataset") -> Dataset: | |
| dataset = load_dataset(data_source, split="test") | |
| return dataset | |
| def main(load_from=None, cfg_path=None): | |
| # Define the device to use | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Getting dataset") | |
| dataset = load_dataset_from_hf() | |
| if load_from is not None: | |
| cfg = Config.fromfile(cfg_path) | |
| predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory) | |
| predictor = load_weights(predictor, torch.load(load_from, map_location="cpu")) | |
| else: | |
| print("Getting model.") | |
| predictor = load_predictor_from_hf(device=device) | |
| ui_update_fn = partial(update_figure, predictor, dataset) | |
| # Do the same thing as above but using the gradio blocks API | |
| with gr.Blocks() as interface: | |
| gr.Markdown( | |
| """ | |
| # Risk-Aware Prediction | |
| Make predictions for the green agent with a risk-seeking bias towards the ego vehicle in blue. | |
| The risk level is a value between 0 and 1, where 0 is not risk-seeking and 1 is the most risk-seeking. | |
| Once the sliders are set, click the "Run" button to see the predictions. | |
| The play button will animate the prediction over time (it is slow especially with many samples). | |
| For more information, see the paper [RAP: Risk-Aware Prediction for Robust Planning](https://arxiv.org/abs/2210.01368) published at [CoRL 2022](https://corl2022.org/). | |
| """) | |
| initial_index = 27 | |
| initial_n_samples = 10 | |
| image = gr.Plot(get_figure(predictor, dataset, initial_index, 0, initial_n_samples)) | |
| interface.queue() | |
| index = gr.Slider( | |
| minimum=0, | |
| maximum=len(dataset)-1, | |
| step=1, | |
| value=initial_index, | |
| label="Index", | |
| ) | |
| risk_level = gr.Slider(minimum=0, maximum=1, step=0.01, label="Risk") | |
| n_samples = gr.Slider(minimum=1, maximum=20, step=1, value=initial_n_samples, label="Number of prediction samples") | |
| button = gr.Button(label="Run") | |
| # Removed the interactive plot because it was running on the first change and all changes made during computation were ignored | |
| # This caused the plot to be out of sync with the sliders | |
| # index.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) | |
| # risk_level.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) | |
| # n_samples.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) | |
| button.click(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) | |
| interface.launch(debug=False) | |
| if __name__ == "__main__": | |
| fire.Fire(main) | |