Spaces:
Running
on
Zero
Running
on
Zero
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| from PIL import Image | |
| from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( | |
| MultiUpscaler, | |
| UpscalerCheckpoints, | |
| ) | |
| from esrgan_model import UpscalerESRGAN | |
| class ESRGANUpscalerCheckpoints(UpscalerCheckpoints): | |
| esrgan: Path | |
| class ESRGANUpscaler(MultiUpscaler): | |
| def __init__( | |
| self, | |
| checkpoints: ESRGANUpscalerCheckpoints, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> None: | |
| super().__init__(checkpoints=checkpoints, device=device, dtype=dtype) | |
| self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype) | |
| self.esrgan.to(device=device, dtype=dtype) | |
| def to(self, device: torch.device, dtype: torch.dtype): | |
| self.esrgan.to(device=device, dtype=dtype) | |
| self.sd = self.sd.to(device=device, dtype=dtype) | |
| self.device = device | |
| self.dtype = dtype | |
| def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image: | |
| image = self.esrgan.upscale_with_tiling(image) | |
| return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4) | |