Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from diffusers import ( | |
| StableDiffusionXLControlNetPipeline, | |
| DiffusionPipeline, | |
| StableDiffusionImg2ImgPipeline, | |
| StableDiffusionInpaintPipeline, | |
| StableDiffusionAdapterPipeline, | |
| StableDiffusionControlNetPipeline, | |
| StableDiffusionXLAdapterPipeline, | |
| StableDiffusionXLImg2ImgPipeline, | |
| StableDiffusionXLInpaintPipeline, | |
| ControlNetModel, | |
| T2IAdapter, | |
| ) | |
| import time | |
| import utils | |
| dtype = torch.float16 | |
| device = torch.device("cuda") | |
| pipeline_mapping = { | |
| "SD T2I": (DiffusionPipeline, "runwayml/stable-diffusion-v1-5"), | |
| "SD I2I": (StableDiffusionImg2ImgPipeline, "runwayml/stable-diffusion-v1-5"), | |
| "SD Inpainting": ( | |
| StableDiffusionInpaintPipeline, | |
| "runwayml/stable-diffusion-inpainting", | |
| ), | |
| "SD ControlNet": ( | |
| StableDiffusionControlNetPipeline, | |
| "runwayml/stable-diffusion-v1-5", | |
| "lllyasviel/sd-controlnet-canny", | |
| ), | |
| "SD T2I Adapters": ( | |
| StableDiffusionAdapterPipeline, | |
| "CompVis/stable-diffusion-v1-4" "TencentARC/t2iadapter_canny_sd14v1", | |
| ), | |
| "SDXL T2I": (DiffusionPipeline, "stabilityai/stable-diffusion-xl-base-1.0"), | |
| "SDXL I2I": ( | |
| StableDiffusionXLImg2ImgPipeline, | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| ), | |
| "SDXL Inpainting": ( | |
| StableDiffusionXLInpaintPipeline, | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| ), | |
| "SDXL ControlNet": ( | |
| StableDiffusionXLControlNetPipeline, | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| "diffusers/controlnet-canny-sdxl-1.0", | |
| ), | |
| "SDXL T2I Adapters": ( | |
| StableDiffusionXLAdapterPipeline, | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| "TencentARC/t2i-adapter-canny-sdxl-1.0", | |
| ), | |
| } | |
| def load_pipeline( | |
| pipeline_to_benchmark: str, | |
| use_channels_last: bool = False, | |
| do_torch_compile: bool = False, | |
| ): | |
| # Get pipeline details. | |
| pipeline_details = pipeline_mapping[pipeline_to_benchmark] | |
| pipeline_cls = pipeline_details[0] | |
| pipeline_ckpt = pipeline_details[1] | |
| # Load adapter if needed. | |
| if "ControlNet" in pipeline_to_benchmark: | |
| controlnet_ckpt = pipeline_details[2] | |
| controlnet = ControlNetModel.from_pretrained( | |
| controlnet_ckpt, variant="fp16", torch_dtype=torch.float16 | |
| ).to(device) | |
| elif "Adapters" in pipeline_to_benchmark: | |
| adapter_clpt = pipeline_details[2] | |
| adapter = T2IAdapter.from_pretrained( | |
| adapter_clpt, variant="fp16", torch_dtype=torch.float16 | |
| ).to(device) | |
| # Load pipeline. | |
| if ( | |
| "ControlNet" not in pipeline_to_benchmark | |
| or "Adapters" not in pipeline_to_benchmark | |
| ): | |
| pipeline = pipeline_cls.from_pretrained( | |
| pipeline_ckpt, variant="fp16", torch_dtype=dtype | |
| ) | |
| elif "ControlNet" in pipeline_to_benchmark: | |
| pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, controlnet=controlnet) | |
| elif "Adapters" in pipeline_to_benchmark: | |
| pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, adapter=adapter) | |
| pipeline.to(device) | |
| # Optionally set memory layout. | |
| if use_channels_last: | |
| pipeline.unet.to(memory_format=torch.channels_last) | |
| if hasattr(pipeline, "controlnet"): | |
| pipeline.controlnet.to(memory_format=torch.channels_last) | |
| elif hasattr(pipeline, "adapter"): | |
| pipeline.adapter.to(memory_format=torch.channels_last) | |
| # Optional torch compilation. | |
| if do_torch_compile: | |
| pipeline.unet = torch.compile( | |
| pipeline.unet, mode="reduce-overhead", fullgraph=True | |
| ) | |
| if hasattr(pipeline, "controlnet"): | |
| pipeline.controlnet = torch.compile( | |
| pipeline.controlnet, mode="reduce-overhead", fullgraph=True | |
| ) | |
| elif hasattr(pipeline, "adapter"): | |
| pipeline.adapter = torch.compile( | |
| pipeline.adapter, mode="reduce-overhead", fullgraph=True | |
| ) | |
| return pipeline | |
| def generate( | |
| pipeline_to_benchmark: str, | |
| num_images_per_prompt: int = 1, | |
| use_channels_last: bool = False, | |
| do_torch_compile: bool = False, | |
| ): | |
| print("Start...") | |
| print("Torch version", torch.__version__) | |
| print("Torch CUDA version", torch.version.cuda) | |
| pipeline = load_pipeline( | |
| pipeline_to_benchmark=pipeline_to_benchmark, | |
| use_channels_last=use_channels_last, | |
| do_torch_compile=do_torch_compile, | |
| ) | |
| for _ in range(3): | |
| prompt = 77 * "a" | |
| num_inference_steps = 20 | |
| call_args = dict( | |
| prompt=prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| num_inference_steps=num_inference_steps, | |
| ) | |
| if pipeline_to_benchmark in ["SD I2I", "SDXL I2I"]: | |
| image = utils.get_image_for_img_to_img(pipeline_to_benchmark) | |
| call_args.update({"image": image}) | |
| elif "Inpainting" in pipeline_to_benchmark: | |
| image, mask_image = utils.get_image_from_inpainting(pipeline_to_benchmark) | |
| call_args.update({"image": image, "mask_image": mask_image}) | |
| elif "ControlNet" in pipeline_to_benchmark: | |
| image = utils.get_image_for_controlnet(pipeline_to_benchmark) | |
| call_args.update({"image": image}) | |
| elif "Adapters" in pipeline_to_benchmark: | |
| image = utils.get_image_for_adapters(pipeline_to_benchmark) | |
| call_args.update({"image": image}) | |
| start_time = time.time() | |
| _ = pipeline(**call_args).images | |
| end_time = time.time() | |
| print(f"For {num_inference_steps} steps", end_time - start_time) | |
| print("Avg per step", (end_time - start_time) / num_inference_steps) | |
| with gr.Blocks() as demo: | |
| do_torch_compile = gr.Checkbox(label="Enable torch.compile()?") | |
| use_channels_last = gr.Checkbox(label="Use `channels_last` memory layout?") | |
| pipeline_to_benchmark = ( | |
| gr.Dropdown( | |
| list(pipeline_mapping.keys()), | |
| value=["Stable Diffusion V1.5"], | |
| multiselect=False, | |
| label="Pipeline to benchmark", | |
| ), | |
| ) | |
| batch_size = gr.Slider( | |
| label="Number of images per prompt", | |
| minimum=1, | |
| maximum=16, | |
| step=1, | |
| value=1, | |
| ) | |
| btn = gr.Button("Benchmark!").style( | |
| margin=False, | |
| rounded=(False, True, True, False), | |
| full_width=False, | |
| ) | |
| btn.click( | |
| fn=generate, | |
| inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile], | |
| ) | |
| demo.launch() | |