Spaces:
Running
Running
| # detect_and_segment.py | |
| import torch | |
| import supervision as sv | |
| from typing import List, Tuple, Optional | |
| # ==== 1. One-time global model loading ===================================== | |
| from .utils.florence import ( | |
| load_florence_model, | |
| run_florence_inference, | |
| FLORENCE_OPEN_VOCABULARY_DETECTION_TASK | |
| ) | |
| from .utils.sam import load_sam_image_model, run_sam_inference | |
| from PIL import Image, ImageDraw, ImageColor | |
| import numpy as np | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # load models once β they stay in memory for repeated calls | |
| FLORENCE_MODEL, FLORENCE_PROC = load_florence_model(device=DEVICE) | |
| SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) | |
| # quick annotators | |
| COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2'] | |
| COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS) | |
| BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX) | |
| LABEL_ANNOTATOR = sv.LabelAnnotator( | |
| color=COLOR_PALETTE, | |
| color_lookup=sv.ColorLookup.INDEX, | |
| text_position=sv.Position.CENTER_OF_MASS, | |
| text_color=sv.Color.from_hex("#000000"), | |
| border_radius=5, | |
| ) | |
| MASK_ANNOTATOR = sv.MaskAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX) | |
| # ==== 2. Inference function =============================================== | |
| def detect_and_segment( | |
| image : Image.Image, | |
| text_prompts : str | List[str], | |
| return_image : bool = True, | |
| ) -> Tuple[sv.Detections, Optional[Image.Image]]: | |
| """ | |
| Run Florence-2 open-vocabulary detection + SAM2 mask refinement on a PIL image. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Input image in RGB. | |
| text_prompts : str | List[str] | |
| Single prompt or comma-separated list (e.g. "dog, tail, leash"). | |
| return_image : bool | |
| If True, also returns an annotated PIL image. | |
| Returns | |
| ------- | |
| detections : sv.Detections | |
| Supervision object with xyxy, mask, class_id, etc. | |
| annotated : PIL.Image | None | |
| Annotated image (None if return_image=False) | |
| """ | |
| # Normalize prompt list | |
| if isinstance(text_prompts, str): | |
| prompts = [p.strip() for p in text_prompts.split(",") if p.strip()] | |
| else: | |
| prompts = [p.strip() for p in text_prompts] | |
| if len(prompts) == 0: | |
| raise ValueError("Empty prompt list given.") | |
| # Collect detections from each prompt | |
| det_list: list[sv.Detections] = [] | |
| for p in prompts: | |
| _, result = run_florence_inference( | |
| model = FLORENCE_MODEL, | |
| processor = FLORENCE_PROC, | |
| device = DEVICE, | |
| image = image, | |
| task = FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, | |
| text = p, | |
| ) | |
| det = sv.Detections.from_lmm( | |
| lmm = sv.LMM.FLORENCE_2, | |
| result = result, | |
| resolution_wh = image.size, | |
| ) | |
| det = run_sam_inference(SAM_IMAGE_MODEL, image, det) # SAM2 refinement | |
| det_list.append(det) | |
| detections = sv.Detections.merge(det_list) | |
| annotated_img = None | |
| if return_image: | |
| annotated_img = image.copy() | |
| annotated_img = MASK_ANNOTATOR.annotate(annotated_img, detections) | |
| annotated_img = BOX_ANNOTATOR.annotate(annotated_img, detections) | |
| annotated_img = LABEL_ANNOTATOR.annotate(annotated_img, detections) | |
| return detections, annotated_img | |
| def fill_detected_bboxes( | |
| image: Image.Image, | |
| text: str, | |
| inflate_pct: float = 0.10, | |
| fill_color: str | tuple[int, int, int] = "#00FF00", | |
| ): | |
| """ | |
| Detect objects matching `text`, inflate each bounding-box by `inflate_pct`, | |
| fill the area with `fill_color`, and return the resulting image. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Input image (RGB). | |
| text : str | |
| Comma-separated prompt(s) for open-vocabulary detection. | |
| inflate_pct : float, default 0.10 | |
| Extra margin per side (0.10 = +10 % width & height). | |
| fill_color : str | tuple, default "#00FF00" | |
| Solid color used to fill each inflated bbox (hex or RGB tuple). | |
| Returns | |
| ------- | |
| filled_img : PIL.Image | |
| Image with each detected (inflated) box filled. | |
| detections : sv.Detections | |
| Original detection object from `detect_and_segment`. | |
| """ | |
| # run Florence2 + SAM2 pipeline (your helper from earlier) | |
| detections, _ = detect_and_segment(image, text) | |
| w, h = image.size | |
| filled_img = image.copy() | |
| draw = ImageDraw.Draw(filled_img) | |
| fill_rgb = ImageColor.getrgb(fill_color) if isinstance(fill_color, str) else fill_color | |
| for box in detections.xyxy: | |
| # xyxy is numpy array β cast to float for math | |
| x1, y1, x2, y2 = box.astype(float) | |
| dw, dh = (x2 - x1) * inflate_pct, (y2 - y1) * inflate_pct | |
| x1_i = max(0, x1 - dw) | |
| y1_i = max(0, y1 - dh) | |
| x2_i = min(w, x2 + dw) | |
| y2_i = min(h, y2 + dh) | |
| draw.rectangle([x1_i, y1_i, x2_i, y2_i], fill=fill_rgb) | |
| return filled_img, detections | |