Spaces:
Running
Running
| """ Adapted from https://github.com/SongweiGe/TATS""" | |
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
| import warnings | |
| import torch | |
| import imageio | |
| import math | |
| import numpy as np | |
| import sys | |
| import pdb as pdb_original | |
| # import SimpleITK as sitk | |
| import logging | |
| import imageio.core.util | |
| logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) | |
| def get_single_device(cpu=True): | |
| if cpu: | |
| return torch.device('cpu') | |
| elif torch.cuda.is_available(): | |
| return torch.device('cuda') | |
| elif torch.xpu.is_available(): | |
| return torch.device('xpu') | |
| elif torch.mps.is_available(): | |
| return torch.device('mps') | |
| return None | |
| class ForkedPdb(pdb_original.Pdb): | |
| """A Pdb subclass that may be used | |
| from a forked multiprocessing child | |
| """ | |
| def interaction(self, *args, **kwargs): | |
| _stdin = sys.stdin | |
| try: | |
| sys.stdin = open('/dev/stdin') | |
| pdb_original.Pdb.interaction(self, *args, **kwargs) | |
| finally: | |
| sys.stdin = _stdin | |
| # Shifts src_tf dim to dest dim | |
| # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) | |
| def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): | |
| n_dims = len(x.shape) | |
| if src_dim < 0: | |
| src_dim = n_dims + src_dim | |
| if dest_dim < 0: | |
| dest_dim = n_dims + dest_dim | |
| assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims | |
| dims = list(range(n_dims)) | |
| del dims[src_dim] | |
| permutation = [] | |
| ctr = 0 | |
| for i in range(n_dims): | |
| if i == dest_dim: | |
| permutation.append(src_dim) | |
| else: | |
| permutation.append(dims[ctr]) | |
| ctr += 1 | |
| x = x.permute(permutation) | |
| if make_contiguous: | |
| x = x.contiguous() | |
| return x | |
| # reshapes tensor start from dim i (inclusive) | |
| # to dim j (exclusive) to the desired shape | |
| # e.g. if x.shape = (b, thw, c) then | |
| # view_range(x, 1, 2, (t, h, w)) returns | |
| # x of shape (b, t, h, w, c) | |
| def view_range(x, i, j, shape): | |
| shape = tuple(shape) | |
| n_dims = len(x.shape) | |
| if i < 0: | |
| i = n_dims + i | |
| if j is None: | |
| j = n_dims | |
| elif j < 0: | |
| j = n_dims + j | |
| assert 0 <= i < j <= n_dims | |
| x_shape = x.shape | |
| target_shape = x_shape[:i] + shape + x_shape[j:] | |
| return x.view(target_shape) | |
| def accuracy(output, target, topk=(1,)): | |
| """Computes the accuracy over the k top predictions for the specified values of k""" | |
| with torch.no_grad(): | |
| maxk = max(topk) | |
| batch_size = target.size(0) | |
| _, pred = output.topk(maxk, 1, True, True) | |
| pred = pred.t() | |
| correct = pred.eq(target.reshape(1, -1).expand_as(pred)) | |
| res = [] | |
| for k in topk: | |
| correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | |
| res.append(correct_k.mul_(100.0 / batch_size)) | |
| return res | |
| def tensor_slice(x, begin, size): | |
| assert all([b >= 0 for b in begin]) | |
| size = [l - b if s == -1 else s | |
| for s, b, l in zip(size, begin, x.shape)] | |
| assert all([s >= 0 for s in size]) | |
| slices = [slice(b, b + s) for b, s in zip(begin, size)] | |
| return x[slices] | |
| def adopt_weight(global_step, threshold=0, value=0.): | |
| weight = 1 | |
| if global_step < threshold: | |
| weight = value | |
| return weight | |
| def comp_getattr(args, attr_name, default=None): | |
| if hasattr(args, attr_name): | |
| return getattr(args, attr_name) | |
| else: | |
| return default | |
| def visualize_tensors(t, name=None, nest=0): | |
| if name is not None: | |
| print(name, "current nest: ", nest) | |
| print("type: ", type(t)) | |
| if 'dict' in str(type(t)): | |
| print(t.keys()) | |
| for k in t.keys(): | |
| if t[k] is None: | |
| print(k, "None") | |
| else: | |
| if 'Tensor' in str(type(t[k])): | |
| print(k, t[k].shape) | |
| elif 'dict' in str(type(t[k])): | |
| print(k, 'dict') | |
| visualize_tensors(t[k], name, nest + 1) | |
| elif 'list' in str(type(t[k])): | |
| print(k, len(t[k])) | |
| visualize_tensors(t[k], name, nest + 1) | |
| elif 'list' in str(type(t)): | |
| print("list length: ", len(t)) | |
| for t2 in t: | |
| visualize_tensors(t2, name, nest + 1) | |
| elif 'Tensor' in str(type(t)): | |
| print(t.shape) | |
| else: | |
| print(t) | |
| return "" | |