Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| class BaseScaler(): | |
| def __init__(self): | |
| self.stretched_limits = None | |
| def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1): | |
| min_logSNR = schedule(torch.ones(1), shift=shift) | |
| max_logSNR = schedule(torch.zeros(1), shift=shift) | |
| min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1] | |
| max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0] | |
| self.stretched_limits = [min_a, max_a, min_b, max_b] | |
| return self.stretched_limits | |
| def stretch_limits(self, a, b): | |
| min_a, max_a, min_b, max_b = self.stretched_limits | |
| return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b) | |
| def scalers(self, logSNR): | |
| raise NotImplementedError("this method needs to be overridden") | |
| def __call__(self, logSNR): | |
| a, b = self.scalers(logSNR) | |
| if self.stretched_limits is not None: | |
| a, b = self.stretch_limits(a, b) | |
| return a, b | |
| class VPScaler(BaseScaler): | |
| def scalers(self, logSNR): | |
| a_squared = logSNR.sigmoid() | |
| a = a_squared.sqrt() | |
| b = (1-a_squared).sqrt() | |
| return a, b | |
| class LERPScaler(BaseScaler): | |
| def scalers(self, logSNR): | |
| _a = logSNR.exp() - 1 | |
| _a[_a == 0] = 1e-3 # Avoid division by zero | |
| a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) | |
| b = 1-a | |
| return a, b | |