Spaces:
Running
Running
| ```python | |
| import torch | |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
| # Model selection and device placement | |
| model_id = "black-forest-labs/FLUX.1-schnell" | |
| device = "mps" if torch.backends.mps.is_available() else "cpu" # Prioritize MPS | |
| # Optimization parameters | |
| torch.backends.mps.graph_mode = False # Disable graph mode for MPS for better debugging and potential performance in some cases | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, # Use float16 for better performance on MPS | |
| scheduler=DPMSolverMultistepScheduler.from_config(model_id, algorithm_type="dpmsolver"), # Optimized scheduler for speed | |
| ) | |
| pipe.to(device) | |
| # Memory optimization | |
| pipe.enable_attention_slicing() # Enable attention slicing for memory efficiency | |
| pipe.enable_vae_slicing() # Enable VAE slicing for memory efficiency | |
| # Speed optimization with torch.compile | |
| pipe = torch.compile(pipe) # Compile the pipeline for better performance | |
| # Inference parameters | |
| prompt = "A cat holding a sign that says hello world" | |
| height = 768 | |
| width = 1360 | |
| num_inference_steps = 4 | |
| # Inference | |
| image = pipe(prompt, height=height, width=width, num_inference_steps=num_inference_steps).images[0] | |
| # Save the image (optional) | |
| image.save("cat_with_sign.png") | |
| ``` | |