import gradio as gr import numpy as np from PIL import Image import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation import spaces import os # Load RMBG-2.0 model device = 'cuda' if torch.cuda.is_available() else 'cpu' model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True).eval().to(device) # Data settings image_size = (1024, 1024) transform_image = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) @spaces.GPU(duration=30) def remove_background(image): """ Remove background from image using RMBG-2.0 model. Args: image (PIL.Image): Input image to process Returns: PIL.Image: Image with background removed """ if image is None: return None # Transform image input_images = transform_image(image).unsqueeze(0).to(device) # Prediction with torch.no_grad(): preds = model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image.size) # Create transparent background image_rgba = image.convert('RGBA') image_rgba.putalpha(mask) return image_rgba def create_collage(original, processed): """ Create a side-by-side comparison of original and processed images. Args: original (PIL.Image): Original image processed (PIL.Image): Processed image with background removed Returns: PIL.Image: Collage of both images """ if original is None or processed is None: return None # Resize images to same height target_height = max(original.height, processed.height) # Resize original original_resized = original.resize( (int(original.width * target_height / original.height), target_height) ) # Resize processed processed_resized = processed.resize( (int(processed.width * target_height / processed.height), target_height) ) # Create collage collage_width = original_resized.width + processed_resized.width + 20 collage = Image.new('RGB', (collage_width, target_height), color='white') # Paste images collage.paste(original_resized, (0, 0)) collage.paste(processed_resized, (original_resized.width + 20, 0)) return collage def download_processed_image(image): """ Prepare image for download. Args: image (PIL.Image): Image to download Returns: PIL.Image: Image ready for download """ return image # Create Gradio interface with gr.Blocks(title="Background Removal App", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Background Removal App Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder) Upload an image to remove its background using the advanced RMBG-2.0 AI model from BRIA AI. """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Input") input_image = gr.Image( label="Upload Image", type="pil", sources=["upload", "webcam", "clipboard"] ) process_btn = gr.Button("Remove Background", variant="primary", size="lg") with gr.Accordion("Advanced Options", open=False): show_comparison = gr.Checkbox( label="Show Before/After Comparison", value=True ) with gr.Column(scale=2): gr.Markdown("### Output") output_image = gr.Image( label="Background Removed", type="pil", format="png" ) comparison_image = gr.Image( label="Before/After Comparison", type="pil", visible=True ) download_btn = gr.DownloadButton( label="Download Result", variant="secondary" ) # Example images gr.Examples( examples=[ ["https://gradio-builds.s3.amazonaws.com/assets/cheetah-003.jpg"], ["https://gradio-builds.s3.amazonaws.com/assets/TheCheethcat.jpg"], ], inputs=input_image, outputs=output_image, fn=remove_background, cache_examples=True ) # Event handlers process_btn.click( fn=remove_background, inputs=input_image, outputs=output_image, show_progress=True ).then( fn=create_collage, inputs=[input_image, output_image], outputs=comparison_image ).then( fn=download_processed_image, inputs=output_image, outputs=download_btn ) # Update comparison visibility show_comparison.change( fn=lambda visible: gr.update(visible=visible), inputs=show_comparison, outputs=comparison_image ) # MCP Server Functions def remove_background_mcp(image_path: str) -> str: """ Remove background from an image file and save the result. Args: image_path (str): Path to the input image file Returns: str: Path to the output image file with background removed """ try: # Load image image = Image.open(image_path) # Process image result = remove_background(image) # Save result output_path = image_path.replace('.', '_no_bg.') result.save(output_path, 'PNG') return output_path except Exception as e: raise Exception(f"Error processing image: {str(e)}") def remove_background_base64(image_data: str) -> str: """ Remove background from base64 encoded image data. Args: image_data (str): Base64 encoded image data Returns: str: Base64 encoded image with background removed """ import base64 import io try: # Decode base64 image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)) # Process image result = remove_background(image) # Encode result back to base64 output_buffer = io.BytesIO() result.save(output_buffer, format='PNG') output_bytes = output_buffer.getvalue() return base64.b64encode(output_bytes).decode('utf-8') except Exception as e: raise Exception(f"Error processing image: {str(e)}") def get_supported_formats() -> list: """ Get list of supported image formats for background removal. Returns: list: List of supported image formats """ return [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"] if __name__ == "__main__": demo.launch(mcp_server=True)