deltheil's picture
Update src/app.py
b67ccaf verified
from pathlib import Path
from typing import Any, cast
import gradio as gr
import spaces
import torch
from finegrain_toolbox.flux import Model
from finegrain_toolbox.flux.prompt import prompt_with_embeds
from finegrain_toolbox.processors import product_placement
from gradio_image_annotation import image_annotator
from huggingface_hub import hf_hub_download
from PIL import Image
from safetensors.torch import load_file
# initialize on CPU then move to GPU (Zero GPU)
DEVICE_CPU = torch.device("cpu")
DTYPE = torch.bfloat16
model = Model.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", device=DEVICE_CPU, dtype=DTYPE)
lora_path = Path(
hf_hub_download(
repo_id="finegrain/finegrain-product-placement-lora",
filename="finegrain-placement-v1-rank8.safetensors",
)
)
prompt_path = Path(
hf_hub_download(
repo_id="finegrain/finegrain-product-placement-lora",
filename="addinbox-prompt.safetensors",
)
)
prompt_st = load_file(prompt_path, device="cpu")
prompt = prompt_with_embeds(
text="Add this in the box",
clip_prompt_embeds=prompt_st["clip"],
t5_prompt_embeds=prompt_st["t5"],
)
model.transformer.load_lora_adapter(lora_path, adapter_name="placement")
model.transformer.fuse_lora()
model.transformer.unload_lora()
DEVICE = torch.device("cuda")
model = model.to(device=DEVICE, dtype=DTYPE)
prompt = prompt.to(device=DEVICE, dtype=DTYPE)
def on_change(scene: dict[str, Any] | None, reference: Image.Image | None) -> tuple[dict[str, Any], str]:
bbox_str = ""
if scene is not None and isinstance(scene.get("boxes"), list) and len(scene.get("boxes", [])) == 1:
assert scene is not None
box = scene["boxes"][0]
bbox_str = f"({box['xmin']}, {box['ymin']}, {box['xmax']}, {box['ymax']})"
return (gr.update(interactive=reference is not None and bbox_str != ""), bbox_str)
@spaces.GPU(duration=120)
def _process(
scene: dict[str, Any],
reference: Image.Image,
seed: int = 1234,
) -> tuple[tuple[Image.Image, Image.Image], Image.Image, Image.Image]:
assert isinstance(scene_image := scene["image"], Image.Image)
assert isinstance(boxes := scene["boxes"], list)
assert len(boxes) == 1
assert isinstance(box := boxes[0], dict)
bbox = tuple(box[k] for k in ["xmin", "ymin", "xmax", "ymax"])
result = product_placement.process(
model=model,
scene=scene_image,
reference=reference,
bbox=bbox,
prompt=prompt,
seed=seed,
max_short_size=1024,
max_long_size=2048,
)
output = result.output
before_after = (scene_image.resize(output.size), output)
return (before_after, result.reference, result.scene)
def process(
scene: dict[str, Any],
reference: Image.Image,
seed: int = 1234,
) -> tuple[tuple[Image.Image, Image.Image], Image.Image, Image.Image]:
assert reference.mode == "RGBA"
extrema = cast(tuple[tuple[int, int], ...], reference.getextrema())
if extrema[reference.mode.index("A")][0] == 255:
raise gr.Error("The reference must be a cutout.", duration=5)
return _process(scene, reference, seed)
TITLE = """
# Finegrain Product Placement LoRA
πŸ§ͺ An experiment to extend Flux Kontext with product placement capabilities.
The LoRA was trained using EditNet, our before / after image editing dataset.
Just draw a box to set where the subject should be blended, and at what size.
*The reference must be a cutout, i.e. have a transparent background.*
If you do not have a cutout available you can create one
[here](https://huggingface.co/spaces/finegrain/finegrain-object-cutter).
[Model Card](https://huggingface.co/finegrain/finegrain-product-placement-lora) |
[Blog Post](https://blog.finegrain.ai/posts/product-placement-flux-lora-experiment/) |
[EditNet](https://finegrain.ai/editnet)
🌟 If you like this Space, follow [Finegrain](https://huggingface.co/finegrain) on Hugging Face for more cool free tools!
"""
with gr.Blocks() as demo:
gr.Markdown(TITLE)
with gr.Row():
with gr.Column():
scene = image_annotator(
label="Scene",
image_type="pil",
disable_edit_boxes=True,
show_download_button=False,
show_share_button=False,
single_box=True,
image_mode="RGB",
)
reference = gr.Image(
label="Product Reference",
visible=True,
interactive=True,
type="pil",
image_mode="RGBA",
)
with gr.Accordion("Options", open=False):
seed = gr.Slider(
minimum=0,
maximum=10_000,
value=1234,
step=1,
label="Seed",
)
with gr.Row():
run_btn = gr.ClearButton(value="Blend", interactive=False)
with gr.Column():
output_image = gr.ImageSlider(label="Output Image", show_fullscreen_button=False)
with gr.Accordion("Debug", open=False):
output_textbox = gr.Textbox(label="Bounding Box", interactive=False)
output_reference = gr.Image(
label="Reference",
visible=True,
interactive=False,
type="pil",
image_mode="RGB",
)
output_scene = gr.Image(
label="Scene",
visible=True,
interactive=False,
type="pil",
image_mode="RGB",
)
run_btn.add(output_image)
# Watch for changes (scene and reference)
# i.e. the user must select a box in the scene and upload a reference image
scene.change(fn=on_change, inputs=[scene, reference], outputs=[run_btn, output_textbox])
reference.change(fn=on_change, inputs=[scene, reference], outputs=[run_btn, output_textbox])
run_btn.click(
fn=process,
inputs=[scene, reference, seed],
outputs=[output_image, output_reference, output_scene],
)
examples = [
[
{
"image": "examples/sunglasses/scene.jpg",
"boxes": [{"xmin": 164, "ymin": 89, "xmax": 379, "ymax": 204}],
},
"examples/sunglasses/reference.webp",
],
[
{
"image": "examples/kitchen/scene.webp",
"boxes": [{"xmin": 165, "ymin": 765, "xmax": 332, "ymax": 883}],
},
"examples/kitchen/reference.webp",
],
[
{
"image": "examples/glass/scene.webp",
"boxes": [{"xmin": 389, "ymin": 509, "xmax": 611, "ymax": 1088}],
},
"examples/glass/reference.webp",
],
[
{
"image": "examples/chair/scene.webp",
"boxes": [{"xmin": 366, "ymin": 389, "xmax": 623, "ymax": 728}],
},
"examples/chair/reference.webp",
],
[
{
"image": "examples/lantern/scene.webp",
"boxes": [{"xmin": 497, "ymin": 690, "xmax": 618, "ymax": 873}],
},
"examples/lantern/reference.webp",
],
]
ex = gr.Examples(
examples=examples,
inputs=[scene, reference],
outputs=[output_image, output_reference, output_scene],
fn=process,
cache_examples=True,
cache_mode="eager",
)
demo.launch(show_api=False, ssr_mode=False)