Spaces:
Sleeping
Sleeping
| # Copyright 2021 AlQuraishi Laboratory | |
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from functools import partial | |
| import logging | |
| from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional | |
| import torch | |
| import torch.nn as nn | |
| def add(m1, m2, inplace): | |
| # The first operation in a checkpoint can't be in-place, but it's | |
| # nice to have in-place addition during inference. Thus... | |
| if(not inplace): | |
| m1 = m1 + m2 | |
| else: | |
| m1 += m2 | |
| return m1 | |
| def permute_final_dims(tensor: torch.Tensor, inds: List[int]): | |
| zero_index = -1 * len(inds) | |
| first_inds = list(range(len(tensor.shape[:zero_index]))) | |
| return tensor.permute(first_inds + [zero_index + i for i in inds]) | |
| def flatten_final_dims(t: torch.Tensor, no_dims: int): | |
| return t.reshape(t.shape[:-no_dims] + (-1,)) | |
| def masked_mean(mask, value, dim, eps=1e-4): | |
| mask = mask.expand(*value.shape) | |
| return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) | |
| def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): | |
| boundaries = torch.linspace( | |
| min_bin, max_bin, no_bins - 1, device=pts.device | |
| ) | |
| dists = torch.sqrt( | |
| torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) | |
| ) | |
| return torch.bucketize(dists, boundaries) | |
| def dict_multimap(fn, dicts): | |
| first = dicts[0] | |
| new_dict = {} | |
| for k, v in first.items(): | |
| all_v = [d[k] for d in dicts] | |
| if type(v) is dict: | |
| new_dict[k] = dict_multimap(fn, all_v) | |
| else: | |
| new_dict[k] = fn(all_v) | |
| return new_dict | |
| def one_hot(x, v_bins): | |
| reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) | |
| diffs = x[..., None] - reshaped_bins | |
| am = torch.argmin(torch.abs(diffs), dim=-1) | |
| return nn.functional.one_hot(am, num_classes=len(v_bins)).float() | |
| def batched_gather(data, inds, dim=0, no_batch_dims=0): | |
| ranges = [] | |
| for i, s in enumerate(data.shape[:no_batch_dims]): | |
| r = torch.arange(s) | |
| r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) | |
| ranges.append(r) | |
| remaining_dims = [ | |
| slice(None) for _ in range(len(data.shape) - no_batch_dims) | |
| ] | |
| remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds | |
| ranges.extend(remaining_dims) | |
| return data[ranges] | |
| # With tree_map, a poor man's JAX tree_map | |
| def dict_map(fn, dic, leaf_type): | |
| new_dict = {} | |
| for k, v in dic.items(): | |
| # print("dictttt", k,type(v), v) | |
| if type(v) is dict: | |
| new_dict[k] = dict_map(fn, v, leaf_type) | |
| else: | |
| new_dict[k] = tree_map(fn, v, leaf_type) | |
| return new_dict | |
| def tree_map(fn, tree, leaf_type): | |
| if isinstance(tree, dict): | |
| return dict_map(fn, tree, leaf_type) | |
| elif isinstance(tree, list): | |
| return [tree_map(fn, x, leaf_type) for x in tree] | |
| elif isinstance(tree, tuple): | |
| return tuple([tree_map(fn, x, leaf_type) for x in tree]) | |
| elif isinstance(tree, leaf_type): | |
| return fn(tree) | |
| else: | |
| print(type(tree)) | |
| raise ValueError("Not supported") | |
| tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) | |