ovi054's picture
Update app.py
0a2963b verified
# PyTorch 2.8 (temporary hack)
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
# Actual demo code
import gradio as gr
import numpy as np
import spaces
import torch
import random
from PIL import Image, ImageOps
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
# from optimization import optimize_pipeline_
MAX_SEED = np.iinfo(np.int32).max
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("ovi054/virtual-tryon-kontext-lora")
pipe.fuse_lora()
# optimize_pipeline_(pipe, image=Image.new("RGB", (512, 512)), prompt='prompt')
import os
EXAMPLES_DIR = "examples"
BASE_EXAMPLES = [os.path.join(EXAMPLES_DIR, "base", f) for f in sorted(os.listdir(os.path.join(EXAMPLES_DIR, "base")))]
FACE_EXAMPLES = [os.path.join(EXAMPLES_DIR, "face", f) for f in sorted(os.listdir(os.path.join(EXAMPLES_DIR, "face")))]
# def add_overlay(base_img, overlay_img, margin=20):
# """
# Pastes an overlay image onto the top-right corner of a base image.
# The overlay is resized to be 1/5th of the width of the base image,
# maintaining its aspect ratio.
# Args:
# base_img (PIL.Image.Image): The main image.
# overlay_img (PIL.Image.Image): The image to place on top.
# margin (int, optional): The pixel margin from the top and right edges. Defaults to 20.
# Returns:
# PIL.Image.Image: The combined image.
# """
# if base_img is None or overlay_img is None:
# return base_img
# base = base_img.convert("RGBA")
# overlay = overlay_img.convert("RGBA")
# # --- MODIFICATION ---
# # Calculate the target width to be 1/5th of the base image's width
# target_width = base.width // 5
# # Keep aspect ratio, resize overlay to the newly calculated target width
# w, h = overlay.size
# # Add a check to prevent division by zero if the overlay image has no width
# if w == 0:
# return base
# new_height = int(h * (target_width / w))
# overlay = overlay.resize((target_width, new_height), Image.LANCZOS)
# # Position: top-right corner with a margin
# x = base.width - overlay.width - margin
# y = margin
# # Paste the resized overlay onto the base image using its alpha channel for transparency
# base.paste(overlay, (x, y), overlay)
# return base
@spaces.GPU(duration=45)
def infer(input_image_upload, prompt="wear it", seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
"""
Perform image editing using the FLUX.1 Kontext pipeline.
This function takes an input image and a text prompt to generate a modified version
of the image based on the provided instructions. It uses the FLUX.1 Kontext model
for contextual image editing tasks.
Args:
input_image (dict or PIL.Image.Image): The input from the gr.Paint component.
input_image_upload (PIL.Image.Image): The input from the gr.Image upload component.
overlay_image (PIL.Image.Image): The face photo to overlay.
prompt (str): Text description of the desired edit to apply to the image.
seed (int, optional): Random seed for reproducible generation.
randomize_seed (bool, optional): If True, generates a random seed.
guidance_scale (float, optional): Controls how closely the model follows the prompt.
steps (int, optional): Controls how many steps to run the diffusion model for.
progress (gr.Progress, optional): Gradio progress tracker.
Returns:
tuple: A 4-tuple containing the result image, the processed input image, the seed, and a gr.Button update.
"""
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# --- CORRECTED LOGIC STARTS HERE ---
# 1. Prioritize the uploaded image. If it exists, it becomes our main 'input_image'.
if input_image_upload is not None:
processed_input_image = input_image_upload
else:
# Fallback in case the input is neither from upload nor a valid canvas dict.
processed_input_image = None
# --- CORRECTED LOGIC ENDS HERE ---
# From this point on, 'processed_input_image' is either a PIL Image or None.
if processed_input_image is not None:
processed_input_image = processed_input_image.convert("RGB")
image = pipe(
image=processed_input_image,
prompt=prompt,
guidance_scale=guidance_scale,
width = processed_input_image.size[0],
height = processed_input_image.size[1],
num_inference_steps=steps,
generator=torch.Generator().manual_seed(seed),
).images[0]
else:
# Handle the text-to-image case where no input image was provided.
image = pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=steps,
generator=torch.Generator().manual_seed(seed),
).images[0]
return image, seed, gr.Button(visible=False)
@spaces.GPU
def infer_example(input_image, prompt):
image, seed, _ = infer(input_image, prompt)
return image, seed
# css="""
# #col-container {
# margin: 0 auto;
# max-width: 960px;
# }
# """
css=""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# FLUX.1 Kontext [dev] + [Virtual Try-On LoRA](https://huggingface.co/ovi054/virtual-tryon-kontext-lora)
""")
with gr.Row():
with gr.Column():
gr.Markdown("""Step 1. Select/Upload the combined model and garment image ⬇️<br>
Place the garment onto the model image as an overlay using [this tool](https://v0-image-editor-app-eight.vercel.app/).
""")
# input_image = gr.Image(label="Upload Image", type="pil")
with gr.Row():
input_image_upload = gr.Image(label="Upload Image", type="pil")
gr.Examples(
examples=[[img] for img in BASE_EXAMPLES],
inputs=[input_image_upload],
)
# with gr.Column():
# gr.Markdown("Step 2. Select/Upload a face photo ⬇️")
# with gr.Row():
# overlay_image = gr.Image(label="Upload face photo", type="pil")
# gr.Examples(
# examples=[[img] for img in FACE_EXAMPLES],
# inputs=[overlay_image],
# )
with gr.Column():
gr.Markdown("Step 2. Press “Run” to get results ⬇️")
with gr.Row():
run_button = gr.Button("Run")
with gr.Accordion("Advanced Settings", open=False):
prompt = gr.Text(
label="Prompt",
max_lines=1,
value = "wear it",
placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')",
container=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=10,
step=0.1,
value=2.5,
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=30,
value=28,
step=1
)
result = gr.Image(label="Result", show_label=False, interactive=False)
result_input = gr.Image(label="Result", visible=False, show_label=False, interactive=False)
reuse_button = gr.Button("Reuse this image", visible=False)
# examples = gr.Examples(
# examples=[
# ["flowers.png", "turn the flowers into sunflowers"],
# ["monster.png", "make this monster ride a skateboard on the beach"],
# ["cat.png", "make this cat happy"]
# ],
# inputs=[input_image_upload, prompt],
# outputs=[result, seed],
# fn=infer_example,
# cache_examples="lazy"
# )
gr.on(
triggers=[run_button.click, prompt.submit],
fn = infer,
inputs = [input_image_upload, prompt, seed, randomize_seed, guidance_scale, steps],
outputs = [result, seed, reuse_button]
)
# reuse_button.click(
# fn = lambda image: image,
# inputs = [result],
# outputs = [input_image]
# )
demo.launch(mcp_server=True)