Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import random | |
| from functools import partial | |
| if os.environ.get("IN_SPACES", None) is not None: | |
| in_spaces = True | |
| import spaces | |
| else: | |
| in_spaces = False | |
| import gradio as gr | |
| import torch | |
| try: | |
| # pre-import triton can avoid diffusers/transformers make import error | |
| import triton | |
| except ImportError: | |
| print("Triton not found, skip pre import") | |
| ## HDM model dep | |
| import xut.env | |
| xut.env.TORCH_COMPILE = False | |
| xut.env.USE_LIGER = False | |
| xut.env.USE_VANILLA = False | |
| xut.env.USE_XFORMERS = False | |
| xut.env.USE_XFORMERS_LAYERS = False | |
| from hdm.pipeline import HDMXUTPipeline | |
| ## TIPO | |
| import kgen.models as kgen_models | |
| import kgen.executor.tipo as tipo | |
| from kgen.formatter import apply_format, seperate_tags | |
| torch.set_float32_matmul_precision("high") | |
| DEFAULT_FORMAT = """ | |
| <|special|>, | |
| <|characters|>, <|copyrights|>, | |
| <|artist|>, | |
| <|general|>, | |
| <|extended|>. | |
| <|quality|>, <|meta|>, <|rating|> | |
| """.strip() | |
| def GPU(func=None, duration=None): | |
| if func is None: | |
| return partial(GPU, duration=duration) | |
| if in_spaces: | |
| if duration: | |
| return spaces.GPU(func, duration=duration) | |
| else: | |
| return spaces.GPU(func) | |
| else: | |
| return func | |
| def prompt_opt(tags, nl_prompt, aspect_ratio, seed): | |
| meta, operations, general, nl_prompt = tipo.parse_tipo_request( | |
| seperate_tags(tags.split(",")), | |
| nl_prompt, | |
| tag_length_target="long", | |
| nl_length_target="short", | |
| generate_extra_nl_prompt=True, | |
| ) | |
| meta["aspect_ratio"] = f"{aspect_ratio:.3f}" | |
| result, timing = tipo.tipo_runner(meta, operations, general, nl_prompt, seed=seed) | |
| return apply_format(result, DEFAULT_FORMAT).strip().strip(".").strip(",") | |
| print("Loading models, please wait...") | |
| device = torch.device("cuda") | |
| model = ( | |
| HDMXUTPipeline.from_pretrained( | |
| "KBlueLeaf/HDM-xut-340M-anime", | |
| trust_remote_code=True, | |
| ) | |
| .to(torch.float16) | |
| .to(device) | |
| ) | |
| tipo_model_name, gguf_list = kgen_models.tipo_model_list[0] | |
| kgen_models.load_model(tipo_model_name, device="cuda") | |
| print("Models loaded successfully. UI is ready.") | |
| def generate( | |
| nl_prompt: str, | |
| tag_prompt: str, | |
| negative_prompt: str, | |
| tipo_enable: bool, | |
| format_enable: bool, | |
| num_images: int, | |
| steps: int, | |
| cfg_scale: float, | |
| size: int, | |
| aspect_ratio: str, | |
| fixed_short_edge: bool, | |
| zoom: float, | |
| x_shift: float, | |
| y_shift: float, | |
| tread_gamma1: float, | |
| tread_gamma2: float, | |
| seed: int, | |
| progress=gr.Progress(), | |
| ): | |
| as_w, as_h = aspect_ratio.split(":") | |
| aspect_ratio = float(as_w) / float(as_h) | |
| # Set seed for reproducibility | |
| if seed == -1: | |
| seed = random.randint(0, 2**32 - 1) | |
| torch.manual_seed(seed) | |
| # TIPO | |
| if tipo_enable: | |
| tipo.BAN_TAGS = [i.strip() for i in negative_prompt.split(",") if i.strip()] | |
| final_prompt = prompt_opt(tag_prompt, nl_prompt, aspect_ratio, seed) | |
| elif format_enable: | |
| final_prompt = apply_format(nl_prompt, DEFAULT_FORMAT) | |
| else: | |
| final_prompt = tag_prompt + "\n" + nl_prompt | |
| yield None, final_prompt | |
| prompts_to_generate = [final_prompt.replace("\n", " ")] * num_images | |
| negative_prompts_to_generate = [negative_prompt] * num_images | |
| if fixed_short_edge: | |
| if aspect_ratio > 1: | |
| h_factor = 1 | |
| w_factor = aspect_ratio | |
| else: | |
| h_factor = 1 / aspect_ratio | |
| w_factor = 1 | |
| else: | |
| w_factor = aspect_ratio**0.5 | |
| h_factor = 1 / w_factor | |
| w = int(size * w_factor / 16) * 16 | |
| h = int(size * h_factor / 16) * 16 | |
| print("=" * 100) | |
| print( | |
| f"Generating {num_images} image(s) with seed: {seed} and resolution {w}x{h}" | |
| ) | |
| print("-" * 80) | |
| print(f"Final prompt: {final_prompt}") | |
| print("-" * 80) | |
| print(f"Negative prompt: {negative_prompt}") | |
| print("-" * 80) | |
| prompts_batch = prompts_to_generate | |
| neg_prompts_batch = negative_prompts_to_generate | |
| images = model( | |
| prompts_batch, | |
| neg_prompts_batch, | |
| num_inference_steps=steps, | |
| cfg_scale=cfg_scale, | |
| width=w, | |
| height=h, | |
| camera_param={ | |
| "zoom": zoom, | |
| "x_shift": x_shift, | |
| "y_shift": y_shift, | |
| }, | |
| tread_gamma1=tread_gamma1, | |
| tread_gamma2=tread_gamma2, | |
| ).images | |
| yield images, final_prompt | |
| # --- Gradio UI Definition --- | |
| with gr.Blocks(title="HDM Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# HDM Demo") | |
| gr.Markdown( | |
| "### Enter a natural language prompt and/or specific tags to generate an image." | |
| ) | |
| with gr.Accordion("Introduction", open=False): | |
| gr.Markdown(""" | |
| # HDM: HomeDiffusion Model Project | |
| HDM is a project to implement a series of generative model that can be pretrained at home. | |
| * Project Source code: https://github.com/KBlueLeaf/HDM | |
| * Model: https://huggingface.co/KBlueLeaf/HDM-xut-340M-anime | |
| ## Usage | |
| This early model used a model trained on anime image set only, | |
| so you should expect to see anime style images only in this demo. | |
| For prompting, enter danbooru tag prompt to the box "Tag Prompt" with comma seperated and remove the underscore. | |
| enter natural language prompt to the box "Natural Language Prompt" and enter negative prompt to the box "Negative Prompt". | |
| If you don't want to spent so much effort on prompting, try to keep "Enable TIPO" selected. | |
| If you don't want to apply any pre-defined format, unselect "Enable TIPO" and "Enable Format". | |
| ## Model Spec | |
| - Backbone: 343M XUT(UViT modified) arch | |
| - Text Encoder: Qwen3 0.6B (596M) | |
| - VAE: EQ-SDXL-VAE, an EQ-VAE finetuned sdxl vae. | |
| ## Pretraining Dataset | |
| - Danbooru 2023 (latest id around 8M) | |
| - Pixiv famous artist set | |
| - some pvc figure photos | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| nl_prompt_box = gr.Textbox( | |
| label="Natural Language Prompt", | |
| placeholder="e.g., A beautiful anime girl standing in a blooming cherry blossom forest", | |
| lines=3, | |
| ) | |
| tag_prompt_box = gr.Textbox( | |
| label="Tag Prompt (comma-separated)", | |
| placeholder="e.g., 1girl, solo, long hair, cherry blossoms, school uniform", | |
| lines=3, | |
| ) | |
| neg_prompt_box = gr.Textbox( | |
| label="Negative Prompt", | |
| value=( | |
| "llow quality, worst quality, text, signature, jpeg artifacts, bad anatomy, old, early, copyright name, watermark, artist name, signature, weibo username, realistic" | |
| ), | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| tipo_enable = gr.Checkbox( | |
| label="Enable TIPO", | |
| value=True, | |
| ) | |
| format_enable = gr.Checkbox( | |
| label="Enable Format", | |
| value=True, | |
| ) | |
| with gr.Row(): | |
| zoom_slider = gr.Slider( | |
| label="Zoom", minimum=0.5, maximum=2.0, value=1.0, step=0.01 | |
| ) | |
| x_shift_slider = gr.Slider( | |
| label="X Shift", minimum=-0.5, maximum=0.5, value=0.0, step=0.01 | |
| ) | |
| y_shift_slider = gr.Slider( | |
| label="Y Shift", minimum=-0.5, maximum=0.5, value=0.0, step=0.01 | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| num_images_slider = gr.Slider( | |
| label="Number of Images", minimum=1, maximum=4, value=1, step=1 | |
| ) | |
| steps_slider = gr.Slider( | |
| label="Inference Steps", minimum=1, maximum=50, value=24, step=1 | |
| ) | |
| with gr.Row(): | |
| cfg_slider = gr.Slider( | |
| label="CFG Scale", minimum=1.0, maximum=7.0, value=4.0, step=0.1 | |
| ) | |
| seed_input = gr.Number( | |
| label="Seed", | |
| value=-1, | |
| precision=0, | |
| info="Set to -1 for a random seed.", | |
| ) | |
| with gr.Row(): | |
| tread_gamma1_slider = gr.Slider( | |
| label="Tread Gamma 1", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.0, | |
| step=0.05, | |
| interactive=True, | |
| ) | |
| tread_gamma2_slider = gr.Slider( | |
| label="Tread Gamma 2", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.0, | |
| step=0.05, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| size_slider = gr.Slider( | |
| label="Base Image Size", | |
| minimum=768, | |
| maximum=1280, | |
| value=1024, | |
| step=16, | |
| ) | |
| with gr.Row(): | |
| aspect_ratio_box = gr.Textbox( | |
| label="Ratio (W:H)", | |
| value="1:1", | |
| ) | |
| fixed_short_edge = gr.Checkbox( | |
| label="Fixed Edge", | |
| value=True, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| generate_button = gr.Button("Generate", variant="primary") | |
| output_prompt = gr.TextArea( | |
| label="Final Prompt", | |
| show_label=True, | |
| interactive=False, | |
| lines=32, | |
| max_lines=32, | |
| ) | |
| with gr.Column(scale=2): | |
| output_gallery = gr.Gallery( | |
| label="Generated Images", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=2, | |
| rows=3, | |
| height="800px", | |
| ) | |
| generate_button.click( | |
| fn=generate, | |
| inputs=[ | |
| nl_prompt_box, | |
| tag_prompt_box, | |
| neg_prompt_box, | |
| tipo_enable, | |
| format_enable, | |
| num_images_slider, | |
| steps_slider, | |
| cfg_slider, | |
| size_slider, | |
| aspect_ratio_box, | |
| fixed_short_edge, | |
| zoom_slider, | |
| x_shift_slider, | |
| y_shift_slider, | |
| tread_gamma1_slider, | |
| tread_gamma2_slider, | |
| seed_input, | |
| ], | |
| outputs=[output_gallery, output_prompt], | |
| show_progress_on=output_gallery, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |