Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import torch | |
| from diffusers import AutoPipelineForInpainting | |
| from PIL import Image | |
| # --- Model Loading --- | |
| # Load the model only once at the start of the application | |
| # We use float16 for memory efficiency and speed on GPUs | |
| # If no GPU is available, this will run on CPU (but it will be very slow) | |
| try: | |
| pipe = AutoPipelineForInpainting.from_pretrained( | |
| "stabilityai/stable-diffusion-2-inpainting", | |
| torch_dtype=torch.float16, | |
| variant="fp16" | |
| ).to("cuda") | |
| except Exception as e: | |
| print(f"Could not load model on GPU: {e}. Falling back to CPU.") | |
| pipe = AutoPipelineForInpainting.from_pretrained( | |
| "stabilityai/stable-diffusion-2-inpainting" | |
| ) | |
| # --- The Inpainting Function --- | |
| # This is the core function that takes user inputs and generates the image | |
| def inpaint_image(input_dict, prompt, negative_prompt, guidance_scale, num_steps): | |
| """ | |
| Performs inpainting on an image based on a mask and a prompt. | |
| Args: | |
| input_dict (dict): A dictionary from Gradio's Image component containing 'image' and 'mask'. | |
| prompt (str): The text prompt describing what to generate in the masked area. | |
| negative_prompt (str): The text prompt describing what to avoid. | |
| guidance_scale (float): A value to control how much the generation follows the prompt. | |
| num_steps (int): The number of inference steps. | |
| Returns: | |
| PIL.Image: The resulting image after inpainting. | |
| """ | |
| # Separate the image and the mask from the input dictionary | |
| image = input_dict["image"].convert("RGB") | |
| mask_image = input_dict["mask"].convert("RGB") | |
| # The model works best with images of a specific size (e.g., 512x512) | |
| # We can resize for consistency, but for user-friendliness, we'll let the pipeline handle it. | |
| # However, it's good practice to inform the user that square images work best. | |
| print(f"Starting inpainting with prompt: '{prompt}'") | |
| # Run the inpainting pipeline | |
| result_image = pipe( | |
| prompt=prompt, | |
| image=image, | |
| mask_image=mask_image, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=int(num_steps), | |
| ).images[0] | |
| return result_image | |
| # --- Gradio User Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎨 AI Image Fixer (Inpainting) | |
| Have an AI-generated image with weird hands, faces, or artifacts? Fix it here! | |
| **How to use:** | |
| 1. Upload your image. | |
| 2. Use the brush tool to "paint" over the parts you want to replace. This is your mask. | |
| 3. Write a prompt describing what you want in the painted-over area. | |
| 4. Adjust the advanced settings if you want more control. | |
| 5. Click "Fix It!" and see the magic happen. | |
| """ | |
| ) | |
| with gr.Row(): | |
| # Input column | |
| with gr.Column(): | |
| gr.Markdown("### 1. Upload & Mask Your Image") | |
| # The Image component with a drawing tool for masking | |
| input_image = gr.Image( | |
| label="Upload Image & Draw Mask", | |
| source="upload", | |
| tool="brush", | |
| type="pil" # We want to work with PIL images in our function | |
| ) | |
| gr.Markdown("### 2. Describe Your Fix") | |
| prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'A beautiful, realistic human hand, detailed fingers'") | |
| # Accordion for advanced settings to keep the UI clean | |
| with gr.Accordion("Advanced Settings", open=False): | |
| negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., 'blurry, distorted, extra fingers, cartoon'") | |
| guidance_scale = gr.Slider(minimum=0, maximum=20, value=8.0, label="Guidance Scale") | |
| num_steps = gr.Slider(minimum=10, maximum=100, step=1, value=40, label="Inference Steps") | |
| # Output column | |
| with gr.Column(): | |
| gr.Markdown("### 3. Get Your Result") | |
| output_image = gr.Image( | |
| label="Resulting Image", | |
| type="pil" | |
| ) | |
| # The button to trigger the process | |
| submit_button = gr.Button("Fix It!", variant="primary") | |
| # Connect the button to the function | |
| submit_button.click( | |
| fn=inpaint_image, | |
| inputs=[input_image, prompt, negative_prompt, guidance_scale, num_steps], | |
| outputs=output_image | |
| ) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| demo.launch() |