Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from diffusers import ( | |
| AutoPipelineForText2Image, | |
| StableDiffusionXLControlNetPipeline, | |
| DiffusionPipeline, | |
| StableDiffusionImg2ImgPipeline, | |
| StableDiffusionInpaintPipeline, | |
| StableDiffusionAdapterPipeline, | |
| StableDiffusionControlNetPipeline, | |
| StableDiffusionXLAdapterPipeline, | |
| StableDiffusionXLImg2ImgPipeline, | |
| StableDiffusionXLInpaintPipeline, | |
| ControlNetModel, | |
| T2IAdapter, | |
| ) | |
| import time | |
| import utils | |
| dtype = torch.float16 | |
| device = torch.device("cuda") | |
| # pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile | |
| # examples = [["SD T2I", 4, True, True]] | |
| pipeline_mapping = { | |
| "SD T2I": (DiffusionPipeline, "runwayml/stable-diffusion-v1-5"), | |
| "SD I2I": (StableDiffusionImg2ImgPipeline, "runwayml/stable-diffusion-v1-5"), | |
| "SD Inpainting": ( | |
| StableDiffusionInpaintPipeline, | |
| "runwayml/stable-diffusion-inpainting", | |
| ), | |
| "SD ControlNet": ( | |
| StableDiffusionControlNetPipeline, | |
| "runwayml/stable-diffusion-v1-5", | |
| "lllyasviel/sd-controlnet-canny", | |
| ), | |
| "SD T2I Adapters": ( | |
| StableDiffusionAdapterPipeline, | |
| "CompVis/stable-diffusion-v1-4", | |
| "TencentARC/t2iadapter_canny_sd14v1", | |
| ), | |
| "SDXL T2I": (DiffusionPipeline, "stabilityai/stable-diffusion-xl-base-1.0"), | |
| "SDXL I2I": ( | |
| StableDiffusionXLImg2ImgPipeline, | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| ), | |
| "SDXL Inpainting": ( | |
| StableDiffusionXLInpaintPipeline, | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| ), | |
| "SDXL ControlNet": ( | |
| StableDiffusionXLControlNetPipeline, | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| "diffusers/controlnet-canny-sdxl-1.0", | |
| ), | |
| "SDXL T2I Adapters": ( | |
| StableDiffusionXLAdapterPipeline, | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| "TencentARC/t2i-adapter-canny-sdxl-1.0", | |
| ), | |
| "Kandinsky 2.2 (T2I)": ( | |
| AutoPipelineForText2Image, | |
| "kandinsky-community/kandinsky-2-2-decoder", | |
| ), | |
| "Würstchen (T2I)": (AutoPipelineForText2Image, "warp-ai/wuerstchen"), | |
| } | |
| def load_pipeline( | |
| pipeline_to_benchmark: str, | |
| use_channels_last: bool = False, | |
| do_torch_compile: bool = False, | |
| ): | |
| # Get pipeline details. | |
| print(f"Loading pipeline: {pipeline_to_benchmark}") | |
| pipeline_details = pipeline_mapping[pipeline_to_benchmark] | |
| pipeline_cls = pipeline_details[0] | |
| pipeline_ckpt = pipeline_details[1] | |
| # Load adapter if needed. | |
| if "ControlNet" in pipeline_to_benchmark: | |
| controlnet_ckpt = pipeline_details[2] | |
| controlnet = ControlNetModel.from_pretrained( | |
| controlnet_ckpt, torch_dtype=dtype | |
| ).to(device) | |
| elif "Adapters" in pipeline_to_benchmark: | |
| adapter_clpt = pipeline_details[2] | |
| adapter = T2IAdapter.from_pretrained(adapter_clpt, torch_dtype=dtype).to(device) | |
| # Load pipeline. | |
| if ( | |
| "ControlNet" not in pipeline_to_benchmark | |
| and "Adapters" not in pipeline_to_benchmark | |
| ): | |
| pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, torch_dtype=dtype) | |
| elif "ControlNet" in pipeline_to_benchmark: | |
| pipeline = pipeline_cls.from_pretrained( | |
| pipeline_ckpt, controlnet=controlnet, torch_dtype=dtype | |
| ) | |
| elif "Adapters" in pipeline_to_benchmark: | |
| pipeline = pipeline_cls.from_pretrained( | |
| pipeline_ckpt, adapter=adapter, torch_dtype=dtype | |
| ) | |
| pipeline.to(device) | |
| # Optionally set memory layout. | |
| if use_channels_last: | |
| print("Setting memory layout.") | |
| if pipeline_to_benchmark not in ["Würstchen (T2I)", "Kandinsky 2.2 (T2I)"]: | |
| pipeline.unet.to(memory_format=torch.channels_last) | |
| elif pipeline_to_benchmark == "Würstchen (T2I)": | |
| pipeline.prior_prior.to(memory_format=torch.channels_last) | |
| pipeline.decoder.to(memory_format=torch.channels_last) | |
| elif pipeline_to_benchmark == "Kandinsky 2.2 (T2I)": | |
| pipeline.unet.to(memory_format=torch.channels_last) | |
| if hasattr(pipeline, "controlnet"): | |
| pipeline.controlnet.to(memory_format=torch.channels_last) | |
| elif hasattr(pipeline, "adapter"): | |
| pipeline.adapter.to(memory_format=torch.channels_last) | |
| # Optional torch compilation. | |
| if do_torch_compile: | |
| print("Compiling pipeline.") | |
| if pipeline_to_benchmark not in ["Würstchen (T2I)", "Kandinsky 2.2 (T2I)"]: | |
| pipeline.unet = torch.compile( | |
| pipeline.unet, mode="reduce-overhead", fullgraph=True | |
| ) | |
| elif pipeline_to_benchmark == "Würstchen (T2I)": | |
| pipeline.prior_prior = torch.compile( | |
| pipeline.prior_prior, mode="reduce-overhead", fullgraph=True | |
| ) | |
| pipeline.decoder = torch.compile( | |
| pipeline.decoder, mode="reduce-overhead", fullgraph=True | |
| ) | |
| elif pipeline_to_benchmark == "Kandinsky 2.2 (T2I)": | |
| pipeline.unet = torch.compile( | |
| pipeline.unet, mode="reduce-overhead", fullgraph=True | |
| ) | |
| if hasattr(pipeline, "controlnet"): | |
| pipeline.controlnet = torch.compile( | |
| pipeline.controlnet, mode="reduce-overhead", fullgraph=True | |
| ) | |
| elif hasattr(pipeline, "adapter"): | |
| pipeline.adapter = torch.compile( | |
| pipeline.adapter, mode="reduce-overhead", fullgraph=True | |
| ) | |
| print("Pipeline loaded.") | |
| pipeline.set_progress_bar_config(disable=True) | |
| return pipeline | |
| def generate( | |
| pipeline_to_benchmark: str, | |
| num_images_per_prompt: int = 1, | |
| use_channels_last: bool = False, | |
| do_torch_compile: bool = False, | |
| ): | |
| if isinstance(pipeline_to_benchmark, list): | |
| # It can only happen when we don't select a pipeline to benchmark. | |
| raise ValueError( | |
| "pipeline_to_benchmark cannot be None. Please select a pipeline to benchmark." | |
| ) | |
| print("Start...") | |
| print("Torch version", torch.__version__) | |
| print("Torch CUDA version", torch.version.cuda) | |
| pipeline = load_pipeline( | |
| pipeline_to_benchmark=pipeline_to_benchmark, | |
| use_channels_last=use_channels_last, | |
| do_torch_compile=do_torch_compile, | |
| ) | |
| for _ in range(3): | |
| prompt = 77 * "a" | |
| num_inference_steps = 20 | |
| call_args = dict( | |
| prompt=prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| num_inference_steps=num_inference_steps, | |
| ) | |
| if pipeline_to_benchmark in ["SD I2I", "SDXL I2I"]: | |
| image = utils.get_image_for_img_to_img(pipeline_to_benchmark) | |
| call_args.update({"image": image}) | |
| elif "Inpainting" in pipeline_to_benchmark: | |
| image, mask_image = utils.get_image_from_inpainting(pipeline_to_benchmark) | |
| call_args.update({"image": image, "mask_image": mask_image}) | |
| elif "ControlNet" in pipeline_to_benchmark: | |
| image = utils.get_image_for_controlnet(pipeline_to_benchmark) | |
| call_args.update({"image": image}) | |
| elif "Adapters" in pipeline_to_benchmark: | |
| image = utils.get_image_for_adapters(pipeline_to_benchmark) | |
| call_args.update({"image": image}) | |
| start_time = time.time() | |
| _ = pipeline(**call_args).images | |
| end_time = time.time() | |
| print(f"For {num_inference_steps} steps", end_time - start_time) | |
| print("Avg per step", (end_time - start_time) / num_inference_steps) | |
| return ( | |
| f"Avg per step: {((end_time - start_time) / num_inference_steps):.4f} seconds." | |
| ) | |
| with gr.Blocks(css="style.css") as demo: | |
| do_torch_compile = gr.Checkbox(label="Enable torch.compile()?") | |
| use_channels_last = gr.Checkbox(label="Use `channels_last` memory layout?") | |
| pipeline_to_benchmark = gr.Dropdown( | |
| list(pipeline_mapping.keys()), | |
| value=None, | |
| multiselect=False, | |
| label="Pipeline to benchmark", | |
| ) | |
| batch_size = gr.Slider( | |
| label="Number of images per prompt", | |
| minimum=1, | |
| maximum=16, | |
| step=1, | |
| value=1, | |
| ) | |
| btn = gr.Button("Benchmark!").style( | |
| margin=False, | |
| rounded=(False, True, True, False), | |
| full_width=False, | |
| ) | |
| result = gr.Text(label="Result") | |
| # gr.Examples( | |
| # examples=examples, | |
| # inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile], | |
| # outputs=result, | |
| # fn=generate, | |
| # cache_examples=True, | |
| # ) | |
| btn.click( | |
| fn=generate, | |
| inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile], | |
| outputs=result, | |
| ) | |
| demo.launch(show_error=True) | |