Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| from transformers import AutoModelForImageSegmentation | |
| import spaces | |
| import os | |
| # Load RMBG-2.0 model | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True).eval().to(device) | |
| # Data settings | |
| image_size = (1024, 1024) | |
| transform_image = transforms.Compose([ | |
| transforms.Resize(image_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def remove_background(image): | |
| """ | |
| Remove background from image using RMBG-2.0 model. | |
| Args: | |
| image (PIL.Image): Input image to process | |
| Returns: | |
| PIL.Image: Image with background removed | |
| """ | |
| if image is None: | |
| return None | |
| # Transform image | |
| input_images = transform_image(image).unsqueeze(0).to(device) | |
| # Prediction | |
| with torch.no_grad(): | |
| preds = model(input_images)[-1].sigmoid().cpu() | |
| pred = preds[0].squeeze() | |
| pred_pil = transforms.ToPILImage()(pred) | |
| mask = pred_pil.resize(image.size) | |
| # Create transparent background | |
| image_rgba = image.convert('RGBA') | |
| image_rgba.putalpha(mask) | |
| return image_rgba | |
| def create_collage(original, processed): | |
| """ | |
| Create a side-by-side comparison of original and processed images. | |
| Args: | |
| original (PIL.Image): Original image | |
| processed (PIL.Image): Processed image with background removed | |
| Returns: | |
| PIL.Image: Collage of both images | |
| """ | |
| if original is None or processed is None: | |
| return None | |
| # Resize images to same height | |
| target_height = max(original.height, processed.height) | |
| # Resize original | |
| original_resized = original.resize( | |
| (int(original.width * target_height / original.height), target_height) | |
| ) | |
| # Resize processed | |
| processed_resized = processed.resize( | |
| (int(processed.width * target_height / processed.height), target_height) | |
| ) | |
| # Create collage | |
| collage_width = original_resized.width + processed_resized.width + 20 | |
| collage = Image.new('RGB', (collage_width, target_height), color='white') | |
| # Paste images | |
| collage.paste(original_resized, (0, 0)) | |
| collage.paste(processed_resized, (original_resized.width + 20, 0)) | |
| return collage | |
| def download_processed_image(image): | |
| """ | |
| Prepare image for download. | |
| Args: | |
| image (PIL.Image): Image to download | |
| Returns: | |
| PIL.Image: Image ready for download | |
| """ | |
| return image | |
| # Create Gradio interface | |
| with gr.Blocks(title="Background Removal App", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Background Removal App | |
| Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder) | |
| Upload an image to remove its background using the advanced RMBG-2.0 AI model from BRIA AI. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input") | |
| input_image = gr.Image( | |
| label="Upload Image", | |
| type="pil", | |
| sources=["upload", "webcam", "clipboard"] | |
| ) | |
| process_btn = gr.Button("Remove Background", variant="primary", size="lg") | |
| with gr.Accordion("Advanced Options", open=False): | |
| show_comparison = gr.Checkbox( | |
| label="Show Before/After Comparison", | |
| value=True | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Output") | |
| output_image = gr.Image( | |
| label="Background Removed", | |
| type="pil", | |
| format="png" | |
| ) | |
| comparison_image = gr.Image( | |
| label="Before/After Comparison", | |
| type="pil", | |
| visible=True | |
| ) | |
| download_btn = gr.DownloadButton( | |
| label="Download Result", | |
| variant="secondary" | |
| ) | |
| # Example images | |
| gr.Examples( | |
| examples=[ | |
| ["https://gradio-builds.s3.amazonaws.com/assets/cheetah-003.jpg"], | |
| ["https://gradio-builds.s3.amazonaws.com/assets/TheCheethcat.jpg"], | |
| ], | |
| inputs=input_image, | |
| outputs=output_image, | |
| fn=remove_background, | |
| cache_examples=True | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| fn=remove_background, | |
| inputs=input_image, | |
| outputs=output_image, | |
| show_progress=True | |
| ).then( | |
| fn=create_collage, | |
| inputs=[input_image, output_image], | |
| outputs=comparison_image | |
| ).then( | |
| fn=download_processed_image, | |
| inputs=output_image, | |
| outputs=download_btn | |
| ) | |
| # Update comparison visibility | |
| show_comparison.change( | |
| fn=lambda visible: gr.update(visible=visible), | |
| inputs=show_comparison, | |
| outputs=comparison_image | |
| ) | |
| # MCP Server Functions | |
| def remove_background_mcp(image_path: str) -> str: | |
| """ | |
| Remove background from an image file and save the result. | |
| Args: | |
| image_path (str): Path to the input image file | |
| Returns: | |
| str: Path to the output image file with background removed | |
| """ | |
| try: | |
| # Load image | |
| image = Image.open(image_path) | |
| # Process image | |
| result = remove_background(image) | |
| # Save result | |
| output_path = image_path.replace('.', '_no_bg.') | |
| result.save(output_path, 'PNG') | |
| return output_path | |
| except Exception as e: | |
| raise Exception(f"Error processing image: {str(e)}") | |
| def remove_background_base64(image_data: str) -> str: | |
| """ | |
| Remove background from base64 encoded image data. | |
| Args: | |
| image_data (str): Base64 encoded image data | |
| Returns: | |
| str: Base64 encoded image with background removed | |
| """ | |
| import base64 | |
| import io | |
| try: | |
| # Decode base64 | |
| image_bytes = base64.b64decode(image_data) | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| # Process image | |
| result = remove_background(image) | |
| # Encode result back to base64 | |
| output_buffer = io.BytesIO() | |
| result.save(output_buffer, format='PNG') | |
| output_bytes = output_buffer.getvalue() | |
| return base64.b64encode(output_bytes).decode('utf-8') | |
| except Exception as e: | |
| raise Exception(f"Error processing image: {str(e)}") | |
| def get_supported_formats() -> list: | |
| """ | |
| Get list of supported image formats for background removal. | |
| Returns: | |
| list: List of supported image formats | |
| """ | |
| return [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"] | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) |