Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| def compute_same_pad(kernel_size, stride): | |
| if isinstance(kernel_size, int): | |
| kernel_size = [kernel_size] | |
| if isinstance(stride, int): | |
| stride = [stride] | |
| assert len(stride) == len( | |
| kernel_size | |
| ), "Pass kernel size and stride both as int, or both as equal length iterable" | |
| return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)] | |
| def uniform_binning_correction(x, n_bits=8): | |
| """Replaces x^i with q^i(x) = U(x, x + 1.0 / 256.0). | |
| Args: | |
| x: 4-D Tensor of shape (NCHW) | |
| n_bits: optional. | |
| Returns: | |
| x: x ~ U(x, x + 1.0 / 256) | |
| objective: Equivalent to -q(x)*log(q(x)). | |
| """ | |
| b, c, h, w = x.size() | |
| n_bins = 2**n_bits | |
| chw = c * h * w | |
| x += torch.zeros_like(x).uniform_(0, 1.0 / n_bins) | |
| objective = -math.log(n_bins) * chw * torch.ones(b, device=x.device) | |
| return x, objective | |
| def split_feature(tensor, type="split"): | |
| """ | |
| type = ["split", "cross"] | |
| """ | |
| C = tensor.size(1) | |
| if type == "split": | |
| # return tensor[:, : C // 2, ...], tensor[:, C // 2 :, ...] | |
| return tensor[:, :1, ...], tensor[:, 1:, ...] | |
| elif type == "cross": | |
| # return tensor[:, 0::2, ...], tensor[:, 1::2, ...] | |
| return tensor[:, 0::2, ...], tensor[:, 1::2, ...] | |