Spaces:
Configuration error
Configuration error
| import os | |
| import json | |
| import torch | |
| import folder_paths | |
| from comfy import utils | |
| from .conf import pixart_conf, pixart_res | |
| from .lora import load_pixart_lora | |
| from .loader import load_pixart | |
| class PixArtCheckpointLoader: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), | |
| "model": (list(pixart_conf.keys()),), | |
| } | |
| } | |
| RETURN_TYPES = ("MODEL",) | |
| RETURN_NAMES = ("model",) | |
| FUNCTION = "load_checkpoint" | |
| CATEGORY = "ExtraModels/PixArt" | |
| TITLE = "PixArt Checkpoint Loader" | |
| def load_checkpoint(self, ckpt_name, model): | |
| ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) | |
| model_conf = pixart_conf[model] | |
| model = load_pixart( | |
| model_path = ckpt_path, | |
| model_conf = model_conf, | |
| ) | |
| return (model,) | |
| class PixArtCheckpointLoaderSimple(PixArtCheckpointLoader): | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), | |
| } | |
| } | |
| TITLE = "PixArt Checkpoint Loader (auto)" | |
| def load_checkpoint(self, ckpt_name): | |
| ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) | |
| model = load_pixart(model_path=ckpt_path) | |
| return (model,) | |
| class PixArtResolutionSelect(): | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model": (list(pixart_res.keys()),), | |
| # keys are the same for both | |
| "ratio": (list(pixart_res["PixArtMS_XL_2"].keys()),{"default":"1.00"}), | |
| } | |
| } | |
| RETURN_TYPES = ("INT","INT") | |
| RETURN_NAMES = ("width","height") | |
| FUNCTION = "get_res" | |
| CATEGORY = "ExtraModels/PixArt" | |
| TITLE = "PixArt Resolution Select" | |
| def get_res(self, model, ratio): | |
| width, height = pixart_res[model][ratio] | |
| return (width,height) | |
| class PixArtLoraLoader: | |
| def __init__(self): | |
| self.loaded_lora = None | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model": ("MODEL",), | |
| "lora_name": (folder_paths.get_filename_list("loras"), ), | |
| "strength": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), | |
| } | |
| } | |
| RETURN_TYPES = ("MODEL",) | |
| FUNCTION = "load_lora" | |
| CATEGORY = "ExtraModels/PixArt" | |
| TITLE = "PixArt Load LoRA" | |
| def load_lora(self, model, lora_name, strength,): | |
| if strength == 0: | |
| return (model) | |
| lora_path = folder_paths.get_full_path("loras", lora_name) | |
| lora = None | |
| if self.loaded_lora is not None: | |
| if self.loaded_lora[0] == lora_path: | |
| lora = self.loaded_lora[1] | |
| else: | |
| temp = self.loaded_lora | |
| self.loaded_lora = None | |
| del temp | |
| if lora is None: | |
| lora = utils.load_torch_file(lora_path, safe_load=True) | |
| self.loaded_lora = (lora_path, lora) | |
| model_lora = load_pixart_lora(model, lora, lora_path, strength,) | |
| return (model_lora,) | |
| class PixArtResolutionCond: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "cond": ("CONDITIONING", ), | |
| "width": ("INT", {"default": 1024.0, "min": 0, "max": 8192}), | |
| "height": ("INT", {"default": 1024.0, "min": 0, "max": 8192}), | |
| } | |
| } | |
| RETURN_TYPES = ("CONDITIONING",) | |
| RETURN_NAMES = ("cond",) | |
| FUNCTION = "add_cond" | |
| CATEGORY = "ExtraModels/PixArt" | |
| TITLE = "PixArt Resolution Conditioning" | |
| def add_cond(self, cond, width, height): | |
| for c in range(len(cond)): | |
| cond[c][1].update({ | |
| "img_hw": [[height, width]], | |
| "aspect_ratio": [[height/width]], | |
| }) | |
| return (cond,) | |
| class PixArtControlNetCond: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "cond": ("CONDITIONING",), | |
| "latent": ("LATENT",), | |
| # "image": ("IMAGE",), | |
| # "vae": ("VAE",), | |
| # "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}) | |
| } | |
| } | |
| RETURN_TYPES = ("CONDITIONING",) | |
| RETURN_NAMES = ("cond",) | |
| FUNCTION = "add_cond" | |
| CATEGORY = "ExtraModels/PixArt" | |
| TITLE = "PixArt ControlNet Conditioning" | |
| def add_cond(self, cond, latent): | |
| for c in range(len(cond)): | |
| cond[c][1]["cn_hint"] = latent["samples"] * 0.18215 | |
| return (cond,) | |
| class PixArtT5TextEncode: | |
| """ | |
| Reference code, mostly to verify compatibility. | |
| Once everything works, this should instead inherit from the | |
| T5 text encode node and simply add the extra conds (res/ar). | |
| """ | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "text": ("STRING", {"multiline": True}), | |
| "T5": ("T5",), | |
| } | |
| } | |
| RETURN_TYPES = ("CONDITIONING",) | |
| FUNCTION = "encode" | |
| CATEGORY = "ExtraModels/PixArt" | |
| TITLE = "PixArt T5 Text Encode [Reference]" | |
| def mask_feature(self, emb, mask): | |
| if emb.shape[0] == 1: | |
| keep_index = mask.sum().item() | |
| return emb[:, :, :keep_index, :], keep_index | |
| else: | |
| masked_feature = emb * mask[:, None, :, None] | |
| return masked_feature, emb.shape[2] | |
| def encode(self, text, T5): | |
| text = text.lower().strip() | |
| tokenizer_out = T5.tokenizer.tokenizer( | |
| text, | |
| max_length = 120, | |
| padding = 'max_length', | |
| truncation = True, | |
| return_attention_mask = True, | |
| add_special_tokens = True, | |
| return_tensors = 'pt' | |
| ) | |
| tokens = tokenizer_out["input_ids"] | |
| mask = tokenizer_out["attention_mask"] | |
| embs = T5.cond_stage_model.transformer( | |
| input_ids = tokens.to(T5.load_device), | |
| attention_mask = mask.to(T5.load_device), | |
| )['last_hidden_state'].float()[:, None] | |
| masked_embs, keep_index = self.mask_feature( | |
| embs.detach().to("cpu"), | |
| mask.detach().to("cpu") | |
| ) | |
| masked_embs = masked_embs.squeeze(0) # match CLIP/internal | |
| print("Encoded T5:", masked_embs.shape) | |
| return ([[masked_embs, {}]], ) | |
| class PixArtT5FromSD3CLIP: | |
| """ | |
| Split the T5 text encoder away from SD3 | |
| """ | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "sd3_clip": ("CLIP",), | |
| "padding": ("INT", {"default": 1, "min": 1, "max": 300}), | |
| } | |
| } | |
| RETURN_TYPES = ("CLIP",) | |
| RETURN_NAMES = ("t5",) | |
| FUNCTION = "split" | |
| CATEGORY = "ExtraModels/PixArt" | |
| TITLE = "PixArt T5 from SD3 CLIP" | |
| def split(self, sd3_clip, padding): | |
| try: | |
| from comfy.text_encoders.sd3_clip import SD3Tokenizer, SD3ClipModel | |
| except ImportError: | |
| # fallback for older ComfyUI versions | |
| from comfy.sd3_clip import SD3Tokenizer, SD3ClipModel | |
| import copy | |
| clip = sd3_clip.clone() | |
| assert clip.cond_stage_model.t5xxl is not None, "CLIP must have T5 loaded!" | |
| # remove transformer | |
| transformer = clip.cond_stage_model.t5xxl.transformer | |
| clip.cond_stage_model.t5xxl.transformer = None | |
| # clone object | |
| tmp = SD3ClipModel(clip_l=False, clip_g=False, t5=False) | |
| tmp.t5xxl = copy.deepcopy(clip.cond_stage_model.t5xxl) | |
| # put transformer back | |
| clip.cond_stage_model.t5xxl.transformer = transformer | |
| tmp.t5xxl.transformer = transformer | |
| # override special tokens | |
| tmp.t5xxl.special_tokens = copy.deepcopy(clip.cond_stage_model.t5xxl.special_tokens) | |
| tmp.t5xxl.special_tokens.pop("end") # make sure empty tokens match | |
| # tokenizer | |
| tok = SD3Tokenizer() | |
| tok.t5xxl.min_length = padding | |
| clip.cond_stage_model = tmp | |
| clip.tokenizer = tok | |
| return (clip, ) | |
| NODE_CLASS_MAPPINGS = { | |
| "PixArtCheckpointLoader" : PixArtCheckpointLoader, | |
| "PixArtCheckpointLoaderSimple" : PixArtCheckpointLoaderSimple, | |
| "PixArtResolutionSelect" : PixArtResolutionSelect, | |
| "PixArtLoraLoader" : PixArtLoraLoader, | |
| "PixArtT5TextEncode" : PixArtT5TextEncode, | |
| "PixArtResolutionCond" : PixArtResolutionCond, | |
| "PixArtControlNetCond" : PixArtControlNetCond, | |
| "PixArtT5FromSD3CLIP": PixArtT5FromSD3CLIP, | |
| } | |