Spaces:
Runtime error
Runtime error
| import os | |
| import io | |
| import pickle | |
| import sys | |
| from functools import partial | |
| from inspect import signature | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| from einops import repeat | |
| import fire | |
| import numpy as np | |
| from pytorch_lightning.utilities.seed import seed_everything | |
| import torch | |
| from risk_biased.utils.config_argparse import config_argparse | |
| from risk_biased.utils.cost import TTCCostTorch, TTCCostParams, get_cost | |
| from risk_biased.utils.risk import get_risk_estimator | |
| from risk_biased.utils.load_model import load_from_config | |
| def to_device(batch, device): | |
| output = [] | |
| for item in batch: | |
| output.append(item.to(device)) | |
| return output | |
| class CPU_Unpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| if module == "torch.storage" and name == "_load_from_bytes": | |
| return lambda b: torch.load(io.BytesIO(b), map_location="cpu") | |
| else: | |
| return super().find_class(module, name) | |
| def distance(pred, truth): | |
| """ | |
| pred (Tensor): (..., time, xy) | |
| truth (Tensor): (..., time, xy) | |
| mask_loss (Tensor): (..., time) Defaults to None. | |
| """ | |
| return torch.sqrt(torch.sum(torch.square(pred[..., :2] - truth[..., :2]), -1)) | |
| def compute_metrics( | |
| predictor, | |
| batch, | |
| cost, | |
| risk_levels, | |
| risk_estimator, | |
| dt, | |
| unnormalizer, | |
| n_samples_risk, | |
| n_samples_stats, | |
| ): | |
| # risk_unbiased | |
| # risk_biased | |
| # cost | |
| # FDE: unbiased, biased(risk_level=[0, 0.3, 0.5, 0.8, 1]) (for all samples so minFDE can be computed later) | |
| # ADE (for all samples so minADE can be computed later) | |
| x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch | |
| mask_z = mask_x.any(-1) | |
| _, z_mean_inference, z_log_std_inference = predictor.model( | |
| x, | |
| mask_x, | |
| map, | |
| mask_map, | |
| offset=offset, | |
| x_ego=x_ego, | |
| y_ego=y_ego, | |
| risk_level=None, | |
| ) | |
| latent_distribs = { | |
| "inference": { | |
| "mean": z_mean_inference[:, 1].detach().cpu(), | |
| "log_std": z_log_std_inference[:, 1].detach().cpu(), | |
| } | |
| } | |
| inference_distances = [] | |
| cost_list = [] | |
| # Cut the number of samples in packs to avoid out-of-memory problems | |
| # Compute and store cost for all packs | |
| for _ in range(n_samples_risk // n_samples_stats): | |
| z_samples_inference = predictor.model.inference_encoder.sample( | |
| z_mean_inference, | |
| z_log_std_inference, | |
| n_samples=n_samples_stats, | |
| ) | |
| y_samples = predictor.model.decode( | |
| z_samples=z_samples_inference, | |
| mask_z=mask_z, | |
| x=x, | |
| mask_x=mask_x, | |
| map=map, | |
| mask_map=mask_map, | |
| offset=offset, | |
| ) | |
| mask_loss_samples = repeat(mask_loss, "b a t -> b a s t", s=n_samples_stats) | |
| # Computing unbiased cost | |
| cost_list.append( | |
| get_cost( | |
| cost, | |
| x, | |
| y_samples, | |
| offset, | |
| x_ego, | |
| y_ego, | |
| dt, | |
| unnormalizer, | |
| mask_loss_samples, | |
| )[:, 1:2] | |
| ) | |
| inference_distances.append(distance(y_samples, y.unsqueeze(2))[:, 1:2]) | |
| cost_dic = {} | |
| cost_dic["inference"] = torch.cat(cost_list, 2).detach().cpu() | |
| distance_dic = {} | |
| distance_dic["inference"] = torch.cat(inference_distances, 2).detach().cpu() | |
| # Set up the output risk tensor | |
| risk_dic = {} | |
| # Loop on risk_level values to fill the risk estimation for each value and compute stats at each risk level | |
| for rl in risk_levels: | |
| risk_level = ( | |
| torch.ones( | |
| (x.shape[0], x.shape[1]), | |
| device=x.device, | |
| ) | |
| * rl | |
| ) | |
| risk_dic[f"biased_{rl}"] = risk_estimator( | |
| risk_level[:, 1:2].detach().cpu(), cost_dic["inference"] | |
| ) | |
| y_samples_biased, z_mean_biased, z_log_std_biased = predictor.model( | |
| x, | |
| mask_x, | |
| map, | |
| mask_map, | |
| offset=offset, | |
| x_ego=x_ego, | |
| y_ego=y_ego, | |
| risk_level=risk_level, | |
| n_samples=n_samples_stats, | |
| ) | |
| latent_distribs[f"biased_{rl}"] = { | |
| "mean": z_mean_biased[:, 1].detach().cpu(), | |
| "log_std": z_log_std_biased[:, 1].detach().cpu(), | |
| } | |
| distance_dic[f"biased_{rl}"] = ( | |
| distance(y_samples_biased, y.unsqueeze(2))[:, 1].detach().cpu() | |
| ) | |
| cost_dic[f"biased_{rl}"] = ( | |
| get_cost( | |
| cost, | |
| x, | |
| y_samples_biased, | |
| offset, | |
| x_ego, | |
| y_ego, | |
| dt, | |
| unnormalizer, | |
| mask_loss_samples, | |
| )[:, 1] | |
| .detach() | |
| .cpu() | |
| ) | |
| # Return risks for the batch and all risk values | |
| return { | |
| "risk": risk_dic, | |
| "cost": cost_dic, | |
| "distance": distance_dic, | |
| "latent_distribs": latent_distribs, | |
| "mask": mask_loss[:, 1].detach().cpu(), | |
| } | |
| def cat_metrics_rec(metrics1, metrics2, cat_to): | |
| for key in metrics1.keys(): | |
| if key not in metrics2.keys(): | |
| raise RuntimeError( | |
| f"Trying to concatenate objects with different keys: {key} is not in second argument keys." | |
| ) | |
| elif isinstance(metrics1[key], dict): | |
| if key not in cat_to.keys(): | |
| cat_to[key] = {} | |
| cat_metrics_rec(metrics1[key], metrics2[key], cat_to[key]) | |
| elif isinstance(metrics1[key], torch.Tensor): | |
| cat_to[key] = torch.cat((metrics1[key], metrics2[key]), 0) | |
| def cat_metrics(metrics1, metrics2): | |
| out = {} | |
| cat_metrics_rec(metrics1, metrics2, out) | |
| return out | |
| def masked_mean_std_ste(data, mask): | |
| mask = mask.view(data.shape) | |
| norm = mask.sum().clamp_min(1) | |
| mean = (data * mask).sum() / norm | |
| std = torch.sqrt(((data - mean) * mask).square().sum() / norm) | |
| return mean.item(), std.item(), (std / torch.sqrt(norm)).item() | |
| def masked_mean_range(data, mask): | |
| data = data[mask] | |
| mean = data.mean() | |
| min = torch.quantile(data, 0.05) | |
| max = torch.quantile(data, 0.95) | |
| return mean, min, max | |
| def masked_mean_dim(data, mask, dim): | |
| norm = mask.sum(dim).clamp_min(1) | |
| mean = (data * mask).sum(dim) / norm | |
| return mean | |
| def plot_risk_error(metrics, risk_levels, risk_estimator, max_n_samples, path_save): | |
| cost_inference = metrics["cost"]["inference"] | |
| cost_biased_0 = metrics["cost"]["biased_0"] | |
| mask = metrics["mask"].any(1) | |
| ones_tensor = torch.ones(mask.shape[0]) | |
| n_samples = np.minimum(cost_biased_0.shape[1], max_n_samples) | |
| for rl in risk_levels: | |
| key = f"biased_{rl}" | |
| reference_risk = metrics["risk"][key] | |
| mean_inference_risk_error_per_samples = np.zeros(n_samples - 1) | |
| min_inference_risk_error_per_samples = np.zeros(n_samples - 1) | |
| max_inference_risk_error_per_samples = np.zeros(n_samples - 1) | |
| # mean_biased_0_risk_error_per_samples = np.zeros(n_samples-1) | |
| # min_biased_0_risk_error_per_samples = np.zeros(n_samples-1) | |
| # max_biased_0_risk_error_per_samples = np.zeros(n_samples-1) | |
| mean_biased_risk_error_per_samples = np.zeros(n_samples - 1) | |
| min_biased_risk_error_per_samples = np.zeros(n_samples - 1) | |
| max_biased_risk_error_per_samples = np.zeros(n_samples - 1) | |
| risk_level_tensor = ones_tensor * rl | |
| for sub_samples in range(1, n_samples): | |
| perm = torch.randperm(metrics["cost"][key].shape[1])[:sub_samples] | |
| risk_error_biased = metrics["cost"][key][:, perm].mean(1) - reference_risk | |
| ( | |
| mean_biased_risk_error_per_samples[sub_samples - 1], | |
| min_biased_risk_error_per_samples[sub_samples - 1], | |
| max_biased_risk_error_per_samples[sub_samples - 1], | |
| ) = masked_mean_range(risk_error_biased, mask) | |
| risk_error_inference = ( | |
| risk_estimator(risk_level_tensor, cost_inference[:, :, :sub_samples]) | |
| - reference_risk | |
| ) | |
| ( | |
| mean_inference_risk_error_per_samples[sub_samples - 1], | |
| min_inference_risk_error_per_samples[sub_samples - 1], | |
| max_inference_risk_error_per_samples[sub_samples - 1], | |
| ) = masked_mean_range(risk_error_inference, mask) | |
| # risk_error_biased_0 = risk_estimator(risk_level_tensor, cost_biased_0[:, :sub_samples]) - reference_risk | |
| # (mean_biased_0_risk_error_per_samples[sub_samples-1], min_biased_0_risk_error_per_samples[sub_samples-1], max_biased_0_risk_error_per_samples[sub_samples-1]) = masked_mean_range(risk_error_biased_0, mask) | |
| plt.plot( | |
| range(1, n_samples), | |
| mean_inference_risk_error_per_samples, | |
| label="Inference", | |
| ) | |
| plt.fill_between( | |
| range(1, n_samples), | |
| min_inference_risk_error_per_samples, | |
| max_inference_risk_error_per_samples, | |
| alpha=0.3, | |
| ) | |
| # plt.plot(range(1, n_samples), mean_biased_0_risk_error_per_samples, label="Unbiased") | |
| # plt.fill_between(range(1, n_samples), min_biased_0_risk_error_per_samples, max_biased_0_risk_error_per_samples, alpha=.3) | |
| plt.plot( | |
| range(1, n_samples), mean_biased_risk_error_per_samples, label="Biased" | |
| ) | |
| plt.fill_between( | |
| range(1, n_samples), | |
| min_biased_risk_error_per_samples, | |
| max_biased_risk_error_per_samples, | |
| alpha=0.3, | |
| ) | |
| plt.ylim( | |
| np.min(min_inference_risk_error_per_samples), | |
| np.max(max_biased_risk_error_per_samples), | |
| ) | |
| plt.hlines(y=0, xmin=0, xmax=n_samples, colors="black", linestyles="--", lw=0.3) | |
| plt.xlabel("Number of samples") | |
| plt.ylabel("Risk estimation error") | |
| plt.legend() | |
| plt.title(f"Risk estimation error at risk-level={rl}") | |
| plt.gcf().set_size_inches(4, 3) | |
| plt.legend(loc="lower right") | |
| plt.savefig(fname=os.path.join(path_save, f"risk_level_{rl}.svg")) | |
| plt.savefig(fname=os.path.join(path_save, f"risk_level_{rl}.png")) | |
| plt.clf() | |
| # plt.show() | |
| def compute_stats(metrics, n_samples_mean_cost=4): | |
| biased_risk_estimate = {} | |
| for key in metrics["cost"].keys(): | |
| if key == "inference": | |
| continue | |
| risk = metrics["risk"][key] | |
| mean_cost = metrics["cost"][key][:, :n_samples_mean_cost].mean(1) | |
| risk_error = mean_cost - risk | |
| biased_risk_estimate[key] = {} | |
| ( | |
| biased_risk_estimate[key]["mean"], | |
| biased_risk_estimate[key]["std"], | |
| biased_risk_estimate[key]["ste"], | |
| ) = masked_mean_std_ste(risk_error, metrics["mask"].any(1)) | |
| ( | |
| biased_risk_estimate[key]["mean_abs"], | |
| biased_risk_estimate[key]["std_abs"], | |
| biased_risk_estimate[key]["ste_abs"], | |
| ) = masked_mean_std_ste(risk_error.abs(), metrics["mask"].any(1)) | |
| risk_stats = {} | |
| for key in metrics["risk"].keys(): | |
| risk_stats[key] = {} | |
| ( | |
| risk_stats[key]["mean"], | |
| risk_stats[key]["std"], | |
| risk_stats[key]["ste"], | |
| ) = masked_mean_std_ste(metrics["risk"][key], metrics["mask"].any(1)) | |
| cost_stats = {} | |
| for key in metrics["cost"].keys(): | |
| cost_stats[key] = {} | |
| ( | |
| cost_stats[key]["mean"], | |
| cost_stats[key]["std"], | |
| cost_stats[key]["ste"], | |
| ) = masked_mean_std_ste( | |
| metrics["cost"][key], metrics["mask"].any(-1, keepdim=True) | |
| ) | |
| distance_stats = {} | |
| for key in metrics["distance"].keys(): | |
| distance_stats[key] = {"FDE": {}, "ADE": {}, "minFDE": {}, "minADE": {}} | |
| ( | |
| distance_stats[key]["FDE"]["mean"], | |
| distance_stats[key]["FDE"]["std"], | |
| distance_stats[key]["FDE"]["ste"], | |
| ) = masked_mean_std_ste( | |
| metrics["distance"][key][..., -1], metrics["mask"][:, None, -1] | |
| ) | |
| mean_dist = masked_mean_dim( | |
| metrics["distance"][key], metrics["mask"][:, None, :], -1 | |
| ) | |
| ( | |
| distance_stats[key]["ADE"]["mean"], | |
| distance_stats[key]["ADE"]["std"], | |
| distance_stats[key]["ADE"]["ste"], | |
| ) = masked_mean_std_ste(mean_dist, metrics["mask"].any(-1, keepdim=True)) | |
| for i in [6, 16, 32]: | |
| distance_stats[key]["minFDE"][i] = {} | |
| min_dist, _ = metrics["distance"][key][:, :i, -1].min(1) | |
| ( | |
| distance_stats[key]["minFDE"][i]["mean"], | |
| distance_stats[key]["minFDE"][i]["std"], | |
| distance_stats[key]["minFDE"][i]["ste"], | |
| ) = masked_mean_std_ste(min_dist, metrics["mask"][:, -1]) | |
| distance_stats[key]["minADE"][i] = {} | |
| mean_dist, _ = masked_mean_dim( | |
| metrics["distance"][key][:, :i], metrics["mask"][:, None, :], -1 | |
| ).min(1) | |
| ( | |
| distance_stats[key]["minADE"][i]["mean"], | |
| distance_stats[key]["minADE"][i]["std"], | |
| distance_stats[key]["minADE"][i]["ste"], | |
| ) = masked_mean_std_ste(mean_dist, metrics["mask"].any(-1)) | |
| return { | |
| "risk": risk_stats, | |
| "biased_risk_estimate": biased_risk_estimate, | |
| "cost": cost_stats, | |
| "distance": distance_stats, | |
| } | |
| def print_stats(stats, n_samples_mean_cost=4): | |
| slash = "\\" | |
| brace_open = "{" | |
| brace_close = "}" | |
| print("\\begin{tabular}{lccccc}") | |
| print("\\hline") | |
| print( | |
| f"Predictive Model & ${slash}sigma$ & minFDE(16) & FDE (1) & Risk est. error ({n_samples_mean_cost}) & Risk est. $|$error$|$ ({n_samples_mean_cost}) {slash}{slash}" | |
| ) | |
| print("\\hline") | |
| for key in stats["distance"].keys(): | |
| strg = ( | |
| f" ${stats['distance'][key]['minFDE'][16]['mean']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['distance'][key]['minFDE'][16]['ste']:.2f}${brace_close}" | |
| + f"& ${stats['distance'][key]['FDE']['mean']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['distance'][key]['FDE']['ste']:.2f}${brace_close}" | |
| ) | |
| if key == "inference": | |
| strg = ( | |
| "Unbiased CVAE & " | |
| + f"{slash}scriptsize{brace_open}NA{brace_close} &" | |
| + strg | |
| + f"& {slash}scriptsize{brace_open}NA{brace_close} & {slash}scriptsize{brace_open}NA{brace_close} {slash}{slash}" | |
| ) | |
| print(strg) | |
| print("\\hline") | |
| else: | |
| strg = ( | |
| "Biased CVAE & " | |
| + f"{key[7:]} & " | |
| + strg | |
| + f"& ${stats['biased_risk_estimate'][key]['mean']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['biased_risk_estimate'][key]['ste']:.2f}${brace_close}" | |
| + f"& ${stats['biased_risk_estimate'][key]['mean_abs']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['biased_risk_estimate'][key]['ste_abs']:.2f}${brace_close}" | |
| + f"{slash}{slash}" | |
| ) | |
| print(strg) | |
| print("\\hline") | |
| print("\\end{tabular}") | |
| def main( | |
| log_path, | |
| force_recompute, | |
| n_samples_risk=256, | |
| n_samples_stats=32, | |
| n_samples_plot=16, | |
| args_to_parser=[], | |
| ): | |
| # Overwrite sys.argv so it doesn't mess up the parser. | |
| sys.argv = sys.argv[0:1] + args_to_parser | |
| working_dir = os.path.dirname(os.path.realpath(__file__)) | |
| config_path = os.path.join( | |
| working_dir, "..", "..", "risk_biased", "config", "learning_config.py" | |
| ) | |
| waymo_config_path = os.path.join( | |
| working_dir, "..", "..", "risk_biased", "config", "waymo_config.py" | |
| ) | |
| cfg = config_argparse([config_path, waymo_config_path]) | |
| file_path = os.path.join(log_path, f"metrics_{cfg.load_from}.pickle") | |
| fig_path = os.path.join(log_path, f"plots_{cfg.load_from}") | |
| if not os.path.exists(fig_path): | |
| os.makedirs(fig_path) | |
| risk_levels = [0, 0.3, 0.5, 0.8, 0.95, 1] | |
| cost = TTCCostTorch(TTCCostParams.from_config(cfg)) | |
| risk_estimator = get_risk_estimator(cfg.risk_estimator) | |
| n_samples_mean_cost = 4 | |
| if not os.path.exists(file_path) or force_recompute: | |
| with torch.no_grad(): | |
| if cfg.seed is not None: | |
| seed_everything(cfg.seed) | |
| predictor, dataloaders, cfg = load_from_config(cfg) | |
| device = torch.device(cfg.gpus[0]) | |
| predictor = predictor.to(device) | |
| val_loader = dataloaders.val_dataloader(shuffle=False, drop_last=False) | |
| # This loops over batches in the validation dataset | |
| beg = 0 | |
| metrics_all = None | |
| for val_batch in tqdm(val_loader): | |
| end = beg + val_batch[0].shape[0] | |
| metrics = compute_metrics( | |
| predictor=predictor, | |
| batch=to_device(val_batch, device), | |
| cost=cost, | |
| risk_levels=risk_levels, | |
| risk_estimator=risk_estimator, | |
| dt=cfg.dt, | |
| unnormalizer=dataloaders.unnormalize_trajectory, | |
| n_samples_risk=n_samples_risk, | |
| n_samples_stats=n_samples_stats, | |
| ) | |
| if metrics_all is None: | |
| metrics_all = metrics | |
| else: | |
| metrics_all = cat_metrics(metrics_all, metrics) | |
| beg = end | |
| with open(file_path, "wb") as handle: | |
| pickle.dump(metrics_all, handle) | |
| else: | |
| print(f"Loading pre-computed metrics from {file_path}") | |
| with open(file_path, "rb") as handle: | |
| metrics_all = CPU_Unpickler(handle).load() | |
| stats = compute_stats(metrics_all, n_samples_mean_cost=n_samples_mean_cost) | |
| print_stats(stats, n_samples_mean_cost=n_samples_mean_cost) | |
| plot_risk_error(metrics_all, risk_levels, risk_estimator, n_samples_plot, fig_path) | |
| if __name__ == "__main__": | |
| # main("./logs/002/", False, 256, 32, 16) | |
| # Fire turns the main function into a script, then the risk_biased module argparse reads the other arguments. | |
| # Thus, the way to use it would be: | |
| # >python compute_stats.py <path to existing log dir> <Force recompute> <n_samples_risk> <n_samples_stats> <n_samples_plot> <other argparse arguments, example --load_from 1uail32> | |
| # This is a hack to separate the Fire script args from the argparse arguments | |
| args_to_parser = sys.argv[len(signature(main).parameters) :] | |
| partial_main = partial(main, args_to_parser=args_to_parser) | |
| sys.argv = sys.argv[: len(signature(main).parameters)] | |
| # Runs the main as a script | |
| fire.Fire(partial_main) | |