Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from sgm.modules.diffusionmodules.discretizer import Discretization | |
| class Img2ImgDiscretizationWrapper: | |
| """ | |
| wraps a discretizer, and prunes the sigmas | |
| params: | |
| strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) | |
| """ | |
| def __init__(self, discretization: Discretization, strength: float = 1.0): | |
| self.discretization = discretization | |
| self.strength = strength | |
| assert 0.0 <= self.strength <= 1.0 | |
| def __call__(self, *args, **kwargs): | |
| # sigmas start large first, and decrease then | |
| sigmas = self.discretization(*args, **kwargs) | |
| print(f"sigmas after discretization, before pruning img2img: ", sigmas) | |
| sigmas = torch.flip(sigmas, (0,)) | |
| sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] | |
| print("prune index:", max(int(self.strength * len(sigmas)), 1)) | |
| sigmas = torch.flip(sigmas, (0,)) | |
| print(f"sigmas after pruning: ", sigmas) | |
| return sigmas | |
| class Txt2NoisyDiscretizationWrapper: | |
| """ | |
| wraps a discretizer, and prunes the sigmas | |
| params: | |
| strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) | |
| """ | |
| def __init__( | |
| self, discretization: Discretization, strength: float = 0.0, original_steps=None | |
| ): | |
| self.discretization = discretization | |
| self.strength = strength | |
| self.original_steps = original_steps | |
| assert 0.0 <= self.strength <= 1.0 | |
| def __call__(self, *args, **kwargs): | |
| # sigmas start large first, and decrease then | |
| sigmas = self.discretization(*args, **kwargs) | |
| print(f"sigmas after discretization, before pruning img2img: ", sigmas) | |
| sigmas = torch.flip(sigmas, (0,)) | |
| if self.original_steps is None: | |
| steps = len(sigmas) | |
| else: | |
| steps = self.original_steps + 1 | |
| prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) | |
| sigmas = sigmas[prune_index:] | |
| print("prune index:", prune_index) | |
| sigmas = torch.flip(sigmas, (0,)) | |
| print(f"sigmas after pruning: ", sigmas) | |
| return sigmas | |