""" Minimal segmentation manager. """ import numpy as np import cv2 import torch from PIL import Image from torchvision import transforms from transformers import AutoModelForImageSegmentation from typing import Optional import logging logger = logging.getLogger(__name__) class SegmentationManager: """Minimal BRIA segmentation.""" def __init__(self, model_name: str = "briaai/RMBG-2.0", device: str = "auto", threshold: float = 0.5, trust_remote_code: bool = True, cache_dir: Optional[str] = None, local_files_only: bool = False): """Initialize segmentation.""" self.model_name = model_name self.threshold = threshold self.device = "cuda" if device == "auto" and torch.cuda.is_available() else device # Get HF token from environment (set as Space secret) import os hf_token = os.environ.get("HF_TOKEN") logger.info(f"Loading BRIA model: {model_name}") self.model = AutoModelForImageSegmentation.from_pretrained( model_name, trust_remote_code=trust_remote_code, cache_dir=cache_dir if cache_dir else None, local_files_only=local_files_only, token=hf_token, ).eval().to(self.device) self.transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) logger.info("BRIA model loaded") def segment_image_soft(self, image: np.ndarray) -> np.ndarray: """Segment image and return soft mask [0,1].""" try: logger.info(f"Segmentation: input image shape={image.shape}, dtype={image.dtype}") rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(rgb_image) input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) try: logger.info(f"Segmentation: tensor shape={input_tensor.shape}, device={self.device}") except Exception: pass with torch.no_grad(): preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy() logger.info(f"Segmentation: raw preds shape={preds.shape}, dtype={preds.dtype}") original_size = (image.shape[1], image.shape[0]) soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR) logger.info(f"Segmentation: resized soft_mask shape={soft_mask.shape}, dtype={soft_mask.dtype}") return np.clip(soft_mask, 0.0, 1.0) except Exception as e: logger.error(f"Segmentation failed: {e}") return np.zeros(image.shape[:2], dtype=np.float32)