clockclock's picture
Create app.py
55d13a0 verified
raw
history blame
4.63 kB
# app.py
import gradio as gr
import torch
from diffusers import AutoPipelineForInpainting
from PIL import Image
# --- Model Loading ---
# Load the model only once at the start of the application
# We use float16 for memory efficiency and speed on GPUs
# If no GPU is available, this will run on CPU (but it will be very slow)
try:
pipe = AutoPipelineForInpainting.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
except Exception as e:
print(f"Could not load model on GPU: {e}. Falling back to CPU.")
pipe = AutoPipelineForInpainting.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting"
)
# --- The Inpainting Function ---
# This is the core function that takes user inputs and generates the image
def inpaint_image(input_dict, prompt, negative_prompt, guidance_scale, num_steps):
"""
Performs inpainting on an image based on a mask and a prompt.
Args:
input_dict (dict): A dictionary from Gradio's Image component containing 'image' and 'mask'.
prompt (str): The text prompt describing what to generate in the masked area.
negative_prompt (str): The text prompt describing what to avoid.
guidance_scale (float): A value to control how much the generation follows the prompt.
num_steps (int): The number of inference steps.
Returns:
PIL.Image: The resulting image after inpainting.
"""
# Separate the image and the mask from the input dictionary
image = input_dict["image"].convert("RGB")
mask_image = input_dict["mask"].convert("RGB")
# The model works best with images of a specific size (e.g., 512x512)
# We can resize for consistency, but for user-friendliness, we'll let the pipeline handle it.
# However, it's good practice to inform the user that square images work best.
print(f"Starting inpainting with prompt: '{prompt}'")
# Run the inpainting pipeline
result_image = pipe(
prompt=prompt,
image=image,
mask_image=mask_image,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=int(num_steps),
).images[0]
return result_image
# --- Gradio User Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🎨 AI Image Fixer (Inpainting)
Have an AI-generated image with weird hands, faces, or artifacts? Fix it here!
**How to use:**
1. Upload your image.
2. Use the brush tool to "paint" over the parts you want to replace. This is your mask.
3. Write a prompt describing what you want in the painted-over area.
4. Adjust the advanced settings if you want more control.
5. Click "Fix It!" and see the magic happen.
"""
)
with gr.Row():
# Input column
with gr.Column():
gr.Markdown("### 1. Upload & Mask Your Image")
# The Image component with a drawing tool for masking
input_image = gr.Image(
label="Upload Image & Draw Mask",
source="upload",
tool="brush",
type="pil" # We want to work with PIL images in our function
)
gr.Markdown("### 2. Describe Your Fix")
prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'A beautiful, realistic human hand, detailed fingers'")
# Accordion for advanced settings to keep the UI clean
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., 'blurry, distorted, extra fingers, cartoon'")
guidance_scale = gr.Slider(minimum=0, maximum=20, value=8.0, label="Guidance Scale")
num_steps = gr.Slider(minimum=10, maximum=100, step=1, value=40, label="Inference Steps")
# Output column
with gr.Column():
gr.Markdown("### 3. Get Your Result")
output_image = gr.Image(
label="Resulting Image",
type="pil"
)
# The button to trigger the process
submit_button = gr.Button("Fix It!", variant="primary")
# Connect the button to the function
submit_button.click(
fn=inpaint_image,
inputs=[input_image, prompt, negative_prompt, guidance_scale, num_steps],
outputs=output_image
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()