File size: 3,920 Bytes
755b512
 
 
 
 
 
 
 
1a974b2
755b512
 
6c5d32a
755b512
8bac254
 
 
 
 
 
 
 
6c5d32a
755b512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03d82a1
 
 
22f2658
755b512
44895ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755b512
44895ee
 
 
 
 
 
755b512
 
44895ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755b512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03d82a1
755b512
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import torch
import spaces
from diffusers import FluxPipeline
from safetensors.torch import load_file

# Load the model
pipe = FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    torch_dtype=torch.bfloat16,
    use_safetensors=True
).to('cuda')

# Load SRPO weights
from huggingface_hub import hf_hub_download

srpo_path = hf_hub_download(
    repo_id="tencent/SRPO",
    filename="diffusion_pytorch_model.safetensors"
)
state_dict = load_file(srpo_path)
pipe.transformer.load_state_dict(state_dict)

@spaces.GPU(duration=120)
def generate_image(
    prompt,
    width=1024,
    height=1024,
    guidance_scale=3.5,
    num_inference_steps=50,
    seed=-1
):
    if seed == -1:
        seed = torch.randint(0, 2**32, (1,)).item()
    
    generator = torch.Generator(device='cuda').manual_seed(seed)
    
    image = pipe(
        prompt=prompt,
        guidance_scale=guidance_scale,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        max_sequence_length=512,
        generator=generator
    ).images[0]
    
    return image, seed

with gr.Blocks(title="FLUX SRPO Text-to-Image", theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray", neutral_hue="slate")) as demo:
    gr.Markdown("# Flux SRPO")
    gr.Markdown("Generate images using FLUX model enhanced with Tencent's SRPO technique")
    gr.Markdown("Built with [AnyCoder](https://huggingface.co/spaces/akhaliq/anycoder)")
    
    output_image = gr.Image(label="Generated Image", type="pil")
    
    prompt = gr.Textbox(
        label="Prompt",
        placeholder="Describe the image you want to generate...",
        lines=3
    )
    
    generate_btn = gr.Button("Generate Image", variant="primary", size="lg")
    
    with gr.Accordion("Advanced Settings", open=False):
        with gr.Row():
            width = gr.Slider(
                minimum=256,
                maximum=2048,
                value=1024,
                step=64,
                label="Width"
            )
            height = gr.Slider(
                minimum=256,
                maximum=2048,
                value=1024,
                step=64,
                label="Height"
            )
        
        with gr.Row():
            guidance_scale = gr.Slider(
                minimum=1.0,
                maximum=20.0,
                value=3.5,
                step=0.5,
                label="Guidance Scale"
            )
            num_inference_steps = gr.Slider(
                minimum=10,
                maximum=100,
                value=50,
                step=5,
                label="Inference Steps"
            )
        
        seed = gr.Number(
            label="Seed (-1 for random)",
            value=-1,
            precision=0
        )
        
        used_seed = gr.Number(label="Seed Used", precision=0)
    
    gr.Examples(
        examples=[
            ["The Death of Ophelia by John Everett Millais, Pre-Raphaelite painting, Ophelia floating in a river surrounded by flowers, detailed natural elements, melancholic and tragic atmosphere"],
            ["A serene Japanese garden with cherry blossoms, koi pond, traditional wooden bridge, soft morning light, photorealistic"],
            ["Cyberpunk cityscape at night, neon lights, flying cars, rain-slicked streets, blade runner aesthetic, highly detailed"],
            ["Portrait of a majestic lion in golden hour light, detailed fur texture, intense gaze, African savanna background"],
            ["Abstract colorful explosion of paint in water, high speed photography, vibrant colors mixing, dramatic lighting"],
        ],
        inputs=prompt,
        label="Example Prompts"
    )
    
    generate_btn.click(
        fn=generate_image,
        inputs=[prompt, width, height, guidance_scale, num_inference_steps, seed],
        outputs=[output_image, used_seed]
    )

demo.launch()