|
|
import spaces |
|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
import random |
|
|
import numpy as np |
|
|
import torch |
|
|
import os |
|
|
import json |
|
|
from datetime import datetime |
|
|
|
|
|
from pipeline_rf import RectifiedFlowPipeline |
|
|
|
|
|
|
|
|
pipe = RectifiedFlowPipeline.from_pretrained("XCLIU/2_rectified_flow_from_sd_1_5", torch_dtype=torch.float32) |
|
|
pipe.to("cuda") |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=20) |
|
|
def process_image( |
|
|
image_layers, prompt, seed, randomize_seed, num_inference_steps, |
|
|
max_steps, learning_rate, optimization_steps, inverseproblem, mask_input |
|
|
): |
|
|
image_with_mask = { |
|
|
"image": image_layers["background"], |
|
|
"mask": image_layers["layers"][0] if mask_input is None else mask_input |
|
|
} |
|
|
|
|
|
|
|
|
if randomize_seed or seed is None: |
|
|
seed = random.randint(0, 2**32 - 1) |
|
|
generator = torch.Generator("cuda").manual_seed(int(seed)) |
|
|
|
|
|
|
|
|
if image_with_mask is None: |
|
|
return None, f"❌ Please upload an image and create a mask." |
|
|
image = image_with_mask["image"] |
|
|
mask = image_with_mask["mask"] |
|
|
|
|
|
if image is None or mask is None: |
|
|
return None, f"❌ Please ensure both image and mask are provided." |
|
|
|
|
|
|
|
|
image = image.convert("RGB") |
|
|
mask = mask.split()[-1] |
|
|
|
|
|
if not prompt: |
|
|
prompt = "" |
|
|
|
|
|
with torch.autocast("cuda"): |
|
|
|
|
|
|
|
|
result = pipe( |
|
|
prompt=prompt, |
|
|
negative_prompt="", |
|
|
input_image=image.resize((512, 512)), |
|
|
mask_image=mask.resize((512, 512)), |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=0.0, |
|
|
generator=generator, |
|
|
save_masked_image=False, |
|
|
output_path="test.png", |
|
|
learning_rate=learning_rate, |
|
|
max_steps=max_steps, |
|
|
optimization_steps=optimization_steps, |
|
|
inverseproblem=inverseproblem |
|
|
).images[0] |
|
|
return result, f"✅ Inpainting completed with seed {seed}." |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
<style> |
|
|
body {background-color: #f5f5f5; color: #333333;} |
|
|
h1 {text-align: center; font-family: 'Helvetica', sans-serif; margin-bottom: 10px;} |
|
|
h2 {text-align: center; color: #666666; font-weight: normal; margin-bottom: 30px;} |
|
|
.gradio-container {max-width: 800px; margin: auto;} |
|
|
.footer {text-align: center; margin-top: 20px; color: #999999; font-size: 12px;} |
|
|
</style> |
|
|
""" |
|
|
) |
|
|
gr.Markdown("<h1>🍲 FlowChef 🍲</h1>") |
|
|
gr.Markdown("<h2>Inversion/Gradient/Training-free Steering of <u>InstaFlow (SDv1.5) for Inpainting (Inverse Problem)</u></h2>") |
|
|
gr.Markdown("<h3><p><a href='https://flowchef.github.io/'>Project Page</a> | <a href='#'>Paper</a></p> (Steering Rectified Flow Models in the Vector Field for Controlled Image Generation)</h3>") |
|
|
|
|
|
gr.Markdown("<h3>⚡ For better performance, check out our demo on <a href='https://huggingface.co/spaces/FlowChef/FlowChef-Flux1-dev'>Flux</a>!</h3>") |
|
|
|
|
|
|
|
|
current_input_image = None |
|
|
current_mask = None |
|
|
current_output_image = None |
|
|
current_params = {} |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_input = gr.ImageMask( |
|
|
|
|
|
|
|
|
type="pil", |
|
|
label="Input Image and Mask", |
|
|
image_mode="RGBA", |
|
|
height=512, |
|
|
width=512, |
|
|
) |
|
|
with gr.Column(): |
|
|
output_image = gr.Image(label="Output Image") |
|
|
|
|
|
|
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
placeholder="Describe what should appear in the masked area..." |
|
|
) |
|
|
with gr.Row(): |
|
|
seed = gr.Number(label="Seed (Optional)", value=None) |
|
|
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) |
|
|
num_inference_steps = gr.Slider( |
|
|
label="Inference Steps", minimum=50, maximum=200, value=100 |
|
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
max_steps = gr.Slider(label="Max Steps", minimum=50, maximum=200, value=200) |
|
|
learning_rate = gr.Slider(label="Learning Rate", minimum=0.01, maximum=0.5, value=0.02) |
|
|
optimization_steps = gr.Slider(label="Optimization Steps", minimum=1, maximum=10, value=1) |
|
|
inverseproblem = gr.Checkbox(label="Apply mask on pixel space (does not work well with HF ZeroGPU)", value=False, info="Enables inverse problem formulation for inpainting by masking the RGB image itself. Hence, to avoid artifacts we increase the mask size manually during inference.") |
|
|
mask_input = gr.Image( |
|
|
type="pil", |
|
|
label="Optional Mask", |
|
|
image_mode="RGBA", |
|
|
) |
|
|
with gr.Row(): |
|
|
run_button = gr.Button("Run", variant="primary") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_and_update_status( |
|
|
image_with_mask, prompt, seed, randomize_seed, num_inference_steps, |
|
|
max_steps, learning_rate, optimization_steps, inverseproblem, mask_input |
|
|
): |
|
|
result_image, result_status = process_image( |
|
|
image_with_mask, prompt, seed, randomize_seed, num_inference_steps, |
|
|
max_steps, learning_rate, optimization_steps, inverseproblem, mask_input |
|
|
) |
|
|
|
|
|
|
|
|
global current_input_image, current_mask, current_output_image, current_params |
|
|
|
|
|
current_input_image = image_with_mask["background"] if image_with_mask else None |
|
|
current_mask = mask_input if mask_input is not None else (image_with_mask["layers"][0] if image_with_mask else None) |
|
|
current_output_image = result_image |
|
|
current_params = { |
|
|
"prompt": prompt, |
|
|
"seed": seed, |
|
|
"randomize_seed": randomize_seed, |
|
|
"num_inference_steps": num_inference_steps, |
|
|
"max_steps": max_steps, |
|
|
"learning_rate": learning_rate, |
|
|
"optimization_steps": optimization_steps, |
|
|
"inverseproblem": inverseproblem, |
|
|
} |
|
|
|
|
|
return result_image |
|
|
|
|
|
def save_data(): |
|
|
if not os.path.exists("saved_results"): |
|
|
os.makedirs("saved_results") |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
save_dir = os.path.join("saved_results", timestamp) |
|
|
os.makedirs(save_dir) |
|
|
|
|
|
|
|
|
if current_input_image: |
|
|
current_input_image.save(os.path.join(save_dir, "input.png")) |
|
|
if current_mask: |
|
|
current_mask.save(os.path.join(save_dir, "mask.png")) |
|
|
if current_output_image: |
|
|
current_output_image.save(os.path.join(save_dir, "output.png")) |
|
|
|
|
|
|
|
|
with open(os.path.join(save_dir, "parameters.json"), "w") as f: |
|
|
json.dump(current_params, f, indent=4) |
|
|
|
|
|
return f"✅ Data saved in {save_dir}" |
|
|
|
|
|
run_button.click( |
|
|
fn=run_and_update_status, |
|
|
inputs=[ |
|
|
image_input, |
|
|
prompt, |
|
|
seed, |
|
|
randomize_seed, |
|
|
num_inference_steps, |
|
|
max_steps, |
|
|
learning_rate, |
|
|
optimization_steps, |
|
|
inverseproblem, |
|
|
mask_input |
|
|
], |
|
|
outputs=output_image, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
"<div class='footer'>Developed with ❤️ using InstaFlow (Stable Diffusion v1.5) and Gradio by <a href='https://maitreyapatel.com'>Maitreya Patel</a></div>" |
|
|
) |
|
|
|
|
|
def load_example_image_with_mask(image_path): |
|
|
|
|
|
image = Image.open(image_path) |
|
|
|
|
|
mask = Image.new('L', image.size, 0) |
|
|
return {"background": image, "layers": [mask], "composite": image} |
|
|
|
|
|
examples_dir = "assets" |
|
|
volcano_dict = load_example_image_with_mask(os.path.join(examples_dir, "vulcano.jpg")) |
|
|
dog_dict = load_example_image_with_mask(os.path.join(examples_dir, "dog.webp")) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
[ |
|
|
"./saved_results/20241129_210517/input.png", |
|
|
"./saved_results/20241129_210517/mask.png", |
|
|
"./saved_results/20241129_210517/output.png", |
|
|
"a cat", |
|
|
0, |
|
|
True, |
|
|
200, |
|
|
200, |
|
|
0.1, |
|
|
1, |
|
|
False, |
|
|
], |
|
|
[ |
|
|
"./saved_results/20241129_211124/input.png", |
|
|
"./saved_results/20241129_211124/mask.png", |
|
|
"./saved_results/20241129_211124/output.png", |
|
|
" ", |
|
|
0, |
|
|
True, |
|
|
200, |
|
|
200, |
|
|
0.1, |
|
|
5, |
|
|
False, |
|
|
], |
|
|
[ |
|
|
"./saved_results/20241129_212001/input.png", |
|
|
"./saved_results/20241129_212001/mask.png", |
|
|
"./saved_results/20241129_212001/output.png", |
|
|
" ", |
|
|
52, |
|
|
False, |
|
|
200, |
|
|
200, |
|
|
0.02, |
|
|
10, |
|
|
False, |
|
|
], |
|
|
[ |
|
|
"./saved_results/20241129_212052/input.png", |
|
|
"./saved_results/20241129_212052/mask.png", |
|
|
"./saved_results/20241129_212052/output.png", |
|
|
" ", |
|
|
52, |
|
|
False, |
|
|
200, |
|
|
200, |
|
|
0.02, |
|
|
10, |
|
|
False, |
|
|
], |
|
|
[ |
|
|
"./saved_results/20241129_212155/input.png", |
|
|
"./saved_results/20241129_212155/mask.png", |
|
|
"./saved_results/20241129_212155/output.png", |
|
|
" ", |
|
|
52, |
|
|
False, |
|
|
200, |
|
|
200, |
|
|
0.02, |
|
|
10, |
|
|
False, |
|
|
], |
|
|
], |
|
|
inputs=[ |
|
|
image_input, |
|
|
mask_input, |
|
|
output_image, |
|
|
prompt, |
|
|
seed, |
|
|
randomize_seed, |
|
|
num_inference_steps, |
|
|
max_steps, |
|
|
learning_rate, |
|
|
optimization_steps, |
|
|
inverseproblem |
|
|
], |
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
demo.launch() |
|
|
|