File size: 4,630 Bytes
55d13a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# app.py

import gradio as gr
import torch
from diffusers import AutoPipelineForInpainting
from PIL import Image

# --- Model Loading ---
# Load the model only once at the start of the application
# We use float16 for memory efficiency and speed on GPUs
# If no GPU is available, this will run on CPU (but it will be very slow)
try:
    pipe = AutoPipelineForInpainting.from_pretrained(
        "stabilityai/stable-diffusion-2-inpainting",
        torch_dtype=torch.float16,
        variant="fp16"
    ).to("cuda")
except Exception as e:
    print(f"Could not load model on GPU: {e}. Falling back to CPU.")
    pipe = AutoPipelineForInpainting.from_pretrained(
        "stabilityai/stable-diffusion-2-inpainting"
    )

# --- The Inpainting Function ---
# This is the core function that takes user inputs and generates the image
def inpaint_image(input_dict, prompt, negative_prompt, guidance_scale, num_steps):
    """
    Performs inpainting on an image based on a mask and a prompt.
    
    Args:
        input_dict (dict): A dictionary from Gradio's Image component containing 'image' and 'mask'.
        prompt (str): The text prompt describing what to generate in the masked area.
        negative_prompt (str): The text prompt describing what to avoid.
        guidance_scale (float): A value to control how much the generation follows the prompt.
        num_steps (int): The number of inference steps.
        
    Returns:
        PIL.Image: The resulting image after inpainting.
    """
    # Separate the image and the mask from the input dictionary
    image = input_dict["image"].convert("RGB")
    mask_image = input_dict["mask"].convert("RGB")

    # The model works best with images of a specific size (e.g., 512x512)
    # We can resize for consistency, but for user-friendliness, we'll let the pipeline handle it.
    # However, it's good practice to inform the user that square images work best.
    
    print(f"Starting inpainting with prompt: '{prompt}'")
    
    # Run the inpainting pipeline
    result_image = pipe(
        prompt=prompt,
        image=image,
        mask_image=mask_image,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=int(num_steps),
    ).images[0]

    return result_image

# --- Gradio User Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🎨 AI Image Fixer (Inpainting)
        
        Have an AI-generated image with weird hands, faces, or artifacts? Fix it here!
        
        **How to use:**
        1. Upload your image.
        2. Use the brush tool to "paint" over the parts you want to replace. This is your mask.
        3. Write a prompt describing what you want in the painted-over area.
        4. Adjust the advanced settings if you want more control.
        5. Click "Fix It!" and see the magic happen.
        """
    )
    
    with gr.Row():
        # Input column
        with gr.Column():
            gr.Markdown("### 1. Upload & Mask Your Image")
            # The Image component with a drawing tool for masking
            input_image = gr.Image(
                label="Upload Image & Draw Mask",
                source="upload",
                tool="brush",
                type="pil" # We want to work with PIL images in our function
            )
            
            gr.Markdown("### 2. Describe Your Fix")
            prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'A beautiful, realistic human hand, detailed fingers'")
            
            # Accordion for advanced settings to keep the UI clean
            with gr.Accordion("Advanced Settings", open=False):
                negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., 'blurry, distorted, extra fingers, cartoon'")
                guidance_scale = gr.Slider(minimum=0, maximum=20, value=8.0, label="Guidance Scale")
                num_steps = gr.Slider(minimum=10, maximum=100, step=1, value=40, label="Inference Steps")
        
        # Output column
        with gr.Column():
            gr.Markdown("### 3. Get Your Result")
            output_image = gr.Image(
                label="Resulting Image",
                type="pil"
            )

    # The button to trigger the process
    submit_button = gr.Button("Fix It!", variant="primary")
    
    # Connect the button to the function
    submit_button.click(
        fn=inpaint_image,
        inputs=[input_image, prompt, negative_prompt, guidance_scale, num_steps],
        outputs=output_image
    )

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()