Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from diffusers import DiffusionPipeline | |
| # Constants | |
| MAX_SEED = 2**32 - 1 | |
| MAX_IMAGE_SIZE = 2048 | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load FLUX model | |
| pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device) | |
| pipe.enable_model_cpu_offload() | |
| pipe.vae.enable_slicing() | |
| pipe.vae.enable_tiling() | |
| def print_model_shapes(pipe): | |
| print("Model component shapes:") | |
| print(f"VAE Encoder: {pipe.vae.encoder}") | |
| print(f"VAE Decoder: {pipe.vae.decoder}") | |
| print(f"x_embedder shape: {pipe.transformer.x_embedder.weight.shape}") | |
| print(f"First transformer block shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}") | |
| print_model_shapes(pipe) | |
| def infer(prompt, init_image=None, seed=None, width=1024, height=1024, num_inference_steps=4, guidance_scale=0.0): | |
| generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None | |
| try: | |
| if init_image is None: | |
| # text2img case | |
| print("Running text-to-image generation") | |
| image = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| guidance_scale=guidance_scale | |
| ).images[0] | |
| else: | |
| # img2img case | |
| print("Running image-to-image generation") | |
| init_image = init_image.convert("RGB").resize((width, height)) | |
| image = pipe( | |
| prompt=prompt, | |
| image=init_image, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| guidance_scale=guidance_scale | |
| ).images[0] | |
| return image, seed | |
| except RuntimeError as e: | |
| if "mat1 and mat2 shapes cannot be multiplied" in str(e): | |
| print("Matrix multiplication error detected. Tensor shapes:") | |
| print(e) | |
| # Here you could add code to print shapes of specific tensors if needed | |
| else: | |
| print(f"RuntimeError during inference: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return Image.new("RGB", (width, height), (255, 0, 0)), seed | |
| except Exception as e: | |
| print(f"Unexpected error during inference: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return Image.new("RGB", (width, height), (255, 0, 0)), seed | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Prompt") | |
| init_image = gr.Image(label="Initial Image (optional)", type="pil") | |
| with gr.Row(): | |
| generate = gr.Button("Generate") | |
| with gr.Row(): | |
| result = gr.Image(label="Result") | |
| seed_output = gr.Number(label="Seed") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=None) | |
| width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
| height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
| num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4) | |
| guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=0.0) | |
| generate.click( | |
| infer, | |
| inputs=[prompt, init_image, seed, width, height, num_inference_steps, guidance_scale], | |
| outputs=[result, seed_output] | |
| ) | |
| demo.launch() |