Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| #import torch.nn.functional as F | |
| #import torchvision | |
| #import torchvision.transforms as T | |
| from diffusers import StableDiffusionInpaintPipeline | |
| import numpy as np | |
| #import cv2 | |
| import os | |
| import shutil | |
| from gradio_client import Client, handle_file | |
| # Load the model once globally to avoid repeated loading | |
| def load_inpainting_model(): | |
| model_path = "uberRealisticPornMerge_v23Inpainting.safetensors" | |
| #model_path = "pornmasterFantasy_v4-inpainting.safetensors" | |
| #model_path = "pornmasterAmateur_v6Vae-inpainting.safetensors" | |
| device = "cpu" # Explicitly use CPU | |
| pipe = StableDiffusionInpaintPipeline.from_single_file( | |
| model_path, | |
| torch_dtype=torch.float32, # Use float32 for CPU | |
| safety_checker=None | |
| ).to(device) | |
| return pipe | |
| # Preload the model once | |
| inpaint_pipeline = load_inpainting_model() | |
| # Function to resize image (simpler interpolation method for speed) | |
| def resize_to_match(input_image, output_image): | |
| #torch_img = pil_to_torch(input_image) | |
| #torch_img_scaled = F.interpolate(torch_img.unsqueeze(0),mode='trilinear').squeeze(0) | |
| #output_image = torchvision.transforms.functional.to_pil_image(torch_img_scaled, mode=None) | |
| #return output_image | |
| return output_image.resize(input_image.size, Image.BICUBIC) # Use BILINEAR for faster resizing | |
| # Function to generate the mask using Florence SAM Masking API (Replicate) | |
| def generate_mask(image_path, text_prompt="clothing"): | |
| client_sam = Client("SkalskiP/florence-sam-masking") | |
| mask_result = client_sam.predict( | |
| image_input=handle_file(image_path), # Provide your image path here | |
| text_input=text_prompt, # Use "clothing" as the prompt | |
| api_name="/process_image" | |
| ) | |
| return mask_result # This is the local path to the generated mask | |
| # Save the generated mask | |
| def save_mask(mask_local_path, save_path="generated_mask.png"): | |
| try: | |
| shutil.copy(mask_local_path, save_path) | |
| except Exception as e: | |
| print(f"Failed to save the mask: {e}") | |
| # Function to perform inpainting | |
| def inpaint_image(input_image, mask_image): | |
| prompt = "undress, naked" | |
| result = inpaint_pipeline(prompt=prompt, image=input_image, mask_image=mask_image) | |
| inpainted_image = result.images[0] | |
| #inpainted_image = resize_to_match(input_image, inpainted_image) | |
| return inpainted_image | |
| # Function to process input image and mask | |
| def process_image(input_image): | |
| # Save the input image temporarily to process with Replicate | |
| input_image_path = "temp_input_image.png" | |
| input_image.save(input_image_path) | |
| # Generate the mask using Florence SAM API | |
| mask_local_path = generate_mask(image_path=input_image_path) | |
| # Save the generated mask | |
| mask_image_path = "generated_mask.png" | |
| save_mask(mask_local_path, save_path=mask_image_path) | |
| # Open the mask image and perform inpainting | |
| mask_image = Image.open(mask_image_path) | |
| result_image = inpaint_image(input_image, mask_image) | |
| # Clean up temporary files | |
| os.remove(input_image_path) | |
| os.remove(mask_image_path) | |
| return result_image | |
| # Define Gradio interface using Blocks API | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| input_image = gr.Image(label="Upload Input Image", type="pil") | |
| output_image = gr.Image(type="pil", label="Output Image") | |
| # Button to trigger the process | |
| with gr.Row(): | |
| btn = gr.Button("Run Inpainting") | |
| # Function to run when button is clicked | |
| btn.click(fn=process_image, inputs=[input_image], outputs=output_image) | |
| # Launch the Gradio app | |
| demo.launch(share=True) | |