Draw2Photo / app.py
ovi054's picture
Update app.py
0cd9500 verified
# PyTorch 2.8 (temporary hack)
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
# Actual demo code
import gradio as gr
import numpy as np
import spaces
import torch
import random
from PIL import Image, ImageOps
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
# from optimization import optimize_pipeline_
MAX_SEED = np.iinfo(np.int32).max
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("ovi054/Draw2Photo-Kontext-LoRA")
pipe.fuse_lora()
# optimize_pipeline_(pipe, image=Image.new("RGB", (512, 512)), prompt='prompt')
import os
EXAMPLES_DIR = "examples"
BASE_EXAMPLES = [os.path.join(EXAMPLES_DIR, "base", f) for f in sorted(os.listdir(os.path.join(EXAMPLES_DIR, "base")))]
FACE_EXAMPLES = [os.path.join(EXAMPLES_DIR, "face", f) for f in sorted(os.listdir(os.path.join(EXAMPLES_DIR, "face")))]
def add_overlay(base_img, overlay_img, margin=20):
"""
Pastes an overlay image onto the top-right corner of a base image.
The overlay is resized to be 1/5th of the width of the base image,
maintaining its aspect ratio.
Args:
base_img (PIL.Image.Image): The main image.
overlay_img (PIL.Image.Image): The image to place on top.
margin (int, optional): The pixel margin from the top and right edges. Defaults to 20.
Returns:
PIL.Image.Image: The combined image.
"""
if base_img is None or overlay_img is None:
return base_img
base = base_img.convert("RGBA")
overlay = overlay_img.convert("RGBA")
# --- MODIFICATION ---
# Calculate the target width to be 1/5th of the base image's width
target_width = base.width // 5
# Keep aspect ratio, resize overlay to the newly calculated target width
w, h = overlay.size
# Add a check to prevent division by zero if the overlay image has no width
if w == 0:
return base
new_height = int(h * (target_width / w))
overlay = overlay.resize((target_width, new_height), Image.LANCZOS)
# Position: top-right corner with a margin
x = base.width - overlay.width - margin
y = margin
# Paste the resized overlay onto the base image using its alpha channel for transparency
base.paste(overlay, (x, y), overlay)
return base
@spaces.GPU(duration=45)
def infer(input_image, input_image_upload, overlay_image, prompt="make it real", seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
"""
Perform image editing using the FLUX.1 Kontext pipeline.
This function takes an input image and a text prompt to generate a modified version
of the image based on the provided instructions. It uses the FLUX.1 Kontext model
for contextual image editing tasks.
Args:
input_image (dict or PIL.Image.Image): The input from the gr.Paint component.
input_image_upload (PIL.Image.Image): The input from the gr.Image upload component.
overlay_image (PIL.Image.Image): The face photo to overlay.
prompt (str): Text description of the desired edit to apply to the image.
seed (int, optional): Random seed for reproducible generation.
randomize_seed (bool, optional): If True, generates a random seed.
guidance_scale (float, optional): Controls how closely the model follows the prompt.
steps (int, optional): Controls how many steps to run the diffusion model for.
progress (gr.Progress, optional): Gradio progress tracker.
Returns:
tuple: A 4-tuple containing the result image, the processed input image, the seed, and a gr.Button update.
"""
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# --- CORRECTED LOGIC STARTS HERE ---
# 1. Prioritize the uploaded image. If it exists, it becomes our main 'input_image'.
if input_image_upload is not None:
processed_input_image = input_image_upload
# 2. If no image was uploaded, check the drawing canvas.
elif isinstance(input_image, dict):
# Extract the actual image from the dictionary provided by gr.Paint
if "composite" in input_image and input_image["composite"] is not None:
processed_input_image = input_image["composite"]
elif "background" in input_image and input_image["background"] is not None:
processed_input_image = input_image["background"]
else:
# The canvas is empty, so there's no input image.
processed_input_image = None
else:
# Fallback in case the input is neither from upload nor a valid canvas dict.
processed_input_image = None
# --- CORRECTED LOGIC ENDS HERE ---
# From this point on, 'processed_input_image' is either a PIL Image or None.
if processed_input_image is not None:
if overlay_image is not None:
# Now this function is guaranteed to receive a PIL Image.
processed_input_image = add_overlay(processed_input_image, overlay_image)
processed_input_image = processed_input_image.convert("RGB")
image = pipe(
image=processed_input_image,
prompt=prompt,
guidance_scale=guidance_scale,
width = processed_input_image.size[0],
height = processed_input_image.size[1],
num_inference_steps=steps,
generator=torch.Generator().manual_seed(seed),
).images[0]
else:
# Handle the text-to-image case where no input image was provided.
image = pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=steps,
generator=torch.Generator().manual_seed(seed),
).images[0]
return image, processed_input_image, seed, gr.Button(visible=False)
@spaces.GPU
def infer_example(input_image, prompt):
image, seed, _ = infer(input_image, prompt)
return image, seed
# css="""
# #col-container {
# margin: 0 auto;
# max-width: 960px;
# }
# """
css=""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# FLUX.1 Kontext [dev] + Draw2Photo LoRA
Turn drawing+face into a realistic photo with FLUX.1 Kontext [dev] + [Draw2Photo LoRA](https://huggingface.co/ovi054/Draw2Photo-Kontext-LoRA)
""")
with gr.Row():
with gr.Column():
gr.Markdown("Step 1. Select/Upload/Draw a person ⬇️")
# input_image = gr.Image(label="Upload drawing", type="pil")
with gr.Row():
with gr.Tabs() as tabs:
with gr.TabItem("Upload"):
input_image_upload = gr.Image(label="Upload drawing", type="pil")
with gr.TabItem("Draw"):
input_image = gr.Paint(
type="pil",
brush=gr.Brush(default_size=6, colors=["#000000"], color_mode="fixed"),
canvas_size = (1200,1200),
layers = False
)
gr.Examples(
examples=[[img] for img in BASE_EXAMPLES],
inputs=[input_image_upload],
)
with gr.Column():
gr.Markdown("Step 2. Select/Upload a face photo ⬇️")
with gr.Row():
overlay_image = gr.Image(label="Upload face photo", type="pil")
gr.Examples(
examples=[[img] for img in FACE_EXAMPLES],
inputs=[overlay_image],
)
with gr.Column():
gr.Markdown("Step 3. Press “Run” to get results ⬇️")
with gr.Row():
run_button = gr.Button("Run")
with gr.Accordion("Advanced Settings", open=False):
prompt = gr.Text(
label="Prompt",
max_lines=1,
value = "make it real",
placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')",
container=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=10,
step=0.1,
value=2.5,
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=30,
value=28,
step=1
)
result = gr.Image(label="Result", show_label=False, interactive=False)
result_input = gr.Image(label="Result", show_label=False, interactive=False)
reuse_button = gr.Button("Reuse this image", visible=False)
# examples = gr.Examples(
# examples=[
# ["flowers.png", "turn the flowers into sunflowers"],
# ["monster.png", "make this monster ride a skateboard on the beach"],
# ["cat.png", "make this cat happy"]
# ],
# inputs=[input_image_upload, prompt],
# outputs=[result, seed],
# fn=infer_example,
# cache_examples="lazy"
# )
gr.on(
triggers=[run_button.click, prompt.submit],
fn = infer,
inputs = [input_image, input_image_upload, overlay_image, prompt, seed, randomize_seed, guidance_scale, steps],
outputs = [result, result_input, seed, reuse_button]
)
# reuse_button.click(
# fn = lambda image: image,
# inputs = [result],
# outputs = [input_image]
# )
demo.launch(mcp_server=True)