Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from abc import ABC, abstractmethod | |
| from typing import Optional, Tuple | |
| from numpy import isin | |
| import torch | |
| from risk_biased.mpc_planner.planner_cost import TrackingCost | |
| from risk_biased.utils.cost import BaseCostTorch | |
| from risk_biased.utils.risk import AbstractMonteCarloRiskEstimator | |
| def get_rotation_matrix(angle, device): | |
| c = torch.cos(angle) | |
| s = torch.sin(angle) | |
| rot_matrix = torch.stack( | |
| (torch.stack((c, s), -1), torch.stack((-s, c), -1)), -1 | |
| ).to(device) | |
| return rot_matrix | |
| class AbstractState(ABC): | |
| """ | |
| State representation using an underlying tensor. Position, Velocity, and Angle can be accessed. | |
| """ | |
| def position(self) -> torch.Tensor: | |
| """Extract position information from the state tensor | |
| Returns: | |
| position_tensor of size (..., 2) | |
| """ | |
| def velocity(self) -> torch.Tensor: | |
| """Extract velocity information from the state tensor | |
| Returns: | |
| velocity_tensor of size (..., 2) | |
| """ | |
| def angle(self) -> torch.Tensor: | |
| """Extract velocity information from the state tensor | |
| Returns: | |
| velocity_tensor of size (..., 1) | |
| """ | |
| def get_states(self, dim: int) -> torch.Tensor: | |
| """Return the underlying states tensor with dim 2, 4 or 5 ([x, y], [x, y, vx, vy], or [x, y, angle, vx, vy]).""" | |
| def rotate(self, angle: float, in_place: bool) -> AbstractState: | |
| """Rotate the state by the given angle | |
| Args: | |
| angle: in radiants | |
| in_place: wether to change the object itself or return a rotated copy | |
| Returns: | |
| rotated self or rotated copy of self | |
| """ | |
| def translate(self, translation: torch.Tensor, in_place: bool) -> AbstractState: | |
| """Translate the state by the given tranlation | |
| Args: | |
| translation: translation vector in 2 dimensions | |
| in_place: wether to change the object itself or return a rotated copy | |
| """ | |
| # Define overloading operators to behave as a tensor for some operations | |
| def __getitem__(self, key) -> AbstractState: | |
| """ | |
| Use get item on the underlying tensor to get the item at the given key. | |
| Allways returns a velocity state so that if the underlying time sequence is reduced to one step, the velocity is still accessible. | |
| """ | |
| if isinstance(key, int): | |
| key = (key, Ellipsis, slice(None, None, None)) | |
| elif Ellipsis not in key: | |
| key = (*key, Ellipsis, slice(None, None, None)) | |
| else: | |
| key = (*key, slice(None, None, None)) | |
| return to_state( | |
| torch.cat( | |
| ( | |
| self.position[key], | |
| self.velocity[key], | |
| ), | |
| dim=-1, | |
| ), | |
| self.dt, | |
| ) | |
| def shape(self): | |
| return self._states.shape[:-1] | |
| def to_state(in_tensor: torch.Tensor, dt: float) -> AbstractState: | |
| if in_tensor.shape[-1] == 2: | |
| return PositionSequenceState(in_tensor, dt) | |
| elif in_tensor.shape[-1] == 4: | |
| return PositionVelocityState(in_tensor, dt) | |
| else: | |
| assert in_tensor.shape[-1] > 4 | |
| return PositionAngleVelocityState(in_tensor, dt) | |
| class PositionSequenceState(AbstractState): | |
| """ | |
| State representation with an underlying tensor defining only positions. | |
| """ | |
| def __init__(self, states: torch.Tensor, dt: float) -> None: | |
| super().__init__() | |
| assert ( | |
| states.shape[-1] == 2 | |
| ) # Check that the input tensor defines only the position | |
| assert ( | |
| states.ndim > 1 and states.shape[-2] > 1 | |
| ) # Check that the input tensor defines a sequence of positions (otherwise velocity cannot be computed) | |
| self.dt = dt | |
| self._states = states.clone() | |
| def position(self) -> torch.Tensor: | |
| return self._states | |
| def velocity(self) -> torch.Tensor: | |
| vel = (self._states[..., 1:, :] - self._states[..., :-1, :]) / self.dt | |
| vel = torch.cat((vel[..., 0:1, :], vel), dim=-2) | |
| return vel.clone() | |
| def angle(self) -> torch.Tensor: | |
| vel = self.velocity | |
| angle = torch.arctan2(vel[..., 1:2], vel[..., 0:1]) | |
| return angle | |
| def get_states(self, dim: int = 2) -> torch.Tensor: | |
| if dim == 2: | |
| return self._states.clone() | |
| elif dim == 4: | |
| return torch.cat((self._states.clone(), self.velocity), dim=-1) | |
| elif dim == 5: | |
| return torch.cat((self._states.clone(), self.angle, self.velocity), dim=-1) | |
| else: | |
| raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}") | |
| def rotate(self, angle: float, in_place: bool = False) -> PositionSequenceState: | |
| """Rotate the state by the given angle in radiants""" | |
| rot_matrix = get_rotation_matrix(angle, self._states.device) | |
| if in_place: | |
| self._states = (rot_matrix @ self._states.unsqueeze(-1)).squeeze(-1) | |
| return self | |
| else: | |
| return to_state( | |
| (rot_matrix @ self._states.unsqueeze(-1).clone()).squeeze(-1), self.dt | |
| ) | |
| def translate( | |
| self, translation: torch.Tensor, in_place: bool = False | |
| ) -> PositionSequenceState: | |
| """Translate the state by the given tranlation""" | |
| if in_place: | |
| self._states[..., :2] += translation.expand_as(self._states[..., :2]) | |
| return self | |
| else: | |
| return to_state( | |
| self._states[..., :2].clone() | |
| + translation.expand_as(self._states[..., :2]), | |
| self.dt, | |
| ) | |
| class PositionVelocityState(AbstractState): | |
| """ | |
| State representation with an underlying tensor defining position and velocity. | |
| """ | |
| def __init__(self, states: torch.Tensor, dt) -> None: | |
| super().__init__() | |
| assert states.shape[-1] == 4 | |
| self._states = states.clone() | |
| self.dt = dt | |
| def position(self) -> torch.Tensor: | |
| return self._states[..., :2] | |
| def velocity(self) -> torch.Tensor: | |
| return self._states[..., 2:4] | |
| def angle(self) -> torch.Tensor: | |
| vel = self.velocity | |
| angle = torch.arctan2(vel[..., 1:2], vel[..., 0:1]) | |
| return angle | |
| def get_states(self, dim: int = 4) -> torch.Tensor: | |
| if dim == 2: | |
| return self._states[..., :2].clone() | |
| elif dim == 4: | |
| return self._states.clone() | |
| elif dim == 5: | |
| return torch.cat( | |
| ( | |
| self._states[..., :2].clone(), | |
| self.angle, | |
| self._states[..., 2:].clone(), | |
| ), | |
| dim=-1, | |
| ) | |
| else: | |
| raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}") | |
| def rotate( | |
| self, angle: torch.Tensor, in_place: bool = False | |
| ) -> PositionVelocityState: | |
| """Rotate the state by the given angle in radiants""" | |
| rot_matrix = get_rotation_matrix(angle, self._states.device) | |
| rotated_pos = (rot_matrix @ self.position.unsqueeze(-1)).squeeze(-1) | |
| rotated_vel = (rot_matrix @ self.velocity.unsqueeze(-1)).squeeze(-1) | |
| if in_place: | |
| self._states = torch.cat((rotated_pos, rotated_vel), dim=-1) | |
| return self | |
| else: | |
| return to_state(torch.cat((rotated_pos, rotated_vel), dim=-1), self.dt) | |
| def translate( | |
| self, translation: torch.Tensor, in_place: bool = False | |
| ) -> PositionVelocityState: | |
| """Translate the state by the given tranlation""" | |
| if in_place: | |
| self._states[..., :2] += translation.expand_as(self._states[..., :2]) | |
| return self | |
| else: | |
| return to_state( | |
| torch.cat( | |
| ( | |
| self._states[..., :2].clone() | |
| + translation.expand_as(self._states[..., :2]), | |
| self._states[..., 2:].clone(), | |
| ), | |
| dim=-1, | |
| ), | |
| self.dt, | |
| ) | |
| class PositionAngleVelocityState(AbstractState): | |
| """ | |
| State representation with an underlying tensor representing position angle and velocity. | |
| """ | |
| def __init__(self, states: torch.Tensor, dt: float) -> None: | |
| super().__init__() | |
| assert states.shape[-1] == 5 | |
| self._states = states.clone() | |
| self.dt = dt | |
| def position(self) -> torch.Tensor: | |
| return self._states[..., :2].clone() | |
| def velocity(self) -> torch.Tensor: | |
| return self._states[..., 3:5].clone() | |
| def angle(self) -> torch.Tensor: | |
| return self._states[..., 2:3].clone() | |
| def get_states(self, dim: int = 5) -> torch.Tensor: | |
| if dim == 2: | |
| return self._states[..., :2].clone() | |
| elif dim == 4: | |
| return torch.cat( | |
| (self._states[..., :2].clone(), self._states[..., 3:].clone()), dim=-1 | |
| ) | |
| elif dim == 5: | |
| return self._states.clone() | |
| else: | |
| raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}") | |
| def rotate( | |
| self, angle: float, in_place: bool = False | |
| ) -> PositionAngleVelocityState: | |
| """Rotate the state by the given angle in radiants""" | |
| rot_matrix = get_rotation_matrix(angle, self._states.device) | |
| rotated_pos = (rot_matrix @ self.position.unsqueeze(-1)).squeeze(-1) | |
| rotated_angle = self.angle + angle | |
| rotated_vel = (rot_matrix @ self.velocity.unsqueeze(-1)).squeeze(-1) | |
| if in_place: | |
| self._states = torch.cat(rotated_pos, rotated_angle, rotated_vel, -1) | |
| return self | |
| else: | |
| return to_state( | |
| torch.cat(rotated_pos, rotated_angle, rotated_vel, -1), self.dt | |
| ) | |
| def translate( | |
| self, translation: torch.Tensor, in_place: bool = False | |
| ) -> PositionAngleVelocityState: | |
| """Translate the state by the given tranlation""" | |
| if in_place: | |
| self._states[..., :2] += translation.expand_as(self._states[..., :2]) | |
| return self | |
| else: | |
| return to_state( | |
| torch.cat( | |
| ( | |
| self._states[..., :2] | |
| + translation.expand_as(self._states[..., :2]), | |
| self._states[..., 2:], | |
| ), | |
| dim=-1, | |
| ), | |
| self.dt, | |
| ) | |
| def get_interaction_cost( | |
| ego_state_future: AbstractState, | |
| ado_state_future_samples: AbstractState, | |
| interaction_cost_function: BaseCostTorch, | |
| ) -> torch.Tensor: | |
| """Computes interaction cost samples from predicted ado future trajectories and a batch of ego | |
| future trajectories | |
| Args: | |
| ego_state_future: ((num_control_samples), num_agents, num_steps_future) ego state future | |
| future trajectory | |
| ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future) | |
| predicted ado state trajectory samples | |
| interaction_cost_function: interaction cost function between ego and (stochastic) ado | |
| dt: time differential between two discrete timesteps in seconds | |
| Returns: | |
| (num_control_samples, num_agents, num_prediction_samples) interaction cost tensor | |
| """ | |
| if len(ego_state_future.shape) == 2: | |
| x_ego = ego_state_future.position.unsqueeze(0) | |
| v_ego = ego_state_future.velocity.unsqueeze(0) | |
| else: | |
| x_ego = ego_state_future.position | |
| v_ego = ego_state_future.velocity | |
| num_control_samples = ego_state_future.shape[0] | |
| ado_position_future_samples = ado_state_future_samples.position.unsqueeze(0).expand( | |
| num_control_samples, -1, -1, -1, -1 | |
| ) | |
| v_samples = ado_state_future_samples.velocity.unsqueeze(0).expand( | |
| num_control_samples, -1, -1, -1, -1 | |
| ) | |
| interaction_cost, _ = interaction_cost_function( | |
| x1=x_ego.unsqueeze(1), | |
| x2=ado_position_future_samples, | |
| v1=v_ego.unsqueeze(1), | |
| v2=v_samples, | |
| ) | |
| return interaction_cost.permute(0, 2, 1) | |
| def evaluate_risk( | |
| risk_level: float, | |
| cost: torch.Tensor, | |
| weights: torch.Tensor, | |
| risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None, | |
| ) -> torch.Tensor: | |
| """Returns a risk tensor given costs and optionally a risk level | |
| Args: | |
| risk_level (optional): a risk-level float. If 0.0, risk-neutral expectation will be | |
| returned. Defaults to 0.0. | |
| cost: (num_control_samples, num_agents, num_prediction_samples) cost tensor | |
| weights: (num_control_samples, num_agents, num_prediction_samples) probability weight of the cost tensor | |
| risk_estimator (optional): a Monte Carlo risk estimator. Defaults to None. | |
| Returns: | |
| (num_control_samples, num_agents) risk tensor | |
| """ | |
| num_control_samples, num_agents, _ = cost.shape | |
| if risk_level == 0.0: | |
| risk = cost.mean(dim=-1) | |
| else: | |
| assert risk_estimator is not None, "no risk estimator is specified" | |
| risk = risk_estimator( | |
| risk_level * torch.ones(num_control_samples, num_agents), | |
| cost, | |
| weights=weights, | |
| ) | |
| return risk | |
| def evaluate_control_sequence( | |
| control_sequence: torch.Tensor, | |
| dynamics_model, | |
| ego_state_history: AbstractState, | |
| ego_state_target_trajectory: AbstractState, | |
| ado_state_future_samples: AbstractState, | |
| sample_weights: torch.Tensor, | |
| interaction_cost_function: BaseCostTorch, | |
| tracking_cost_function: TrackingCost, | |
| risk_level: float = 0.0, | |
| risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None, | |
| ) -> Tuple[float, float]: | |
| """Returns the risk and tracking cost evaluation of the given control sequence | |
| Args: | |
| control_sequence: (num_steps_future, control_dim) tensor of control sequence | |
| dynamics_model: dynamics model for control | |
| ego_state_target_trajectory: (num_steps_future) tensor of ego target | |
| state trajectory | |
| ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future) | |
| of predicted ado trajectory samples states | |
| sample_weights: (num_prediction_samples, num_agents) tensor of probability weights of the samples | |
| intraction_cost_function: interaction cost function between ego and (stochastic) ado | |
| tracking_cost_function: deterministic tracking cost that does not involve ado | |
| risk_level: risk_level (optional): a risk-level float. If 0.0, risk-neutral expectation | |
| is used. Defaults to 0.0. | |
| risk_estimator (optional): a Monte Carlo risk estimator. Defaults to None. | |
| Returns: | |
| tuple of (interaction risk, tracking_cost) | |
| """ | |
| ego_state_current = ego_state_history[..., -1] | |
| ego_state_future = dynamics_model.simulate(ego_state_current, control_sequence) | |
| # state starts with x, y, angle, vx, vy | |
| tracking_cost = tracking_cost_function( | |
| ego_state_future.position, | |
| ego_state_target_trajectory.position, | |
| ego_state_target_trajectory.velocity, | |
| ) | |
| interaction_cost = get_interaction_cost( | |
| ego_state_future, | |
| ado_state_future_samples, | |
| interaction_cost_function, | |
| ) | |
| interaction_risk = evaluate_risk( | |
| risk_level, | |
| interaction_cost, | |
| sample_weights.permute(1, 0).unsqueeze(0).expand_as(interaction_cost), | |
| risk_estimator, | |
| ) | |
| # TODO: averaging over agents but we might want to reduce a different way | |
| return (interaction_risk.mean().item(), tracking_cost.mean().item()) | |