Spaces:
Paused
Paused
1024 support
Browse files- app.py +65 -31
- src/generate.py +10 -2
app.py
CHANGED
|
@@ -8,11 +8,7 @@ import numpy as np
|
|
| 8 |
|
| 9 |
from src.generate import seed_everything, generate
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
# def init_pipeline():
|
| 15 |
-
# global pipe
|
| 16 |
pipe = FluxPipeline.from_pretrained(
|
| 17 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
| 18 |
)
|
|
@@ -20,12 +16,17 @@ pipe = pipe.to("cuda")
|
|
| 20 |
pipe.load_lora_weights(
|
| 21 |
"Yuanshi/OminiControl",
|
| 22 |
weight_name=f"omini/subject_512.safetensors",
|
| 23 |
-
adapter_name="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
)
|
| 25 |
|
|
|
|
| 26 |
@spaces.GPU
|
| 27 |
-
def process_image_and_text(image, text):
|
| 28 |
-
# center crop image
|
| 29 |
w, h, min_size = image.size[0], image.size[1], min(image.size)
|
| 30 |
image = image.crop(
|
| 31 |
(
|
|
@@ -39,16 +40,13 @@ def process_image_and_text(image, text):
|
|
| 39 |
|
| 40 |
condition = Condition("subject", image)
|
| 41 |
|
| 42 |
-
# if pipe is None:
|
| 43 |
-
# init_pipeline()
|
| 44 |
-
|
| 45 |
result_img = generate(
|
| 46 |
pipe,
|
| 47 |
prompt=text.strip(),
|
| 48 |
conditions=[condition],
|
| 49 |
num_inference_steps=8,
|
| 50 |
-
height=
|
| 51 |
-
width=
|
| 52 |
).images[0]
|
| 53 |
|
| 54 |
return result_img
|
|
@@ -58,38 +56,74 @@ def get_samples():
|
|
| 58 |
sample_list = [
|
| 59 |
{
|
| 60 |
"image": "assets/oranges.jpg",
|
|
|
|
| 61 |
"text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
|
| 62 |
},
|
| 63 |
{
|
| 64 |
"image": "assets/penguin.jpg",
|
|
|
|
| 65 |
"text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
|
| 66 |
},
|
| 67 |
{
|
| 68 |
"image": "assets/rc_car.jpg",
|
|
|
|
| 69 |
"text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
|
| 70 |
},
|
| 71 |
{
|
| 72 |
"image": "assets/clock.jpg",
|
|
|
|
| 73 |
"text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
|
| 74 |
},
|
| 75 |
]
|
| 76 |
-
return [
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
if __name__ == "__main__":
|
| 91 |
-
|
| 92 |
-
demo.launch(
|
| 93 |
-
debug=True,
|
| 94 |
-
ssr_mode=False
|
| 95 |
-
)
|
|
|
|
| 8 |
|
| 9 |
from src.generate import seed_everything, generate
|
| 10 |
|
| 11 |
+
pipe = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
pipe = FluxPipeline.from_pretrained(
|
| 13 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
| 14 |
)
|
|
|
|
| 16 |
pipe.load_lora_weights(
|
| 17 |
"Yuanshi/OminiControl",
|
| 18 |
weight_name=f"omini/subject_512.safetensors",
|
| 19 |
+
adapter_name="subject_512",
|
| 20 |
+
)
|
| 21 |
+
pipe.load_lora_weights(
|
| 22 |
+
"Yuanshi/OminiControl",
|
| 23 |
+
weight_name=f"omini/subject_1024_beta.safetensors",
|
| 24 |
+
adapter_name="subject_1024",
|
| 25 |
)
|
| 26 |
|
| 27 |
+
|
| 28 |
@spaces.GPU
|
| 29 |
+
def process_image_and_text(image, resolution, text):
|
|
|
|
| 30 |
w, h, min_size = image.size[0], image.size[1], min(image.size)
|
| 31 |
image = image.crop(
|
| 32 |
(
|
|
|
|
| 40 |
|
| 41 |
condition = Condition("subject", image)
|
| 42 |
|
|
|
|
|
|
|
|
|
|
| 43 |
result_img = generate(
|
| 44 |
pipe,
|
| 45 |
prompt=text.strip(),
|
| 46 |
conditions=[condition],
|
| 47 |
num_inference_steps=8,
|
| 48 |
+
height=resolution,
|
| 49 |
+
width=resolution,
|
| 50 |
).images[0]
|
| 51 |
|
| 52 |
return result_img
|
|
|
|
| 56 |
sample_list = [
|
| 57 |
{
|
| 58 |
"image": "assets/oranges.jpg",
|
| 59 |
+
"resolution": 512,
|
| 60 |
"text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
|
| 61 |
},
|
| 62 |
{
|
| 63 |
"image": "assets/penguin.jpg",
|
| 64 |
+
"resolution": 512,
|
| 65 |
"text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
|
| 66 |
},
|
| 67 |
{
|
| 68 |
"image": "assets/rc_car.jpg",
|
| 69 |
+
"resolution": 1024,
|
| 70 |
"text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
|
| 71 |
},
|
| 72 |
{
|
| 73 |
"image": "assets/clock.jpg",
|
| 74 |
+
"resolution": 1024,
|
| 75 |
"text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
|
| 76 |
},
|
| 77 |
]
|
| 78 |
+
return [
|
| 79 |
+
[
|
| 80 |
+
Image.open(sample["image"]).resize((512, 512)),
|
| 81 |
+
sample["resolution"],
|
| 82 |
+
sample["text"],
|
| 83 |
+
]
|
| 84 |
+
for sample in sample_list
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
header = """
|
| 89 |
+
# 🌍 OminiControl / FLUX
|
| 90 |
+
|
| 91 |
+
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
|
| 92 |
+
<a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
|
| 93 |
+
<a href="https://huggingface.co/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
|
| 94 |
+
<a href="https://github.com/Yuanshi9815/OminiControl"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
|
| 95 |
+
</div>
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def create_app():
|
| 100 |
+
with gr.Blocks() as app:
|
| 101 |
+
gr.Markdown(header)
|
| 102 |
+
with gr.Tabs():
|
| 103 |
+
with gr.Tab("Subject-driven"):
|
| 104 |
+
gr.Interface(
|
| 105 |
+
fn=process_image_and_text,
|
| 106 |
+
inputs=[
|
| 107 |
+
gr.Image(type="pil", label="Condition Image", width=300),
|
| 108 |
+
gr.Radio(
|
| 109 |
+
[("512", 512), ("1024(beta)", 1024)],
|
| 110 |
+
label="Resolution",
|
| 111 |
+
value=512,
|
| 112 |
+
),
|
| 113 |
+
# gr.Slider(4, 16, 4, step=4, label="Inference Steps"),
|
| 114 |
+
gr.Textbox(lines=2, label="Text Prompt"),
|
| 115 |
+
],
|
| 116 |
+
outputs=gr.Image(type="pil"),
|
| 117 |
+
examples=get_samples(),
|
| 118 |
+
)
|
| 119 |
+
with gr.Tab("Fill"):
|
| 120 |
+
gr.Markdown("Coming soon")
|
| 121 |
+
with gr.Tab("Canny"):
|
| 122 |
+
gr.Markdown("Coming soon")
|
| 123 |
+
with gr.Tab("Depth"):
|
| 124 |
+
gr.Markdown("Coming soon")
|
| 125 |
+
return app
|
| 126 |
+
|
| 127 |
|
| 128 |
if __name__ == "__main__":
|
| 129 |
+
create_app().launch(debug=True, ssr_mode=False)
|
|
|
|
|
|
|
|
|
|
|
|
src/generate.py
CHANGED
|
@@ -166,7 +166,12 @@ def generate(
|
|
| 166 |
use_condition = conditions is not None or []
|
| 167 |
if use_condition:
|
| 168 |
assert len(conditions) <= 1, "Only one condition is supported for now."
|
| 169 |
-
pipeline.set_adapters(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
for condition in conditions:
|
| 171 |
tokens, ids, type_id = condition.encode(self)
|
| 172 |
condition_latents.append(tokens) # [batch_size, token_n, token_dim]
|
|
@@ -175,7 +180,10 @@ def generate(
|
|
| 175 |
condition_latents = torch.cat(condition_latents, dim=1)
|
| 176 |
condition_ids = torch.cat(condition_ids, dim=0)
|
| 177 |
if condition.condition_type == "subject":
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
| 179 |
condition_type_ids = torch.cat(condition_type_ids, dim=0)
|
| 180 |
|
| 181 |
# 5. Prepare timesteps
|
|
|
|
| 166 |
use_condition = conditions is not None or []
|
| 167 |
if use_condition:
|
| 168 |
assert len(conditions) <= 1, "Only one condition is supported for now."
|
| 169 |
+
pipeline.set_adapters(
|
| 170 |
+
{
|
| 171 |
+
512: "subject_512",
|
| 172 |
+
1024: "subject_1024",
|
| 173 |
+
}[height]
|
| 174 |
+
)
|
| 175 |
for condition in conditions:
|
| 176 |
tokens, ids, type_id = condition.encode(self)
|
| 177 |
condition_latents.append(tokens) # [batch_size, token_n, token_dim]
|
|
|
|
| 180 |
condition_latents = torch.cat(condition_latents, dim=1)
|
| 181 |
condition_ids = torch.cat(condition_ids, dim=0)
|
| 182 |
if condition.condition_type == "subject":
|
| 183 |
+
delta = 32 if height == 512 else -32
|
| 184 |
+
# print(f"Condition delta: {delta}")
|
| 185 |
+
condition_ids[:, 2] += delta
|
| 186 |
+
|
| 187 |
condition_type_ids = torch.cat(condition_type_ids, dim=0)
|
| 188 |
|
| 189 |
# 5. Prepare timesteps
|