Spaces:
Running
Running
| import base64 | |
| import io | |
| import os | |
| import zipfile | |
| from io import BytesIO | |
| from pathlib import Path | |
| from typing import Literal, TypedDict, cast | |
| import gradio as gr | |
| import numpy as np | |
| import requests | |
| from gradio.components.image_editor import EditorValue | |
| from PIL import Image | |
| _PASSWORD = os.environ.get("PASSWORD", None) | |
| if not _PASSWORD: | |
| msg = "PASSWORD is not set" | |
| raise ValueError(msg) | |
| PASSWORD = cast("str", _PASSWORD) | |
| _ENDPOINT = os.environ.get("ENDPOINT", None) | |
| if not _ENDPOINT: | |
| msg = "ENDPOINT is not set" | |
| raise ValueError(msg) | |
| ENDPOINT = cast("str", _ENDPOINT) | |
| # Add constants at the top | |
| THUMBNAIL_MAX_SIZE = 2048 | |
| REFERENCE_MAX_SIZE = 1024 | |
| REQUEST_TIMEOUT = 300 # 5 minutes | |
| DEFAULT_BRUSH_SIZE = 75 | |
| def encode_image_as_base64(image: Image.Image) -> str: | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def make_example(image_path: Path, mask_path: Path | None) -> EditorValue: | |
| background_image = Image.open(image_path) | |
| background_image = background_image.convert("RGB") | |
| background = np.array(background_image) | |
| if mask_path: | |
| mask_image = Image.open(mask_path) | |
| mask_image = mask_image.convert("RGB") | |
| mask = np.array(mask_image) | |
| mask = mask[:, :, 0] | |
| mask = np.where(mask == 255, 0, 255) # noqa: PLR2004 | |
| else: | |
| mask = np.zeros_like(background) | |
| mask = mask[:, :, 0] | |
| if background.shape[0] != mask.shape[0] or background.shape[1] != mask.shape[1]: | |
| msg = "Background and mask must have the same shape" | |
| raise ValueError(msg) | |
| layer = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8) | |
| layer[:, :, 3] = mask | |
| composite = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8) | |
| composite[:, :, :3] = background | |
| composite[:, :, 3] = np.where(mask == 255, 0, 255) # noqa: PLR2004 | |
| return { | |
| "background": background, | |
| "layers": [layer], | |
| "composite": composite, | |
| } | |
| class InputFurnitureBlendingTypedDict(TypedDict): | |
| return_type: Literal["zipfile", "s3"] | |
| model_type: Literal["schnell", "dev"] | |
| room_image_input: str | |
| bbox: tuple[int, int, int, int] | |
| furniture_reference_image: str | |
| prompt: str | |
| seed: int | |
| num_inference_steps: int | |
| max_dimension: int | |
| margin: int | |
| crop: bool | |
| num_images_per_prompt: int | |
| bucket: str | |
| # Add type hints for the response | |
| class GenerationResponse(TypedDict): | |
| images: list[Image.Image] | |
| error: str | None | |
| def validate_inputs( | |
| image_and_mask: EditorValue | None, | |
| furniture_reference: Image.Image | None, | |
| ) -> tuple[Literal[True], None] | tuple[Literal[False], str]: | |
| if not image_and_mask: | |
| return False, "Please upload an image and draw a mask" | |
| image_np = cast("np.ndarray", image_and_mask["background"]) | |
| if np.sum(image_np) == 0: | |
| return False, "Please upload an image" | |
| alpha_channel = cast("np.ndarray", image_and_mask["layers"][0]) | |
| mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) | |
| if np.sum(mask_np) == 0: | |
| return False, "Please mark the areas you want to remove" | |
| if not furniture_reference: | |
| return False, "Please upload a furniture reference image" | |
| return True, None | |
| def process_images( | |
| image_and_mask: EditorValue, | |
| furniture_reference: Image.Image, | |
| ) -> tuple[Image.Image, Image.Image, Image.Image]: | |
| image_np = cast("np.ndarray", image_and_mask["background"]) | |
| alpha_channel = cast("np.ndarray", image_and_mask["layers"][0]) | |
| mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) | |
| mask_image = Image.fromarray(mask_np).convert("L") | |
| target_image = Image.fromarray(image_np).convert("RGB") | |
| # Resize images | |
| mask_image.thumbnail( | |
| (THUMBNAIL_MAX_SIZE, THUMBNAIL_MAX_SIZE), Image.Resampling.LANCZOS | |
| ) | |
| target_image.thumbnail( | |
| (THUMBNAIL_MAX_SIZE, THUMBNAIL_MAX_SIZE), Image.Resampling.LANCZOS | |
| ) | |
| furniture_reference.thumbnail( | |
| (REFERENCE_MAX_SIZE, REFERENCE_MAX_SIZE), Image.Resampling.LANCZOS | |
| ) | |
| return target_image, mask_image, furniture_reference | |
| def predict( | |
| model_type: Literal["schnell", "dev", "pixart"], | |
| image_and_mask: EditorValue, | |
| furniture_reference: Image.Image | None, | |
| prompt: str = "", | |
| seed: int = 0, | |
| num_inference_steps: int = 28, | |
| max_dimension: int = 512, | |
| margin: int = 128, | |
| crop: bool = True, | |
| num_images_per_prompt: int = 1, | |
| ) -> list[Image.Image] | None: | |
| # Validate inputs | |
| is_valid, error_message = validate_inputs(image_and_mask, furniture_reference) | |
| if not is_valid and error_message: | |
| gr.Info(error_message) | |
| return None | |
| if model_type == "pixart": | |
| gr.Info("PixArt is not supported yet") | |
| return None | |
| # Process images | |
| target_image, mask_image, furniture_reference = process_images( | |
| image_and_mask, cast("Image.Image", furniture_reference) | |
| ) | |
| bbox = mask_image.getbbox() | |
| if not bbox: | |
| gr.Info("Please mark the areas you want to remove") | |
| return None | |
| # Prepare API request | |
| room_image_input_base64 = "data:image/png;base64," + encode_image_as_base64( | |
| target_image | |
| ) | |
| furniture_reference_base64 = "data:image/png;base64," + encode_image_as_base64( | |
| furniture_reference | |
| ) | |
| body = InputFurnitureBlendingTypedDict( | |
| return_type="zipfile", | |
| model_type=model_type, | |
| room_image_input=room_image_input_base64, | |
| bbox=bbox, | |
| furniture_reference_image=furniture_reference_base64, | |
| prompt=prompt, | |
| seed=seed, | |
| num_inference_steps=num_inference_steps, | |
| max_dimension=max_dimension, | |
| margin=margin, | |
| crop=crop, | |
| num_images_per_prompt=num_images_per_prompt, | |
| bucket="furniture-blending", | |
| ) | |
| try: | |
| response = requests.post( | |
| ENDPOINT, | |
| headers={"accept": "application/json", "Content-Type": "application/json"}, | |
| json=body, | |
| timeout=REQUEST_TIMEOUT, | |
| ) | |
| response.raise_for_status() | |
| except requests.RequestException as e: | |
| gr.Info(f"API request failed: {e!s}") | |
| return None | |
| # Process response | |
| try: | |
| zip_bytes = io.BytesIO(response.content) | |
| final_image_list: list[Image.Image] = [] | |
| with zipfile.ZipFile(zip_bytes, "r") as zip_file: | |
| for filename in zip_file.namelist(): | |
| with zip_file.open(filename) as file: | |
| image = Image.open(file).convert("RGB") | |
| final_image_list.append(image) | |
| except (OSError, zipfile.BadZipFile) as e: | |
| gr.Info(f"Failed to process response: {e!s}") | |
| return None | |
| return final_image_list | |
| css = r""" | |
| #col-left { | |
| margin: 0 auto; | |
| max-width: 430px; | |
| } | |
| #col-mid { | |
| margin: 0 auto; | |
| max-width: 430px; | |
| } | |
| #col-right { | |
| margin: 0 auto; | |
| max-width: 430px; | |
| } | |
| #col-showcase { | |
| margin: 0 auto; | |
| max-width: 1100px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML(""" | |
| <div style="display: flex; justify-content: center; text-align:center; flex-direction: column;"> | |
| <h1 style="color: #333;">🪑 Furniture Blending Demo</h1> | |
| <div style="max-width: 800px; margin: 0 auto;"> | |
| <p style="font-size: 16px;">Upload an image, draw a mask on the areas you want to remove, and upload a furniture reference image.</p> | |
| <p style="font-size: 16px;"> | |
| For the best results, make square masks. | |
| Flux dev give better results than the schnell but is slower. | |
| Object reference should be a single object with white background. | |
| </p> | |
| <p style="font-size: 16px;"> | |
| You can edit the object with the prompt. | |
| For example, you can add "red couch" to the prompt to make the couch red. | |
| </p> | |
| <br> | |
| <p style="font-size: 16px;">⚠️ Note that the images are compressed to reduce the workloads of the demo. </p> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(elem_id="col-left"): | |
| gr.HTML( | |
| r""" | |
| <div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px"> | |
| <div> | |
| 🪟 Room image with inpainting mask ⬇️ | |
| </div> | |
| </div> | |
| """, | |
| max_height=50, | |
| ) | |
| image_and_mask = gr.ImageMask( | |
| label="Image and Mask", | |
| layers=False, | |
| show_fullscreen_button=False, | |
| sources=["upload"], | |
| show_download_button=False, | |
| interactive=True, | |
| brush=gr.Brush( | |
| default_size=DEFAULT_BRUSH_SIZE, | |
| colors=["#000000"], | |
| color_mode="fixed", | |
| ), | |
| transforms=[], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| make_example(path, None) | |
| for path in Path("./examples/scenes").glob("*.png") | |
| ], | |
| label="Room examples", | |
| examples_per_page=6, | |
| inputs=[image_and_mask], | |
| ) | |
| with gr.Column(elem_id="col-mid"): | |
| gr.HTML( | |
| r""" | |
| <div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px"> | |
| <div> | |
| 🪑 Furniture reference image ⬇️ | |
| </div> | |
| </div> | |
| """, | |
| max_height=50, | |
| ) | |
| condition_image = gr.Image( | |
| label="Furniture Reference", | |
| type="pil", | |
| sources=["upload"], | |
| image_mode="RGB", | |
| ) | |
| gr.Examples( | |
| examples=list(Path("./examples/objects").glob("*.png")), | |
| label="Furniture examples", | |
| examples_per_page=6, | |
| inputs=[condition_image], | |
| ) | |
| with gr.Column(elem_id="col-right"): | |
| gr.HTML( | |
| r""" | |
| <div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px"> | |
| <div> | |
| 🔥 Press Run ⬇️ | |
| </div> | |
| </div> | |
| """, | |
| max_height=50, | |
| ) | |
| results = gr.Gallery( | |
| label="Result", | |
| format="png", | |
| file_types=["image"], | |
| show_label=False, | |
| columns=2, | |
| allow_preview=True, | |
| preview=True, | |
| ) | |
| model_type = gr.Radio( | |
| choices=["schnell", "dev", "pixart"], | |
| value="dev", | |
| label="Model Type", | |
| ) | |
| run_button = gr.Button("Run") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value="", | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=np.iinfo(np.int32).max, | |
| step=1, | |
| value=0, | |
| ) | |
| num_images_per_prompt = gr.Slider( | |
| label="Number of images per prompt", | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| value=2, | |
| ) | |
| crop = gr.Checkbox( | |
| label="Crop", | |
| value=False, | |
| ) | |
| margin = gr.Slider( | |
| label="Margin", | |
| minimum=0, | |
| maximum=256, | |
| step=16, | |
| value=128, | |
| ) | |
| with gr.Column(): | |
| max_dimension = gr.Slider( | |
| label="Max Dimension", | |
| minimum=256, | |
| maximum=1024, | |
| step=128, | |
| value=512, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=4, | |
| maximum=30, | |
| step=2, | |
| value=28, | |
| ) | |
| # Change the number of inference steps based on the model type | |
| model_type.change( | |
| fn=lambda x: gr.update(value=4 if x == "schnell" else 28), | |
| inputs=model_type, | |
| outputs=num_inference_steps, | |
| ) | |
| # Add loading indicator | |
| with gr.Row(): | |
| loading_indicator = gr.HTML( | |
| '<div id="loading" style="display:none;">Processing... Please wait.</div>' | |
| ) | |
| # Update click handler to show loading state | |
| run_button.click( | |
| fn=lambda: gr.update(visible=True), | |
| outputs=[loading_indicator], | |
| ).then( | |
| fn=predict, | |
| inputs=[ | |
| model_type, | |
| image_and_mask, | |
| condition_image, | |
| prompt, | |
| seed, | |
| num_inference_steps, | |
| max_dimension, | |
| margin, | |
| crop, | |
| num_images_per_prompt, | |
| ], | |
| outputs=[results], | |
| ).then( | |
| fn=lambda: gr.update(visible=False), | |
| outputs=[loading_indicator], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |