""" Minimal segmentation manager. """ import numpy as np import cv2 import torch from PIL import Image from torchvision import transforms from transformers import AutoModelForImageSegmentation from typing import Optional import logging logger = logging.getLogger(__name__) class SegmentationManager: """Minimal BRIA segmentation.""" def __init__(self, model_name: str = "briaai/RMBG-2.0", device: str = "auto", threshold: float = 0.5, trust_remote_code: bool = True, cache_dir: Optional[str] = None, local_files_only: bool = False): """Initialize segmentation.""" self.model_name = model_name self.threshold = threshold self.device = "cuda" if device == "auto" and torch.cuda.is_available() else device # Get HF token from environment (set as Space secret) import os hf_token = os.environ.get("HF_TOKEN") # Set cache directory to /tmp to avoid persistent storage issues if cache_dir is None: cache_dir = "/tmp/huggingface_cache" logger.info(f"Loading BRIA model: {model_name} (cache: {cache_dir})") self.model = AutoModelForImageSegmentation.from_pretrained( model_name, trust_remote_code=trust_remote_code, cache_dir=cache_dir, local_files_only=local_files_only, token=hf_token, low_cpu_mem_usage=True, # Reduce memory usage during loading ).eval().to(self.device) # Use 384x384 for even faster speed (6x improvement over 1024x1024) self.transform = transforms.Compose([ transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) logger.info(f"BRIA model loaded on device: {self.device}") def segment_image_soft(self, image: np.ndarray) -> np.ndarray: """Segment image and return soft mask [0,1].""" try: logger.info(f"Segmentation: input image shape={image.shape}, dtype={image.dtype}") rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(rgb_image) input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) try: logger.info(f"Segmentation: tensor shape={input_tensor.shape}, device={self.device}") except Exception: pass with torch.no_grad(): preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy() logger.info(f"Segmentation: raw preds shape={preds.shape}, dtype={preds.dtype}") original_size = (image.shape[1], image.shape[0]) soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR) logger.info(f"Segmentation: resized soft_mask shape={soft_mask.shape}, dtype={soft_mask.dtype}") return np.clip(soft_mask, 0.0, 1.0) except Exception as e: logger.error(f"Segmentation failed: {e}") return np.zeros(image.shape[:2], dtype=np.float32)