Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| # Constants | |
| MAX_SEED = 2**32 - 1 | |
| MAX_IMAGE_SIZE = 2048 | |
| # Load FLUX model | |
| pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) | |
| pipe = pipe.to("cuda") | |
| pipe.enable_model_cpu_offload() | |
| pipe.vae.enable_slicing() | |
| pipe.vae.enable_tiling() | |
| def generate_image(prompt, seed, width, height, num_inference_steps): | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| try: | |
| image = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| guidance_scale=0.0 | |
| ).images[0] | |
| return image, seed | |
| except Exception as e: | |
| print(f"Error during image generation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, seed | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Prompt") | |
| with gr.Row(): | |
| generate = gr.Button("Generate") | |
| with gr.Row(): | |
| result = gr.Image(label="Generated Image") | |
| seed_output = gr.Number(label="Seed Used") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, randomize=True) | |
| width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
| height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
| num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4) | |
| generate.click( | |
| generate_image, | |
| inputs=[prompt, seed, width, height, num_inference_steps], | |
| outputs=[result, seed_output] | |
| ) | |
| demo.launch() |