Spaces:
Sleeping
Sleeping
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Rot3Array Matrix Class.""" | |
| from __future__ import annotations | |
| import dataclasses | |
| from typing import List | |
| import torch | |
| from dockformerpp.utils.geometry import utils | |
| from dockformerpp.utils.geometry import vector | |
| from dockformerpp.utils.tensor_utils import tensor_tree_map | |
| COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] | |
| class Rot3Array: | |
| """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" | |
| xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32}) | |
| xy: torch.Tensor | |
| xz: torch.Tensor | |
| yx: torch.Tensor | |
| yy: torch.Tensor | |
| yz: torch.Tensor | |
| zx: torch.Tensor | |
| zy: torch.Tensor | |
| zz: torch.Tensor | |
| __array_ufunc__ = None | |
| def __getitem__(self, index): | |
| field_names = utils.get_field_names(Rot3Array) | |
| return Rot3Array( | |
| **{ | |
| name: getattr(self, name)[index] | |
| for name in field_names | |
| } | |
| ) | |
| def __mul__(self, other: torch.Tensor): | |
| field_names = utils.get_field_names(Rot3Array) | |
| return Rot3Array( | |
| **{ | |
| name: getattr(self, name) * other | |
| for name in field_names | |
| } | |
| ) | |
| def __matmul__(self, other: Rot3Array) -> Rot3Array: | |
| """Composes two Rot3Arrays.""" | |
| c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) | |
| c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) | |
| c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) | |
| return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) | |
| def map_tensor_fn(self, fn) -> Rot3Array: | |
| field_names = utils.get_field_names(Rot3Array) | |
| return Rot3Array( | |
| **{ | |
| name: fn(getattr(self, name)) | |
| for name in field_names | |
| } | |
| ) | |
| def inverse(self) -> Rot3Array: | |
| """Returns inverse of Rot3Array.""" | |
| return Rot3Array( | |
| self.xx, self.yx, self.zx, | |
| self.xy, self.yy, self.zy, | |
| self.xz, self.yz, self.zz | |
| ) | |
| def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: | |
| """Applies Rot3Array to point.""" | |
| return vector.Vec3Array( | |
| self.xx * point.x + self.xy * point.y + self.xz * point.z, | |
| self.yx * point.x + self.yy * point.y + self.yz * point.z, | |
| self.zx * point.x + self.zy * point.y + self.zz * point.z | |
| ) | |
| def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: | |
| """Applies inverse Rot3Array to point.""" | |
| return self.inverse().apply_to_point(point) | |
| def unsqueeze(self, dim: int): | |
| return Rot3Array( | |
| *tensor_tree_map( | |
| lambda t: t.unsqueeze(dim), | |
| [getattr(self, c) for c in COMPONENTS] | |
| ) | |
| ) | |
| def stop_gradient(self) -> Rot3Array: | |
| return Rot3Array( | |
| *[getattr(self, c).detach() for c in COMPONENTS] | |
| ) | |
| def identity(cls, shape, device) -> Rot3Array: | |
| """Returns identity of given shape.""" | |
| ones = torch.ones(shape, dtype=torch.float32, device=device) | |
| zeros = torch.zeros(shape, dtype=torch.float32, device=device) | |
| return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) | |
| def from_two_vectors( | |
| cls, e0: vector.Vec3Array, | |
| e1: vector.Vec3Array | |
| ) -> Rot3Array: | |
| """Construct Rot3Array from two Vectors. | |
| Rot3Array is constructed such that in the corresponding frame 'e0' lies on | |
| the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. | |
| Args: | |
| e0: Vector | |
| e1: Vector | |
| Returns: | |
| Rot3Array | |
| """ | |
| # Normalize the unit vector for the x-axis, e0. | |
| e0 = e0.normalized() | |
| # make e1 perpendicular to e0. | |
| c = e1.dot(e0) | |
| e1 = (e1 - c * e0).normalized() | |
| # Compute e2 as cross product of e0 and e1. | |
| e2 = e0.cross(e1) | |
| return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) | |
| def from_array(cls, array: torch.Tensor) -> Rot3Array: | |
| """Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" | |
| rows = torch.unbind(array, dim=-2) | |
| rc = [torch.unbind(e, dim=-1) for e in rows] | |
| return cls(*[e for row in rc for e in row]) | |
| def to_tensor(self) -> torch.Tensor: | |
| """Convert Rot3Array to array of shape [..., 3, 3].""" | |
| return torch.stack( | |
| [ | |
| torch.stack([self.xx, self.xy, self.xz], dim=-1), | |
| torch.stack([self.yx, self.yy, self.yz], dim=-1), | |
| torch.stack([self.zx, self.zy, self.zz], dim=-1) | |
| ], | |
| dim=-2 | |
| ) | |
| def from_quaternion(cls, | |
| w: torch.Tensor, | |
| x: torch.Tensor, | |
| y: torch.Tensor, | |
| z: torch.Tensor, | |
| normalize: bool = True, | |
| eps: float = 1e-6 | |
| ) -> Rot3Array: | |
| """Construct Rot3Array from components of quaternion.""" | |
| if normalize: | |
| inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps)) | |
| w = w * inv_norm | |
| x = x * inv_norm | |
| y = y * inv_norm | |
| z = z * inv_norm | |
| xx = 1.0 - 2.0 * (y ** 2 + z ** 2) | |
| xy = 2.0 * (x * y - w * z) | |
| xz = 2.0 * (x * z + w * y) | |
| yx = 2.0 * (x * y + w * z) | |
| yy = 1.0 - 2.0 * (x ** 2 + z ** 2) | |
| yz = 2.0 * (y * z - w * x) | |
| zx = 2.0 * (x * z - w * y) | |
| zy = 2.0 * (y * z + w * x) | |
| zz = 1.0 - 2.0 * (x ** 2 + y ** 2) | |
| return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) | |
| def reshape(self, new_shape): | |
| field_names = utils.get_field_names(Rot3Array) | |
| reshape_fn = lambda t: t.reshape(new_shape) | |
| return Rot3Array( | |
| **{ | |
| name: reshape_fn(getattr(self, name)) | |
| for name in field_names | |
| } | |
| ) | |
| def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array: | |
| field_names = utils.get_field_names(Rot3Array) | |
| cat_fn = lambda l: torch.cat(l, dim=dim) | |
| return cls( | |
| **{ | |
| name: cat_fn([getattr(r, name) for r in rots]) | |
| for name in field_names | |
| } | |
| ) | |