Spaces:
Runtime error
Runtime error
File size: 3,920 Bytes
755b512 1a974b2 755b512 6c5d32a 755b512 8bac254 6c5d32a 755b512 03d82a1 22f2658 755b512 44895ee 755b512 44895ee 755b512 44895ee 755b512 03d82a1 755b512 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import torch
import spaces
from diffusers import FluxPipeline
from safetensors.torch import load_file
# Load the model
pipe = FluxPipeline.from_pretrained(
'black-forest-labs/FLUX.1-dev',
torch_dtype=torch.bfloat16,
use_safetensors=True
).to('cuda')
# Load SRPO weights
from huggingface_hub import hf_hub_download
srpo_path = hf_hub_download(
repo_id="tencent/SRPO",
filename="diffusion_pytorch_model.safetensors"
)
state_dict = load_file(srpo_path)
pipe.transformer.load_state_dict(state_dict)
@spaces.GPU(duration=120)
def generate_image(
prompt,
width=1024,
height=1024,
guidance_scale=3.5,
num_inference_steps=50,
seed=-1
):
if seed == -1:
seed = torch.randint(0, 2**32, (1,)).item()
generator = torch.Generator(device='cuda').manual_seed(seed)
image = pipe(
prompt=prompt,
guidance_scale=guidance_scale,
height=height,
width=width,
num_inference_steps=num_inference_steps,
max_sequence_length=512,
generator=generator
).images[0]
return image, seed
with gr.Blocks(title="FLUX SRPO Text-to-Image", theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray", neutral_hue="slate")) as demo:
gr.Markdown("# Flux SRPO")
gr.Markdown("Generate images using FLUX model enhanced with Tencent's SRPO technique")
gr.Markdown("Built with [AnyCoder](https://huggingface.co/spaces/akhaliq/anycoder)")
output_image = gr.Image(label="Generated Image", type="pil")
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate...",
lines=3
)
generate_btn = gr.Button("Generate Image", variant="primary", size="lg")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
width = gr.Slider(
minimum=256,
maximum=2048,
value=1024,
step=64,
label="Width"
)
height = gr.Slider(
minimum=256,
maximum=2048,
value=1024,
step=64,
label="Height"
)
with gr.Row():
guidance_scale = gr.Slider(
minimum=1.0,
maximum=20.0,
value=3.5,
step=0.5,
label="Guidance Scale"
)
num_inference_steps = gr.Slider(
minimum=10,
maximum=100,
value=50,
step=5,
label="Inference Steps"
)
seed = gr.Number(
label="Seed (-1 for random)",
value=-1,
precision=0
)
used_seed = gr.Number(label="Seed Used", precision=0)
gr.Examples(
examples=[
["The Death of Ophelia by John Everett Millais, Pre-Raphaelite painting, Ophelia floating in a river surrounded by flowers, detailed natural elements, melancholic and tragic atmosphere"],
["A serene Japanese garden with cherry blossoms, koi pond, traditional wooden bridge, soft morning light, photorealistic"],
["Cyberpunk cityscape at night, neon lights, flying cars, rain-slicked streets, blade runner aesthetic, highly detailed"],
["Portrait of a majestic lion in golden hour light, detailed fur texture, intense gaze, African savanna background"],
["Abstract colorful explosion of paint in water, high speed photography, vibrant colors mixing, dramatic lighting"],
],
inputs=prompt,
label="Example Prompts"
)
generate_btn.click(
fn=generate_image,
inputs=[prompt, width, height, guidance_scale, num_inference_steps, seed],
outputs=[output_image, used_seed]
)
demo.launch() |