Zhang-Yang-Sustech
Add multi-points input, foreground/background points input and box input to EfficientSAM model (#291)
2a88bde
| import argparse | |
| import numpy as np | |
| import cv2 as cv | |
| from efficientSAM import EfficientSAM | |
| # Check OpenCV version | |
| opencv_python_version = lambda str_version: tuple(map(int, (str_version.split(".")))) | |
| assert opencv_python_version(cv.__version__) >= opencv_python_version("4.10.0"), \ | |
| "Please install latest opencv-python for benchmark: python3 -m pip install --upgrade opencv-python" | |
| # Valid combinations of backends and targets | |
| backend_target_pairs = [ | |
| [cv.dnn.DNN_BACKEND_OPENCV, cv.dnn.DNN_TARGET_CPU], | |
| [cv.dnn.DNN_BACKEND_CUDA, cv.dnn.DNN_TARGET_CUDA], | |
| [cv.dnn.DNN_BACKEND_CUDA, cv.dnn.DNN_TARGET_CUDA_FP16], | |
| [cv.dnn.DNN_BACKEND_TIMVX, cv.dnn.DNN_TARGET_NPU], | |
| [cv.dnn.DNN_BACKEND_CANN, cv.dnn.DNN_TARGET_NPU] | |
| ] | |
| parser = argparse.ArgumentParser(description='EfficientSAM Demo') | |
| parser.add_argument('--input', '-i', type=str, | |
| help='Set input path to a certain image.') | |
| parser.add_argument('--model', '-m', type=str, default='image_segmentation_efficientsam_ti_2025april.onnx', | |
| help='Set model path, defaults to image_segmentation_efficientsam_ti_2025april.onnx.') | |
| parser.add_argument('--backend_target', '-bt', type=int, default=0, | |
| help='''Choose one of the backend-target pair to run this demo: | |
| {:d}: (default) OpenCV implementation + CPU, | |
| {:d}: CUDA + GPU (CUDA), | |
| {:d}: CUDA + GPU (CUDA FP16), | |
| {:d}: TIM-VX + NPU, | |
| {:d}: CANN + NPU | |
| '''.format(*[x for x in range(len(backend_target_pairs))])) | |
| parser.add_argument('--save', '-s', action='store_true', | |
| help='Specify to save a file with results. Invalid in case of camera input.') | |
| args = parser.parse_args() | |
| # Global configuration | |
| WINDOW_SIZE = (800, 600) # Fixed window size (width, height) | |
| MAX_POINTS = 6 # Maximum allowed points | |
| points = [] # Store clicked coordinates (original image scale) | |
| labels = [] # Point labels (-1: useless, 0: background, 1: foreground, 2: top-left, 3: bottom right) | |
| backend_point = [] | |
| rectangle = False | |
| current_img = None | |
| def visualize(image, result): | |
| """ | |
| Visualize the inference result on the input image. | |
| Args: | |
| image (np.ndarray): The input image. | |
| result (np.ndarray): The inference result. | |
| Returns: | |
| vis_result (np.ndarray): The visualized result. | |
| """ | |
| # get image and mask | |
| vis_result = np.copy(image) | |
| mask = np.copy(result) | |
| # change mask to binary image | |
| t, binary = cv.threshold(mask, 127, 255, cv.THRESH_BINARY) | |
| assert set(np.unique(binary)) <= {0, 255}, "The mask must be a binary image." | |
| # enhance red channel to make the segmentation more obviously | |
| enhancement_factor = 1.8 | |
| red_channel = vis_result[:, :, 2] | |
| # update the channel | |
| red_channel = np.where(binary == 255, np.minimum(red_channel * enhancement_factor, 255), red_channel) | |
| vis_result[:, :, 2] = red_channel | |
| # draw borders | |
| contours, hierarchy = cv.findContours(binary, cv.RETR_LIST, cv.CHAIN_APPROX_TC89_L1) | |
| cv.drawContours(vis_result, contours, contourIdx = -1, color = (255,255,255), thickness=2) | |
| return vis_result | |
| def select(event, x, y, flags, param): | |
| """Handle mouse events with coordinate conversion""" | |
| global points, labels, backend_point, rectangle, current_img | |
| orig_img = param['original_img'] | |
| image_window = param['image_window'] | |
| if event == cv.EVENT_LBUTTONDOWN: | |
| param['mouse_down_time'] = cv.getTickCount() | |
| backend_point = [x, y] | |
| elif event == cv.EVENT_MOUSEMOVE: | |
| if rectangle == True: | |
| rectangle_change_img = current_img.copy() | |
| cv.rectangle(rectangle_change_img, (backend_point[0], backend_point[1]), (x, y), (255,0,0) , 2) | |
| cv.imshow(image_window, rectangle_change_img) | |
| elif len(backend_point) != 0 and len(points) < MAX_POINTS: | |
| rectangle = True | |
| elif event == cv.EVENT_LBUTTONUP: | |
| if len(points) >= MAX_POINTS: | |
| print(f"Maximum points reached {MAX_POINTS}.") | |
| return | |
| if rectangle == False: | |
| duration = (cv.getTickCount() - param['mouse_down_time'])/cv.getTickFrequency() | |
| label = -1 if duration > 0.5 else 1 # Long press = background | |
| points.append([backend_point[0], backend_point[1]]) | |
| labels.append(label) | |
| print(f"Added {['background','foreground','background'][label]} point {backend_point}.") | |
| else: | |
| if len(points) + 1 >= MAX_POINTS: | |
| rectangle = False | |
| backend_point.clear() | |
| cv.imshow(image_window, current_img) | |
| print(f"Points reached {MAX_POINTS}, could not add box.") | |
| return | |
| point_leftup = [] | |
| point_rightdown = [] | |
| if x > backend_point[0] or y > backend_point[1]: | |
| point_leftup.extend(backend_point) | |
| point_rightdown.extend([x,y]) | |
| else: | |
| point_leftup.extend([x,y]) | |
| point_rightdown.extend(backend_point) | |
| points.append(point_leftup) | |
| points.append(point_rightdown) | |
| print(f"Added box from {point_leftup} to {point_rightdown}.") | |
| labels.append(2) | |
| labels.append(3) | |
| rectangle = False | |
| backend_point.clear() | |
| marked_img = orig_img.copy() | |
| top_left = None | |
| for (px, py), lbl in zip(points, labels): | |
| if lbl == -1: | |
| cv.circle(marked_img, (px, py), 5, (0, 0, 255), -1) | |
| elif lbl == 1: | |
| cv.circle(marked_img, (px, py), 5, (0, 255, 0), -1) | |
| elif lbl == 2: | |
| top_left = (px, py) | |
| elif lbl == 3: | |
| bottom_right = (px, py) | |
| cv.rectangle(marked_img, top_left, bottom_right, (255,0,0) , 2) | |
| cv.imshow(image_window, marked_img) | |
| current_img = marked_img.copy() | |
| if __name__ == '__main__': | |
| backend_id = backend_target_pairs[args.backend_target][0] | |
| target_id = backend_target_pairs[args.backend_target][1] | |
| # Load the EfficientSAM model | |
| model = EfficientSAM(modelPath=args.model) | |
| if args.input is not None: | |
| # Read image | |
| image = cv.imread(args.input) | |
| if image is None: | |
| print('Could not open or find the image:', args.input) | |
| exit(0) | |
| # create window | |
| image_window = "Origin image" | |
| cv.namedWindow(image_window, cv.WINDOW_NORMAL) | |
| # change window size | |
| rate = 1 | |
| rate1 = 1 | |
| rate2 = 1 | |
| if(image.shape[1]>WINDOW_SIZE[0]): | |
| rate1 = WINDOW_SIZE[0]/image.shape[1] | |
| if(image.shape[0]>WINDOW_SIZE[1]): | |
| rate2 = WINDOW_SIZE[1]/image.shape[0] | |
| rate = min(rate1, rate2) | |
| # width, height | |
| WINDOW_SIZE = (int(image.shape[1] * rate), int(image.shape[0] * rate)) | |
| cv.resizeWindow(image_window, WINDOW_SIZE[0], WINDOW_SIZE[1]) | |
| # put the window on the left of the screen | |
| cv.moveWindow(image_window, 50, 100) | |
| # set listener to record user's click point | |
| param = { | |
| 'original_img': image, | |
| 'mouse_down_time': 0, | |
| 'image_window' : image_window | |
| } | |
| cv.setMouseCallback(image_window, select, param) | |
| # tips in the terminal | |
| print("Click — Select foreground point\n" | |
| "Long press — Select background point\n" | |
| "Drag — Create selection box\n" | |
| "Enter — Infer\n" | |
| "Backspace — Clear the prompts\n" | |
| "Q - Quit") | |
| # show image | |
| cv.imshow(image_window, image) | |
| current_img = image.copy() | |
| # create window to show visualized result | |
| vis_image = image.copy() | |
| segmentation_window = "Segment result" | |
| cv.namedWindow(segmentation_window, cv.WINDOW_NORMAL) | |
| cv.resizeWindow(segmentation_window, WINDOW_SIZE[0], WINDOW_SIZE[1]) | |
| cv.moveWindow(segmentation_window, WINDOW_SIZE[0]+51, 100) | |
| cv.imshow(segmentation_window, vis_image) | |
| # waiting for click | |
| while True: | |
| # Check window status | |
| # if click × to close the image window then ending | |
| if (cv.getWindowProperty(image_window, cv.WND_PROP_VISIBLE) < 1 or | |
| cv.getWindowProperty(segmentation_window, cv.WND_PROP_VISIBLE) < 1): | |
| break | |
| # Handle keyboard input | |
| key = cv.waitKey(1) | |
| # receive enter | |
| if key == 13: | |
| vis_image = image.copy() | |
| cv.putText(vis_image, "infering...", | |
| (50, vis_image.shape[0]//2), | |
| cv.FONT_HERSHEY_SIMPLEX, 10, (255,255,255), 5) | |
| cv.imshow(segmentation_window, vis_image) | |
| result = model.infer(image=image, points=points, labels=labels) | |
| if len(result) == 0: | |
| print("clear and select points again!") | |
| else: | |
| vis_result = visualize(image, result) | |
| cv.imshow(segmentation_window, vis_result) | |
| elif key == 8 or key == 127: # ASCII for Backspace or Delete | |
| points.clear() | |
| labels.clear() | |
| backend_point = [] | |
| rectangle = False | |
| current_img = image | |
| print("Points are cleared.") | |
| cv.imshow(image_window, image) | |
| elif key == ord('q') or key == ord('Q'): | |
| break | |
| cv.destroyAllWindows() | |
| # Save results if save is true | |
| if args.save: | |
| cv.imwrite('./example_outputs/vis_result.jpg', vis_result) | |
| cv.imwrite("./example_outputs/mask.jpg", result) | |
| print('vis_result.jpg and mask.jpg are saved to ./example_outputs/') | |
| else: | |
| print('Set input path to a certain image.') | |
| pass | |