Spaces:
Configuration error
Configuration error
| import torch | |
| import comfy.sd | |
| import comfy.utils | |
| from comfy import model_management | |
| from comfy import diffusers_convert | |
| class EXVAE(comfy.sd.VAE): | |
| def __init__(self, model_path, model_conf, dtype=torch.float32): | |
| self.latent_dim = model_conf["embed_dim"] | |
| self.latent_scale = model_conf["embed_scale"] | |
| self.device = model_management.vae_device() | |
| self.offload_device = model_management.vae_offload_device() | |
| self.vae_dtype = dtype | |
| sd = comfy.utils.load_torch_file(model_path) | |
| model = None | |
| if model_conf["type"] == "AutoencoderKL": | |
| from .models.kl import AutoencoderKL | |
| model = AutoencoderKL(config=model_conf) | |
| if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): | |
| sd = diffusers_convert.convert_vae_state_dict(sd) | |
| elif model_conf["type"] == "AutoencoderKL-VideoDecoder": | |
| from .models.temporal_ae import AutoencoderKL | |
| model = AutoencoderKL(config=model_conf) | |
| elif model_conf["type"] == "VQModel": | |
| from .models.vq import VQModel | |
| model = VQModel(config=model_conf) | |
| elif model_conf["type"] == "ConsistencyDecoder": | |
| from .models.consistencydecoder import ConsistencyDecoder | |
| model = ConsistencyDecoder() | |
| sd = {f"model.{k}":v for k,v in sd.items()} | |
| elif model_conf["type"] == "MoVQ3": | |
| from .models.movq3 import MoVQ | |
| model = MoVQ(model_conf) | |
| else: | |
| raise NotImplementedError(f"Unknown VAE type '{model_conf['type']}'") | |
| self.first_stage_model = model.eval() | |
| m, u = self.first_stage_model.load_state_dict(sd, strict=False) | |
| if len(m) > 0: print("Missing VAE keys", m) | |
| if len(u) > 0: print("Leftover VAE keys", u) | |
| self.first_stage_model.to(self.vae_dtype).to(self.offload_device) | |
| ### Encode/Decode functions below needed due to source repo having 4 VAE channels and a scale factor of 8 hardcoded | |
| def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): | |
| steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) | |
| steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) | |
| steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) | |
| pbar = comfy.utils.ProgressBar(steps) | |
| decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() | |
| output = torch.clamp(( | |
| (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.latent_scale, pbar = pbar) + | |
| comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.latent_scale, pbar = pbar) + | |
| comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.latent_scale, pbar = pbar)) | |
| / 3.0) / 2.0, min=0.0, max=1.0) | |
| return output | |
| def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): | |
| steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) | |
| steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) | |
| steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) | |
| pbar = comfy.utils.ProgressBar(steps) | |
| encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() | |
| samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar) | |
| samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar) | |
| samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar) | |
| samples /= 3.0 | |
| return samples | |
| def decode(self, samples_in): | |
| self.first_stage_model = self.first_stage_model.to(self.device) | |
| try: | |
| memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7 | |
| model_management.free_memory(memory_used, self.device) | |
| free_memory = model_management.get_free_memory(self.device) | |
| batch_number = int(free_memory / memory_used) | |
| batch_number = max(1, batch_number) | |
| pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.latent_scale), round(samples_in.shape[3] * self.latent_scale)), device="cpu") | |
| for x in range(0, samples_in.shape[0], batch_number): | |
| samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) | |
| pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) | |
| except model_management.OOM_EXCEPTION as e: | |
| print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") | |
| pixel_samples = self.decode_tiled_(samples_in) | |
| self.first_stage_model = self.first_stage_model.to(self.offload_device) | |
| pixel_samples = pixel_samples.cpu().movedim(1,-1) | |
| return pixel_samples | |
| def encode(self, pixel_samples): | |
| self.first_stage_model = self.first_stage_model.to(self.device) | |
| pixel_samples = pixel_samples.movedim(-1,1) | |
| try: | |
| memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. | |
| model_management.free_memory(memory_used, self.device) | |
| free_memory = model_management.get_free_memory(self.device) | |
| batch_number = int(free_memory / memory_used) | |
| batch_number = max(1, batch_number) | |
| samples = torch.empty((pixel_samples.shape[0], self.latent_dim, round(pixel_samples.shape[2] // self.latent_scale), round(pixel_samples.shape[3] // self.latent_scale)), device="cpu") | |
| for x in range(0, pixel_samples.shape[0], batch_number): | |
| pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) | |
| samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() | |
| except model_management.OOM_EXCEPTION as e: | |
| print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") | |
| samples = self.encode_tiled_(pixel_samples) | |
| self.first_stage_model = self.first_stage_model.to(self.offload_device) | |
| return samples | |