| """Copyright (c) Meta Platforms, Inc. and affiliates.""" | |
| import math | |
| import torch | |
| from geoopt.manifolds import Sphere as geoopt_Sphere | |
| class Sphere(geoopt_Sphere): | |
| def transp(self, x, y, v): | |
| denom = 1 + self.inner(x, x, y, keepdim=True) | |
| res = v - self.inner(x, y, v, keepdim=True) / denom * (x + y) | |
| cond = denom.gt(1e-3) | |
| return torch.where(cond, res, -v) | |
| def uniform_logprob(self, x): | |
| dim = x.shape[-1] | |
| return torch.full_like( | |
| x[..., 0], | |
| math.lgamma(dim / 2) - (math.log(2) + (dim / 2) * math.log(math.pi)), | |
| ) | |
| def random_base(self, *args, **kwargs): | |
| return self.random_uniform(*args, **kwargs) | |
| def base_logprob(self, *args, **kwargs): | |
| return self.uniform_logprob(*args, **kwargs) | |
| def geodesic(manifold, start_point, end_point): | |
| shooting_tangent_vec = manifold.logmap(start_point, end_point) | |
| def path(t): | |
| """Generate parameterized function for geodesic curve. | |
| Parameters | |
| ---------- | |
| t : array-like, shape=[n_points,] | |
| Times at which to compute points of the geodesics. | |
| """ | |
| tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec) | |
| points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs) | |
| return points_at_time_t | |
| return path | |