Spaces:
Runtime error
Runtime error
| import copy | |
| from dataclasses import dataclass | |
| from mmcv import Config | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from pydantic import NoneBytes | |
| import pytorch_lightning as pl | |
| import torch | |
| import wandb | |
| from risk_biased.scene_dataset.loaders import SceneDataLoaders | |
| from risk_biased.scene_dataset.scene import RandomScene, RandomSceneParams | |
| from risk_biased.scene_dataset.scene_plotter import ScenePlotter | |
| from risk_biased.utils.cost import ( | |
| DistanceCostNumpy, | |
| DistanceCostParams, | |
| TTCCostNumpy, | |
| TTCCostParams, | |
| ) | |
| from risk_biased.utils.risk import get_risk_level_sampler | |
| class SwitchTrainingModeCallback(pl.Callback): | |
| """ | |
| This callback switches between CVAE traning and biasing training for the biased_latent_cvae_model | |
| Args: | |
| switch_at_epoch: The number of epoch after which to make the switch. The CVAE is not trained anymore after that point. | |
| """ | |
| def __init__(self, switch_at_epoch: int) -> None: | |
| super().__init__() | |
| self._switch_at_epoch = switch_at_epoch | |
| self._train_has_started = False | |
| def on_train_start( | |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule | |
| ) -> None: | |
| """Store the optimizer list and set the trainer to the first optimizer.""" | |
| self._optimizers = trainer.optimizers | |
| trainer.optimizers = [self._optimizers[0]] | |
| self._train_has_started = True | |
| def on_epoch_start( | |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule | |
| ) -> None: | |
| """ | |
| Check if the switch should be made and if so, | |
| set the trainer on the second optimizer. | |
| """ | |
| if trainer.current_epoch == self._switch_at_epoch and self._train_has_started: | |
| print("Switching to bias training.") | |
| pl_module.set_training_mode("bias") | |
| trainer.optimizers = [self._optimizers[1]] | |
| def get_fast_slow_scenes(params: RandomSceneParams, n_samples: int): | |
| """Define and return two RandomScene objects, one initialized such that slow | |
| pedestrians are safer and the other such that fast pedestrians are safer. | |
| Args: | |
| params: dataclass containing the necessary parameters for a RandomScene object | |
| n_samples: number of samples to draw in each scene | |
| """ | |
| params = copy.deepcopy(params) | |
| params.batch_size = n_samples | |
| scene_safe_slow = RandomScene( | |
| params, | |
| is_torch=False, | |
| ) | |
| percent_right = 0.8 | |
| percent_top = 1.1 | |
| angle = 5 * np.pi / 4 | |
| positions = np.array([[[percent_right, percent_top]]] * n_samples) | |
| angles = np.array([[angle]] * n_samples) | |
| scene_safe_slow.set_pedestrians_states(positions, angles) | |
| scene_safe_fast = RandomScene( | |
| params, | |
| is_torch=False, | |
| ) | |
| percent_right = 0.8 | |
| percent_top = 0.6 | |
| angle = 5 * np.pi / 4 | |
| positions = np.array([[[percent_right, percent_top]]] * n_samples) | |
| angles = np.array([[angle]] * n_samples) | |
| scene_safe_fast.set_pedestrians_states(positions, angles) | |
| return scene_safe_fast, scene_safe_slow | |
| class DrawCallbackParams: | |
| """ | |
| Args: | |
| scene_params: dataclass parameters for the RandomScene | |
| dist_cost_params: dataclass parameters for the DistanceCost | |
| ttc_cost_params: dataclass parameters for the TTCCost | |
| plot_interval_epoch: number of epochs between each plot drawing | |
| histogram_interval_epoch: number of epochs between each histogram drawing | |
| num_steps: number of time steps as defined in the config | |
| num_steps_future: number of time steps in the future as defined in the config | |
| risk_distribution: dict object describing a risk distribution | |
| dt: time step size as defined in the config | |
| """ | |
| scene_params: RandomSceneParams | |
| dist_cost_params: DistanceCostParams | |
| ttc_cost_params: TTCCostParams | |
| plot_interval_epoch: int | |
| histogram_interval_epoch: int | |
| num_steps: int | |
| num_steps_future: int | |
| risk_distribution: dict | |
| dt: float | |
| def from_config(cfg: Config): | |
| return DrawCallbackParams( | |
| scene_params=RandomSceneParams.from_config(cfg), | |
| dist_cost_params=DistanceCostParams.from_config(cfg), | |
| ttc_cost_params=TTCCostParams.from_config(cfg), | |
| plot_interval_epoch=cfg.plot_interval_epoch, | |
| histogram_interval_epoch=cfg.histogram_interval_epoch, | |
| num_steps=cfg.num_steps, | |
| num_steps_future=cfg.num_steps_future, | |
| risk_distribution=cfg.risk_distribution, | |
| dt=cfg.dt, | |
| ) | |
| class HistogramCallback(pl.Callback): | |
| """Logs histograms of distances, distance cost and ttc cost for the data, the predictions at risk_level=0, the predictions at risk_level=1 | |
| Args: | |
| params: dataclass defining the necessary parameters | |
| n_samples: Number of samples to use for the histogram plot | |
| """ | |
| def __init__( | |
| self, | |
| params: DrawCallbackParams, | |
| n_samples=1000, | |
| ): | |
| super().__init__() | |
| self.scene_safe_fast, self.scene_safe_slow = get_fast_slow_scenes( | |
| params.scene_params, n_samples | |
| ) | |
| self.num_steps = params.num_steps | |
| self.n_scenes = n_samples | |
| self.sample_times = params.scene_params.sample_times | |
| self.dist_cost_func = DistanceCostNumpy(params.dist_cost_params) | |
| self.ttc_cost_func = TTCCostNumpy(params.ttc_cost_params) | |
| self.histogram_interval_epoch = params.histogram_interval_epoch | |
| self.ego_traj = self.scene_safe_fast.get_ego_ref_trajectory(self.sample_times) | |
| self._risk_sampler = get_risk_level_sampler(params.risk_distribution) | |
| def _log_scene(self, pl_module: pl.LightningModule, scene: RandomScene, name: str): | |
| """ | |
| Log in WandB three histogram for the given scene: One for the data, one for the predictions at risk_level=0 and one for the predictions at risk_level=1 | |
| Args: | |
| pl_module: LightningModule object | |
| scene: RandomScene object | |
| name: name of the given scene | |
| """ | |
| ped_trajs = scene.get_pedestrians_trajectories() | |
| device = pl_module.device | |
| n_agents = ped_trajs.shape[1] | |
| input_traj = ped_trajs[..., : self.num_steps, :] | |
| normalized_input, offset = SceneDataLoaders.normalize_trajectory( | |
| torch.from_numpy(input_traj.astype("float32")).contiguous().to(device) | |
| ) | |
| mask_input = torch.ones_like(normalized_input[..., 0]) | |
| ego_history = ( | |
| torch.from_numpy(self.ego_traj[..., : self.num_steps, :].astype("float32")) | |
| .expand_as(normalized_input) | |
| .contiguous() | |
| .to(device) | |
| ) | |
| ego_future = ( | |
| torch.from_numpy(self.ego_traj[..., self.num_steps :, :].astype("float32")) | |
| .expand(normalized_input.shape[0], n_agents, -1, -1) | |
| .contiguous() | |
| .to(device) | |
| ) | |
| map = torch.empty(ego_history.shape[0], 0, 0, 2, device=mask_input.device) | |
| mask_map = torch.empty(ego_history.shape[0], 0, 0, device=mask_input.device) | |
| pred_riskier = ( | |
| pl_module.predict_step( | |
| ( | |
| normalized_input, | |
| mask_input, | |
| map, | |
| mask_map, | |
| offset, | |
| ego_history, | |
| ego_future, | |
| ), | |
| 0, | |
| risk_level=self._risk_sampler.get_highest_risk( | |
| batch_size=self.n_scenes, device=device | |
| ) | |
| .unsqueeze(1) | |
| .repeat(1, n_agents), | |
| ) | |
| .cpu() | |
| .detach() | |
| .numpy() | |
| ) | |
| pred = ( | |
| pl_module.predict_step( | |
| ( | |
| normalized_input, | |
| mask_input, | |
| map, | |
| mask_map, | |
| offset, | |
| ego_history, | |
| ego_future, | |
| ), | |
| 0, | |
| risk_level=None, | |
| ) | |
| .cpu() | |
| .detach() | |
| .numpy() | |
| ) | |
| ped_trajs_pred = np.concatenate((input_traj, pred), axis=-2) | |
| ped_trajs_pred_riskier = np.concatenate((input_traj, pred_riskier), axis=-2) | |
| travel_distances = np.sqrt( | |
| np.square(ped_trajs[..., -1, :] - ped_trajs[..., 0, :]).sum(-1) | |
| ) | |
| dist_cost, dist = self.dist_cost_func( | |
| self.ego_traj[..., self.num_steps :, :], | |
| ped_trajs[..., self.num_steps :, :], | |
| ) | |
| ttc_cost, (ttc, dist) = self.ttc_cost_func( | |
| self.ego_traj[..., self.num_steps :, :], | |
| ped_trajs[..., self.num_steps :, :], | |
| scene.get_ego_ref_velocity(), | |
| scene.get_pedestrians_velocities(), | |
| ) | |
| travel_distances_pred = np.sqrt( | |
| np.square(ped_trajs_pred[..., -1, :] - ped_trajs_pred[..., 0, :]).sum(-1) | |
| ) | |
| dist_cost_pred, dist_pred = self.dist_cost_func( | |
| self.ego_traj[..., self.num_steps :, :], | |
| ped_trajs_pred[..., self.num_steps :, :], | |
| ) | |
| sample_times = np.array(self.sample_times) | |
| ped_velocities_pred = ( | |
| ped_trajs_pred[..., 1:, :] - ped_trajs_pred[..., :-1, :] | |
| ) / ((sample_times[1:] - sample_times[:-1])[None, None, :, None]) | |
| ped_velocities_pred = np.concatenate( | |
| (ped_velocities_pred[..., 0:1, :], ped_velocities_pred), -2 | |
| ) | |
| ttc_cost_pred, (ttc_pred, dist_pred) = self.ttc_cost_func( | |
| self.ego_traj[..., self.num_steps :, :], | |
| ped_trajs_pred[..., self.num_steps :, :], | |
| scene.get_ego_ref_velocity(), | |
| ped_velocities_pred[..., self.num_steps :, :], | |
| ) | |
| travel_distances_pred_riskier = np.sqrt( | |
| np.square( | |
| ped_trajs_pred_riskier[..., -1, :] - ped_trajs_pred_riskier[..., 0, :] | |
| ).sum(-1) | |
| ) | |
| dist_cost_pred_riskier, dist_pred_riskier = self.dist_cost_func( | |
| self.ego_traj[..., self.num_steps :, :], | |
| ped_trajs_pred_riskier[..., self.num_steps :, :], | |
| ) | |
| sample_times = np.array(self.sample_times) | |
| ped_velocities_pred_riskier = ( | |
| ped_trajs_pred_riskier[..., 1:, :] - ped_trajs_pred_riskier[..., :-1, :] | |
| ) / ((sample_times[1:] - sample_times[:-1])[None, None, :, None]) | |
| ped_velocities_pred_riskier = np.concatenate( | |
| (ped_velocities_pred_riskier[..., 0:1, :], ped_velocities_pred_riskier), -2 | |
| ) | |
| ttc_cost_pred_riskier, (ttc_pred, dist_pred_riskier) = self.ttc_cost_func( | |
| self.ego_traj[..., self.num_steps :, :], | |
| ped_trajs_pred_riskier[..., self.num_steps :, :], | |
| scene.get_ego_ref_velocity(), | |
| ped_velocities_pred_riskier[..., self.num_steps :, :], | |
| ) | |
| data = [ | |
| [dist, dist_pred, dist_risk] | |
| for (dist, dist_pred, dist_risk) in zip( | |
| travel_distances.flatten(), | |
| travel_distances_pred.flatten(), | |
| travel_distances_pred_riskier.flatten(), | |
| ) | |
| ] | |
| table_travel_distance = wandb.Table( | |
| data=data, | |
| columns=[ | |
| "Travel distance data " + name, | |
| "Travel distance prediction " + name, | |
| "Travel distance riskier " + name, | |
| ], | |
| ) | |
| data = [ | |
| [cost, cost_pred, cost_risk] | |
| for (cost, cost_pred, cost_risk) in zip( | |
| dist_cost.flatten(), | |
| dist_cost_pred.flatten(), | |
| dist_cost_pred_riskier.flatten(), | |
| ) | |
| ] | |
| table_distance_cost = wandb.Table( | |
| data=data, | |
| columns=[ | |
| "Distance cost data " + name, | |
| "Distance cost prediction " + name, | |
| "Distance cost riskier " + name, | |
| ], | |
| ) | |
| data = [ | |
| [ttc, ttc_pred, ttc_risk] | |
| for (ttc, ttc_pred, ttc_risk) in zip( | |
| ttc_cost.flatten(), | |
| ttc_cost_pred.flatten(), | |
| ttc_cost_pred_riskier.flatten(), | |
| ) | |
| ] | |
| table_ttc_cost = wandb.Table( | |
| data=data, | |
| columns=[ | |
| "TTC cost data " + name, | |
| "TTC cost prediction " + name, | |
| "TTC cost riskier " + name, | |
| ], | |
| ) | |
| wandb.log( | |
| { | |
| "Travel distance data " | |
| + name: wandb.plot_table( | |
| vega_spec_name="jmercat/histogram_01_bins", | |
| data_table=table_travel_distance, | |
| fields={ | |
| "value": "Travel distance data " + name, | |
| "title": "Travel distance data " + name, | |
| }, | |
| ), | |
| "Travel distance prediction " | |
| + name: wandb.plot_table( | |
| vega_spec_name="jmercat/histogram_01_bins", | |
| data_table=table_travel_distance, | |
| fields={ | |
| "value": "Travel distance prediction " + name, | |
| "title": "Travel distance prediction " + name, | |
| }, | |
| ), | |
| "Travel distance riskier " | |
| + name: wandb.plot_table( | |
| vega_spec_name="jmercat/histogram_01_bins", | |
| data_table=table_travel_distance, | |
| fields={ | |
| "value": "Travel distance riskier " + name, | |
| "title": "Travel distance riskier " + name, | |
| }, | |
| ), | |
| "Distance cost data " | |
| + name: wandb.plot_table( | |
| vega_spec_name="jmercat/histogram_0025_bins", | |
| data_table=table_distance_cost, | |
| fields={ | |
| "value": "Distance cost data " + name, | |
| "title": "Distance cost data " + name, | |
| }, | |
| ), | |
| "Distance cost prediction " | |
| + name: wandb.plot_table( | |
| vega_spec_name="jmercat/histogram_0025_bins", | |
| data_table=table_distance_cost, | |
| fields={ | |
| "value": "Distance cost prediction " + name, | |
| "title": "Distance cost prediction " + name, | |
| }, | |
| ), | |
| "Distance cost riskier " | |
| + name: wandb.plot_table( | |
| vega_spec_name="jmercat/histogram_0025_bins", | |
| data_table=table_distance_cost, | |
| fields={ | |
| "value": "Distance cost riskier " + name, | |
| "title": "Distance cost riskier " + name, | |
| }, | |
| ), | |
| "TTC cost data " | |
| + name: wandb.plot_table( | |
| vega_spec_name="jmercat/histogram_005_bins", | |
| data_table=table_ttc_cost, | |
| fields={ | |
| "value": "TTC cost data " + name, | |
| "title": "TTC cost data " + name, | |
| }, | |
| ), | |
| "TTC cost prediction " | |
| + name: wandb.plot_table( | |
| vega_spec_name="jmercat/histogram_005_bins", | |
| data_table=table_ttc_cost, | |
| fields={ | |
| "value": "TTC cost prediction " + name, | |
| "title": "TTC cost prediction " + name, | |
| }, | |
| ), | |
| "TTC cost riskier " | |
| + name: wandb.plot_table( | |
| vega_spec_name="jmercat/histogram_005_bins", | |
| data_table=table_ttc_cost, | |
| fields={ | |
| "value": "TTC cost riskier " + name, | |
| "title": "TTC cost riskier " + name, | |
| }, | |
| ), | |
| } | |
| ) | |
| def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | |
| """After a validation at the end of every histogram_interval_epoch, | |
| log the histograms for two scenes: the safer fast scene and the safer slow scene. | |
| """ | |
| if ( | |
| trainer.current_epoch % self.histogram_interval_epoch | |
| == self.histogram_interval_epoch - 1 | |
| ): | |
| self._log_scene(pl_module, self.scene_safe_fast, name="Safer fast") | |
| self._log_scene(pl_module, self.scene_safe_slow, name="Safer slow") | |
| class PlotTrajCallback(pl.Callback): | |
| """Plot trajectory samples for two scenes: | |
| One that is safer for the slow pedestrians | |
| One that is safer for the fast pedestrians | |
| Samples of ground truth, prediction, and biased predictions are superposed. | |
| Last positions are marked to visualize the clusters. | |
| Args: | |
| params: dataclass containing the necessary parameters for a | |
| n_samples: number of sample trajectories to draw | |
| """ | |
| def __init__( | |
| self, | |
| params: DrawCallbackParams, | |
| n_samples: int = 1, | |
| ): | |
| super().__init__() | |
| self.n_samples = n_samples | |
| self.num_steps = params.num_steps | |
| self.dt = params.scene_params.dt | |
| self.scene_params = params.scene_params | |
| self.plot_interval_epoch = params.plot_interval_epoch | |
| self.scene_safe_fast, self.scene_safe_slow = get_fast_slow_scenes( | |
| params.scene_params, n_samples | |
| ) | |
| self.ego_traj = self.scene_safe_fast.get_ego_ref_trajectory( | |
| params.scene_params.sample_times | |
| ) | |
| self._risk_sampler = get_risk_level_sampler(params.risk_distribution) | |
| def _log_scene(self, epoch: int, pl_module, scene: RandomScene, name: str) -> None: | |
| """Add drawing of samples of prediction, biased prediction and ground truth in the scene. | |
| Args: | |
| epoch: current epoch calling the log | |
| pl_module: pytorch lightning module being trained | |
| scene: scene to draw | |
| name: name of the scene | |
| """ | |
| ped_trajs = scene.get_pedestrians_trajectories() | |
| device = pl_module.device | |
| n_agents = ped_trajs.shape[1] | |
| input_traj = ped_trajs[..., : self.num_steps, :] | |
| normalized_input, offset = SceneDataLoaders.normalize_trajectory( | |
| torch.from_numpy(input_traj.astype("float32")).contiguous().to(device) | |
| ) | |
| mask_input = torch.ones_like(normalized_input[..., 0]) | |
| ego_history = ( | |
| torch.from_numpy(self.ego_traj[..., : self.num_steps, :].astype("float32")) | |
| .expand_as(normalized_input) | |
| .contiguous() | |
| .to(device) | |
| ) | |
| ego_future = ( | |
| torch.from_numpy(self.ego_traj[..., self.num_steps :, :].astype("float32")) | |
| .expand(normalized_input.shape[0], n_agents, -1, -1) | |
| .contiguous() | |
| .to(device) | |
| ) | |
| map = torch.empty(ego_history.shape[0], 0, 0, 2, device=mask_input.device) | |
| mask_map = torch.empty(ego_history.shape[0], 0, 0, device=mask_input.device) | |
| pred_riskier = ( | |
| pl_module.predict_step( | |
| ( | |
| normalized_input, | |
| mask_input, | |
| map, | |
| mask_map, | |
| offset, | |
| ego_history, | |
| ego_future, | |
| ), | |
| 0, | |
| risk_level=self._risk_sampler.get_highest_risk( | |
| batch_size=self.n_samples, device=device | |
| ) | |
| .unsqueeze(1) | |
| .repeat(1, n_agents), | |
| ) | |
| .cpu() | |
| .detach() | |
| .numpy() | |
| ) | |
| pred = ( | |
| pl_module.predict_step( | |
| ( | |
| normalized_input, | |
| mask_input, | |
| map, | |
| mask_map, | |
| offset, | |
| ego_history, | |
| ego_future, | |
| ), | |
| 0, | |
| risk_level=None, | |
| ) | |
| .cpu() | |
| .detach() | |
| .numpy() | |
| ) | |
| fig, ax = plt.subplots() | |
| plotter = ScenePlotter(scene, ax=ax) | |
| fig.set_size_inches(h=scene.road_width / 3 + 1, w=scene.road_length / 3) | |
| time = self.dt * self.num_steps | |
| plotter.draw_scene(0, time=time) | |
| alpha = 0.5 / np.log(self.n_samples) | |
| plotter.draw_all_trajectories( | |
| ped_trajs[..., self.num_steps :, :], | |
| color="g", | |
| alpha=alpha, | |
| label="Future ground truth", | |
| ) | |
| plotter.draw_all_trajectories( | |
| input_traj, color="b", alpha=alpha, label="Past input" | |
| ) | |
| plotter.draw_all_trajectories( | |
| pred, color="orange", alpha=alpha, label="Prediction" | |
| ) | |
| plotter.draw_all_trajectories( | |
| pred_riskier, color="r", alpha=alpha, label="Prediction risk-seeking" | |
| ) | |
| plotter.draw_legend() | |
| plt.tight_layout() | |
| wandb.log({"Road scene " + name: wandb.Image(fig), "epoch": epoch}) | |
| plt.close() | |
| def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | |
| """After a validation at the end of every plot_interval_epoch, | |
| log the prediction samples for two scenes: the safer fast scene and the safer slow scene. | |
| """ | |
| if ( | |
| trainer.current_epoch % self.plot_interval_epoch | |
| == self.plot_interval_epoch - 1 | |
| ): | |
| self.scene_safe_fast, self.scene_safe_slow = get_fast_slow_scenes( | |
| self.scene_params, self.n_samples | |
| ) | |
| self._log_scene( | |
| trainer.current_epoch, pl_module, self.scene_safe_slow, "Safer slow" | |
| ) | |
| self._log_scene( | |
| trainer.current_epoch, pl_module, self.scene_safe_fast, "Safer fast" | |
| ) | |
| # TODO: make the same kind of logs for the Waymo dataset | |