| import torch | |
| import os | |
| auth_token = os.getenv("HF_TOKEN") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if device == "cuda" else None | |
| from diffusers import StableDiffusionPipeline | |
| model_id = "CompVis/stable-diffusion-v1-4" | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype | |
| ).to(device) | |
| def predict(prompt): | |
| return pipe(prompt).images[0] | |
| import gradio as gr | |
| gradio_ui = gr.Interface( | |
| fn=predict, | |
| title="Stable Diffusion Demo", | |
| description="Enter a description of an image you'd like to generate!", | |
| inputs=[ | |
| gr.Textbox(lines=2, label="Paste some text here"), | |
| ], | |
| outputs=["image"], | |
| examples=[["a photograph of an astronaut riding a horse"]], | |
| ) | |
| gradio_ui.launch() |