Spaces:
Build error
Build error
UPD: added setup.py for installation
Browse files- SegmentAnything2AssistApp.py +1 -1
- setup.py +25 -0
- src/{YOLOv10Plugin.py β SegmentAnything2Assist/Plugin/YOLOv10Plugin.py} +0 -0
- src/{__init__.py β SegmentAnything2Assist/Plugin/__init__.py} +0 -0
- src/{SegmentAnything2Assist.py β SegmentAnything2Assist/SegmentAnything2Assist.py} +11 -7
- src/SegmentAnything2Assist/__init__.py +0 -0
- test/assets/liberty.jpg +0 -0
- test/test_module.py +59 -0
SegmentAnything2AssistApp.py
CHANGED
|
@@ -4,7 +4,7 @@ import gradio_imageslider
|
|
| 4 |
import spaces
|
| 5 |
import torch
|
| 6 |
|
| 7 |
-
import src.SegmentAnything2Assist as SegmentAnything2Assist
|
| 8 |
|
| 9 |
example_image_annotation = {
|
| 10 |
"image": "assets/cars.jpg",
|
|
|
|
| 4 |
import spaces
|
| 5 |
import torch
|
| 6 |
|
| 7 |
+
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
|
| 8 |
|
| 9 |
example_image_annotation = {
|
| 10 |
"image": "assets/cars.jpg",
|
setup.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="SegmentAnything2Assist",
|
| 5 |
+
version="0.1",
|
| 6 |
+
packages=find_packages(where="src"),
|
| 7 |
+
package_dir={"": "src"},
|
| 8 |
+
install_requires=[
|
| 9 |
+
"SAM-2 @ git+https://github.com/facebookresearch/segment-anything-2.git@7e1596c0b6462eb1d1ba7e1492430fed95023598",
|
| 10 |
+
"ultralytics @ git+https://github.com/THU-MIG/yolov10.git@cd2f79c70299c9041fb6d19617ef1296f47575b1",
|
| 11 |
+
"opencv-python==4.10.0.84",
|
| 12 |
+
],
|
| 13 |
+
author="xqt",
|
| 14 |
+
author_email="xqt@users.noreply.huggingface.co",
|
| 15 |
+
description="A package to segment anything and assist in the process",
|
| 16 |
+
long_description=open("README.md").read(),
|
| 17 |
+
long_description_content_type="text/markdown",
|
| 18 |
+
url="https://huggingface.co/spaces/xqt/Segment-Anything-2-Assist",
|
| 19 |
+
classifiers=[
|
| 20 |
+
"Programming Language :: Python :: 3",
|
| 21 |
+
"License :: OSI Approved :: MIT License",
|
| 22 |
+
"Operating System :: OS Independent",
|
| 23 |
+
],
|
| 24 |
+
python_requires=">=3.8.0",
|
| 25 |
+
)
|
src/{YOLOv10Plugin.py β SegmentAnything2Assist/Plugin/YOLOv10Plugin.py}
RENAMED
|
File without changes
|
src/{__init__.py β SegmentAnything2Assist/Plugin/__init__.py}
RENAMED
|
File without changes
|
src/{SegmentAnything2Assist.py β SegmentAnything2Assist/SegmentAnything2Assist.py}
RENAMED
|
@@ -5,12 +5,11 @@ import tqdm
|
|
| 5 |
import requests
|
| 6 |
import torch
|
| 7 |
import numpy
|
| 8 |
-
import pickle
|
| 9 |
|
| 10 |
import sam2.build_sam
|
| 11 |
import sam2.automatic_mask_generator
|
| 12 |
|
| 13 |
-
from . import YOLOv10Plugin
|
| 14 |
|
| 15 |
import cv2
|
| 16 |
|
|
@@ -122,14 +121,17 @@ class SegmentAnything2Assist:
|
|
| 122 |
print(f"SegmentAnything2Assist::is_model_available::{ret}")
|
| 123 |
return ret
|
| 124 |
|
| 125 |
-
def load_model(self) ->
|
| 126 |
if self.is_model_available():
|
| 127 |
self.sam2 = sam2.build_sam(checkpoint=self.model_path)
|
|
|
|
| 128 |
|
| 129 |
-
|
|
|
|
|
|
|
| 130 |
if not force and self.is_model_available():
|
| 131 |
print(f"{self.model_path} already exists. Skipping download.")
|
| 132 |
-
return
|
| 133 |
|
| 134 |
response = requests.get(self.download_url, stream=True)
|
| 135 |
total_size = int(response.headers.get("content-length", 0))
|
|
@@ -141,10 +143,12 @@ class SegmentAnything2Assist:
|
|
| 141 |
file.write(data)
|
| 142 |
progress_bar.update(len(data))
|
| 143 |
|
|
|
|
|
|
|
| 144 |
def generate_automatic_masks(
|
| 145 |
self,
|
| 146 |
-
image,
|
| 147 |
-
points_per_side=
|
| 148 |
points_per_batch=32,
|
| 149 |
pred_iou_thresh=0.8,
|
| 150 |
stability_score_thresh=0.95,
|
|
|
|
| 5 |
import requests
|
| 6 |
import torch
|
| 7 |
import numpy
|
|
|
|
| 8 |
|
| 9 |
import sam2.build_sam
|
| 10 |
import sam2.automatic_mask_generator
|
| 11 |
|
| 12 |
+
from .Plugin import YOLOv10Plugin
|
| 13 |
|
| 14 |
import cv2
|
| 15 |
|
|
|
|
| 121 |
print(f"SegmentAnything2Assist::is_model_available::{ret}")
|
| 122 |
return ret
|
| 123 |
|
| 124 |
+
def load_model(self) -> bool:
|
| 125 |
if self.is_model_available():
|
| 126 |
self.sam2 = sam2.build_sam(checkpoint=self.model_path)
|
| 127 |
+
return True
|
| 128 |
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
def download_model(self, force: bool = False) -> bool:
|
| 132 |
if not force and self.is_model_available():
|
| 133 |
print(f"{self.model_path} already exists. Skipping download.")
|
| 134 |
+
return False
|
| 135 |
|
| 136 |
response = requests.get(self.download_url, stream=True)
|
| 137 |
total_size = int(response.headers.get("content-length", 0))
|
|
|
|
| 143 |
file.write(data)
|
| 144 |
progress_bar.update(len(data))
|
| 145 |
|
| 146 |
+
return True
|
| 147 |
+
|
| 148 |
def generate_automatic_masks(
|
| 149 |
self,
|
| 150 |
+
image: numpy.ndarray,
|
| 151 |
+
points_per_side=10,
|
| 152 |
points_per_batch=32,
|
| 153 |
pred_iou_thresh=0.8,
|
| 154 |
stability_score_thresh=0.95,
|
src/SegmentAnything2Assist/__init__.py
ADDED
|
File without changes
|
test/assets/liberty.jpg
ADDED
|
test/test_module.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TestSegmentAnything2Assist(unittest.TestCase):
|
| 7 |
+
def setUp(self) -> None:
|
| 8 |
+
return super().setUp()
|
| 9 |
+
|
| 10 |
+
def tearDown(self) -> None:
|
| 11 |
+
return super().tearDown()
|
| 12 |
+
|
| 13 |
+
def _loading_all_sam_model_types(self):
|
| 14 |
+
# Test loading all types of SAM2 models.
|
| 15 |
+
all_sam_models_type = [
|
| 16 |
+
"sam2_hiera_tiny",
|
| 17 |
+
"sam2_hiera_small",
|
| 18 |
+
"sam2_hiera_base_plus",
|
| 19 |
+
"sam2_hiera_large",
|
| 20 |
+
]
|
| 21 |
+
for sam_model_type in all_sam_models_type:
|
| 22 |
+
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
| 23 |
+
sam_model_name=sam_model_type, download=True, device="cpu"
|
| 24 |
+
)
|
| 25 |
+
self.assertEqual(sam_model.is_model_available(), True)
|
| 26 |
+
|
| 27 |
+
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
| 28 |
+
sam_model_name=sam_model_type,
|
| 29 |
+
download=False,
|
| 30 |
+
model_path=f".tmp/checkpoints/{sam_model_type}.pth",
|
| 31 |
+
device="cpu",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
with self.assertRaises(Exception):
|
| 35 |
+
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
| 36 |
+
sam_model_name=sam_model_type,
|
| 37 |
+
download=False,
|
| 38 |
+
model_path=".",
|
| 39 |
+
device="cpu",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def test_generate_automatic_mask(self):
|
| 43 |
+
image = cv2.imread("test/assets/liberty.jpg")
|
| 44 |
+
|
| 45 |
+
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
| 46 |
+
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
masks, segmentation_masks, bboxes = sam_model.generate_automatic_masks(image)
|
| 50 |
+
|
| 51 |
+
print(type(masks[0]))
|
| 52 |
+
print(type(segmentation_masks[0]))
|
| 53 |
+
print(type(bboxes[0]))
|
| 54 |
+
|
| 55 |
+
self.assertEqual(len(masks), len(segmentation_masks))
|
| 56 |
+
self.assertEqual(len(masks), len(bboxes))
|
| 57 |
+
|
| 58 |
+
# for mask, segmentation_mask, bbox in zip(masks, segmentation_masks, bboxes):
|
| 59 |
+
self.assertEqual(segmentation_masks[0].shape, image.shape)
|