Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| import logging | |
| import os | |
| from typing import Literal, Union | |
| import cv2 | |
| import numpy as np | |
| import rembg | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| from segment_anything import ( | |
| SamAutomaticMaskGenerator, | |
| SamPredictor, | |
| sam_model_registry, | |
| ) | |
| from transformers import pipeline | |
| from embodied_gen.data.utils import resize_pil, trellis_preprocess | |
| from embodied_gen.utils.process_media import filter_small_connected_components | |
| from embodied_gen.validators.quality_checkers import ImageSegChecker | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| __all__ = [ | |
| "SAMRemover", | |
| "SAMPredictor", | |
| "RembgRemover", | |
| "get_segmented_image_by_agent", | |
| ] | |
| class SAMRemover(object): | |
| """Loading SAM models and performing background removal on images. | |
| Attributes: | |
| checkpoint (str): Path to the model checkpoint. | |
| model_type (str): Type of the SAM model to load (default: "vit_h"). | |
| area_ratio (float): Area ratio filtering small connected components. | |
| """ | |
| def __init__( | |
| self, | |
| checkpoint: str = None, | |
| model_type: str = "vit_h", | |
| area_ratio: float = 15, | |
| ): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model_type = model_type | |
| self.area_ratio = area_ratio | |
| if checkpoint is None: | |
| suffix = "sam" | |
| model_path = snapshot_download( | |
| repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" | |
| ) | |
| checkpoint = os.path.join( | |
| model_path, suffix, "sam_vit_h_4b8939.pth" | |
| ) | |
| self.mask_generator = self._load_sam_model(checkpoint) | |
| def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator: | |
| sam = sam_model_registry[self.model_type](checkpoint=checkpoint) | |
| sam.to(device=self.device) | |
| return SamAutomaticMaskGenerator(sam) | |
| def __call__( | |
| self, image: Union[str, Image.Image, np.ndarray], save_path: str = None | |
| ) -> Image.Image: | |
| """Removes the background from an image using the SAM model. | |
| Args: | |
| image (Union[str, Image.Image, np.ndarray]): Input image, | |
| can be a file path, PIL Image, or numpy array. | |
| save_path (str): Path to save the output image (default: None). | |
| Returns: | |
| Image.Image: The image with background removed, | |
| including an alpha channel. | |
| """ | |
| # Convert input to numpy array | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert("RGB") | |
| image = resize_pil(image) | |
| image = np.array(image.convert("RGB")) | |
| # Generate masks | |
| masks = self.mask_generator.generate(image) | |
| masks = sorted(masks, key=lambda x: x["area"], reverse=True) | |
| if not masks: | |
| logger.warning( | |
| "Segmentation failed: No mask generated, return raw image." | |
| ) | |
| output_image = Image.fromarray(image, mode="RGB") | |
| else: | |
| # Use the largest mask | |
| best_mask = masks[0]["segmentation"] | |
| mask = (best_mask * 255).astype(np.uint8) | |
| mask = filter_small_connected_components( | |
| mask, area_ratio=self.area_ratio | |
| ) | |
| # Apply the mask to remove the background | |
| background_removed = cv2.bitwise_and(image, image, mask=mask) | |
| output_image = np.dstack((background_removed, mask)) | |
| output_image = Image.fromarray(output_image, mode="RGBA") | |
| if save_path is not None: | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| output_image.save(save_path) | |
| return output_image | |
| class SAMPredictor(object): | |
| def __init__( | |
| self, | |
| checkpoint: str = None, | |
| model_type: str = "vit_h", | |
| binary_thresh: float = 0.1, | |
| device: str = "cuda", | |
| ): | |
| self.device = device | |
| self.model_type = model_type | |
| if checkpoint is None: | |
| suffix = "sam" | |
| model_path = snapshot_download( | |
| repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" | |
| ) | |
| checkpoint = os.path.join( | |
| model_path, suffix, "sam_vit_h_4b8939.pth" | |
| ) | |
| self.predictor = self._load_sam_model(checkpoint) | |
| self.binary_thresh = binary_thresh | |
| def _load_sam_model(self, checkpoint: str) -> SamPredictor: | |
| sam = sam_model_registry[self.model_type](checkpoint=checkpoint) | |
| sam.to(device=self.device) | |
| return SamPredictor(sam) | |
| def preprocess_image(self, image: Image.Image) -> np.ndarray: | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert("RGB") | |
| image = resize_pil(image) | |
| image = np.array(image.convert("RGB")) | |
| return image | |
| def generate_masks( | |
| self, | |
| image: np.ndarray, | |
| selected_points: list[list[int]], | |
| ) -> np.ndarray: | |
| if len(selected_points) == 0: | |
| return [] | |
| points = ( | |
| torch.Tensor([p for p, _ in selected_points]) | |
| .to(self.predictor.device) | |
| .unsqueeze(1) | |
| ) | |
| labels = ( | |
| torch.Tensor([int(l) for _, l in selected_points]) | |
| .to(self.predictor.device) | |
| .unsqueeze(1) | |
| ) | |
| transformed_points = self.predictor.transform.apply_coords_torch( | |
| points, image.shape[:2] | |
| ) | |
| masks, scores, _ = self.predictor.predict_torch( | |
| point_coords=transformed_points, | |
| point_labels=labels, | |
| multimask_output=True, | |
| ) | |
| valid_mask = masks[:, torch.argmax(scores, dim=1)] | |
| masks_pos = valid_mask[labels[:, 0] == 1, 0].cpu().detach().numpy() | |
| masks_neg = valid_mask[labels[:, 0] == 0, 0].cpu().detach().numpy() | |
| if len(masks_neg) == 0: | |
| masks_neg = np.zeros_like(masks_pos) | |
| if len(masks_pos) == 0: | |
| masks_pos = np.zeros_like(masks_neg) | |
| masks_neg = masks_neg.max(axis=0, keepdims=True) | |
| masks_pos = masks_pos.max(axis=0, keepdims=True) | |
| valid_mask = (masks_pos.astype(int) - masks_neg.astype(int)).clip(0, 1) | |
| binary_mask = (valid_mask > self.binary_thresh).astype(np.int32) | |
| return [(mask, f"mask_{i}") for i, mask in enumerate(binary_mask)] | |
| def get_segmented_image( | |
| self, image: np.ndarray, masks: list[tuple[np.ndarray, str]] | |
| ) -> Image.Image: | |
| seg_image = Image.fromarray(image, mode="RGB") | |
| alpha_channel = np.zeros( | |
| (seg_image.height, seg_image.width), dtype=np.uint8 | |
| ) | |
| for mask, _ in masks: | |
| # Use the maximum to combine multiple masks | |
| alpha_channel = np.maximum(alpha_channel, mask) | |
| alpha_channel = np.clip(alpha_channel, 0, 1) | |
| alpha_channel = (alpha_channel * 255).astype(np.uint8) | |
| alpha_image = Image.fromarray(alpha_channel, mode="L") | |
| r, g, b = seg_image.split() | |
| seg_image = Image.merge("RGBA", (r, g, b, alpha_image)) | |
| return seg_image | |
| def __call__( | |
| self, | |
| image: Union[str, Image.Image, np.ndarray], | |
| selected_points: list[list[int]], | |
| ) -> Image.Image: | |
| image = self.preprocess_image(image) | |
| self.predictor.set_image(image) | |
| masks = self.generate_masks(image, selected_points) | |
| return self.get_segmented_image(image, masks) | |
| class RembgRemover(object): | |
| def __init__(self): | |
| self.rembg_session = rembg.new_session("u2net") | |
| def __call__( | |
| self, image: Union[str, Image.Image, np.ndarray], save_path: str = None | |
| ) -> Image.Image: | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| image = resize_pil(image) | |
| output_image = rembg.remove(image, session=self.rembg_session) | |
| if save_path is not None: | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| output_image.save(save_path) | |
| return output_image | |
| class BMGG14Remover(object): | |
| def __init__(self) -> None: | |
| self.model = pipeline( | |
| "image-segmentation", | |
| model="briaai/RMBG-1.4", | |
| trust_remote_code=True, | |
| ) | |
| def __call__( | |
| self, image: Union[str, Image.Image, np.ndarray], save_path: str = None | |
| ): | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| image = resize_pil(image) | |
| output_image = self.model(image) | |
| if save_path is not None: | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| output_image.save(save_path) | |
| return output_image | |
| def invert_rgba_pil( | |
| image: Image.Image, mask: Image.Image, save_path: str = None | |
| ) -> Image.Image: | |
| mask = (255 - np.array(mask))[..., None] | |
| image_array = np.concatenate([np.array(image), mask], axis=-1) | |
| inverted_image = Image.fromarray(image_array, "RGBA") | |
| if save_path is not None: | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| inverted_image.save(save_path) | |
| return inverted_image | |
| def get_segmented_image_by_agent( | |
| image: Image.Image, | |
| sam_remover: SAMRemover, | |
| rbg_remover: RembgRemover, | |
| seg_checker: ImageSegChecker = None, | |
| save_path: str = None, | |
| mode: Literal["loose", "strict"] = "loose", | |
| ) -> Image.Image: | |
| def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool: | |
| if seg_checker is None: | |
| return True | |
| return raw_img.mode == "RGBA" and seg_checker([raw_img, seg_img])[0] | |
| out_sam = f"{save_path}_sam.png" if save_path else None | |
| out_sam_inv = f"{save_path}_sam_inv.png" if save_path else None | |
| out_rbg = f"{save_path}_rbg.png" if save_path else None | |
| seg_image = sam_remover(image, out_sam) | |
| seg_image = seg_image.convert("RGBA") | |
| _, _, _, alpha = seg_image.split() | |
| seg_image_inv = invert_rgba_pil(image.convert("RGB"), alpha, out_sam_inv) | |
| seg_image_rbg = rbg_remover(image, out_rbg) | |
| final_image = None | |
| if _is_valid_seg(image, seg_image): | |
| final_image = seg_image | |
| elif _is_valid_seg(image, seg_image_inv): | |
| final_image = seg_image_inv | |
| elif _is_valid_seg(image, seg_image_rbg): | |
| logger.warning(f"Failed to segment by `SAM`, retry with `rembg`.") | |
| final_image = seg_image_rbg | |
| else: | |
| if mode == "strict": | |
| raise RuntimeError( | |
| f"Failed to segment by `SAM` or `rembg`, abort." | |
| ) | |
| logger.warning("Failed to segment by SAM or rembg, use raw image.") | |
| final_image = image.convert("RGBA") | |
| if save_path: | |
| final_image.save(save_path) | |
| final_image = trellis_preprocess(final_image) | |
| return final_image | |
| if __name__ == "__main__": | |
| input_image = "outputs/text2image/demo_objects/electrical/sample_0.jpg" | |
| output_image = "sample_0_seg2.png" | |
| # input_image = "outputs/text2image/tmp/coffee_machine.jpeg" | |
| # output_image = "outputs/text2image/tmp/coffee_machine_seg.png" | |
| # input_image = "outputs/text2image/tmp/bucket.jpeg" | |
| # output_image = "outputs/text2image/tmp/bucket_seg.png" | |
| remover = SAMRemover(model_type="vit_h") | |
| remover = RembgRemover() | |
| clean_image = remover(input_image) | |
| clean_image.save(output_image) | |
| get_segmented_image_by_agent( | |
| Image.open(input_image), remover, remover, None, "./test_seg.png" | |
| ) | |
| remover = BMGG14Remover() | |
| remover("embodied_gen/models/test_seg.jpg", "./seg.png") | |