Spaces:
Runtime error
Runtime error
| import contextlib | |
| import torch | |
| from modules import errors | |
| # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility | |
| has_mps = getattr(torch, 'has_mps', False) | |
| cpu = torch.device("cpu") | |
| def get_optimal_device(): | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if has_mps: | |
| return torch.device("mps") | |
| return cpu | |
| def torch_gc(): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| def enable_tf32(): | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| errors.run(enable_tf32, "Enabling TF32") | |
| device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() | |
| dtype = torch.float16 | |
| def randn(seed, shape): | |
| # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. | |
| if device.type == 'mps': | |
| generator = torch.Generator(device=cpu) | |
| generator.manual_seed(seed) | |
| noise = torch.randn(shape, generator=generator, device=cpu).to(device) | |
| return noise | |
| torch.manual_seed(seed) | |
| return torch.randn(shape, device=device) | |
| def randn_without_seed(shape): | |
| # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. | |
| if device.type == 'mps': | |
| generator = torch.Generator(device=cpu) | |
| noise = torch.randn(shape, generator=generator, device=cpu).to(device) | |
| return noise | |
| return torch.randn(shape, device=device) | |
| def autocast(): | |
| from modules import shared | |
| if dtype == torch.float32 or shared.cmd_opts.precision == "full": | |
| return contextlib.nullcontext() | |
| return torch.autocast("cuda") | |