Spaces:
Build error
Build error
| import time | |
| #import spaces | |
| import gradio as gr | |
| import torch | |
| import diffusers | |
| from utils import patch_attention_proc, remove_patch | |
| import math | |
| import numpy as np | |
| from PIL import Image | |
| from threading import Semaphore | |
| # Globals | |
| css = """ | |
| h1 { | |
| text-align: center; | |
| display: block; | |
| } | |
| """ | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| # Pipeline | |
| pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to(device, torch.float16) | |
| pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
| pipe.safety_checker = None | |
| semaphore = Semaphore() # for preventing collisions of two simultaneous button presses | |
| #@spaces.GPU | |
| def generate_baseline(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method): | |
| semaphore.acquire() | |
| downsample_factor = 2 | |
| ratio = 0.38 | |
| merge_method = "downsample" if method == "todo" else "similarity" | |
| merge_tokens = "keys/values" if method == "todo" else "all" | |
| if height_width == 1024: | |
| downsample_factor = 2 | |
| ratio = 0.75 | |
| downsample_factor_level_2 = 1 | |
| ratio_level_2 = 0.0 | |
| elif height_width == 1536: | |
| downsample_factor = 3 | |
| ratio = 0.89 | |
| downsample_factor_level_2 = 1 | |
| ratio_level_2 = 0.0 | |
| elif height_width == 2048: | |
| downsample_factor = 4 | |
| ratio = 0.9375 | |
| downsample_factor_level_2 = 1 | |
| ratio_level_2 = 0.0 | |
| token_merge_args = {"ratio": ratio, | |
| "merge_tokens": merge_tokens, | |
| "merge_method": merge_method, | |
| "downsample_method": "nearest", | |
| "downsample_factor": downsample_factor, | |
| "timestep_threshold_switch": 0.0, | |
| "timestep_threshold_stop": 0.0, | |
| "downsample_factor_level_2": downsample_factor_level_2, | |
| "ratio_level_2": ratio_level_2 | |
| } | |
| torch.manual_seed(seed) | |
| start_time_base = time.time() | |
| remove_patch(pipe) | |
| base_img = pipe(prompt, | |
| num_inference_steps=steps, height=height_width, width=height_width, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale).images[0] | |
| end_time_base = time.time() | |
| result = f"Baseline Runtime: {end_time_base-start_time_base:.2f} sec" | |
| semaphore.release() | |
| return base_img, result | |
| ##@spaces.GPU | |
| def generate_merged(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method): | |
| semaphore.acquire() | |
| downsample_factor = 2 | |
| ratio = 0.38 | |
| merge_method = "downsample" if method == "todo" else "similarity" | |
| merge_tokens = "keys/values" if method == "todo" else "all" | |
| if height_width == 1024: | |
| downsample_factor = 2 | |
| ratio = 0.75 | |
| downsample_factor_level_2 = 1 | |
| ratio_level_2 = 0.0 | |
| elif height_width == 1536: | |
| downsample_factor = 3 | |
| ratio = 0.89 | |
| downsample_factor_level_2 = 1 | |
| ratio_level_2 = 0.0 | |
| elif height_width == 2048: | |
| downsample_factor = 4 | |
| ratio = 0.9375 | |
| downsample_factor_level_2 = 1 | |
| ratio_level_2 = 0.0 | |
| token_merge_args = {"ratio": ratio, | |
| "merge_tokens": merge_tokens, | |
| "merge_method": merge_method, | |
| "downsample_method": "nearest", | |
| "downsample_factor": downsample_factor, | |
| "timestep_threshold_switch": 0.0, | |
| "timestep_threshold_stop": 0.0, | |
| "downsample_factor_level_2": downsample_factor_level_2, | |
| "ratio_level_2": ratio_level_2 | |
| } | |
| patch_attention_proc(pipe.unet, token_merge_args=token_merge_args) | |
| torch.manual_seed(seed) | |
| start_time_merge = time.time() | |
| merged_img = pipe(prompt, | |
| num_inference_steps=steps, height=height_width, width=height_width, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale).images[0] | |
| end_time_merge = time.time() | |
| result = f"{'ToDo' if method == 'todo' else 'ToMe'} Runtime: {end_time_merge-start_time_merge:.2f} sec" | |
| semaphore.release() | |
| return merged_img, result | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("# ToDo: Token Downsampling for Efficient Generation of High-Resolution Images") | |
| prompt = gr.Textbox(interactive=True, label="prompt") | |
| negative_prompt = gr.Textbox(interactive=True, label="negative_prompt") | |
| with gr.Row(): | |
| method = gr.Dropdown(["todo", "tome"], value="todo", label="method", info="Choose Your Desired Method (Default: todo)") | |
| height_width = gr.Dropdown([1024, 1536, 2048], value=1024, label="height/width", info="Choose Your Desired Height/Width (Default: 1024)") | |
| with gr.Row(): | |
| guidance_scale = gr.Number(label="guidance_scale", value=7.5, precision=1) | |
| steps = gr.Number(label="steps", value=20, precision=0) | |
| seed = gr.Number(label="seed", value=1, precision=0) | |
| with gr.Row(): | |
| with gr.Column(): | |
| base_result = gr.Textbox(label="Baseline Runtime") | |
| base_image = gr.Image(label=f"baseline_image", type="pil", interactive=False) | |
| gen = gr.Button("Generate Baseline") | |
| gen.click(generate_baseline, inputs=[prompt, seed, steps, height_width, negative_prompt, | |
| guidance_scale, method], outputs=[base_image, base_result]) | |
| with gr.Column(): | |
| output_result = gr.Textbox(label="Runtime") | |
| output_image = gr.Image(label=f"image", type="pil", interactive=False) | |
| gen = gr.Button("Generate") | |
| gen.click(generate_merged, inputs=[prompt, seed, steps, height_width, negative_prompt, | |
| guidance_scale, method], outputs=[output_image, output_result]) | |
| demo.launch(share=True) | |