Spaces:
Running
on
Zero
Running
on
Zero
| import glob | |
| import subprocess | |
| import sys | |
| from typing import List | |
| sys.path.append(".") | |
| from benchmark_text_to_image import ALL_T2I_CKPTS # noqa: E402 | |
| PATTERN = "benchmark_*.py" | |
| class SubprocessCallException(Exception): | |
| pass | |
| # Taken from `test_examples_utils.py` | |
| def run_command(command: List[str], return_stdout=False): | |
| """ | |
| Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture | |
| if an error occurred while running `command` | |
| """ | |
| try: | |
| output = subprocess.check_output(command, stderr=subprocess.STDOUT) | |
| if return_stdout: | |
| if hasattr(output, "decode"): | |
| output = output.decode("utf-8") | |
| return output | |
| except subprocess.CalledProcessError as e: | |
| raise SubprocessCallException( | |
| f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" | |
| ) from e | |
| def main(): | |
| python_files = glob.glob(PATTERN) | |
| for file in python_files: | |
| print(f"****** Running file: {file} ******") | |
| # Run with canonical settings. | |
| if file != "benchmark_text_to_image.py" and file != "benchmark_ip_adapters.py": | |
| command = f"python {file}" | |
| run_command(command.split()) | |
| command += " --run_compile" | |
| run_command(command.split()) | |
| # Run variants. | |
| for file in python_files: | |
| # See: https://github.com/pytorch/pytorch/issues/129637 | |
| if file == "benchmark_ip_adapters.py": | |
| continue | |
| if file == "benchmark_text_to_image.py": | |
| for ckpt in ALL_T2I_CKPTS: | |
| command = f"python {file} --ckpt {ckpt}" | |
| if "turbo" in ckpt: | |
| command += " --num_inference_steps 1" | |
| run_command(command.split()) | |
| command += " --run_compile" | |
| run_command(command.split()) | |
| elif file == "benchmark_sd_img.py": | |
| for ckpt in ["stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/sdxl-turbo"]: | |
| command = f"python {file} --ckpt {ckpt}" | |
| if ckpt == "stabilityai/sdxl-turbo": | |
| command += " --num_inference_steps 2" | |
| run_command(command.split()) | |
| command += " --run_compile" | |
| run_command(command.split()) | |
| elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]: | |
| sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0" | |
| command = f"python {file} --ckpt {sdxl_ckpt}" | |
| run_command(command.split()) | |
| command += " --run_compile" | |
| run_command(command.split()) | |
| elif file in ["benchmark_controlnet.py", "benchmark_t2i_adapter.py"]: | |
| sdxl_ckpt = ( | |
| "diffusers/controlnet-canny-sdxl-1.0" | |
| if "controlnet" in file | |
| else "TencentARC/t2i-adapter-canny-sdxl-1.0" | |
| ) | |
| command = f"python {file} --ckpt {sdxl_ckpt}" | |
| run_command(command.split()) | |
| command += " --run_compile" | |
| run_command(command.split()) | |
| if __name__ == "__main__": | |
| main() | |