Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import csv | |
| import gc | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Union | |
| import torch | |
| import torch.utils.benchmark as benchmark | |
| GITHUB_SHA = os.getenv("GITHUB_SHA", None) | |
| BENCHMARK_FIELDS = [ | |
| "pipeline_cls", | |
| "ckpt_id", | |
| "batch_size", | |
| "num_inference_steps", | |
| "model_cpu_offload", | |
| "run_compile", | |
| "time (secs)", | |
| "memory (gbs)", | |
| "actual_gpu_memory (gbs)", | |
| "github_sha", | |
| ] | |
| PROMPT = "ghibli style, a fantasy landscape with castles" | |
| BASE_PATH = os.getenv("BASE_PATH", ".") | |
| TOTAL_GPU_MEMORY = float(os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3))) | |
| REPO_ID = "diffusers/benchmarks" | |
| FINAL_CSV_FILE = "collated_results.csv" | |
| class BenchmarkInfo: | |
| time: float | |
| memory: float | |
| def flush(): | |
| """Wipes off memory.""" | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_max_memory_allocated() | |
| torch.cuda.reset_peak_memory_stats() | |
| def bytes_to_giga_bytes(bytes): | |
| return f"{(bytes / 1024 / 1024 / 1024):.3f}" | |
| def benchmark_fn(f, *args, **kwargs): | |
| t0 = benchmark.Timer( | |
| stmt="f(*args, **kwargs)", | |
| globals={"args": args, "kwargs": kwargs, "f": f}, | |
| num_threads=torch.get_num_threads(), | |
| ) | |
| return f"{(t0.blocked_autorange().mean):.3f}" | |
| def generate_csv_dict( | |
| pipeline_cls: str, ckpt: str, args: argparse.Namespace, benchmark_info: BenchmarkInfo | |
| ) -> Dict[str, Union[str, bool, float]]: | |
| """Packs benchmarking data into a dictionary for latter serialization.""" | |
| data_dict = { | |
| "pipeline_cls": pipeline_cls, | |
| "ckpt_id": ckpt, | |
| "batch_size": args.batch_size, | |
| "num_inference_steps": args.num_inference_steps, | |
| "model_cpu_offload": args.model_cpu_offload, | |
| "run_compile": args.run_compile, | |
| "time (secs)": benchmark_info.time, | |
| "memory (gbs)": benchmark_info.memory, | |
| "actual_gpu_memory (gbs)": f"{(TOTAL_GPU_MEMORY):.3f}", | |
| "github_sha": GITHUB_SHA, | |
| } | |
| return data_dict | |
| def write_to_csv(file_name: str, data_dict: Dict[str, Union[str, bool, float]]): | |
| """Serializes a dictionary into a CSV file.""" | |
| with open(file_name, mode="w", newline="") as csvfile: | |
| writer = csv.DictWriter(csvfile, fieldnames=BENCHMARK_FIELDS) | |
| writer.writeheader() | |
| writer.writerow(data_dict) | |
| def collate_csv(input_files: List[str], output_file: str): | |
| """Collates multiple identically structured CSVs into a single CSV file.""" | |
| with open(output_file, mode="w", newline="") as outfile: | |
| writer = csv.DictWriter(outfile, fieldnames=BENCHMARK_FIELDS) | |
| writer.writeheader() | |
| for file in input_files: | |
| with open(file, mode="r") as infile: | |
| reader = csv.DictReader(infile) | |
| for row in reader: | |
| writer.writerow(row) | |