Spaces:
Configuration error
Configuration error
| import torch | |
| import os | |
| import random | |
| import re | |
| import gc | |
| import json | |
| import psutil | |
| import comfy.model_management as mm | |
| from comfy.utils import ProgressBar, load_torch_file | |
| import folder_paths | |
| script_directory = os.path.dirname(os.path.abspath(__file__)) | |
| folder_paths.add_model_folder_path("llms", os.path.join(folder_paths.models_dir, "llms", "checkpoints")) | |
| from .kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline | |
| from .kolors.models.modeling_chatglm import ChatGLMModel, ChatGLMConfig | |
| from .kolors.models.tokenization_chatglm import ChatGLMTokenizer | |
| from diffusers import UNet2DConditionModel | |
| from diffusers import (DPMSolverMultistepScheduler, | |
| EulerDiscreteScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| DEISMultistepScheduler, | |
| UniPCMultistepScheduler | |
| ) | |
| from contextlib import nullcontext | |
| try: | |
| from accelerate import init_empty_weights | |
| from accelerate.utils import set_module_tensor_to_device | |
| is_accelerate_available = True | |
| except: | |
| pass | |
| from comfy.utils import ProgressBar | |
| class DownloadAndLoadKolorsModel: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "model": ( | |
| [ | |
| 'Kwai-Kolors/Kolors', | |
| ], | |
| ), | |
| "precision": ([ 'fp16'], | |
| { | |
| "default": 'fp16' | |
| }), | |
| }, | |
| } | |
| RETURN_TYPES = ("KOLORSMODEL",) | |
| RETURN_NAMES = ("kolors_model",) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "KwaiKolorsWrapper" | |
| def loadmodel(self, model, precision): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] | |
| pbar = ProgressBar(4) | |
| model_name = model.rsplit('/', 1)[-1] | |
| model_path = os.path.join(folder_paths.models_dir, "diffusers", model_name) | |
| if not os.path.exists(model_path): | |
| print(f"Downloading Kolor model to: {model_path}") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download(repo_id=model, | |
| allow_patterns=['*fp16.safetensors*', '*.json'], | |
| ignore_patterns=['vae/*', 'text_encoder/*', 'tokenizer/*'], | |
| local_dir=model_path, | |
| local_dir_use_symlinks=False) | |
| pbar.update(1) | |
| ram_rss_start = psutil.Process().memory_info().rss | |
| scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder= 'scheduler') | |
| print(f'Load UNET...') | |
| unet = UNet2DConditionModel.from_pretrained(model_path, subfolder= 'unet', variant="fp16", revision=None, low_cpu_mem_usage=True).to(dtype).eval() | |
| ram_rss_end = psutil.Process().memory_info().rss | |
| print(f'Kolors-unet: RAM allocated = {(ram_rss_end-ram_rss_start)/(1024*1024*1024):.3f}GB') | |
| pipeline = StableDiffusionXLPipeline( | |
| unet=unet, | |
| scheduler=scheduler, | |
| ) | |
| kolors_model = { | |
| 'pipeline': pipeline, | |
| 'dtype': dtype | |
| } | |
| return (kolors_model,) | |
| class LoadChatGLM3: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "chatglm3_checkpoint": (folder_paths.get_filename_list("llms"),), | |
| }, | |
| } | |
| RETURN_TYPES = ("CHATGLM3MODEL",) | |
| RETURN_NAMES = ("chatglm3_model",) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "KwaiKolorsWrapper" | |
| def loadmodel(self, chatglm3_checkpoint): | |
| device=mm.get_torch_device() | |
| offload_device=mm.unet_offload_device() | |
| print(f'chatglm3: device={device}, offload_device={offload_device}') | |
| pbar = ProgressBar(2) | |
| chatglm3_path = folder_paths.get_full_path("llms", chatglm3_checkpoint) | |
| print("Load TEXT_ENCODER...") | |
| text_encoder_config = os.path.join(script_directory, 'configs', 'text_encoder_config.json') | |
| with open(text_encoder_config, 'r') as file: | |
| config = json.load(file) | |
| text_encoder_config = ChatGLMConfig(**config) | |
| with (init_empty_weights() if is_accelerate_available else nullcontext()): | |
| text_encoder = ChatGLMModel(text_encoder_config) | |
| if '4bit' in chatglm3_checkpoint: | |
| text_encoder.quantize(4) | |
| elif '8bit' in chatglm3_checkpoint: | |
| text_encoder.quantize(8) | |
| text_encoder_sd = load_torch_file(chatglm3_path) | |
| if is_accelerate_available: | |
| for key in text_encoder_sd: | |
| set_module_tensor_to_device(text_encoder, key, device=offload_device, value=text_encoder_sd[key]) | |
| else: | |
| text_encoder.load_state_dict() | |
| tokenizer_path = os.path.join(script_directory,'configs',"tokenizer") | |
| tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path) | |
| pbar.update(1) | |
| chatglm3_model = { | |
| 'text_encoder': text_encoder, | |
| 'tokenizer': tokenizer | |
| } | |
| return (chatglm3_model,) | |
| class DownloadAndLoadChatGLM3: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "precision": ([ 'fp16', 'quant4', 'quant8'], | |
| { | |
| "default": 'fp16' | |
| }), | |
| }, | |
| } | |
| RETURN_TYPES = ("CHATGLM3MODEL",) | |
| RETURN_NAMES = ("chatglm3_model",) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "KwaiKolorsWrapper" | |
| def loadmodel(self, precision): | |
| pbar = ProgressBar(2) | |
| model = "Kwai-Kolors/Kolors" | |
| model_name = model.rsplit('/', 1)[-1] | |
| model_path = os.path.join(folder_paths.models_dir, "diffusers", model_name) | |
| text_encoder_path = os.path.join(model_path, "text_encoder") | |
| if not os.path.exists(text_encoder_path): | |
| print(f"Downloading ChatGLM3 to: {text_encoder_path}") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download(repo_id=model, | |
| allow_patterns=['text_encoder/*'], | |
| ignore_patterns=['*.py', '*.pyc'], | |
| local_dir=model_path, | |
| local_dir_use_symlinks=False) | |
| pbar.update(1) | |
| ram_rss_start = psutil.Process().memory_info().rss | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| print(f"Load TEXT_ENCODER..., {precision}, {offload_device}") | |
| text_encoder = ChatGLMModel.from_pretrained( | |
| text_encoder_path, | |
| torch_dtype=torch.float16 | |
| ).to(offload_device) | |
| if precision == 'quant8': | |
| text_encoder.quantize(8) | |
| elif precision == 'quant4': | |
| text_encoder.quantize(4) | |
| #device_text = next(text_encoder.parameters()).device | |
| #print(f'chatglm3: device={device_text}, torch_device={device}, offload_device={offload_device}') | |
| tokenizer = ChatGLMTokenizer.from_pretrained(text_encoder_path) | |
| pbar.update(1) | |
| chatglm3_model = { | |
| 'text_encoder': text_encoder, | |
| 'tokenizer': tokenizer | |
| } | |
| ram_rss_end = psutil.Process().memory_info().rss | |
| print(f'chatglm3: RAM allocated = {(ram_rss_end-ram_rss_start)/(1024*1024*1024):.3f}GB') | |
| return (chatglm3_model,) | |
| class KolorsTextEncode: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "chatglm3_model": ("CHATGLM3MODEL", ), | |
| "prompt": ("STRING", {"multiline": True, "default": "",}), | |
| "negative_prompt": ("STRING", {"multiline": True, "default": "",}), | |
| "num_images_per_prompt": ("INT", {"default": 1, "min": 1, "max": 128, "step": 1}), | |
| }, | |
| } | |
| RETURN_TYPES = ("KOLORS_EMBEDS",) | |
| RETURN_NAMES =("kolors_embeds",) | |
| FUNCTION = "encode" | |
| CATEGORY = "KwaiKolorsWrapper" | |
| def encode(self, chatglm3_model, prompt, negative_prompt, num_images_per_prompt): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| mm.unload_all_models() | |
| mm.soft_empty_cache() | |
| # Function to randomly select an option from the brackets | |
| def choose_random_option(match): | |
| options = match.group(1).split('|') | |
| return random.choice(options) | |
| # Randomly choose between options in brackets for prompt and negative_prompt | |
| prompt = re.sub(r'\{([^{}]*)\}', choose_random_option, prompt) | |
| negative_prompt = re.sub(r'\{([^{}]*)\}', choose_random_option, negative_prompt) | |
| if "|" in prompt: | |
| prompt = prompt.split("|") | |
| negative_prompt = [negative_prompt] * len(prompt) # Replicate negative_prompt to match length of prompt list | |
| print(prompt) | |
| do_classifier_free_guidance = True | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| # Define tokenizers and text encoders | |
| tokenizer = chatglm3_model['tokenizer'] | |
| text_encoder = chatglm3_model['text_encoder'] | |
| text_encoder.to(device) | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=256, | |
| truncation=True, | |
| return_tensors="pt", | |
| ).to(device) | |
| output = text_encoder( | |
| input_ids=text_inputs['input_ids'] , | |
| attention_mask=text_inputs['attention_mask'], | |
| position_ids=text_inputs['position_ids'], | |
| output_hidden_states=True) | |
| prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() # [batch_size, 77, 4096] | |
| text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096] | |
| bs_embed, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
| if do_classifier_free_guidance: | |
| uncond_tokens = [] | |
| if negative_prompt is None: | |
| uncond_tokens = [""] * batch_size | |
| elif prompt is not None and type(prompt) is not type(negative_prompt): | |
| raise TypeError( | |
| f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
| f" {type(prompt)}." | |
| ) | |
| elif isinstance(negative_prompt, str): | |
| uncond_tokens = [negative_prompt] | |
| elif batch_size != len(negative_prompt): | |
| raise ValueError( | |
| f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
| f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
| " the batch size of `prompt`." | |
| ) | |
| else: | |
| uncond_tokens = negative_prompt | |
| max_length = prompt_embeds.shape[1] | |
| uncond_input = tokenizer( | |
| uncond_tokens, | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ).to(device) | |
| output = text_encoder( | |
| input_ids=uncond_input['input_ids'] , | |
| attention_mask=uncond_input['attention_mask'], | |
| position_ids=uncond_input['position_ids'], | |
| output_hidden_states=True) | |
| negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() # [batch_size, 77, 4096] | |
| negative_text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096] | |
| if do_classifier_free_guidance: | |
| # duplicate unconditional embeddings for each generation per prompt, using mps friendly method | |
| seq_len = negative_prompt_embeds.shape[1] | |
| negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) | |
| negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| negative_prompt_embeds = negative_prompt_embeds.view( | |
| batch_size * num_images_per_prompt, seq_len, -1 | |
| ) | |
| bs_embed = text_proj.shape[0] | |
| text_proj = text_proj.repeat(1, num_images_per_prompt).view( | |
| bs_embed * num_images_per_prompt, -1 | |
| ) | |
| negative_text_proj = negative_text_proj.repeat(1, num_images_per_prompt).view( | |
| bs_embed * num_images_per_prompt, -1 | |
| ) | |
| text_encoder.to(offload_device) | |
| mm.soft_empty_cache() | |
| gc.collect() | |
| kolors_embeds = { | |
| 'prompt_embeds': prompt_embeds, | |
| 'negative_prompt_embeds': negative_prompt_embeds, | |
| 'pooled_prompt_embeds': text_proj, | |
| 'negative_pooled_prompt_embeds': negative_text_proj | |
| } | |
| return (kolors_embeds,) | |
| class KolorsSampler: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "kolors_model": ("KOLORSMODEL", ), | |
| "kolors_embeds": ("KOLORS_EMBEDS", ), | |
| "width": ("INT", {"default": 1024, "min": 64, "max": 2048, "step": 64}), | |
| "height": ("INT", {"default": 1024, "min": 64, "max": 2048, "step": 64}), | |
| "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), | |
| "steps": ("INT", {"default": 25, "min": 1, "max": 200, "step": 1}), | |
| "cfg": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 20.0, "step": 0.01}), | |
| "scheduler": ( | |
| [ | |
| "EulerDiscreteScheduler", | |
| "EulerAncestralDiscreteScheduler", | |
| "DPMSolverMultistepScheduler", | |
| "DPMSolverMultistepScheduler_SDE_karras", | |
| "UniPCMultistepScheduler", | |
| "DEISMultistepScheduler", | |
| ], | |
| { | |
| "default": 'EulerDiscreteScheduler' | |
| } | |
| ), | |
| }, | |
| "optional": { | |
| "latent": ("LATENT", ), | |
| "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| } | |
| } | |
| RETURN_TYPES = ("LATENT",) | |
| RETURN_NAMES =("latent",) | |
| FUNCTION = "process" | |
| CATEGORY = "KwaiKolorsWrapper" | |
| def process(self, kolors_model, kolors_embeds, width, height, seed, steps, cfg, scheduler, latent=None, denoise_strength=1.0): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| vae_scaling_factor = 0.13025 #SDXL scaling factor | |
| mm.soft_empty_cache() | |
| gc.collect() | |
| pipeline = kolors_model['pipeline'] | |
| scheduler_config = { | |
| "beta_schedule": "scaled_linear", | |
| "beta_start": 0.00085, | |
| "beta_end": 0.014, | |
| "dynamic_thresholding_ratio": 0.995, | |
| "num_train_timesteps": 1100, | |
| "prediction_type": "epsilon", | |
| "rescale_betas_zero_snr": False, | |
| "steps_offset": 1, | |
| "timestep_spacing": "leading", | |
| "trained_betas": None, | |
| } | |
| if scheduler == "DPMSolverMultistepScheduler": | |
| noise_scheduler = DPMSolverMultistepScheduler(**scheduler_config) | |
| elif scheduler == "DPMSolverMultistepScheduler_SDE_karras": | |
| scheduler_config.update({"algorithm_type": "sde-dpmsolver++"}) | |
| scheduler_config.update({"use_karras_sigmas": True}) | |
| noise_scheduler = DPMSolverMultistepScheduler(**scheduler_config) | |
| elif scheduler == "DEISMultistepScheduler": | |
| scheduler_config.pop("rescale_betas_zero_snr") | |
| noise_scheduler = DEISMultistepScheduler(**scheduler_config) | |
| elif scheduler == "EulerDiscreteScheduler": | |
| scheduler_config.update({"interpolation_type": "linear"}) | |
| scheduler_config.pop("dynamic_thresholding_ratio") | |
| noise_scheduler = EulerDiscreteScheduler(**scheduler_config) | |
| elif scheduler == "EulerAncestralDiscreteScheduler": | |
| scheduler_config.pop("dynamic_thresholding_ratio") | |
| noise_scheduler = EulerAncestralDiscreteScheduler(**scheduler_config) | |
| elif scheduler == "UniPCMultistepScheduler": | |
| scheduler_config.pop("rescale_betas_zero_snr") | |
| noise_scheduler = UniPCMultistepScheduler(**scheduler_config) | |
| pipeline.scheduler = noise_scheduler | |
| generator= torch.Generator(device).manual_seed(seed) | |
| pipeline.unet.to(device) | |
| if latent is not None: | |
| samples_in = latent['samples'] | |
| samples_in = samples_in * vae_scaling_factor | |
| samples_in = samples_in.to(pipeline.unet.dtype).to(device) | |
| latent_out = pipeline( | |
| prompt=None, | |
| latents=samples_in if latent is not None else None, | |
| prompt_embeds = kolors_embeds['prompt_embeds'], | |
| pooled_prompt_embeds = kolors_embeds['pooled_prompt_embeds'], | |
| negative_prompt_embeds = kolors_embeds['negative_prompt_embeds'], | |
| negative_pooled_prompt_embeds = kolors_embeds['negative_pooled_prompt_embeds'], | |
| height=height, | |
| width=width, | |
| num_inference_steps=steps, | |
| guidance_scale=cfg, | |
| num_images_per_prompt=1, | |
| generator= generator, | |
| strength=denoise_strength, | |
| ).images | |
| pipeline.unet.to(offload_device) | |
| latent_out = latent_out / vae_scaling_factor | |
| return ({'samples': latent_out},) | |
| NODE_CLASS_MAPPINGS = { | |
| "DownloadAndLoadKolorsModel": DownloadAndLoadKolorsModel, | |
| "DownloadAndLoadChatGLM3": DownloadAndLoadChatGLM3, | |
| "KolorsSampler": KolorsSampler, | |
| "KolorsTextEncode": KolorsTextEncode, | |
| "LoadChatGLM3": LoadChatGLM3 | |
| } | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "DownloadAndLoadKolorsModel": "(Down)load Kolors Model", | |
| "DownloadAndLoadChatGLM3": "(Down)load ChatGLM3 Model", | |
| "KolorsSampler": "Kolors Sampler", | |
| "KolorsTextEncode": "Kolors Text Encode", | |
| "LoadChatGLM3": "Load ChatGLM3 Model" | |
| } | |