File size: 10,634 Bytes
449d6db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fde9193
449d6db
 
 
 
e813159
 
 
 
 
 
 
449d6db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cd9500
931b456
449d6db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931b456
449d6db
 
 
 
 
 
babe222
 
 
 
 
 
 
 
449d6db
 
 
 
931b456
e5e37ae
449d6db
 
 
931b456
541c781
449d6db
 
541c781
 
 
449d6db
 
 
 
 
 
 
ba36c3d
 
 
 
541c781
 
22352a3
541c781
 
 
 
 
 
 
 
22352a3
449d6db
541c781
 
 
449d6db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931b456
 
 
 
 
 
 
 
 
 
 
449d6db
 
 
 
 
 
 
931b456
 
 
 
 
449d6db
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# PyTorch 2.8 (temporary hack)
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')

# Actual demo code
import gradio as gr
import numpy as np
import spaces
import torch
import random
from PIL import Image, ImageOps

from diffusers import FluxKontextPipeline
from diffusers.utils import load_image

# from optimization import optimize_pipeline_

MAX_SEED = np.iinfo(np.int32).max

pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("ovi054/Draw2Photo-Kontext-LoRA")
pipe.fuse_lora()
# optimize_pipeline_(pipe, image=Image.new("RGB", (512, 512)), prompt='prompt')


import os

EXAMPLES_DIR = "examples"
BASE_EXAMPLES = [os.path.join(EXAMPLES_DIR, "base", f) for f in sorted(os.listdir(os.path.join(EXAMPLES_DIR, "base")))]
FACE_EXAMPLES = [os.path.join(EXAMPLES_DIR, "face", f) for f in sorted(os.listdir(os.path.join(EXAMPLES_DIR, "face")))]


def add_overlay(base_img, overlay_img, margin=20):
    """
    Pastes an overlay image onto the top-right corner of a base image.

    The overlay is resized to be 1/5th of the width of the base image,
    maintaining its aspect ratio.

    Args:
        base_img (PIL.Image.Image): The main image.
        overlay_img (PIL.Image.Image): The image to place on top.
        margin (int, optional): The pixel margin from the top and right edges. Defaults to 20.

    Returns:
        PIL.Image.Image: The combined image.
    """
    if base_img is None or overlay_img is None:
        return base_img
    
    base = base_img.convert("RGBA")
    overlay = overlay_img.convert("RGBA")
    
    # --- MODIFICATION ---
    # Calculate the target width to be 1/5th of the base image's width
    target_width = base.width // 5
    
    # Keep aspect ratio, resize overlay to the newly calculated target width
    w, h = overlay.size
    
    # Add a check to prevent division by zero if the overlay image has no width
    if w == 0:
        return base
        
    new_height = int(h * (target_width / w))
    overlay = overlay.resize((target_width, new_height), Image.LANCZOS)

    # Position: top-right corner with a margin
    x = base.width - overlay.width - margin
    y = margin

    # Paste the resized overlay onto the base image using its alpha channel for transparency
    base.paste(overlay, (x, y), overlay)
    return base



@spaces.GPU(duration=45)
def infer(input_image, input_image_upload, overlay_image, prompt="make it real", seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
    """
    Perform image editing using the FLUX.1 Kontext pipeline.
    
    This function takes an input image and a text prompt to generate a modified version
    of the image based on the provided instructions. It uses the FLUX.1 Kontext model
    for contextual image editing tasks.
    
    Args:
        input_image (dict or PIL.Image.Image): The input from the gr.Paint component.
        input_image_upload (PIL.Image.Image): The input from the gr.Image upload component.
        overlay_image (PIL.Image.Image): The face photo to overlay.
        prompt (str): Text description of the desired edit to apply to the image.
        seed (int, optional): Random seed for reproducible generation.
        randomize_seed (bool, optional): If True, generates a random seed.
        guidance_scale (float, optional): Controls how closely the model follows the prompt.
        steps (int, optional): Controls how many steps to run the diffusion model for.
        progress (gr.Progress, optional): Gradio progress tracker.
    
    Returns:
        tuple: A 4-tuple containing the result image, the processed input image, the seed, and a gr.Button update.
    """
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    # --- CORRECTED LOGIC STARTS HERE ---
    
    # 1. Prioritize the uploaded image. If it exists, it becomes our main 'input_image'.
    if input_image_upload is not None:
        processed_input_image = input_image_upload
    # 2. If no image was uploaded, check the drawing canvas.
    elif isinstance(input_image, dict):
        # Extract the actual image from the dictionary provided by gr.Paint
        if "composite" in input_image and input_image["composite"] is not None:
            processed_input_image = input_image["composite"]
        elif "background" in input_image and input_image["background"] is not None:
            processed_input_image = input_image["background"]
        else:
            # The canvas is empty, so there's no input image.
            processed_input_image = None
    else:
        # Fallback in case the input is neither from upload nor a valid canvas dict.
        processed_input_image = None
        
    # --- CORRECTED LOGIC ENDS HERE ---
    
    # From this point on, 'processed_input_image' is either a PIL Image or None.
    if processed_input_image is not None:
        if overlay_image is not None:
            # Now this function is guaranteed to receive a PIL Image.
            processed_input_image = add_overlay(processed_input_image, overlay_image)
            
        processed_input_image = processed_input_image.convert("RGB")
        image = pipe(
            image=processed_input_image, 
            prompt=prompt,
            guidance_scale=guidance_scale,
            width = processed_input_image.size[0],
            height = processed_input_image.size[1],
            num_inference_steps=steps,
            generator=torch.Generator().manual_seed(seed),
        ).images[0]
    else:
        # Handle the text-to-image case where no input image was provided.
        image = pipe(
            prompt=prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=steps,
            generator=torch.Generator().manual_seed(seed),
        ).images[0]
        
    return image, processed_input_image, seed, gr.Button(visible=False)
    
@spaces.GPU
def infer_example(input_image, prompt):
    image, seed, _ = infer(input_image, prompt)
    return image, seed

# css="""
# #col-container {
#     margin: 0 auto;
#     max-width: 960px;
# }
# """

css=""

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# FLUX.1 Kontext [dev] + Draw2Photo LoRA
Turn drawing+face into a realistic photo with FLUX.1 Kontext [dev] + [Draw2Photo LoRA](https://huggingface.co/ovi054/Draw2Photo-Kontext-LoRA)
        """)
        with gr.Row():
            with gr.Column():
                gr.Markdown("Step 1.  Select/Upload/Draw a person ⬇️")
                # input_image = gr.Image(label="Upload drawing", type="pil")
                with gr.Row():
                    with gr.Tabs() as tabs:
                        with gr.TabItem("Upload"):
                            input_image_upload = gr.Image(label="Upload drawing", type="pil")
                            
                        with gr.TabItem("Draw"):
                            input_image = gr.Paint(
                                    type="pil",
                                    brush=gr.Brush(default_size=6, colors=["#000000"], color_mode="fixed"),
                                    canvas_size = (1200,1200),
                                    layers = False
                                )
                gr.Examples(
                    examples=[[img] for img in BASE_EXAMPLES],
                    inputs=[input_image_upload],
                )

            with gr.Column():
                gr.Markdown("Step 2.  Select/Upload a face photo ⬇️")
                with gr.Row():
                    overlay_image = gr.Image(label="Upload face photo", type="pil")
                gr.Examples(
                    examples=[[img] for img in FACE_EXAMPLES],
                    inputs=[overlay_image],
                )
                    
            with gr.Column():
                gr.Markdown("Step 3.  Press “Run” to get results ⬇️")
                with gr.Row():
                    run_button = gr.Button("Run")
                with gr.Accordion("Advanced Settings", open=False):

                    prompt = gr.Text(
                        label="Prompt",
                        max_lines=1,
                        value = "make it real",
                        placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')",
                        container=False,
                    )
                    
                    seed = gr.Slider(
                        label="Seed",
                        minimum=0,
                        maximum=MAX_SEED,
                        step=1,
                        value=0,
                    )
                    
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                    
                    guidance_scale = gr.Slider(
                        label="Guidance Scale",
                        minimum=1,
                        maximum=10,
                        step=0.1,
                        value=2.5,
                    )       
                    
                    steps = gr.Slider(
                        label="Steps",
                        minimum=1,
                        maximum=30,
                        value=28,
                        step=1
                    )
                result = gr.Image(label="Result", show_label=False, interactive=False)
                result_input = gr.Image(label="Result", show_label=False, interactive=False)
                reuse_button = gr.Button("Reuse this image", visible=False)
        
            
        # examples = gr.Examples(
        #     examples=[
        #         ["flowers.png", "turn the flowers into sunflowers"],
        #         ["monster.png", "make this monster ride a skateboard on the beach"],
        #         ["cat.png", "make this cat happy"]
        #     ],
        #     inputs=[input_image_upload, prompt],
        #     outputs=[result, seed],
        #     fn=infer_example,
        #     cache_examples="lazy"
        # )
            
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn = infer,
        inputs = [input_image, input_image_upload, overlay_image, prompt, seed, randomize_seed, guidance_scale, steps],
        outputs = [result, result_input, seed, reuse_button]
    )
    # reuse_button.click(
    #     fn = lambda image: image,
    #     inputs = [result],
    #     outputs = [input_image]
    # )

demo.launch(mcp_server=True)