Spaces:
Configuration error
Configuration error
| import os | |
| import torch | |
| from omegaconf import OmegaConf | |
| import comfy.utils | |
| import comfy.model_management as mm | |
| import folder_paths | |
| import torch.cuda | |
| import torch.nn.functional as F | |
| from .sgm.util import instantiate_from_config | |
| from .SUPIR.util import convert_dtype, load_state_dict | |
| from .sgm.modules.distributions.distributions import DiagonalGaussianDistribution | |
| import open_clip | |
| from contextlib import contextmanager, nullcontext | |
| import gc | |
| from contextlib import nullcontext | |
| try: | |
| from accelerate import init_empty_weights | |
| from accelerate.utils import set_module_tensor_to_device | |
| is_accelerate_available = True | |
| except: | |
| pass | |
| from transformers import ( | |
| CLIPTextModel, | |
| CLIPTokenizer, | |
| CLIPTextConfig, | |
| ) | |
| script_directory = os.path.dirname(os.path.abspath(__file__)) | |
| def dummy_build_vision_tower(*args, **kwargs): | |
| # Monkey patch the CLIP class before you create an instance. | |
| return None | |
| def patch_build_vision_tower(): | |
| original_build_vision_tower = open_clip.model._build_vision_tower | |
| open_clip.model._build_vision_tower = dummy_build_vision_tower | |
| try: | |
| yield | |
| finally: | |
| open_clip.model._build_vision_tower = original_build_vision_tower | |
| def build_text_model_from_openai_state_dict( | |
| state_dict: dict, | |
| device, | |
| cast_dtype=torch.float16, | |
| ): | |
| embed_dim = state_dict["text_projection"].shape[1] | |
| context_length = state_dict["positional_embedding"].shape[0] | |
| vocab_size = state_dict["token_embedding.weight"].shape[0] | |
| transformer_width = state_dict["ln_final.weight"].shape[0] | |
| transformer_heads = transformer_width // 64 | |
| transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) | |
| vision_cfg = None | |
| text_cfg = open_clip.CLIPTextCfg( | |
| context_length=context_length, | |
| vocab_size=vocab_size, | |
| width=transformer_width, | |
| heads=transformer_heads, | |
| layers=transformer_layers, | |
| ) | |
| with patch_build_vision_tower(): | |
| with (init_empty_weights() if is_accelerate_available else nullcontext()): | |
| model = open_clip.CLIP( | |
| embed_dim, | |
| vision_cfg=vision_cfg, | |
| text_cfg=text_cfg, | |
| quick_gelu=True, | |
| cast_dtype=cast_dtype, | |
| ) | |
| if is_accelerate_available: | |
| for key in state_dict: | |
| set_module_tensor_to_device(model, key, device=device, value=state_dict[key]) | |
| else: | |
| model.load_state_dict(state_dict, strict=False) | |
| model = model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| return model | |
| class SUPIR_encode: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "SUPIR_VAE": ("SUPIRVAE",), | |
| "image": ("IMAGE",), | |
| "use_tiled_vae": ("BOOLEAN", {"default": True}), | |
| "encoder_tile_size": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}), | |
| "encoder_dtype": ( | |
| [ | |
| 'bf16', | |
| 'fp32', | |
| 'auto' | |
| ], { | |
| "default": 'auto' | |
| }), | |
| } | |
| } | |
| RETURN_TYPES = ("LATENT",) | |
| RETURN_NAMES = ("latent",) | |
| FUNCTION = "encode" | |
| CATEGORY = "SUPIR" | |
| def encode(self, SUPIR_VAE, image, encoder_dtype, use_tiled_vae, encoder_tile_size): | |
| device = mm.get_torch_device() | |
| mm.unload_all_models() | |
| if encoder_dtype == 'auto': | |
| try: | |
| if mm.should_use_bf16(): | |
| print("Encoder using bf16") | |
| vae_dtype = 'bf16' | |
| else: | |
| print("Encoder using fp32") | |
| vae_dtype = 'fp32' | |
| except: | |
| raise AttributeError("ComfyUI version too old, can't autodetect properly. Set your dtypes manually.") | |
| else: | |
| vae_dtype = encoder_dtype | |
| print(f"Encoder using {vae_dtype}") | |
| dtype = convert_dtype(vae_dtype) | |
| image = image.permute(0, 3, 1, 2) | |
| B, C, H, W = image.shape | |
| downscale_ratio = 32 | |
| orig_H, orig_W = H, W | |
| if W % downscale_ratio != 0: | |
| W = W - (W % downscale_ratio) | |
| if H % downscale_ratio != 0: | |
| H = H - (H % downscale_ratio) | |
| if orig_H % downscale_ratio != 0 or orig_W % downscale_ratio != 0: | |
| image = F.interpolate(image, size=(H, W), mode="bicubic") | |
| resized_image = image.to(device) | |
| if use_tiled_vae: | |
| from .SUPIR.utils.tilevae import VAEHook | |
| # Store the `original_forward` only if it hasn't been stored already | |
| if not hasattr(SUPIR_VAE.encoder, 'original_forward'): | |
| SUPIR_VAE.encoder.original_forward = SUPIR_VAE.encoder.forward | |
| SUPIR_VAE.encoder.forward = VAEHook( | |
| SUPIR_VAE.encoder, encoder_tile_size, is_decoder=False, fast_decoder=False, | |
| fast_encoder=False, color_fix=False, to_gpu=True) | |
| else: | |
| # Only assign `original_forward` back if it exists | |
| if hasattr(SUPIR_VAE.encoder, 'original_forward'): | |
| SUPIR_VAE.encoder.forward = SUPIR_VAE.encoder.original_forward | |
| pbar = comfy.utils.ProgressBar(B) | |
| out = [] | |
| for img in resized_image: | |
| SUPIR_VAE.to(dtype).to(device) | |
| autocast_condition = (dtype != torch.float32) and not comfy.model_management.is_device_mps(device) | |
| with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): | |
| z = SUPIR_VAE.encode(img.unsqueeze(0)) | |
| z = z * 0.13025 | |
| out.append(z) | |
| pbar.update(1) | |
| if len(out[0].shape) == 4: | |
| samples_out_stacked = torch.cat(out, dim=0) | |
| else: | |
| samples_out_stacked = torch.stack(out, dim=0) | |
| return ({"samples":samples_out_stacked, "original_size": [orig_H, orig_W]},) | |
| class SUPIR_decode: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "SUPIR_VAE": ("SUPIRVAE",), | |
| "latents": ("LATENT",), | |
| "use_tiled_vae": ("BOOLEAN", {"default": True}), | |
| "decoder_tile_size": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}), | |
| } | |
| } | |
| RETURN_TYPES = ("IMAGE",) | |
| RETURN_NAMES = ("image",) | |
| FUNCTION = "decode" | |
| CATEGORY = "SUPIR" | |
| def decode(self, SUPIR_VAE, latents, use_tiled_vae, decoder_tile_size): | |
| device = mm.get_torch_device() | |
| mm.unload_all_models() | |
| samples = latents["samples"] | |
| B, H, W, C = samples.shape | |
| pbar = comfy.utils.ProgressBar(B) | |
| if mm.should_use_bf16(): | |
| print("Decoder using bf16") | |
| dtype = torch.bfloat16 | |
| else: | |
| print("Decoder using fp32") | |
| dtype = torch.float32 | |
| print("SUPIR decoder using", dtype) | |
| SUPIR_VAE.to(dtype).to(device) | |
| samples = samples.to(device) | |
| if use_tiled_vae: | |
| from .SUPIR.utils.tilevae import VAEHook | |
| # Store the `original_forward` only if it hasn't been stored already | |
| if not hasattr(SUPIR_VAE.decoder, 'original_forward'): | |
| SUPIR_VAE.decoder.original_forward = SUPIR_VAE.decoder.forward | |
| SUPIR_VAE.decoder.forward = VAEHook( | |
| SUPIR_VAE.decoder, decoder_tile_size // 8, is_decoder=True, fast_decoder=False, | |
| fast_encoder=False, color_fix=False, to_gpu=True) | |
| else: | |
| # Only assign `original_forward` back if it exists | |
| if hasattr(SUPIR_VAE.decoder, 'original_forward'): | |
| SUPIR_VAE.decoder.forward = SUPIR_VAE.decoder.original_forward | |
| out = [] | |
| for sample in samples: | |
| autocast_condition = (dtype != torch.float32) and not comfy.model_management.is_device_mps(device) | |
| with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): | |
| sample = 1.0 / 0.13025 * sample | |
| decoded_image = SUPIR_VAE.decode(sample.unsqueeze(0)) | |
| out.append(decoded_image) | |
| pbar.update(1) | |
| decoded_out= torch.cat(out, dim=0).float() | |
| if "original_size" in latents and latents["original_size"] is not None: | |
| orig_H, orig_W = latents["original_size"] | |
| if decoded_out.shape[2] != orig_H or decoded_out.shape[3] != orig_W: | |
| print("Restoring original dimensions: ", orig_W,"x",orig_H) | |
| decoded_out = F.interpolate(decoded_out, size=(orig_H, orig_W), mode="bicubic") | |
| decoded_out = torch.clip(decoded_out, 0, 1) | |
| decoded_out = decoded_out.cpu().to(torch.float32).permute(0, 2, 3, 1) | |
| return (decoded_out,) | |
| class SUPIR_first_stage: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "SUPIR_VAE": ("SUPIRVAE",), | |
| "image": ("IMAGE",), | |
| "use_tiled_vae": ("BOOLEAN", {"default": True}), | |
| "encoder_tile_size": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}), | |
| "decoder_tile_size": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}), | |
| "encoder_dtype": ( | |
| [ | |
| 'bf16', | |
| 'fp32', | |
| 'auto' | |
| ], { | |
| "default": 'auto' | |
| }), | |
| } | |
| } | |
| RETURN_TYPES = ("SUPIRVAE", "IMAGE", "LATENT",) | |
| RETURN_NAMES = ("SUPIR_VAE", "denoised_image", "denoised_latents",) | |
| FUNCTION = "process" | |
| CATEGORY = "SUPIR" | |
| DESCRIPTION = """ | |
| SUPIR "first stage" processing. | |
| Encodes and decodes the image using SUPIR's "denoise_encoder", purpose | |
| is to fix compression artifacts and such, ends up blurring the image often | |
| which is expected. Can be replaced with any other denoiser/blur or not used at all. | |
| """ | |
| def process(self, SUPIR_VAE, image, encoder_dtype, use_tiled_vae, encoder_tile_size, decoder_tile_size): | |
| device = mm.get_torch_device() | |
| mm.unload_all_models() | |
| if encoder_dtype == 'auto': | |
| try: | |
| if mm.should_use_bf16(): | |
| print("Encoder using bf16") | |
| vae_dtype = 'bf16' | |
| else: | |
| print("Encoder using fp32") | |
| vae_dtype = 'fp32' | |
| except: | |
| raise AttributeError("ComfyUI version too old, can't autodetect properly. Set your dtypes manually.") | |
| else: | |
| vae_dtype = encoder_dtype | |
| print(f"Encoder using {vae_dtype}") | |
| dtype = convert_dtype(vae_dtype) | |
| if use_tiled_vae: | |
| from .SUPIR.utils.tilevae import VAEHook | |
| # Store the `original_forward` only if it hasn't been stored already | |
| if not hasattr(SUPIR_VAE.encoder, 'original_forward'): | |
| SUPIR_VAE.denoise_encoder.original_forward = SUPIR_VAE.denoise_encoder.forward | |
| SUPIR_VAE.decoder.original_forward = SUPIR_VAE.decoder.forward | |
| SUPIR_VAE.denoise_encoder.forward = VAEHook( | |
| SUPIR_VAE.denoise_encoder, encoder_tile_size, is_decoder=False, fast_decoder=False, | |
| fast_encoder=False, color_fix=False, to_gpu=True) | |
| SUPIR_VAE.decoder.forward = VAEHook( | |
| SUPIR_VAE.decoder, decoder_tile_size // 8, is_decoder=True, fast_decoder=False, | |
| fast_encoder=False, color_fix=False, to_gpu=True) | |
| else: | |
| # Only assign `original_forward` back if it exists | |
| if hasattr(SUPIR_VAE.denoise_encoder, 'original_forward'): | |
| SUPIR_VAE.denoise_encoder.forward = SUPIR_VAE.denoise_encoder.original_forward | |
| SUPIR_VAE.decoder.forward = SUPIR_VAE.decoder.original_forward | |
| image = image.permute(0, 3, 1, 2) | |
| B, C, H, W = image.shape | |
| downscale_ratio = 32 | |
| orig_H, orig_W = H, W | |
| if W % downscale_ratio != 0: | |
| W = W - (W % downscale_ratio) | |
| if H % downscale_ratio != 0: | |
| H = H - (H % downscale_ratio) | |
| if orig_H % downscale_ratio != 0 or orig_W % downscale_ratio != 0: | |
| image = F.interpolate(image, size=(H, W), mode="bicubic") | |
| resized_image = image.to(device) | |
| pbar = comfy.utils.ProgressBar(B) | |
| out = [] | |
| out_samples = [] | |
| for img in resized_image: | |
| SUPIR_VAE.to(dtype).to(device) | |
| autocast_condition = (dtype != torch.float32) and not comfy.model_management.is_device_mps(device) | |
| with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): | |
| h = SUPIR_VAE.denoise_encoder(img.unsqueeze(0)) | |
| moments = SUPIR_VAE.quant_conv(h) | |
| posterior = DiagonalGaussianDistribution(moments) | |
| sample = posterior.sample() | |
| decoded_images = SUPIR_VAE.decode(sample).float() | |
| out.append(decoded_images.cpu()) | |
| out_samples.append(sample.cpu() * 0.13025) | |
| pbar.update(1) | |
| out_stacked = torch.cat(out, dim=0).to(torch.float32).permute(0, 2, 3, 1) | |
| out_samples_stacked = torch.cat(out_samples, dim=0) | |
| original_size = [orig_H, orig_W] | |
| return (SUPIR_VAE, out_stacked, {"samples": out_samples_stacked, "original_size": original_size},) | |
| class SUPIR_sample: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "SUPIR_model": ("SUPIRMODEL",), | |
| "latents": ("LATENT",), | |
| "positive": ("SUPIR_cond_pos",), | |
| "negative": ("SUPIR_cond_neg",), | |
| "seed": ("INT", {"default": 123, "min": 0, "max": 0xffffffffffffffff, "step": 1}), | |
| "steps": ("INT", {"default": 45, "min": 3, "max": 4096, "step": 1}), | |
| "cfg_scale_start": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 100.0, "step": 0.01}), | |
| "cfg_scale_end": ("FLOAT", {"default": 4.0, "min": 0, "max": 100.0, "step": 0.01}), | |
| "EDM_s_churn": ("INT", {"default": 5, "min": 0, "max": 40, "step": 1}), | |
| "s_noise": ("FLOAT", {"default": 1.003, "min": 1.0, "max": 1.1, "step": 0.001}), | |
| "DPMPP_eta": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.01}), | |
| "control_scale_start": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.01}), | |
| "control_scale_end": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.01}), | |
| "restore_cfg": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20.0, "step": 0.01}), | |
| "keep_model_loaded": ("BOOLEAN", {"default": False}), | |
| "sampler": ( | |
| [ | |
| 'RestoreDPMPP2MSampler', | |
| 'RestoreEDMSampler', | |
| 'TiledRestoreDPMPP2MSampler', | |
| 'TiledRestoreEDMSampler', | |
| ], { | |
| "default": 'RestoreEDMSampler' | |
| }), | |
| }, | |
| "optional": { | |
| "sampler_tile_size": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 32}), | |
| "sampler_tile_stride": ("INT", {"default": 512, "min": 32, "max": 2048, "step": 32}), | |
| } | |
| } | |
| RETURN_TYPES = ("LATENT",) | |
| RETURN_NAMES = ("latent",) | |
| FUNCTION = "sample" | |
| CATEGORY = "SUPIR" | |
| DESCRIPTION = """ | |
| - **latent:** | |
| Latent to sample from, when using SUPIR latent this is just for the noise shape, | |
| it's actually not used otherwise here. Identical to feeding this comfy empty latent. | |
| If fed anything else it's used as it is, no noise is added. | |
| - **cfg:** | |
| Linearly scaled CFG is always used, first step will use the cfg_scale_start value, | |
| and that is interpolated to the cfg_scale_end value at last step. | |
| To disable scaling set these values to be the same. | |
| - **EDM_s_churn:** | |
| controls the rate of adaptation of the diffusion process to changes in noise levels | |
| over time. Has no effect with DPMPP samplers. | |
| - **s_noise:** | |
| This parameter directly controls the amount of noise added to the image at each | |
| step of the diffusion process. | |
| - **DPMPP_eta:** | |
| Scaling factor that influences the diffusion process by adjusting how the denoising | |
| process adapts to changes in noise levels over time. | |
| No effect with EDM samplers. | |
| - **control_scale:** | |
| The strenght of the SUPIR control model, scales linearly from start to end. | |
| Lower values allow more freedom from the input image. | |
| - **restore_cfg:** | |
| Controls the degree of restoration towards the original image during the diffusion | |
| process. It allows for dome fine-tuning of the process. | |
| - **samplers:** | |
| EDM samplers need lots of steps but generally have better quality. | |
| DPMPP samplers work well with lower steps, good for lightning models. | |
| Tiled samplers enable tiled diffusion process, this is very slow but allows higher | |
| resolutions to be used by saving VRAM. Tile size should be chosen so the image | |
| is evenly tiled. Tile stride affects the overlap of the tiles. Check the | |
| SUPIR Tiles -node for preview to understand how the image is tiled. | |
| """ | |
| def sample(self, SUPIR_model, latents, steps, seed, cfg_scale_end, EDM_s_churn, s_noise, positive, negative, | |
| cfg_scale_start, control_scale_start, control_scale_end, restore_cfg, keep_model_loaded, DPMPP_eta, | |
| sampler, sampler_tile_size=1024, sampler_tile_stride=512): | |
| torch.manual_seed(seed) | |
| device = mm.get_torch_device() | |
| mm.unload_all_models() | |
| mm.soft_empty_cache() | |
| self.sampler_config = { | |
| 'target': f'.sgm.modules.diffusionmodules.sampling.{sampler}', | |
| 'params': { | |
| 'num_steps': steps, | |
| 'restore_cfg': restore_cfg, | |
| 's_churn': EDM_s_churn, | |
| 's_noise': s_noise, | |
| 'discretization_config': { | |
| 'target': '.sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization' | |
| }, | |
| 'guider_config': { | |
| 'target': '.sgm.modules.diffusionmodules.guiders.LinearCFG', | |
| 'params': { | |
| 'scale': cfg_scale_start, | |
| 'scale_min': cfg_scale_end | |
| } | |
| } | |
| } | |
| } | |
| if 'Tiled' in sampler: | |
| self.sampler_config['params']['tile_size'] = sampler_tile_size // 8 | |
| self.sampler_config['params']['tile_stride'] = sampler_tile_stride // 8 | |
| if 'DPMPP' in sampler: | |
| self.sampler_config['params']['eta'] = DPMPP_eta | |
| self.sampler_config['params']['restore_cfg'] = -1 | |
| if not hasattr (self,'sampler') or self.sampler_config != self.current_sampler_config: | |
| self.sampler = instantiate_from_config(self.sampler_config) | |
| self.current_sampler_config = self.sampler_config | |
| print("sampler_config: ", self.sampler_config) | |
| SUPIR_model.denoiser.to(device) | |
| SUPIR_model.model.diffusion_model.to(device) | |
| SUPIR_model.model.control_model.to(device) | |
| use_linear_control_scale = control_scale_start != control_scale_end | |
| denoiser = lambda input, sigma, c, control_scale: SUPIR_model.denoiser(SUPIR_model.model, input, sigma, c, control_scale) | |
| original_size = positive['original_size'] | |
| positive = positive['cond'] | |
| negative = negative['uncond'] | |
| samples = latents["samples"] | |
| samples = samples.to(device) | |
| #print("positives: ", len(positive)) | |
| #print("negatives: ", len(negative)) | |
| out = [] | |
| pbar = comfy.utils.ProgressBar(samples.shape[0]) | |
| for i, sample in enumerate(samples): | |
| try: | |
| if 'original_size' in latents: | |
| print("Using random noise") | |
| noised_z = torch.randn_like(sample.unsqueeze(0), device=samples.device) | |
| else: | |
| print("Using latent from input") | |
| noised_z = torch.randn_like(sample.unsqueeze(0), device=samples.device) | |
| noised_z += sample.unsqueeze(0) | |
| if len(positive) != len(samples): | |
| print("Tiled sampling") | |
| _samples = self.sampler(denoiser, noised_z, cond=positive, uc=negative, x_center=sample.unsqueeze(0), control_scale=control_scale_end, | |
| use_linear_control_scale=use_linear_control_scale, control_scale_start=control_scale_start) | |
| else: | |
| #print("positives[i]: ", len(positive[i])) | |
| #print("negatives[i]: ", len(negative[i])) | |
| _samples = self.sampler(denoiser, noised_z, cond=positive[i], uc=negative[i], x_center=sample.unsqueeze(0), control_scale=control_scale_end, | |
| use_linear_control_scale=use_linear_control_scale, control_scale_start=control_scale_start) | |
| except torch.cuda.OutOfMemoryError as e: | |
| mm.free_memory(mm.get_total_memory(mm.get_torch_device()), mm.get_torch_device()) | |
| SUPIR_model = None | |
| mm.soft_empty_cache() | |
| print("It's likely that too large of an image or batch_size for SUPIR was used," | |
| " and it has devoured all of the memory it had reserved, you may need to restart ComfyUI. Make sure you are using tiled_vae, " | |
| " you can also try using fp8 for reduced memory usage if your system supports it.") | |
| raise e | |
| out.append(_samples) | |
| print("Sampled ", i+1, " of ", samples.shape[0]) | |
| pbar.update(1) | |
| if not keep_model_loaded: | |
| SUPIR_model.denoiser.to('cpu') | |
| SUPIR_model.model.diffusion_model.to('cpu') | |
| SUPIR_model.model.control_model.to('cpu') | |
| mm.soft_empty_cache() | |
| if len(out[0].shape) == 4: | |
| samples_out_stacked = torch.cat(out, dim=0) | |
| else: | |
| samples_out_stacked = torch.stack(out, dim=0) | |
| if original_size is None: | |
| samples_out_stacked = samples_out_stacked / 0.13025 | |
| return ({"samples":samples_out_stacked, "original_size": original_size},) | |
| class SUPIR_conditioner: | |
| # @classmethod | |
| # def IS_CHANGED(s): | |
| # return "" | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "SUPIR_model": ("SUPIRMODEL",), | |
| "latents": ("LATENT",), | |
| "positive_prompt": ("STRING", {"multiline": True, "default": "high quality, detailed", }), | |
| "negative_prompt": ("STRING", {"multiline": True, "default": "bad quality, blurry, messy", }), | |
| }, | |
| "optional": { | |
| "captions": ("STRING", {"forceInput": True, "multiline": False, "default": "", }), | |
| } | |
| } | |
| RETURN_TYPES = ("SUPIR_cond_pos", "SUPIR_cond_neg",) | |
| RETURN_NAMES = ("positive", "negative",) | |
| FUNCTION = "condition" | |
| CATEGORY = "SUPIR" | |
| DESCRIPTION = """ | |
| Creates the conditioning for the sampler. | |
| Caption input is optional, when it receives a single caption, it's added to the positive prompt. | |
| If a list of caption is given for single input image, the captions need to match the number of tiles, | |
| refer to the SUPIR Tiles node. | |
| If a list of captions is given and it matches the incoming image batch, each image uses corresponding caption. | |
| """ | |
| def condition(self, SUPIR_model, latents, positive_prompt, negative_prompt, captions=""): | |
| device = mm.get_torch_device() | |
| mm.soft_empty_cache() | |
| if "original_size" in latents: | |
| original_size = latents["original_size"] | |
| samples = latents["samples"] | |
| else: | |
| original_size = None | |
| samples = latents["samples"] * 0.13025 | |
| N, H, W, C = samples.shape | |
| import copy | |
| if not isinstance(captions, list): | |
| captions_list = [] | |
| captions_list.append([captions]) | |
| captions_list = captions_list * N | |
| else: | |
| captions_list = captions | |
| print("captions: ", captions_list) | |
| SUPIR_model.conditioner.to(device) | |
| samples = samples.to(device) | |
| uc = [] | |
| pbar = comfy.utils.ProgressBar(N) | |
| autocast_condition = (SUPIR_model.model.dtype != torch.float32) and not comfy.model_management.is_device_mps(device) | |
| with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=SUPIR_model.model.dtype) if autocast_condition else nullcontext(): | |
| if N != len(captions_list): #Tiled captioning | |
| print("Tiled captioning") | |
| c = [] | |
| uc = [] | |
| for i, caption in enumerate(captions_list): | |
| cond = {} | |
| cond['original_size_as_tuple'] = torch.tensor([[1024, 1024]]).to(device) | |
| cond['crop_coords_top_left'] = torch.tensor([[0, 0]]).to(device) | |
| cond['target_size_as_tuple'] = torch.tensor([[1024, 1024]]).to(device) | |
| cond['aesthetic_score'] = torch.tensor([[9.0]]).to(device) | |
| cond['control'] = samples[0].unsqueeze(0) | |
| uncond = copy.deepcopy(cond) | |
| uncond['txt'] = [negative_prompt] | |
| cond['txt'] = [''.join([caption[0], positive_prompt])] | |
| if i == 0: | |
| _c, uc = SUPIR_model.conditioner.get_unconditional_conditioning(cond, uncond) | |
| else: | |
| _c, _ = SUPIR_model.conditioner.get_unconditional_conditioning(cond, None) | |
| c.append(_c) | |
| pbar.update(1) | |
| else: #batch captioning | |
| print("Batch captioning") | |
| c = [] | |
| uc = [] | |
| for i, sample in enumerate(samples): | |
| cond = {} | |
| cond['original_size_as_tuple'] = torch.tensor([[1024, 1024]]).to(device) | |
| cond['crop_coords_top_left'] = torch.tensor([[0, 0]]).to(device) | |
| cond['target_size_as_tuple'] = torch.tensor([[1024, 1024]]).to(device) | |
| cond['aesthetic_score'] = torch.tensor([[9.0]]).to(device) | |
| cond['control'] = sample.unsqueeze(0) | |
| uncond = copy.deepcopy(cond) | |
| uncond['txt'] = [negative_prompt] | |
| cond['txt'] = [''.join([captions_list[i][0], positive_prompt])] | |
| _c, _uc = SUPIR_model.conditioner.get_unconditional_conditioning(cond, uncond) | |
| c.append(_c) | |
| uc.append(_uc) | |
| pbar.update(1) | |
| SUPIR_model.conditioner.to('cpu') | |
| if "original_size" in latents: | |
| original_size = latents["original_size"] | |
| else: | |
| original_size = None | |
| return ({"cond": c, "original_size":original_size}, {"uncond": uc},) | |
| class SUPIR_model_loader: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "supir_model": (folder_paths.get_filename_list("checkpoints"),), | |
| "sdxl_model": (folder_paths.get_filename_list("checkpoints"),), | |
| "fp8_unet": ("BOOLEAN", {"default": False}), | |
| "diffusion_dtype": ( | |
| [ | |
| 'fp16', | |
| 'bf16', | |
| 'fp32', | |
| 'auto' | |
| ], { | |
| "default": 'auto' | |
| }), | |
| }, | |
| } | |
| RETURN_TYPES = ("SUPIRMODEL", "SUPIRVAE") | |
| RETURN_NAMES = ("SUPIR_model","SUPIR_VAE",) | |
| FUNCTION = "process" | |
| CATEGORY = "SUPIR" | |
| DESCRIPTION = """ | |
| Old loader, not recommended to be used. | |
| Loads the SUPIR model and the selected SDXL model and merges them. | |
| """ | |
| def process(self, supir_model, sdxl_model, diffusion_dtype, fp8_unet): | |
| device = mm.get_torch_device() | |
| mm.unload_all_models() | |
| SUPIR_MODEL_PATH = folder_paths.get_full_path("checkpoints", supir_model) | |
| SDXL_MODEL_PATH = folder_paths.get_full_path("checkpoints", sdxl_model) | |
| config_path = os.path.join(script_directory, "options/SUPIR_v0.yaml") | |
| clip_config_path = os.path.join(script_directory, "configs/clip_vit_config.json") | |
| tokenizer_path = os.path.join(script_directory, "configs/tokenizer") | |
| custom_config = { | |
| 'sdxl_model': sdxl_model, | |
| 'diffusion_dtype': diffusion_dtype, | |
| 'supir_model': supir_model, | |
| 'fp8_unet': fp8_unet, | |
| } | |
| if diffusion_dtype == 'auto': | |
| try: | |
| if mm.should_use_fp16(): | |
| print("Diffusion using fp16") | |
| dtype = torch.float16 | |
| model_dtype = 'fp16' | |
| elif mm.should_use_bf16(): | |
| print("Diffusion using bf16") | |
| dtype = torch.bfloat16 | |
| model_dtype = 'bf16' | |
| else: | |
| print("Diffusion using fp32") | |
| dtype = torch.float32 | |
| model_dtype = 'fp32' | |
| except: | |
| raise AttributeError("ComfyUI version too old, can't autodetect properly. Set your dtypes manually.") | |
| else: | |
| print(f"Diffusion using {diffusion_dtype}") | |
| dtype = convert_dtype(diffusion_dtype) | |
| model_dtype = diffusion_dtype | |
| if not hasattr(self, "model") or self.model is None or self.current_config != custom_config: | |
| self.current_config = custom_config | |
| self.model = None | |
| mm.soft_empty_cache() | |
| config = OmegaConf.load(config_path) | |
| if mm.XFORMERS_IS_AVAILABLE: | |
| print("Using XFORMERS") | |
| config.model.params.control_stage_config.params.spatial_transformer_attn_type = "softmax-xformers" | |
| config.model.params.network_config.params.spatial_transformer_attn_type = "softmax-xformers" | |
| config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla-xformers" | |
| config.model.params.diffusion_dtype = model_dtype | |
| config.model.target = ".SUPIR.models.SUPIR_model_v2.SUPIRModel" | |
| pbar = comfy.utils.ProgressBar(5) | |
| self.model = instantiate_from_config(config.model).cpu() | |
| self.model.model.dtype = dtype | |
| pbar.update(1) | |
| try: | |
| print(f"Attempting to load SDXL model: [{SDXL_MODEL_PATH}]") | |
| sdxl_state_dict = load_state_dict(SDXL_MODEL_PATH) | |
| self.model.load_state_dict(sdxl_state_dict, strict=False) | |
| if fp8_unet: | |
| self.model.model.to(torch.float8_e4m3fn) | |
| else: | |
| self.model.model.to(dtype) | |
| pbar.update(1) | |
| except: | |
| raise Exception("Failed to load SDXL model") | |
| #first clip model from SDXL checkpoint | |
| try: | |
| print("Loading first clip model from SDXL checkpoint") | |
| replace_prefix = {} | |
| replace_prefix["conditioner.embedders.0.transformer."] = "" | |
| sd = comfy.utils.state_dict_prefix_replace(sdxl_state_dict, replace_prefix, filter_keys=False) | |
| clip_text_config = CLIPTextConfig.from_pretrained(clip_config_path) | |
| self.model.conditioner.embedders[0].tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) | |
| self.model.conditioner.embedders[0].transformer = CLIPTextModel(clip_text_config) | |
| self.model.conditioner.embedders[0].transformer.load_state_dict(sd, strict=False) | |
| self.model.conditioner.embedders[0].eval() | |
| self.model.conditioner.embedders[0].to(dtype) | |
| for param in self.model.conditioner.embedders[0].parameters(): | |
| param.requires_grad = False | |
| pbar.update(1) | |
| except: | |
| raise Exception("Failed to load first clip model from SDXL checkpoint") | |
| del sdxl_state_dict | |
| #second clip model from SDXL checkpoint | |
| try: | |
| print("Loading second clip model from SDXL checkpoint") | |
| replace_prefix2 = {} | |
| replace_prefix2["conditioner.embedders.1.model."] = "" | |
| sd = comfy.utils.state_dict_prefix_replace(sd, replace_prefix2, filter_keys=True) | |
| clip_g = build_text_model_from_openai_state_dict(sd, device, cast_dtype=dtype) | |
| self.model.conditioner.embedders[1].model = clip_g | |
| self.model.conditioner.embedders[1].to(dtype) | |
| pbar.update(1) | |
| except: | |
| raise Exception("Failed to load second clip model from SDXL checkpoint") | |
| del sd, clip_g | |
| try: | |
| print(f'Attempting to load SUPIR model: [{SUPIR_MODEL_PATH}]') | |
| supir_state_dict = load_state_dict(SUPIR_MODEL_PATH) | |
| self.model.load_state_dict(supir_state_dict, strict=False) | |
| if fp8_unet: | |
| self.model.model.to(torch.float8_e4m3fn) | |
| else: | |
| self.model.model.to(dtype) | |
| del supir_state_dict | |
| pbar.update(1) | |
| except: | |
| raise Exception("Failed to load SUPIR model") | |
| mm.soft_empty_cache() | |
| return (self.model, self.model.first_stage_model,) | |
| class SUPIR_model_loader_v2: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "model" :("MODEL",), | |
| "clip": ("CLIP",), | |
| "vae": ("VAE",), | |
| "supir_model": (folder_paths.get_filename_list("checkpoints"),), | |
| "fp8_unet": ("BOOLEAN", {"default": False}), | |
| "diffusion_dtype": ( | |
| [ | |
| 'fp16', | |
| 'bf16', | |
| 'fp32', | |
| 'auto' | |
| ], { | |
| "default": 'auto' | |
| }), | |
| }, | |
| "optional": { | |
| "high_vram": ("BOOLEAN", {"default": False}), | |
| } | |
| } | |
| RETURN_TYPES = ("SUPIRMODEL", "SUPIRVAE") | |
| RETURN_NAMES = ("SUPIR_model","SUPIR_VAE",) | |
| FUNCTION = "process" | |
| CATEGORY = "SUPIR" | |
| DESCRIPTION = """ | |
| Loads the SUPIR model and merges it with the SDXL model. | |
| Diffusion type should be kept on auto, unless you have issues loading the model. | |
| fp8_unet casts the unet weights to torch.float8_e4m3fn, which saves a lot of VRAM but has slight quality impact. | |
| high_vram: uses Accelerate to load weights to GPU, slightly faster model loading. | |
| """ | |
| def process(self, supir_model, diffusion_dtype, fp8_unet, model, clip, vae, high_vram=False): | |
| if high_vram: | |
| device = mm.get_torch_device() | |
| else: | |
| device = mm.unet_offload_device() | |
| print("Loading weights to: ", device) | |
| mm.unload_all_models() | |
| SUPIR_MODEL_PATH = folder_paths.get_full_path("checkpoints", supir_model) | |
| config_path = os.path.join(script_directory, "options/SUPIR_v0.yaml") | |
| clip_config_path = os.path.join(script_directory, "configs/clip_vit_config.json") | |
| tokenizer_path = os.path.join(script_directory, "configs/tokenizer") | |
| custom_config = { | |
| 'diffusion_dtype': diffusion_dtype, | |
| 'supir_model': supir_model, | |
| 'fp8_unet': fp8_unet, | |
| 'model': model, | |
| "clip": clip, | |
| "vae": vae | |
| } | |
| if diffusion_dtype == 'auto': | |
| try: | |
| if mm.should_use_fp16(): | |
| print("Diffusion using fp16") | |
| dtype = torch.float16 | |
| elif mm.should_use_bf16(): | |
| print("Diffusion using bf16") | |
| dtype = torch.bfloat16 | |
| else: | |
| print("Diffusion using fp32") | |
| dtype = torch.float32 | |
| except: | |
| raise AttributeError("ComfyUI version too old, can't autodecet properly. Set your dtypes manually.") | |
| else: | |
| print(f"Diffusion using {diffusion_dtype}") | |
| dtype = convert_dtype(diffusion_dtype) | |
| if not hasattr(self, "model") or self.model is None or self.current_config != custom_config: | |
| self.current_config = custom_config | |
| self.model = None | |
| mm.soft_empty_cache() | |
| config = OmegaConf.load(config_path) | |
| if mm.XFORMERS_IS_AVAILABLE: | |
| print("Using XFORMERS") | |
| config.model.params.control_stage_config.params.spatial_transformer_attn_type = "softmax-xformers" | |
| config.model.params.network_config.params.spatial_transformer_attn_type = "softmax-xformers" | |
| config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla-xformers" | |
| config.model.target = ".SUPIR.models.SUPIR_model_v2.SUPIRModel" | |
| pbar = comfy.utils.ProgressBar(5) | |
| #with (init_empty_weights() if is_accelerate_available else nullcontext()): | |
| self.model = instantiate_from_config(config.model).cpu() | |
| self.model.model.dtype = dtype | |
| pbar.update(1) | |
| try: | |
| print(f"Attempting to load SDXL model from node inputs") | |
| mm.load_model_gpu(model) | |
| sdxl_state_dict = model.model.state_dict_for_saving(None, vae.get_sd(), None) | |
| if is_accelerate_available: | |
| for key in sdxl_state_dict: | |
| set_module_tensor_to_device(self.model, key, device=device, dtype=dtype, value=sdxl_state_dict[key]) | |
| else: | |
| self.model.load_state_dict(sdxl_state_dict, strict=False) | |
| if fp8_unet: | |
| self.model.model.to(torch.float8_e4m3fn) | |
| else: | |
| self.model.model.to(dtype) | |
| del sdxl_state_dict | |
| pbar.update(1) | |
| except: | |
| raise Exception("Failed to load SDXL model") | |
| gc.collect() | |
| mm.soft_empty_cache() | |
| #first clip model from SDXL checkpoint | |
| try: | |
| print("Loading first clip model from SDXL checkpoint") | |
| clip_sd = None | |
| clip_model = clip.load_model() | |
| mm.load_model_gpu(clip_model) | |
| clip_sd = clip.get_sd() | |
| clip_sd = model.model.model_config.process_clip_state_dict_for_saving(clip_sd) | |
| replace_prefix = {} | |
| replace_prefix["conditioner.embedders.0.transformer."] = "" | |
| clip_l_sd = comfy.utils.state_dict_prefix_replace(clip_sd, replace_prefix, filter_keys=True) | |
| clip_text_config = CLIPTextConfig.from_pretrained(clip_config_path) | |
| self.model.conditioner.embedders[0].tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) | |
| with (init_empty_weights() if is_accelerate_available else nullcontext()): | |
| self.model.conditioner.embedders[0].transformer = CLIPTextModel(clip_text_config) | |
| if is_accelerate_available: | |
| for key in clip_l_sd: | |
| set_module_tensor_to_device(self.model.conditioner.embedders[0].transformer, key, device=device, dtype=dtype, value=clip_l_sd[key]) | |
| else: | |
| self.model.conditioner.embedders[0].transformer.load_state_dict(clip_l_sd, strict=False) | |
| self.model.conditioner.embedders[0].eval() | |
| for param in self.model.conditioner.embedders[0].parameters(): | |
| param.requires_grad = False | |
| self.model.conditioner.embedders[0].to(dtype) | |
| del clip_l_sd | |
| pbar.update(1) | |
| except: | |
| raise Exception("Failed to load first clip model from SDXL checkpoint") | |
| gc.collect() | |
| mm.soft_empty_cache() | |
| #second clip model from SDXL checkpoint | |
| try: | |
| print("Loading second clip model from SDXL checkpoint") | |
| replace_prefix2 = {} | |
| replace_prefix2["conditioner.embedders.1.model."] = "" | |
| clip_g_sd = comfy.utils.state_dict_prefix_replace(clip_sd, replace_prefix2, filter_keys=True) | |
| clip_g = build_text_model_from_openai_state_dict(clip_g_sd, device, cast_dtype=dtype) | |
| self.model.conditioner.embedders[1].model = clip_g | |
| self.model.conditioner.embedders[1].model.to(dtype) | |
| del clip_g_sd | |
| pbar.update(1) | |
| except: | |
| raise Exception("Failed to load second clip model from SDXL checkpoint") | |
| try: | |
| print(f'Attempting to load SUPIR model: [{SUPIR_MODEL_PATH}]') | |
| supir_state_dict = load_state_dict(SUPIR_MODEL_PATH) | |
| if "Q" not in supir_model or not is_accelerate_available: #I don't know why this doesn't work with the Q model. | |
| for key in supir_state_dict: | |
| set_module_tensor_to_device(self.model, key, device=device, dtype=dtype, value=supir_state_dict[key]) | |
| else: | |
| self.model.load_state_dict(supir_state_dict, strict=False) | |
| if fp8_unet: | |
| self.model.model.to(torch.float8_e4m3fn) | |
| else: | |
| self.model.model.to(dtype) | |
| del supir_state_dict | |
| pbar.update(1) | |
| except: | |
| raise Exception("Failed to load SUPIR model") | |
| mm.soft_empty_cache() | |
| return (self.model, self.model.first_stage_model,) | |
| class SUPIR_model_loader_v2_clip: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "model" :("MODEL",), | |
| "clip_l": ("CLIP",), | |
| "clip_g": ("CLIP",), | |
| "vae": ("VAE",), | |
| "supir_model": (folder_paths.get_filename_list("checkpoints"),), | |
| "fp8_unet": ("BOOLEAN", {"default": False}), | |
| "diffusion_dtype": ( | |
| [ | |
| 'fp16', | |
| 'bf16', | |
| 'fp32', | |
| 'auto' | |
| ], { | |
| "default": 'auto' | |
| }), | |
| }, | |
| "optional": { | |
| "high_vram": ("BOOLEAN", {"default": False}), | |
| } | |
| } | |
| RETURN_TYPES = ("SUPIRMODEL", "SUPIRVAE") | |
| RETURN_NAMES = ("SUPIR_model","SUPIR_VAE",) | |
| FUNCTION = "process" | |
| CATEGORY = "SUPIR" | |
| DESCRIPTION = """ | |
| Loads the SUPIR model and merges it with the SDXL model. | |
| Diffusion type should be kept on auto, unless you have issues loading the model. | |
| fp8_unet casts the unet weights to torch.float8_e4m3fn, which saves a lot of VRAM but has slight quality impact. | |
| high_vram: uses Accelerate to load weights to GPU, slightly faster model loading. | |
| """ | |
| def process(self, supir_model, diffusion_dtype, fp8_unet, model, clip_l, clip_g, vae, high_vram=False): | |
| if high_vram: | |
| device = mm.get_torch_device() | |
| else: | |
| device = mm.unet_offload_device() | |
| print("Loading weights to: ", device) | |
| mm.unload_all_models() | |
| SUPIR_MODEL_PATH = folder_paths.get_full_path("checkpoints", supir_model) | |
| config_path = os.path.join(script_directory, "options/SUPIR_v0.yaml") | |
| clip_config_path = os.path.join(script_directory, "configs/clip_vit_config.json") | |
| tokenizer_path = os.path.join(script_directory, "configs/tokenizer") | |
| custom_config = { | |
| 'diffusion_dtype': diffusion_dtype, | |
| 'supir_model': supir_model, | |
| 'fp8_unet': fp8_unet, | |
| 'model': model, | |
| "clip": clip_l, | |
| "clip_g": clip_g, | |
| "vae": vae | |
| } | |
| if diffusion_dtype == 'auto': | |
| try: | |
| if mm.should_use_fp16(): | |
| print("Diffusion using fp16") | |
| dtype = torch.float16 | |
| elif mm.should_use_bf16(): | |
| print("Diffusion using bf16") | |
| dtype = torch.bfloat16 | |
| else: | |
| print("Diffusion using fp32") | |
| dtype = torch.float32 | |
| except: | |
| raise AttributeError("ComfyUI version too old, can't autodecet properly. Set your dtypes manually.") | |
| else: | |
| print(f"Diffusion using {diffusion_dtype}") | |
| dtype = convert_dtype(diffusion_dtype) | |
| if not hasattr(self, "model") or self.model is None or self.current_config != custom_config: | |
| self.current_config = custom_config | |
| self.model = None | |
| mm.soft_empty_cache() | |
| config = OmegaConf.load(config_path) | |
| if mm.XFORMERS_IS_AVAILABLE: | |
| print("Using XFORMERS") | |
| config.model.params.control_stage_config.params.spatial_transformer_attn_type = "softmax-xformers" | |
| config.model.params.network_config.params.spatial_transformer_attn_type = "softmax-xformers" | |
| config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla-xformers" | |
| config.model.target = ".SUPIR.models.SUPIR_model_v2.SUPIRModel" | |
| pbar = comfy.utils.ProgressBar(5) | |
| #with (init_empty_weights() if is_accelerate_available else nullcontext()): | |
| self.model = instantiate_from_config(config.model).cpu() | |
| self.model.model.dtype = dtype | |
| pbar.update(1) | |
| try: | |
| print(f"Attempting to load SDXL model from node inputs") | |
| mm.load_model_gpu(model) | |
| sdxl_state_dict = model.model.state_dict_for_saving(None, vae.get_sd(), None) | |
| if is_accelerate_available: | |
| for key in sdxl_state_dict: | |
| set_module_tensor_to_device(self.model, key, device=device, dtype=dtype, value=sdxl_state_dict[key]) | |
| else: | |
| self.model.load_state_dict(sdxl_state_dict, strict=False) | |
| if fp8_unet: | |
| self.model.model.to(torch.float8_e4m3fn) | |
| else: | |
| self.model.model.to(dtype) | |
| del sdxl_state_dict | |
| pbar.update(1) | |
| except: | |
| raise Exception("Failed to load SDXL model") | |
| gc.collect() | |
| mm.soft_empty_cache() | |
| #first clip model from SDXL checkpoint | |
| try: | |
| print("Loading first clip model from SDXL checkpoint") | |
| clip_l_sd = None | |
| clip_l_model = clip_l.load_model() | |
| mm.load_model_gpu(clip_l_model) | |
| clip_l_sd = clip_l.get_sd() | |
| clip_l_sd = model.model.model_config.process_clip_state_dict_for_saving(clip_l_sd) | |
| replace_prefix = {} | |
| replace_prefix["conditioner.embedders.0.transformer."] = "" | |
| clip_l_sd = comfy.utils.state_dict_prefix_replace(clip_l_sd, replace_prefix, filter_keys=True) | |
| clip_text_config = CLIPTextConfig.from_pretrained(clip_config_path) | |
| self.model.conditioner.embedders[0].tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) | |
| with (init_empty_weights() if is_accelerate_available else nullcontext()): | |
| self.model.conditioner.embedders[0].transformer = CLIPTextModel(clip_text_config) | |
| if is_accelerate_available: | |
| for key in clip_l_sd: | |
| set_module_tensor_to_device(self.model.conditioner.embedders[0].transformer, key, device=device, dtype=dtype, value=clip_l_sd[key]) | |
| else: | |
| self.model.conditioner.embedders[0].transformer.load_state_dict(clip_l_sd, strict=False) | |
| self.model.conditioner.embedders[0].eval() | |
| for param in self.model.conditioner.embedders[0].parameters(): | |
| param.requires_grad = False | |
| self.model.conditioner.embedders[0].to(dtype) | |
| del clip_l_sd | |
| pbar.update(1) | |
| except: | |
| raise Exception("Failed to load first clip model from SDXL checkpoint") | |
| gc.collect() | |
| mm.soft_empty_cache() | |
| #second clip model from SDXL checkpoint | |
| try: | |
| print("Loading second clip model from SDXL checkpoint") | |
| clip_g_sd = None | |
| clip_g_model = clip_g.load_model() | |
| mm.load_model_gpu(clip_g_model) | |
| clip_g_sd = clip_g.get_sd() | |
| clip_g_sd = model.model.model_config.process_clip_state_dict_for_saving(clip_g_sd) | |
| replace_prefix2 = {} | |
| replace_prefix2["conditioner.embedders.1.model."] = "" | |
| clip_g_sd = comfy.utils.state_dict_prefix_replace(clip_g_sd, replace_prefix2, filter_keys=True) | |
| clip_g = build_text_model_from_openai_state_dict(clip_g_sd, device, cast_dtype=dtype) | |
| self.model.conditioner.embedders[1].model = clip_g | |
| self.model.conditioner.embedders[1].model.to(dtype) | |
| del clip_g_sd | |
| pbar.update(1) | |
| except: | |
| raise Exception("Failed to load second clip model from SDXL checkpoint") | |
| try: | |
| print(f'Attempting to load SUPIR model: [{SUPIR_MODEL_PATH}]') | |
| supir_state_dict = load_state_dict(SUPIR_MODEL_PATH) | |
| if "Q" not in supir_model or not is_accelerate_available: #I don't know why this doesn't work with the Q model. | |
| for key in supir_state_dict: | |
| set_module_tensor_to_device(self.model, key, device=device, dtype=dtype, value=supir_state_dict[key]) | |
| else: | |
| self.model.load_state_dict(supir_state_dict, strict=False) | |
| if fp8_unet: | |
| self.model.model.to(torch.float8_e4m3fn) | |
| else: | |
| self.model.model.to(dtype) | |
| del supir_state_dict | |
| pbar.update(1) | |
| except: | |
| raise Exception("Failed to load SUPIR model") | |
| mm.soft_empty_cache() | |
| return (self.model, self.model.first_stage_model,) | |
| class SUPIR_tiles: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "image": ("IMAGE",), | |
| "tile_size": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}), | |
| "tile_stride": ("INT", {"default": 256, "min": 64, "max": 8192, "step": 64}), | |
| } | |
| } | |
| RETURN_TYPES = ("IMAGE", "INT", "INT",) | |
| RETURN_NAMES = ("image_tiles", "tile_size", "tile_stride",) | |
| FUNCTION = "tile" | |
| CATEGORY = "SUPIR" | |
| DESCRIPTION = """ | |
| Tiles the image with same function as the Tiled samplers use. | |
| Useful for previewing the tiling and generating captions per tile (WIP feature) | |
| """ | |
| def tile(self, image, tile_size, tile_stride): | |
| def _sliding_windows(h: int, w: int, tile_size: int, tile_stride: int): | |
| hi_list = list(range(0, h - tile_size + 1, tile_stride)) | |
| if (h - tile_size) % tile_stride != 0: | |
| hi_list.append(h - tile_size) | |
| wi_list = list(range(0, w - tile_size + 1, tile_stride)) | |
| if (w - tile_size) % tile_stride != 0: | |
| wi_list.append(w - tile_size) | |
| coords = [] | |
| for hi in hi_list: | |
| for wi in wi_list: | |
| coords.append((hi, hi + tile_size, wi, wi + tile_size)) | |
| return coords | |
| image = image.permute(0, 3, 1, 2) | |
| _, _, h, w = image.shape | |
| tiles_iterator = _sliding_windows(h, w, tile_size, tile_stride) | |
| tiles = [] | |
| for hi, hi_end, wi, wi_end in tiles_iterator: | |
| tile = image[:, :, hi:hi_end, wi:wi_end] | |
| tiles.append(tile) | |
| out = torch.cat(tiles, dim=0).to(torch.float32).permute(0, 2, 3, 1) | |
| print(out.shape) | |
| print("len(tiles): ", len(tiles)) | |
| return (out, tile_size, tile_stride,) | |