Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| from ultralytics import FastSAM | |
| from ultralytics.models.fastsam import FastSAMPrompt | |
| # Set up device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load FastSAM model | |
| model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt | |
| def fig2img(fig): | |
| buf = io.BytesIO() | |
| fig.savefig(buf) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| def plot_masks(annotations, output_shape): | |
| fig, ax = plt.subplots(figsize=(10, 10)) | |
| ax.imshow(annotations[0].orig_img) | |
| for ann in annotations: | |
| for mask in ann.masks.data: | |
| mask = cv2.resize(mask.cpu().numpy().astype('uint8'), output_shape[::-1]) | |
| masked = np.ma.masked_where(mask == 0, mask) | |
| ax.imshow(masked, alpha=0.5, cmap=plt.cm.get_cmap('jet')) | |
| ax.axis('off') | |
| plt.close() | |
| return fig2img(fig) | |
| def segment_everything(input_image): | |
| try: | |
| if input_image is None: | |
| return None, "Please upload an image before submitting." | |
| input_image = Image.fromarray(input_image).convert("RGB") | |
| # Run FastSAM model in "everything" mode | |
| everything_results = model(input_image, device=device, retina_masks=True, imgsz=1024, conf=0.25, iou=0.9, agnostic_nms=True) | |
| # Prepare a Prompt Process object | |
| prompt_process = FastSAMPrompt(input_image, everything_results, device=device) | |
| # Get everything segmentation | |
| ann = prompt_process.everything_prompt() | |
| # Plot the results | |
| result_image = plot_masks(ann, input_image.size) | |
| return result_image, f"Segmented everything in the image. Found {len(ann[0].masks)} objects." | |
| except Exception as e: | |
| return None, f"An error occurred: {str(e)}" | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=segment_everything, | |
| inputs=[ | |
| gr.Image(type="numpy", label="Upload an image") | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="Segmented Image"), | |
| gr.Textbox(label="Status") | |
| ], | |
| title="FastSAM Everything Segmentation", | |
| description="Upload an image to segment all objects using FastSAM." | |
| ) | |
| # Launch the interface | |
| iface.launch() |