super-light-lab / app.py
akhaliq's picture
akhaliq HF Staff
Update Gradio app with multiple files
28cce91 verified
import gradio as gr
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import spaces
import os
# Load RMBG-2.0 model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True).eval().to(device)
# Data settings
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@spaces.GPU(duration=30)
def remove_background(image):
"""
Remove background from image using RMBG-2.0 model.
Args:
image (PIL.Image): Input image to process
Returns:
PIL.Image: Image with background removed
"""
if image is None:
return None
# Transform image
input_images = transform_image(image).unsqueeze(0).to(device)
# Prediction
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image.size)
# Create transparent background
image_rgba = image.convert('RGBA')
image_rgba.putalpha(mask)
return image_rgba
def create_collage(original, processed):
"""
Create a side-by-side comparison of original and processed images.
Args:
original (PIL.Image): Original image
processed (PIL.Image): Processed image with background removed
Returns:
PIL.Image: Collage of both images
"""
if original is None or processed is None:
return None
# Resize images to same height
target_height = max(original.height, processed.height)
# Resize original
original_resized = original.resize(
(int(original.width * target_height / original.height), target_height)
)
# Resize processed
processed_resized = processed.resize(
(int(processed.width * target_height / processed.height), target_height)
)
# Create collage
collage_width = original_resized.width + processed_resized.width + 20
collage = Image.new('RGB', (collage_width, target_height), color='white')
# Paste images
collage.paste(original_resized, (0, 0))
collage.paste(processed_resized, (original_resized.width + 20, 0))
return collage
def download_processed_image(image):
"""
Prepare image for download.
Args:
image (PIL.Image): Image to download
Returns:
PIL.Image: Image ready for download
"""
return image
# Create Gradio interface
with gr.Blocks(title="Background Removal App", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Background Removal App
Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
Upload an image to remove its background using the advanced RMBG-2.0 AI model from BRIA AI.
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input")
input_image = gr.Image(
label="Upload Image",
type="pil",
sources=["upload", "webcam", "clipboard"]
)
process_btn = gr.Button("Remove Background", variant="primary", size="lg")
with gr.Accordion("Advanced Options", open=False):
show_comparison = gr.Checkbox(
label="Show Before/After Comparison",
value=True
)
with gr.Column(scale=2):
gr.Markdown("### Output")
output_image = gr.Image(
label="Background Removed",
type="pil",
format="png"
)
comparison_image = gr.Image(
label="Before/After Comparison",
type="pil",
visible=True
)
download_btn = gr.DownloadButton(
label="Download Result",
variant="secondary"
)
# Example images
gr.Examples(
examples=[
["https://gradio-builds.s3.amazonaws.com/assets/cheetah-003.jpg"],
["https://gradio-builds.s3.amazonaws.com/assets/TheCheethcat.jpg"],
],
inputs=input_image,
outputs=output_image,
fn=remove_background,
cache_examples=True
)
# Event handlers
process_btn.click(
fn=remove_background,
inputs=input_image,
outputs=output_image,
show_progress=True
).then(
fn=create_collage,
inputs=[input_image, output_image],
outputs=comparison_image
).then(
fn=download_processed_image,
inputs=output_image,
outputs=download_btn
)
# Update comparison visibility
show_comparison.change(
fn=lambda visible: gr.update(visible=visible),
inputs=show_comparison,
outputs=comparison_image
)
# MCP Server Functions
def remove_background_mcp(image_path: str) -> str:
"""
Remove background from an image file and save the result.
Args:
image_path (str): Path to the input image file
Returns:
str: Path to the output image file with background removed
"""
try:
# Load image
image = Image.open(image_path)
# Process image
result = remove_background(image)
# Save result
output_path = image_path.replace('.', '_no_bg.')
result.save(output_path, 'PNG')
return output_path
except Exception as e:
raise Exception(f"Error processing image: {str(e)}")
def remove_background_base64(image_data: str) -> str:
"""
Remove background from base64 encoded image data.
Args:
image_data (str): Base64 encoded image data
Returns:
str: Base64 encoded image with background removed
"""
import base64
import io
try:
# Decode base64
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes))
# Process image
result = remove_background(image)
# Encode result back to base64
output_buffer = io.BytesIO()
result.save(output_buffer, format='PNG')
output_bytes = output_buffer.getvalue()
return base64.b64encode(output_bytes).decode('utf-8')
except Exception as e:
raise Exception(f"Error processing image: {str(e)}")
def get_supported_formats() -> list:
"""
Get list of supported image formats for background removal.
Returns:
list: List of supported image formats
"""
return [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"]
if __name__ == "__main__":
demo.launch(mcp_server=True)