|
|
import spaces |
|
|
import torch |
|
|
from diffusers import BriaPipeline |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
|
|
|
MODEL_ID = "briaai/BRIA-3.2" |
|
|
pipe = BriaPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16) |
|
|
pipe.to("cuda") |
|
|
|
|
|
@spaces.GPU(duration=1500) |
|
|
def compile_transformer(): |
|
|
with spaces.aoti_capture(pipe.transformer) as call: |
|
|
pipe("arbitrary example prompt") |
|
|
|
|
|
exported = torch.export.export( |
|
|
pipe.transformer, |
|
|
args=call.args, |
|
|
kwargs=call.kwargs, |
|
|
) |
|
|
return spaces.aoti_compile(exported) |
|
|
|
|
|
compiled_transformer = compile_transformer() |
|
|
spaces.aoti_apply(compiled_transformer, pipe.transformer) |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def generate_image(prompt, seed=0): |
|
|
torch.manual_seed(seed) |
|
|
image = pipe(prompt).images[0] |
|
|
return image |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# BRIA-3.2 Text-to-Image Generator") |
|
|
gr.Markdown("Generate images from text prompts using the BRIA-3.2 model.") |
|
|
|
|
|
with gr.Row(): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
value="a cat sitting on a chair", |
|
|
interactive=True |
|
|
) |
|
|
seed = gr.Number( |
|
|
label="Seed (0 for random)", |
|
|
value=0, |
|
|
precision=0 |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("Generate Image") |
|
|
output = gr.Image(label="Generated Image", type="pil") |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_image, |
|
|
inputs=[prompt, seed], |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["a futuristic cityscape at sunset"], |
|
|
["a forest with glowing mushrooms"], |
|
|
["a steampunk robot drinking tea"], |
|
|
["an astronaut riding a horse on mars"] |
|
|
], |
|
|
inputs=[prompt], |
|
|
outputs=output, |
|
|
fn=generate_image, |
|
|
cache_examples="lazy" |
|
|
) |
|
|
|
|
|
demo.launch() |