Spaces:
Configuration error
Configuration error
| import os | |
| import json | |
| import torch | |
| import folder_paths | |
| from .conf import dit_conf | |
| from .loader import load_dit | |
| class DitCheckpointLoader: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), | |
| "model": (list(dit_conf.keys()),), | |
| "image_size": ([256, 512],), | |
| # "num_classes": ("INT", {"default": 1000, "min": 0,}), | |
| } | |
| } | |
| RETURN_TYPES = ("MODEL",) | |
| RETURN_NAMES = ("model",) | |
| FUNCTION = "load_checkpoint" | |
| CATEGORY = "ExtraModels/DiT" | |
| TITLE = "DitCheckpointLoader" | |
| def load_checkpoint(self, ckpt_name, model, image_size): | |
| ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) | |
| model_conf = dit_conf[model] | |
| model_conf["unet_config"]["input_size"] = image_size // 8 | |
| # model_conf["unet_config"]["num_classes"] = num_classes | |
| dit = load_dit( | |
| model_path = ckpt_path, | |
| model_conf = model_conf, | |
| ) | |
| return (dit,) | |
| # todo: this needs frontend code to display properly | |
| def get_label_data(label_file="labels/imagenet1000.json"): | |
| label_path = os.path.join( | |
| os.path.dirname(os.path.realpath(__file__)), | |
| label_file, | |
| ) | |
| label_data = {0: "None"} | |
| with open(label_path, "r") as f: | |
| label_data = json.loads(f.read()) | |
| return label_data | |
| label_data = get_label_data() | |
| class DiTCondLabelSelect: | |
| def INPUT_TYPES(s): | |
| global label_data | |
| return { | |
| "required": { | |
| "model" : ("MODEL",), | |
| "label_name": (list(label_data.values()),), | |
| } | |
| } | |
| RETURN_TYPES = ("CONDITIONING",) | |
| RETURN_NAMES = ("class",) | |
| FUNCTION = "cond_label" | |
| CATEGORY = "ExtraModels/DiT" | |
| TITLE = "DiTCondLabelSelect" | |
| def cond_label(self, model, label_name): | |
| global label_data | |
| class_labels = [int(k) for k,v in label_data.items() if v == label_name] | |
| y = torch.tensor([[class_labels[0]]]).to(torch.int) | |
| return ([[y, {}]], ) | |
| class DiTCondLabelEmpty: | |
| def INPUT_TYPES(s): | |
| global label_data | |
| return { | |
| "required": { | |
| "model" : ("MODEL",), | |
| } | |
| } | |
| RETURN_TYPES = ("CONDITIONING",) | |
| RETURN_NAMES = ("empty",) | |
| FUNCTION = "cond_empty" | |
| CATEGORY = "ExtraModels/DiT" | |
| TITLE = "DiTCondLabelEmpty" | |
| def cond_empty(self, model): | |
| # [ID of last class + 1] == [num_classes] | |
| y_null = model.model.model_config.unet_config["num_classes"] | |
| y = torch.tensor([[y_null]]).to(torch.int) | |
| return ([[y, {}]], ) | |
| NODE_CLASS_MAPPINGS = { | |
| "DitCheckpointLoader" : DitCheckpointLoader, | |
| "DiTCondLabelSelect" : DiTCondLabelSelect, | |
| "DiTCondLabelEmpty" : DiTCondLabelEmpty, | |
| } | |