ImageSegment / app.py
humbleakh's picture
fix
518a32b verified
import gradio as gr
import torch
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import numpy as np
from PIL import Image
# Load model and processor
model_name = "nvidia/segformer-b0-finetuned-ade-512-512"
processor = SegformerImageProcessor.from_pretrained(model_name)
model = SegformerForSemanticSegmentation.from_pretrained(model_name)
# ADE20k color palette (simplified for visualization)
def create_color_map():
# Create a colorful palette for visualization
np.random.seed(42)
return np.random.randint(0, 255, size=(150, 3), dtype=np.uint8)
color_map = create_color_map()
def segment_image(image):
# Process the image
inputs = processor(images=image, return_tensors="pt")
# Get model prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get segmentation map
seg = logits.argmax(dim=1)[0].cpu().numpy()
# Convert to colored segmentation map
colored_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label in np.unique(seg):
colored_seg[seg == label] = color_map[label]
# Convert to PIL image for display
segmented_image = Image.fromarray(colored_seg)
return [image, segmented_image]
# Create Gradio interface
demo = gr.Interface(
fn=segment_image,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Original"),
gr.Image(type="pil", label="Segmented")
],
title="Image Segmentation with SegFormer",
description="Upload an image to segment it into different semantic regions using SegFormer model fine-tuned on ADE20K dataset."
)
# Launch the app
if __name__ == "__main__":
demo.launch()